1/*
2 *
3 * Copyright 2021 Google LLC
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package s2a provides the S2A transport credentials used by a gRPC
20// application.
21package s2a
22
23import (
24 "context"
25 "crypto/tls"
26 "errors"
27 "fmt"
28 "net"
29 "sync"
30 "time"
31
32 "github.com/google/s2a-go/fallback"
33 "github.com/google/s2a-go/internal/handshaker"
34 "github.com/google/s2a-go/internal/handshaker/service"
35 "github.com/google/s2a-go/internal/tokenmanager"
36 "github.com/google/s2a-go/internal/v2"
37 "github.com/google/s2a-go/retry"
38 "google.golang.org/grpc/credentials"
39 "google.golang.org/grpc/grpclog"
40 "google.golang.org/protobuf/proto"
41
42 commonpbv1 "github.com/google/s2a-go/internal/proto/common_go_proto"
43 commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
44 s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
45)
46
47const (
48 s2aSecurityProtocol = "tls"
49 // defaultTimeout specifies the default server handshake timeout.
50 defaultTimeout = 30.0 * time.Second
51)
52
53// s2aTransportCreds are the transport credentials required for establishing
54// a secure connection using the S2A. They implement the
55// credentials.TransportCredentials interface.
56type s2aTransportCreds struct {
57 info *credentials.ProtocolInfo
58 minTLSVersion commonpbv1.TLSVersion
59 maxTLSVersion commonpbv1.TLSVersion
60 // tlsCiphersuites contains the ciphersuites used in the S2A connection.
61 // Note that these are currently unconfigurable.
62 tlsCiphersuites []commonpbv1.Ciphersuite
63 // localIdentity should only be used by the client.
64 localIdentity *commonpbv1.Identity
65 // localIdentities should only be used by the server.
66 localIdentities []*commonpbv1.Identity
67 // targetIdentities should only be used by the client.
68 targetIdentities []*commonpbv1.Identity
69 isClient bool
70 s2aAddr string
71 ensureProcessSessionTickets *sync.WaitGroup
72}
73
74// NewClientCreds returns a client-side transport credentials object that uses
75// the S2A to establish a secure connection with a server.
76func NewClientCreds(opts *ClientOptions) (credentials.TransportCredentials, error) {
77 if opts == nil {
78 return nil, errors.New("nil client options")
79 }
80 var targetIdentities []*commonpbv1.Identity
81 for _, targetIdentity := range opts.TargetIdentities {
82 protoTargetIdentity, err := toProtoIdentity(targetIdentity)
83 if err != nil {
84 return nil, err
85 }
86 targetIdentities = append(targetIdentities, protoTargetIdentity)
87 }
88 localIdentity, err := toProtoIdentity(opts.LocalIdentity)
89 if err != nil {
90 return nil, err
91 }
92 if opts.EnableLegacyMode {
93 return &s2aTransportCreds{
94 info: &credentials.ProtocolInfo{
95 SecurityProtocol: s2aSecurityProtocol,
96 },
97 minTLSVersion: commonpbv1.TLSVersion_TLS1_3,
98 maxTLSVersion: commonpbv1.TLSVersion_TLS1_3,
99 tlsCiphersuites: []commonpbv1.Ciphersuite{
100 commonpbv1.Ciphersuite_AES_128_GCM_SHA256,
101 commonpbv1.Ciphersuite_AES_256_GCM_SHA384,
102 commonpbv1.Ciphersuite_CHACHA20_POLY1305_SHA256,
103 },
104 localIdentity: localIdentity,
105 targetIdentities: targetIdentities,
106 isClient: true,
107 s2aAddr: opts.S2AAddress,
108 ensureProcessSessionTickets: opts.EnsureProcessSessionTickets,
109 }, nil
110 }
111 verificationMode := getVerificationMode(opts.VerificationMode)
112 var fallbackFunc fallback.ClientHandshake
113 if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackClientHandshakeFunc != nil {
114 fallbackFunc = opts.FallbackOpts.FallbackClientHandshakeFunc
115 }
116 v2LocalIdentity, err := toV2ProtoIdentity(opts.LocalIdentity)
117 if err != nil {
118 return nil, err
119 }
120 return v2.NewClientCreds(opts.S2AAddress, opts.TransportCreds, v2LocalIdentity, verificationMode, fallbackFunc, opts.getS2AStream, opts.serverAuthorizationPolicy)
121}
122
123// NewServerCreds returns a server-side transport credentials object that uses
124// the S2A to establish a secure connection with a client.
125func NewServerCreds(opts *ServerOptions) (credentials.TransportCredentials, error) {
126 if opts == nil {
127 return nil, errors.New("nil server options")
128 }
129 var localIdentities []*commonpbv1.Identity
130 for _, localIdentity := range opts.LocalIdentities {
131 protoLocalIdentity, err := toProtoIdentity(localIdentity)
132 if err != nil {
133 return nil, err
134 }
135 localIdentities = append(localIdentities, protoLocalIdentity)
136 }
137 if opts.EnableLegacyMode {
138 return &s2aTransportCreds{
139 info: &credentials.ProtocolInfo{
140 SecurityProtocol: s2aSecurityProtocol,
141 },
142 minTLSVersion: commonpbv1.TLSVersion_TLS1_3,
143 maxTLSVersion: commonpbv1.TLSVersion_TLS1_3,
144 tlsCiphersuites: []commonpbv1.Ciphersuite{
145 commonpbv1.Ciphersuite_AES_128_GCM_SHA256,
146 commonpbv1.Ciphersuite_AES_256_GCM_SHA384,
147 commonpbv1.Ciphersuite_CHACHA20_POLY1305_SHA256,
148 },
149 localIdentities: localIdentities,
150 isClient: false,
151 s2aAddr: opts.S2AAddress,
152 }, nil
153 }
154 verificationMode := getVerificationMode(opts.VerificationMode)
155 var v2LocalIdentities []*commonpb.Identity
156 for _, localIdentity := range opts.LocalIdentities {
157 protoLocalIdentity, err := toV2ProtoIdentity(localIdentity)
158 if err != nil {
159 return nil, err
160 }
161 v2LocalIdentities = append(v2LocalIdentities, protoLocalIdentity)
162 }
163 return v2.NewServerCreds(opts.S2AAddress, opts.TransportCreds, v2LocalIdentities, verificationMode, opts.getS2AStream)
164}
165
166// ClientHandshake initiates a client-side TLS handshake using the S2A.
167func (c *s2aTransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
168 if !c.isClient {
169 return nil, nil, errors.New("client handshake called using server transport credentials")
170 }
171
172 var cancel context.CancelFunc
173 ctx, cancel = context.WithCancel(ctx)
174 defer cancel()
175
176 // Connect to the S2A.
177 hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
178 if err != nil {
179 grpclog.Infof("Failed to connect to S2A: %v", err)
180 return nil, nil, err
181 }
182
183 opts := &handshaker.ClientHandshakerOptions{
184 MinTLSVersion: c.minTLSVersion,
185 MaxTLSVersion: c.maxTLSVersion,
186 TLSCiphersuites: c.tlsCiphersuites,
187 TargetIdentities: c.targetIdentities,
188 LocalIdentity: c.localIdentity,
189 TargetName: serverAuthority,
190 EnsureProcessSessionTickets: c.ensureProcessSessionTickets,
191 }
192 chs, err := handshaker.NewClientHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
193 if err != nil {
194 grpclog.Infof("Call to handshaker.NewClientHandshaker failed: %v", err)
195 return nil, nil, err
196 }
197 defer func() {
198 if err != nil {
199 if closeErr := chs.Close(); closeErr != nil {
200 grpclog.Infof("Close failed unexpectedly: %v", err)
201 err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
202 }
203 }
204 }()
205
206 secConn, authInfo, err := chs.ClientHandshake(context.Background())
207 if err != nil {
208 grpclog.Infof("Handshake failed: %v", err)
209 return nil, nil, err
210 }
211 return secConn, authInfo, nil
212}
213
214// ServerHandshake initiates a server-side TLS handshake using the S2A.
215func (c *s2aTransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
216 if c.isClient {
217 return nil, nil, errors.New("server handshake called using client transport credentials")
218 }
219
220 ctx, cancel := context.WithTimeout(context.Background(), defaultTimeout)
221 defer cancel()
222
223 // Connect to the S2A.
224 hsConn, err := service.Dial(ctx, c.s2aAddr, nil)
225 if err != nil {
226 grpclog.Infof("Failed to connect to S2A: %v", err)
227 return nil, nil, err
228 }
229
230 opts := &handshaker.ServerHandshakerOptions{
231 MinTLSVersion: c.minTLSVersion,
232 MaxTLSVersion: c.maxTLSVersion,
233 TLSCiphersuites: c.tlsCiphersuites,
234 LocalIdentities: c.localIdentities,
235 }
236 shs, err := handshaker.NewServerHandshaker(ctx, hsConn, rawConn, c.s2aAddr, opts)
237 if err != nil {
238 grpclog.Infof("Call to handshaker.NewServerHandshaker failed: %v", err)
239 return nil, nil, err
240 }
241 defer func() {
242 if err != nil {
243 if closeErr := shs.Close(); closeErr != nil {
244 grpclog.Infof("Close failed unexpectedly: %v", err)
245 err = fmt.Errorf("%v: close unexpectedly failed: %v", err, closeErr)
246 }
247 }
248 }()
249
250 secConn, authInfo, err := shs.ServerHandshake(context.Background())
251 if err != nil {
252 grpclog.Infof("Handshake failed: %v", err)
253 return nil, nil, err
254 }
255 return secConn, authInfo, nil
256}
257
258func (c *s2aTransportCreds) Info() credentials.ProtocolInfo {
259 return *c.info
260}
261
262func (c *s2aTransportCreds) Clone() credentials.TransportCredentials {
263 info := *c.info
264 var localIdentity *commonpbv1.Identity
265 if c.localIdentity != nil {
266 localIdentity = proto.Clone(c.localIdentity).(*commonpbv1.Identity)
267 }
268 var localIdentities []*commonpbv1.Identity
269 if c.localIdentities != nil {
270 localIdentities = make([]*commonpbv1.Identity, len(c.localIdentities))
271 for i, localIdentity := range c.localIdentities {
272 localIdentities[i] = proto.Clone(localIdentity).(*commonpbv1.Identity)
273 }
274 }
275 var targetIdentities []*commonpbv1.Identity
276 if c.targetIdentities != nil {
277 targetIdentities = make([]*commonpbv1.Identity, len(c.targetIdentities))
278 for i, targetIdentity := range c.targetIdentities {
279 targetIdentities[i] = proto.Clone(targetIdentity).(*commonpbv1.Identity)
280 }
281 }
282 return &s2aTransportCreds{
283 info: &info,
284 minTLSVersion: c.minTLSVersion,
285 maxTLSVersion: c.maxTLSVersion,
286 tlsCiphersuites: c.tlsCiphersuites,
287 localIdentity: localIdentity,
288 localIdentities: localIdentities,
289 targetIdentities: targetIdentities,
290 isClient: c.isClient,
291 s2aAddr: c.s2aAddr,
292 }
293}
294
295func (c *s2aTransportCreds) OverrideServerName(serverNameOverride string) error {
296 c.info.ServerName = serverNameOverride
297 return nil
298}
299
300// TLSClientConfigOptions specifies parameters for creating client TLS config.
301type TLSClientConfigOptions struct {
302 // ServerName is required by s2a as the expected name when verifying the hostname found in server's certificate.
303 // tlsConfig, _ := factory.Build(ctx, &s2a.TLSClientConfigOptions{
304 // ServerName: "example.com",
305 // })
306 ServerName string
307}
308
309// TLSClientConfigFactory defines the interface for a client TLS config factory.
310type TLSClientConfigFactory interface {
311 Build(ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error)
312}
313
314// NewTLSClientConfigFactory returns an instance of s2aTLSClientConfigFactory.
315func NewTLSClientConfigFactory(opts *ClientOptions) (TLSClientConfigFactory, error) {
316 if opts == nil {
317 return nil, fmt.Errorf("opts must be non-nil")
318 }
319 if opts.EnableLegacyMode {
320 return nil, fmt.Errorf("NewTLSClientConfigFactory only supports S2Av2")
321 }
322 tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
323 if err != nil {
324 // The only possible error is: access token not set in the environment,
325 // which is okay in environments other than serverless.
326 grpclog.Infof("Access token manager not initialized: %v", err)
327 return &s2aTLSClientConfigFactory{
328 s2av2Address: opts.S2AAddress,
329 transportCreds: opts.TransportCreds,
330 tokenManager: nil,
331 verificationMode: getVerificationMode(opts.VerificationMode),
332 serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
333 }, nil
334 }
335 return &s2aTLSClientConfigFactory{
336 s2av2Address: opts.S2AAddress,
337 transportCreds: opts.TransportCreds,
338 tokenManager: tokenManager,
339 verificationMode: getVerificationMode(opts.VerificationMode),
340 serverAuthorizationPolicy: opts.serverAuthorizationPolicy,
341 }, nil
342}
343
344type s2aTLSClientConfigFactory struct {
345 s2av2Address string
346 transportCreds credentials.TransportCredentials
347 tokenManager tokenmanager.AccessTokenManager
348 verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
349 serverAuthorizationPolicy []byte
350}
351
352func (f *s2aTLSClientConfigFactory) Build(
353 ctx context.Context, opts *TLSClientConfigOptions) (*tls.Config, error) {
354 serverName := ""
355 if opts != nil && opts.ServerName != "" {
356 serverName = opts.ServerName
357 }
358 return v2.NewClientTLSConfig(ctx, f.s2av2Address, f.transportCreds, f.tokenManager, f.verificationMode, serverName, f.serverAuthorizationPolicy)
359}
360
361func getVerificationMode(verificationMode VerificationModeType) s2av2pb.ValidatePeerCertificateChainReq_VerificationMode {
362 switch verificationMode {
363 case ConnectToGoogle:
364 return s2av2pb.ValidatePeerCertificateChainReq_CONNECT_TO_GOOGLE
365 case Spiffe:
366 return s2av2pb.ValidatePeerCertificateChainReq_SPIFFE
367 case ReservedCustomVerificationMode3:
368 return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_3
369 case ReservedCustomVerificationMode4:
370 return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_4
371 case ReservedCustomVerificationMode5:
372 return s2av2pb.ValidatePeerCertificateChainReq_RESERVED_CUSTOM_VERIFICATION_MODE_5
373 default:
374 return s2av2pb.ValidatePeerCertificateChainReq_UNSPECIFIED
375 }
376}
377
378// NewS2ADialTLSContextFunc returns a dialer which establishes an MTLS connection using S2A.
379// Example use with http.RoundTripper:
380//
381// dialTLSContext := s2a.NewS2aDialTLSContextFunc(&s2a.ClientOptions{
382// S2AAddress: s2aAddress, // required
383// })
384// transport := http.DefaultTransport
385// transport.DialTLSContext = dialTLSContext
386func NewS2ADialTLSContextFunc(opts *ClientOptions) func(ctx context.Context, network, addr string) (net.Conn, error) {
387
388 return func(ctx context.Context, network, addr string) (net.Conn, error) {
389
390 fallback := func(err error) (net.Conn, error) {
391 if opts.FallbackOpts != nil && opts.FallbackOpts.FallbackDialer != nil &&
392 opts.FallbackOpts.FallbackDialer.Dialer != nil && opts.FallbackOpts.FallbackDialer.ServerAddr != "" {
393 fbDialer := opts.FallbackOpts.FallbackDialer
394 grpclog.Infof("fall back to dial: %s", fbDialer.ServerAddr)
395 fbConn, fbErr := fbDialer.Dialer.DialContext(ctx, network, fbDialer.ServerAddr)
396 if fbErr != nil {
397 return nil, fmt.Errorf("error fallback to %s: %v; S2A error: %w", fbDialer.ServerAddr, fbErr, err)
398 }
399 return fbConn, nil
400 }
401 return nil, err
402 }
403
404 factory, err := NewTLSClientConfigFactory(opts)
405 if err != nil {
406 grpclog.Infof("error creating S2A client config factory: %v", err)
407 return fallback(err)
408 }
409
410 serverName, _, err := net.SplitHostPort(addr)
411 if err != nil {
412 serverName = addr
413 }
414 timeoutCtx, cancel := context.WithTimeout(ctx, v2.GetS2ATimeout())
415 defer cancel()
416
417 var s2aTLSConfig *tls.Config
418 var c net.Conn
419 retry.Run(timeoutCtx,
420 func() error {
421 s2aTLSConfig, err = factory.Build(timeoutCtx, &TLSClientConfigOptions{
422 ServerName: serverName,
423 })
424 if err != nil {
425 grpclog.Infof("error building S2A TLS config: %v", err)
426 return err
427 }
428
429 s2aDialer := &tls.Dialer{
430 Config: s2aTLSConfig,
431 }
432 c, err = s2aDialer.DialContext(timeoutCtx, network, addr)
433 return err
434 })
435 if err != nil {
436 grpclog.Infof("error dialing with S2A to %s: %v", addr, err)
437 return fallback(err)
438 }
439 grpclog.Infof("success dialing MTLS to %s with S2A", addr)
440 return c, nil
441 }
442}