1// Copyright 2023 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package transport
16
17import (
18 "context"
19 "crypto/tls"
20 "crypto/x509"
21 "errors"
22 "log"
23 "log/slog"
24 "net"
25 "net/http"
26 "net/url"
27 "os"
28 "strconv"
29 "strings"
30
31 "cloud.google.com/go/auth/internal"
32 "cloud.google.com/go/auth/internal/transport/cert"
33 "github.com/google/s2a-go"
34 "github.com/google/s2a-go/fallback"
35 "google.golang.org/grpc/credentials"
36)
37
38const (
39 mTLSModeAlways = "always"
40 mTLSModeNever = "never"
41 mTLSModeAuto = "auto"
42
43 // Experimental: if true, the code will try MTLS with S2A as the default for transport security. Default value is false.
44 googleAPIUseS2AEnv = "EXPERIMENTAL_GOOGLE_API_USE_S2A"
45 googleAPIUseCertSource = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
46 googleAPIUseMTLS = "GOOGLE_API_USE_MTLS_ENDPOINT"
47 googleAPIUseMTLSOld = "GOOGLE_API_USE_MTLS"
48
49 universeDomainPlaceholder = "UNIVERSE_DOMAIN"
50
51 mtlsMDSRoot = "/run/google-mds-mtls/root.crt"
52 mtlsMDSKey = "/run/google-mds-mtls/client.key"
53)
54
55// Options is a struct that is duplicated information from the individual
56// transport packages in order to avoid cyclic deps. It correlates 1:1 with
57// fields on httptransport.Options and grpctransport.Options.
58type Options struct {
59 Endpoint string
60 DefaultEndpointTemplate string
61 DefaultMTLSEndpoint string
62 ClientCertProvider cert.Provider
63 Client *http.Client
64 UniverseDomain string
65 EnableDirectPath bool
66 EnableDirectPathXds bool
67 Logger *slog.Logger
68}
69
70// getUniverseDomain returns the default service domain for a given Cloud
71// universe.
72func (o *Options) getUniverseDomain() string {
73 if o.UniverseDomain == "" {
74 return internal.DefaultUniverseDomain
75 }
76 return o.UniverseDomain
77}
78
79// isUniverseDomainGDU returns true if the universe domain is the default Google
80// universe.
81func (o *Options) isUniverseDomainGDU() bool {
82 return o.getUniverseDomain() == internal.DefaultUniverseDomain
83}
84
85// defaultEndpoint returns the DefaultEndpointTemplate merged with the
86// universe domain if the DefaultEndpointTemplate is set, otherwise returns an
87// empty string.
88func (o *Options) defaultEndpoint() string {
89 if o.DefaultEndpointTemplate == "" {
90 return ""
91 }
92 return strings.Replace(o.DefaultEndpointTemplate, universeDomainPlaceholder, o.getUniverseDomain(), 1)
93}
94
95// defaultMTLSEndpoint returns the DefaultMTLSEndpointTemplate merged with the
96// universe domain if the DefaultMTLSEndpointTemplate is set, otherwise returns an
97// empty string.
98func (o *Options) defaultMTLSEndpoint() string {
99 if o.DefaultMTLSEndpoint == "" {
100 return ""
101 }
102 return strings.Replace(o.DefaultMTLSEndpoint, universeDomainPlaceholder, o.getUniverseDomain(), 1)
103}
104
105// mergedEndpoint merges a user-provided Endpoint of format host[:port] with the
106// default endpoint.
107func (o *Options) mergedEndpoint() (string, error) {
108 defaultEndpoint := o.defaultEndpoint()
109 u, err := url.Parse(fixScheme(defaultEndpoint))
110 if err != nil {
111 return "", err
112 }
113 return strings.Replace(defaultEndpoint, u.Host, o.Endpoint, 1), nil
114}
115
116func fixScheme(baseURL string) string {
117 if !strings.Contains(baseURL, "://") {
118 baseURL = "https://" + baseURL
119 }
120 return baseURL
121}
122
123// GetGRPCTransportCredsAndEndpoint returns an instance of
124// [google.golang.org/grpc/credentials.TransportCredentials], and the
125// corresponding endpoint to use for GRPC client.
126func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCredentials, string, error) {
127 config, err := getTransportConfig(opts)
128 if err != nil {
129 return nil, "", err
130 }
131
132 defaultTransportCreds := credentials.NewTLS(&tls.Config{
133 GetClientCertificate: config.clientCertSource,
134 })
135
136 var s2aAddr string
137 var transportCredsForS2A credentials.TransportCredentials
138
139 if config.mtlsS2AAddress != "" {
140 s2aAddr = config.mtlsS2AAddress
141 transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
142 if err != nil {
143 log.Printf("Loading MTLS MDS credentials failed: %v", err)
144 if config.s2aAddress != "" {
145 s2aAddr = config.s2aAddress
146 } else {
147 return defaultTransportCreds, config.endpoint, nil
148 }
149 }
150 } else if config.s2aAddress != "" {
151 s2aAddr = config.s2aAddress
152 } else {
153 return defaultTransportCreds, config.endpoint, nil
154 }
155
156 var fallbackOpts *s2a.FallbackOptions
157 // In case of S2A failure, fall back to the endpoint that would've been used without S2A.
158 if fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(config.endpoint); err == nil {
159 fallbackOpts = &s2a.FallbackOptions{
160 FallbackClientHandshakeFunc: fallbackHandshake,
161 }
162 }
163
164 s2aTransportCreds, err := s2a.NewClientCreds(&s2a.ClientOptions{
165 S2AAddress: s2aAddr,
166 TransportCreds: transportCredsForS2A,
167 FallbackOpts: fallbackOpts,
168 })
169 if err != nil {
170 // Use default if we cannot initialize S2A client transport credentials.
171 return defaultTransportCreds, config.endpoint, nil
172 }
173 return s2aTransportCreds, config.s2aMTLSEndpoint, nil
174}
175
176// GetHTTPTransportConfig returns a client certificate source and a function for
177// dialing MTLS with S2A.
178func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context, string, string) (net.Conn, error), error) {
179 config, err := getTransportConfig(opts)
180 if err != nil {
181 return nil, nil, err
182 }
183
184 var s2aAddr string
185 var transportCredsForS2A credentials.TransportCredentials
186
187 if config.mtlsS2AAddress != "" {
188 s2aAddr = config.mtlsS2AAddress
189 transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
190 if err != nil {
191 log.Printf("Loading MTLS MDS credentials failed: %v", err)
192 if config.s2aAddress != "" {
193 s2aAddr = config.s2aAddress
194 } else {
195 return config.clientCertSource, nil, nil
196 }
197 }
198 } else if config.s2aAddress != "" {
199 s2aAddr = config.s2aAddress
200 } else {
201 return config.clientCertSource, nil, nil
202 }
203
204 var fallbackOpts *s2a.FallbackOptions
205 // In case of S2A failure, fall back to the endpoint that would've been used without S2A.
206 if fallbackURL, err := url.Parse(config.endpoint); err == nil {
207 if fallbackDialer, fallbackServerAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackURL.Hostname()); err == nil {
208 fallbackOpts = &s2a.FallbackOptions{
209 FallbackDialer: &s2a.FallbackDialer{
210 Dialer: fallbackDialer,
211 ServerAddr: fallbackServerAddr,
212 },
213 }
214 }
215 }
216
217 dialTLSContextFunc := s2a.NewS2ADialTLSContextFunc(&s2a.ClientOptions{
218 S2AAddress: s2aAddr,
219 TransportCreds: transportCredsForS2A,
220 FallbackOpts: fallbackOpts,
221 })
222 return nil, dialTLSContextFunc, nil
223}
224
225func loadMTLSMDSTransportCreds(mtlsMDSRootFile, mtlsMDSKeyFile string) (credentials.TransportCredentials, error) {
226 rootPEM, err := os.ReadFile(mtlsMDSRootFile)
227 if err != nil {
228 return nil, err
229 }
230 caCertPool := x509.NewCertPool()
231 ok := caCertPool.AppendCertsFromPEM(rootPEM)
232 if !ok {
233 return nil, errors.New("failed to load MTLS MDS root certificate")
234 }
235 // The mTLS MDS credentials are formatted as the concatenation of a PEM-encoded certificate chain
236 // followed by a PEM-encoded private key. For this reason, the concatenation is passed in to the
237 // tls.X509KeyPair function as both the certificate chain and private key arguments.
238 cert, err := tls.LoadX509KeyPair(mtlsMDSKeyFile, mtlsMDSKeyFile)
239 if err != nil {
240 return nil, err
241 }
242 tlsConfig := tls.Config{
243 RootCAs: caCertPool,
244 Certificates: []tls.Certificate{cert},
245 MinVersion: tls.VersionTLS13,
246 }
247 return credentials.NewTLS(&tlsConfig), nil
248}
249
250func getTransportConfig(opts *Options) (*transportConfig, error) {
251 clientCertSource, err := GetClientCertificateProvider(opts)
252 if err != nil {
253 return nil, err
254 }
255 endpoint, err := getEndpoint(opts, clientCertSource)
256 if err != nil {
257 return nil, err
258 }
259 defaultTransportConfig := transportConfig{
260 clientCertSource: clientCertSource,
261 endpoint: endpoint,
262 }
263
264 if !shouldUseS2A(clientCertSource, opts) {
265 return &defaultTransportConfig, nil
266 }
267
268 s2aAddress := GetS2AAddress(opts.Logger)
269 mtlsS2AAddress := GetMTLSS2AAddress(opts.Logger)
270 if s2aAddress == "" && mtlsS2AAddress == "" {
271 return &defaultTransportConfig, nil
272 }
273 return &transportConfig{
274 clientCertSource: clientCertSource,
275 endpoint: endpoint,
276 s2aAddress: s2aAddress,
277 mtlsS2AAddress: mtlsS2AAddress,
278 s2aMTLSEndpoint: opts.defaultMTLSEndpoint(),
279 }, nil
280}
281
282// GetClientCertificateProvider returns a default client certificate source, if
283// not provided by the user.
284//
285// A nil default source can be returned if the source does not exist. Any exceptions
286// encountered while initializing the default source will be reported as client
287// error (ex. corrupt metadata file).
288func GetClientCertificateProvider(opts *Options) (cert.Provider, error) {
289 if !isClientCertificateEnabled(opts) {
290 return nil, nil
291 } else if opts.ClientCertProvider != nil {
292 return opts.ClientCertProvider, nil
293 }
294 return cert.DefaultProvider()
295
296}
297
298// isClientCertificateEnabled returns true by default for all GDU universe domain, unless explicitly overridden by env var
299func isClientCertificateEnabled(opts *Options) bool {
300 if value, ok := os.LookupEnv(googleAPIUseCertSource); ok {
301 // error as false is OK
302 b, _ := strconv.ParseBool(value)
303 return b
304 }
305 return opts.isUniverseDomainGDU()
306}
307
308type transportConfig struct {
309 // The client certificate source.
310 clientCertSource cert.Provider
311 // The corresponding endpoint to use based on client certificate source.
312 endpoint string
313 // The plaintext S2A address if it can be used, otherwise an empty string.
314 s2aAddress string
315 // The MTLS S2A address if it can be used, otherwise an empty string.
316 mtlsS2AAddress string
317 // The MTLS endpoint to use with S2A.
318 s2aMTLSEndpoint string
319}
320
321// getEndpoint returns the endpoint for the service, taking into account the
322// user-provided endpoint override "settings.Endpoint".
323//
324// If no endpoint override is specified, we will either return the default
325// endpoint or the default mTLS endpoint if a client certificate is available.
326//
327// You can override the default endpoint choice (mTLS vs. regular) by setting
328// the GOOGLE_API_USE_MTLS_ENDPOINT environment variable.
329//
330// If the endpoint override is an address (host:port) rather than full base
331// URL (ex. https://...), then the user-provided address will be merged into
332// the default endpoint. For example, WithEndpoint("myhost:8000") and
333// DefaultEndpointTemplate("https://UNIVERSE_DOMAIN/bar/baz") will return
334// "https://myhost:8080/bar/baz". Note that this does not apply to the mTLS
335// endpoint.
336func getEndpoint(opts *Options, clientCertSource cert.Provider) (string, error) {
337 if opts.Endpoint == "" {
338 mtlsMode := getMTLSMode()
339 if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
340 return opts.defaultMTLSEndpoint(), nil
341 }
342 return opts.defaultEndpoint(), nil
343 }
344 if strings.Contains(opts.Endpoint, "://") {
345 // User passed in a full URL path, use it verbatim.
346 return opts.Endpoint, nil
347 }
348 if opts.defaultEndpoint() == "" {
349 // If DefaultEndpointTemplate is not configured,
350 // use the user provided endpoint verbatim. This allows a naked
351 // "host[:port]" URL to be used with GRPC Direct Path.
352 return opts.Endpoint, nil
353 }
354
355 // Assume user-provided endpoint is host[:port], merge it with the default endpoint.
356 return opts.mergedEndpoint()
357}
358
359func getMTLSMode() string {
360 mode := os.Getenv(googleAPIUseMTLS)
361 if mode == "" {
362 mode = os.Getenv(googleAPIUseMTLSOld) // Deprecated.
363 }
364 if mode == "" {
365 return mTLSModeAuto
366 }
367 return strings.ToLower(mode)
368}