1// Copyright 2023 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package transport
 16
 17import (
 18	"context"
 19	"encoding/json"
 20	"fmt"
 21	"log"
 22	"log/slog"
 23	"os"
 24	"strconv"
 25	"sync"
 26
 27	"cloud.google.com/go/auth/internal/transport/cert"
 28	"cloud.google.com/go/compute/metadata"
 29)
 30
 31const (
 32	configEndpointSuffix = "instance/platform-security/auto-mtls-configuration"
 33)
 34
 35var (
 36	mtlsConfiguration *mtlsConfig
 37
 38	mtlsOnce sync.Once
 39)
 40
 41// GetS2AAddress returns the S2A address to be reached via plaintext connection.
 42// Returns empty string if not set or invalid.
 43func GetS2AAddress(logger *slog.Logger) string {
 44	getMetadataMTLSAutoConfig(logger)
 45	if !mtlsConfiguration.valid() {
 46		return ""
 47	}
 48	return mtlsConfiguration.S2A.PlaintextAddress
 49}
 50
 51// GetMTLSS2AAddress returns the S2A address to be reached via MTLS connection.
 52// Returns empty string if not set or invalid.
 53func GetMTLSS2AAddress(logger *slog.Logger) string {
 54	getMetadataMTLSAutoConfig(logger)
 55	if !mtlsConfiguration.valid() {
 56		return ""
 57	}
 58	return mtlsConfiguration.S2A.MTLSAddress
 59}
 60
 61// mtlsConfig contains the configuration for establishing MTLS connections with Google APIs.
 62type mtlsConfig struct {
 63	S2A *s2aAddresses `json:"s2a"`
 64}
 65
 66func (c *mtlsConfig) valid() bool {
 67	return c != nil && c.S2A != nil
 68}
 69
 70// s2aAddresses contains the plaintext and/or MTLS S2A addresses.
 71type s2aAddresses struct {
 72	// PlaintextAddress is the plaintext address to reach S2A
 73	PlaintextAddress string `json:"plaintext_address"`
 74	// MTLSAddress is the MTLS address to reach S2A
 75	MTLSAddress string `json:"mtls_address"`
 76}
 77
 78func getMetadataMTLSAutoConfig(logger *slog.Logger) {
 79	var err error
 80	mtlsOnce.Do(func() {
 81		mtlsConfiguration, err = queryConfig(logger)
 82		if err != nil {
 83			log.Printf("Getting MTLS config failed: %v", err)
 84		}
 85	})
 86}
 87
 88var httpGetMetadataMTLSConfig = func(logger *slog.Logger) (string, error) {
 89	metadataClient := metadata.NewWithOptions(&metadata.Options{
 90		Logger: logger,
 91	})
 92	return metadataClient.GetWithContext(context.Background(), configEndpointSuffix)
 93}
 94
 95func queryConfig(logger *slog.Logger) (*mtlsConfig, error) {
 96	resp, err := httpGetMetadataMTLSConfig(logger)
 97	if err != nil {
 98		return nil, fmt.Errorf("querying MTLS config from MDS endpoint failed: %w", err)
 99	}
100	var config mtlsConfig
101	err = json.Unmarshal([]byte(resp), &config)
102	if err != nil {
103		return nil, fmt.Errorf("unmarshalling MTLS config from MDS endpoint failed: %w", err)
104	}
105	if config.S2A == nil {
106		return nil, fmt.Errorf("returned MTLS config from MDS endpoint is invalid: %v", config)
107	}
108	return &config, nil
109}
110
111func shouldUseS2A(clientCertSource cert.Provider, opts *Options) bool {
112	// If client cert is found, use that over S2A.
113	if clientCertSource != nil {
114		return false
115	}
116	// If EXPERIMENTAL_GOOGLE_API_USE_S2A is not set to true, skip S2A.
117	if !isGoogleS2AEnabled() {
118		return false
119	}
120	// If DefaultMTLSEndpoint is not set or has endpoint override, skip S2A.
121	if opts.DefaultMTLSEndpoint == "" || opts.Endpoint != "" {
122		return false
123	}
124	// If custom HTTP client is provided, skip S2A.
125	if opts.Client != nil {
126		return false
127	}
128	// If directPath is enabled, skip S2A.
129	return !opts.EnableDirectPath && !opts.EnableDirectPathXds
130}
131
132func isGoogleS2AEnabled() bool {
133	b, err := strconv.ParseBool(os.Getenv(googleAPIUseS2AEnv))
134	if err != nil {
135		return false
136	}
137	return b
138}