chained_token_credential.go

  1//go:build go1.18
  2// +build go1.18
  3
  4// Copyright (c) Microsoft Corporation. All rights reserved.
  5// Licensed under the MIT License.
  6
  7package azidentity
  8
  9import (
 10	"context"
 11	"errors"
 12	"fmt"
 13	"strings"
 14	"sync"
 15
 16	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
 17	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
 18	"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
 19)
 20
 21// ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential.
 22type ChainedTokenCredentialOptions struct {
 23	// RetrySources configures how the credential uses its sources. When true, the credential always attempts to
 24	// authenticate through each source in turn, stopping when one succeeds. When false, the credential authenticates
 25	// only through this first successful source--it never again tries the sources which failed.
 26	RetrySources bool
 27}
 28
 29// ChainedTokenCredential links together multiple credentials and tries them sequentially when authenticating. By default,
 30// it tries all the credentials until one authenticates, after which it always uses that credential.
 31type ChainedTokenCredential struct {
 32	cond                 *sync.Cond
 33	iterating            bool
 34	name                 string
 35	retrySources         bool
 36	sources              []azcore.TokenCredential
 37	successfulCredential azcore.TokenCredential
 38}
 39
 40// NewChainedTokenCredential creates a ChainedTokenCredential. Pass nil for options to accept defaults.
 41func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) {
 42	if len(sources) == 0 {
 43		return nil, errors.New("sources must contain at least one TokenCredential")
 44	}
 45	for _, source := range sources {
 46		if source == nil { // cannot have a nil credential in the chain or else the application will panic when GetToken() is called on nil
 47			return nil, errors.New("sources cannot contain nil")
 48		}
 49	}
 50	cp := make([]azcore.TokenCredential, len(sources))
 51	copy(cp, sources)
 52	if options == nil {
 53		options = &ChainedTokenCredentialOptions{}
 54	}
 55	return &ChainedTokenCredential{
 56		cond:         sync.NewCond(&sync.Mutex{}),
 57		name:         "ChainedTokenCredential",
 58		retrySources: options.RetrySources,
 59		sources:      cp,
 60	}, nil
 61}
 62
 63// GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token.
 64// This method is called automatically by Azure SDK clients.
 65func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
 66	if !c.retrySources {
 67		// ensure only one goroutine at a time iterates the sources and perhaps sets c.successfulCredential
 68		c.cond.L.Lock()
 69		for {
 70			if c.successfulCredential != nil {
 71				c.cond.L.Unlock()
 72				return c.successfulCredential.GetToken(ctx, opts)
 73			}
 74			if !c.iterating {
 75				c.iterating = true
 76				// allow other goroutines to wait while this one iterates
 77				c.cond.L.Unlock()
 78				break
 79			}
 80			c.cond.Wait()
 81		}
 82	}
 83
 84	var (
 85		err                  error
 86		errs                 []error
 87		successfulCredential azcore.TokenCredential
 88		token                azcore.AccessToken
 89		unavailableErr       credentialUnavailable
 90	)
 91	for _, cred := range c.sources {
 92		token, err = cred.GetToken(ctx, opts)
 93		if err == nil {
 94			log.Writef(EventAuthentication, "%s authenticated with %s", c.name, extractCredentialName(cred))
 95			successfulCredential = cred
 96			break
 97		}
 98		errs = append(errs, err)
 99		// continue to the next source iff this one returned credentialUnavailableError
100		if !errors.As(err, &unavailableErr) {
101			break
102		}
103	}
104	if c.iterating {
105		c.cond.L.Lock()
106		// this is nil when all credentials returned an error
107		c.successfulCredential = successfulCredential
108		c.iterating = false
109		c.cond.L.Unlock()
110		c.cond.Broadcast()
111	}
112	// err is the error returned by the last GetToken call. It will be nil when that call succeeds
113	if err != nil {
114		// return credentialUnavailableError iff all sources did so; return AuthenticationFailedError otherwise
115		msg := createChainedErrorMessage(errs)
116		if errors.As(err, &unavailableErr) {
117			err = newCredentialUnavailableError(c.name, msg)
118		} else {
119			res := getResponseFromError(err)
120			err = newAuthenticationFailedError(c.name, msg, res, err)
121		}
122	}
123	return token, err
124}
125
126func createChainedErrorMessage(errs []error) string {
127	msg := "failed to acquire a token.\nAttempted credentials:"
128	for _, err := range errs {
129		msg += fmt.Sprintf("\n\t%s", err.Error())
130	}
131	return msg
132}
133
134func extractCredentialName(credential azcore.TokenCredential) string {
135	return strings.TrimPrefix(fmt.Sprintf("%T", credential), "*azidentity.")
136}
137
138var _ azcore.TokenCredential = (*ChainedTokenCredential)(nil)