1//go:build go1.18
2// +build go1.18
3
4// Copyright (c) Microsoft Corporation. All rights reserved.
5// Licensed under the MIT License.
6
7package azidentity
8
9import (
10 "context"
11 "encoding/json"
12 "errors"
13 "fmt"
14 "net/http"
15 "net/url"
16 "os"
17 "path/filepath"
18 "runtime"
19 "strconv"
20 "strings"
21 "time"
22
23 "github.com/Azure/azure-sdk-for-go/sdk/azcore"
24 "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
25 azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
26 "github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
27 "github.com/Azure/azure-sdk-for-go/sdk/internal/log"
28 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
29)
30
31const (
32 arcIMDSEndpoint = "IMDS_ENDPOINT"
33 defaultIdentityClientID = "DEFAULT_IDENTITY_CLIENT_ID"
34 identityEndpoint = "IDENTITY_ENDPOINT"
35 identityHeader = "IDENTITY_HEADER"
36 identityServerThumbprint = "IDENTITY_SERVER_THUMBPRINT"
37 headerMetadata = "Metadata"
38 imdsEndpoint = "http://169.254.169.254/metadata/identity/oauth2/token"
39 miResID = "mi_res_id"
40 msiEndpoint = "MSI_ENDPOINT"
41 msiResID = "msi_res_id"
42 msiSecret = "MSI_SECRET"
43 imdsAPIVersion = "2018-02-01"
44 azureArcAPIVersion = "2019-08-15"
45 qpClientID = "client_id"
46 serviceFabricAPIVersion = "2019-07-01-preview"
47)
48
49var imdsProbeTimeout = time.Second
50
51type msiType int
52
53const (
54 msiTypeAppService msiType = iota
55 msiTypeAzureArc
56 msiTypeAzureML
57 msiTypeCloudShell
58 msiTypeIMDS
59 msiTypeServiceFabric
60)
61
62type managedIdentityClient struct {
63 azClient *azcore.Client
64 endpoint string
65 id ManagedIDKind
66 msiType msiType
67 probeIMDS bool
68}
69
70// arcKeyDirectory returns the directory expected to contain Azure Arc keys
71var arcKeyDirectory = func() (string, error) {
72 switch runtime.GOOS {
73 case "linux":
74 return "/var/opt/azcmagent/tokens", nil
75 case "windows":
76 pd := os.Getenv("ProgramData")
77 if pd == "" {
78 return "", errors.New("environment variable ProgramData has no value")
79 }
80 return filepath.Join(pd, "AzureConnectedMachineAgent", "Tokens"), nil
81 default:
82 return "", fmt.Errorf("unsupported OS %q", runtime.GOOS)
83 }
84}
85
86type wrappedNumber json.Number
87
88func (n *wrappedNumber) UnmarshalJSON(b []byte) error {
89 c := string(b)
90 if c == "\"\"" {
91 return nil
92 }
93 return json.Unmarshal(b, (*json.Number)(n))
94}
95
96// setIMDSRetryOptionDefaults sets zero-valued fields to default values appropriate for IMDS
97func setIMDSRetryOptionDefaults(o *policy.RetryOptions) {
98 if o.MaxRetries == 0 {
99 o.MaxRetries = 5
100 }
101 if o.MaxRetryDelay == 0 {
102 o.MaxRetryDelay = 1 * time.Minute
103 }
104 if o.RetryDelay == 0 {
105 o.RetryDelay = 2 * time.Second
106 }
107 if o.StatusCodes == nil {
108 o.StatusCodes = []int{
109 // IMDS docs recommend retrying 404, 410, 429 and 5xx
110 // https://learn.microsoft.com/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#error-handling
111 http.StatusNotFound, // 404
112 http.StatusGone, // 410
113 http.StatusTooManyRequests, // 429
114 http.StatusInternalServerError, // 500
115 http.StatusNotImplemented, // 501
116 http.StatusBadGateway, // 502
117 http.StatusServiceUnavailable, // 503
118 http.StatusGatewayTimeout, // 504
119 http.StatusHTTPVersionNotSupported, // 505
120 http.StatusVariantAlsoNegotiates, // 506
121 http.StatusInsufficientStorage, // 507
122 http.StatusLoopDetected, // 508
123 http.StatusNotExtended, // 510
124 http.StatusNetworkAuthenticationRequired, // 511
125 }
126 }
127 if o.TryTimeout == 0 {
128 o.TryTimeout = 1 * time.Minute
129 }
130}
131
132// newManagedIdentityClient creates a new instance of the ManagedIdentityClient with the ManagedIdentityCredentialOptions
133// that are passed into it along with a default pipeline.
134// options: ManagedIdentityCredentialOptions configure policies for the pipeline and the authority host that
135// will be used to retrieve tokens and authenticate
136func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*managedIdentityClient, error) {
137 if options == nil {
138 options = &ManagedIdentityCredentialOptions{}
139 }
140 cp := options.ClientOptions
141 c := managedIdentityClient{id: options.ID, endpoint: imdsEndpoint, msiType: msiTypeIMDS}
142 env := "IMDS"
143 if endpoint, ok := os.LookupEnv(identityEndpoint); ok {
144 if _, ok := os.LookupEnv(identityHeader); ok {
145 if _, ok := os.LookupEnv(identityServerThumbprint); ok {
146 env = "Service Fabric"
147 c.endpoint = endpoint
148 c.msiType = msiTypeServiceFabric
149 } else {
150 env = "App Service"
151 c.endpoint = endpoint
152 c.msiType = msiTypeAppService
153 }
154 } else if _, ok := os.LookupEnv(arcIMDSEndpoint); ok {
155 env = "Azure Arc"
156 c.endpoint = endpoint
157 c.msiType = msiTypeAzureArc
158 }
159 } else if endpoint, ok := os.LookupEnv(msiEndpoint); ok {
160 c.endpoint = endpoint
161 if _, ok := os.LookupEnv(msiSecret); ok {
162 env = "Azure ML"
163 c.msiType = msiTypeAzureML
164 } else {
165 env = "Cloud Shell"
166 c.msiType = msiTypeCloudShell
167 }
168 } else {
169 c.probeIMDS = options.dac
170 setIMDSRetryOptionDefaults(&cp.Retry)
171 }
172
173 client, err := azcore.NewClient(module, version, azruntime.PipelineOptions{
174 Tracing: azruntime.TracingOptions{
175 Namespace: traceNamespace,
176 },
177 }, &cp)
178 if err != nil {
179 return nil, err
180 }
181 c.azClient = client
182
183 if log.Should(EventAuthentication) {
184 log.Writef(EventAuthentication, "Managed Identity Credential will use %s managed identity", env)
185 }
186
187 return &c, nil
188}
189
190// provideToken acquires a token for MSAL's confidential.Client, which caches the token
191func (c *managedIdentityClient) provideToken(ctx context.Context, params confidential.TokenProviderParameters) (confidential.TokenProviderResult, error) {
192 result := confidential.TokenProviderResult{}
193 tk, err := c.authenticate(ctx, c.id, params.Scopes)
194 if err == nil {
195 result.AccessToken = tk.Token
196 result.ExpiresInSeconds = int(time.Until(tk.ExpiresOn).Seconds())
197 }
198 return result, err
199}
200
201// authenticate acquires an access token
202func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKind, scopes []string) (azcore.AccessToken, error) {
203 // no need to synchronize around this value because it's true only when DefaultAzureCredential constructed the client,
204 // and in that case ChainedTokenCredential.GetToken synchronizes goroutines that would execute this block
205 if c.probeIMDS {
206 cx, cancel := context.WithTimeout(ctx, imdsProbeTimeout)
207 defer cancel()
208 cx = policy.WithRetryOptions(cx, policy.RetryOptions{MaxRetries: -1})
209 req, err := azruntime.NewRequest(cx, http.MethodGet, c.endpoint)
210 if err == nil {
211 _, err = c.azClient.Pipeline().Do(req)
212 }
213 if err != nil {
214 msg := err.Error()
215 if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
216 msg = "managed identity timed out. See https://aka.ms/azsdk/go/identity/troubleshoot#dac for more information"
217 }
218 return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg)
219 }
220 // send normal token requests from now on because something responded
221 c.probeIMDS = false
222 }
223
224 msg, err := c.createAuthRequest(ctx, id, scopes)
225 if err != nil {
226 return azcore.AccessToken{}, err
227 }
228
229 resp, err := c.azClient.Pipeline().Do(msg)
230 if err != nil {
231 return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, err.Error(), nil, err)
232 }
233
234 if azruntime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) {
235 return c.createAccessToken(resp)
236 }
237
238 if c.msiType == msiTypeIMDS {
239 switch resp.StatusCode {
240 case http.StatusBadRequest:
241 if id != nil {
242 return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp, nil)
243 }
244 msg := "failed to authenticate a system assigned identity"
245 if body, err := azruntime.Payload(resp); err == nil && len(body) > 0 {
246 msg += fmt.Sprintf(". The endpoint responded with %s", body)
247 }
248 return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg)
249 case http.StatusForbidden:
250 // Docker Desktop runs a proxy that responds 403 to IMDS token requests. If we get that response,
251 // we return credentialUnavailableError so credential chains continue to their next credential
252 body, err := azruntime.Payload(resp)
253 if err == nil && strings.Contains(string(body), "unreachable") {
254 return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, fmt.Sprintf("unexpected response %q", string(body)))
255 }
256 }
257 }
258
259 return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "authentication failed", resp, nil)
260}
261
262func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.AccessToken, error) {
263 value := struct {
264 // these are the only fields that we use
265 Token string `json:"access_token,omitempty"`
266 RefreshToken string `json:"refresh_token,omitempty"`
267 ExpiresIn wrappedNumber `json:"expires_in,omitempty"` // this field should always return the number of seconds for which a token is valid
268 ExpiresOn interface{} `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string
269 }{}
270 if err := azruntime.UnmarshalAsJSON(res, &value); err != nil {
271 return azcore.AccessToken{}, fmt.Errorf("internal AccessToken: %v", err)
272 }
273 if value.ExpiresIn != "" {
274 expiresIn, err := json.Number(value.ExpiresIn).Int64()
275 if err != nil {
276 return azcore.AccessToken{}, err
277 }
278 return azcore.AccessToken{Token: value.Token, ExpiresOn: time.Now().Add(time.Second * time.Duration(expiresIn)).UTC()}, nil
279 }
280 switch v := value.ExpiresOn.(type) {
281 case float64:
282 return azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(v), 0).UTC()}, nil
283 case string:
284 if expiresOn, err := strconv.Atoi(v); err == nil {
285 return azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(expiresOn), 0).UTC()}, nil
286 }
287 return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "unexpected expires_on value: "+v, res, nil)
288 default:
289 msg := fmt.Sprintf("unsupported type received in expires_on: %T, %v", v, v)
290 return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, msg, res, nil)
291 }
292}
293
294func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
295 switch c.msiType {
296 case msiTypeIMDS:
297 return c.createIMDSAuthRequest(ctx, id, scopes)
298 case msiTypeAppService:
299 return c.createAppServiceAuthRequest(ctx, id, scopes)
300 case msiTypeAzureArc:
301 // need to perform preliminary request to retreive the secret key challenge provided by the HIMDS service
302 key, err := c.getAzureArcSecretKey(ctx, scopes)
303 if err != nil {
304 msg := fmt.Sprintf("failed to retreive secret key from the identity endpoint: %v", err)
305 return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil, err)
306 }
307 return c.createAzureArcAuthRequest(ctx, id, scopes, key)
308 case msiTypeAzureML:
309 return c.createAzureMLAuthRequest(ctx, id, scopes)
310 case msiTypeServiceFabric:
311 return c.createServiceFabricAuthRequest(ctx, id, scopes)
312 case msiTypeCloudShell:
313 return c.createCloudShellAuthRequest(ctx, id, scopes)
314 default:
315 return nil, newCredentialUnavailableError(credNameManagedIdentity, "managed identity isn't supported in this environment")
316 }
317}
318
319func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
320 request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
321 if err != nil {
322 return nil, err
323 }
324 request.Raw().Header.Set(headerMetadata, "true")
325 q := request.Raw().URL.Query()
326 q.Add("api-version", imdsAPIVersion)
327 q.Add("resource", strings.Join(scopes, " "))
328 if id != nil {
329 if id.idKind() == miResourceID {
330 q.Add(msiResID, id.String())
331 } else {
332 q.Add(qpClientID, id.String())
333 }
334 }
335 request.Raw().URL.RawQuery = q.Encode()
336 return request, nil
337}
338
339func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
340 request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
341 if err != nil {
342 return nil, err
343 }
344 request.Raw().Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeader))
345 q := request.Raw().URL.Query()
346 q.Add("api-version", "2019-08-01")
347 q.Add("resource", scopes[0])
348 if id != nil {
349 if id.idKind() == miResourceID {
350 q.Add(miResID, id.String())
351 } else {
352 q.Add(qpClientID, id.String())
353 }
354 }
355 request.Raw().URL.RawQuery = q.Encode()
356 return request, nil
357}
358
359func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
360 request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
361 if err != nil {
362 return nil, err
363 }
364 request.Raw().Header.Set("secret", os.Getenv(msiSecret))
365 q := request.Raw().URL.Query()
366 q.Add("api-version", "2017-09-01")
367 q.Add("resource", strings.Join(scopes, " "))
368 q.Add("clientid", os.Getenv(defaultIdentityClientID))
369 if id != nil {
370 if id.idKind() == miResourceID {
371 log.Write(EventAuthentication, "WARNING: Azure ML doesn't support specifying a managed identity by resource ID")
372 q.Set("clientid", "")
373 q.Set(miResID, id.String())
374 } else {
375 q.Set("clientid", id.String())
376 }
377 }
378 request.Raw().URL.RawQuery = q.Encode()
379 return request, nil
380}
381
382func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
383 request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
384 if err != nil {
385 return nil, err
386 }
387 q := request.Raw().URL.Query()
388 request.Raw().Header.Set("Accept", "application/json")
389 request.Raw().Header.Set("Secret", os.Getenv(identityHeader))
390 q.Add("api-version", serviceFabricAPIVersion)
391 q.Add("resource", strings.Join(scopes, " "))
392 if id != nil {
393 log.Write(EventAuthentication, "WARNING: Service Fabric doesn't support selecting a user-assigned identity at runtime")
394 if id.idKind() == miResourceID {
395 q.Add(miResID, id.String())
396 } else {
397 q.Add(qpClientID, id.String())
398 }
399 }
400 request.Raw().URL.RawQuery = q.Encode()
401 return request, nil
402}
403
404func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resources []string) (string, error) {
405 // create the request to retreive the secret key challenge provided by the HIMDS service
406 request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
407 if err != nil {
408 return "", err
409 }
410 request.Raw().Header.Set(headerMetadata, "true")
411 q := request.Raw().URL.Query()
412 q.Add("api-version", azureArcAPIVersion)
413 q.Add("resource", strings.Join(resources, " "))
414 request.Raw().URL.RawQuery = q.Encode()
415 // send the initial request to get the short-lived secret key
416 response, err := c.azClient.Pipeline().Do(request)
417 if err != nil {
418 return "", err
419 }
420 // the endpoint is expected to return a 401 with the WWW-Authenticate header set to the location
421 // of the secret key file. Any other status code indicates an error in the request.
422 if response.StatusCode != 401 {
423 msg := fmt.Sprintf("expected a 401 response, received %d", response.StatusCode)
424 return "", newAuthenticationFailedError(credNameManagedIdentity, msg, response, nil)
425 }
426 header := response.Header.Get("WWW-Authenticate")
427 if len(header) == 0 {
428 return "", newAuthenticationFailedError(credNameManagedIdentity, "HIMDS response has no WWW-Authenticate header", nil, nil)
429 }
430 // the WWW-Authenticate header is expected in the following format: Basic realm=/some/file/path.key
431 _, p, found := strings.Cut(header, "=")
432 if !found {
433 return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected WWW-Authenticate header from HIMDS: "+header, nil, nil)
434 }
435 expected, err := arcKeyDirectory()
436 if err != nil {
437 return "", err
438 }
439 if filepath.Dir(p) != expected || !strings.HasSuffix(p, ".key") {
440 return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected file path from HIMDS service: "+p, nil, nil)
441 }
442 f, err := os.Stat(p)
443 if err != nil {
444 return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not stat %q: %v", p, err), nil, nil)
445 }
446 if s := f.Size(); s > 4096 {
447 return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("key is too large (%d bytes)", s), nil, nil)
448 }
449 key, err := os.ReadFile(p)
450 if err != nil {
451 return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not read %q: %v", p, err), nil, nil)
452 }
453 return string(key), nil
454}
455
456func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, id ManagedIDKind, resources []string, key string) (*policy.Request, error) {
457 request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
458 if err != nil {
459 return nil, err
460 }
461 request.Raw().Header.Set(headerMetadata, "true")
462 request.Raw().Header.Set("Authorization", fmt.Sprintf("Basic %s", key))
463 q := request.Raw().URL.Query()
464 q.Add("api-version", azureArcAPIVersion)
465 q.Add("resource", strings.Join(resources, " "))
466 if id != nil {
467 log.Write(EventAuthentication, "WARNING: Azure Arc doesn't support user-assigned managed identities")
468 if id.idKind() == miResourceID {
469 q.Add(miResID, id.String())
470 } else {
471 q.Add(qpClientID, id.String())
472 }
473 }
474 request.Raw().URL.RawQuery = q.Encode()
475 return request, nil
476}
477
478func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
479 request, err := azruntime.NewRequest(ctx, http.MethodPost, c.endpoint)
480 if err != nil {
481 return nil, err
482 }
483 request.Raw().Header.Set(headerMetadata, "true")
484 data := url.Values{}
485 data.Set("resource", strings.Join(scopes, " "))
486 dataEncoded := data.Encode()
487 body := streaming.NopCloser(strings.NewReader(dataEncoded))
488 if err := request.SetBody(body, "application/x-www-form-urlencoded"); err != nil {
489 return nil, err
490 }
491 if id != nil {
492 log.Write(EventAuthentication, "WARNING: Cloud Shell doesn't support user-assigned managed identities")
493 q := request.Raw().URL.Query()
494 if id.idKind() == miResourceID {
495 q.Add(miResID, id.String())
496 } else {
497 q.Add(qpClientID, id.String())
498 }
499 }
500 return request, nil
501}