1//go:build go1.18
2// +build go1.18
3
4// Copyright (c) Microsoft Corporation. All rights reserved.
5// Licensed under the MIT License.
6
7package azidentity
8
9import (
10 "context"
11 "errors"
12 "fmt"
13 "net/http"
14 "os"
15 "strings"
16 "sync"
17
18 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
19 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
20 "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
21 "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal"
22 "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
23 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
24)
25
26type confidentialClientOptions struct {
27 azcore.ClientOptions
28
29 AdditionallyAllowedTenants []string
30 // Assertion for on-behalf-of authentication
31 Assertion string
32 DisableInstanceDiscovery, SendX5C bool
33 tokenCachePersistenceOptions *tokenCachePersistenceOptions
34}
35
36// confidentialClient wraps the MSAL confidential client
37type confidentialClient struct {
38 cae, noCAE msalConfidentialClient
39 caeMu, noCAEMu, clientMu *sync.Mutex
40 clientID, tenantID string
41 cred confidential.Credential
42 host string
43 name string
44 opts confidentialClientOptions
45 region string
46 azClient *azcore.Client
47}
48
49func newConfidentialClient(tenantID, clientID, name string, cred confidential.Credential, opts confidentialClientOptions) (*confidentialClient, error) {
50 if !validTenantID(tenantID) {
51 return nil, errInvalidTenantID
52 }
53 host, err := setAuthorityHost(opts.Cloud)
54 if err != nil {
55 return nil, err
56 }
57 client, err := azcore.NewClient(module, version, runtime.PipelineOptions{
58 Tracing: runtime.TracingOptions{
59 Namespace: traceNamespace,
60 },
61 }, &opts.ClientOptions)
62 if err != nil {
63 return nil, err
64 }
65 opts.AdditionallyAllowedTenants = resolveAdditionalTenants(opts.AdditionallyAllowedTenants)
66 return &confidentialClient{
67 caeMu: &sync.Mutex{},
68 clientID: clientID,
69 clientMu: &sync.Mutex{},
70 cred: cred,
71 host: host,
72 name: name,
73 noCAEMu: &sync.Mutex{},
74 opts: opts,
75 region: os.Getenv(azureRegionalAuthorityName),
76 tenantID: tenantID,
77 azClient: client,
78 }, nil
79}
80
81// GetToken requests an access token from MSAL, checking the cache first.
82func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
83 if len(tro.Scopes) < 1 {
84 return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", c.name)
85 }
86 // we don't resolve the tenant for managed identities because they acquire tokens only from their home tenants
87 if c.name != credNameManagedIdentity {
88 tenant, err := c.resolveTenant(tro.TenantID)
89 if err != nil {
90 return azcore.AccessToken{}, err
91 }
92 tro.TenantID = tenant
93 }
94 client, mu, err := c.client(tro)
95 if err != nil {
96 return azcore.AccessToken{}, err
97 }
98 mu.Lock()
99 defer mu.Unlock()
100 var ar confidential.AuthResult
101 if c.opts.Assertion != "" {
102 ar, err = client.AcquireTokenOnBehalfOf(ctx, c.opts.Assertion, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
103 } else {
104 ar, err = client.AcquireTokenSilent(ctx, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
105 if err != nil {
106 ar, err = client.AcquireTokenByCredential(ctx, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
107 }
108 }
109 if err != nil {
110 // We could get a credentialUnavailableError from managed identity authentication because in that case the error comes from our code.
111 // We return it directly because it affects the behavior of credential chains. Otherwise, we return AuthenticationFailedError.
112 var unavailableErr credentialUnavailable
113 if !errors.As(err, &unavailableErr) {
114 res := getResponseFromError(err)
115 err = newAuthenticationFailedError(c.name, err.Error(), res, err)
116 }
117 } else {
118 msg := fmt.Sprintf("%s.GetToken() acquired a token for scope %q", c.name, strings.Join(ar.GrantedScopes, ", "))
119 log.Write(EventAuthentication, msg)
120 }
121 return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
122}
123
124func (c *confidentialClient) client(tro policy.TokenRequestOptions) (msalConfidentialClient, *sync.Mutex, error) {
125 c.clientMu.Lock()
126 defer c.clientMu.Unlock()
127 if tro.EnableCAE {
128 if c.cae == nil {
129 client, err := c.newMSALClient(true)
130 if err != nil {
131 return nil, nil, err
132 }
133 c.cae = client
134 }
135 return c.cae, c.caeMu, nil
136 }
137 if c.noCAE == nil {
138 client, err := c.newMSALClient(false)
139 if err != nil {
140 return nil, nil, err
141 }
142 c.noCAE = client
143 }
144 return c.noCAE, c.noCAEMu, nil
145}
146
147func (c *confidentialClient) newMSALClient(enableCAE bool) (msalConfidentialClient, error) {
148 cache, err := internal.NewCache(c.opts.tokenCachePersistenceOptions, enableCAE)
149 if err != nil {
150 return nil, err
151 }
152 authority := runtime.JoinPaths(c.host, c.tenantID)
153 o := []confidential.Option{
154 confidential.WithAzureRegion(c.region),
155 confidential.WithCache(cache),
156 confidential.WithHTTPClient(c),
157 }
158 if enableCAE {
159 o = append(o, confidential.WithClientCapabilities(cp1))
160 }
161 if c.opts.SendX5C {
162 o = append(o, confidential.WithX5C())
163 }
164 if c.opts.DisableInstanceDiscovery || strings.ToLower(c.tenantID) == "adfs" {
165 o = append(o, confidential.WithInstanceDiscovery(false))
166 }
167 return confidential.New(authority, c.clientID, c.cred, o...)
168}
169
170// resolveTenant returns the correct WithTenantID() argument for a token request given the client's
171// configuration, or an error when that configuration doesn't allow the specified tenant
172func (c *confidentialClient) resolveTenant(specified string) (string, error) {
173 return resolveTenant(c.tenantID, specified, c.name, c.opts.AdditionallyAllowedTenants)
174}
175
176// these methods satisfy the MSAL ops.HTTPClient interface
177
178func (c *confidentialClient) CloseIdleConnections() {
179 // do nothing
180}
181
182func (c *confidentialClient) Do(r *http.Request) (*http.Response, error) {
183 return doForClient(c.azClient, r)
184}