tlsconfigstore.go

  1/*
  2 *
  3 * Copyright 2022 Google LLC
  4 *
  5 * Licensed under the Apache License, Version 2.0 (the "License");
  6 * you may not use this file except in compliance with the License.
  7 * You may obtain a copy of the License at
  8 *
  9 *     https://www.apache.org/licenses/LICENSE-2.0
 10 *
 11 * Unless required by applicable law or agreed to in writing, software
 12 * distributed under the License is distributed on an "AS IS" BASIS,
 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 * See the License for the specific language governing permissions and
 15 * limitations under the License.
 16 *
 17 */
 18
 19// Package tlsconfigstore offloads operations to S2Av2.
 20package tlsconfigstore
 21
 22import (
 23	"crypto/tls"
 24	"crypto/x509"
 25	"encoding/pem"
 26	"errors"
 27	"fmt"
 28
 29	"github.com/google/s2a-go/internal/tokenmanager"
 30	"github.com/google/s2a-go/internal/v2/certverifier"
 31	"github.com/google/s2a-go/internal/v2/remotesigner"
 32	"github.com/google/s2a-go/stream"
 33	"google.golang.org/grpc/codes"
 34	"google.golang.org/grpc/grpclog"
 35
 36	commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
 37	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
 38)
 39
 40const (
 41	// HTTP/2
 42	h2 = "h2"
 43)
 44
 45// GetTLSConfigurationForClient returns a tls.Config instance for use by a client application.
 46func GetTLSConfigurationForClient(serverHostname string, s2AStream stream.S2AStream, tokenManager tokenmanager.AccessTokenManager, localIdentity *commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, serverAuthorizationPolicy []byte) (*tls.Config, error) {
 47	authMechanisms := getAuthMechanisms(tokenManager, []*commonpb.Identity{localIdentity})
 48
 49	if grpclog.V(1) {
 50		grpclog.Infof("Sending request to S2Av2 for client TLS config.")
 51	}
 52	// Send request to S2Av2 for config.
 53	if err := s2AStream.Send(&s2av2pb.SessionReq{
 54		LocalIdentity:            localIdentity,
 55		AuthenticationMechanisms: authMechanisms,
 56		ReqOneof: &s2av2pb.SessionReq_GetTlsConfigurationReq{
 57			GetTlsConfigurationReq: &s2av2pb.GetTlsConfigurationReq{
 58				ConnectionSide: commonpb.ConnectionSide_CONNECTION_SIDE_CLIENT,
 59			},
 60		},
 61	}); err != nil {
 62		grpclog.Infof("Failed to send request to S2Av2 for client TLS config")
 63		return nil, err
 64	}
 65
 66	// Get the response containing config from S2Av2.
 67	resp, err := s2AStream.Recv()
 68	if err != nil {
 69		grpclog.Infof("Failed to receive client TLS config response from S2Av2.")
 70		return nil, err
 71	}
 72
 73	// TODO(rmehta19): Add unit test for this if statement.
 74	if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
 75		return nil, fmt.Errorf("failed to get TLS configuration from S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
 76	}
 77
 78	// Extract TLS configiguration from SessionResp.
 79	tlsConfig := resp.GetGetTlsConfigurationResp().GetClientTlsConfiguration()
 80
 81	var cert tls.Certificate
 82	for i, v := range tlsConfig.CertificateChain {
 83		// Populate Certificates field.
 84		block, _ := pem.Decode([]byte(v))
 85		if block == nil {
 86			return nil, errors.New("certificate in CertificateChain obtained from S2Av2 is empty")
 87		}
 88		x509Cert, err := x509.ParseCertificate(block.Bytes)
 89		if err != nil {
 90			return nil, err
 91		}
 92		cert.Certificate = append(cert.Certificate, x509Cert.Raw)
 93		if i == 0 {
 94			cert.Leaf = x509Cert
 95		}
 96	}
 97
 98	if len(tlsConfig.CertificateChain) > 0 {
 99		cert.PrivateKey = remotesigner.New(cert.Leaf, s2AStream)
100		if cert.PrivateKey == nil {
101			return nil, errors.New("failed to retrieve Private Key from Remote Signer Library")
102		}
103	}
104
105	minVersion, maxVersion, err := getTLSMinMaxVersionsClient(tlsConfig)
106	if err != nil {
107		return nil, err
108	}
109
110	// Create mTLS credentials for client.
111	config := &tls.Config{
112		VerifyPeerCertificate:  certverifier.VerifyServerCertificateChain(serverHostname, verificationMode, s2AStream, serverAuthorizationPolicy),
113		ServerName:             serverHostname,
114		InsecureSkipVerify:     true, // NOLINT
115		ClientSessionCache:     nil,
116		SessionTicketsDisabled: true,
117		MinVersion:             minVersion,
118		MaxVersion:             maxVersion,
119		NextProtos:             []string{h2},
120	}
121	if len(tlsConfig.CertificateChain) > 0 {
122		config.Certificates = []tls.Certificate{cert}
123	}
124	return config, nil
125}
126
127// GetTLSConfigurationForServer returns a tls.Config instance for use by a server application.
128func GetTLSConfigurationForServer(s2AStream stream.S2AStream, tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode) (*tls.Config, error) {
129	return &tls.Config{
130		GetConfigForClient: ClientConfig(tokenManager, localIdentities, verificationMode, s2AStream),
131	}, nil
132}
133
134// ClientConfig builds a TLS config for a server to establish a secure
135// connection with a client, based on SNI communicated during ClientHello.
136// Ensures that server presents the correct certificate to establish a TLS
137// connection.
138func ClientConfig(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, s2AStream stream.S2AStream) func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
139	return func(chi *tls.ClientHelloInfo) (*tls.Config, error) {
140		tlsConfig, err := getServerConfigFromS2Av2(tokenManager, localIdentities, chi.ServerName, s2AStream)
141		if err != nil {
142			return nil, err
143		}
144
145		var cert tls.Certificate
146		for i, v := range tlsConfig.CertificateChain {
147			// Populate Certificates field.
148			block, _ := pem.Decode([]byte(v))
149			if block == nil {
150				return nil, errors.New("certificate in CertificateChain obtained from S2Av2 is empty")
151			}
152			x509Cert, err := x509.ParseCertificate(block.Bytes)
153			if err != nil {
154				return nil, err
155			}
156			cert.Certificate = append(cert.Certificate, x509Cert.Raw)
157			if i == 0 {
158				cert.Leaf = x509Cert
159			}
160		}
161
162		cert.PrivateKey = remotesigner.New(cert.Leaf, s2AStream)
163		if cert.PrivateKey == nil {
164			return nil, errors.New("failed to retrieve Private Key from Remote Signer Library")
165		}
166
167		minVersion, maxVersion, err := getTLSMinMaxVersionsServer(tlsConfig)
168		if err != nil {
169			return nil, err
170		}
171
172		clientAuth := getTLSClientAuthType(tlsConfig)
173
174		var cipherSuites []uint16
175		cipherSuites = getCipherSuites(tlsConfig.Ciphersuites)
176
177		// Create mTLS credentials for server.
178		return &tls.Config{
179			Certificates:           []tls.Certificate{cert},
180			VerifyPeerCertificate:  certverifier.VerifyClientCertificateChain(verificationMode, s2AStream),
181			ClientAuth:             clientAuth,
182			CipherSuites:           cipherSuites,
183			SessionTicketsDisabled: true,
184			MinVersion:             minVersion,
185			MaxVersion:             maxVersion,
186			NextProtos:             []string{h2},
187		}, nil
188	}
189}
190
191func getCipherSuites(tlsConfigCipherSuites []commonpb.Ciphersuite) []uint16 {
192	var tlsGoCipherSuites []uint16
193	for _, v := range tlsConfigCipherSuites {
194		s := getTLSCipherSuite(v)
195		if s != 0xffff {
196			tlsGoCipherSuites = append(tlsGoCipherSuites, s)
197		}
198	}
199	return tlsGoCipherSuites
200}
201
202func getTLSCipherSuite(tlsCipherSuite commonpb.Ciphersuite) uint16 {
203	switch tlsCipherSuite {
204	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256:
205		return tls.TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256
206	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384:
207		return tls.TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384
208	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256:
209		return tls.TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256
210	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_AES_128_GCM_SHA256:
211		return tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
212	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_AES_256_GCM_SHA384:
213		return tls.TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384
214	case commonpb.Ciphersuite_CIPHERSUITE_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256:
215		return tls.TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256
216	default:
217		return 0xffff
218	}
219}
220
221func getServerConfigFromS2Av2(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity, sni string, s2AStream stream.S2AStream) (*s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration, error) {
222	authMechanisms := getAuthMechanisms(tokenManager, localIdentities)
223	var locID *commonpb.Identity
224	if localIdentities != nil {
225		locID = localIdentities[0]
226	}
227
228	if err := s2AStream.Send(&s2av2pb.SessionReq{
229		LocalIdentity:            locID,
230		AuthenticationMechanisms: authMechanisms,
231		ReqOneof: &s2av2pb.SessionReq_GetTlsConfigurationReq{
232			GetTlsConfigurationReq: &s2av2pb.GetTlsConfigurationReq{
233				ConnectionSide: commonpb.ConnectionSide_CONNECTION_SIDE_SERVER,
234				Sni:            sni,
235			},
236		},
237	}); err != nil {
238		return nil, err
239	}
240
241	resp, err := s2AStream.Recv()
242	if err != nil {
243		return nil, err
244	}
245
246	// TODO(rmehta19): Add unit test for this if statement.
247	if (resp.GetStatus() != nil) && (resp.GetStatus().Code != uint32(codes.OK)) {
248		return nil, fmt.Errorf("failed to get TLS configuration from S2A: %d, %v", resp.GetStatus().Code, resp.GetStatus().Details)
249	}
250
251	return resp.GetGetTlsConfigurationResp().GetServerTlsConfiguration(), nil
252}
253
254func getTLSClientAuthType(tlsConfig *s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration) tls.ClientAuthType {
255	var clientAuth tls.ClientAuthType
256	switch x := tlsConfig.RequestClientCertificate; x {
257	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_DONT_REQUEST_CLIENT_CERTIFICATE:
258		clientAuth = tls.NoClientCert
259	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
260		clientAuth = tls.RequestClientCert
261	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_CLIENT_CERTIFICATE_AND_VERIFY:
262		// This case actually maps to tls.VerifyClientCertIfGiven. However this
263		// mapping triggers normal verification, followed by custom verification,
264		// specified in VerifyPeerCertificate. To bypass normal verification, and
265		// only do custom verification we set clientAuth to RequireAnyClientCert or
266		// RequestClientCert. See https://github.com/google/s2a-go/pull/43 for full
267		// discussion.
268		clientAuth = tls.RequireAnyClientCert
269	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_BUT_DONT_VERIFY:
270		clientAuth = tls.RequireAnyClientCert
271	case s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration_REQUEST_AND_REQUIRE_CLIENT_CERTIFICATE_AND_VERIFY:
272		// This case actually maps to tls.RequireAndVerifyClientCert. However this
273		// mapping triggers normal verification, followed by custom verification,
274		// specified in VerifyPeerCertificate. To bypass normal verification, and
275		// only do custom verification we set clientAuth to RequireAnyClientCert or
276		// RequestClientCert. See https://github.com/google/s2a-go/pull/43 for full
277		// discussion.
278		clientAuth = tls.RequireAnyClientCert
279	default:
280		clientAuth = tls.RequireAnyClientCert
281	}
282	return clientAuth
283}
284
285func getAuthMechanisms(tokenManager tokenmanager.AccessTokenManager, localIdentities []*commonpb.Identity) []*s2av2pb.AuthenticationMechanism {
286	if tokenManager == nil {
287		return nil
288	}
289	if len(localIdentities) == 0 {
290		token, err := tokenManager.DefaultToken()
291		if err != nil {
292			grpclog.Infof("Unable to get token for empty local identity: %v", err)
293			return nil
294		}
295		return []*s2av2pb.AuthenticationMechanism{
296			{
297				MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
298					Token: token,
299				},
300			},
301		}
302	}
303	var authMechanisms []*s2av2pb.AuthenticationMechanism
304	for _, localIdentity := range localIdentities {
305		if localIdentity == nil {
306			token, err := tokenManager.DefaultToken()
307			if err != nil {
308				grpclog.Infof("Unable to get default token for local identity %v: %v", localIdentity, err)
309				continue
310			}
311			authMechanisms = append(authMechanisms, &s2av2pb.AuthenticationMechanism{
312				Identity: localIdentity,
313				MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
314					Token: token,
315				},
316			})
317		} else {
318			token, err := tokenManager.Token(localIdentity)
319			if err != nil {
320				grpclog.Infof("Unable to get token for local identity %v: %v", localIdentity, err)
321				continue
322			}
323			authMechanisms = append(authMechanisms, &s2av2pb.AuthenticationMechanism{
324				Identity: localIdentity,
325				MechanismOneof: &s2av2pb.AuthenticationMechanism_Token{
326					Token: token,
327				},
328			})
329		}
330	}
331	return authMechanisms
332}
333
334// TODO(rmehta19): refactor switch statements into a helper function.
335func getTLSMinMaxVersionsClient(tlsConfig *s2av2pb.GetTlsConfigurationResp_ClientTlsConfiguration) (uint16, uint16, error) {
336	// Map S2Av2 TLSVersion to consts defined in tls package.
337	var minVersion uint16
338	var maxVersion uint16
339	switch x := tlsConfig.MinTlsVersion; x {
340	case commonpb.TLSVersion_TLS_VERSION_1_0:
341		minVersion = tls.VersionTLS10
342	case commonpb.TLSVersion_TLS_VERSION_1_1:
343		minVersion = tls.VersionTLS11
344	case commonpb.TLSVersion_TLS_VERSION_1_2:
345		minVersion = tls.VersionTLS12
346	case commonpb.TLSVersion_TLS_VERSION_1_3:
347		minVersion = tls.VersionTLS13
348	default:
349		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MinTlsVersion: %v", x)
350	}
351
352	switch x := tlsConfig.MaxTlsVersion; x {
353	case commonpb.TLSVersion_TLS_VERSION_1_0:
354		maxVersion = tls.VersionTLS10
355	case commonpb.TLSVersion_TLS_VERSION_1_1:
356		maxVersion = tls.VersionTLS11
357	case commonpb.TLSVersion_TLS_VERSION_1_2:
358		maxVersion = tls.VersionTLS12
359	case commonpb.TLSVersion_TLS_VERSION_1_3:
360		maxVersion = tls.VersionTLS13
361	default:
362		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MaxTlsVersion: %v", x)
363	}
364	if minVersion > maxVersion {
365		return minVersion, maxVersion, errors.New("S2Av2 provided minVersion > maxVersion")
366	}
367	return minVersion, maxVersion, nil
368}
369
370func getTLSMinMaxVersionsServer(tlsConfig *s2av2pb.GetTlsConfigurationResp_ServerTlsConfiguration) (uint16, uint16, error) {
371	// Map S2Av2 TLSVersion to consts defined in tls package.
372	var minVersion uint16
373	var maxVersion uint16
374	switch x := tlsConfig.MinTlsVersion; x {
375	case commonpb.TLSVersion_TLS_VERSION_1_0:
376		minVersion = tls.VersionTLS10
377	case commonpb.TLSVersion_TLS_VERSION_1_1:
378		minVersion = tls.VersionTLS11
379	case commonpb.TLSVersion_TLS_VERSION_1_2:
380		minVersion = tls.VersionTLS12
381	case commonpb.TLSVersion_TLS_VERSION_1_3:
382		minVersion = tls.VersionTLS13
383	default:
384		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MinTlsVersion: %v", x)
385	}
386
387	switch x := tlsConfig.MaxTlsVersion; x {
388	case commonpb.TLSVersion_TLS_VERSION_1_0:
389		maxVersion = tls.VersionTLS10
390	case commonpb.TLSVersion_TLS_VERSION_1_1:
391		maxVersion = tls.VersionTLS11
392	case commonpb.TLSVersion_TLS_VERSION_1_2:
393		maxVersion = tls.VersionTLS12
394	case commonpb.TLSVersion_TLS_VERSION_1_3:
395		maxVersion = tls.VersionTLS13
396	default:
397		return minVersion, maxVersion, fmt.Errorf("S2Av2 provided invalid MaxTlsVersion: %v", x)
398	}
399	if minVersion > maxVersion {
400		return minVersion, maxVersion, errors.New("S2Av2 provided minVersion > maxVersion")
401	}
402	return minVersion, maxVersion, nil
403}