1package web
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "net/http"
8 "strings"
9
10 "github.com/charmbracelet/log"
11 "github.com/charmbracelet/soft-serve/server/backend"
12 "github.com/charmbracelet/soft-serve/server/config"
13 "github.com/charmbracelet/soft-serve/server/proto"
14 "github.com/golang-jwt/jwt/v5"
15)
16
17// authenticate authenticates the user from the request.
18func authenticate(r *http.Request) (proto.User, error) {
19 ctx := r.Context()
20 logger := log.FromContext(ctx)
21
22 // Check for auth header
23 header := r.Header.Get("Authorization")
24 if header != "" {
25 logger.Debug("authorization", "header", header)
26
27 parts := strings.SplitN(header, " ", 2)
28 if len(parts) != 2 {
29 return nil, errors.New("invalid authorization header")
30 }
31
32 // TODO: add basic, and token types
33 be := backend.FromContext(ctx)
34 switch strings.ToLower(parts[0]) {
35 case "bearer":
36 claims, err := getJWTClaims(ctx, parts[1])
37 if err != nil {
38 return nil, err
39 }
40
41 // Find the user
42 parts := strings.SplitN(claims.Subject, "#", 2)
43 if len(parts) != 2 {
44 logger.Error("invalid jwt subject", "subject", claims.Subject)
45 return nil, errors.New("invalid jwt subject")
46 }
47
48 user, err := be.User(ctx, parts[0])
49 if err != nil {
50 logger.Error("failed to get user", "err", err)
51 return nil, err
52 }
53
54 expectedSubject := fmt.Sprintf("%s#%d", user.Username(), user.ID())
55 if expectedSubject != claims.Subject {
56 logger.Error("invalid jwt subject", "subject", claims.Subject, "expected", expectedSubject)
57 return nil, errors.New("invalid jwt subject")
58 }
59
60 return user, nil
61 default:
62 return nil, errors.New("invalid authorization header")
63 }
64 }
65
66 logger.Debug("no authorization header")
67
68 return nil, proto.ErrUserNotFound
69}
70
71// ErrInvalidToken is returned when a token is invalid.
72var ErrInvalidToken = errors.New("invalid token")
73
74func getJWTClaims(ctx context.Context, bearer string) (*jwt.RegisteredClaims, error) {
75 cfg := config.FromContext(ctx)
76 logger := log.FromContext(ctx).WithPrefix("http.auth")
77 kp, err := cfg.SSH.KeyPair()
78 if err != nil {
79 return nil, err
80 }
81
82 repo := proto.RepositoryFromContext(ctx)
83 if repo == nil {
84 return nil, errors.New("missing repository")
85 }
86
87 token, err := jwt.ParseWithClaims(bearer, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) {
88 if _, ok := t.Method.(*jwt.SigningMethodEd25519); !ok {
89 return nil, errors.New("invalid signing method")
90 }
91
92 return kp.CryptoPublicKey(), nil
93 },
94 jwt.WithIssuer(cfg.HTTP.PublicURL),
95 jwt.WithIssuedAt(),
96 jwt.WithAudience(repo.Name()),
97 )
98 if err != nil {
99 logger.Error("failed to parse jwt", "err", err)
100 return nil, ErrInvalidToken
101 }
102
103 claims, ok := token.Claims.(*jwt.RegisteredClaims)
104 if !token.Valid || !ok {
105 return nil, ErrInvalidToken
106 }
107
108 return claims, nil
109}