auth.go

  1package web
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"net/http"
  8	"strings"
  9
 10	"github.com/charmbracelet/log"
 11	"github.com/charmbracelet/soft-serve/server/backend"
 12	"github.com/charmbracelet/soft-serve/server/config"
 13	"github.com/charmbracelet/soft-serve/server/proto"
 14	"github.com/golang-jwt/jwt/v5"
 15)
 16
 17// authenticate authenticates the user from the request.
 18func authenticate(r *http.Request) (proto.User, error) {
 19	// Prefer the Authorization header
 20	user, err := parseAuthHdr(r)
 21	if err != nil || user == nil {
 22		if errors.Is(err, ErrInvalidToken) || errors.Is(err, ErrInvalidPassword) {
 23			return nil, err
 24		}
 25		return nil, proto.ErrUserNotFound
 26	}
 27
 28	return user, nil
 29}
 30
 31// ErrInvalidPassword is returned when the password is invalid.
 32var ErrInvalidPassword = errors.New("invalid password")
 33
 34func parseUsernamePassword(ctx context.Context, username, password string) (proto.User, error) {
 35	logger := log.FromContext(ctx)
 36	be := backend.FromContext(ctx)
 37
 38	if username != "" && password != "" {
 39		user, err := be.User(ctx, username)
 40		if err == nil && user != nil && backend.VerifyPassword(password, user.Password()) {
 41			return user, nil
 42		}
 43
 44		// Try to authenticate using access token as the password
 45		user, err = be.UserByAccessToken(ctx, password)
 46		if err == nil {
 47			return user, nil
 48		}
 49
 50		logger.Error("invalid password or token", "username", username, "err", err)
 51		return nil, ErrInvalidPassword
 52	} else if username != "" {
 53		// Try to authenticate using access token as the username
 54		logger.Info("trying to authenticate using access token as username", "username", username)
 55		user, err := be.UserByAccessToken(ctx, username)
 56		if err == nil {
 57			return user, nil
 58		}
 59
 60		logger.Error("failed to get user", "err", err)
 61		return nil, ErrInvalidToken
 62	}
 63
 64	return nil, proto.ErrUserNotFound
 65}
 66
 67// ErrInvalidHeader is returned when the authorization header is invalid.
 68var ErrInvalidHeader = errors.New("invalid authorization header")
 69
 70func parseAuthHdr(r *http.Request) (proto.User, error) {
 71	// Check for auth header
 72	header := r.Header.Get("Authorization")
 73	if header == "" {
 74		return nil, ErrInvalidHeader
 75	}
 76
 77	ctx := r.Context()
 78	logger := log.FromContext(ctx).WithPrefix("http.auth")
 79	be := backend.FromContext(ctx)
 80
 81	logger.Debug("authorization auth header", "header", header)
 82
 83	parts := strings.SplitN(header, " ", 2)
 84	if len(parts) != 2 {
 85		return nil, errors.New("invalid authorization header")
 86	}
 87
 88	switch strings.ToLower(parts[0]) {
 89	case "token":
 90		user, err := be.UserByAccessToken(ctx, parts[1])
 91		if err != nil {
 92			logger.Error("failed to get user", "err", err)
 93			return nil, err
 94		}
 95
 96		return user, nil
 97	case "bearer":
 98		claims, err := parseJWT(ctx, parts[1])
 99		if err != nil {
100			return nil, err
101		}
102
103		// Find the user
104		parts := strings.SplitN(claims.Subject, "#", 2)
105		if len(parts) != 2 {
106			logger.Error("invalid jwt subject", "subject", claims.Subject)
107			return nil, errors.New("invalid jwt subject")
108		}
109
110		user, err := be.User(ctx, parts[0])
111		if err != nil {
112			logger.Error("failed to get user", "err", err)
113			return nil, err
114		}
115
116		expectedSubject := fmt.Sprintf("%s#%d", user.Username(), user.ID())
117		if expectedSubject != claims.Subject {
118			logger.Error("invalid jwt subject", "subject", claims.Subject, "expected", expectedSubject)
119			return nil, errors.New("invalid jwt subject")
120		}
121
122		return user, nil
123	default:
124		username, password, ok := r.BasicAuth()
125		if !ok {
126			return nil, ErrInvalidHeader
127		}
128
129		return parseUsernamePassword(ctx, username, password)
130	}
131}
132
133// ErrInvalidToken is returned when a token is invalid.
134var ErrInvalidToken = errors.New("invalid token")
135
136func parseJWT(ctx context.Context, bearer string) (*jwt.RegisteredClaims, error) {
137	cfg := config.FromContext(ctx)
138	logger := log.FromContext(ctx).WithPrefix("http.auth")
139	kp, err := cfg.SSH.KeyPair()
140	if err != nil {
141		return nil, err
142	}
143
144	repo := proto.RepositoryFromContext(ctx)
145	if repo == nil {
146		return nil, errors.New("missing repository")
147	}
148
149	token, err := jwt.ParseWithClaims(bearer, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) {
150		if _, ok := t.Method.(*jwt.SigningMethodEd25519); !ok {
151			return nil, errors.New("invalid signing method")
152		}
153
154		return kp.CryptoPublicKey(), nil
155	},
156		jwt.WithIssuer(cfg.HTTP.PublicURL),
157		jwt.WithIssuedAt(),
158		jwt.WithAudience(repo.Name()),
159	)
160	if err != nil {
161		logger.Error("failed to parse jwt", "err", err)
162		return nil, ErrInvalidToken
163	}
164
165	claims, ok := token.Claims.(*jwt.RegisteredClaims)
166	if !token.Valid || !ok {
167		return nil, ErrInvalidToken
168	}
169
170	return claims, nil
171}