managed_identity_client.go

  1//go:build go1.18
  2// +build go1.18
  3
  4// Copyright (c) Microsoft Corporation. All rights reserved.
  5// Licensed under the MIT License.
  6
  7package azidentity
  8
  9import (
 10	"context"
 11	"encoding/json"
 12	"errors"
 13	"fmt"
 14	"net/http"
 15	"net/url"
 16	"os"
 17	"path/filepath"
 18	"runtime"
 19	"strconv"
 20	"strings"
 21	"time"
 22
 23	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
 24	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
 25	azruntime "github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
 26	"github.com/Azure/azure-sdk-for-go/sdk/azcore/streaming"
 27	"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
 28	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
 29)
 30
 31const (
 32	arcIMDSEndpoint          = "IMDS_ENDPOINT"
 33	defaultIdentityClientID  = "DEFAULT_IDENTITY_CLIENT_ID"
 34	identityEndpoint         = "IDENTITY_ENDPOINT"
 35	identityHeader           = "IDENTITY_HEADER"
 36	identityServerThumbprint = "IDENTITY_SERVER_THUMBPRINT"
 37	headerMetadata           = "Metadata"
 38	imdsEndpoint             = "http://169.254.169.254/metadata/identity/oauth2/token"
 39	miResID                  = "mi_res_id"
 40	msiEndpoint              = "MSI_ENDPOINT"
 41	msiResID                 = "msi_res_id"
 42	msiSecret                = "MSI_SECRET"
 43	imdsAPIVersion           = "2018-02-01"
 44	azureArcAPIVersion       = "2019-08-15"
 45	qpClientID               = "client_id"
 46	serviceFabricAPIVersion  = "2019-07-01-preview"
 47)
 48
 49var imdsProbeTimeout = time.Second
 50
 51type msiType int
 52
 53const (
 54	msiTypeAppService msiType = iota
 55	msiTypeAzureArc
 56	msiTypeAzureML
 57	msiTypeCloudShell
 58	msiTypeIMDS
 59	msiTypeServiceFabric
 60)
 61
 62type managedIdentityClient struct {
 63	azClient  *azcore.Client
 64	endpoint  string
 65	id        ManagedIDKind
 66	msiType   msiType
 67	probeIMDS bool
 68}
 69
 70// arcKeyDirectory returns the directory expected to contain Azure Arc keys
 71var arcKeyDirectory = func() (string, error) {
 72	switch runtime.GOOS {
 73	case "linux":
 74		return "/var/opt/azcmagent/tokens", nil
 75	case "windows":
 76		pd := os.Getenv("ProgramData")
 77		if pd == "" {
 78			return "", errors.New("environment variable ProgramData has no value")
 79		}
 80		return filepath.Join(pd, "AzureConnectedMachineAgent", "Tokens"), nil
 81	default:
 82		return "", fmt.Errorf("unsupported OS %q", runtime.GOOS)
 83	}
 84}
 85
 86type wrappedNumber json.Number
 87
 88func (n *wrappedNumber) UnmarshalJSON(b []byte) error {
 89	c := string(b)
 90	if c == "\"\"" {
 91		return nil
 92	}
 93	return json.Unmarshal(b, (*json.Number)(n))
 94}
 95
 96// setIMDSRetryOptionDefaults sets zero-valued fields to default values appropriate for IMDS
 97func setIMDSRetryOptionDefaults(o *policy.RetryOptions) {
 98	if o.MaxRetries == 0 {
 99		o.MaxRetries = 5
100	}
101	if o.MaxRetryDelay == 0 {
102		o.MaxRetryDelay = 1 * time.Minute
103	}
104	if o.RetryDelay == 0 {
105		o.RetryDelay = 2 * time.Second
106	}
107	if o.StatusCodes == nil {
108		o.StatusCodes = []int{
109			// IMDS docs recommend retrying 404, 410, 429 and 5xx
110			// https://learn.microsoft.com/entra/identity/managed-identities-azure-resources/how-to-use-vm-token#error-handling
111			http.StatusNotFound,                      // 404
112			http.StatusGone,                          // 410
113			http.StatusTooManyRequests,               // 429
114			http.StatusInternalServerError,           // 500
115			http.StatusNotImplemented,                // 501
116			http.StatusBadGateway,                    // 502
117			http.StatusServiceUnavailable,            // 503
118			http.StatusGatewayTimeout,                // 504
119			http.StatusHTTPVersionNotSupported,       // 505
120			http.StatusVariantAlsoNegotiates,         // 506
121			http.StatusInsufficientStorage,           // 507
122			http.StatusLoopDetected,                  // 508
123			http.StatusNotExtended,                   // 510
124			http.StatusNetworkAuthenticationRequired, // 511
125		}
126	}
127	if o.TryTimeout == 0 {
128		o.TryTimeout = 1 * time.Minute
129	}
130}
131
132// newManagedIdentityClient creates a new instance of the ManagedIdentityClient with the ManagedIdentityCredentialOptions
133// that are passed into it along with a default pipeline.
134// options: ManagedIdentityCredentialOptions configure policies for the pipeline and the authority host that
135// will be used to retrieve tokens and authenticate
136func newManagedIdentityClient(options *ManagedIdentityCredentialOptions) (*managedIdentityClient, error) {
137	if options == nil {
138		options = &ManagedIdentityCredentialOptions{}
139	}
140	cp := options.ClientOptions
141	c := managedIdentityClient{id: options.ID, endpoint: imdsEndpoint, msiType: msiTypeIMDS}
142	env := "IMDS"
143	if endpoint, ok := os.LookupEnv(identityEndpoint); ok {
144		if _, ok := os.LookupEnv(identityHeader); ok {
145			if _, ok := os.LookupEnv(identityServerThumbprint); ok {
146				env = "Service Fabric"
147				c.endpoint = endpoint
148				c.msiType = msiTypeServiceFabric
149			} else {
150				env = "App Service"
151				c.endpoint = endpoint
152				c.msiType = msiTypeAppService
153			}
154		} else if _, ok := os.LookupEnv(arcIMDSEndpoint); ok {
155			env = "Azure Arc"
156			c.endpoint = endpoint
157			c.msiType = msiTypeAzureArc
158		}
159	} else if endpoint, ok := os.LookupEnv(msiEndpoint); ok {
160		c.endpoint = endpoint
161		if _, ok := os.LookupEnv(msiSecret); ok {
162			env = "Azure ML"
163			c.msiType = msiTypeAzureML
164		} else {
165			env = "Cloud Shell"
166			c.msiType = msiTypeCloudShell
167		}
168	} else {
169		c.probeIMDS = options.dac
170		setIMDSRetryOptionDefaults(&cp.Retry)
171	}
172
173	client, err := azcore.NewClient(module, version, azruntime.PipelineOptions{
174		Tracing: azruntime.TracingOptions{
175			Namespace: traceNamespace,
176		},
177	}, &cp)
178	if err != nil {
179		return nil, err
180	}
181	c.azClient = client
182
183	if log.Should(EventAuthentication) {
184		log.Writef(EventAuthentication, "Managed Identity Credential will use %s managed identity", env)
185	}
186
187	return &c, nil
188}
189
190// provideToken acquires a token for MSAL's confidential.Client, which caches the token
191func (c *managedIdentityClient) provideToken(ctx context.Context, params confidential.TokenProviderParameters) (confidential.TokenProviderResult, error) {
192	result := confidential.TokenProviderResult{}
193	tk, err := c.authenticate(ctx, c.id, params.Scopes)
194	if err == nil {
195		result.AccessToken = tk.Token
196		result.ExpiresInSeconds = int(time.Until(tk.ExpiresOn).Seconds())
197	}
198	return result, err
199}
200
201// authenticate acquires an access token
202func (c *managedIdentityClient) authenticate(ctx context.Context, id ManagedIDKind, scopes []string) (azcore.AccessToken, error) {
203	// no need to synchronize around this value because it's true only when DefaultAzureCredential constructed the client,
204	// and in that case ChainedTokenCredential.GetToken synchronizes goroutines that would execute this block
205	if c.probeIMDS {
206		cx, cancel := context.WithTimeout(ctx, imdsProbeTimeout)
207		defer cancel()
208		cx = policy.WithRetryOptions(cx, policy.RetryOptions{MaxRetries: -1})
209		req, err := azruntime.NewRequest(cx, http.MethodGet, c.endpoint)
210		if err == nil {
211			_, err = c.azClient.Pipeline().Do(req)
212		}
213		if err != nil {
214			msg := err.Error()
215			if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
216				msg = "managed identity timed out. See https://aka.ms/azsdk/go/identity/troubleshoot#dac for more information"
217			}
218			return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg)
219		}
220		// send normal token requests from now on because something responded
221		c.probeIMDS = false
222	}
223
224	msg, err := c.createAuthRequest(ctx, id, scopes)
225	if err != nil {
226		return azcore.AccessToken{}, err
227	}
228
229	resp, err := c.azClient.Pipeline().Do(msg)
230	if err != nil {
231		return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, err.Error(), nil, err)
232	}
233
234	if azruntime.HasStatusCode(resp, http.StatusOK, http.StatusCreated) {
235		return c.createAccessToken(resp)
236	}
237
238	if c.msiType == msiTypeIMDS {
239		switch resp.StatusCode {
240		case http.StatusBadRequest:
241			if id != nil {
242				return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "the requested identity isn't assigned to this resource", resp, nil)
243			}
244			msg := "failed to authenticate a system assigned identity"
245			if body, err := azruntime.Payload(resp); err == nil && len(body) > 0 {
246				msg += fmt.Sprintf(". The endpoint responded with %s", body)
247			}
248			return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, msg)
249		case http.StatusForbidden:
250			// Docker Desktop runs a proxy that responds 403 to IMDS token requests. If we get that response,
251			// we return credentialUnavailableError so credential chains continue to their next credential
252			body, err := azruntime.Payload(resp)
253			if err == nil && strings.Contains(string(body), "unreachable") {
254				return azcore.AccessToken{}, newCredentialUnavailableError(credNameManagedIdentity, fmt.Sprintf("unexpected response %q", string(body)))
255			}
256		}
257	}
258
259	return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "authentication failed", resp, nil)
260}
261
262func (c *managedIdentityClient) createAccessToken(res *http.Response) (azcore.AccessToken, error) {
263	value := struct {
264		// these are the only fields that we use
265		Token        string        `json:"access_token,omitempty"`
266		RefreshToken string        `json:"refresh_token,omitempty"`
267		ExpiresIn    wrappedNumber `json:"expires_in,omitempty"` // this field should always return the number of seconds for which a token is valid
268		ExpiresOn    interface{}   `json:"expires_on,omitempty"` // the value returned in this field varies between a number and a date string
269	}{}
270	if err := azruntime.UnmarshalAsJSON(res, &value); err != nil {
271		return azcore.AccessToken{}, fmt.Errorf("internal AccessToken: %v", err)
272	}
273	if value.ExpiresIn != "" {
274		expiresIn, err := json.Number(value.ExpiresIn).Int64()
275		if err != nil {
276			return azcore.AccessToken{}, err
277		}
278		return azcore.AccessToken{Token: value.Token, ExpiresOn: time.Now().Add(time.Second * time.Duration(expiresIn)).UTC()}, nil
279	}
280	switch v := value.ExpiresOn.(type) {
281	case float64:
282		return azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(v), 0).UTC()}, nil
283	case string:
284		if expiresOn, err := strconv.Atoi(v); err == nil {
285			return azcore.AccessToken{Token: value.Token, ExpiresOn: time.Unix(int64(expiresOn), 0).UTC()}, nil
286		}
287		return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, "unexpected expires_on value: "+v, res, nil)
288	default:
289		msg := fmt.Sprintf("unsupported type received in expires_on: %T, %v", v, v)
290		return azcore.AccessToken{}, newAuthenticationFailedError(credNameManagedIdentity, msg, res, nil)
291	}
292}
293
294func (c *managedIdentityClient) createAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
295	switch c.msiType {
296	case msiTypeIMDS:
297		return c.createIMDSAuthRequest(ctx, id, scopes)
298	case msiTypeAppService:
299		return c.createAppServiceAuthRequest(ctx, id, scopes)
300	case msiTypeAzureArc:
301		// need to perform preliminary request to retreive the secret key challenge provided by the HIMDS service
302		key, err := c.getAzureArcSecretKey(ctx, scopes)
303		if err != nil {
304			msg := fmt.Sprintf("failed to retreive secret key from the identity endpoint: %v", err)
305			return nil, newAuthenticationFailedError(credNameManagedIdentity, msg, nil, err)
306		}
307		return c.createAzureArcAuthRequest(ctx, id, scopes, key)
308	case msiTypeAzureML:
309		return c.createAzureMLAuthRequest(ctx, id, scopes)
310	case msiTypeServiceFabric:
311		return c.createServiceFabricAuthRequest(ctx, id, scopes)
312	case msiTypeCloudShell:
313		return c.createCloudShellAuthRequest(ctx, id, scopes)
314	default:
315		return nil, newCredentialUnavailableError(credNameManagedIdentity, "managed identity isn't supported in this environment")
316	}
317}
318
319func (c *managedIdentityClient) createIMDSAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
320	request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
321	if err != nil {
322		return nil, err
323	}
324	request.Raw().Header.Set(headerMetadata, "true")
325	q := request.Raw().URL.Query()
326	q.Add("api-version", imdsAPIVersion)
327	q.Add("resource", strings.Join(scopes, " "))
328	if id != nil {
329		if id.idKind() == miResourceID {
330			q.Add(msiResID, id.String())
331		} else {
332			q.Add(qpClientID, id.String())
333		}
334	}
335	request.Raw().URL.RawQuery = q.Encode()
336	return request, nil
337}
338
339func (c *managedIdentityClient) createAppServiceAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
340	request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
341	if err != nil {
342		return nil, err
343	}
344	request.Raw().Header.Set("X-IDENTITY-HEADER", os.Getenv(identityHeader))
345	q := request.Raw().URL.Query()
346	q.Add("api-version", "2019-08-01")
347	q.Add("resource", scopes[0])
348	if id != nil {
349		if id.idKind() == miResourceID {
350			q.Add(miResID, id.String())
351		} else {
352			q.Add(qpClientID, id.String())
353		}
354	}
355	request.Raw().URL.RawQuery = q.Encode()
356	return request, nil
357}
358
359func (c *managedIdentityClient) createAzureMLAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
360	request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
361	if err != nil {
362		return nil, err
363	}
364	request.Raw().Header.Set("secret", os.Getenv(msiSecret))
365	q := request.Raw().URL.Query()
366	q.Add("api-version", "2017-09-01")
367	q.Add("resource", strings.Join(scopes, " "))
368	q.Add("clientid", os.Getenv(defaultIdentityClientID))
369	if id != nil {
370		if id.idKind() == miResourceID {
371			log.Write(EventAuthentication, "WARNING: Azure ML doesn't support specifying a managed identity by resource ID")
372			q.Set("clientid", "")
373			q.Set(miResID, id.String())
374		} else {
375			q.Set("clientid", id.String())
376		}
377	}
378	request.Raw().URL.RawQuery = q.Encode()
379	return request, nil
380}
381
382func (c *managedIdentityClient) createServiceFabricAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
383	request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
384	if err != nil {
385		return nil, err
386	}
387	q := request.Raw().URL.Query()
388	request.Raw().Header.Set("Accept", "application/json")
389	request.Raw().Header.Set("Secret", os.Getenv(identityHeader))
390	q.Add("api-version", serviceFabricAPIVersion)
391	q.Add("resource", strings.Join(scopes, " "))
392	if id != nil {
393		log.Write(EventAuthentication, "WARNING: Service Fabric doesn't support selecting a user-assigned identity at runtime")
394		if id.idKind() == miResourceID {
395			q.Add(miResID, id.String())
396		} else {
397			q.Add(qpClientID, id.String())
398		}
399	}
400	request.Raw().URL.RawQuery = q.Encode()
401	return request, nil
402}
403
404func (c *managedIdentityClient) getAzureArcSecretKey(ctx context.Context, resources []string) (string, error) {
405	// create the request to retreive the secret key challenge provided by the HIMDS service
406	request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
407	if err != nil {
408		return "", err
409	}
410	request.Raw().Header.Set(headerMetadata, "true")
411	q := request.Raw().URL.Query()
412	q.Add("api-version", azureArcAPIVersion)
413	q.Add("resource", strings.Join(resources, " "))
414	request.Raw().URL.RawQuery = q.Encode()
415	// send the initial request to get the short-lived secret key
416	response, err := c.azClient.Pipeline().Do(request)
417	if err != nil {
418		return "", err
419	}
420	// the endpoint is expected to return a 401 with the WWW-Authenticate header set to the location
421	// of the secret key file. Any other status code indicates an error in the request.
422	if response.StatusCode != 401 {
423		msg := fmt.Sprintf("expected a 401 response, received %d", response.StatusCode)
424		return "", newAuthenticationFailedError(credNameManagedIdentity, msg, response, nil)
425	}
426	header := response.Header.Get("WWW-Authenticate")
427	if len(header) == 0 {
428		return "", newAuthenticationFailedError(credNameManagedIdentity, "HIMDS response has no WWW-Authenticate header", nil, nil)
429	}
430	// the WWW-Authenticate header is expected in the following format: Basic realm=/some/file/path.key
431	_, p, found := strings.Cut(header, "=")
432	if !found {
433		return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected WWW-Authenticate header from HIMDS: "+header, nil, nil)
434	}
435	expected, err := arcKeyDirectory()
436	if err != nil {
437		return "", err
438	}
439	if filepath.Dir(p) != expected || !strings.HasSuffix(p, ".key") {
440		return "", newAuthenticationFailedError(credNameManagedIdentity, "unexpected file path from HIMDS service: "+p, nil, nil)
441	}
442	f, err := os.Stat(p)
443	if err != nil {
444		return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not stat %q: %v", p, err), nil, nil)
445	}
446	if s := f.Size(); s > 4096 {
447		return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("key is too large (%d bytes)", s), nil, nil)
448	}
449	key, err := os.ReadFile(p)
450	if err != nil {
451		return "", newAuthenticationFailedError(credNameManagedIdentity, fmt.Sprintf("could not read %q: %v", p, err), nil, nil)
452	}
453	return string(key), nil
454}
455
456func (c *managedIdentityClient) createAzureArcAuthRequest(ctx context.Context, id ManagedIDKind, resources []string, key string) (*policy.Request, error) {
457	request, err := azruntime.NewRequest(ctx, http.MethodGet, c.endpoint)
458	if err != nil {
459		return nil, err
460	}
461	request.Raw().Header.Set(headerMetadata, "true")
462	request.Raw().Header.Set("Authorization", fmt.Sprintf("Basic %s", key))
463	q := request.Raw().URL.Query()
464	q.Add("api-version", azureArcAPIVersion)
465	q.Add("resource", strings.Join(resources, " "))
466	if id != nil {
467		log.Write(EventAuthentication, "WARNING: Azure Arc doesn't support user-assigned managed identities")
468		if id.idKind() == miResourceID {
469			q.Add(miResID, id.String())
470		} else {
471			q.Add(qpClientID, id.String())
472		}
473	}
474	request.Raw().URL.RawQuery = q.Encode()
475	return request, nil
476}
477
478func (c *managedIdentityClient) createCloudShellAuthRequest(ctx context.Context, id ManagedIDKind, scopes []string) (*policy.Request, error) {
479	request, err := azruntime.NewRequest(ctx, http.MethodPost, c.endpoint)
480	if err != nil {
481		return nil, err
482	}
483	request.Raw().Header.Set(headerMetadata, "true")
484	data := url.Values{}
485	data.Set("resource", strings.Join(scopes, " "))
486	dataEncoded := data.Encode()
487	body := streaming.NopCloser(strings.NewReader(dataEncoded))
488	if err := request.SetBody(body, "application/x-www-form-urlencoded"); err != nil {
489		return nil, err
490	}
491	if id != nil {
492		log.Write(EventAuthentication, "WARNING: Cloud Shell doesn't support user-assigned managed identities")
493		q := request.Raw().URL.Query()
494		if id.idKind() == miResourceID {
495			q.Add(miResID, id.String())
496		} else {
497			q.Add(qpClientID, id.String())
498		}
499	}
500	return request, nil
501}