threelegged.go

  1// Copyright 2023 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package auth
 16
 17import (
 18	"bytes"
 19	"context"
 20	"encoding/json"
 21	"errors"
 22	"fmt"
 23	"log/slog"
 24	"mime"
 25	"net/http"
 26	"net/url"
 27	"strconv"
 28	"strings"
 29	"time"
 30
 31	"cloud.google.com/go/auth/internal"
 32	"github.com/googleapis/gax-go/v2/internallog"
 33)
 34
 35// AuthorizationHandler is a 3-legged-OAuth helper that prompts the user for
 36// OAuth consent at the specified auth code URL and returns an auth code and
 37// state upon approval.
 38type AuthorizationHandler func(authCodeURL string) (code string, state string, err error)
 39
 40// Options3LO are the options for doing a 3-legged OAuth2 flow.
 41type Options3LO struct {
 42	// ClientID is the application's ID.
 43	ClientID string
 44	// ClientSecret is the application's secret. Not required if AuthHandlerOpts
 45	// is set.
 46	ClientSecret string
 47	// AuthURL is the URL for authenticating.
 48	AuthURL string
 49	// TokenURL is the URL for retrieving a token.
 50	TokenURL string
 51	// AuthStyle is used to describe how to client info in the token request.
 52	AuthStyle Style
 53	// RefreshToken is the token used to refresh the credential. Not required
 54	// if AuthHandlerOpts is set.
 55	RefreshToken string
 56	// RedirectURL is the URL to redirect users to. Optional.
 57	RedirectURL string
 58	// Scopes specifies requested permissions for the Token. Optional.
 59	Scopes []string
 60
 61	// URLParams are the set of values to apply to the token exchange. Optional.
 62	URLParams url.Values
 63	// Client is the client to be used to make the underlying token requests.
 64	// Optional.
 65	Client *http.Client
 66	// EarlyTokenExpiry is the time before the token expires that it should be
 67	// refreshed. If not set the default value is 3 minutes and 45 seconds.
 68	// Optional.
 69	EarlyTokenExpiry time.Duration
 70
 71	// AuthHandlerOpts provides a set of options for doing a
 72	// 3-legged OAuth2 flow with a custom [AuthorizationHandler]. Optional.
 73	AuthHandlerOpts *AuthorizationHandlerOptions
 74	// Logger is used for debug logging. If provided, logging will be enabled
 75	// at the loggers configured level. By default logging is disabled unless
 76	// enabled by setting GOOGLE_SDK_GO_LOGGING_LEVEL in which case a default
 77	// logger will be used. Optional.
 78	Logger *slog.Logger
 79}
 80
 81func (o *Options3LO) validate() error {
 82	if o == nil {
 83		return errors.New("auth: options must be provided")
 84	}
 85	if o.ClientID == "" {
 86		return errors.New("auth: client ID must be provided")
 87	}
 88	if o.AuthHandlerOpts == nil && o.ClientSecret == "" {
 89		return errors.New("auth: client secret must be provided")
 90	}
 91	if o.AuthURL == "" {
 92		return errors.New("auth: auth URL must be provided")
 93	}
 94	if o.TokenURL == "" {
 95		return errors.New("auth: token URL must be provided")
 96	}
 97	if o.AuthStyle == StyleUnknown {
 98		return errors.New("auth: auth style must be provided")
 99	}
100	if o.AuthHandlerOpts == nil && o.RefreshToken == "" {
101		return errors.New("auth: refresh token must be provided")
102	}
103	return nil
104}
105
106func (o *Options3LO) logger() *slog.Logger {
107	return internallog.New(o.Logger)
108}
109
110// PKCEOptions holds parameters to support PKCE.
111type PKCEOptions struct {
112	// Challenge is the un-padded, base64-url-encoded string of the encrypted code verifier.
113	Challenge string // The un-padded, base64-url-encoded string of the encrypted code verifier.
114	// ChallengeMethod is the encryption method (ex. S256).
115	ChallengeMethod string
116	// Verifier is the original, non-encrypted secret.
117	Verifier string // The original, non-encrypted secret.
118}
119
120type tokenJSON struct {
121	AccessToken  string `json:"access_token"`
122	TokenType    string `json:"token_type"`
123	RefreshToken string `json:"refresh_token"`
124	ExpiresIn    int    `json:"expires_in"`
125	// error fields
126	ErrorCode        string `json:"error"`
127	ErrorDescription string `json:"error_description"`
128	ErrorURI         string `json:"error_uri"`
129}
130
131func (e *tokenJSON) expiry() (t time.Time) {
132	if v := e.ExpiresIn; v != 0 {
133		return time.Now().Add(time.Duration(v) * time.Second)
134	}
135	return
136}
137
138func (o *Options3LO) client() *http.Client {
139	if o.Client != nil {
140		return o.Client
141	}
142	return internal.DefaultClient()
143}
144
145// authCodeURL returns a URL that points to a OAuth2 consent page.
146func (o *Options3LO) authCodeURL(state string, values url.Values) string {
147	var buf bytes.Buffer
148	buf.WriteString(o.AuthURL)
149	v := url.Values{
150		"response_type": {"code"},
151		"client_id":     {o.ClientID},
152	}
153	if o.RedirectURL != "" {
154		v.Set("redirect_uri", o.RedirectURL)
155	}
156	if len(o.Scopes) > 0 {
157		v.Set("scope", strings.Join(o.Scopes, " "))
158	}
159	if state != "" {
160		v.Set("state", state)
161	}
162	if o.AuthHandlerOpts != nil {
163		if o.AuthHandlerOpts.PKCEOpts != nil &&
164			o.AuthHandlerOpts.PKCEOpts.Challenge != "" {
165			v.Set(codeChallengeKey, o.AuthHandlerOpts.PKCEOpts.Challenge)
166		}
167		if o.AuthHandlerOpts.PKCEOpts != nil &&
168			o.AuthHandlerOpts.PKCEOpts.ChallengeMethod != "" {
169			v.Set(codeChallengeMethodKey, o.AuthHandlerOpts.PKCEOpts.ChallengeMethod)
170		}
171	}
172	for k := range values {
173		v.Set(k, v.Get(k))
174	}
175	if strings.Contains(o.AuthURL, "?") {
176		buf.WriteByte('&')
177	} else {
178		buf.WriteByte('?')
179	}
180	buf.WriteString(v.Encode())
181	return buf.String()
182}
183
184// New3LOTokenProvider returns a [TokenProvider] based on the 3-legged OAuth2
185// configuration. The TokenProvider is caches and auto-refreshes tokens by
186// default.
187func New3LOTokenProvider(opts *Options3LO) (TokenProvider, error) {
188	if err := opts.validate(); err != nil {
189		return nil, err
190	}
191	if opts.AuthHandlerOpts != nil {
192		return new3LOTokenProviderWithAuthHandler(opts), nil
193	}
194	return NewCachedTokenProvider(&tokenProvider3LO{opts: opts, refreshToken: opts.RefreshToken, client: opts.client()}, &CachedTokenProviderOptions{
195		ExpireEarly: opts.EarlyTokenExpiry,
196	}), nil
197}
198
199// AuthorizationHandlerOptions provides a set of options to specify for doing a
200// 3-legged OAuth2 flow with a custom [AuthorizationHandler].
201type AuthorizationHandlerOptions struct {
202	// AuthorizationHandler specifies the handler used to for the authorization
203	// part of the flow.
204	Handler AuthorizationHandler
205	// State is used verify that the "state" is identical in the request and
206	// response before exchanging the auth code for OAuth2 token.
207	State string
208	// PKCEOpts allows setting configurations for PKCE. Optional.
209	PKCEOpts *PKCEOptions
210}
211
212func new3LOTokenProviderWithAuthHandler(opts *Options3LO) TokenProvider {
213	return NewCachedTokenProvider(&tokenProviderWithHandler{opts: opts, state: opts.AuthHandlerOpts.State}, &CachedTokenProviderOptions{
214		ExpireEarly: opts.EarlyTokenExpiry,
215	})
216}
217
218// exchange handles the final exchange portion of the 3lo flow. Returns a Token,
219// refreshToken, and error.
220func (o *Options3LO) exchange(ctx context.Context, code string) (*Token, string, error) {
221	// Build request
222	v := url.Values{
223		"grant_type": {"authorization_code"},
224		"code":       {code},
225	}
226	if o.RedirectURL != "" {
227		v.Set("redirect_uri", o.RedirectURL)
228	}
229	if o.AuthHandlerOpts != nil &&
230		o.AuthHandlerOpts.PKCEOpts != nil &&
231		o.AuthHandlerOpts.PKCEOpts.Verifier != "" {
232		v.Set(codeVerifierKey, o.AuthHandlerOpts.PKCEOpts.Verifier)
233	}
234	for k := range o.URLParams {
235		v.Set(k, o.URLParams.Get(k))
236	}
237	return fetchToken(ctx, o, v)
238}
239
240// This struct is not safe for concurrent access alone, but the way it is used
241// in this package by wrapping it with a cachedTokenProvider makes it so.
242type tokenProvider3LO struct {
243	opts         *Options3LO
244	client       *http.Client
245	refreshToken string
246}
247
248func (tp *tokenProvider3LO) Token(ctx context.Context) (*Token, error) {
249	if tp.refreshToken == "" {
250		return nil, errors.New("auth: token expired and refresh token is not set")
251	}
252	v := url.Values{
253		"grant_type":    {"refresh_token"},
254		"refresh_token": {tp.refreshToken},
255	}
256	for k := range tp.opts.URLParams {
257		v.Set(k, tp.opts.URLParams.Get(k))
258	}
259
260	tk, rt, err := fetchToken(ctx, tp.opts, v)
261	if err != nil {
262		return nil, err
263	}
264	if tp.refreshToken != rt && rt != "" {
265		tp.refreshToken = rt
266	}
267	return tk, err
268}
269
270type tokenProviderWithHandler struct {
271	opts  *Options3LO
272	state string
273}
274
275func (tp tokenProviderWithHandler) Token(ctx context.Context) (*Token, error) {
276	url := tp.opts.authCodeURL(tp.state, nil)
277	code, state, err := tp.opts.AuthHandlerOpts.Handler(url)
278	if err != nil {
279		return nil, err
280	}
281	if state != tp.state {
282		return nil, errors.New("auth: state mismatch in 3-legged-OAuth flow")
283	}
284	tok, _, err := tp.opts.exchange(ctx, code)
285	return tok, err
286}
287
288// fetchToken returns a Token, refresh token, and/or an error.
289func fetchToken(ctx context.Context, o *Options3LO, v url.Values) (*Token, string, error) {
290	var refreshToken string
291	if o.AuthStyle == StyleInParams {
292		if o.ClientID != "" {
293			v.Set("client_id", o.ClientID)
294		}
295		if o.ClientSecret != "" {
296			v.Set("client_secret", o.ClientSecret)
297		}
298	}
299	req, err := http.NewRequestWithContext(ctx, "POST", o.TokenURL, strings.NewReader(v.Encode()))
300	if err != nil {
301		return nil, refreshToken, err
302	}
303	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
304	if o.AuthStyle == StyleInHeader {
305		req.SetBasicAuth(url.QueryEscape(o.ClientID), url.QueryEscape(o.ClientSecret))
306	}
307	logger := o.logger()
308
309	logger.DebugContext(ctx, "3LO token request", "request", internallog.HTTPRequest(req, []byte(v.Encode())))
310	// Make request
311	resp, body, err := internal.DoRequest(o.client(), req)
312	if err != nil {
313		return nil, refreshToken, err
314	}
315	logger.DebugContext(ctx, "3LO token response", "response", internallog.HTTPResponse(resp, body))
316	failureStatus := resp.StatusCode < 200 || resp.StatusCode > 299
317	tokError := &Error{
318		Response: resp,
319		Body:     body,
320	}
321
322	var token *Token
323	// errors ignored because of default switch on content
324	content, _, _ := mime.ParseMediaType(resp.Header.Get("Content-Type"))
325	switch content {
326	case "application/x-www-form-urlencoded", "text/plain":
327		// some endpoints return a query string
328		vals, err := url.ParseQuery(string(body))
329		if err != nil {
330			if failureStatus {
331				return nil, refreshToken, tokError
332			}
333			return nil, refreshToken, fmt.Errorf("auth: cannot parse response: %w", err)
334		}
335		tokError.code = vals.Get("error")
336		tokError.description = vals.Get("error_description")
337		tokError.uri = vals.Get("error_uri")
338		token = &Token{
339			Value:    vals.Get("access_token"),
340			Type:     vals.Get("token_type"),
341			Metadata: make(map[string]interface{}, len(vals)),
342		}
343		for k, v := range vals {
344			token.Metadata[k] = v
345		}
346		refreshToken = vals.Get("refresh_token")
347		e := vals.Get("expires_in")
348		expires, _ := strconv.Atoi(e)
349		if expires != 0 {
350			token.Expiry = time.Now().Add(time.Duration(expires) * time.Second)
351		}
352	default:
353		var tj tokenJSON
354		if err = json.Unmarshal(body, &tj); err != nil {
355			if failureStatus {
356				return nil, refreshToken, tokError
357			}
358			return nil, refreshToken, fmt.Errorf("auth: cannot parse json: %w", err)
359		}
360		tokError.code = tj.ErrorCode
361		tokError.description = tj.ErrorDescription
362		tokError.uri = tj.ErrorURI
363		token = &Token{
364			Value:    tj.AccessToken,
365			Type:     tj.TokenType,
366			Expiry:   tj.expiry(),
367			Metadata: make(map[string]interface{}),
368		}
369		json.Unmarshal(body, &token.Metadata) // optional field, skip err check
370		refreshToken = tj.RefreshToken
371	}
372	// according to spec, servers should respond status 400 in error case
373	// https://www.rfc-editor.org/rfc/rfc6749#section-5.2
374	// but some unorthodox servers respond 200 in error case
375	if failureStatus || tokError.code != "" {
376		return nil, refreshToken, tokError
377	}
378	if token.Value == "" {
379		return nil, refreshToken, errors.New("auth: server response missing access_token")
380	}
381	return token, refreshToken, nil
382}