1//go:build go1.18
2// +build go1.18
3
4// Copyright (c) Microsoft Corporation. All rights reserved.
5// Licensed under the MIT License.
6
7package azidentity
8
9import (
10 "context"
11 "errors"
12 "fmt"
13 "net/http"
14 "strings"
15 "sync"
16
17 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
18 "github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
19 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
20 "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
21 "github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal"
22 "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
23 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
24
25 // this import ensures well-known configurations in azcore/cloud have ARM audiences for Authenticate()
26 _ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime"
27)
28
29type publicClientOptions struct {
30 azcore.ClientOptions
31
32 AdditionallyAllowedTenants []string
33 DeviceCodePrompt func(context.Context, DeviceCodeMessage) error
34 DisableAutomaticAuthentication bool
35 DisableInstanceDiscovery bool
36 LoginHint, RedirectURL string
37 Record authenticationRecord
38 TokenCachePersistenceOptions *tokenCachePersistenceOptions
39 Username, Password string
40}
41
42// publicClient wraps the MSAL public client
43type publicClient struct {
44 cae, noCAE msalPublicClient
45 caeMu, noCAEMu, clientMu *sync.Mutex
46 clientID, tenantID string
47 defaultScope []string
48 host string
49 name string
50 opts publicClientOptions
51 record authenticationRecord
52 azClient *azcore.Client
53}
54
55var errScopeRequired = errors.New("authenticating in this environment requires specifying a scope in TokenRequestOptions")
56
57func newPublicClient(tenantID, clientID, name string, o publicClientOptions) (*publicClient, error) {
58 if !validTenantID(tenantID) {
59 return nil, errInvalidTenantID
60 }
61 host, err := setAuthorityHost(o.Cloud)
62 if err != nil {
63 return nil, err
64 }
65 // if the application specified a cloud configuration, use its ARM audience as the default scope for Authenticate()
66 audience := o.Cloud.Services[cloud.ResourceManager].Audience
67 if audience == "" {
68 // no cloud configuration, or no ARM audience, specified; try to map the host to a well-known one (all of which have a trailing slash)
69 if !strings.HasSuffix(host, "/") {
70 host += "/"
71 }
72 switch host {
73 case cloud.AzureChina.ActiveDirectoryAuthorityHost:
74 audience = cloud.AzureChina.Services[cloud.ResourceManager].Audience
75 case cloud.AzureGovernment.ActiveDirectoryAuthorityHost:
76 audience = cloud.AzureGovernment.Services[cloud.ResourceManager].Audience
77 case cloud.AzurePublic.ActiveDirectoryAuthorityHost:
78 audience = cloud.AzurePublic.Services[cloud.ResourceManager].Audience
79 }
80 }
81 // if we didn't come up with an audience, the application will have to specify a scope for Authenticate()
82 var defaultScope []string
83 if audience != "" {
84 defaultScope = []string{audience + defaultSuffix}
85 }
86 client, err := azcore.NewClient(module, version, runtime.PipelineOptions{
87 Tracing: runtime.TracingOptions{
88 Namespace: traceNamespace,
89 },
90 }, &o.ClientOptions)
91 if err != nil {
92 return nil, err
93 }
94 o.AdditionallyAllowedTenants = resolveAdditionalTenants(o.AdditionallyAllowedTenants)
95 return &publicClient{
96 caeMu: &sync.Mutex{},
97 clientID: clientID,
98 clientMu: &sync.Mutex{},
99 defaultScope: defaultScope,
100 host: host,
101 name: name,
102 noCAEMu: &sync.Mutex{},
103 opts: o,
104 record: o.Record,
105 tenantID: tenantID,
106 azClient: client,
107 }, nil
108}
109
110func (p *publicClient) Authenticate(ctx context.Context, tro *policy.TokenRequestOptions) (authenticationRecord, error) {
111 if tro == nil {
112 tro = &policy.TokenRequestOptions{}
113 }
114 if len(tro.Scopes) == 0 {
115 if p.defaultScope == nil {
116 return authenticationRecord{}, errScopeRequired
117 }
118 tro.Scopes = p.defaultScope
119 }
120 client, mu, err := p.client(*tro)
121 if err != nil {
122 return authenticationRecord{}, err
123 }
124 mu.Lock()
125 defer mu.Unlock()
126 _, err = p.reqToken(ctx, client, *tro)
127 if err == nil {
128 scope := strings.Join(tro.Scopes, ", ")
129 msg := fmt.Sprintf("%s.Authenticate() acquired a token for scope %q", p.name, scope)
130 log.Write(EventAuthentication, msg)
131 }
132 return p.record, err
133}
134
135// GetToken requests an access token from MSAL, checking the cache first.
136func (p *publicClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
137 if len(tro.Scopes) < 1 {
138 return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", p.name)
139 }
140 tenant, err := p.resolveTenant(tro.TenantID)
141 if err != nil {
142 return azcore.AccessToken{}, err
143 }
144 client, mu, err := p.client(tro)
145 if err != nil {
146 return azcore.AccessToken{}, err
147 }
148 mu.Lock()
149 defer mu.Unlock()
150 ar, err := client.AcquireTokenSilent(ctx, tro.Scopes, public.WithSilentAccount(p.record.account()), public.WithClaims(tro.Claims), public.WithTenantID(tenant))
151 if err == nil {
152 return p.token(ar, err)
153 }
154 if p.opts.DisableAutomaticAuthentication {
155 return azcore.AccessToken{}, newauthenticationRequiredError(p.name, tro)
156 }
157 at, err := p.reqToken(ctx, client, tro)
158 if err == nil {
159 msg := fmt.Sprintf("%s.GetToken() acquired a token for scope %q", p.name, strings.Join(ar.GrantedScopes, ", "))
160 log.Write(EventAuthentication, msg)
161 }
162 return at, err
163}
164
165// reqToken requests a token from the MSAL public client. It's separate from GetToken() to enable Authenticate() to bypass the cache.
166func (p *publicClient) reqToken(ctx context.Context, c msalPublicClient, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
167 tenant, err := p.resolveTenant(tro.TenantID)
168 if err != nil {
169 return azcore.AccessToken{}, err
170 }
171 var ar public.AuthResult
172 switch p.name {
173 case credNameBrowser:
174 ar, err = c.AcquireTokenInteractive(ctx, tro.Scopes,
175 public.WithClaims(tro.Claims),
176 public.WithLoginHint(p.opts.LoginHint),
177 public.WithRedirectURI(p.opts.RedirectURL),
178 public.WithTenantID(tenant),
179 )
180 case credNameDeviceCode:
181 dc, e := c.AcquireTokenByDeviceCode(ctx, tro.Scopes, public.WithClaims(tro.Claims), public.WithTenantID(tenant))
182 if e != nil {
183 return azcore.AccessToken{}, e
184 }
185 err = p.opts.DeviceCodePrompt(ctx, DeviceCodeMessage{
186 Message: dc.Result.Message,
187 UserCode: dc.Result.UserCode,
188 VerificationURL: dc.Result.VerificationURL,
189 })
190 if err == nil {
191 ar, err = dc.AuthenticationResult(ctx)
192 }
193 case credNameUserPassword:
194 ar, err = c.AcquireTokenByUsernamePassword(ctx, tro.Scopes, p.opts.Username, p.opts.Password, public.WithClaims(tro.Claims), public.WithTenantID(tenant))
195 default:
196 return azcore.AccessToken{}, fmt.Errorf("unknown credential %q", p.name)
197 }
198 return p.token(ar, err)
199}
200
201func (p *publicClient) client(tro policy.TokenRequestOptions) (msalPublicClient, *sync.Mutex, error) {
202 p.clientMu.Lock()
203 defer p.clientMu.Unlock()
204 if tro.EnableCAE {
205 if p.cae == nil {
206 client, err := p.newMSALClient(true)
207 if err != nil {
208 return nil, nil, err
209 }
210 p.cae = client
211 }
212 return p.cae, p.caeMu, nil
213 }
214 if p.noCAE == nil {
215 client, err := p.newMSALClient(false)
216 if err != nil {
217 return nil, nil, err
218 }
219 p.noCAE = client
220 }
221 return p.noCAE, p.noCAEMu, nil
222}
223
224func (p *publicClient) newMSALClient(enableCAE bool) (msalPublicClient, error) {
225 cache, err := internal.NewCache(p.opts.TokenCachePersistenceOptions, enableCAE)
226 if err != nil {
227 return nil, err
228 }
229 o := []public.Option{
230 public.WithAuthority(runtime.JoinPaths(p.host, p.tenantID)),
231 public.WithCache(cache),
232 public.WithHTTPClient(p),
233 }
234 if enableCAE {
235 o = append(o, public.WithClientCapabilities(cp1))
236 }
237 if p.opts.DisableInstanceDiscovery || strings.ToLower(p.tenantID) == "adfs" {
238 o = append(o, public.WithInstanceDiscovery(false))
239 }
240 return public.New(p.clientID, o...)
241}
242
243func (p *publicClient) token(ar public.AuthResult, err error) (azcore.AccessToken, error) {
244 if err == nil {
245 p.record, err = newAuthenticationRecord(ar)
246 } else {
247 res := getResponseFromError(err)
248 err = newAuthenticationFailedError(p.name, err.Error(), res, err)
249 }
250 return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
251}
252
253// resolveTenant returns the correct WithTenantID() argument for a token request given the client's
254// configuration, or an error when that configuration doesn't allow the specified tenant
255func (p *publicClient) resolveTenant(specified string) (string, error) {
256 t, err := resolveTenant(p.tenantID, specified, p.name, p.opts.AdditionallyAllowedTenants)
257 if t == p.tenantID {
258 // callers pass this value to MSAL's WithTenantID(). There's no need to redundantly specify
259 // the client's default tenant and doing so is an error when that tenant is "organizations"
260 t = ""
261 }
262 return t, err
263}
264
265// these methods satisfy the MSAL ops.HTTPClient interface
266
267func (p *publicClient) CloseIdleConnections() {
268 // do nothing
269}
270
271func (p *publicClient) Do(r *http.Request) (*http.Response, error) {
272 return doForClient(p.azClient, r)
273}