1package web
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "net/http"
8 "strings"
9
10 "github.com/charmbracelet/log"
11 "github.com/charmbracelet/soft-serve/server/backend"
12 "github.com/charmbracelet/soft-serve/server/config"
13 "github.com/charmbracelet/soft-serve/server/proto"
14 "github.com/golang-jwt/jwt/v5"
15)
16
17// authenticate authenticates the user from the request.
18func authenticate(r *http.Request) (proto.User, error) {
19 // Prefer the Authorization header
20 user, err := parseAuthHdr(r)
21 if err != nil || user == nil {
22 if errors.Is(err, ErrInvalidToken) || errors.Is(err, ErrInvalidPassword) {
23 return nil, err
24 }
25 return nil, proto.ErrUserNotFound
26 }
27
28 return user, nil
29}
30
31// ErrInvalidPassword is returned when the password is invalid.
32var ErrInvalidPassword = errors.New("invalid password")
33
34func parseUsernamePassword(ctx context.Context, username, password string) (proto.User, error) {
35 logger := log.FromContext(ctx)
36 be := backend.FromContext(ctx)
37
38 if username != "" && password != "" {
39 user, err := be.User(ctx, username)
40 if err == nil && user != nil && backend.VerifyPassword(password, user.Password()) {
41 return user, nil
42 }
43
44 // Try to authenticate using access token as the password
45 user, err = be.UserByAccessToken(ctx, password)
46 if err == nil {
47 return user, nil
48 }
49
50 logger.Error("invalid password or token", "username", username, "err", err)
51 return nil, ErrInvalidPassword
52 } else if username != "" {
53 // Try to authenticate using access token as the username
54 logger.Info("trying to authenticate using access token as username", "username", username)
55 user, err := be.UserByAccessToken(ctx, username)
56 if err == nil {
57 return user, nil
58 }
59
60 logger.Error("failed to get user", "err", err)
61 return nil, ErrInvalidToken
62 }
63
64 return nil, proto.ErrUserNotFound
65}
66
67// ErrInvalidHeader is returned when the authorization header is invalid.
68var ErrInvalidHeader = errors.New("invalid authorization header")
69
70func parseAuthHdr(r *http.Request) (proto.User, error) {
71 // Check for auth header
72 header := r.Header.Get("Authorization")
73 if header == "" {
74 return nil, ErrInvalidHeader
75 }
76
77 ctx := r.Context()
78 logger := log.FromContext(ctx).WithPrefix("http.auth")
79 be := backend.FromContext(ctx)
80
81 logger.Debug("authorization auth header", "header", header)
82
83 parts := strings.SplitN(header, " ", 2)
84 if len(parts) != 2 {
85 return nil, errors.New("invalid authorization header")
86 }
87
88 switch strings.ToLower(parts[0]) {
89 case "token":
90 user, err := be.UserByAccessToken(ctx, parts[1])
91 if err != nil {
92 logger.Error("failed to get user", "err", err)
93 return nil, err
94 }
95
96 return user, nil
97 case "bearer":
98 claims, err := parseJWT(ctx, parts[1])
99 if err != nil {
100 return nil, err
101 }
102
103 // Find the user
104 parts := strings.SplitN(claims.Subject, "#", 2)
105 if len(parts) != 2 {
106 logger.Error("invalid jwt subject", "subject", claims.Subject)
107 return nil, errors.New("invalid jwt subject")
108 }
109
110 user, err := be.User(ctx, parts[0])
111 if err != nil {
112 logger.Error("failed to get user", "err", err)
113 return nil, err
114 }
115
116 expectedSubject := fmt.Sprintf("%s#%d", user.Username(), user.ID())
117 if expectedSubject != claims.Subject {
118 logger.Error("invalid jwt subject", "subject", claims.Subject, "expected", expectedSubject)
119 return nil, errors.New("invalid jwt subject")
120 }
121
122 return user, nil
123 default:
124 username, password, ok := r.BasicAuth()
125 if !ok {
126 return nil, ErrInvalidHeader
127 }
128
129 return parseUsernamePassword(ctx, username, password)
130 }
131}
132
133// ErrInvalidToken is returned when a token is invalid.
134var ErrInvalidToken = errors.New("invalid token")
135
136func parseJWT(ctx context.Context, bearer string) (*jwt.RegisteredClaims, error) {
137 cfg := config.FromContext(ctx)
138 logger := log.FromContext(ctx).WithPrefix("http.auth")
139 kp, err := cfg.SSH.KeyPair()
140 if err != nil {
141 return nil, err
142 }
143
144 repo := proto.RepositoryFromContext(ctx)
145 if repo == nil {
146 return nil, errors.New("missing repository")
147 }
148
149 token, err := jwt.ParseWithClaims(bearer, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) {
150 if _, ok := t.Method.(*jwt.SigningMethodEd25519); !ok {
151 return nil, errors.New("invalid signing method")
152 }
153
154 return kp.CryptoPublicKey(), nil
155 },
156 jwt.WithIssuer(cfg.HTTP.PublicURL),
157 jwt.WithIssuedAt(),
158 jwt.WithAudience(repo.Name()),
159 )
160 if err != nil {
161 logger.Error("failed to parse jwt", "err", err)
162 return nil, ErrInvalidToken
163 }
164
165 claims, ok := token.Claims.(*jwt.RegisteredClaims)
166 if !token.Valid || !ok {
167 return nil, ErrInvalidToken
168 }
169
170 return claims, nil
171}