gdch.go

  1// Copyright 2023 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package gdch
 16
 17import (
 18	"context"
 19	"crypto"
 20	"crypto/tls"
 21	"crypto/x509"
 22	"encoding/json"
 23	"errors"
 24	"fmt"
 25	"log/slog"
 26	"net/http"
 27	"net/url"
 28	"os"
 29	"strings"
 30	"time"
 31
 32	"cloud.google.com/go/auth"
 33	"cloud.google.com/go/auth/internal"
 34	"cloud.google.com/go/auth/internal/credsfile"
 35	"cloud.google.com/go/auth/internal/jwt"
 36	"github.com/googleapis/gax-go/v2/internallog"
 37)
 38
 39const (
 40	// GrantType is the grant type for the token request.
 41	GrantType        = "urn:ietf:params:oauth:token-type:token-exchange"
 42	requestTokenType = "urn:ietf:params:oauth:token-type:access_token"
 43	subjectTokenType = "urn:k8s:params:oauth:token-type:serviceaccount"
 44)
 45
 46var (
 47	gdchSupportFormatVersions map[string]bool = map[string]bool{
 48		"1": true,
 49	}
 50)
 51
 52// Options for [NewTokenProvider].
 53type Options struct {
 54	STSAudience string
 55	Client      *http.Client
 56	Logger      *slog.Logger
 57}
 58
 59// NewTokenProvider returns a [cloud.google.com/go/auth.TokenProvider] from a
 60// GDCH cred file.
 61func NewTokenProvider(f *credsfile.GDCHServiceAccountFile, o *Options) (auth.TokenProvider, error) {
 62	if !gdchSupportFormatVersions[f.FormatVersion] {
 63		return nil, fmt.Errorf("credentials: unsupported gdch_service_account format %q", f.FormatVersion)
 64	}
 65	if o.STSAudience == "" {
 66		return nil, errors.New("credentials: STSAudience must be set for the GDCH auth flows")
 67	}
 68	signer, err := internal.ParseKey([]byte(f.PrivateKey))
 69	if err != nil {
 70		return nil, err
 71	}
 72	certPool, err := loadCertPool(f.CertPath)
 73	if err != nil {
 74		return nil, err
 75	}
 76
 77	tp := gdchProvider{
 78		serviceIdentity: fmt.Sprintf("system:serviceaccount:%s:%s", f.Project, f.Name),
 79		tokenURL:        f.TokenURL,
 80		aud:             o.STSAudience,
 81		signer:          signer,
 82		pkID:            f.PrivateKeyID,
 83		certPool:        certPool,
 84		client:          o.Client,
 85		logger:          internallog.New(o.Logger),
 86	}
 87	return tp, nil
 88}
 89
 90func loadCertPool(path string) (*x509.CertPool, error) {
 91	pool := x509.NewCertPool()
 92	pem, err := os.ReadFile(path)
 93	if err != nil {
 94		return nil, fmt.Errorf("credentials: failed to read certificate: %w", err)
 95	}
 96	pool.AppendCertsFromPEM(pem)
 97	return pool, nil
 98}
 99
100type gdchProvider struct {
101	serviceIdentity string
102	tokenURL        string
103	aud             string
104	signer          crypto.Signer
105	pkID            string
106	certPool        *x509.CertPool
107
108	client *http.Client
109	logger *slog.Logger
110}
111
112func (g gdchProvider) Token(ctx context.Context) (*auth.Token, error) {
113	addCertToTransport(g.client, g.certPool)
114	iat := time.Now()
115	exp := iat.Add(time.Hour)
116	claims := jwt.Claims{
117		Iss: g.serviceIdentity,
118		Sub: g.serviceIdentity,
119		Aud: g.tokenURL,
120		Iat: iat.Unix(),
121		Exp: exp.Unix(),
122	}
123	h := jwt.Header{
124		Algorithm: jwt.HeaderAlgRSA256,
125		Type:      jwt.HeaderType,
126		KeyID:     string(g.pkID),
127	}
128	payload, err := jwt.EncodeJWS(&h, &claims, g.signer)
129	if err != nil {
130		return nil, err
131	}
132	v := url.Values{}
133	v.Set("grant_type", GrantType)
134	v.Set("audience", g.aud)
135	v.Set("requested_token_type", requestTokenType)
136	v.Set("subject_token", payload)
137	v.Set("subject_token_type", subjectTokenType)
138
139	req, err := http.NewRequestWithContext(ctx, "POST", g.tokenURL, strings.NewReader(v.Encode()))
140	if err != nil {
141		return nil, err
142	}
143	req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
144	g.logger.DebugContext(ctx, "gdch token request", "request", internallog.HTTPRequest(req, []byte(v.Encode())))
145	resp, body, err := internal.DoRequest(g.client, req)
146	if err != nil {
147		return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
148	}
149	g.logger.DebugContext(ctx, "gdch token response", "response", internallog.HTTPResponse(resp, body))
150	if c := resp.StatusCode; c < http.StatusOK || c > http.StatusMultipleChoices {
151		return nil, &auth.Error{
152			Response: resp,
153			Body:     body,
154		}
155	}
156
157	var tokenRes struct {
158		AccessToken string `json:"access_token"`
159		TokenType   string `json:"token_type"`
160		ExpiresIn   int64  `json:"expires_in"` // relative seconds from now
161	}
162	if err := json.Unmarshal(body, &tokenRes); err != nil {
163		return nil, fmt.Errorf("credentials: cannot fetch token: %w", err)
164	}
165	token := &auth.Token{
166		Value: tokenRes.AccessToken,
167		Type:  tokenRes.TokenType,
168	}
169	raw := make(map[string]interface{})
170	json.Unmarshal(body, &raw) // no error checks for optional fields
171	token.Metadata = raw
172
173	if secs := tokenRes.ExpiresIn; secs > 0 {
174		token.Expiry = time.Now().Add(time.Duration(secs) * time.Second)
175	}
176	return token, nil
177}
178
179// addCertToTransport makes a best effort attempt at adding in the cert info to
180// the client. It tries to keep all configured transport settings if the
181// underlying transport is an http.Transport. Or else it overwrites the
182// transport with defaults adding in the certs.
183func addCertToTransport(hc *http.Client, certPool *x509.CertPool) {
184	trans, ok := hc.Transport.(*http.Transport)
185	if !ok {
186		trans = http.DefaultTransport.(*http.Transport).Clone()
187	}
188	trans.TLSClientConfig = &tls.Config{
189		RootCAs: certPool,
190	}
191}