auth.go

  1// Copyright 2023 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15// Package auth provides utilities for managing Google Cloud credentials,
 16// including functionality for creating, caching, and refreshing OAuth2 tokens.
 17// It offers customizable options for different OAuth2 flows, such as 2-legged
 18// (2LO) and 3-legged (3LO) OAuth, along with support for PKCE and automatic
 19// token management.
 20package auth
 21
 22import (
 23	"context"
 24	"encoding/json"
 25	"errors"
 26	"fmt"
 27	"log/slog"
 28	"net/http"
 29	"net/url"
 30	"strings"
 31	"sync"
 32	"time"
 33
 34	"cloud.google.com/go/auth/internal"
 35	"cloud.google.com/go/auth/internal/jwt"
 36	"github.com/googleapis/gax-go/v2/internallog"
 37)
 38
 39const (
 40	// Parameter keys for AuthCodeURL method to support PKCE.
 41	codeChallengeKey       = "code_challenge"
 42	codeChallengeMethodKey = "code_challenge_method"
 43
 44	// Parameter key for Exchange method to support PKCE.
 45	codeVerifierKey = "code_verifier"
 46
 47	// 3 minutes and 45 seconds before expiration. The shortest MDS cache is 4 minutes,
 48	// so we give it 15 seconds to refresh it's cache before attempting to refresh a token.
 49	defaultExpiryDelta = 225 * time.Second
 50
 51	universeDomainDefault = "googleapis.com"
 52)
 53
 54// tokenState represents different states for a [Token].
 55type tokenState int
 56
 57const (
 58	// fresh indicates that the [Token] is valid. It is not expired or close to
 59	// expired, or the token has no expiry.
 60	fresh tokenState = iota
 61	// stale indicates that the [Token] is close to expired, and should be
 62	// refreshed. The token can be used normally.
 63	stale
 64	// invalid indicates that the [Token] is expired or invalid. The token
 65	// cannot be used for a normal operation.
 66	invalid
 67)
 68
 69var (
 70	defaultGrantType = "urn:ietf:params:oauth:grant-type:jwt-bearer"
 71	defaultHeader    = &jwt.Header{Algorithm: jwt.HeaderAlgRSA256, Type: jwt.HeaderType}
 72
 73	// for testing
 74	timeNow = time.Now
 75)
 76
 77// TokenProvider specifies an interface for anything that can return a token.
 78type TokenProvider interface {
 79	// Token returns a Token or an error.
 80	// The Token returned must be safe to use
 81	// concurrently.
 82	// The returned Token must not be modified.
 83	// The context provided must be sent along to any requests that are made in
 84	// the implementing code.
 85	Token(context.Context) (*Token, error)
 86}
 87
 88// Token holds the credential token used to authorized requests. All fields are
 89// considered read-only.
 90type Token struct {
 91	// Value is the token used to authorize requests. It is usually an access
 92	// token but may be other types of tokens such as ID tokens in some flows.
 93	Value string
 94	// Type is the type of token Value is. If uninitialized, it should be
 95	// assumed to be a "Bearer" token.
 96	Type string
 97	// Expiry is the time the token is set to expire.
 98	Expiry time.Time
 99	// Metadata  may include, but is not limited to, the body of the token
100	// response returned by the server.
101	Metadata map[string]interface{} // TODO(codyoss): maybe make a method to flatten metadata to avoid []string for url.Values
102}
103
104// IsValid reports that a [Token] is non-nil, has a [Token.Value], and has not
105// expired. A token is considered expired if [Token.Expiry] has passed or will
106// pass in the next 225 seconds.
107func (t *Token) IsValid() bool {
108	return t.isValidWithEarlyExpiry(defaultExpiryDelta)
109}
110
111// MetadataString is a convenience method for accessing string values in the
112// token's metadata. Returns an empty string if the metadata is nil or the value
113// for the given key cannot be cast to a string.
114func (t *Token) MetadataString(k string) string {
115	if t.Metadata == nil {
116		return ""
117	}
118	s, ok := t.Metadata[k].(string)
119	if !ok {
120		return ""
121	}
122	return s
123}
124
125func (t *Token) isValidWithEarlyExpiry(earlyExpiry time.Duration) bool {
126	if t.isEmpty() {
127		return false
128	}
129	if t.Expiry.IsZero() {
130		return true
131	}
132	return !t.Expiry.Round(0).Add(-earlyExpiry).Before(timeNow())
133}
134
135func (t *Token) isEmpty() bool {
136	return t == nil || t.Value == ""
137}
138
139// Credentials holds Google credentials, including
140// [Application Default Credentials].
141//
142// [Application Default Credentials]: https://developers.google.com/accounts/docs/application-default-credentials
143type Credentials struct {
144	json           []byte
145	projectID      CredentialsPropertyProvider
146	quotaProjectID CredentialsPropertyProvider
147	// universeDomain is the default service domain for a given Cloud universe.
148	universeDomain CredentialsPropertyProvider
149
150	TokenProvider
151}
152
153// JSON returns the bytes associated with the the file used to source
154// credentials if one was used.
155func (c *Credentials) JSON() []byte {
156	return c.json
157}
158
159// ProjectID returns the associated project ID from the underlying file or
160// environment.
161func (c *Credentials) ProjectID(ctx context.Context) (string, error) {
162	if c.projectID == nil {
163		return internal.GetProjectID(c.json, ""), nil
164	}
165	v, err := c.projectID.GetProperty(ctx)
166	if err != nil {
167		return "", err
168	}
169	return internal.GetProjectID(c.json, v), nil
170}
171
172// QuotaProjectID returns the associated quota project ID from the underlying
173// file or environment.
174func (c *Credentials) QuotaProjectID(ctx context.Context) (string, error) {
175	if c.quotaProjectID == nil {
176		return internal.GetQuotaProject(c.json, ""), nil
177	}
178	v, err := c.quotaProjectID.GetProperty(ctx)
179	if err != nil {
180		return "", err
181	}
182	return internal.GetQuotaProject(c.json, v), nil
183}
184
185// UniverseDomain returns the default service domain for a given Cloud universe.
186// The default value is "googleapis.com".
187func (c *Credentials) UniverseDomain(ctx context.Context) (string, error) {
188	if c.universeDomain == nil {
189		return universeDomainDefault, nil
190	}
191	v, err := c.universeDomain.GetProperty(ctx)
192	if err != nil {
193		return "", err
194	}
195	if v == "" {
196		return universeDomainDefault, nil
197	}
198	return v, err
199}
200
201// CredentialsPropertyProvider provides an implementation to fetch a property
202// value for [Credentials].
203type CredentialsPropertyProvider interface {
204	GetProperty(context.Context) (string, error)
205}
206
207// CredentialsPropertyFunc is a type adapter to allow the use of ordinary
208// functions as a [CredentialsPropertyProvider].
209type CredentialsPropertyFunc func(context.Context) (string, error)
210
211// GetProperty loads the properly value provided the given context.
212func (p CredentialsPropertyFunc) GetProperty(ctx context.Context) (string, error) {
213	return p(ctx)
214}
215
216// CredentialsOptions are used to configure [Credentials].
217type CredentialsOptions struct {
218	// TokenProvider is a means of sourcing a token for the credentials. Required.
219	TokenProvider TokenProvider
220	// JSON is the raw contents of the credentials file if sourced from a file.
221	JSON []byte
222	// ProjectIDProvider resolves the project ID associated with the
223	// credentials.
224	ProjectIDProvider CredentialsPropertyProvider
225	// QuotaProjectIDProvider resolves the quota project ID associated with the
226	// credentials.
227	QuotaProjectIDProvider CredentialsPropertyProvider
228	// UniverseDomainProvider resolves the universe domain with the credentials.
229	UniverseDomainProvider CredentialsPropertyProvider
230}
231
232// NewCredentials returns new [Credentials] from the provided options.
233func NewCredentials(opts *CredentialsOptions) *Credentials {
234	creds := &Credentials{
235		TokenProvider:  opts.TokenProvider,
236		json:           opts.JSON,
237		projectID:      opts.ProjectIDProvider,
238		quotaProjectID: opts.QuotaProjectIDProvider,
239		universeDomain: opts.UniverseDomainProvider,
240	}
241
242	return creds
243}
244
245// CachedTokenProviderOptions provides options for configuring a cached
246// [TokenProvider].
247type CachedTokenProviderOptions struct {
248	// DisableAutoRefresh makes the TokenProvider always return the same token,
249	// even if it is expired. The default is false. Optional.
250	DisableAutoRefresh bool
251	// ExpireEarly configures the amount of time before a token expires, that it
252	// should be refreshed. If unset, the default value is 3 minutes and 45
253	// seconds. Optional.
254	ExpireEarly time.Duration
255	// DisableAsyncRefresh configures a synchronous workflow that refreshes
256	// tokens in a blocking manner. The default is false. Optional.
257	DisableAsyncRefresh bool
258}
259
260func (ctpo *CachedTokenProviderOptions) autoRefresh() bool {
261	if ctpo == nil {
262		return true
263	}
264	return !ctpo.DisableAutoRefresh
265}
266
267func (ctpo *CachedTokenProviderOptions) expireEarly() time.Duration {
268	if ctpo == nil || ctpo.ExpireEarly == 0 {
269		return defaultExpiryDelta
270	}
271	return ctpo.ExpireEarly
272}
273
274func (ctpo *CachedTokenProviderOptions) blockingRefresh() bool {
275	if ctpo == nil {
276		return false
277	}
278	return ctpo.DisableAsyncRefresh
279}
280
281// NewCachedTokenProvider wraps a [TokenProvider] to cache the tokens returned
282// by the underlying provider. By default it will refresh tokens asynchronously
283// a few minutes before they expire.
284func NewCachedTokenProvider(tp TokenProvider, opts *CachedTokenProviderOptions) TokenProvider {
285	if ctp, ok := tp.(*cachedTokenProvider); ok {
286		return ctp
287	}
288	return &cachedTokenProvider{
289		tp:              tp,
290		autoRefresh:     opts.autoRefresh(),
291		expireEarly:     opts.expireEarly(),
292		blockingRefresh: opts.blockingRefresh(),
293	}
294}
295
296type cachedTokenProvider struct {
297	tp              TokenProvider
298	autoRefresh     bool
299	expireEarly     time.Duration
300	blockingRefresh bool
301
302	mu          sync.Mutex
303	cachedToken *Token
304	// isRefreshRunning ensures that the non-blocking refresh will only be
305	// attempted once, even if multiple callers enter the Token method.
306	isRefreshRunning bool
307	// isRefreshErr ensures that the non-blocking refresh will only be attempted
308	// once per refresh window if an error is encountered.
309	isRefreshErr bool
310}
311
312func (c *cachedTokenProvider) Token(ctx context.Context) (*Token, error) {
313	if c.blockingRefresh {
314		return c.tokenBlocking(ctx)
315	}
316	return c.tokenNonBlocking(ctx)
317}
318
319func (c *cachedTokenProvider) tokenNonBlocking(ctx context.Context) (*Token, error) {
320	switch c.tokenState() {
321	case fresh:
322		c.mu.Lock()
323		defer c.mu.Unlock()
324		return c.cachedToken, nil
325	case stale:
326		// Call tokenAsync with a new Context because the user-provided context
327		// may have a short timeout incompatible with async token refresh.
328		c.tokenAsync(context.Background())
329		// Return the stale token immediately to not block customer requests to Cloud services.
330		c.mu.Lock()
331		defer c.mu.Unlock()
332		return c.cachedToken, nil
333	default: // invalid
334		return c.tokenBlocking(ctx)
335	}
336}
337
338// tokenState reports the token's validity.
339func (c *cachedTokenProvider) tokenState() tokenState {
340	c.mu.Lock()
341	defer c.mu.Unlock()
342	t := c.cachedToken
343	now := timeNow()
344	if t == nil || t.Value == "" {
345		return invalid
346	} else if t.Expiry.IsZero() {
347		return fresh
348	} else if now.After(t.Expiry.Round(0)) {
349		return invalid
350	} else if now.After(t.Expiry.Round(0).Add(-c.expireEarly)) {
351		return stale
352	}
353	return fresh
354}
355
356// tokenAsync uses a bool to ensure that only one non-blocking token refresh
357// happens at a time, even if multiple callers have entered this function
358// concurrently. This avoids creating an arbitrary number of concurrent
359// goroutines. Retries should be attempted and managed within the Token method.
360// If the refresh attempt fails, no further attempts are made until the refresh
361// window expires and the token enters the invalid state, at which point the
362// blocking call to Token should likely return the same error on the main goroutine.
363func (c *cachedTokenProvider) tokenAsync(ctx context.Context) {
364	fn := func() {
365		c.mu.Lock()
366		c.isRefreshRunning = true
367		c.mu.Unlock()
368		t, err := c.tp.Token(ctx)
369		c.mu.Lock()
370		defer c.mu.Unlock()
371		c.isRefreshRunning = false
372		if err != nil {
373			// Discard errors from the non-blocking refresh, but prevent further
374			// attempts.
375			c.isRefreshErr = true
376			return
377		}
378		c.cachedToken = t
379	}
380	c.mu.Lock()
381	defer c.mu.Unlock()
382	if !c.isRefreshRunning && !c.isRefreshErr {
383		go fn()
384	}
385}
386
387func (c *cachedTokenProvider) tokenBlocking(ctx context.Context) (*Token, error) {
388	c.mu.Lock()
389	defer c.mu.Unlock()
390	c.isRefreshErr = false
391	if c.cachedToken.IsValid() || (!c.autoRefresh && !c.cachedToken.isEmpty()) {
392		return c.cachedToken, nil
393	}
394	t, err := c.tp.Token(ctx)
395	if err != nil {
396		return nil, err
397	}
398	c.cachedToken = t
399	return t, nil
400}
401
402// Error is a error associated with retrieving a [Token]. It can hold useful
403// additional details for debugging.
404type Error struct {
405	// Response is the HTTP response associated with error. The body will always
406	// be already closed and consumed.
407	Response *http.Response
408	// Body is the HTTP response body.
409	Body []byte
410	// Err is the underlying wrapped error.
411	Err error
412
413	// code returned in the token response
414	code string
415	// description returned in the token response
416	description string
417	// uri returned in the token response
418	uri string
419}
420
421func (e *Error) Error() string {
422	if e.code != "" {
423		s := fmt.Sprintf("auth: %q", e.code)
424		if e.description != "" {
425			s += fmt.Sprintf(" %q", e.description)
426		}
427		if e.uri != "" {
428			s += fmt.Sprintf(" %q", e.uri)
429		}
430		return s
431	}
432	return fmt.Sprintf("auth: cannot fetch token: %v\nResponse: %s", e.Response.StatusCode, e.Body)
433}
434
435// Temporary returns true if the error is considered temporary and may be able
436// to be retried.
437func (e *Error) Temporary() bool {
438	if e.Response == nil {
439		return false
440	}
441	sc := e.Response.StatusCode
442	return sc == http.StatusInternalServerError || sc == http.StatusServiceUnavailable || sc == http.StatusRequestTimeout || sc == http.StatusTooManyRequests
443}
444
445func (e *Error) Unwrap() error {
446	return e.Err
447}
448
449// Style describes how the token endpoint wants to receive the ClientID and
450// ClientSecret.
451type Style int
452
453const (
454	// StyleUnknown means the value has not been initiated. Sending this in
455	// a request will cause the token exchange to fail.
456	StyleUnknown Style = iota
457	// StyleInParams sends client info in the body of a POST request.
458	StyleInParams
459	// StyleInHeader sends client info using Basic Authorization header.
460	StyleInHeader
461)
462
463// Options2LO is the configuration settings for doing a 2-legged JWT OAuth2 flow.
464type Options2LO struct {
465	// Email is the OAuth2 client ID. This value is set as the "iss" in the
466	// JWT.
467	Email string
468	// PrivateKey contains the contents of an RSA private key or the
469	// contents of a PEM file that contains a private key. It is used to sign
470	// the JWT created.
471	PrivateKey []byte
472	// TokenURL is th URL the JWT is sent to. Required.
473	TokenURL string
474	// PrivateKeyID is the ID of the key used to sign the JWT. It is used as the
475	// "kid" in the JWT header. Optional.
476	PrivateKeyID string
477	// Subject is the used for to impersonate a user. It is used as the "sub" in
478	// the JWT.m Optional.
479	Subject string
480	// Scopes specifies requested permissions for the token. Optional.
481	Scopes []string
482	// Expires specifies the lifetime of the token. Optional.
483	Expires time.Duration
484	// Audience specifies the "aud" in the JWT. Optional.
485	Audience string
486	// PrivateClaims allows specifying any custom claims for the JWT. Optional.
487	PrivateClaims map[string]interface{}
488
489	// Client is the client to be used to make the underlying token requests.
490	// Optional.
491	Client *http.Client
492	// UseIDToken requests that the token returned be an ID token if one is
493	// returned from the server. Optional.
494	UseIDToken bool
495	// Logger is used for debug logging. If provided, logging will be enabled
496	// at the loggers configured level. By default logging is disabled unless
497	// enabled by setting GOOGLE_SDK_GO_LOGGING_LEVEL in which case a default
498	// logger will be used. Optional.
499	Logger *slog.Logger
500}
501
502func (o *Options2LO) client() *http.Client {
503	if o.Client != nil {
504		return o.Client
505	}
506	return internal.DefaultClient()
507}
508
509func (o *Options2LO) validate() error {
510	if o == nil {
511		return errors.New("auth: options must be provided")
512	}
513	if o.Email == "" {
514		return errors.New("auth: email must be provided")
515	}
516	if len(o.PrivateKey) == 0 {
517		return errors.New("auth: private key must be provided")
518	}
519	if o.TokenURL == "" {
520		return errors.New("auth: token URL must be provided")
521	}
522	return nil
523}
524
525// New2LOTokenProvider returns a [TokenProvider] from the provided options.
526func New2LOTokenProvider(opts *Options2LO) (TokenProvider, error) {
527	if err := opts.validate(); err != nil {
528		return nil, err
529	}
530	return tokenProvider2LO{opts: opts, Client: opts.client(), logger: internallog.New(opts.Logger)}, nil
531}
532
533type tokenProvider2LO struct {
534	opts   *Options2LO
535	Client *http.Client
536	logger *slog.Logger
537}
538
539func (tp tokenProvider2LO) Token(ctx context.Context) (*Token, error) {
540	pk, err := internal.ParseKey(tp.opts.PrivateKey)
541	if err != nil {
542		return nil, err
543	}
544	claimSet := &jwt.Claims{
545		Iss:              tp.opts.Email,
546		Scope:            strings.Join(tp.opts.Scopes, " "),
547		Aud:              tp.opts.TokenURL,
548		AdditionalClaims: tp.opts.PrivateClaims,
549		Sub:              tp.opts.Subject,
550	}
551	if t := tp.opts.Expires; t > 0 {
552		claimSet.Exp = time.Now().Add(t).Unix()
553	}
554	if aud := tp.opts.Audience; aud != "" {
555		claimSet.Aud = aud
556	}
557	h := *defaultHeader
558	h.KeyID = tp.opts.PrivateKeyID
559	payload, err := jwt.EncodeJWS(&h, claimSet, pk)
560	if err != nil {
561		return nil, err
562	}
563	v := url.Values{}
564	v.Set("grant_type", defaultGrantType)
565	v.Set("assertion", payload)
566	req, err := http.NewRequestWithContext(ctx, "POST", tp.opts.TokenURL, strings.NewReader(v.Encode()))
567	if err != nil {
568		return nil, err
569	}
570	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
571	tp.logger.DebugContext(ctx, "2LO token request", "request", internallog.HTTPRequest(req, []byte(v.Encode())))
572	resp, body, err := internal.DoRequest(tp.Client, req)
573	if err != nil {
574		return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
575	}
576	tp.logger.DebugContext(ctx, "2LO token response", "response", internallog.HTTPResponse(resp, body))
577	if c := resp.StatusCode; c < http.StatusOK || c >= http.StatusMultipleChoices {
578		return nil, &Error{
579			Response: resp,
580			Body:     body,
581		}
582	}
583	// tokenRes is the JSON response body.
584	var tokenRes struct {
585		AccessToken string `json:"access_token"`
586		TokenType   string `json:"token_type"`
587		IDToken     string `json:"id_token"`
588		ExpiresIn   int64  `json:"expires_in"`
589	}
590	if err := json.Unmarshal(body, &tokenRes); err != nil {
591		return nil, fmt.Errorf("auth: cannot fetch token: %w", err)
592	}
593	token := &Token{
594		Value: tokenRes.AccessToken,
595		Type:  tokenRes.TokenType,
596	}
597	token.Metadata = make(map[string]interface{})
598	json.Unmarshal(body, &token.Metadata) // no error checks for optional fields
599
600	if secs := tokenRes.ExpiresIn; secs > 0 {
601		token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
602	}
603	if v := tokenRes.IDToken; v != "" {
604		// decode returned id token to get expiry
605		claimSet, err := jwt.DecodeJWS(v)
606		if err != nil {
607			return nil, fmt.Errorf("auth: error decoding JWT token: %w", err)
608		}
609		token.Expiry = time.Unix(claimSet.Exp, 0)
610	}
611	if tp.opts.UseIDToken {
612		if tokenRes.IDToken == "" {
613			return nil, fmt.Errorf("auth: response doesn't have JWT token")
614		}
615		token.Value = tokenRes.IDToken
616	}
617	return token, nil
618}