1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT license.
3
4package oauth
5
6import (
7 "context"
8 "encoding/json"
9 "fmt"
10 "io"
11 "time"
12
13 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/errors"
14 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/exported"
15 internalTime "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/json/types/time"
16 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
17 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/accesstokens"
18 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
19 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust"
20 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/wstrust/defs"
21 "github.com/google/uuid"
22)
23
24// ResolveEndpointer contains the methods for resolving authority endpoints.
25type ResolveEndpointer interface {
26 ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error)
27}
28
29// AccessTokens contains the methods for fetching tokens from different sources.
30type AccessTokens interface {
31 DeviceCodeResult(ctx context.Context, authParameters authority.AuthParams) (accesstokens.DeviceCodeResult, error)
32 FromUsernamePassword(ctx context.Context, authParameters authority.AuthParams) (accesstokens.TokenResponse, error)
33 FromAuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error)
34 FromRefreshToken(ctx context.Context, appType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken string) (accesstokens.TokenResponse, error)
35 FromClientSecret(ctx context.Context, authParameters authority.AuthParams, clientSecret string) (accesstokens.TokenResponse, error)
36 FromAssertion(ctx context.Context, authParameters authority.AuthParams, assertion string) (accesstokens.TokenResponse, error)
37 FromUserAssertionClientSecret(ctx context.Context, authParameters authority.AuthParams, userAssertion string, clientSecret string) (accesstokens.TokenResponse, error)
38 FromUserAssertionClientCertificate(ctx context.Context, authParameters authority.AuthParams, userAssertion string, assertion string) (accesstokens.TokenResponse, error)
39 FromDeviceCodeResult(ctx context.Context, authParameters authority.AuthParams, deviceCodeResult accesstokens.DeviceCodeResult) (accesstokens.TokenResponse, error)
40 FromSamlGrant(ctx context.Context, authParameters authority.AuthParams, samlGrant wstrust.SamlTokenInfo) (accesstokens.TokenResponse, error)
41}
42
43// FetchAuthority will be implemented by authority.Authority.
44type FetchAuthority interface {
45 UserRealm(context.Context, authority.AuthParams) (authority.UserRealm, error)
46 AADInstanceDiscovery(context.Context, authority.Info) (authority.InstanceDiscoveryResponse, error)
47}
48
49// FetchWSTrust contains the methods for interacting with WSTrust endpoints.
50type FetchWSTrust interface {
51 Mex(ctx context.Context, federationMetadataURL string) (defs.MexDocument, error)
52 SAMLTokenInfo(ctx context.Context, authParameters authority.AuthParams, cloudAudienceURN string, endpoint defs.Endpoint) (wstrust.SamlTokenInfo, error)
53}
54
55// Client provides tokens for various types of token requests.
56type Client struct {
57 Resolver ResolveEndpointer
58 AccessTokens AccessTokens
59 Authority FetchAuthority
60 WSTrust FetchWSTrust
61}
62
63// New is the constructor for Token.
64func New(httpClient ops.HTTPClient) *Client {
65 r := ops.New(httpClient)
66 return &Client{
67 Resolver: newAuthorityEndpoint(r),
68 AccessTokens: r.AccessTokens(),
69 Authority: r.Authority(),
70 WSTrust: r.WSTrust(),
71 }
72}
73
74// ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance.
75func (t *Client) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
76 return t.Resolver.ResolveEndpoints(ctx, authorityInfo, userPrincipalName)
77}
78
79// AADInstanceDiscovery attempts to discover a tenant endpoint (used in OIDC auth with an authorization endpoint).
80// This is done by AAD which allows for aliasing of tenants (windows.sts.net is the same as login.windows.com).
81func (t *Client) AADInstanceDiscovery(ctx context.Context, authorityInfo authority.Info) (authority.InstanceDiscoveryResponse, error) {
82 return t.Authority.AADInstanceDiscovery(ctx, authorityInfo)
83}
84
85// AuthCode returns a token based on an authorization code.
86func (t *Client) AuthCode(ctx context.Context, req accesstokens.AuthCodeRequest) (accesstokens.TokenResponse, error) {
87 if err := scopeError(req.AuthParams); err != nil {
88 return accesstokens.TokenResponse{}, err
89 }
90 if err := t.resolveEndpoint(ctx, &req.AuthParams, ""); err != nil {
91 return accesstokens.TokenResponse{}, err
92 }
93
94 tResp, err := t.AccessTokens.FromAuthCode(ctx, req)
95 if err != nil {
96 return accesstokens.TokenResponse{}, fmt.Errorf("could not retrieve token from auth code: %w", err)
97 }
98 return tResp, nil
99}
100
101// Credential acquires a token from the authority using a client credentials grant.
102func (t *Client) Credential(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) {
103 if cred.TokenProvider != nil {
104 now := time.Now()
105 scopes := make([]string, len(authParams.Scopes))
106 copy(scopes, authParams.Scopes)
107 params := exported.TokenProviderParameters{
108 Claims: authParams.Claims,
109 CorrelationID: uuid.New().String(),
110 Scopes: scopes,
111 TenantID: authParams.AuthorityInfo.Tenant,
112 }
113 tr, err := cred.TokenProvider(ctx, params)
114 if err != nil {
115 if len(scopes) == 0 {
116 err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err)
117 return accesstokens.TokenResponse{}, err
118 }
119 return accesstokens.TokenResponse{}, err
120 }
121 return accesstokens.TokenResponse{
122 TokenType: authParams.AuthnScheme.AccessTokenType(),
123 AccessToken: tr.AccessToken,
124 ExpiresOn: internalTime.DurationTime{
125 T: now.Add(time.Duration(tr.ExpiresInSeconds) * time.Second),
126 },
127 GrantedScopes: accesstokens.Scopes{Slice: authParams.Scopes},
128 }, nil
129 }
130
131 if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
132 return accesstokens.TokenResponse{}, err
133 }
134
135 if cred.Secret != "" {
136 return t.AccessTokens.FromClientSecret(ctx, authParams, cred.Secret)
137 }
138 jwt, err := cred.JWT(ctx, authParams)
139 if err != nil {
140 return accesstokens.TokenResponse{}, err
141 }
142 return t.AccessTokens.FromAssertion(ctx, authParams, jwt)
143}
144
145// Credential acquires a token from the authority using a client credentials grant.
146func (t *Client) OnBehalfOf(ctx context.Context, authParams authority.AuthParams, cred *accesstokens.Credential) (accesstokens.TokenResponse, error) {
147 if err := scopeError(authParams); err != nil {
148 return accesstokens.TokenResponse{}, err
149 }
150 if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
151 return accesstokens.TokenResponse{}, err
152 }
153
154 if cred.Secret != "" {
155 return t.AccessTokens.FromUserAssertionClientSecret(ctx, authParams, authParams.UserAssertion, cred.Secret)
156 }
157 jwt, err := cred.JWT(ctx, authParams)
158 if err != nil {
159 return accesstokens.TokenResponse{}, err
160 }
161 tr, err := t.AccessTokens.FromUserAssertionClientCertificate(ctx, authParams, authParams.UserAssertion, jwt)
162 if err != nil {
163 return accesstokens.TokenResponse{}, err
164 }
165 return tr, nil
166}
167
168func (t *Client) Refresh(ctx context.Context, reqType accesstokens.AppType, authParams authority.AuthParams, cc *accesstokens.Credential, refreshToken accesstokens.RefreshToken) (accesstokens.TokenResponse, error) {
169 if err := scopeError(authParams); err != nil {
170 return accesstokens.TokenResponse{}, err
171 }
172 if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
173 return accesstokens.TokenResponse{}, err
174 }
175
176 tr, err := t.AccessTokens.FromRefreshToken(ctx, reqType, authParams, cc, refreshToken.Secret)
177 if err != nil {
178 return accesstokens.TokenResponse{}, err
179 }
180 return tr, nil
181}
182
183// UsernamePassword retrieves a token where a username and password is used. However, if this is
184// a user realm of "Federated", this uses SAML tokens. If "Managed", uses normal username/password.
185func (t *Client) UsernamePassword(ctx context.Context, authParams authority.AuthParams) (accesstokens.TokenResponse, error) {
186 if err := scopeError(authParams); err != nil {
187 return accesstokens.TokenResponse{}, err
188 }
189
190 if authParams.AuthorityInfo.AuthorityType == authority.ADFS {
191 if err := t.resolveEndpoint(ctx, &authParams, authParams.Username); err != nil {
192 return accesstokens.TokenResponse{}, err
193 }
194 return t.AccessTokens.FromUsernamePassword(ctx, authParams)
195 }
196 if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
197 return accesstokens.TokenResponse{}, err
198 }
199
200 userRealm, err := t.Authority.UserRealm(ctx, authParams)
201 if err != nil {
202 return accesstokens.TokenResponse{}, fmt.Errorf("problem getting user realm from authority: %w", err)
203 }
204
205 switch userRealm.AccountType {
206 case authority.Federated:
207 mexDoc, err := t.WSTrust.Mex(ctx, userRealm.FederationMetadataURL)
208 if err != nil {
209 err = fmt.Errorf("problem getting mex doc from federated url(%s): %w", userRealm.FederationMetadataURL, err)
210 return accesstokens.TokenResponse{}, err
211 }
212
213 saml, err := t.WSTrust.SAMLTokenInfo(ctx, authParams, userRealm.CloudAudienceURN, mexDoc.UsernamePasswordEndpoint)
214 if err != nil {
215 err = fmt.Errorf("problem getting SAML token info: %w", err)
216 return accesstokens.TokenResponse{}, err
217 }
218 tr, err := t.AccessTokens.FromSamlGrant(ctx, authParams, saml)
219 if err != nil {
220 return accesstokens.TokenResponse{}, err
221 }
222 return tr, nil
223 case authority.Managed:
224 if len(authParams.Scopes) == 0 {
225 err = fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which may cause the following error: %w", err)
226 return accesstokens.TokenResponse{}, err
227 }
228 return t.AccessTokens.FromUsernamePassword(ctx, authParams)
229 }
230 return accesstokens.TokenResponse{}, errors.New("unknown account type")
231}
232
233// DeviceCode is the result of a call to Token.DeviceCode().
234type DeviceCode struct {
235 // Result is the device code result from the first call in the device code flow. This allows
236 // the caller to retrieve the displayed code that is used to authorize on the second device.
237 Result accesstokens.DeviceCodeResult
238 authParams authority.AuthParams
239
240 accessTokens AccessTokens
241}
242
243// Token returns a token AFTER the user uses the user code on the second device. This will block
244// until either: (1) the code is input by the user and the service releases a token, (2) the token
245// expires, (3) the Context passed to .DeviceCode() is cancelled or expires, (4) some other service
246// error occurs.
247func (d DeviceCode) Token(ctx context.Context) (accesstokens.TokenResponse, error) {
248 if d.accessTokens == nil {
249 return accesstokens.TokenResponse{}, fmt.Errorf("DeviceCode was either created outside its package or the creating method had an error. DeviceCode is not valid")
250 }
251
252 var cancel context.CancelFunc
253 if deadline, ok := ctx.Deadline(); !ok || d.Result.ExpiresOn.Before(deadline) {
254 ctx, cancel = context.WithDeadline(ctx, d.Result.ExpiresOn)
255 } else {
256 ctx, cancel = context.WithCancel(ctx)
257 }
258 defer cancel()
259
260 var interval = 50 * time.Millisecond
261 timer := time.NewTimer(interval)
262 defer timer.Stop()
263
264 for {
265 timer.Reset(interval)
266 select {
267 case <-ctx.Done():
268 return accesstokens.TokenResponse{}, ctx.Err()
269 case <-timer.C:
270 interval += interval * 2
271 if interval > 5*time.Second {
272 interval = 5 * time.Second
273 }
274 }
275
276 token, err := d.accessTokens.FromDeviceCodeResult(ctx, d.authParams, d.Result)
277 if err != nil && isWaitDeviceCodeErr(err) {
278 continue
279 }
280 return token, err // This handles if it was a non-wait error or success
281 }
282}
283
284type deviceCodeError struct {
285 Error string `json:"error"`
286}
287
288func isWaitDeviceCodeErr(err error) bool {
289 var c errors.CallErr
290 if !errors.As(err, &c) {
291 return false
292 }
293 if c.Resp.StatusCode != 400 {
294 return false
295 }
296 var dCErr deviceCodeError
297 defer c.Resp.Body.Close()
298 body, err := io.ReadAll(c.Resp.Body)
299 if err != nil {
300 return false
301 }
302 err = json.Unmarshal(body, &dCErr)
303 if err != nil {
304 return false
305 }
306 if dCErr.Error == "authorization_pending" || dCErr.Error == "slow_down" {
307 return true
308 }
309 return false
310}
311
312// DeviceCode returns a DeviceCode object that can be used to get the code that must be entered on the second
313// device and optionally the token once the code has been entered on the second device.
314func (t *Client) DeviceCode(ctx context.Context, authParams authority.AuthParams) (DeviceCode, error) {
315 if err := scopeError(authParams); err != nil {
316 return DeviceCode{}, err
317 }
318
319 if err := t.resolveEndpoint(ctx, &authParams, ""); err != nil {
320 return DeviceCode{}, err
321 }
322
323 dcr, err := t.AccessTokens.DeviceCodeResult(ctx, authParams)
324 if err != nil {
325 return DeviceCode{}, err
326 }
327
328 return DeviceCode{Result: dcr, authParams: authParams, accessTokens: t.AccessTokens}, nil
329}
330
331func (t *Client) resolveEndpoint(ctx context.Context, authParams *authority.AuthParams, userPrincipalName string) error {
332 endpoints, err := t.Resolver.ResolveEndpoints(ctx, authParams.AuthorityInfo, userPrincipalName)
333 if err != nil {
334 return fmt.Errorf("unable to resolve an endpoint: %s", err)
335 }
336 authParams.Endpoints = endpoints
337 return nil
338}
339
340// scopeError takes an authority.AuthParams and returns an error
341// if len(AuthParams.Scope) == 0.
342func scopeError(a authority.AuthParams) error {
343 // TODO(someone): we could look deeper at the message to determine if
344 // it's a scope error, but this is a good start.
345 /*
346 {error":"invalid_scope","error_description":"AADSTS1002012: The provided value for scope
347 openid offline_access profile is not valid. Client credential flows must have a scope value
348 with /.default suffixed to the resource identifier (application ID URI)...}
349 */
350 if len(a.Scopes) == 0 {
351 return fmt.Errorf("token request had an empty authority.AuthParams.Scopes, which is invalid")
352 }
353 return nil
354}