1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package auth
16
17import (
18 "bytes"
19 "context"
20 "encoding/json"
21 "errors"
22 "fmt"
23 "log/slog"
24 "mime"
25 "net/http"
26 "net/url"
27 "strconv"
28 "strings"
29 "time"
30
31 "cloud.google.com/go/auth/internal"
32 "github.com/googleapis/gax-go/v2/internallog"
33)
34
35// AuthorizationHandler is a 3-legged-OAuth helper that prompts the user for
36// OAuth consent at the specified auth code URL and returns an auth code and
37// state upon approval.
38type AuthorizationHandler func(authCodeURL string) (code string, state string, err error)
39
40// Options3LO are the options for doing a 3-legged OAuth2 flow.
41type Options3LO struct {
42 // ClientID is the application's ID.
43 ClientID string
44 // ClientSecret is the application's secret. Not required if AuthHandlerOpts
45 // is set.
46 ClientSecret string
47 // AuthURL is the URL for authenticating.
48 AuthURL string
49 // TokenURL is the URL for retrieving a token.
50 TokenURL string
51 // AuthStyle is used to describe how to client info in the token request.
52 AuthStyle Style
53 // RefreshToken is the token used to refresh the credential. Not required
54 // if AuthHandlerOpts is set.
55 RefreshToken string
56 // RedirectURL is the URL to redirect users to. Optional.
57 RedirectURL string
58 // Scopes specifies requested permissions for the Token. Optional.
59 Scopes []string
60
61 // URLParams are the set of values to apply to the token exchange. Optional.
62 URLParams url.Values
63 // Client is the client to be used to make the underlying token requests.
64 // Optional.
65 Client *http.Client
66 // EarlyTokenExpiry is the time before the token expires that it should be
67 // refreshed. If not set the default value is 3 minutes and 45 seconds.
68 // Optional.
69 EarlyTokenExpiry time.Duration
70
71 // AuthHandlerOpts provides a set of options for doing a
72 // 3-legged OAuth2 flow with a custom [AuthorizationHandler]. Optional.
73 AuthHandlerOpts *AuthorizationHandlerOptions
74 // Logger is used for debug logging. If provided, logging will be enabled
75 // at the loggers configured level. By default logging is disabled unless
76 // enabled by setting GOOGLE_SDK_GO_LOGGING_LEVEL in which case a default
77 // logger will be used. Optional.
78 Logger *slog.Logger
79}
80
81func (o *Options3LO) validate() error {
82 if o == nil {
83 return errors.New("auth: options must be provided")
84 }
85 if o.ClientID == "" {
86 return errors.New("auth: client ID must be provided")
87 }
88 if o.AuthHandlerOpts == nil && o.ClientSecret == "" {
89 return errors.New("auth: client secret must be provided")
90 }
91 if o.AuthURL == "" {
92 return errors.New("auth: auth URL must be provided")
93 }
94 if o.TokenURL == "" {
95 return errors.New("auth: token URL must be provided")
96 }
97 if o.AuthStyle == StyleUnknown {
98 return errors.New("auth: auth style must be provided")
99 }
100 if o.AuthHandlerOpts == nil && o.RefreshToken == "" {
101 return errors.New("auth: refresh token must be provided")
102 }
103 return nil
104}
105
106func (o *Options3LO) logger() *slog.Logger {
107 return internallog.New(o.Logger)
108}
109
110// PKCEOptions holds parameters to support PKCE.
111type PKCEOptions struct {
112 // Challenge is the un-padded, base64-url-encoded string of the encrypted code verifier.
113 Challenge string // The un-padded, base64-url-encoded string of the encrypted code verifier.
114 // ChallengeMethod is the encryption method (ex. S256).
115 ChallengeMethod string
116 // Verifier is the original, non-encrypted secret.
117 Verifier string // The original, non-encrypted secret.
118}
119
120type tokenJSON struct {
121 AccessToken string `json:"access_token"`
122 TokenType string `json:"token_type"`
123 RefreshToken string `json:"refresh_token"`
124 ExpiresIn int `json:"expires_in"`
125 // error fields
126 ErrorCode string `json:"error"`
127 ErrorDescription string `json:"error_description"`
128 ErrorURI string `json:"error_uri"`
129}
130
131func (e *tokenJSON) expiry() (t time.Time) {
132 if v := e.ExpiresIn; v != 0 {
133 return time.Now().Add(time.Duration(v) * time.Second)
134 }
135 return
136}
137
138func (o *Options3LO) client() *http.Client {
139 if o.Client != nil {
140 return o.Client
141 }
142 return internal.DefaultClient()
143}
144
145// authCodeURL returns a URL that points to a OAuth2 consent page.
146func (o *Options3LO) authCodeURL(state string, values url.Values) string {
147 var buf bytes.Buffer
148 buf.WriteString(o.AuthURL)
149 v := url.Values{
150 "response_type": {"code"},
151 "client_id": {o.ClientID},
152 }
153 if o.RedirectURL != "" {
154 v.Set("redirect_uri", o.RedirectURL)
155 }
156 if len(o.Scopes) > 0 {
157 v.Set("scope", strings.Join(o.Scopes, " "))
158 }
159 if state != "" {
160 v.Set("state", state)
161 }
162 if o.AuthHandlerOpts != nil {
163 if o.AuthHandlerOpts.PKCEOpts != nil &&
164 o.AuthHandlerOpts.PKCEOpts.Challenge != "" {
165 v.Set(codeChallengeKey, o.AuthHandlerOpts.PKCEOpts.Challenge)
166 }
167 if o.AuthHandlerOpts.PKCEOpts != nil &&
168 o.AuthHandlerOpts.PKCEOpts.ChallengeMethod != "" {
169 v.Set(codeChallengeMethodKey, o.AuthHandlerOpts.PKCEOpts.ChallengeMethod)
170 }
171 }
172 for k := range values {
173 v.Set(k, v.Get(k))
174 }
175 if strings.Contains(o.AuthURL, "?") {
176 buf.WriteByte('&')
177 } else {
178 buf.WriteByte('?')
179 }
180 buf.WriteString(v.Encode())
181 return buf.String()
182}
183
184// New3LOTokenProvider returns a [TokenProvider] based on the 3-legged OAuth2
185// configuration. The TokenProvider is caches and auto-refreshes tokens by
186// default.
187func New3LOTokenProvider(opts *Options3LO) (TokenProvider, error) {
188 if err := opts.validate(); err != nil {
189 return nil, err
190 }
191 if opts.AuthHandlerOpts != nil {
192 return new3LOTokenProviderWithAuthHandler(opts), nil
193 }
194 return NewCachedTokenProvider(&tokenProvider3LO{opts: opts, refreshToken: opts.RefreshToken, client: opts.client()}, &CachedTokenProviderOptions{
195 ExpireEarly: opts.EarlyTokenExpiry,
196 }), nil
197}
198
199// AuthorizationHandlerOptions provides a set of options to specify for doing a
200// 3-legged OAuth2 flow with a custom [AuthorizationHandler].
201type AuthorizationHandlerOptions struct {
202 // AuthorizationHandler specifies the handler used to for the authorization
203 // part of the flow.
204 Handler AuthorizationHandler
205 // State is used verify that the "state" is identical in the request and
206 // response before exchanging the auth code for OAuth2 token.
207 State string
208 // PKCEOpts allows setting configurations for PKCE. Optional.
209 PKCEOpts *PKCEOptions
210}
211
212func new3LOTokenProviderWithAuthHandler(opts *Options3LO) TokenProvider {
213 return NewCachedTokenProvider(&tokenProviderWithHandler{opts: opts, state: opts.AuthHandlerOpts.State}, &CachedTokenProviderOptions{
214 ExpireEarly: opts.EarlyTokenExpiry,
215 })
216}
217
218// exchange handles the final exchange portion of the 3lo flow. Returns a Token,
219// refreshToken, and error.
220func (o *Options3LO) exchange(ctx context.Context, code string) (*Token, string, error) {
221 // Build request
222 v := url.Values{
223 "grant_type": {"authorization_code"},
224 "code": {code},
225 }
226 if o.RedirectURL != "" {
227 v.Set("redirect_uri", o.RedirectURL)
228 }
229 if o.AuthHandlerOpts != nil &&
230 o.AuthHandlerOpts.PKCEOpts != nil &&
231 o.AuthHandlerOpts.PKCEOpts.Verifier != "" {
232 v.Set(codeVerifierKey, o.AuthHandlerOpts.PKCEOpts.Verifier)
233 }
234 for k := range o.URLParams {
235 v.Set(k, o.URLParams.Get(k))
236 }
237 return fetchToken(ctx, o, v)
238}
239
240// This struct is not safe for concurrent access alone, but the way it is used
241// in this package by wrapping it with a cachedTokenProvider makes it so.
242type tokenProvider3LO struct {
243 opts *Options3LO
244 client *http.Client
245 refreshToken string
246}
247
248func (tp *tokenProvider3LO) Token(ctx context.Context) (*Token, error) {
249 if tp.refreshToken == "" {
250 return nil, errors.New("auth: token expired and refresh token is not set")
251 }
252 v := url.Values{
253 "grant_type": {"refresh_token"},
254 "refresh_token": {tp.refreshToken},
255 }
256 for k := range tp.opts.URLParams {
257 v.Set(k, tp.opts.URLParams.Get(k))
258 }
259
260 tk, rt, err := fetchToken(ctx, tp.opts, v)
261 if err != nil {
262 return nil, err
263 }
264 if tp.refreshToken != rt && rt != "" {
265 tp.refreshToken = rt
266 }
267 return tk, err
268}
269
270type tokenProviderWithHandler struct {
271 opts *Options3LO
272 state string
273}
274
275func (tp tokenProviderWithHandler) Token(ctx context.Context) (*Token, error) {
276 url := tp.opts.authCodeURL(tp.state, nil)
277 code, state, err := tp.opts.AuthHandlerOpts.Handler(url)
278 if err != nil {
279 return nil, err
280 }
281 if state != tp.state {
282 return nil, errors.New("auth: state mismatch in 3-legged-OAuth flow")
283 }
284 tok, _, err := tp.opts.exchange(ctx, code)
285 return tok, err
286}
287
288// fetchToken returns a Token, refresh token, and/or an error.
289func fetchToken(ctx context.Context, o *Options3LO, v url.Values) (*Token, string, error) {
290 var refreshToken string
291 if o.AuthStyle == StyleInParams {
292 if o.ClientID != "" {
293 v.Set("client_id", o.ClientID)
294 }
295 if o.ClientSecret != "" {
296 v.Set("client_secret", o.ClientSecret)
297 }
298 }
299 req, err := http.NewRequestWithContext(ctx, "POST", o.TokenURL, strings.NewReader(v.Encode()))
300 if err != nil {
301 return nil, refreshToken, err
302 }
303 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
304 if o.AuthStyle == StyleInHeader {
305 req.SetBasicAuth(url.QueryEscape(o.ClientID), url.QueryEscape(o.ClientSecret))
306 }
307 logger := o.logger()
308
309 logger.DebugContext(ctx, "3LO token request", "request", internallog.HTTPRequest(req, []byte(v.Encode())))
310 // Make request
311 resp, body, err := internal.DoRequest(o.client(), req)
312 if err != nil {
313 return nil, refreshToken, err
314 }
315 logger.DebugContext(ctx, "3LO token response", "response", internallog.HTTPResponse(resp, body))
316 failureStatus := resp.StatusCode < 200 || resp.StatusCode > 299
317 tokError := &Error{
318 Response: resp,
319 Body: body,
320 }
321
322 var token *Token
323 // errors ignored because of default switch on content
324 content, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
325 switch content {
326 case "application/x-www-form-urlencoded", "text/plain":
327 // some endpoints return a query string
328 vals, err := url.ParseQuery(string(body))
329 if err != nil {
330 if failureStatus {
331 return nil, refreshToken, tokError
332 }
333 return nil, refreshToken, fmt.Errorf("auth: cannot parse response: %w", err)
334 }
335 tokError.code = vals.Get("error")
336 tokError.description = vals.Get("error_description")
337 tokError.uri = vals.Get("error_uri")
338 token = &Token{
339 Value: vals.Get("access_token"),
340 Type: vals.Get("token_type"),
341 Metadata: make(map[string]interface{}, len(vals)),
342 }
343 for k, v := range vals {
344 token.Metadata[k] = v
345 }
346 refreshToken = vals.Get("refresh_token")
347 e := vals.Get("expires_in")
348 expires, _ := strconv.Atoi(e)
349 if expires != 0 {
350 token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
351 }
352 default:
353 var tj tokenJSON
354 if err = json.Unmarshal(body, &tj); err != nil {
355 if failureStatus {
356 return nil, refreshToken, tokError
357 }
358 return nil, refreshToken, fmt.Errorf("auth: cannot parse json: %w", err)
359 }
360 tokError.code = tj.ErrorCode
361 tokError.description = tj.ErrorDescription
362 tokError.uri = tj.ErrorURI
363 token = &Token{
364 Value: tj.AccessToken,
365 Type: tj.TokenType,
366 Expiry: tj.expiry(),
367 Metadata: make(map[string]interface{}),
368 }
369 json.Unmarshal(body, &token.Metadata) // optional field, skip err check
370 refreshToken = tj.RefreshToken
371 }
372 // according to spec, servers should respond status 400 in error case
373 // https://www.rfc-editor.org/rfc/rfc6749#section-5.2
374 // but some unorthodox servers respond 200 in error case
375 if failureStatus || tokError.code != "" {
376 return nil, refreshToken, tokError
377 }
378 if token.Value == "" {
379 return nil, refreshToken, errors.New("auth: server response missing access_token")
380 }
381 return token, refreshToken, nil
382}