1package web
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"net/http"
  8	"strings"
  9
 10	"github.com/charmbracelet/log"
 11	"github.com/charmbracelet/soft-serve/server/backend"
 12	"github.com/charmbracelet/soft-serve/server/config"
 13	"github.com/charmbracelet/soft-serve/server/proto"
 14	"github.com/golang-jwt/jwt/v5"
 15)
 16
 17// authenticate authenticates the user from the request.
 18func authenticate(r *http.Request) (proto.User, error) {
 19	// Prefer the Authorization header
 20	user, err := parseAuthHdr(r)
 21	if err != nil || user == nil {
 22		return nil, proto.ErrUserNotFound
 23	}
 24
 25	return user, nil
 26}
 27
 28// ErrInvalidPassword is returned when the password is invalid.
 29var ErrInvalidPassword = errors.New("invalid password")
 30
 31func parseUsernamePassword(ctx context.Context, username, password string) (proto.User, error) {
 32	logger := log.FromContext(ctx)
 33	be := backend.FromContext(ctx)
 34
 35	if username != "" && password != "" {
 36		user, err := be.User(ctx, username)
 37		if err == nil && user != nil && backend.VerifyPassword(password, user.Password()) {
 38			return user, nil
 39		}
 40
 41		// Try to authenticate using access token as the password
 42		user, err = be.UserByAccessToken(ctx, password)
 43		if err == nil {
 44			return user, nil
 45		}
 46
 47		logger.Error("invalid password or token", "username", username, "err", err)
 48		return nil, ErrInvalidPassword
 49	} else if username != "" {
 50		// Try to authenticate using access token as the username
 51		logger.Info("trying to authenticate using access token as username", "username", username)
 52		user, err := be.UserByAccessToken(ctx, username)
 53		if err == nil {
 54			return user, nil
 55		}
 56
 57		logger.Error("failed to get user", "err", err)
 58		return nil, ErrInvalidToken
 59	}
 60
 61	return nil, proto.ErrUserNotFound
 62}
 63
 64// ErrInvalidHeader is returned when the authorization header is invalid.
 65var ErrInvalidHeader = errors.New("invalid authorization header")
 66
 67func parseAuthHdr(r *http.Request) (proto.User, error) {
 68	// Check for auth header
 69	header := r.Header.Get("Authorization")
 70	if header == "" {
 71		return nil, ErrInvalidHeader
 72	}
 73
 74	ctx := r.Context()
 75	logger := log.FromContext(ctx).WithPrefix("http.auth")
 76	be := backend.FromContext(ctx)
 77
 78	logger.Debug("authorization auth header", "header", header)
 79
 80	parts := strings.SplitN(header, " ", 2)
 81	if len(parts) != 2 {
 82		return nil, errors.New("invalid authorization header")
 83	}
 84
 85	switch strings.ToLower(parts[0]) {
 86	case "token":
 87		user, err := be.UserByAccessToken(ctx, parts[1])
 88		if err != nil {
 89			logger.Error("failed to get user", "err", err)
 90			return nil, err
 91		}
 92
 93		return user, nil
 94	case "bearer":
 95		claims, err := parseJWT(ctx, parts[1])
 96		if err != nil {
 97			return nil, err
 98		}
 99
100		// Find the user
101		parts := strings.SplitN(claims.Subject, "#", 2)
102		if len(parts) != 2 {
103			logger.Error("invalid jwt subject", "subject", claims.Subject)
104			return nil, errors.New("invalid jwt subject")
105		}
106
107		user, err := be.User(ctx, parts[0])
108		if err != nil {
109			logger.Error("failed to get user", "err", err)
110			return nil, err
111		}
112
113		expectedSubject := fmt.Sprintf("%s#%d", user.Username(), user.ID())
114		if expectedSubject != claims.Subject {
115			logger.Error("invalid jwt subject", "subject", claims.Subject, "expected", expectedSubject)
116			return nil, errors.New("invalid jwt subject")
117		}
118
119		return user, nil
120	default:
121		username, password, ok := r.BasicAuth()
122		if !ok {
123			return nil, ErrInvalidHeader
124		}
125
126		return parseUsernamePassword(ctx, username, password)
127	}
128}
129
130// ErrInvalidToken is returned when a token is invalid.
131var ErrInvalidToken = errors.New("invalid token")
132
133func parseJWT(ctx context.Context, bearer string) (*jwt.RegisteredClaims, error) {
134	cfg := config.FromContext(ctx)
135	logger := log.FromContext(ctx).WithPrefix("http.auth")
136	kp, err := cfg.SSH.KeyPair()
137	if err != nil {
138		return nil, err
139	}
140
141	repo := proto.RepositoryFromContext(ctx)
142	if repo == nil {
143		return nil, errors.New("missing repository")
144	}
145
146	token, err := jwt.ParseWithClaims(bearer, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) {
147		if _, ok := t.Method.(*jwt.SigningMethodEd25519); !ok {
148			return nil, errors.New("invalid signing method")
149		}
150
151		return kp.CryptoPublicKey(), nil
152	},
153		jwt.WithIssuer(cfg.HTTP.PublicURL),
154		jwt.WithIssuedAt(),
155		jwt.WithAudience(repo.Name()),
156	)
157	if err != nil {
158		logger.Error("failed to parse jwt", "err", err)
159		return nil, ErrInvalidToken
160	}
161
162	claims, ok := token.Claims.(*jwt.RegisteredClaims)
163	if !token.Valid || !ok {
164		return nil, ErrInvalidToken
165	}
166
167	return claims, nil
168}