1// Package web provides web server functionality.
2package web
3
4import (
5 "context"
6 "errors"
7 "fmt"
8 "net/http"
9 "strings"
10
11 "github.com/charmbracelet/log/v2"
12 "github.com/charmbracelet/soft-serve/pkg/backend"
13 "github.com/charmbracelet/soft-serve/pkg/config"
14 "github.com/charmbracelet/soft-serve/pkg/proto"
15 "github.com/golang-jwt/jwt/v5"
16)
17
18// authenticate authenticates the user from the request.
19func authenticate(r *http.Request) (proto.User, error) {
20 // Prefer the Authorization header
21 user, err := parseAuthHdr(r)
22 if err != nil || user == nil {
23 if errors.Is(err, ErrInvalidToken) || errors.Is(err, ErrInvalidPassword) {
24 return nil, err
25 }
26 return nil, proto.ErrUserNotFound
27 }
28
29 return user, nil
30}
31
32// ErrInvalidPassword is returned when the password is invalid.
33var ErrInvalidPassword = errors.New("invalid password")
34
35func parseUsernamePassword(ctx context.Context, username, password string) (proto.User, error) {
36 logger := log.FromContext(ctx)
37 be := backend.FromContext(ctx)
38
39 if username != "" && password != "" {
40 user, err := be.User(ctx, username)
41 if err == nil && user != nil && backend.VerifyPassword(password, user.Password()) {
42 return user, nil
43 }
44
45 // Try to authenticate using access token as the password
46 user, err = be.UserByAccessToken(ctx, password)
47 if err == nil {
48 return user, nil
49 }
50
51 logger.Error("invalid password or token", "username", username, "err", err)
52 return nil, ErrInvalidPassword
53 } else if username != "" {
54 // Try to authenticate using access token as the username
55 logger.Debug("trying to authenticate using access token as username", "username", username)
56 user, err := be.UserByAccessToken(ctx, username)
57 if err == nil {
58 return user, nil
59 }
60
61 logger.Error("failed to get user", "err", err)
62 return nil, ErrInvalidToken
63 }
64
65 return nil, proto.ErrUserNotFound
66}
67
68// ErrInvalidHeader is returned when the authorization header is invalid.
69var ErrInvalidHeader = errors.New("invalid authorization header")
70
71func parseAuthHdr(r *http.Request) (proto.User, error) {
72 // Check for auth header
73 header := r.Header.Get("Authorization")
74 if header == "" {
75 return nil, ErrInvalidHeader
76 }
77
78 ctx := r.Context()
79 logger := log.FromContext(ctx).WithPrefix("http.auth")
80 be := backend.FromContext(ctx)
81
82 logger.Debug("authorization auth header", "header", header)
83
84 parts := strings.SplitN(header, " ", 2)
85 if len(parts) != 2 {
86 return nil, errors.New("invalid authorization header")
87 }
88
89 switch strings.ToLower(parts[0]) {
90 case "token":
91 user, err := be.UserByAccessToken(ctx, parts[1])
92 if err != nil {
93 logger.Error("failed to get user", "err", err)
94 return nil, err //nolint:wrapcheck
95 }
96
97 return user, nil
98 case "bearer":
99 claims, err := parseJWT(ctx, parts[1])
100 if err != nil {
101 return nil, err
102 }
103
104 // Find the user
105 parts := strings.SplitN(claims.Subject, "#", 2)
106 if len(parts) != 2 {
107 logger.Error("invalid jwt subject", "subject", claims.Subject)
108 return nil, errors.New("invalid jwt subject")
109 }
110
111 user, err := be.User(ctx, parts[0])
112 if err != nil {
113 logger.Error("failed to get user", "err", err)
114 return nil, err //nolint:wrapcheck
115 }
116
117 expectedSubject := fmt.Sprintf("%s#%d", user.Username(), user.ID())
118 if expectedSubject != claims.Subject {
119 logger.Error("invalid jwt subject", "subject", claims.Subject, "expected", expectedSubject)
120 return nil, errors.New("invalid jwt subject")
121 }
122
123 return user, nil
124 default:
125 username, password, ok := r.BasicAuth()
126 if !ok {
127 return nil, ErrInvalidHeader
128 }
129
130 return parseUsernamePassword(ctx, username, password)
131 }
132}
133
134// ErrInvalidToken is returned when a token is invalid.
135var ErrInvalidToken = errors.New("invalid token")
136
137func parseJWT(ctx context.Context, bearer string) (*jwt.RegisteredClaims, error) {
138 cfg := config.FromContext(ctx)
139 logger := log.FromContext(ctx).WithPrefix("http.auth")
140 kp, err := config.KeyPair(cfg)
141 if err != nil {
142 return nil, err //nolint:wrapcheck
143 }
144
145 repo := proto.RepositoryFromContext(ctx)
146 if repo == nil {
147 return nil, errors.New("missing repository")
148 }
149
150 token, err := jwt.ParseWithClaims(bearer, &jwt.RegisteredClaims{}, func(t *jwt.Token) (interface{}, error) {
151 if _, ok := t.Method.(*jwt.SigningMethodEd25519); !ok {
152 return nil, errors.New("invalid signing method")
153 }
154
155 return kp.CryptoPublicKey(), nil
156 },
157 jwt.WithIssuer(cfg.HTTP.PublicURL),
158 jwt.WithIssuedAt(),
159 jwt.WithAudience(repo.Name()),
160 )
161 if err != nil {
162 logger.Error("failed to parse jwt", "err", err)
163 return nil, ErrInvalidToken
164 }
165
166 claims, ok := token.Claims.(*jwt.RegisteredClaims)
167 if !token.Valid || !ok {
168 return nil, ErrInvalidToken
169 }
170
171 return claims, nil
172}