auth.go

  1// Package web provides web server functionality.
  2package web
  3
  4import (
  5	"context"
  6	"errors"
  7	"fmt"
  8	"net/http"
  9	"strings"
 10
 11	"github.com/charmbracelet/log/v2"
 12	"github.com/charmbracelet/soft-serve/pkg/backend"
 13	"github.com/charmbracelet/soft-serve/pkg/config"
 14	"github.com/charmbracelet/soft-serve/pkg/proto"
 15	"github.com/golang-jwt/jwt/v5"
 16)
 17
 18// authenticate authenticates the user from the request.
 19func authenticate(r *http.Request) (proto.User, error) {
 20	// Prefer the Authorization header
 21	user, err := parseAuthHdr(r)
 22	if err != nil || user == nil {
 23		if errors.Is(err, ErrInvalidToken) || errors.Is(err, ErrInvalidPassword) {
 24			return nil, err
 25		}
 26		return nil, proto.ErrUserNotFound
 27	}
 28
 29	return user, nil
 30}
 31
 32// ErrInvalidPassword is returned when the password is invalid.
 33var ErrInvalidPassword = errors.New("invalid password")
 34
 35func parseUsernamePassword(ctx context.Context, username, password string) (proto.User, error) {
 36	logger := log.FromContext(ctx)
 37	be := backend.FromContext(ctx)
 38
 39	if username != "" && password != "" {
 40		user, err := be.User(ctx, username)
 41		if err == nil && user != nil && backend.VerifyPassword(password, user.Password()) {
 42			return user, nil
 43		}
 44
 45		// Try to authenticate using access token as the password
 46		user, err = be.UserByAccessToken(ctx, password)
 47		if err == nil {
 48			return user, nil
 49		}
 50
 51		logger.Error("invalid password or token", "username", username, "err", err)
 52		return nil, ErrInvalidPassword
 53	} else if username != "" {
 54		// Try to authenticate using access token as the username
 55		logger.Debug("trying to authenticate using access token as username", "username", username)
 56		user, err := be.UserByAccessToken(ctx, username)
 57		if err == nil {
 58			return user, nil
 59		}
 60
 61		logger.Error("failed to get user", "err", err)
 62		return nil, ErrInvalidToken
 63	}
 64
 65	return nil, proto.ErrUserNotFound
 66}
 67
 68// ErrInvalidHeader is returned when the authorization header is invalid.
 69var ErrInvalidHeader = errors.New("invalid authorization header")
 70
 71func parseAuthHdr(r *http.Request) (proto.User, error) {
 72	// Check for auth header
 73	header := r.Header.Get("Authorization")
 74	if header == "" {
 75		return nil, ErrInvalidHeader
 76	}
 77
 78	ctx := r.Context()
 79	logger := log.FromContext(ctx).WithPrefix("http.auth")
 80	be := backend.FromContext(ctx)
 81
 82	logger.Debug("authorization auth header", "header", header)
 83
 84	parts := strings.SplitN(header, " ", 2)
 85	if len(parts) != 2 {
 86		return nil, errors.New("invalid authorization header")
 87	}
 88
 89	switch strings.ToLower(parts[0]) {
 90	case "token":
 91		user, err := be.UserByAccessToken(ctx, parts[1])
 92		if err != nil {
 93			logger.Error("failed to get user", "err", err)
 94			return nil, err //nolint:wrapcheck
 95		}
 96
 97		return user, nil
 98	case "bearer":
 99		claims, err := parseJWT(ctx, parts[1])
100		if err != nil {
101			return nil, err
102		}
103
104		// Find the user
105		parts := strings.SplitN(claims.Subject, "#", 2)
106		if len(parts) != 2 {
107			logger.Error("invalid jwt subject", "subject", claims.Subject)
108			return nil, errors.New("invalid jwt subject")
109		}
110
111		user, err := be.User(ctx, parts[0])
112		if err != nil {
113			logger.Error("failed to get user", "err", err)
114			return nil, err //nolint:wrapcheck
115		}
116
117		expectedSubject := fmt.Sprintf("%s#%d", user.Username(), user.ID())
118		if expectedSubject != claims.Subject {
119			logger.Error("invalid jwt subject", "subject", claims.Subject, "expected", expectedSubject)
120			return nil, errors.New("invalid jwt subject")
121		}
122
123		return user, nil
124	default:
125		username, password, ok := r.BasicAuth()
126		if !ok {
127			return nil, ErrInvalidHeader
128		}
129
130		return parseUsernamePassword(ctx, username, password)
131	}
132}
133
134// ErrInvalidToken is returned when a token is invalid.
135var ErrInvalidToken = errors.New("invalid token")
136
137func parseJWT(ctx context.Context, bearer string) (*jwt.RegisteredClaims, error) {
138	cfg := config.FromContext(ctx)
139	logger := log.FromContext(ctx).WithPrefix("http.auth")
140	kp, err := config.KeyPair(cfg)
141	if err != nil {
142		return nil, err //nolint:wrapcheck
143	}
144
145	repo := proto.RepositoryFromContext(ctx)
146	if repo == nil {
147		return nil, errors.New("missing repository")
148	}
149
150	token, err := jwt.ParseWithClaims(bearer, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) {
151		if _, ok := t.Method.(*jwt.SigningMethodEd25519); !ok {
152			return nil, errors.New("invalid signing method")
153		}
154
155		return kp.CryptoPublicKey(), nil
156	},
157		jwt.WithIssuer(cfg.HTTP.PublicURL),
158		jwt.WithIssuedAt(),
159		jwt.WithAudience(repo.Name()),
160	)
161	if err != nil {
162		logger.Error("failed to parse jwt", "err", err)
163		return nil, ErrInvalidToken
164	}
165
166	claims, ok := token.Claims.(*jwt.RegisteredClaims)
167	if !token.Valid || !ok {
168		return nil, ErrInvalidToken
169	}
170
171	return claims, nil
172}