package middleware import ( "context" "database/sql" "net/http" "regexp" "strings" "owo.codes/whats-this/api/lib/apierrors" "owo.codes/whats-this/api/lib/db" "github.com/rs/zerolog/log" ) // UUIDRegex is a UUID regex. See http://stackoverflow.com/a/13653180. var UUIDRegex = regexp.MustCompile("^[0-9a-f]{8}-[0-9a-f]{4}-[1-5][0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$") type authKey struct{} // AuthorizedUserKey is the context value key for storing request authorization // information. var AuthorizedUserKey authKey // Authenticator creates middleware that authenticates incoming requests using // a token and checking it against the database. func Authenticator(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() // Get request token token := r.Header.Get("Authorization") if token == "" { token = r.URL.Query().Get("key") } if token == "" { token = r.URL.Query().Get("apikey") } if token == "" { // No token present, move on next.ServeHTTP(w, r) return } token = strings.ToLower(token) if !UUIDRegex.MatchString(token) { panic(apierrors.BadToken) } // Select the user and add to request user, err := db.SelectUserByToken(token) switch { case err == sql.ErrNoRows: panic(apierrors.Unauthorized) case err != nil: log.Error().Err(err).Msg("failed to run SELECT query on database") panic(apierrors.InternalServerError) } if user.IsBlocked { panic(apierrors.Unauthorized) } ctx = context.WithValue(ctx, AuthorizedUserKey, user) next.ServeHTTP(w, r.WithContext(ctx)) } return http.HandlerFunc(fn) } // GetAuthorizedUser returns the current authorized user information. func GetAuthorizedUser(r *http.Request) db.User { user, _ := r.Context().Value(AuthorizedUserKey).(db.User) return user }