s2a_fallback.go

  1/*
  2 *
  3 * Copyright 2023 Google LLC
  4 *
  5 * Licensed under the Apache License, Version 2.0 (the "License");
  6 * you may not use this file except in compliance with the License.
  7 * You may obtain a copy of the License at
  8 *
  9 *     https://www.apache.org/licenses/LICENSE-2.0
 10 *
 11 * Unless required by applicable law or agreed to in writing, software
 12 * distributed under the License is distributed on an "AS IS" BASIS,
 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 * See the License for the specific language governing permissions and
 15 * limitations under the License.
 16 *
 17 */
 18
 19// Package fallback provides default implementations of fallback options when S2A fails.
 20package fallback
 21
 22import (
 23	"context"
 24	"crypto/tls"
 25	"fmt"
 26	"net"
 27
 28	"google.golang.org/grpc/credentials"
 29	"google.golang.org/grpc/grpclog"
 30)
 31
 32const (
 33	alpnProtoStrH2   = "h2"
 34	alpnProtoStrHTTP = "http/1.1"
 35	defaultHTTPSPort = "443"
 36)
 37
 38// FallbackTLSConfigGRPC is a tls.Config used by the DefaultFallbackClientHandshakeFunc function.
 39// It supports GRPC use case, thus the alpn is set to 'h2'.
 40var FallbackTLSConfigGRPC = tls.Config{
 41	MinVersion:         tls.VersionTLS13,
 42	ClientSessionCache: nil,
 43	NextProtos:         []string{alpnProtoStrH2},
 44}
 45
 46// FallbackTLSConfigHTTP is a tls.Config used by the DefaultFallbackDialerAndAddress func.
 47// It supports the HTTP use case and the alpn is set to both 'http/1.1' and 'h2'.
 48var FallbackTLSConfigHTTP = tls.Config{
 49	MinVersion:         tls.VersionTLS13,
 50	ClientSessionCache: nil,
 51	NextProtos:         []string{alpnProtoStrH2, alpnProtoStrHTTP},
 52}
 53
 54// ClientHandshake establishes a TLS connection and returns it, plus its auth info.
 55// Inputs:
 56//
 57//	targetServer: the server attempted with S2A.
 58//	conn: the tcp connection to the server at address targetServer that was passed into S2A's ClientHandshake func.
 59//	            If fallback is successful, the `conn` should be closed.
 60//	err: the error encountered when performing the client-side TLS handshake with S2A.
 61type ClientHandshake func(ctx context.Context, targetServer string, conn net.Conn, err error) (net.Conn, credentials.AuthInfo, error)
 62
 63// DefaultFallbackClientHandshakeFunc returns a ClientHandshake function,
 64// which establishes a TLS connection to the provided fallbackAddr, returns the new connection and its auth info.
 65// Example use:
 66//
 67//	transportCreds, _ = s2a.NewClientCreds(&s2a.ClientOptions{
 68//		S2AAddress: s2aAddress,
 69//		FallbackOpts: &s2a.FallbackOptions{ // optional
 70//			FallbackClientHandshakeFunc: fallback.DefaultFallbackClientHandshakeFunc(fallbackAddr),
 71//		},
 72//	})
 73//
 74// The fallback server's certificate must be verifiable using OS root store.
 75// The fallbackAddr is expected to be a network address, e.g. example.com:port. If port is not specified,
 76// it uses default port 443.
 77// In the returned function's TLS config, ClientSessionCache is explicitly set to nil to disable TLS resumption,
 78// and min TLS version is set to 1.3.
 79func DefaultFallbackClientHandshakeFunc(fallbackAddr string) (ClientHandshake, error) {
 80	var fallbackDialer = tls.Dialer{Config: &FallbackTLSConfigGRPC}
 81	return defaultFallbackClientHandshakeFuncInternal(fallbackAddr, fallbackDialer.DialContext)
 82}
 83
 84func defaultFallbackClientHandshakeFuncInternal(fallbackAddr string, dialContextFunc func(context.Context, string, string) (net.Conn, error)) (ClientHandshake, error) {
 85	fallbackServerAddr, err := processFallbackAddr(fallbackAddr)
 86	if err != nil {
 87		if grpclog.V(1) {
 88			grpclog.Infof("error processing fallback address [%s]: %v", fallbackAddr, err)
 89		}
 90		return nil, err
 91	}
 92	return func(ctx context.Context, targetServer string, conn net.Conn, s2aErr error) (net.Conn, credentials.AuthInfo, error) {
 93		fbConn, fbErr := dialContextFunc(ctx, "tcp", fallbackServerAddr)
 94		if fbErr != nil {
 95			grpclog.Infof("dialing to fallback server %s failed: %v", fallbackServerAddr, fbErr)
 96			return nil, nil, fmt.Errorf("dialing to fallback server %s failed: %v; S2A client handshake with %s error: %w", fallbackServerAddr, fbErr, targetServer, s2aErr)
 97		}
 98
 99		tc, success := fbConn.(*tls.Conn)
100		if !success {
101			grpclog.Infof("the connection with fallback server is expected to be tls but isn't")
102			return nil, nil, fmt.Errorf("the connection with fallback server is expected to be tls but isn't; S2A client handshake with %s error: %w", targetServer, s2aErr)
103		}
104
105		tlsInfo := credentials.TLSInfo{
106			State: tc.ConnectionState(),
107			CommonAuthInfo: credentials.CommonAuthInfo{
108				SecurityLevel: credentials.PrivacyAndIntegrity,
109			},
110		}
111		if grpclog.V(1) {
112			grpclog.Infof("ConnectionState.NegotiatedProtocol: %v", tc.ConnectionState().NegotiatedProtocol)
113			grpclog.Infof("ConnectionState.HandshakeComplete: %v", tc.ConnectionState().HandshakeComplete)
114			grpclog.Infof("ConnectionState.ServerName: %v", tc.ConnectionState().ServerName)
115		}
116		conn.Close()
117		return fbConn, tlsInfo, nil
118	}, nil
119}
120
121// DefaultFallbackDialerAndAddress returns a TLS dialer and the network address to dial.
122// Example use:
123//
124//	    fallbackDialer, fallbackServerAddr := fallback.DefaultFallbackDialerAndAddress(fallbackAddr)
125//		dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
126//			S2AAddress:         s2aAddress, // required
127//			FallbackOpts: &s2a.FallbackOptions{
128//				FallbackDialer: &s2a.FallbackDialer{
129//					Dialer:     fallbackDialer,
130//					ServerAddr: fallbackServerAddr,
131//				},
132//			},
133//	})
134//
135// The fallback server's certificate should be verifiable using OS root store.
136// The fallbackAddr is expected to be a network address, e.g. example.com:port. If port is not specified,
137// it uses default port 443.
138// In the returned function's TLS config, ClientSessionCache is explicitly set to nil to disable TLS resumption,
139// and min TLS version is set to 1.3.
140func DefaultFallbackDialerAndAddress(fallbackAddr string) (*tls.Dialer, string, error) {
141	fallbackServerAddr, err := processFallbackAddr(fallbackAddr)
142	if err != nil {
143		if grpclog.V(1) {
144			grpclog.Infof("error processing fallback address [%s]: %v", fallbackAddr, err)
145		}
146		return nil, "", err
147	}
148	return &tls.Dialer{Config: &FallbackTLSConfigHTTP}, fallbackServerAddr, nil
149}
150
151func processFallbackAddr(fallbackAddr string) (string, error) {
152	var fallbackServerAddr string
153	var err error
154
155	if fallbackAddr == "" {
156		return "", fmt.Errorf("empty fallback address")
157	}
158	_, _, err = net.SplitHostPort(fallbackAddr)
159	if err != nil {
160		// fallbackAddr does not have port suffix
161		fallbackServerAddr = net.JoinHostPort(fallbackAddr, defaultHTTPSPort)
162	} else {
163		// FallbackServerAddr already has port suffix
164		fallbackServerAddr = fallbackAddr
165	}
166	return fallbackServerAddr, nil
167}