handshaker.go

  1/*
  2 *
  3 * Copyright 2021 Google LLC
  4 *
  5 * Licensed under the Apache License, Version 2.0 (the "License");
  6 * you may not use this file except in compliance with the License.
  7 * You may obtain a copy of the License at
  8 *
  9 *     https://www.apache.org/licenses/LICENSE-2.0
 10 *
 11 * Unless required by applicable law or agreed to in writing, software
 12 * distributed under the License is distributed on an "AS IS" BASIS,
 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 * See the License for the specific language governing permissions and
 15 * limitations under the License.
 16 *
 17 */
 18
 19// Package handshaker communicates with the S2A handshaker service.
 20package handshaker
 21
 22import (
 23	"context"
 24	"errors"
 25	"fmt"
 26	"io"
 27	"net"
 28	"sync"
 29
 30	"github.com/google/s2a-go/internal/authinfo"
 31	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
 32	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
 33	"github.com/google/s2a-go/internal/record"
 34	"github.com/google/s2a-go/internal/tokenmanager"
 35	grpc "google.golang.org/grpc"
 36	"google.golang.org/grpc/codes"
 37	"google.golang.org/grpc/credentials"
 38	"google.golang.org/grpc/grpclog"
 39)
 40
 41var (
 42	// appProtocol contains the application protocol accepted by the handshaker.
 43	appProtocol = "grpc"
 44	// frameLimit is the maximum size of a frame in bytes.
 45	frameLimit = 1024 * 64
 46	// peerNotRespondingError is the error thrown when the peer doesn't respond.
 47	errPeerNotResponding = errors.New("peer is not responding and re-connection should be attempted")
 48)
 49
 50// Handshaker defines a handshaker interface.
 51type Handshaker interface {
 52	// ClientHandshake starts and completes a TLS handshake from the client side,
 53	// and returns a secure connection along with additional auth information.
 54	ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
 55	// ServerHandshake starts and completes a TLS handshake from the server side,
 56	// and returns a secure connection along with additional auth information.
 57	ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
 58	// Close terminates the Handshaker. It should be called when the handshake
 59	// is complete.
 60	Close() error
 61}
 62
 63// ClientHandshakerOptions contains the options needed to configure the S2A
 64// handshaker service on the client-side.
 65type ClientHandshakerOptions struct {
 66	// MinTLSVersion specifies the min TLS version supported by the client.
 67	MinTLSVersion commonpb.TLSVersion
 68	// MaxTLSVersion specifies the max TLS version supported by the client.
 69	MaxTLSVersion commonpb.TLSVersion
 70	// TLSCiphersuites is the ordered list of ciphersuites supported by the
 71	// client.
 72	TLSCiphersuites []commonpb.Ciphersuite
 73	// TargetIdentities contains a list of allowed server identities. One of the
 74	// target identities should match the peer identity in the handshake
 75	// result; otherwise, the handshake fails.
 76	TargetIdentities []*commonpb.Identity
 77	// LocalIdentity is the local identity of the client application. If none is
 78	// provided, then the S2A will choose the default identity.
 79	LocalIdentity *commonpb.Identity
 80	// TargetName is the allowed server name, which may be used for server
 81	// authorization check by the S2A if it is provided.
 82	TargetName string
 83	// EnsureProcessSessionTickets allows users to wait and ensure that all
 84	// available session tickets are sent to S2A before a process completes.
 85	EnsureProcessSessionTickets *sync.WaitGroup
 86}
 87
 88// ServerHandshakerOptions contains the options needed to configure the S2A
 89// handshaker service on the server-side.
 90type ServerHandshakerOptions struct {
 91	// MinTLSVersion specifies the min TLS version supported by the server.
 92	MinTLSVersion commonpb.TLSVersion
 93	// MaxTLSVersion specifies the max TLS version supported by the server.
 94	MaxTLSVersion commonpb.TLSVersion
 95	// TLSCiphersuites is the ordered list of ciphersuites supported by the
 96	// server.
 97	TLSCiphersuites []commonpb.Ciphersuite
 98	// LocalIdentities is the list of local identities that may be assumed by
 99	// the server. If no local identity is specified, then the S2A chooses a
100	// default local identity.
101	LocalIdentities []*commonpb.Identity
102}
103
104// s2aHandshaker performs a TLS handshake using the S2A handshaker service.
105type s2aHandshaker struct {
106	// stream is used to communicate with the S2A handshaker service.
107	stream s2apb.S2AService_SetUpSessionClient
108	// conn is the connection to the peer.
109	conn net.Conn
110	// clientOpts should be non-nil iff the handshaker is client-side.
111	clientOpts *ClientHandshakerOptions
112	// serverOpts should be non-nil iff the handshaker is server-side.
113	serverOpts *ServerHandshakerOptions
114	// isClient determines if the handshaker is client or server side.
115	isClient bool
116	// hsAddr stores the address of the S2A handshaker service.
117	hsAddr string
118	// tokenManager manages access tokens for authenticating to S2A.
119	tokenManager tokenmanager.AccessTokenManager
120	// localIdentities is the set of local identities for whom the
121	// tokenManager should fetch a token when preparing a request to be
122	// sent to S2A.
123	localIdentities []*commonpb.Identity
124}
125
126// NewClientHandshaker creates an s2aHandshaker instance that performs a
127// client-side TLS handshake using the S2A handshaker service.
128func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ClientHandshakerOptions) (Handshaker, error) {
129	stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
130	if err != nil {
131		return nil, err
132	}
133	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
134	if err != nil {
135		grpclog.Infof("failed to create single token access token manager: %v", err)
136	}
137	return newClientHandshaker(stream, c, hsAddr, opts, tokenManager), nil
138}
139
140func newClientHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ClientHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
141	var localIdentities []*commonpb.Identity
142	if opts != nil {
143		localIdentities = []*commonpb.Identity{opts.LocalIdentity}
144	}
145	return &s2aHandshaker{
146		stream:          stream,
147		conn:            c,
148		clientOpts:      opts,
149		isClient:        true,
150		hsAddr:          hsAddr,
151		tokenManager:    tokenManager,
152		localIdentities: localIdentities,
153	}
154}
155
156// NewServerHandshaker creates an s2aHandshaker instance that performs a
157// server-side TLS handshake using the S2A handshaker service.
158func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ServerHandshakerOptions) (Handshaker, error) {
159	stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
160	if err != nil {
161		return nil, err
162	}
163	tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
164	if err != nil {
165		grpclog.Infof("failed to create single token access token manager: %v", err)
166	}
167	return newServerHandshaker(stream, c, hsAddr, opts, tokenManager), nil
168}
169
170func newServerHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ServerHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
171	var localIdentities []*commonpb.Identity
172	if opts != nil {
173		localIdentities = opts.LocalIdentities
174	}
175	return &s2aHandshaker{
176		stream:          stream,
177		conn:            c,
178		serverOpts:      opts,
179		isClient:        false,
180		hsAddr:          hsAddr,
181		tokenManager:    tokenManager,
182		localIdentities: localIdentities,
183	}
184}
185
186// ClientHandshake performs a client-side TLS handshake using the S2A handshaker
187// service. When complete, returns a TLS connection.
188func (h *s2aHandshaker) ClientHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
189	if !h.isClient {
190		return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client-side handshake")
191	}
192	// Extract the hostname from the target name. The target name is assumed to be an authority.
193	hostname, _, err := net.SplitHostPort(h.clientOpts.TargetName)
194	if err != nil {
195		// If the target name had no host port or could not be parsed, use it as is.
196		hostname = h.clientOpts.TargetName
197	}
198
199	// Prepare a client start message to send to the S2A handshaker service.
200	req := &s2apb.SessionReq{
201		ReqOneof: &s2apb.SessionReq_ClientStart{
202			ClientStart: &s2apb.ClientSessionStartReq{
203				ApplicationProtocols: []string{appProtocol},
204				MinTlsVersion:        h.clientOpts.MinTLSVersion,
205				MaxTlsVersion:        h.clientOpts.MaxTLSVersion,
206				TlsCiphersuites:      h.clientOpts.TLSCiphersuites,
207				TargetIdentities:     h.clientOpts.TargetIdentities,
208				LocalIdentity:        h.clientOpts.LocalIdentity,
209				TargetName:           hostname,
210			},
211		},
212		AuthMechanisms: h.getAuthMechanisms(),
213	}
214	conn, result, err := h.setUpSession(req)
215	if err != nil {
216		return nil, nil, err
217	}
218	authInfo, err := authinfo.NewS2AAuthInfo(result)
219	if err != nil {
220		return nil, nil, err
221	}
222	return conn, authInfo, nil
223}
224
225// ServerHandshake performs a server-side TLS handshake using the S2A handshaker
226// service. When complete, returns a TLS connection.
227func (h *s2aHandshaker) ServerHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
228	if h.isClient {
229		return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server-side handshake")
230	}
231	p := make([]byte, frameLimit)
232	n, err := h.conn.Read(p)
233	if err != nil {
234		return nil, nil, err
235	}
236	// Prepare a server start message to send to the S2A handshaker service.
237	req := &s2apb.SessionReq{
238		ReqOneof: &s2apb.SessionReq_ServerStart{
239			ServerStart: &s2apb.ServerSessionStartReq{
240				ApplicationProtocols: []string{appProtocol},
241				MinTlsVersion:        h.serverOpts.MinTLSVersion,
242				MaxTlsVersion:        h.serverOpts.MaxTLSVersion,
243				TlsCiphersuites:      h.serverOpts.TLSCiphersuites,
244				LocalIdentities:      h.serverOpts.LocalIdentities,
245				InBytes:              p[:n],
246			},
247		},
248		AuthMechanisms: h.getAuthMechanisms(),
249	}
250	conn, result, err := h.setUpSession(req)
251	if err != nil {
252		return nil, nil, err
253	}
254	authInfo, err := authinfo.NewS2AAuthInfo(result)
255	if err != nil {
256		return nil, nil, err
257	}
258	return conn, authInfo, nil
259}
260
261// setUpSession proxies messages between the peer and the S2A handshaker
262// service.
263func (h *s2aHandshaker) setUpSession(req *s2apb.SessionReq) (net.Conn, *s2apb.SessionResult, error) {
264	resp, err := h.accessHandshakerService(req)
265	if err != nil {
266		return nil, nil, err
267	}
268	// Check if the returned status is an error.
269	if resp.GetStatus() != nil {
270		if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
271			return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
272		}
273	}
274	// Calculate the extra unread bytes from the Session. Attempting to consume
275	// more than the bytes sent will throw an error.
276	var extra []byte
277	if req.GetServerStart() != nil {
278		if resp.GetBytesConsumed() > uint32(len(req.GetServerStart().GetInBytes())) {
279			return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
280		}
281		extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
282	}
283	result, extra, err := h.processUntilDone(resp, extra)
284	if err != nil {
285		return nil, nil, err
286	}
287	if result.GetLocalIdentity() == nil {
288		return nil, nil, errors.New("local identity must be populated in session result")
289	}
290
291	// Create a new TLS record protocol using the Session Result.
292	newConn, err := record.NewConn(&record.ConnParameters{
293		NetConn:                     h.conn,
294		Ciphersuite:                 result.GetState().GetTlsCiphersuite(),
295		TLSVersion:                  result.GetState().GetTlsVersion(),
296		InTrafficSecret:             result.GetState().GetInKey(),
297		OutTrafficSecret:            result.GetState().GetOutKey(),
298		UnusedBuf:                   extra,
299		InSequence:                  result.GetState().GetInSequence(),
300		OutSequence:                 result.GetState().GetOutSequence(),
301		HSAddr:                      h.hsAddr,
302		ConnectionID:                result.GetState().GetConnectionId(),
303		LocalIdentity:               result.GetLocalIdentity(),
304		EnsureProcessSessionTickets: h.ensureProcessSessionTickets(),
305	})
306	if err != nil {
307		return nil, nil, err
308	}
309	return newConn, result, nil
310}
311
312func (h *s2aHandshaker) ensureProcessSessionTickets() *sync.WaitGroup {
313	if h.clientOpts == nil {
314		return nil
315	}
316	return h.clientOpts.EnsureProcessSessionTickets
317}
318
319// accessHandshakerService sends the session request to the S2A handshaker
320// service and returns the session response.
321func (h *s2aHandshaker) accessHandshakerService(req *s2apb.SessionReq) (*s2apb.SessionResp, error) {
322	if err := h.stream.Send(req); err != nil {
323		return nil, err
324	}
325	resp, err := h.stream.Recv()
326	if err != nil {
327		return nil, err
328	}
329	return resp, nil
330}
331
332// processUntilDone continues proxying messages between the peer and the S2A
333// handshaker service until the handshaker service returns the SessionResult at
334// the end of the handshake or an error occurs.
335func (h *s2aHandshaker) processUntilDone(resp *s2apb.SessionResp, unusedBytes []byte) (*s2apb.SessionResult, []byte, error) {
336	for {
337		if len(resp.OutFrames) > 0 {
338			if _, err := h.conn.Write(resp.OutFrames); err != nil {
339				return nil, nil, err
340			}
341		}
342		if resp.Result != nil {
343			return resp.Result, unusedBytes, nil
344		}
345		buf := make([]byte, frameLimit)
346		n, err := h.conn.Read(buf)
347		if err != nil && err != io.EOF {
348			return nil, nil, err
349		}
350		// If there is nothing to send to the handshaker service and nothing is
351		// received from the peer, then we are stuck. This covers the case when
352		// the peer is not responding. Note that handshaker service connection
353		// issues are caught in accessHandshakerService before we even get
354		// here.
355		if len(resp.OutFrames) == 0 && n == 0 {
356			return nil, nil, errPeerNotResponding
357		}
358		// Append extra bytes from the previous interaction with the handshaker
359		// service with the current buffer read from conn.
360		p := append(unusedBytes, buf[:n]...)
361		// From here on, p and unusedBytes point to the same slice.
362		resp, err = h.accessHandshakerService(&s2apb.SessionReq{
363			ReqOneof: &s2apb.SessionReq_Next{
364				Next: &s2apb.SessionNextReq{
365					InBytes: p,
366				},
367			},
368			AuthMechanisms: h.getAuthMechanisms(),
369		})
370		if err != nil {
371			return nil, nil, err
372		}
373
374		// Cache the local identity returned by S2A, if it is populated. This
375		// overwrites any existing local identities. This is done because, once the
376		// S2A has selected a local identity, then only that local identity should
377		// be asserted in future requests until the end of the current handshake.
378		if resp.GetLocalIdentity() != nil {
379			h.localIdentities = []*commonpb.Identity{resp.GetLocalIdentity()}
380		}
381
382		// Set unusedBytes based on the handshaker service response.
383		if resp.GetBytesConsumed() > uint32(len(p)) {
384			return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
385		}
386		unusedBytes = p[resp.GetBytesConsumed():]
387	}
388}
389
390// Close shuts down the handshaker and the stream to the S2A handshaker service
391// when the handshake is complete. It should be called when the caller obtains
392// the secure connection at the end of the handshake.
393func (h *s2aHandshaker) Close() error {
394	return h.stream.CloseSend()
395}
396
397func (h *s2aHandshaker) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
398	if h.tokenManager == nil {
399		return nil
400	}
401	// First handle the special case when no local identities have been provided
402	// by the application. In this case, an AuthenticationMechanism with no local
403	// identity will be sent.
404	if len(h.localIdentities) == 0 {
405		token, err := h.tokenManager.DefaultToken()
406		if err != nil {
407			grpclog.Infof("unable to get token for empty local identity: %v", err)
408			return nil
409		}
410		return []*s2apb.AuthenticationMechanism{
411			{
412				MechanismOneof: &s2apb.AuthenticationMechanism_Token{
413					Token: token,
414				},
415			},
416		}
417	}
418
419	// Next, handle the case where the application (or the S2A) has provided
420	// one or more local identities.
421	var authMechanisms []*s2apb.AuthenticationMechanism
422	for _, localIdentity := range h.localIdentities {
423		token, err := h.tokenManager.Token(localIdentity)
424		if err != nil {
425			grpclog.Infof("unable to get token for local identity %v: %v", localIdentity, err)
426			continue
427		}
428
429		authMechanism := &s2apb.AuthenticationMechanism{
430			Identity: localIdentity,
431			MechanismOneof: &s2apb.AuthenticationMechanism_Token{
432				Token: token,
433			},
434		}
435		authMechanisms = append(authMechanisms, authMechanism)
436	}
437	return authMechanisms
438}