confidential_client.go

  1//go:build go1.18
  2// +build go1.18
  3
  4// Copyright (c) Microsoft Corporation. All rights reserved.
  5// Licensed under the MIT License.
  6
  7package azidentity
  8
  9import (
 10	"context"
 11	"errors"
 12	"fmt"
 13	"net/http"
 14	"os"
 15	"strings"
 16	"sync"
 17
 18	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
 19	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
 20	"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
 21	"github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal"
 22	"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
 23	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/confidential"
 24)
 25
 26type confidentialClientOptions struct {
 27	azcore.ClientOptions
 28
 29	AdditionallyAllowedTenants []string
 30	// Assertion for on-behalf-of authentication
 31	Assertion                         string
 32	DisableInstanceDiscovery, SendX5C bool
 33	tokenCachePersistenceOptions      *tokenCachePersistenceOptions
 34}
 35
 36// confidentialClient wraps the MSAL confidential client
 37type confidentialClient struct {
 38	cae, noCAE               msalConfidentialClient
 39	caeMu, noCAEMu, clientMu *sync.Mutex
 40	clientID, tenantID       string
 41	cred                     confidential.Credential
 42	host                     string
 43	name                     string
 44	opts                     confidentialClientOptions
 45	region                   string
 46	azClient                 *azcore.Client
 47}
 48
 49func newConfidentialClient(tenantID, clientID, name string, cred confidential.Credential, opts confidentialClientOptions) (*confidentialClient, error) {
 50	if !validTenantID(tenantID) {
 51		return nil, errInvalidTenantID
 52	}
 53	host, err := setAuthorityHost(opts.Cloud)
 54	if err != nil {
 55		return nil, err
 56	}
 57	client, err := azcore.NewClient(module, version, runtime.PipelineOptions{
 58		Tracing: runtime.TracingOptions{
 59			Namespace: traceNamespace,
 60		},
 61	}, &opts.ClientOptions)
 62	if err != nil {
 63		return nil, err
 64	}
 65	opts.AdditionallyAllowedTenants = resolveAdditionalTenants(opts.AdditionallyAllowedTenants)
 66	return &confidentialClient{
 67		caeMu:    &sync.Mutex{},
 68		clientID: clientID,
 69		clientMu: &sync.Mutex{},
 70		cred:     cred,
 71		host:     host,
 72		name:     name,
 73		noCAEMu:  &sync.Mutex{},
 74		opts:     opts,
 75		region:   os.Getenv(azureRegionalAuthorityName),
 76		tenantID: tenantID,
 77		azClient: client,
 78	}, nil
 79}
 80
 81// GetToken requests an access token from MSAL, checking the cache first.
 82func (c *confidentialClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
 83	if len(tro.Scopes) < 1 {
 84		return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", c.name)
 85	}
 86	// we don't resolve the tenant for managed identities because they acquire tokens only from their home tenants
 87	if c.name != credNameManagedIdentity {
 88		tenant, err := c.resolveTenant(tro.TenantID)
 89		if err != nil {
 90			return azcore.AccessToken{}, err
 91		}
 92		tro.TenantID = tenant
 93	}
 94	client, mu, err := c.client(tro)
 95	if err != nil {
 96		return azcore.AccessToken{}, err
 97	}
 98	mu.Lock()
 99	defer mu.Unlock()
100	var ar confidential.AuthResult
101	if c.opts.Assertion != "" {
102		ar, err = client.AcquireTokenOnBehalfOf(ctx, c.opts.Assertion, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
103	} else {
104		ar, err = client.AcquireTokenSilent(ctx, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
105		if err != nil {
106			ar, err = client.AcquireTokenByCredential(ctx, tro.Scopes, confidential.WithClaims(tro.Claims), confidential.WithTenantID(tro.TenantID))
107		}
108	}
109	if err != nil {
110		// We could get a credentialUnavailableError from managed identity authentication because in that case the error comes from our code.
111		// We return it directly because it affects the behavior of credential chains. Otherwise, we return AuthenticationFailedError.
112		var unavailableErr credentialUnavailable
113		if !errors.As(err, &unavailableErr) {
114			res := getResponseFromError(err)
115			err = newAuthenticationFailedError(c.name, err.Error(), res, err)
116		}
117	} else {
118		msg := fmt.Sprintf("%s.GetToken() acquired a token for scope %q", c.name, strings.Join(ar.GrantedScopes, ", "))
119		log.Write(EventAuthentication, msg)
120	}
121	return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
122}
123
124func (c *confidentialClient) client(tro policy.TokenRequestOptions) (msalConfidentialClient, *sync.Mutex, error) {
125	c.clientMu.Lock()
126	defer c.clientMu.Unlock()
127	if tro.EnableCAE {
128		if c.cae == nil {
129			client, err := c.newMSALClient(true)
130			if err != nil {
131				return nil, nil, err
132			}
133			c.cae = client
134		}
135		return c.cae, c.caeMu, nil
136	}
137	if c.noCAE == nil {
138		client, err := c.newMSALClient(false)
139		if err != nil {
140			return nil, nil, err
141		}
142		c.noCAE = client
143	}
144	return c.noCAE, c.noCAEMu, nil
145}
146
147func (c *confidentialClient) newMSALClient(enableCAE bool) (msalConfidentialClient, error) {
148	cache, err := internal.NewCache(c.opts.tokenCachePersistenceOptions, enableCAE)
149	if err != nil {
150		return nil, err
151	}
152	authority := runtime.JoinPaths(c.host, c.tenantID)
153	o := []confidential.Option{
154		confidential.WithAzureRegion(c.region),
155		confidential.WithCache(cache),
156		confidential.WithHTTPClient(c),
157	}
158	if enableCAE {
159		o = append(o, confidential.WithClientCapabilities(cp1))
160	}
161	if c.opts.SendX5C {
162		o = append(o, confidential.WithX5C())
163	}
164	if c.opts.DisableInstanceDiscovery || strings.ToLower(c.tenantID) == "adfs" {
165		o = append(o, confidential.WithInstanceDiscovery(false))
166	}
167	return confidential.New(authority, c.clientID, c.cred, o...)
168}
169
170// resolveTenant returns the correct WithTenantID() argument for a token request given the client's
171// configuration, or an error when that configuration doesn't allow the specified tenant
172func (c *confidentialClient) resolveTenant(specified string) (string, error) {
173	return resolveTenant(c.tenantID, specified, c.name, c.opts.AdditionallyAllowedTenants)
174}
175
176// these methods satisfy the MSAL ops.HTTPClient interface
177
178func (c *confidentialClient) CloseIdleConnections() {
179	// do nothing
180}
181
182func (c *confidentialClient) Do(r *http.Request) (*http.Response, error) {
183	return doForClient(c.azClient, r)
184}