oauth.go

  1// Copyright (c) Microsoft Corporation.
  2// Licensed under the MIT license.
  3
  4package oauth
  5
  6import (
  7	"context"
  8	"encoding/json"
  9	"fmt"
 10	"io"
 11	"time"
 12
 13	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
 14	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
 15	internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
 16	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
 17	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
 18	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
 19	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
 20	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
 21	"github.com/google/uuid"
 22)
 23
 24// ResolveEndpointer contains the methods for resolving authority endpoints.
 25type ResolveEndpointer interface {
 26	ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error)
 27}
 28
 29// AccessTokens contains the methods for fetching tokens from different sources.
 30type AccessTokens interface {
 31	DeviceCodeResult(ctx context.Context, authParameters authority.AuthParams) (accesstokens.DeviceCodeResult, error)
 32	FromUsernamePassword(ctx context.Context, authParameters authority.AuthParams) (accesstokens.TokenResponse, error)
 33	FromAuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error)
 34	FromRefreshToken(ctx context.Context, appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string) (accesstokens.TokenResponse, error)
 35	FromClientSecret(ctx context.Context, authParameters authority.AuthParams, clientSecret string) (accesstokens.TokenResponse, error)
 36	FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (accesstokens.TokenResponse, error)
 37	FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion string, clientSecret string) (accesstokens.TokenResponse, error)
 38	FromUserAssertionClientCertificate(ctx context.Context, authParameters authority.AuthParams, userAssertion string, assertion string) (accesstokens.TokenResponse, error)
 39	FromDeviceCodeResult(ctx context.Context, authParameters authority.AuthParams, deviceCodeResult accesstokens.DeviceCodeResult) (accesstokens.TokenResponse, error)
 40	FromSamlGrant(ctx context.Context, authParameters authority.AuthParams, samlGrant wstrust.SamlTokenInfo) (accesstokens.TokenResponse, error)
 41}
 42
 43// FetchAuthority will be implemented by authority.Authority.
 44type FetchAuthority interface {
 45	UserRealm(context.Context, authority.AuthParams) (authority.UserRealm, error)
 46	AADInstanceDiscovery(context.Context, authority.Info) (authority.InstanceDiscoveryResponse, error)
 47}
 48
 49// FetchWSTrust contains the methods for interacting with WSTrust endpoints.
 50type FetchWSTrust interface {
 51	Mex(ctx context.Context, federationMetadataURL string) (defs.MexDocument, error)
 52	SAMLTokenInfo(ctx context.Context, authParameters authority.AuthParams, cloudAudienceURN string, endpoint defs.Endpoint) (wstrust.SamlTokenInfo, error)
 53}
 54
 55// Client provides tokens for various types of token requests.
 56type Client struct {
 57	Resolver     ResolveEndpointer
 58	AccessTokens AccessTokens
 59	Authority    FetchAuthority
 60	WSTrust      FetchWSTrust
 61}
 62
 63// New is the constructor for Token.
 64func New(httpClient ops.HTTPClient) *Client {
 65	r := ops.New(httpClient)
 66	return &Client{
 67		Resolver:     newAuthorityEndpoint(r),
 68		AccessTokens: r.AccessTokens(),
 69		Authority:    r.Authority(),
 70		WSTrust:      r.WSTrust(),
 71	}
 72}
 73
 74// ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance.
 75func (t *Client) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
 76	return t.Resolver.ResolveEndpoints(ctx, authorityInfo, userPrincipalName)
 77}
 78
 79// AADInstanceDiscovery attempts to discover a tenant endpoint (used in OIDC auth with an authorization endpoint).
 80// This is done by AAD which allows for aliasing of tenants (windows.sts.net is the same as login.windows.com).
 81func (t *Client) AADInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error) {
 82	return t.Authority.AADInstanceDiscovery(ctx, authorityInfo)
 83}
 84
 85// AuthCode returns a token based on an authorization code.
 86func (t *Client) AuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error) {
 87	if err := scopeError(req.AuthParams); err != nil {
 88		return accesstokens.TokenResponse{}, err
 89	}
 90	if err := t.resolveEndpoint(ctx, &req.AuthParams, ""); err != nil {
 91		return accesstokens.TokenResponse{}, err
 92	}
 93
 94	tResp, err := t.AccessTokens.FromAuthCode(ctx, req)
 95	if err != nil {
 96		return accesstokens.TokenResponse{}, fmt.Errorf("could not retrieve token from auth code: %w", err)
 97	}
 98	return tResp, nil
 99}
