s2a.go

  1/*
  2 *
  3 * Copyright 2021 Google LLC
  4 *
  5 * Licensed under the Apache License, Version 2.0 (the "License");
  6 * you may not use this file except in compliance with the License.
  7 * You may obtain a copy of the License at
  8 *
  9 *     https://www.apache.org/licenses/LICENSE-2.0
 10 *
 11 * Unless required by applicable law or agreed to in writing, software
 12 * distributed under the License is distributed on an "AS IS" BASIS,
 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 * See the License for the specific language governing permissions and
 15 * limitations under the License.
 16 *
 17 */
 18
 19// Package s2a provides the S2A transport credentials used by a gRPC
 20// application.
 21package s2a
 22
 23import (
 24	"context"
 25	"crypto/tls"
 26	"errors"
 27	"fmt"
 28	"net"
 29	"sync"
 30	"time"
 31
 32	"github.com/google/s2a-go/fallback"
 33	"github.com/google/s2a-go/internal/handshaker"
 34	"github.com/google/s2a-go/internal/handshaker/service"
 35	"github.com/google/s2a-go/internal/tokenmanager"
 36	"github.com/google/s2a-go/internal/v2"
 37	"github.com/google/s2a-go/retry"
 38	"google.golang.org/grpc/credentials"
 39	"google.golang.org/grpc/grpclog"
 40	"google.golang.org/protobuf/proto"
 41
 42	commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
 43	commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
 44	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
 45)
 46
 47const (
 48	s2aSecurityProtocol = "tls"
 49	// defaultTimeout specifies the default server handshake timeout.
 50	defaultTimeout = 30.0 * time.Second
 51)
 52
 53// s2aTransportCreds are the transport credentials required for establishing
 54// a secure connection using the S2A. They implement the
 55// credentials.TransportCredentials interface.
 56type s2aTransportCreds struct {
 57	info          *credentials.ProtocolInfo
 58	minTLSVersion commonpbv1.TLSVersion
 59	maxTLSVersion commonpbv1.TLSVersion
 60	// tlsCiphersuites contains the ciphersuites used in the S2A connection.
 61	// Note that these are currently unconfigurable.
 62	tlsCiphersuites []commonpbv1.Ciphersuite
 63	// localIdentity should only be used by the client.
 64	localIdentity *commonpbv1.Identity
 65	// localIdentities should only be used by the server.
 66	localIdentities []*commonpbv1.Identity
 67	// targetIdentities should only be used by the client.
 68	targetIdentities            []*commonpbv1.Identity
 69	isClient                    bool
 70	s2aAddr                     string
 71	ensureProcessSessionTickets *sync.WaitGroup
 72}
 73
 74// NewClientCreds returns a client-side transport credentials object that uses
 75// the S2A to establish a secure connection with a server.
 76func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
 77	if opts == nil {
 78		return nil, errors.New("nil client options")
 79	}
 80	var targetIdentities []*commonpbv1.Identity
 81	for _, targetIdentity := range opts.TargetIdentities {
 82		protoTargetIdentity, err := toProtoIdentity(targetIdentity)
 83		if err != nil {
 84			return nil, err
 85		}
 86		targetIdentities = append(targetIdentities, protoTargetIdentity)
 87	}
 88	localIdentity, err := toProtoIdentity(opts.LocalIdentity)
 89	if err != nil {
 90		return nil, err
 91	}
 92	if opts.EnableLegacyMode {
 93		return &s2aTransportCreds{
 94			info: &credentials.ProtocolInfo{
 95				SecurityProtocol: s2aSecurityProtocol,
 96			},
 97			minTLSVersion: commonpbv1.TLSVersion_TLS1_3,
 98			maxTLSVersion: commonpbv1.TLSVersion_TLS1_3,
 99			tlsCiphersuites: []commonpbv1.Ciphersuite{
100				commonpbv1.Ciphersuite_AES_128_GCM_SHA256,
101				commonpbv1.Ciphersuite_AES_256_GCM_SHA384,
102				commonpbv1.Ciphersuite_CHACHA20_POLY1305_SHA256,
103			},
104			localIdentity:               localIdentity,
105			targetIdentities:            targetIdentities,
106			isClient:                    true,
107			s2aAddr:                     opts.S2AAddress,
108			ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
109		}, nil
110	}
111	verificationMode := getVerificationMode(opts.VerificationMode)
112	var fallbackFunc fallback.ClientHandshake
113	if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
114		fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
115	}
116	v2LocalIdentity, err := toV2ProtoIdentity(opts.LocalIdentity)
117	if err != nil {
118		return nil, err
119	}
120	return v2.NewClientCreds(opts.S2AAddress, opts.TransportCreds, v2LocalIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
121}
122
123// NewServerCreds returns a server-side transport credentials object that uses
124// the S2A to establish a secure connection with a client.
125func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
126	if opts == nil {
127		return nil, errors.New("nil server options")
128	}
129	var localIdentities []*commonpbv1.Identity
130	for _, localIdentity := range opts.LocalIdentities {
131		protoLocalIdentity, err := toProtoIdentity(localIdentity)
132		if err != nil {
133			return nil, err
134		}
135		localIdentities = append(localIdentities, protoLocalIdentity)
136	}
137	if opts.EnableLegacyMode {
138		return &s2aTransportCreds{
139			info: &credentials.ProtocolInfo{
140				SecurityProtocol: s2aSecurityProtocol,
141			},
142			minTLSVersion: commonpbv1.TLSVersion_TLS1_3,
143			maxTLSVersion: commonpbv1.TLSVersion_TLS1_3,
144			tlsCiphersuites: []commonpbv1.Ciphersuite{
145				commonpbv1.Ciphersuite_AES_128_GCM_SHA256,
146				commonpbv1.Ciphersuite_AES_256_GCM_SHA384,
147				commonpbv1.Ciphersuite_CHACHA20_POLY1305_SHA256,
148			},
149			localIdentities: localIdentities,
150			isClient:        false,
151			s2aAddr:         opts.S2AAddress,
152		}, nil
153	}
154	verificationMode := getVerificationMode(opts.VerificationMode)
155	var v2LocalIdentities []*commonpb.Identity
156	for _, localIdentity := range opts.LocalIdentities {
157		protoLocalIdentity, err := toV2ProtoIdentity(localIdentity)
158		if err != nil {
159			return nil, err
160		}
161		v2LocalIdentities = append(v2LocalIdentities, protoLocalIdentity)
162	}
163	return v2.NewServerCreds(opts.S2AAddress, opts.TransportCreds, v2LocalIdentities, verificationMode, opts.getS2AStream)
164}
165
166// ClientHandshake initiates a client-side TLS handshake using the S2A.
167func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
168	if !c.isClient {
169		return nil, nil, errors.New("client handshake called using server transport credentials")
170	}
171
172	var cancel context.CancelFunc
173	ctx, cancel = context.WithCancel(ctx)
174	defer cancel()
175
176	// Connect to the S2A.
177	hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
178	if err != nil {
179		grpclog.Infof("Failed to connect to S2A: %v", err)
180		return nil, nil, err
181	}
182
183	opts := &handshaker.ClientHandshakerOptions{
184		MinTLSVersion:               c.minTLSVersion,
185		MaxTLSVersion:               c.maxTLSVersion,
186		TLSCiphersuites:             c.tlsCiphersuites,
187		TargetIdentities:            c.targetIdentities,
188		LocalIdentity:               c.localIdentity,
189		TargetName:                  serverAuthority,
190		EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
191	}
192	chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
193	if err != nil {
194		grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
195		return nil, nil, err
196	}
197	defer func() {
198		if err != nil {
199			if closeErr := chs.Close(); closeErr != nil {
200				grpclog.Infof("Close failed unexpectedly: %v", err)
201				err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
202			}
203		}
204	}()
205
206	secConn, authInfo, err := chs.ClientHandshake(context.Background())
207	if err != nil {
208		grpclog.Infof("Handshake failed: %v", err)
209		return nil, nil, err
210	}
211	return secConn, authInfo, nil
212}
213
214// ServerHandshake initiates a server-side TLS handshake using the S2A.
215func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
216	if c.isClient {
217		return nil, nil, errors.New("server handshake called using client transport credentials")
218	}
219
220	ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
221	defer cancel()
222
223	// Connect to the S2A.
224	hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
225	if err != nil {
226		grpclog.Infof("Failed to connect to S2A: %v", err)
227		return nil, nil, err
228	}
229
230	opts := &handshaker.ServerHandshakerOptions{
231		MinTLSVersion:   c.minTLSVersion,
232		MaxTLSVersion:   c.maxTLSVersion,
233		TLSCiphersuites: c.tlsCiphersuites,
234		LocalIdentities: c.localIdentities,
235	}
236	shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
237	if err != nil {
238		grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
239		return nil, nil, err
240	}
241	defer func() {
242		if err != nil {
243			if closeErr := shs.Close(); closeErr != nil {
244				grpclog.Infof("Close failed unexpectedly: %v", err)
245				err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
246			}
247		}
248	}()
249
250	secConn, authInfo, err := shs.ServerHandshake(context.Background())
251	if err != nil {
252		grpclog.Infof("Handshake failed: %v", err)
253		return nil, nil, err
254	}
255	return secConn, authInfo, nil
256}
257
258func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
259	return *c.info
260}
261
262func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
263	info := *c.info
264	var localIdentity *commonpbv1.Identity
265	if c.localIdentity != nil {
266		localIdentity = proto.Clone(c.localIdentity).(*commonpbv1.Identity)
267	}
268	var localIdentities []*commonpbv1.Identity
269	if c.localIdentities != nil {
270		localIdentities = make([]*commonpbv1.Identity, len(c.localIdentities))
271		for i, localIdentity := range c.localIdentities {
272			localIdentities[i] = proto.Clone(localIdentity).(*commonpbv1.Identity)
273		}
274	}
275	var targetIdentities []*commonpbv1.Identity
276	if c.targetIdentities != nil {
277		targetIdentities = make([]*commonpbv1.Identity, len(c.targetIdentities))
278		for i, targetIdentity := range c.targetIdentities {
279			targetIdentities[i] = proto.Clone(targetIdentity).(*commonpbv1.Identity)
280		}
281	}
282	return &s2aTransportCreds{
283		info:             &info,
284		minTLSVersion:    c.minTLSVersion,
285		maxTLSVersion:    c.maxTLSVersion,
286		tlsCiphersuites:  c.tlsCiphersuites,
287		localIdentity:    localIdentity,
288		localIdentities:  localIdentities,
289		targetIdentities: targetIdentities,
290		isClient:         c.isClient,
291		s2aAddr:          c.s2aAddr,
292	}
293}
294
295func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
296	c.info.ServerName = serverNameOverride
297	return nil
298}
299
300// TLSClientConfigOptions specifies parameters for creating client TLS config.
301type TLSClientConfigOptions struct {
302	// ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
303	// 		tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
304	//			ServerName: "example.com",
305	//		})
306	ServerName string
307}
308
309// TLSClientConfigFactory defines the interface for a client TLS config factory.
310type TLSClientConfigFactory interface {
311	Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
312}
313
314// NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
315func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
316	if opts == nil {
317		return nil, fmt.Errorf("opts must be non-nil")
318	}
319	if opts.EnableLegacyMode {
320		return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
321	}
322	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
323	if err != nil {
324		// The only possible error is: access token not set in the environment,
325		// which is okay in environments other than serverless.
326		grpclog.Infof("Access token manager not initialized: %v", err)
327		return &s2aTLSClientConfigFactory{
328			s2av2Address:              opts.S2AAddress,
329			transportCreds:            opts.TransportCreds,
330			tokenManager:              nil,
331			verificationMode:          getVerificationMode(opts.VerificationMode),
332			serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
333		}, nil
334	}
335	return &s2aTLSClientConfigFactory{
336		s2av2Address:              opts.S2AAddress,
337		transportCreds:            opts.TransportCreds,
338		tokenManager:              tokenManager,
339		verificationMode:          getVerificationMode(opts.VerificationMode),
340		serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
341	}, nil
342}
343
344type s2aTLSClientConfigFactory struct {
345	s2av2Address              string
346	transportCreds            credentials.TransportCredentials
347	tokenManager              tokenmanager.AccessTokenManager
348	verificationMode          s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
349	serverAuthorizationPolicy []byte
350}
351
352func (f *s2aTLSClientConfigFactory) Build(
353	ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
354	serverName := ""
355	if opts != nil && opts.ServerName != "" {
356		serverName = opts.ServerName
357	}
358	return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.transportCreds, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
359}
360
361func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
362	switch verificationMode {
363	case ConnectToGoogle:
364		return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
365	case Spiffe:
366		return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
367	case ReservedCustomVerificationMode3:
368		return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_3
369	case ReservedCustomVerificationMode4:
370		return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_4
371	case ReservedCustomVerificationMode5:
372		return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_5
373	default:
374		return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
375	}
376}
377
378// NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
379// Example use with http.RoundTripper:
380//
381//		dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
382//			S2AAddress:         s2aAddress, // required
383//		})
384//	 	transport := http.DefaultTransport
385//	 	transport.DialTLSContext = dialTLSContext
386func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
387
388	return func(ctx context.Context, network, addr string) (net.Conn, error) {
389
390		fallback := func(err error) (net.Conn, error) {
391			if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
392				opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
393				fbDialer := opts.FallbackOpts.FallbackDialer
394				grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
395				fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
396				if fbErr != nil {
397					return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
398				}
399				return fbConn, nil
400			}
401			return nil, err
402		}
403
404		factory, err := NewTLSClientConfigFactory(opts)
405		if err != nil {
406			grpclog.Infof("error creating S2A client config factory: %v", err)
407			return fallback(err)
408		}
409
410		serverName, _, err := net.SplitHostPort(addr)
411		if err != nil {
412			serverName = addr
413		}
414		timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
415		defer cancel()
416
417		var s2aTLSConfig *tls.Config
418		var c net.Conn
419		retry.Run(timeoutCtx,
420			func() error {
421				s2aTLSConfig, err = factory.Build(timeoutCtx, &TLSClientConfigOptions{
422					ServerName: serverName,
423				})
424				if err != nil {
425					grpclog.Infof("error building S2A TLS config: %v", err)
426					return err
427				}
428
429				s2aDialer := &tls.Dialer{
430					Config: s2aTLSConfig,
431				}
432				c, err = s2aDialer.DialContext(timeoutCtx, network, addr)
433				return err
434			})
435		if err != nil {
436			grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
437			return fallback(err)
438		}
439		grpclog.Infof("success dialing MTLS to %s with S2A", addr)
440		return c, nil
441	}
442}