public_client.go

  1//go:build go1.18
  2// +build go1.18
  3
  4// Copyright (c) Microsoft Corporation. All rights reserved.
  5// Licensed under the MIT License.
  6
  7package azidentity
  8
  9import (
 10	"context"
 11	"errors"
 12	"fmt"
 13	"net/http"
 14	"strings"
 15	"sync"
 16
 17	"github.com/Azure/azure-sdk-for-go/sdk/azcore"
 18	"github.com/Azure/azure-sdk-for-go/sdk/azcore/cloud"
 19	"github.com/Azure/azure-sdk-for-go/sdk/azcore/policy"
 20	"github.com/Azure/azure-sdk-for-go/sdk/azcore/runtime"
 21	"github.com/Azure/azure-sdk-for-go/sdk/azidentity/internal"
 22	"github.com/Azure/azure-sdk-for-go/sdk/internal/log"
 23	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/public"
 24
 25	// this import ensures well-known configurations in azcore/cloud have ARM audiences for Authenticate()
 26	_ "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm/runtime"
 27)
 28
 29type publicClientOptions struct {
 30	azcore.ClientOptions
 31
 32	AdditionallyAllowedTenants     []string
 33	DeviceCodePrompt               func(context.Context, DeviceCodeMessage) error
 34	DisableAutomaticAuthentication bool
 35	DisableInstanceDiscovery       bool
 36	LoginHint, RedirectURL         string
 37	Record                         authenticationRecord
 38	TokenCachePersistenceOptions   *tokenCachePersistenceOptions
 39	Username, Password             string
 40}
 41
 42// publicClient wraps the MSAL public client
 43type publicClient struct {
 44	cae, noCAE               msalPublicClient
 45	caeMu, noCAEMu, clientMu *sync.Mutex
 46	clientID, tenantID       string
 47	defaultScope             []string
 48	host                     string
 49	name                     string
 50	opts                     publicClientOptions
 51	record                   authenticationRecord
 52	azClient                 *azcore.Client
 53}
 54
 55var errScopeRequired = errors.New("authenticating in this environment requires specifying a scope in TokenRequestOptions")
 56
 57func newPublicClient(tenantID, clientID, name string, o publicClientOptions) (*publicClient, error) {
 58	if !validTenantID(tenantID) {
 59		return nil, errInvalidTenantID
 60	}
 61	host, err := setAuthorityHost(o.Cloud)
 62	if err != nil {
 63		return nil, err
 64	}
 65	// if the application specified a cloud configuration, use its ARM audience as the default scope for Authenticate()
 66	audience := o.Cloud.Services[cloud.ResourceManager].Audience
 67	if audience == "" {
 68		// no cloud configuration, or no ARM audience, specified; try to map the host to a well-known one (all of which have a trailing slash)
 69		if !strings.HasSuffix(host, "/") {
 70			host += "/"
 71		}
 72		switch host {
 73		case cloud.AzureChina.ActiveDirectoryAuthorityHost:
 74			audience = cloud.AzureChina.Services[cloud.ResourceManager].Audience
 75		case cloud.AzureGovernment.ActiveDirectoryAuthorityHost:
 76			audience = cloud.AzureGovernment.Services[cloud.ResourceManager].Audience
 77		case cloud.AzurePublic.ActiveDirectoryAuthorityHost:
 78			audience = cloud.AzurePublic.Services[cloud.ResourceManager].Audience
 79		}
 80	}
 81	// if we didn't come up with an audience, the application will have to specify a scope for Authenticate()
 82	var defaultScope []string
 83	if audience != "" {
 84		defaultScope = []string{audience + defaultSuffix}
 85	}
 86	client, err := azcore.NewClient(module, version, runtime.PipelineOptions{
 87		Tracing: runtime.TracingOptions{
 88			Namespace: traceNamespace,
 89		},
 90	}, &o.ClientOptions)
 91	if err != nil {
 92		return nil, err
 93	}
 94	o.AdditionallyAllowedTenants = resolveAdditionalTenants(o.AdditionallyAllowedTenants)
 95	return &publicClient{
 96		caeMu:        &sync.Mutex{},
 97		clientID:     clientID,
 98		clientMu:     &sync.Mutex{},
 99		defaultScope: defaultScope,
100		host:         host,
101		name:         name,
102		noCAEMu:      &sync.Mutex{},
103		opts:         o,
104		record:       o.Record,
105		tenantID:     tenantID,
106		azClient:     client,
107	}, nil
108}
109
110func (p *publicClient) Authenticate(ctx context.Context, tro *policy.TokenRequestOptions) (authenticationRecord, error) {
111	if tro == nil {
112		tro = &policy.TokenRequestOptions{}
113	}
114	if len(tro.Scopes) == 0 {
115		if p.defaultScope == nil {
116			return authenticationRecord{}, errScopeRequired
117		}
118		tro.Scopes = p.defaultScope
119	}
120	client, mu, err := p.client(*tro)
121	if err != nil {
122		return authenticationRecord{}, err
123	}
124	mu.Lock()
125	defer mu.Unlock()
126	_, err = p.reqToken(ctx, client, *tro)
127	if err == nil {
128		scope := strings.Join(tro.Scopes, ", ")
129		msg := fmt.Sprintf("%s.Authenticate() acquired a token for scope %q", p.name, scope)
130		log.Write(EventAuthentication, msg)
131	}
132	return p.record, err
133}
134
135// GetToken requests an access token from MSAL, checking the cache first.
136func (p *publicClient) GetToken(ctx context.Context, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
137	if len(tro.Scopes) < 1 {
138		return azcore.AccessToken{}, fmt.Errorf("%s.GetToken() requires at least one scope", p.name)
139	}
140	tenant, err := p.resolveTenant(tro.TenantID)
141	if err != nil {
142		return azcore.AccessToken{}, err
143	}
144	client, mu, err := p.client(tro)
145	if err != nil {
146		return azcore.AccessToken{}, err
147	}
148	mu.Lock()
149	defer mu.Unlock()
150	ar, err := client.AcquireTokenSilent(ctx, tro.Scopes, public.WithSilentAccount(p.record.account()), public.WithClaims(tro.Claims), public.WithTenantID(tenant))
151	if err == nil {
152		return p.token(ar, err)
153	}
154	if p.opts.DisableAutomaticAuthentication {
155		return azcore.AccessToken{}, newauthenticationRequiredError(p.name, tro)
156	}
157	at, err := p.reqToken(ctx, client, tro)
158	if err == nil {
159		msg := fmt.Sprintf("%s.GetToken() acquired a token for scope %q", p.name, strings.Join(ar.GrantedScopes, ", "))
160		log.Write(EventAuthentication, msg)
161	}
162	return at, err
163}
164
165// reqToken requests a token from the MSAL public client. It's separate from GetToken() to enable Authenticate() to bypass the cache.
166func (p *publicClient) reqToken(ctx context.Context, c msalPublicClient, tro policy.TokenRequestOptions) (azcore.AccessToken, error) {
167	tenant, err := p.resolveTenant(tro.TenantID)
168	if err != nil {
169		return azcore.AccessToken{}, err
170	}
171	var ar public.AuthResult
172	switch p.name {
173	case credNameBrowser:
174		ar, err = c.AcquireTokenInteractive(ctx, tro.Scopes,
175			public.WithClaims(tro.Claims),
176			public.WithLoginHint(p.opts.LoginHint),
177			public.WithRedirectURI(p.opts.RedirectURL),
178			public.WithTenantID(tenant),
179		)
180	case credNameDeviceCode:
181		dc, e := c.AcquireTokenByDeviceCode(ctx, tro.Scopes, public.WithClaims(tro.Claims), public.WithTenantID(tenant))
182		if e != nil {
183			return azcore.AccessToken{}, e
184		}
185		err = p.opts.DeviceCodePrompt(ctx, DeviceCodeMessage{
186			Message:         dc.Result.Message,
187			UserCode:        dc.Result.UserCode,
188			VerificationURL: dc.Result.VerificationURL,
189		})
190		if err == nil {
191			ar, err = dc.AuthenticationResult(ctx)
192		}
193	case credNameUserPassword:
194		ar, err = c.AcquireTokenByUsernamePassword(ctx, tro.Scopes, p.opts.Username, p.opts.Password, public.WithClaims(tro.Claims), public.WithTenantID(tenant))
195	default:
196		return azcore.AccessToken{}, fmt.Errorf("unknown credential %q", p.name)
197	}
198	return p.token(ar, err)
199}
200
201func (p *publicClient) client(tro policy.TokenRequestOptions) (msalPublicClient, *sync.Mutex, error) {
202	p.clientMu.Lock()
203	defer p.clientMu.Unlock()
204	if tro.EnableCAE {
205		if p.cae == nil {
206			client, err := p.newMSALClient(true)
207			if err != nil {
208				return nil, nil, err
209			}
210			p.cae = client
211		}
212		return p.cae, p.caeMu, nil
213	}
214	if p.noCAE == nil {
215		client, err := p.newMSALClient(false)
216		if err != nil {
217			return nil, nil, err
218		}
219		p.noCAE = client
220	}
221	return p.noCAE, p.noCAEMu, nil
222}
223
224func (p *publicClient) newMSALClient(enableCAE bool) (msalPublicClient, error) {
225	cache, err := internal.NewCache(p.opts.TokenCachePersistenceOptions, enableCAE)
226	if err != nil {
227		return nil, err
228	}
229	o := []public.Option{
230		public.WithAuthority(runtime.JoinPaths(p.host, p.tenantID)),
231		public.WithCache(cache),
232		public.WithHTTPClient(p),
233	}
234	if enableCAE {
235		o = append(o, public.WithClientCapabilities(cp1))
236	}
237	if p.opts.DisableInstanceDiscovery || strings.ToLower(p.tenantID) == "adfs" {
238		o = append(o, public.WithInstanceDiscovery(false))
239	}
240	return public.New(p.clientID, o...)
241}
242
243func (p *publicClient) token(ar public.AuthResult, err error) (azcore.AccessToken, error) {
244	if err == nil {
245		p.record, err = newAuthenticationRecord(ar)
246	} else {
247		res := getResponseFromError(err)
248		err = newAuthenticationFailedError(p.name, err.Error(), res, err)
249	}
250	return azcore.AccessToken{Token: ar.AccessToken, ExpiresOn: ar.ExpiresOn.UTC()}, err
251}
252
253// resolveTenant returns the correct WithTenantID() argument for a token request given the client's
254// configuration, or an error when that configuration doesn't allow the specified tenant
255func (p *publicClient) resolveTenant(specified string) (string, error) {
256	t, err := resolveTenant(p.tenantID, specified, p.name, p.opts.AdditionallyAllowedTenants)
257	if t == p.tenantID {
258		// callers pass this value to MSAL's WithTenantID(). There's no need to redundantly specify
259		// the client's default tenant and doing so is an error when that tenant is "organizations"
260		t = ""
261	}
262	return t, err
263}
264
265// these methods satisfy the MSAL ops.HTTPClient interface
266
267func (p *publicClient) CloseIdleConnections() {
268	// do nothing
269}
270
271func (p *publicClient) Do(r *http.Request) (*http.Response, error) {
272	return doForClient(p.azClient, r)
273}