first round of files
This commit is contained in:
@@ -0,0 +1,68 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"github.com/gorilla/sessions"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
ContextKeyUserID contextKey = "userID"
|
||||
ContextKeyUsername contextKey = "username"
|
||||
ContextKeyIsAdmin contextKey = "isAdmin"
|
||||
)
|
||||
|
||||
type AuthMiddleware struct {
|
||||
store sessions.Store
|
||||
}
|
||||
|
||||
func NewAuth(store sessions.Store) *AuthMiddleware {
|
||||
return &AuthMiddleware{store: store}
|
||||
}
|
||||
|
||||
func (a *AuthMiddleware) Require(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := a.store.Get(r, "fb_session")
|
||||
if err != nil || session.IsNew {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
userID, ok := session.Values["userID"].(int64)
|
||||
if !ok || userID == 0 {
|
||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), ContextKeyUserID, userID)
|
||||
if username, ok := session.Values["username"].(string); ok {
|
||||
ctx = context.WithValue(ctx, ContextKeyUsername, username)
|
||||
}
|
||||
if isAdmin, ok := session.Values["isAdmin"].(bool); ok {
|
||||
ctx = context.WithValue(ctx, ContextKeyIsAdmin, isAdmin)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
})
|
||||
}
|
||||
|
||||
func (a *AuthMiddleware) Optional(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
session, err := a.store.Get(r, "fb_session")
|
||||
if err == nil && !session.IsNew {
|
||||
if userID, ok := session.Values["userID"].(int64); ok && userID != 0 {
|
||||
ctx := context.WithValue(r.Context(), ContextKeyUserID, userID)
|
||||
r = r.WithContext(ctx)
|
||||
}
|
||||
}
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func UserIDFromContext(ctx context.Context) (int64, bool) {
|
||||
id, ok := ctx.Value(ContextKeyUserID).(int64)
|
||||
return id, ok
|
||||
}
|
||||
Reference in New Issue
Block a user