1//go:build go1.18
2// +build go1.18
3
4// Copyright (c) Microsoft Corporation. All rights reserved.
5// Licensed under the MIT License.
6
7package azidentity
8
9import (
10 "context"
11 "errors"
12 "fmt"
13 "strings"
14 "sync"
15
16 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
17 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
18 "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
19)
20
21// ChainedTokenCredentialOptions contains optional parameters for ChainedTokenCredential.
22type ChainedTokenCredentialOptions struct {
23 // RetrySources configures how the credential uses its sources. When true, the credential always attempts to
24 // authenticate through each source in turn, stopping when one succeeds. When false, the credential authenticates
25 // only through this first successful source--it never again tries the sources which failed.
26 RetrySources bool
27}
28
29// ChainedTokenCredential links together multiple credentials and tries them sequentially when authenticating. By default,
30// it tries all the credentials until one authenticates, after which it always uses that credential.
31type ChainedTokenCredential struct {
32 cond *sync.Cond
33 iterating bool
34 name string
35 retrySources bool
36 sources []azcore.TokenCredential
37 successfulCredential azcore.TokenCredential
38}
39
40// NewChainedTokenCredential creates a ChainedTokenCredential. Pass nil for options to accept defaults.
41func NewChainedTokenCredential(sources []azcore.TokenCredential, options *ChainedTokenCredentialOptions) (*ChainedTokenCredential, error) {
42 if len(sources) == 0 {
43 return nil, errors.New("sources must contain at least one TokenCredential")
44 }
45 for _, source := range sources {
46 if source == nil { // cannot have a nil credential in the chain or else the application will panic when GetToken() is called on nil
47 return nil, errors.New("sources cannot contain nil")
48 }
49 }
50 cp := make([]azcore.TokenCredential, len(sources))
51 copy(cp, sources)
52 if options == nil {
53 options = &ChainedTokenCredentialOptions{}
54 }
55 return &ChainedTokenCredential{
56 cond: sync.NewCond(&sync.Mutex{}),
57 name: "ChainedTokenCredential",
58 retrySources: options.RetrySources,
59 sources: cp,
60 }, nil
61}
62
63// GetToken calls GetToken on the chained credentials in turn, stopping when one returns a token.
64// This method is called automatically by Azure SDK clients.
65func (c *ChainedTokenCredential) GetToken(ctx context.Context, opts policy.TokenRequestOptions) (azcore.AccessToken, error) {
66 if !c.retrySources {
67 // ensure only one goroutine at a time iterates the sources and perhaps sets c.successfulCredential
68 c.cond.L.Lock()
69 for {
70 if c.successfulCredential != nil {
71 c.cond.L.Unlock()
72 return c.successfulCredential.GetToken(ctx, opts)
73 }
74 if !c.iterating {
75 c.iterating = true
76 // allow other goroutines to wait while this one iterates
77 c.cond.L.Unlock()
78 break
79 }
80 c.cond.Wait()
81 }
82 }
83
84 var (
85 err error
86 errs []error
87 successfulCredential azcore.TokenCredential
88 token azcore.AccessToken
89 unavailableErr credentialUnavailable
90 )
91 for _, cred := range c.sources {
92 token, err = cred.GetToken(ctx, opts)
93 if err == nil {
94 log.Writef(EventAuthentication, "%s authenticated with %s", c.name, extractCredentialName(cred))
95 successfulCredential = cred
96 break
97 }
98 errs = append(errs, err)
99 // continue to the next source iff this one returned credentialUnavailableError
100 if !errors.As(err, &unavailableErr) {
101 break
102 }
103 }
104 if c.iterating {
105 c.cond.L.Lock()
106 // this is nil when all credentials returned an error
107 c.successfulCredential = successfulCredential
108 c.iterating = false
109 c.cond.L.Unlock()
110 c.cond.Broadcast()
111 }
112 // err is the error returned by the last GetToken call. It will be nil when that call succeeds
113 if err != nil {
114 // return credentialUnavailableError iff all sources did so; return AuthenticationFailedError otherwise
115 msg := createChainedErrorMessage(errs)
116 if errors.As(err, &unavailableErr) {
117 err = newCredentialUnavailableError(c.name, msg)
118 } else {
119 res := getResponseFromError(err)
120 err = newAuthenticationFailedError(c.name, msg, res, err)
121 }
122 }
123 return token, err
124}
125
126func createChainedErrorMessage(errs []error) string {
127 msg := "failed to acquire a token.\nAttempted credentials:"
128 for _, err := range errs {
129 msg += fmt.Sprintf("\n\t%s", err.Error())
130 }
131 return msg
132}
133
134func extractCredentialName(credential azcore.TokenCredential) string {
135 return strings.TrimPrefix(fmt.Sprintf("%T", credential), "*azidentity.")
136}
137
138var _ azcore.TokenCredential = (*ChainedTokenCredential)(nil)