1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15// Package auth provides utilities for managing Google Cloud credentials,
16// including functionality for creating, caching, and refreshing OAuth2 tokens.
17// It offers customizable options for different OAuth2 flows, such as 2-legged
18// (2LO) and 3-legged (3LO) OAuth, along with support for PKCE and automatic
19// token management.
20package auth
21
22import (
23 "context"
24 "encoding/json"
25 "errors"
26 "fmt"
27 "log/slog"
28 "net/http"
29 "net/url"
30 "strings"
31 "sync"
32 "time"
33
34 "cloud.google.com/go/auth/internal"
35 "cloud.google.com/go/auth/internal/jwt"
36 "github.com/googleapis/gax-go/v2/internallog"
37)
38
39const (
40 // Parameter keys for AuthCodeURL method to support PKCE.
41 codeChallengeKey = "code_challenge"
42 codeChallengeMethodKey = "code_challenge_method"
43
44 // Parameter key for Exchange method to support PKCE.
45 codeVerifierKey = "code_verifier"
46
47 // 3 minutes and 45 seconds before expiration. The shortest MDS cache is 4 minutes,
48 // so we give it 15 seconds to refresh it's cache before attempting to refresh a token.
49 defaultExpiryDelta = 225 * time.Second
50
51 universeDomainDefault = "googleapis.com"
52)
53
54// tokenState represents different states for a [Token].
55type tokenState int
56
57const (
58 // fresh indicates that the [Token] is valid. It is not expired or close to
59 // expired, or the token has no expiry.
60 fresh tokenState = iota
61 // stale indicates that the [Token] is close to expired, and should be
62 // refreshed. The token can be used normally.
63 stale
64 // invalid indicates that the [Token] is expired or invalid. The token
65 // cannot be used for a normal operation.
66 invalid
67)
68
69var (
70 defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
71 defaultHeader = &jwt.Header{Algorithm: jwt.HeaderAlgRSA256, Type: jwt.HeaderType}
72
73 // for testing
74 timeNow = time.Now
75)
76
77// TokenProvider specifies an interface for anything that can return a token.
78type TokenProvider interface {
79 // Token returns a Token or an error.
80 // The Token returned must be safe to use
81 // concurrently.
82 // The returned Token must not be modified.
83 // The context provided must be sent along to any requests that are made in
84 // the implementing code.
85 Token(context.Context) (*Token, error)
86}
87
88// Token holds the credential token used to authorized requests. All fields are
89// considered read-only.
90type Token struct {
91 // Value is the token used to authorize requests. It is usually an access
92 // token but may be other types of tokens such as ID tokens in some flows.
93 Value string
94 // Type is the type of token Value is. If uninitialized, it should be
95 // assumed to be a "Bearer" token.
96 Type string
97 // Expiry is the time the token is set to expire.
98 Expiry time.Time
99 // Metadata may include, but is not limited to, the body of the token
100 // response returned by the server.
101 Metadata map[string]interface{} // TODO(codyoss): maybe make a method to flatten metadata to avoid []string for url.Values
102}
103
104// IsValid reports that a [Token] is non-nil, has a [Token.Value], and has not
105// expired. A token is considered expired if [Token.Expiry] has passed or will
106// pass in the next 225 seconds.
107func (t *Token) IsValid() bool {
108 return t.isValidWithEarlyExpiry(defaultExpiryDelta)
109}
110
111// MetadataString is a convenience method for accessing string values in the
112// token's metadata. Returns an empty string if the metadata is nil or the value
113// for the given key cannot be cast to a string.
114func (t *Token) MetadataString(k string) string {
115 if t.Metadata == nil {
116 return ""
117 }
118 s, ok := t.Metadata[k].(string)
119 if !ok {
120 return ""
121 }
122 return s
123}
124
125func (t *Token) isValidWithEarlyExpiry(earlyExpiry time.Duration) bool {
126 if t.isEmpty() {
127 return false
128 }
129 if t.Expiry.IsZero() {
130 return true
131 }
132 return !t.Expiry.Round(0).Add(-earlyExpiry).Before(timeNow())
133}
134
135func (t *Token) isEmpty() bool {
136 return t == nil || t.Value == ""
137}
138
139// Credentials holds Google credentials, including
140// [Application Default Credentials].
141//
142// [Application Default Credentials]: https://developers.google.com/accounts/docs/application-default-credentials
143type Credentials struct {
144 json []byte
145 projectID CredentialsPropertyProvider
146 quotaProjectID CredentialsPropertyProvider
147 // universeDomain is the default service domain for a given Cloud universe.
148 universeDomain CredentialsPropertyProvider
149
150 TokenProvider
151}
152
153// JSON returns the bytes associated with the the file used to source
154// credentials if one was used.
155func (c *Credentials) JSON() []byte {
156 return c.json
157}
158
159// ProjectID returns the associated project ID from the underlying file or
160// environment.
161func (c *Credentials) ProjectID(ctx context.Context) (string, error) {
162 if c.projectID == nil {
163 return internal.GetProjectID(c.json, ""), nil
164 }
165 v, err := c.projectID.GetProperty(ctx)
166 if err != nil {
167 return "", err
168 }
169 return internal.GetProjectID(c.json, v), nil
170}
171
172// QuotaProjectID returns the associated quota project ID from the underlying
173// file or environment.
174func (c *Credentials) QuotaProjectID(ctx context.Context) (string, error) {
175 if c.quotaProjectID == nil {
176 return internal.GetQuotaProject(c.json, ""), nil
177 }
178 v, err := c.quotaProjectID.GetProperty(ctx)
179 if err != nil {
180 return "", err
181 }
182 return internal.GetQuotaProject(c.json, v), nil
183}
184
185// UniverseDomain returns the default service domain for a given Cloud universe.
186// The default value is "googleapis.com".
187func (c *Credentials) UniverseDomain(ctx context.Context) (string, error) {
188 if c.universeDomain == nil {
189 return universeDomainDefault, nil
190 }
191 v, err := c.universeDomain.GetProperty(ctx)
192 if err != nil {
193 return "", err
194 }
195 if v == "" {
196 return universeDomainDefault, nil
197 }
198 return v, err
199}
200
201// CredentialsPropertyProvider provides an implementation to fetch a property
202// value for [Credentials].
203type CredentialsPropertyProvider interface {
204 GetProperty(context.Context) (string, error)
205}
206
207// CredentialsPropertyFunc is a type adapter to allow the use of ordinary
208// functions as a [CredentialsPropertyProvider].
209type CredentialsPropertyFunc func(context.Context) (string, error)
210
211// GetProperty loads the properly value provided the given context.
212func (p CredentialsPropertyFunc) GetProperty(ctx context.Context) (string, error) {
213 return p(ctx)
214}
215
216// CredentialsOptions are used to configure [Credentials].
217type CredentialsOptions struct {
218 // TokenProvider is a means of sourcing a token for the credentials. Required.
219 TokenProvider TokenProvider
220 // JSON is the raw contents of the credentials file if sourced from a file.
221 JSON []byte
222 // ProjectIDProvider resolves the project ID associated with the
223 // credentials.
224 ProjectIDProvider CredentialsPropertyProvider
225 // QuotaProjectIDProvider resolves the quota project ID associated with the
226 // credentials.
227 QuotaProjectIDProvider CredentialsPropertyProvider
228 // UniverseDomainProvider resolves the universe domain with the credentials.
229 UniverseDomainProvider CredentialsPropertyProvider
230}
231
232// NewCredentials returns new [Credentials] from the provided options.
233func NewCredentials(opts *CredentialsOptions) *Credentials {
234 creds := &Credentials{
235 TokenProvider: opts.TokenProvider,
236 json: opts.JSON,
237 projectID: opts.ProjectIDProvider,
238 quotaProjectID: opts.QuotaProjectIDProvider,
239 universeDomain: opts.UniverseDomainProvider,
240 }
241
242 return creds
243}
244
245// CachedTokenProviderOptions provides options for configuring a cached
246// [TokenProvider].
247type CachedTokenProviderOptions struct {
248 // DisableAutoRefresh makes the TokenProvider always return the same token,
249 // even if it is expired. The default is false. Optional.
250 DisableAutoRefresh bool
251 // ExpireEarly configures the amount of time before a token expires, that it
252 // should be refreshed. If unset, the default value is 3 minutes and 45
253 // seconds. Optional.
254 ExpireEarly time.Duration
255 // DisableAsyncRefresh configures a synchronous workflow that refreshes
256 // tokens in a blocking manner. The default is false. Optional.
257 DisableAsyncRefresh bool
258}
259
260func (ctpo *CachedTokenProviderOptions) autoRefresh() bool {
261 if ctpo == nil {
262 return true
263 }
264 return !ctpo.DisableAutoRefresh
265}
266
267func (ctpo *CachedTokenProviderOptions) expireEarly() time.Duration {
268 if ctpo == nil || ctpo.ExpireEarly == 0 {
269 return defaultExpiryDelta
270 }
271 return ctpo.ExpireEarly
272}
273
274func (ctpo *CachedTokenProviderOptions) blockingRefresh() bool {
275 if ctpo == nil {
276 return false
277 }
278 return ctpo.DisableAsyncRefresh
279}
280
281// NewCachedTokenProvider wraps a [TokenProvider] to cache the tokens returned
282// by the underlying provider. By default it will refresh tokens asynchronously
283// a few minutes before they expire.
284func NewCachedTokenProvider(tp TokenProvider, opts *CachedTokenProviderOptions) TokenProvider {
285 if ctp, ok := tp.(*cachedTokenProvider); ok {
286 return ctp
287 }
288 return &cachedTokenProvider{
289 tp: tp,
290 autoRefresh: opts.autoRefresh(),
291 expireEarly: opts.expireEarly(),
292 blockingRefresh: opts.blockingRefresh(),
293 }
294}
295
296type cachedTokenProvider struct {
297 tp TokenProvider
298 autoRefresh bool
299 expireEarly time.Duration
300 blockingRefresh bool
301
302 mu sync.Mutex
303 cachedToken *Token
304 // isRefreshRunning ensures that the non-blocking refresh will only be
305 // attempted once, even if multiple callers enter the Token method.
306 isRefreshRunning bool
307 // isRefreshErr ensures that the non-blocking refresh will only be attempted
308 // once per refresh window if an error is encountered.
309 isRefreshErr bool
310}
311
312func (c *cachedTokenProvider) Token(ctx context.Context) (*Token, error) {
313 if c.blockingRefresh {
314 return c.tokenBlocking(ctx)
315 }
316 return c.tokenNonBlocking(ctx)
317}
318
319func (c *cachedTokenProvider) tokenNonBlocking(ctx context.Context) (*Token, error) {
320 switch c.tokenState() {
321 case fresh:
322 c.mu.Lock()
323 defer c.mu.Unlock()
324 return c.cachedToken, nil
325 case stale:
326 // Call tokenAsync with a new Context because the user-provided context
327 // may have a short timeout incompatible with async token refresh.
328 c.tokenAsync(context.Background())
329 // Return the stale token immediately to not block customer requests to Cloud services.
330 c.mu.Lock()
331 defer c.mu.Unlock()
332 return c.cachedToken, nil
333 default: // invalid
334 return c.tokenBlocking(ctx)
335 }
336}
337
338// tokenState reports the token's validity.
339func (c *cachedTokenProvider) tokenState() tokenState {
340 c.mu.Lock()
341 defer c.mu.Unlock()
342 t := c.cachedToken
343 now := timeNow()
344 if t == nil || t.Value == "" {
345 return invalid
346 } else if t.Expiry.IsZero() {
347 return fresh
348 } else if now.After(t.Expiry.Round(0)) {
349 return invalid
350 } else if now.After(t.Expiry.Round(0).Add(-c.expireEarly)) {
351 return stale
352 }
353 return fresh
354}
355
356// tokenAsync uses a bool to ensure that only one non-blocking token refresh
357// happens at a time, even if multiple callers have entered this function
358// concurrently. This avoids creating an arbitrary number of concurrent
359// goroutines. Retries should be attempted and managed within the Token method.
360// If the refresh attempt fails, no further attempts are made until the refresh
361// window expires and the token enters the invalid state, at which point the
362// blocking call to Token should likely return the same error on the main goroutine.
363func (c *cachedTokenProvider) tokenAsync(ctx context.Context) {
364 fn := func() {
365 c.mu.Lock()
366 c.isRefreshRunning = true
367 c.mu.Unlock()
368 t, err := c.tp.Token(ctx)
369 c.mu.Lock()
370 defer c.mu.Unlock()
371 c.isRefreshRunning = false
372 if err != nil {
373 // Discard errors from the non-blocking refresh, but prevent further
374 // attempts.
375 c.isRefreshErr = true
376 return
377 }
378 c.cachedToken = t
379 }
380 c.mu.Lock()
381 defer c.mu.Unlock()
382 if !c.isRefreshRunning && !c.isRefreshErr {
383 go fn()
384 }
385}
386
387func (c *cachedTokenProvider) tokenBlocking(ctx context.Context) (*Token, error) {
388 c.mu.Lock()
389 defer c.mu.Unlock()
390 c.isRefreshErr = false
391 if c.cachedToken.IsValid() || (!c.autoRefresh && !c.cachedToken.isEmpty()) {
392 return c.cachedToken, nil
393 }
394 t, err := c.tp.Token(ctx)
395 if err != nil {
396 return nil, err
397 }
398 c.cachedToken = t
399 return t, nil
400}
401
402// Error is a error associated with retrieving a [Token]. It can hold useful
403// additional details for debugging.
404type Error struct {
405 // Response is the HTTP response associated with error. The body will always
406 // be already closed and consumed.
407 Response *http.Response
408 // Body is the HTTP response body.
409 Body []byte
410 // Err is the underlying wrapped error.
411 Err error
412
413 // code returned in the token response
414 code string
415 // description returned in the token response
416 description string
417 // uri returned in the token response
418 uri string
419}
420
421func (e *Error) Error() string {
422 if e.code != "" {
423 s := fmt.Sprintf("auth: %q", e.code)
424 if e.description != "" {
425 s += fmt.Sprintf(" %q", e.description)
426 }
427 if e.uri != "" {
428 s += fmt.Sprintf(" %q", e.uri)
429 }
430 return s
431 }
432 return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", e.Response.StatusCode, e.Body)
433}
434
435// Temporary returns true if the error is considered temporary and may be able
436// to be retried.
437func (e *Error) Temporary() bool {
438 if e.Response == nil {
439 return false
440 }
441 sc := e.Response.StatusCode
442 return sc == http.StatusInternalServerError || sc == http.StatusServiceUnavailable || sc == http.StatusRequestTimeout || sc == http.StatusTooManyRequests
443}
444
445func (e *Error) Unwrap() error {
446 return e.Err
447}
448
449// Style describes how the token endpoint wants to receive the ClientID and
450// ClientSecret.
451type Style int
452
453const (
454 // StyleUnknown means the value has not been initiated. Sending this in
455 // a request will cause the token exchange to fail.
456 StyleUnknown Style = iota
457 // StyleInParams sends client info in the body of a POST request.
458 StyleInParams
459 // StyleInHeader sends client info using Basic Authorization header.
460 StyleInHeader
461)
462
463// Options2LO is the configuration settings for doing a 2-legged JWT OAuth2 flow.
464type Options2LO struct {
465 // Email is the OAuth2 client ID. This value is set as the "iss" in the
466 // JWT.
467 Email string
468 // PrivateKey contains the contents of an RSA private key or the
469 // contents of a PEM file that contains a private key. It is used to sign
470 // the JWT created.
471 PrivateKey []byte
472 // TokenURL is th URL the JWT is sent to. Required.
473 TokenURL string
474 // PrivateKeyID is the ID of the key used to sign the JWT. It is used as the
475 // "kid" in the JWT header. Optional.
476 PrivateKeyID string
477 // Subject is the used for to impersonate a user. It is used as the "sub" in
478 // the JWT.m Optional.
479 Subject string
480 // Scopes specifies requested permissions for the token. Optional.
481 Scopes []string
482 // Expires specifies the lifetime of the token. Optional.
483 Expires time.Duration
484 // Audience specifies the "aud" in the JWT. Optional.
485 Audience string
486 // PrivateClaims allows specifying any custom claims for the JWT. Optional.
487 PrivateClaims map[string]interface{}
488
489 // Client is the client to be used to make the underlying token requests.
490 // Optional.
491 Client *http.Client
492 // UseIDToken requests that the token returned be an ID token if one is
493 // returned from the server. Optional.
494 UseIDToken bool
495 // Logger is used for debug logging. If provided, logging will be enabled
496 // at the loggers configured level. By default logging is disabled unless
497 // enabled by setting GOOGLE_SDK_GO_LOGGING_LEVEL in which case a default
498 // logger will be used. Optional.
499 Logger *slog.Logger
500}
501
502func (o *Options2LO) client() *http.Client {
503 if o.Client != nil {
504 return o.Client
505 }
506 return internal.DefaultClient()
507}
508
509func (o *Options2LO) validate() error {
510 if o == nil {
511 return errors.New("auth: options must be provided")
512 }
513 if o.Email == "" {
514 return errors.New("auth: email must be provided")
515 }
516 if len(o.PrivateKey) == 0 {
517 return errors.New("auth: private key must be provided")
518 }
519 if o.TokenURL == "" {
520 return errors.New("auth: token URL must be provided")
521 }
522 return nil
523}
524
525// New2LOTokenProvider returns a [TokenProvider] from the provided options.
526func New2LOTokenProvider(opts *Options2LO) (TokenProvider, error) {
527 if err := opts.validate(); err != nil {
528 return nil, err
529 }
530 return tokenProvider2LO{opts: opts, Client: opts.client(), logger: internallog.New(opts.Logger)}, nil
531}
532
533type tokenProvider2LO struct {
534 opts *Options2LO
535 Client *http.Client
536 logger *slog.Logger
537}
538
539func (tp tokenProvider2LO) Token(ctx context.Context) (*Token, error) {
540 pk, err := internal.ParseKey(tp.opts.PrivateKey)
541 if err != nil {
542 return nil, err
543 }
544 claimSet := &jwt.Claims{
545 Iss: tp.opts.Email,
546 Scope: strings.Join(tp.opts.Scopes, " "),
547 Aud: tp.opts.TokenURL,
548 AdditionalClaims: tp.opts.PrivateClaims,
549 Sub: tp.opts.Subject,
550 }
551 if t := tp.opts.Expires; t > 0 {
552 claimSet.Exp = time.Now().Add(t).Unix()
553 }
554 if aud := tp.opts.Audience; aud != "" {
555 claimSet.Aud = aud
556 }
557 h := *defaultHeader
558 h.KeyID = tp.opts.PrivateKeyID
559 payload, err := jwt.EncodeJWS(&h, claimSet, pk)
560 if err != nil {
561 return nil, err
562 }
563 v := url.Values{}
564 v.Set("grant_type", defaultGrantType)
565 v.Set("assertion", payload)
566 req, err := http.NewRequestWithContext(ctx, "POST", tp.opts.TokenURL, strings.NewReader(v.Encode()))
567 if err != nil {
568 return nil, err
569 }
570 req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
571 tp.logger.DebugContext(ctx, "2LO token request", "request", internallog.HTTPRequest(req, []byte(v.Encode())))
572 resp, body, err := internal.DoRequest(tp.Client, req)
573 if err != nil {
574 return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
575 }
576 tp.logger.DebugContext(ctx, "2LO token response", "response", internallog.HTTPResponse(resp, body))
577 if c := resp.StatusCode; c < http.StatusOK || c >= http.StatusMultipleChoices {
578 return nil, &Error{
579 Response: resp,
580 Body: body,
581 }
582 }
583 // tokenRes is the JSON response body.
584 var tokenRes struct {
585 AccessToken string `json:"access_token"`
586 TokenType string `json:"token_type"`
587 IDToken string `json:"id_token"`
588 ExpiresIn int64 `json:"expires_in"`
589 }
590 if err := json.Unmarshal(body, &tokenRes); err != nil {
591 return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
592 }
593 token := &Token{
594 Value: tokenRes.AccessToken,
595 Type: tokenRes.TokenType,
596 }
597 token.Metadata = make(map[string]interface{})
598 json.Unmarshal(body, &token.Metadata) // no error checks for optional fields
599
600 if secs := tokenRes.ExpiresIn; secs > 0 {
601 token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
602 }
603 if v := tokenRes.IDToken; v != "" {
604 // decode returned id token to get expiry
605 claimSet, err := jwt.DecodeJWS(v)
606 if err != nil {
607 return nil, fmt.Errorf("auth: error decoding JWT token: %w", err)
608 }
609 token.Expiry = time.Unix(claimSet.Exp, 0)
610 }
611 if tp.opts.UseIDToken {
612 if tokenRes.IDToken == "" {
613 return nil, fmt.Errorf("auth: response doesn't have JWT token")
614 }
615 token.Value = tokenRes.IDToken
616 }
617 return token, nil
618}