105 lines
2.7 KiB
Go
105 lines
2.7 KiB
Go
package middleware
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"strings"
|
|
|
|
"github.com/gorilla/sessions"
|
|
"xorm.io/xorm"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const (
|
|
ContextKeyUserID contextKey = "userID"
|
|
ContextKeyUsername contextKey = "username"
|
|
ContextKeyIsAdmin contextKey = "isAdmin"
|
|
)
|
|
|
|
// TokenLookupFn is injected to avoid an import cycle with the handlers package.
|
|
type TokenLookupFn func(db *xorm.Engine, rawToken string) (userID, repoID int64, hasWrite bool, ok bool)
|
|
|
|
type AuthMiddleware struct {
|
|
store sessions.Store
|
|
db *xorm.Engine
|
|
lookupToken TokenLookupFn
|
|
}
|
|
|
|
func NewAuth(store sessions.Store, db *xorm.Engine, lookupToken TokenLookupFn) *AuthMiddleware {
|
|
return &AuthMiddleware{store: store, db: db, lookupToken: lookupToken}
|
|
}
|
|
|
|
func extractBearer(r *http.Request) string {
|
|
v := r.Header.Get("Authorization")
|
|
if strings.HasPrefix(v, "Bearer ") {
|
|
return strings.TrimPrefix(v, "Bearer ")
|
|
}
|
|
return ""
|
|
}
|
|
|
|
func (a *AuthMiddleware) trySession(r *http.Request) (context.Context, bool) {
|
|
session, err := a.store.Get(r, "fb_session")
|
|
if err != nil || session.IsNew {
|
|
return r.Context(), false
|
|
}
|
|
userID, ok := session.Values["userID"].(int64)
|
|
if !ok || userID == 0 {
|
|
return r.Context(), false
|
|
}
|
|
ctx := context.WithValue(r.Context(), ContextKeyUserID, userID)
|
|
if username, ok := session.Values["username"].(string); ok {
|
|
ctx = context.WithValue(ctx, ContextKeyUsername, username)
|
|
}
|
|
if isAdmin, ok := session.Values["isAdmin"].(bool); ok {
|
|
ctx = context.WithValue(ctx, ContextKeyIsAdmin, isAdmin)
|
|
}
|
|
return ctx, true
|
|
}
|
|
|
|
func (a *AuthMiddleware) tryBearer(r *http.Request) (context.Context, bool) {
|
|
raw := extractBearer(r)
|
|
if raw == "" || a.lookupToken == nil {
|
|
return r.Context(), false
|
|
}
|
|
userID, _, _, ok := a.lookupToken(a.db, raw)
|
|
if !ok {
|
|
return r.Context(), false
|
|
}
|
|
ctx := context.WithValue(r.Context(), ContextKeyUserID, userID)
|
|
return ctx, true
|
|
}
|
|
|
|
func (a *AuthMiddleware) Require(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if ctx, ok := a.trySession(r); ok {
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
if ctx, ok := a.tryBearer(r); ok {
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
|
})
|
|
}
|
|
|
|
func (a *AuthMiddleware) Optional(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
if ctx, ok := a.trySession(r); ok {
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
if ctx, ok := a.tryBearer(r); ok {
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
return
|
|
}
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func UserIDFromContext(ctx context.Context) (int64, bool) {
|
|
id, ok := ctx.Value(ContextKeyUserID).(int64)
|
|
return id, ok
|
|
}
|