auth.go

  1package web
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"net/http"
  8	"strings"
  9
 10	"github.com/charmbracelet/log"
 11	"github.com/charmbracelet/soft-serve/server/backend"
 12	"github.com/charmbracelet/soft-serve/server/config"
 13	"github.com/charmbracelet/soft-serve/server/proto"
 14	"github.com/golang-jwt/jwt/v5"
 15)
 16
 17// authenticate authenticates the user from the request.
 18func authenticate(r *http.Request) (proto.User, error) {
 19	ctx := r.Context()
 20	logger := log.FromContext(ctx)
 21
 22	// Check for auth header
 23	header := r.Header.Get("Authorization")
 24	if header != "" {
 25		logger.Debug("authorization", "header", header)
 26
 27		parts := strings.SplitN(header, " ", 2)
 28		if len(parts) != 2 {
 29			return nil, errors.New("invalid authorization header")
 30		}
 31
 32		// TODO: add basic, and token types
 33		be := backend.FromContext(ctx)
 34		switch strings.ToLower(parts[0]) {
 35		case "bearer":
 36			claims, err := getJWTClaims(ctx, parts[1])
 37			if err != nil {
 38				return nil, err
 39			}
 40
 41			// Find the user
 42			parts := strings.SplitN(claims.Subject, "#", 2)
 43			if len(parts) != 2 {
 44				logger.Error("invalid jwt subject", "subject", claims.Subject)
 45				return nil, errors.New("invalid jwt subject")
 46			}
 47
 48			user, err := be.User(ctx, parts[0])
 49			if err != nil {
 50				logger.Error("failed to get user", "err", err)
 51				return nil, err
 52			}
 53
 54			expectedSubject := fmt.Sprintf("%s#%d", user.Username(), user.ID())
 55			if expectedSubject != claims.Subject {
 56				logger.Error("invalid jwt subject", "subject", claims.Subject, "expected", expectedSubject)
 57				return nil, errors.New("invalid jwt subject")
 58			}
 59
 60			return user, nil
 61		default:
 62			return nil, errors.New("invalid authorization header")
 63		}
 64	}
 65
 66	logger.Debug("no authorization header")
 67
 68	return nil, proto.ErrUserNotFound
 69}
 70
 71// ErrInvalidToken is returned when a token is invalid.
 72var ErrInvalidToken = errors.New("invalid token")
 73
 74func getJWTClaims(ctx context.Context, bearer string) (*jwt.RegisteredClaims, error) {
 75	cfg := config.FromContext(ctx)
 76	logger := log.FromContext(ctx).WithPrefix("http.auth")
 77	kp, err := cfg.SSH.KeyPair()
 78	if err != nil {
 79		return nil, err
 80	}
 81
 82	repo := proto.RepositoryFromContext(ctx)
 83	if repo == nil {
 84		return nil, errors.New("missing repository")
 85	}
 86
 87	token, err := jwt.ParseWithClaims(bearer, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) {
 88		if _, ok := t.Method.(*jwt.SigningMethodEd25519); !ok {
 89			return nil, errors.New("invalid signing method")
 90		}
 91
 92		return kp.CryptoPublicKey(), nil
 93	},
 94		jwt.WithIssuer(cfg.HTTP.PublicURL),
 95		jwt.WithIssuedAt(),
 96		jwt.WithAudience(repo.Name()),
 97	)
 98	if err != nil {
 99		logger.Error("failed to parse jwt", "err", err)
100		return nil, ErrInvalidToken
101	}
102
103	claims, ok := token.Claims.(*jwt.RegisteredClaims)
104	if !token.Valid || !ok {
105		return nil, ErrInvalidToken
106	}
107
108	return claims, nil
109}