cba.go

  1// Copyright 2023 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package transport
 16
 17import (
 18	"context"
 19	"crypto/tls"
 20	"crypto/x509"
 21	"errors"
 22	"log"
 23	"log/slog"
 24	"net"
 25	"net/http"
 26	"net/url"
 27	"os"
 28	"strconv"
 29	"strings"
 30
 31	"cloud.google.com/go/auth/internal"
 32	"cloud.google.com/go/auth/internal/transport/cert"
 33	"github.com/google/s2a-go"
 34	"github.com/google/s2a-go/fallback"
 35	"google.golang.org/grpc/credentials"
 36)
 37
 38const (
 39	mTLSModeAlways = "always"
 40	mTLSModeNever  = "never"
 41	mTLSModeAuto   = "auto"
 42
 43	// Experimental: if true, the code will try MTLS with S2A as the default for transport security. Default value is false.
 44	googleAPIUseS2AEnv     = "EXPERIMENTAL_GOOGLE_API_USE_S2A"
 45	googleAPIUseCertSource = "GOOGLE_API_USE_CLIENT_CERTIFICATE"
 46	googleAPIUseMTLS       = "GOOGLE_API_USE_MTLS_ENDPOINT"
 47	googleAPIUseMTLSOld    = "GOOGLE_API_USE_MTLS"
 48
 49	universeDomainPlaceholder = "UNIVERSE_DOMAIN"
 50
 51	mtlsMDSRoot = "/run/google-mds-mtls/root.crt"
 52	mtlsMDSKey  = "/run/google-mds-mtls/client.key"
 53)
 54
 55// Options is a struct that is duplicated information from the individual
 56// transport packages in order to avoid cyclic deps. It correlates 1:1 with
 57// fields on httptransport.Options and grpctransport.Options.
 58type Options struct {
 59	Endpoint                string
 60	DefaultEndpointTemplate string
 61	DefaultMTLSEndpoint     string
 62	ClientCertProvider      cert.Provider
 63	Client                  *http.Client
 64	UniverseDomain          string
 65	EnableDirectPath        bool
 66	EnableDirectPathXds     bool
 67	Logger                  *slog.Logger
 68}
 69
 70// getUniverseDomain returns the default service domain for a given Cloud
 71// universe.
 72func (o *Options) getUniverseDomain() string {
 73	if o.UniverseDomain == "" {
 74		return internal.DefaultUniverseDomain
 75	}
 76	return o.UniverseDomain
 77}
 78
 79// isUniverseDomainGDU returns true if the universe domain is the default Google
 80// universe.
 81func (o *Options) isUniverseDomainGDU() bool {
 82	return o.getUniverseDomain() == internal.DefaultUniverseDomain
 83}
 84
 85// defaultEndpoint returns the DefaultEndpointTemplate merged with the
 86// universe domain if the DefaultEndpointTemplate is set, otherwise returns an
 87// empty string.
 88func (o *Options) defaultEndpoint() string {
 89	if o.DefaultEndpointTemplate == "" {
 90		return ""
 91	}
 92	return strings.Replace(o.DefaultEndpointTemplate, universeDomainPlaceholder, o.getUniverseDomain(), 1)
 93}
 94
 95// defaultMTLSEndpoint returns the DefaultMTLSEndpointTemplate merged with the
 96// universe domain if the DefaultMTLSEndpointTemplate is set, otherwise returns an
 97// empty string.
 98func (o *Options) defaultMTLSEndpoint() string {
 99	if o.DefaultMTLSEndpoint == "" {
100		return ""
101	}
102	return strings.Replace(o.DefaultMTLSEndpoint, universeDomainPlaceholder, o.getUniverseDomain(), 1)
103}
104
105// mergedEndpoint merges a user-provided Endpoint of format host[:port] with the
106// default endpoint.
107func (o *Options) mergedEndpoint() (string, error) {
108	defaultEndpoint := o.defaultEndpoint()
109	u, err := url.Parse(fixScheme(defaultEndpoint))
110	if err != nil {
111		return "", err
112	}
113	return strings.Replace(defaultEndpoint, u.Host, o.Endpoint, 1), nil
114}
115
116func fixScheme(baseURL string) string {
117	if !strings.Contains(baseURL, "://") {
118		baseURL = "https://" + baseURL
119	}
120	return baseURL
121}
122
123// GetGRPCTransportCredsAndEndpoint returns an instance of
124// [google.golang.org/grpc/credentials.TransportCredentials], and the
125// corresponding endpoint to use for GRPC client.
126func GetGRPCTransportCredsAndEndpoint(opts *Options) (credentials.TransportCredentials, string, error) {
127	config, err := getTransportConfig(opts)
128	if err != nil {
129		return nil, "", err
130	}
131
132	defaultTransportCreds := credentials.NewTLS(&tls.Config{
133		GetClientCertificate: config.clientCertSource,
134	})
135
136	var s2aAddr string
137	var transportCredsForS2A credentials.TransportCredentials
138
139	if config.mtlsS2AAddress != "" {
140		s2aAddr = config.mtlsS2AAddress
141		transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
142		if err != nil {
143			log.Printf("Loading MTLS MDS credentials failed: %v", err)
144			if config.s2aAddress != "" {
145				s2aAddr = config.s2aAddress
146			} else {
147				return defaultTransportCreds, config.endpoint, nil
148			}
149		}
150	} else if config.s2aAddress != "" {
151		s2aAddr = config.s2aAddress
152	} else {
153		return defaultTransportCreds, config.endpoint, nil
154	}
155
156	var fallbackOpts *s2a.FallbackOptions
157	// In case of S2A failure, fall back to the endpoint that would've been used without S2A.
158	if fallbackHandshake, err := fallback.DefaultFallbackClientHandshakeFunc(config.endpoint); err == nil {
159		fallbackOpts = &s2a.FallbackOptions{
160			FallbackClientHandshakeFunc: fallbackHandshake,
161		}
162	}
163
164	s2aTransportCreds, err := s2a.NewClientCreds(&s2a.ClientOptions{
165		S2AAddress:     s2aAddr,
166		TransportCreds: transportCredsForS2A,
167		FallbackOpts:   fallbackOpts,
168	})
169	if err != nil {
170		// Use default if we cannot initialize S2A client transport credentials.
171		return defaultTransportCreds, config.endpoint, nil
172	}
173	return s2aTransportCreds, config.s2aMTLSEndpoint, nil
174}
175
176// GetHTTPTransportConfig returns a client certificate source and a function for
177// dialing MTLS with S2A.
178func GetHTTPTransportConfig(opts *Options) (cert.Provider, func(context.Context, string, string) (net.Conn, error), error) {
179	config, err := getTransportConfig(opts)
180	if err != nil {
181		return nil, nil, err
182	}
183
184	var s2aAddr string
185	var transportCredsForS2A credentials.TransportCredentials
186
187	if config.mtlsS2AAddress != "" {
188		s2aAddr = config.mtlsS2AAddress
189		transportCredsForS2A, err = loadMTLSMDSTransportCreds(mtlsMDSRoot, mtlsMDSKey)
190		if err != nil {
191			log.Printf("Loading MTLS MDS credentials failed: %v", err)
192			if config.s2aAddress != "" {
193				s2aAddr = config.s2aAddress
194			} else {
195				return config.clientCertSource, nil, nil
196			}
197		}
198	} else if config.s2aAddress != "" {
199		s2aAddr = config.s2aAddress
200	} else {
201		return config.clientCertSource, nil, nil
202	}
203
204	var fallbackOpts *s2a.FallbackOptions
205	// In case of S2A failure, fall back to the endpoint that would've been used without S2A.
206	if fallbackURL, err := url.Parse(config.endpoint); err == nil {
207		if fallbackDialer, fallbackServerAddr, err := fallback.DefaultFallbackDialerAndAddress(fallbackURL.Hostname()); err == nil {
208			fallbackOpts = &s2a.FallbackOptions{
209				FallbackDialer: &s2a.FallbackDialer{
210					Dialer:     fallbackDialer,
211					ServerAddr: fallbackServerAddr,
212				},
213			}
214		}
215	}
216
217	dialTLSContextFunc := s2a.NewS2ADialTLSContextFunc(&s2a.ClientOptions{
218		S2AAddress:     s2aAddr,
219		TransportCreds: transportCredsForS2A,
220		FallbackOpts:   fallbackOpts,
221	})
222	return nil, dialTLSContextFunc, nil
223}
224
225func loadMTLSMDSTransportCreds(mtlsMDSRootFile, mtlsMDSKeyFile string) (credentials.TransportCredentials, error) {
226	rootPEM, err := os.ReadFile(mtlsMDSRootFile)
227	if err != nil {
228		return nil, err
229	}
230	caCertPool := x509.NewCertPool()
231	ok := caCertPool.AppendCertsFromPEM(rootPEM)
232	if !ok {
233		return nil, errors.New("failed to load MTLS MDS root certificate")
234	}
235	// The mTLS MDS credentials are formatted as the concatenation of a PEM-encoded certificate chain
236	// followed by a PEM-encoded private key. For this reason, the concatenation is passed in to the
237	// tls.X509KeyPair function as both the certificate chain and private key arguments.
238	cert, err := tls.LoadX509KeyPair(mtlsMDSKeyFile, mtlsMDSKeyFile)
239	if err != nil {
240		return nil, err
241	}
242	tlsConfig := tls.Config{
243		RootCAs:      caCertPool,
244		Certificates: []tls.Certificate{cert},
245		MinVersion:   tls.VersionTLS13,
246	}
247	return credentials.NewTLS(&tlsConfig), nil
248}
249
250func getTransportConfig(opts *Options) (*transportConfig, error) {
251	clientCertSource, err := GetClientCertificateProvider(opts)
252	if err != nil {
253		return nil, err
254	}
255	endpoint, err := getEndpoint(opts, clientCertSource)
256	if err != nil {
257		return nil, err
258	}
259	defaultTransportConfig := transportConfig{
260		clientCertSource: clientCertSource,
261		endpoint:         endpoint,
262	}
263
264	if !shouldUseS2A(clientCertSource, opts) {
265		return &defaultTransportConfig, nil
266	}
267
268	s2aAddress := GetS2AAddress(opts.Logger)
269	mtlsS2AAddress := GetMTLSS2AAddress(opts.Logger)
270	if s2aAddress == "" && mtlsS2AAddress == "" {
271		return &defaultTransportConfig, nil
272	}
273	return &transportConfig{
274		clientCertSource: clientCertSource,
275		endpoint:         endpoint,
276		s2aAddress:       s2aAddress,
277		mtlsS2AAddress:   mtlsS2AAddress,
278		s2aMTLSEndpoint:  opts.defaultMTLSEndpoint(),
279	}, nil
280}
281
282// GetClientCertificateProvider returns a default client certificate source, if
283// not provided by the user.
284//
285// A nil default source can be returned if the source does not exist. Any exceptions
286// encountered while initializing the default source will be reported as client
287// error (ex. corrupt metadata file).
288func GetClientCertificateProvider(opts *Options) (cert.Provider, error) {
289	if !isClientCertificateEnabled(opts) {
290		return nil, nil
291	} else if opts.ClientCertProvider != nil {
292		return opts.ClientCertProvider, nil
293	}
294	return cert.DefaultProvider()
295
296}
297
298// isClientCertificateEnabled returns true by default for all GDU universe domain, unless explicitly overridden by env var
299func isClientCertificateEnabled(opts *Options) bool {
300	if value, ok := os.LookupEnv(googleAPIUseCertSource); ok {
301		// error as false is OK
302		b, _ := strconv.ParseBool(value)
303		return b
304	}
305	return opts.isUniverseDomainGDU()
306}
307
308type transportConfig struct {
309	// The client certificate source.
310	clientCertSource cert.Provider
311	// The corresponding endpoint to use based on client certificate source.
312	endpoint string
313	// The plaintext S2A address if it can be used, otherwise an empty string.
314	s2aAddress string
315	// The MTLS S2A address if it can be used, otherwise an empty string.
316	mtlsS2AAddress string
317	// The MTLS endpoint to use with S2A.
318	s2aMTLSEndpoint string
319}
320
321// getEndpoint returns the endpoint for the service, taking into account the
322// user-provided endpoint override "settings.Endpoint".
323//
324// If no endpoint override is specified, we will either return the default
325// endpoint or the default mTLS endpoint if a client certificate is available.
326//
327// You can override the default endpoint choice (mTLS vs. regular) by setting
328// the GOOGLE_API_USE_MTLS_ENDPOINT environment variable.
329//
330// If the endpoint override is an address (host:port) rather than full base
331// URL (ex. https://...), then the user-provided address will be merged into
332// the default endpoint. For example, WithEndpoint("myhost:8000") and
333// DefaultEndpointTemplate("https://UNIVERSE_DOMAIN/bar/baz") will return
334// "https://myhost:8080/bar/baz". Note that this does not apply to the mTLS
335// endpoint.
336func getEndpoint(opts *Options, clientCertSource cert.Provider) (string, error) {
337	if opts.Endpoint == "" {
338		mtlsMode := getMTLSMode()
339		if mtlsMode == mTLSModeAlways || (clientCertSource != nil && mtlsMode == mTLSModeAuto) {
340			return opts.defaultMTLSEndpoint(), nil
341		}
342		return opts.defaultEndpoint(), nil
343	}
344	if strings.Contains(opts.Endpoint, "://") {
345		// User passed in a full URL path, use it verbatim.
346		return opts.Endpoint, nil
347	}
348	if opts.defaultEndpoint() == "" {
349		// If DefaultEndpointTemplate is not configured,
350		// use the user provided endpoint verbatim. This allows a naked
351		// "host[:port]" URL to be used with GRPC Direct Path.
352		return opts.Endpoint, nil
353	}
354
355	// Assume user-provided endpoint is host[:port], merge it with the default endpoint.
356	return opts.mergedEndpoint()
357}
358
359func getMTLSMode() string {
360	mode := os.Getenv(googleAPIUseMTLS)
361	if mode == "" {
362		mode = os.Getenv(googleAPIUseMTLSOld) // Deprecated.
363	}
364	if mode == "" {
365		return mTLSModeAuto
366	}
367	return strings.ToLower(mode)
368}