s2av2.go

  1/*
  2 *
  3 * Copyright 2022 Google LLC
  4 *
  5 * Licensed under the Apache License, Version 2.0 (the "License");
  6 * you may not use this file except in compliance with the License.
  7 * You may obtain a copy of the License at
  8 *
  9 *     https://www.apache.org/licenses/LICENSE-2.0
 10 *
 11 * Unless required by applicable law or agreed to in writing, software
 12 * distributed under the License is distributed on an "AS IS" BASIS,
 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 * See the License for the specific language governing permissions and
 15 * limitations under the License.
 16 *
 17 */
 18
 19// Package v2 provides the S2Av2 transport credentials used by a gRPC
 20// application.
 21package v2
 22
 23import (
 24	"context"
 25	"crypto/tls"
 26	"errors"
 27	"net"
 28	"os"
 29	"time"
 30
 31	"github.com/google/s2a-go/fallback"
 32	"github.com/google/s2a-go/internal/handshaker/service"
 33	"github.com/google/s2a-go/internal/tokenmanager"
 34	"github.com/google/s2a-go/internal/v2/tlsconfigstore"
 35	"github.com/google/s2a-go/retry"
 36	"github.com/google/s2a-go/stream"
 37	"google.golang.org/grpc"
 38	"google.golang.org/grpc/credentials"
 39	"google.golang.org/grpc/grpclog"
 40	"google.golang.org/protobuf/proto"
 41
 42	commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
 43	s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
 44)
 45
 46const (
 47	s2aSecurityProtocol = "tls"
 48	defaultS2ATimeout   = 6 * time.Second
 49)
 50
 51// An environment variable, which sets the timeout enforced on the connection to the S2A service for handshake.
 52const s2aTimeoutEnv = "S2A_TIMEOUT"
 53
 54type s2av2TransportCreds struct {
 55	info           *credentials.ProtocolInfo
 56	isClient       bool
 57	serverName     string
 58	s2av2Address   string
 59	transportCreds credentials.TransportCredentials
 60	tokenManager   *tokenmanager.AccessTokenManager
 61	// localIdentity should only be used by the client.
 62	localIdentity *commonpb.Identity
 63	// localIdentities should only be used by the server.
 64	localIdentities           []*commonpb.Identity
 65	verificationMode          s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
 66	fallbackClientHandshake   fallback.ClientHandshake
 67	getS2AStream              func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)
 68	serverAuthorizationPolicy []byte
 69}
 70
 71// NewClientCreds returns a client-side transport credentials object that uses
 72// the S2Av2 to establish a secure connection with a server.
 73func NewClientCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentity *commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, fallbackClientHandshakeFunc fallback.ClientHandshake, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error), serverAuthorizationPolicy []byte) (credentials.TransportCredentials, error) {
 74	// Create an AccessTokenManager instance to use to authenticate to S2Av2.
 75	accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
 76
 77	creds := &s2av2TransportCreds{
 78		info: &credentials.ProtocolInfo{
 79			SecurityProtocol: s2aSecurityProtocol,
 80		},
 81		isClient:                  true,
 82		serverName:                "",
 83		s2av2Address:              s2av2Address,
 84		transportCreds:            transportCreds,
 85		localIdentity:             localIdentity,
 86		verificationMode:          verificationMode,
 87		fallbackClientHandshake:   fallbackClientHandshakeFunc,
 88		getS2AStream:              getS2AStream,
 89		serverAuthorizationPolicy: serverAuthorizationPolicy,
 90	}
 91	if err != nil {
 92		creds.tokenManager = nil
 93	} else {
 94		creds.tokenManager = &accessTokenManager
 95	}
 96	if grpclog.V(1) {
 97		grpclog.Info("Created client S2Av2 transport credentials.")
 98	}
 99	return creds, nil
100}
101
102// NewServerCreds returns a server-side transport credentials object that uses
103// the S2Av2 to establish a secure connection with a client.
104func NewServerCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentities []*commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (credentials.TransportCredentials, error) {
105	// Create an AccessTokenManager instance to use to authenticate to S2Av2.
106	accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
107	creds := &s2av2TransportCreds{
108		info: &credentials.ProtocolInfo{
109			SecurityProtocol: s2aSecurityProtocol,
110		},
111		isClient:         false,
112		s2av2Address:     s2av2Address,
113		transportCreds:   transportCreds,
114		localIdentities:  localIdentities,
115		verificationMode: verificationMode,
116		getS2AStream:     getS2AStream,
117	}
118	if err != nil {
119		creds.tokenManager = nil
120	} else {
121		creds.tokenManager = &accessTokenManager
122	}
123	if grpclog.V(1) {
124		grpclog.Info("Created server S2Av2 transport credentials.")
125	}
126	return creds, nil
127}
128
129// ClientHandshake performs a client-side mTLS handshake using the S2Av2.
130func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
131	if !c.isClient {
132		return nil, nil, errors.New("client handshake called using server transport credentials")
133	}
134	// Remove the port from serverAuthority.
135	serverName := removeServerNamePort(serverAuthority)
136	timeoutCtx, cancel := context.WithTimeout(ctx, GetS2ATimeout())
137	defer cancel()
138	var s2AStream stream.S2AStream
139	var err error
140	retry.Run(timeoutCtx,
141		func() error {
142			s2AStream, err = createStream(timeoutCtx, c.s2av2Address, c.transportCreds, c.getS2AStream)
143			return err
144		})
145	if err != nil {
146		grpclog.Infof("Failed to connect to S2Av2: %v", err)
147		if c.fallbackClientHandshake != nil {
148			return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
149		}
150		return nil, nil, err
151	}
152	defer s2AStream.CloseSend()
153	if grpclog.V(1) {
154		grpclog.Infof("Connected to S2Av2.")
155	}
156	var config *tls.Config
157
158	var tokenManager tokenmanager.AccessTokenManager
159	if c.tokenManager == nil {
160		tokenManager = nil
161	} else {
162		tokenManager = *c.tokenManager
163	}
164
165	sn := serverName
166	if c.serverName != "" {
167		sn = c.serverName
168	}
169	retry.Run(timeoutCtx,
170		func() error {
171			config, err = tlsconfigstore.GetTLSConfigurationForClient(sn, s2AStream, tokenManager, c.localIdentity, c.verificationMode, c.serverAuthorizationPolicy)
172			return err
173		})
174	if err != nil {
175		grpclog.Info("Failed to get client TLS config from S2Av2: %v", err)
176		if c.fallbackClientHandshake != nil {
177			return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
178		}
179		return nil, nil, err
180	}
181	if grpclog.V(1) {
182		grpclog.Infof("Got client TLS config from S2Av2.")
183	}
184
185	creds := credentials.NewTLS(config)
186	conn, authInfo, err := creds.ClientHandshake(timeoutCtx, serverName, rawConn)
187	if err != nil {
188		grpclog.Infof("Failed to do client handshake using S2Av2: %v", err)
189		if c.fallbackClientHandshake != nil {
190			return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
191		}
192		return nil, nil, err
193	}
194	grpclog.Infof("client-side handshake is done using S2Av2 to: %s", serverName)
195
196	return conn, authInfo, err
197}
198
199// ServerHandshake performs a server-side mTLS handshake using the S2Av2.
200func (c *s2av2TransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
201	if c.isClient {
202		return nil, nil, errors.New("server handshake called using client transport credentials")
203	}
204	ctx, cancel := context.WithTimeout(context.Background(), GetS2ATimeout())
205	defer cancel()
206	var s2AStream stream.S2AStream
207	var err error
208	retry.Run(ctx,
209		func() error {
210			s2AStream, err = createStream(ctx, c.s2av2Address, c.transportCreds, c.getS2AStream)
211			return err
212		})
213	if err != nil {
214		grpclog.Infof("Failed to connect to S2Av2: %v", err)
215		return nil, nil, err
216	}
217	defer s2AStream.CloseSend()
218	if grpclog.V(1) {
219		grpclog.Infof("Connected to S2Av2.")
220	}
221
222	var tokenManager tokenmanager.AccessTokenManager
223	if c.tokenManager == nil {
224		tokenManager = nil
225	} else {
226		tokenManager = *c.tokenManager
227	}
228
229	var config *tls.Config
230	retry.Run(ctx,
231		func() error {
232			config, err = tlsconfigstore.GetTLSConfigurationForServer(s2AStream, tokenManager, c.localIdentities, c.verificationMode)
233			return err
234		})
235	if err != nil {
236		grpclog.Infof("Failed to get server TLS config from S2Av2: %v", err)
237		return nil, nil, err
238	}
239	if grpclog.V(1) {
240		grpclog.Infof("Got server TLS config from S2Av2.")
241	}
242
243	creds := credentials.NewTLS(config)
244	conn, authInfo, err := creds.ServerHandshake(rawConn)
245	if err != nil {
246		grpclog.Infof("Failed to do server handshake using S2Av2: %v", err)
247		return nil, nil, err
248	}
249	return conn, authInfo, err
250}
251
252// Info returns protocol info of s2av2TransportCreds.
253func (c *s2av2TransportCreds) Info() credentials.ProtocolInfo {
254	return *c.info
255}
256
257// Clone makes a deep copy of s2av2TransportCreds.
258func (c *s2av2TransportCreds) Clone() credentials.TransportCredentials {
259	info := *c.info
260	serverName := c.serverName
261	fallbackClientHandshake := c.fallbackClientHandshake
262
263	s2av2Address := c.s2av2Address
264	var tokenManager tokenmanager.AccessTokenManager
265	if c.tokenManager == nil {
266		tokenManager = nil
267	} else {
268		tokenManager = *c.tokenManager
269	}
270	verificationMode := c.verificationMode
271	var localIdentity *commonpb.Identity
272	if c.localIdentity != nil {
273		localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
274	}
275	var localIdentities []*commonpb.Identity
276	if c.localIdentities != nil {
277		localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
278		for i, localIdentity := range c.localIdentities {
279			localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
280		}
281	}
282	creds := &s2av2TransportCreds{
283		info:                    &info,
284		isClient:                c.isClient,
285		serverName:              serverName,
286		fallbackClientHandshake: fallbackClientHandshake,
287		s2av2Address:            s2av2Address,
288		localIdentity:           localIdentity,
289		localIdentities:         localIdentities,
290		verificationMode:        verificationMode,
291	}
292	if c.tokenManager == nil {
293		creds.tokenManager = nil
294	} else {
295		creds.tokenManager = &tokenManager
296	}
297	return creds
298}
299
300// NewClientTLSConfig returns a tls.Config instance that uses S2Av2 to establish a TLS connection as
301// a client. The tls.Config MUST only be used to establish a single TLS connection.
302func NewClientTLSConfig(
303	ctx context.Context,
304	s2av2Address string,
305	transportCreds credentials.TransportCredentials,
306	tokenManager tokenmanager.AccessTokenManager,
307	verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode,
308	serverName string,
309	serverAuthorizationPolicy []byte) (*tls.Config, error) {
310	s2AStream, err := createStream(ctx, s2av2Address, transportCreds, nil)
311	if err != nil {
312		grpclog.Infof("Failed to connect to S2Av2: %v", err)
313		return nil, err
314	}
315
316	return tlsconfigstore.GetTLSConfigurationForClient(removeServerNamePort(serverName), s2AStream, tokenManager, nil, verificationMode, serverAuthorizationPolicy)
317}
318
319// OverrideServerName sets the ServerName in the s2av2TransportCreds protocol
320// info. The ServerName MUST be a hostname.
321func (c *s2av2TransportCreds) OverrideServerName(serverNameOverride string) error {
322	serverName := removeServerNamePort(serverNameOverride)
323	c.info.ServerName = serverName
324	c.serverName = serverName
325	return nil
326}
327
328// Remove the trailing port from server name.
329func removeServerNamePort(serverName string) string {
330	name, _, err := net.SplitHostPort(serverName)
331	if err != nil {
332		name = serverName
333	}
334	return name
335}
336
337type s2AGrpcStream struct {
338	stream s2av2pb.S2AService_SetUpSessionClient
339}
340
341func (x s2AGrpcStream) Send(m *s2av2pb.SessionReq) error {
342	return x.stream.Send(m)
343}
344
345func (x s2AGrpcStream) Recv() (*s2av2pb.SessionResp, error) {
346	return x.stream.Recv()
347}
348
349func (x s2AGrpcStream) CloseSend() error {
350	return x.stream.CloseSend()
351}
352
353func createStream(ctx context.Context, s2av2Address string, transportCreds credentials.TransportCredentials, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (stream.S2AStream, error) {
354	if getS2AStream != nil {
355		return getS2AStream(ctx, s2av2Address)
356	}
357	// TODO(rmehta19): Consider whether to close the connection to S2Av2.
358	conn, err := service.Dial(ctx, s2av2Address, transportCreds)
359	if err != nil {
360		return nil, err
361	}
362	client := s2av2pb.NewS2AServiceClient(conn)
363	gRPCStream, err := client.SetUpSession(ctx, []grpc.CallOption{}...)
364	if err != nil {
365		return nil, err
366	}
367	return &s2AGrpcStream{
368		stream: gRPCStream,
369	}, nil
370}
371
372// GetS2ATimeout returns the timeout enforced on the connection to the S2A service for handshake.
373func GetS2ATimeout() time.Duration {
374	timeout, err := time.ParseDuration(os.Getenv(s2aTimeoutEnv))
375	if err != nil {
376		return defaultS2ATimeout
377	}
378	return timeout
379}