package middleware import ( "context" "net/http" "strings" "github.com/gorilla/sessions" "xorm.io/xorm" ) type contextKey string const ( ContextKeyUserID contextKey = "userID" ContextKeyUsername contextKey = "username" ContextKeyIsAdmin contextKey = "isAdmin" ) // TokenLookupFn is injected to avoid an import cycle with the handlers package. type TokenLookupFn func(db *xorm.Engine, rawToken string) (userID, repoID int64, hasWrite bool, ok bool) type AuthMiddleware struct { store sessions.Store db *xorm.Engine lookupToken TokenLookupFn } func NewAuth(store sessions.Store, db *xorm.Engine, lookupToken TokenLookupFn) *AuthMiddleware { return &AuthMiddleware{store: store, db: db, lookupToken: lookupToken} } func extractBearer(r *http.Request) string { v := r.Header.Get("Authorization") if strings.HasPrefix(v, "Bearer ") { return strings.TrimPrefix(v, "Bearer ") } return "" } func (a *AuthMiddleware) trySession(r *http.Request) (context.Context, bool) { session, err := a.store.Get(r, "fb_session") if err != nil || session.IsNew { return r.Context(), false } userID, ok := session.Values["userID"].(int64) if !ok || userID == 0 { return r.Context(), false } ctx := context.WithValue(r.Context(), ContextKeyUserID, userID) if username, ok := session.Values["username"].(string); ok { ctx = context.WithValue(ctx, ContextKeyUsername, username) } if isAdmin, ok := session.Values["isAdmin"].(bool); ok { ctx = context.WithValue(ctx, ContextKeyIsAdmin, isAdmin) } return ctx, true } func (a *AuthMiddleware) tryBearer(r *http.Request) (context.Context, bool) { raw := extractBearer(r) if raw == "" || a.lookupToken == nil { return r.Context(), false } userID, _, _, ok := a.lookupToken(a.db, raw) if !ok { return r.Context(), false } ctx := context.WithValue(r.Context(), ContextKeyUserID, userID) return ctx, true } func (a *AuthMiddleware) Require(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if ctx, ok := a.trySession(r); ok { next.ServeHTTP(w, r.WithContext(ctx)) return } if ctx, ok := a.tryBearer(r); ok { next.ServeHTTP(w, r.WithContext(ctx)) return } http.Error(w, "unauthorized", http.StatusUnauthorized) }) } func (a *AuthMiddleware) Optional(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if ctx, ok := a.trySession(r); ok { next.ServeHTTP(w, r.WithContext(ctx)) return } if ctx, ok := a.tryBearer(r); ok { next.ServeHTTP(w, r.WithContext(ctx)) return } next.ServeHTTP(w, r) }) } func UserIDFromContext(ctx context.Context) (int64, bool) { id, ok := ctx.Value(ContextKeyUserID).(int64) return id, ok }