transport.go

  1// Copyright 2023 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package httptransport
 16
 17import (
 18	"context"
 19	"crypto/tls"
 20	"net"
 21	"net/http"
 22	"os"
 23	"time"
 24
 25	"cloud.google.com/go/auth"
 26	"cloud.google.com/go/auth/credentials"
 27	"cloud.google.com/go/auth/internal"
 28	"cloud.google.com/go/auth/internal/transport"
 29	"cloud.google.com/go/auth/internal/transport/cert"
 30	"go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
 31	"golang.org/x/net/http2"
 32)
 33
 34const (
 35	quotaProjectHeaderKey = "X-goog-user-project"
 36)
 37
 38func newTransport(base http.RoundTripper, opts *Options) (http.RoundTripper, error) {
 39	var headers = opts.Headers
 40	ht := &headerTransport{
 41		base:    base,
 42		headers: headers,
 43	}
 44	var trans http.RoundTripper = ht
 45	trans = addOpenTelemetryTransport(trans, opts)
 46	switch {
 47	case opts.DisableAuthentication:
 48		// Do nothing.
 49	case opts.APIKey != "":
 50		qp := internal.GetQuotaProject(nil, opts.Headers.Get(quotaProjectHeaderKey))
 51		if qp != "" {
 52			if headers == nil {
 53				headers = make(map[string][]string, 1)
 54			}
 55			headers.Set(quotaProjectHeaderKey, qp)
 56		}
 57		trans = &apiKeyTransport{
 58			Transport: trans,
 59			Key:       opts.APIKey,
 60		}
 61	default:
 62		var creds *auth.Credentials
 63		if opts.Credentials != nil {
 64			creds = opts.Credentials
 65		} else {
 66			var err error
 67			creds, err = credentials.DetectDefault(opts.resolveDetectOptions())
 68			if err != nil {
 69				return nil, err
 70			}
 71		}
 72		qp, err := creds.QuotaProjectID(context.Background())
 73		if err != nil {
 74			return nil, err
 75		}
 76		if qp != "" {
 77			if headers == nil {
 78				headers = make(map[string][]string, 1)
 79			}
 80			// Don't overwrite user specified quota
 81			if v := headers.Get(quotaProjectHeaderKey); v == "" {
 82				headers.Set(quotaProjectHeaderKey, qp)
 83			}
 84		}
 85		var skipUD bool
 86		if iOpts := opts.InternalOptions; iOpts != nil {
 87			skipUD = iOpts.SkipUniverseDomainValidation
 88		}
 89		creds.TokenProvider = auth.NewCachedTokenProvider(creds.TokenProvider, nil)
 90		trans = &authTransport{
 91			base:                         trans,
 92			creds:                        creds,
 93			clientUniverseDomain:         opts.UniverseDomain,
 94			skipUniverseDomainValidation: skipUD,
 95		}
 96	}
 97	return trans, nil
 98}
 99
100// defaultBaseTransport returns the base HTTP transport.
101// On App Engine, this is urlfetch.Transport.
102// Otherwise, use a default transport, taking most defaults from
103// http.DefaultTransport.
104// If TLSCertificate is available, set TLSClientConfig as well.
105func defaultBaseTransport(clientCertSource cert.Provider, dialTLSContext func(context.Context, string, string) (net.Conn, error)) http.RoundTripper {
106	defaultTransport, ok := http.DefaultTransport.(*http.Transport)
107	if !ok {
108		defaultTransport = transport.BaseTransport()
109	}
110	trans := defaultTransport.Clone()
111	trans.MaxIdleConnsPerHost = 100
112
113	if clientCertSource != nil {
114		trans.TLSClientConfig = &tls.Config{
115			GetClientCertificate: clientCertSource,
116		}
117	}
118	if dialTLSContext != nil {
119		// If DialTLSContext is set, TLSClientConfig wil be ignored
120		trans.DialTLSContext = dialTLSContext
121	}
122
123	// Configures the ReadIdleTimeout HTTP/2 option for the
124	// transport. This allows broken idle connections to be pruned more quickly,
125	// preventing the client from attempting to re-use connections that will no
126	// longer work.
127	http2Trans, err := http2.ConfigureTransports(trans)
128	if err == nil {
129		http2Trans.ReadIdleTimeout = time.Second * 31
130	}
131
132	return trans
133}
134
135type apiKeyTransport struct {
136	// Key is the API Key to set on requests.
137	Key string
138	// Transport is the underlying HTTP transport.
139	// If nil, http.DefaultTransport is used.
140	Transport http.RoundTripper
141}
142
143func (t *apiKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
144	newReq := *req
145	args := newReq.URL.Query()
146	args.Set("key", t.Key)
147	newReq.URL.RawQuery = args.Encode()
148	return t.Transport.RoundTrip(&newReq)
149}
150
151type headerTransport struct {
152	headers http.Header
153	base    http.RoundTripper
154}
155
156func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
157	rt := t.base
158	newReq := *req
159	newReq.Header = make(http.Header)
160	for k, vv := range req.Header {
161		newReq.Header[k] = vv
162	}
163
164	for k, v := range t.headers {
165		newReq.Header[k] = v
166	}
167
168	return rt.RoundTrip(&newReq)
169}
170
171func addOpenTelemetryTransport(trans http.RoundTripper, opts *Options) http.RoundTripper {
172	if opts.DisableTelemetry {
173		return trans
174	}
175	return otelhttp.NewTransport(trans)
176}
177
178type authTransport struct {
179	creds                        *auth.Credentials
180	base                         http.RoundTripper
181	clientUniverseDomain         string
182	skipUniverseDomainValidation bool
183}
184
185// getClientUniverseDomain returns the default service domain for a given Cloud
186// universe, with the following precedence:
187//
188// 1. A non-empty option.WithUniverseDomain or similar client option.
189// 2. A non-empty environment variable GOOGLE_CLOUD_UNIVERSE_DOMAIN.
190// 3. The default value "googleapis.com".
191//
192// This is the universe domain configured for the client, which will be compared
193// to the universe domain that is separately configured for the credentials.
194func (t *authTransport) getClientUniverseDomain() string {
195	if t.clientUniverseDomain != "" {
196		return t.clientUniverseDomain
197	}
198	if envUD := os.Getenv(internal.UniverseDomainEnvVar); envUD != "" {
199		return envUD
200	}
201	return internal.DefaultUniverseDomain
202}
203
204// RoundTrip authorizes and authenticates the request with an
205// access token from Transport's Source. Per the RoundTripper contract we must
206// not modify the initial request, so we clone it, and we must close the body
207// on any errors that happens during our token logic.
208func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
209	reqBodyClosed := false
210	if req.Body != nil {
211		defer func() {
212			if !reqBodyClosed {
213				req.Body.Close()
214			}
215		}()
216	}
217	token, err := t.creds.Token(req.Context())
218	if err != nil {
219		return nil, err
220	}
221	if !t.skipUniverseDomainValidation && token.MetadataString("auth.google.tokenSource") != "compute-metadata" {
222		credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context())
223		if err != nil {
224			return nil, err
225		}
226		if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil {
227			return nil, err
228		}
229	}
230	req2 := req.Clone(req.Context())
231	SetAuthHeader(token, req2)
232	reqBodyClosed = true
233	return t.base.RoundTrip(req2)
234}