1// Copyright (c) Microsoft Corporation.
2// Licensed under the MIT license.
3
4// TODO(msal): Write some tests. The original code this came from didn't have tests and I'm too
5// tired at this point to do it. It, like many other *Manager code I found was broken because
6// they didn't have mutex protection.
7
8package oauth
9
10import (
11 "context"
12 "errors"
13 "fmt"
14 "strings"
15 "sync"
16
17 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
18 "github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
19)
20
21// ADFS is an active directory federation service authority type.
22const ADFS = "ADFS"
23
24type cacheEntry struct {
25 Endpoints authority.Endpoints
26 ValidForDomainsInList map[string]bool
27}
28
29func createcacheEntry(endpoints authority.Endpoints) cacheEntry {
30 return cacheEntry{endpoints, map[string]bool{}}
31}
32
33// AuthorityEndpoint retrieves endpoints from an authority for auth and token acquisition.
34type authorityEndpoint struct {
35 rest *ops.REST
36
37 mu sync.Mutex
38 cache map[string]cacheEntry
39}
40
41// newAuthorityEndpoint is the constructor for AuthorityEndpoint.
42func newAuthorityEndpoint(rest *ops.REST) *authorityEndpoint {
43 m := &authorityEndpoint{rest: rest, cache: map[string]cacheEntry{}}
44 return m
45}
46
47// ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance
48func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
49
50 if endpoints, found := m.cachedEndpoints(authorityInfo, userPrincipalName); found {
51 return endpoints, nil
52 }
53
54 endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo, userPrincipalName)
55 if err != nil {
56 return authority.Endpoints{}, err
57 }
58
59 resp, err := m.rest.Authority().GetTenantDiscoveryResponse(ctx, endpoint)
60 if err != nil {
61 return authority.Endpoints{}, err
62 }
63 if err := resp.Validate(); err != nil {
64 return authority.Endpoints{}, fmt.Errorf("ResolveEndpoints(): %w", err)
65 }
66
67 tenant := authorityInfo.Tenant
68
69 endpoints := authority.NewEndpoints(
70 strings.Replace(resp.AuthorizationEndpoint, "{tenant}", tenant, -1),
71 strings.Replace(resp.TokenEndpoint, "{tenant}", tenant, -1),
72 strings.Replace(resp.Issuer, "{tenant}", tenant, -1),
73 authorityInfo.Host)
74
75 m.addCachedEndpoints(authorityInfo, userPrincipalName, endpoints)
76
77 return endpoints, nil
78}
79
80// cachedEndpoints returns a the cached endpoints if they exists. If not, we return false.
81func (m *authorityEndpoint) cachedEndpoints(authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, bool) {
82 m.mu.Lock()
83 defer m.mu.Unlock()
84
85 if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
86 if authorityInfo.AuthorityType == ADFS {
87 domain, err := adfsDomainFromUpn(userPrincipalName)
88 if err == nil {
89 if _, ok := cacheEntry.ValidForDomainsInList[domain]; ok {
90 return cacheEntry.Endpoints, true
91 }
92 }
93 }
94 return cacheEntry.Endpoints, true
95 }
96 return authority.Endpoints{}, false
97}
98
99func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, userPrincipalName string, endpoints authority.Endpoints) {
100 m.mu.Lock()
101 defer m.mu.Unlock()
102
103 updatedCacheEntry := createcacheEntry(endpoints)
104
105 if authorityInfo.AuthorityType == ADFS {
106 // Since we're here, we've made a call to the backend. We want to ensure we're caching
107 // the latest values from the server.
108 if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
109 for k := range cacheEntry.ValidForDomainsInList {
110 updatedCacheEntry.ValidForDomainsInList[k] = true
111 }
112 }
113 domain, err := adfsDomainFromUpn(userPrincipalName)
114 if err == nil {
115 updatedCacheEntry.ValidForDomainsInList[domain] = true
116 }
117 }
118
119 m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry
120}
121
122func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (string, error) {
123 if authorityInfo.Tenant == "adfs" {
124 return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil
125 } else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) {
126 resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
127 if err != nil {
128 return "", err
129 }
130 return resp.TenantDiscoveryEndpoint, nil
131 } else if authorityInfo.Region != "" {
132 resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
133 if err != nil {
134 return "", err
135 }
136 return resp.TenantDiscoveryEndpoint, nil
137
138 }
139
140 return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil
141}
142
143func adfsDomainFromUpn(userPrincipalName string) (string, error) {
144 parts := strings.Split(userPrincipalName, "@")
145 if len(parts) < 2 {
146 return "", errors.New("no @ present in user principal name")
147 }
148 return parts[1], nil
149}