package middleware import ( "context" "net/http" "github.com/gorilla/sessions" ) type contextKey string const ( ContextKeyUserID contextKey = "userID" ContextKeyUsername contextKey = "username" ContextKeyIsAdmin contextKey = "isAdmin" ) type AuthMiddleware struct { store sessions.Store } func NewAuth(store sessions.Store) *AuthMiddleware { return &AuthMiddleware{store: store} } func (a *AuthMiddleware) Require(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session, err := a.store.Get(r, "fb_session") if err != nil || session.IsNew { http.Error(w, "unauthorized", http.StatusUnauthorized) return } userID, ok := session.Values["userID"].(int64) if !ok || userID == 0 { http.Error(w, "unauthorized", http.StatusUnauthorized) return } ctx := context.WithValue(r.Context(), ContextKeyUserID, userID) if username, ok := session.Values["username"].(string); ok { ctx = context.WithValue(ctx, ContextKeyUsername, username) } if isAdmin, ok := session.Values["isAdmin"].(bool); ok { ctx = context.WithValue(ctx, ContextKeyIsAdmin, isAdmin) } next.ServeHTTP(w, r.WithContext(ctx)) }) } func (a *AuthMiddleware) Optional(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session, err := a.store.Get(r, "fb_session") if err == nil && !session.IsNew { if userID, ok := session.Values["userID"].(int64); ok && userID != 0 { ctx := context.WithValue(r.Context(), ContextKeyUserID, userID) r = r.WithContext(ctx) } } next.ServeHTTP(w, r) }) } func UserIDFromContext(ctx context.Context) (int64, bool) { id, ok := ctx.Value(ContextKeyUserID).(int64) return id, ok }