resolvers.go

  1// Copyright (c) Microsoft Corporation.
  2// Licensed under the MIT license.
  3
  4// TODO(msal): Write some tests. The original code this came from didn't have tests and I'm too
  5// tired at this point to do it. It, like many other *Manager code I found was broken because
  6// they didn't have mutex protection.
  7
  8package oauth
  9
 10import (
 11	"context"
 12	"errors"
 13	"fmt"
 14	"strings"
 15	"sync"
 16
 17	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops"
 18	"github.com/AzureAD/microsoft-authentication-library-for-go/apps/internal/oauth/ops/authority"
 19)
 20
 21// ADFS is an active directory federation service authority type.
 22const ADFS = "ADFS"
 23
 24type cacheEntry struct {
 25	Endpoints             authority.Endpoints
 26	ValidForDomainsInList map[string]bool
 27}
 28
 29func createcacheEntry(endpoints authority.Endpoints) cacheEntry {
 30	return cacheEntry{endpoints, map[string]bool{}}
 31}
 32
 33// AuthorityEndpoint retrieves endpoints from an authority for auth and token acquisition.
 34type authorityEndpoint struct {
 35	rest *ops.REST
 36
 37	mu    sync.Mutex
 38	cache map[string]cacheEntry
 39}
 40
 41// newAuthorityEndpoint is the constructor for AuthorityEndpoint.
 42func newAuthorityEndpoint(rest *ops.REST) *authorityEndpoint {
 43	m := &authorityEndpoint{rest: rest, cache: map[string]cacheEntry{}}
 44	return m
 45}
 46
 47// ResolveEndpoints gets the authorization and token endpoints and creates an AuthorityEndpoints instance
 48func (m *authorityEndpoint) ResolveEndpoints(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, error) {
 49
 50	if endpoints, found := m.cachedEndpoints(authorityInfo, userPrincipalName); found {
 51		return endpoints, nil
 52	}
 53
 54	endpoint, err := m.openIDConfigurationEndpoint(ctx, authorityInfo, userPrincipalName)
 55	if err != nil {
 56		return authority.Endpoints{}, err
 57	}
 58
 59	resp, err := m.rest.Authority().GetTenantDiscoveryResponse(ctx, endpoint)
 60	if err != nil {
 61		return authority.Endpoints{}, err
 62	}
 63	if err := resp.Validate(); err != nil {
 64		return authority.Endpoints{}, fmt.Errorf("ResolveEndpoints(): %w", err)
 65	}
 66
 67	tenant := authorityInfo.Tenant
 68
 69	endpoints := authority.NewEndpoints(
 70		strings.Replace(resp.AuthorizationEndpoint, "{tenant}", tenant, -1),
 71		strings.Replace(resp.TokenEndpoint, "{tenant}", tenant, -1),
 72		strings.Replace(resp.Issuer, "{tenant}", tenant, -1),
 73		authorityInfo.Host)
 74
 75	m.addCachedEndpoints(authorityInfo, userPrincipalName, endpoints)
 76
 77	return endpoints, nil
 78}
 79
 80// cachedEndpoints returns a the cached endpoints if they exists. If not, we return false.
 81func (m *authorityEndpoint) cachedEndpoints(authorityInfo authority.Info, userPrincipalName string) (authority.Endpoints, bool) {
 82	m.mu.Lock()
 83	defer m.mu.Unlock()
 84
 85	if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
 86		if authorityInfo.AuthorityType == ADFS {
 87			domain, err := adfsDomainFromUpn(userPrincipalName)
 88			if err == nil {
 89				if _, ok := cacheEntry.ValidForDomainsInList[domain]; ok {
 90					return cacheEntry.Endpoints, true
 91				}
 92			}
 93		}
 94		return cacheEntry.Endpoints, true
 95	}
 96	return authority.Endpoints{}, false
 97}
 98
 99func (m *authorityEndpoint) addCachedEndpoints(authorityInfo authority.Info, userPrincipalName string, endpoints authority.Endpoints) {
100	m.mu.Lock()
101	defer m.mu.Unlock()
102
103	updatedCacheEntry := createcacheEntry(endpoints)
104
105	if authorityInfo.AuthorityType == ADFS {
106		// Since we're here, we've made a call to the backend.  We want to ensure we're caching
107		// the latest values from the server.
108		if cacheEntry, ok := m.cache[authorityInfo.CanonicalAuthorityURI]; ok {
109			for k := range cacheEntry.ValidForDomainsInList {
110				updatedCacheEntry.ValidForDomainsInList[k] = true
111			}
112		}
113		domain, err := adfsDomainFromUpn(userPrincipalName)
114		if err == nil {
115			updatedCacheEntry.ValidForDomainsInList[domain] = true
116		}
117	}
118
119	m.cache[authorityInfo.CanonicalAuthorityURI] = updatedCacheEntry
120}
121
122func (m *authorityEndpoint) openIDConfigurationEndpoint(ctx context.Context, authorityInfo authority.Info, userPrincipalName string) (string, error) {
123	if authorityInfo.Tenant == "adfs" {
124		return fmt.Sprintf("https://%s/adfs/.well-known/openid-configuration", authorityInfo.Host), nil
125	} else if authorityInfo.ValidateAuthority && !authority.TrustedHost(authorityInfo.Host) {
126		resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
127		if err != nil {
128			return "", err
129		}
130		return resp.TenantDiscoveryEndpoint, nil
131	} else if authorityInfo.Region != "" {
132		resp, err := m.rest.Authority().AADInstanceDiscovery(ctx, authorityInfo)
133		if err != nil {
134			return "", err
135		}
136		return resp.TenantDiscoveryEndpoint, nil
137
138	}
139
140	return authorityInfo.CanonicalAuthorityURI + "v2.0/.well-known/openid-configuration", nil
141}
142
143func adfsDomainFromUpn(userPrincipalName string) (string, error) {
144	parts := strings.Split(userPrincipalName, "@")
145	if len(parts) < 2 {
146		return "", errors.New("no @ present in user principal name")
147	}
148	return parts[1], nil
149}