100
101// Credential acquires a token from the authority using a client credentials grant.
102func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) {
103	if cred.TokenProvider != nil {
104		now := time.Now()
105		scopes := make([]string, len(authParams.Scopes))
106		copy(scopes, authParams.Scopes)
107		params := exported.TokenProviderParameters{
108			Claims:        authParams.Claims,
109			CorrelationID: uuid.New().String(),
110			Scopes:        scopes,
111			TenantID:      authParams.AuthorityInfo.Tenant,
112		}
113		tr, err := cred.TokenProvider(ctx, params)
114		if err != nil {
115			if len(scopes) == 0 {
116				err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err)
117				return accesstokens.TokenResponse{}, err
118			}
119			return accesstokens.TokenResponse{}, err
120		}
121		return accesstokens.TokenResponse{
122			TokenType:   authParams.AuthnScheme.AccessTokenType(),
123			AccessToken: tr.AccessToken,
124			ExpiresOn: internalTime.DurationTime{
125				T: now.Add(time.Duration(tr.ExpiresInSeconds) * time.Second),
126			},
127			GrantedScopes: accesstokens.Scopes{Slice: authParams.Scopes},
128		}, nil
129	}
130
131	if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
132		return accesstokens.TokenResponse{}, err
133	}
134
135	if cred.Secret != "" {
136		return t.AccessTokens.FromClientSecret(ctx, authParams, cred.Secret)
137	}
138	jwt, err := cred.JWT(ctx, authParams)
139	if err != nil {
140		return accesstokens.TokenResponse{}, err
141	}
142	return t.AccessTokens.FromAssertion(ctx, authParams, jwt)
143}
144
145// Credential acquires a token from the authority using a client credentials grant.
146func (t *Client) OnBehalfOf(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) {
147	if err := scopeError(authParams); err != nil {
148		return accesstokens.TokenResponse{}, err
149	}
150	if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
151		return accesstokens.TokenResponse{}, err
152	}
153
154	if cred.Secret != "" {
155		return t.AccessTokens.FromUserAssertionClientSecret(ctx, authParams, authParams.UserAssertion, cred.Secret)
156	}
157	jwt, err := cred.JWT(ctx, authParams)
158	if err != nil {
159		return accesstokens.TokenResponse{}, err
160	}
161	tr, err := t.AccessTokens.FromUserAssertionClientCertificate(ctx, authParams, authParams.UserAssertion, jwt)
162	if err != nil {
163		return accesstokens.TokenResponse{}, err
164	}
165	return tr, nil
166}
167
168func (t *Client) Refresh(ctx context.Context, reqType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken accesstokens.RefreshToken) (accesstokens.TokenResponse, error) {
169	if err := scopeError(authParams); err != nil {
170		return accesstokens.TokenResponse{}, err
171	}
172	if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
173		return accesstokens.TokenResponse{}, err
174	}
175
176	tr, err := t.AccessTokens.FromRefreshToken(ctx, reqType, authParams, cc, refreshToken.Secret)
177	if err != nil {
178		return accesstokens.TokenResponse{}, err
179	}
180	return tr, nil
181}
182
183// UsernamePassword retrieves a token where a username and password is used. However, if this is
184// a user realm of "Federated", this uses SAML tokens. If "Managed", uses normal username/password.
185func (t *Client) UsernamePassword(ctx context.Context, authParams authority.AuthParams) (accesstokens.TokenResponse, error) {
186	if err := scopeError(authParams); err != nil {
187		return accesstokens.TokenResponse{}, err
188	}
189
190	if authParams.AuthorityInfo.AuthorityType == authority.ADFS {
191		if err := t.resolveEndpoint(ctx, &authParams, authParams.Username); err != nil {
192			return accesstokens.TokenResponse{}, err
193		}
194		return t.AccessTokens.FromUsernamePassword(ctx, authParams)
195	}
196	if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
197		return accesstokens.TokenResponse{}, err
198	}
199
200	userRealm, err := t.Authority.UserRealm(ctx, authParams)
201	if err != nil {
202		return accesstokens.TokenResponse{}, fmt.Errorf("problem getting user realm from authority: %w", err)
203	}
204
205	switch userRealm.AccountType {
206	case authority.Federated:
207		mexDoc, err := t.WSTrust.Mex(ctx, userRealm.FederationMetadataURL)
208		if err != nil {
209			err = fmt.Errorf("problem getting mex doc from federated url(%s): %w", userRealm.FederationMetadataURL, err)
210			return accesstokens.TokenResponse{}, err
211		}
212
213		saml, err := t.WSTrust.SAMLTokenInfo(ctx, authParams, userRealm.CloudAudienceURN, mexDoc.UsernamePasswordEndpoint)
214		if err != nil {
215			err = fmt.Errorf("problem getting SAML token info: %w", err)
216			return accesstokens.TokenResponse{}, err
217		}
218		tr, err := t.AccessTokens.FromSamlGrant(ctx, authParams, saml)
219		if err != nil {
220			return accesstokens.TokenResponse{}, err
221		}
222		return tr, nil
223	case authority.Managed:
224		if len(authParams.Scopes) == 0 {
225			err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err)
226			return accesstokens.TokenResponse{}, err
227		}
228		return t.AccessTokens.FromUsernamePassword(ctx, authParams)
229	}
230	return accesstokens.TokenResponse{}, errors.New("unknown account type")
231}
232
233// DeviceCode is the result of a call to Token.DeviceCode().
234type DeviceCode struct {
235	// Result is the device code result from the first call in the device code flow. This allows
236	// the caller to retrieve the displayed code that is used to authorize on the second device.
237	Result     accesstokens.DeviceCodeResult
238	authParams authority.AuthParams
239
240	accessTokens AccessTokens
241}
242
243// Token returns a token AFTER the user uses the user code on the second device. This will block
244// until either: (1) the code is input by the user and the service releases a token, (2) the token
245// expires, (3) the Context passed to .DeviceCode() is cancelled or expires, (4) some other service
246// error occurs.
247func (d DeviceCode) Token(ctx context.Context) (accesstokens.TokenResponse, error) {
248	if d.accessTokens == nil {
249		return accesstokens.TokenResponse{}, fmt.Errorf("DeviceCode was either created outside its package or the creating method had an error. DeviceCode is not valid")
250	}
251
252	var cancel context.CancelFunc
253	if deadline, ok := ctx.Deadline(); !ok || d.Result.ExpiresOn.Before(deadline) {
254		ctx, cancel = context.WithDeadline(ctx, d.Result.ExpiresOn)
255	} else {
256		ctx, cancel = context.WithCancel(ctx)
257	}
258	defer cancel()
259
260	var interval = 50 * time.Millisecond
261	timer := time.NewTimer(interval)
262	defer timer.Stop()
263
264	for {
265		timer.Reset(interval)
266		select {
267		case <-ctx.Done():
268			return accesstokens.TokenResponse{}, ctx.Err()
269		case <-timer.C:
270			interval += interval * 2
271			if interval > 5*time.Second {
272				interval = 5 * time.Second
273			}
274		}
275
276		token, err := d.accessTokens.FromDeviceCodeResult(ctx, d.authParams, d.Result)
277		if err != nil && isWaitDeviceCodeErr(err) {
278			continue
279		}
280		return token, err // This handles if it was a non-wait error or success
281	}
282}
283
284type deviceCodeError struct {
285	Error string `json:"error"`
286}
287
288func isWaitDeviceCodeErr(err error) bool {
289	var c errors.CallErr
290	if !errors.As(err, &c) {
291		return false
292	}
293	if c.Resp.StatusCode != 400 {
294		return false
295	}
296	var dCErr deviceCodeError
297	defer c.Resp.Body.Close()
298	body, err := io.ReadAll(c.Resp.Body)
299	if err != nil {
300		return false
301	}
302	err = json.Unmarshal(body, &dCErr)
303	if err != nil {
304		return false
305	}
306	if dCErr.Error == "authorization_pending" || dCErr.Error == "slow_down" {
307		return true
308	}
309	return false
310}
311
312// DeviceCode returns a DeviceCode object that can be used to get the code that must be entered on the second
313// device and optionally the token once the code has been entered on the second device.
314func (t *Client) DeviceCode(ctx context.Context, authParams authority.AuthParams) (DeviceCode, error) {
315	if err := scopeError(authParams); err != nil {
316		return DeviceCode{}, err
317	}
318
319	if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
320		return DeviceCode{}, err
321	}
322
323	dcr, err := t.AccessTokens.DeviceCodeResult(ctx, authParams)
324	if err != nil {
325		return DeviceCode{}, err
326	}
327
328	return DeviceCode{Result: dcr, authParams: authParams, accessTokens: t.AccessTokens}, nil
329}
330
331func (t *Client) resolveEndpoint(ctx context.Context, authParams *authority.AuthParams, userPrincipalName string) error {
332	endpoints, err := t.Resolver.ResolveEndpoints(ctx, authParams.AuthorityInfo, userPrincipalName)
333	if err != nil {
334		return fmt.Errorf("unable to resolve an endpoint: %s", err)
335	}
336	authParams.Endpoints = endpoints
337	return nil
338}
339
340// scopeError takes an authority.AuthParams and returns an error
341// if len(AuthParams.Scope) == 0.
342func scopeError(a authority.AuthParams) error {
343	// TODO(someone): we could look deeper at the message to determine if
344	// it's a scope error, but this is a good start.
345	/*
346		{error":"invalid_scope","error_description":"AADSTS1002012: The provided value for scope
347		openid offline_access profile is not valid. Client credential flows must have a scope value
348		with /.default suffixed to the resource identifier (application ID URI)...}
349	*/
350	if len(a.Scopes) == 0 {
351		return fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which is invalid")
352	}
353	return nil
354}