1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package httptransport
16
17import (
18 "context"
19 "crypto/tls"
20 "net"
21 "net/http"
22 "os"
23 "time"
24
25 "cloud.google.com/go/auth"
26 "cloud.google.com/go/auth/credentials"
27 "cloud.google.com/go/auth/internal"
28 "cloud.google.com/go/auth/internal/transport"
29 "cloud.google.com/go/auth/internal/transport/cert"
30 "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp"
31 "golang.org/x/net/http2"
32)
33
34const (
35 quotaProjectHeaderKey = "X-goog-user-project"
36)
37
38func newTransport(base http.RoundTripper, opts *Options) (http.RoundTripper, error) {
39 var headers = opts.Headers
40 ht := &headerTransport{
41 base: base,
42 headers: headers,
43 }
44 var trans http.RoundTripper = ht
45 trans = addOpenTelemetryTransport(trans, opts)
46 switch {
47 case opts.DisableAuthentication:
48 // Do nothing.
49 case opts.APIKey != "":
50 qp := internal.GetQuotaProject(nil, opts.Headers.Get(quotaProjectHeaderKey))
51 if qp != "" {
52 if headers == nil {
53 headers = make(map[string][]string, 1)
54 }
55 headers.Set(quotaProjectHeaderKey, qp)
56 }
57 trans = &apiKeyTransport{
58 Transport: trans,
59 Key: opts.APIKey,
60 }
61 default:
62 var creds *auth.Credentials
63 if opts.Credentials != nil {
64 creds = opts.Credentials
65 } else {
66 var err error
67 creds, err = credentials.DetectDefault(opts.resolveDetectOptions())
68 if err != nil {
69 return nil, err
70 }
71 }
72 qp, err := creds.QuotaProjectID(context.Background())
73 if err != nil {
74 return nil, err
75 }
76 if qp != "" {
77 if headers == nil {
78 headers = make(map[string][]string, 1)
79 }
80 // Don't overwrite user specified quota
81 if v := headers.Get(quotaProjectHeaderKey); v == "" {
82 headers.Set(quotaProjectHeaderKey, qp)
83 }
84 }
85 var skipUD bool
86 if iOpts := opts.InternalOptions; iOpts != nil {
87 skipUD = iOpts.SkipUniverseDomainValidation
88 }
89 creds.TokenProvider = auth.NewCachedTokenProvider(creds.TokenProvider, nil)
90 trans = &authTransport{
91 base: trans,
92 creds: creds,
93 clientUniverseDomain: opts.UniverseDomain,
94 skipUniverseDomainValidation: skipUD,
95 }
96 }
97 return trans, nil
98}
99
100// defaultBaseTransport returns the base HTTP transport.
101// On App Engine, this is urlfetch.Transport.
102// Otherwise, use a default transport, taking most defaults from
103// http.DefaultTransport.
104// If TLSCertificate is available, set TLSClientConfig as well.
105func defaultBaseTransport(clientCertSource cert.Provider, dialTLSContext func(context.Context, string, string) (net.Conn, error)) http.RoundTripper {
106 defaultTransport, ok := http.DefaultTransport.(*http.Transport)
107 if !ok {
108 defaultTransport = transport.BaseTransport()
109 }
110 trans := defaultTransport.Clone()
111 trans.MaxIdleConnsPerHost = 100
112
113 if clientCertSource != nil {
114 trans.TLSClientConfig = &tls.Config{
115 GetClientCertificate: clientCertSource,
116 }
117 }
118 if dialTLSContext != nil {
119 // If DialTLSContext is set, TLSClientConfig wil be ignored
120 trans.DialTLSContext = dialTLSContext
121 }
122
123 // Configures the ReadIdleTimeout HTTP/2 option for the
124 // transport. This allows broken idle connections to be pruned more quickly,
125 // preventing the client from attempting to re-use connections that will no
126 // longer work.
127 http2Trans, err := http2.ConfigureTransports(trans)
128 if err == nil {
129 http2Trans.ReadIdleTimeout = time.Second * 31
130 }
131
132 return trans
133}
134
135type apiKeyTransport struct {
136 // Key is the API Key to set on requests.
137 Key string
138 // Transport is the underlying HTTP transport.
139 // If nil, http.DefaultTransport is used.
140 Transport http.RoundTripper
141}
142
143func (t *apiKeyTransport) RoundTrip(req *http.Request) (*http.Response, error) {
144 newReq := *req
145 args := newReq.URL.Query()
146 args.Set("key", t.Key)
147 newReq.URL.RawQuery = args.Encode()
148 return t.Transport.RoundTrip(&newReq)
149}
150
151type headerTransport struct {
152 headers http.Header
153 base http.RoundTripper
154}
155
156func (t *headerTransport) RoundTrip(req *http.Request) (*http.Response, error) {
157 rt := t.base
158 newReq := *req
159 newReq.Header = make(http.Header)
160 for k, vv := range req.Header {
161 newReq.Header[k] = vv
162 }
163
164 for k, v := range t.headers {
165 newReq.Header[k] = v
166 }
167
168 return rt.RoundTrip(&newReq)
169}
170
171func addOpenTelemetryTransport(trans http.RoundTripper, opts *Options) http.RoundTripper {
172 if opts.DisableTelemetry {
173 return trans
174 }
175 return otelhttp.NewTransport(trans)
176}
177
178type authTransport struct {
179 creds *auth.Credentials
180 base http.RoundTripper
181 clientUniverseDomain string
182 skipUniverseDomainValidation bool
183}
184
185// getClientUniverseDomain returns the default service domain for a given Cloud
186// universe, with the following precedence:
187//
188// 1. A non-empty option.WithUniverseDomain or similar client option.
189// 2. A non-empty environment variable GOOGLE_CLOUD_UNIVERSE_DOMAIN.
190// 3. The default value "googleapis.com".
191//
192// This is the universe domain configured for the client, which will be compared
193// to the universe domain that is separately configured for the credentials.
194func (t *authTransport) getClientUniverseDomain() string {
195 if t.clientUniverseDomain != "" {
196 return t.clientUniverseDomain
197 }
198 if envUD := os.Getenv(internal.UniverseDomainEnvVar); envUD != "" {
199 return envUD
200 }
201 return internal.DefaultUniverseDomain
202}
203
204// RoundTrip authorizes and authenticates the request with an
205// access token from Transport's Source. Per the RoundTripper contract we must
206// not modify the initial request, so we clone it, and we must close the body
207// on any errors that happens during our token logic.
208func (t *authTransport) RoundTrip(req *http.Request) (*http.Response, error) {
209 reqBodyClosed := false
210 if req.Body != nil {
211 defer func() {
212 if !reqBodyClosed {
213 req.Body.Close()
214 }
215 }()
216 }
217 token, err := t.creds.Token(req.Context())
218 if err != nil {
219 return nil, err
220 }
221 if !t.skipUniverseDomainValidation && token.MetadataString("auth.google.tokenSource") != "compute-metadata" {
222 credentialsUniverseDomain, err := t.creds.UniverseDomain(req.Context())
223 if err != nil {
224 return nil, err
225 }
226 if err := transport.ValidateUniverseDomain(t.getClientUniverseDomain(), credentialsUniverseDomain); err != nil {
227 return nil, err
228 }
229 }
230 req2 := req.Clone(req.Context())
231 SetAuthHeader(token, req2)
232 reqBodyClosed = true
233 return t.base.RoundTrip(req2)
234}