1/*
2 *
3 * Copyright 2021 Google LLC
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package handshaker communicates with the S2A handshaker service.
20package handshaker
21
22import (
23 "context"
24 "errors"
25 "fmt"
26 "io"
27 "net"
28 "sync"
29
30 "github.com/google/s2a-go/internal/authinfo"
31 commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
32 s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
33 "github.com/google/s2a-go/internal/record"
34 "github.com/google/s2a-go/internal/tokenmanager"
35 grpc "google.golang.org/grpc"
36 "google.golang.org/grpc/codes"
37 "google.golang.org/grpc/credentials"
38 "google.golang.org/grpc/grpclog"
39)
40
41var (
42 // appProtocol contains the application protocol accepted by the handshaker.
43 appProtocol = "grpc"
44 // frameLimit is the maximum size of a frame in bytes.
45 frameLimit = 1024 * 64
46 // peerNotRespondingError is the error thrown when the peer doesn't respond.
47 errPeerNotResponding = errors.New("peer is not responding and re-connection should be attempted")
48)
49
50// Handshaker defines a handshaker interface.
51type Handshaker interface {
52 // ClientHandshake starts and completes a TLS handshake from the client side,
53 // and returns a secure connection along with additional auth information.
54 ClientHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
55 // ServerHandshake starts and completes a TLS handshake from the server side,
56 // and returns a secure connection along with additional auth information.
57 ServerHandshake(ctx context.Context) (net.Conn, credentials.AuthInfo, error)
58 // Close terminates the Handshaker. It should be called when the handshake
59 // is complete.
60 Close() error
61}
62
63// ClientHandshakerOptions contains the options needed to configure the S2A
64// handshaker service on the client-side.
65type ClientHandshakerOptions struct {
66 // MinTLSVersion specifies the min TLS version supported by the client.
67 MinTLSVersion commonpb.TLSVersion
68 // MaxTLSVersion specifies the max TLS version supported by the client.
69 MaxTLSVersion commonpb.TLSVersion
70 // TLSCiphersuites is the ordered list of ciphersuites supported by the
71 // client.
72 TLSCiphersuites []commonpb.Ciphersuite
73 // TargetIdentities contains a list of allowed server identities. One of the
74 // target identities should match the peer identity in the handshake
75 // result; otherwise, the handshake fails.
76 TargetIdentities []*commonpb.Identity
77 // LocalIdentity is the local identity of the client application. If none is
78 // provided, then the S2A will choose the default identity.
79 LocalIdentity *commonpb.Identity
80 // TargetName is the allowed server name, which may be used for server
81 // authorization check by the S2A if it is provided.
82 TargetName string
83 // EnsureProcessSessionTickets allows users to wait and ensure that all
84 // available session tickets are sent to S2A before a process completes.
85 EnsureProcessSessionTickets *sync.WaitGroup
86}
87
88// ServerHandshakerOptions contains the options needed to configure the S2A
89// handshaker service on the server-side.
90type ServerHandshakerOptions struct {
91 // MinTLSVersion specifies the min TLS version supported by the server.
92 MinTLSVersion commonpb.TLSVersion
93 // MaxTLSVersion specifies the max TLS version supported by the server.
94 MaxTLSVersion commonpb.TLSVersion
95 // TLSCiphersuites is the ordered list of ciphersuites supported by the
96 // server.
97 TLSCiphersuites []commonpb.Ciphersuite
98 // LocalIdentities is the list of local identities that may be assumed by
99 // the server. If no local identity is specified, then the S2A chooses a
100 // default local identity.
101 LocalIdentities []*commonpb.Identity
102}
103
104// s2aHandshaker performs a TLS handshake using the S2A handshaker service.
105type s2aHandshaker struct {
106 // stream is used to communicate with the S2A handshaker service.
107 stream s2apb.S2AService_SetUpSessionClient
108 // conn is the connection to the peer.
109 conn net.Conn
110 // clientOpts should be non-nil iff the handshaker is client-side.
111 clientOpts *ClientHandshakerOptions
112 // serverOpts should be non-nil iff the handshaker is server-side.
113 serverOpts *ServerHandshakerOptions
114 // isClient determines if the handshaker is client or server side.
115 isClient bool
116 // hsAddr stores the address of the S2A handshaker service.
117 hsAddr string
118 // tokenManager manages access tokens for authenticating to S2A.
119 tokenManager tokenmanager.AccessTokenManager
120 // localIdentities is the set of local identities for whom the
121 // tokenManager should fetch a token when preparing a request to be
122 // sent to S2A.
123 localIdentities []*commonpb.Identity
124}
125
126// NewClientHandshaker creates an s2aHandshaker instance that performs a
127// client-side TLS handshake using the S2A handshaker service.
128func NewClientHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ClientHandshakerOptions) (Handshaker, error) {
129 stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
130 if err != nil {
131 return nil, err
132 }
133 tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
134 if err != nil {
135 grpclog.Infof("failed to create single token access token manager: %v", err)
136 }
137 return newClientHandshaker(stream, c, hsAddr, opts, tokenManager), nil
138}
139
140func newClientHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ClientHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
141 var localIdentities []*commonpb.Identity
142 if opts != nil {
143 localIdentities = []*commonpb.Identity{opts.LocalIdentity}
144 }
145 return &s2aHandshaker{
146 stream: stream,
147 conn: c,
148 clientOpts: opts,
149 isClient: true,
150 hsAddr: hsAddr,
151 tokenManager: tokenManager,
152 localIdentities: localIdentities,
153 }
154}
155
156// NewServerHandshaker creates an s2aHandshaker instance that performs a
157// server-side TLS handshake using the S2A handshaker service.
158func NewServerHandshaker(ctx context.Context, conn *grpc.ClientConn, c net.Conn, hsAddr string, opts *ServerHandshakerOptions) (Handshaker, error) {
159 stream, err := s2apb.NewS2AServiceClient(conn).SetUpSession(ctx, grpc.WaitForReady(true))
160 if err != nil {
161 return nil, err
162 }
163 tokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
164 if err != nil {
165 grpclog.Infof("failed to create single token access token manager: %v", err)
166 }
167 return newServerHandshaker(stream, c, hsAddr, opts, tokenManager), nil
168}
169
170func newServerHandshaker(stream s2apb.S2AService_SetUpSessionClient, c net.Conn, hsAddr string, opts *ServerHandshakerOptions, tokenManager tokenmanager.AccessTokenManager) *s2aHandshaker {
171 var localIdentities []*commonpb.Identity
172 if opts != nil {
173 localIdentities = opts.LocalIdentities
174 }
175 return &s2aHandshaker{
176 stream: stream,
177 conn: c,
178 serverOpts: opts,
179 isClient: false,
180 hsAddr: hsAddr,
181 tokenManager: tokenManager,
182 localIdentities: localIdentities,
183 }
184}
185
186// ClientHandshake performs a client-side TLS handshake using the S2A handshaker
187// service. When complete, returns a TLS connection.
188func (h *s2aHandshaker) ClientHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
189 if !h.isClient {
190 return nil, nil, errors.New("only handshakers created using NewClientHandshaker can perform a client-side handshake")
191 }
192 // Extract the hostname from the target name. The target name is assumed to be an authority.
193 hostname, _, err := net.SplitHostPort(h.clientOpts.TargetName)
194 if err != nil {
195 // If the target name had no host port or could not be parsed, use it as is.
196 hostname = h.clientOpts.TargetName
197 }
198
199 // Prepare a client start message to send to the S2A handshaker service.
200 req := &s2apb.SessionReq{
201 ReqOneof: &s2apb.SessionReq_ClientStart{
202 ClientStart: &s2apb.ClientSessionStartReq{
203 ApplicationProtocols: []string{appProtocol},
204 MinTlsVersion: h.clientOpts.MinTLSVersion,
205 MaxTlsVersion: h.clientOpts.MaxTLSVersion,
206 TlsCiphersuites: h.clientOpts.TLSCiphersuites,
207 TargetIdentities: h.clientOpts.TargetIdentities,
208 LocalIdentity: h.clientOpts.LocalIdentity,
209 TargetName: hostname,
210 },
211 },
212 AuthMechanisms: h.getAuthMechanisms(),
213 }
214 conn, result, err := h.setUpSession(req)
215 if err != nil {
216 return nil, nil, err
217 }
218 authInfo, err := authinfo.NewS2AAuthInfo(result)
219 if err != nil {
220 return nil, nil, err
221 }
222 return conn, authInfo, nil
223}
224
225// ServerHandshake performs a server-side TLS handshake using the S2A handshaker
226// service. When complete, returns a TLS connection.
227func (h *s2aHandshaker) ServerHandshake(_ context.Context) (net.Conn, credentials.AuthInfo, error) {
228 if h.isClient {
229 return nil, nil, errors.New("only handshakers created using NewServerHandshaker can perform a server-side handshake")
230 }
231 p := make([]byte, frameLimit)
232 n, err := h.conn.Read(p)
233 if err != nil {
234 return nil, nil, err
235 }
236 // Prepare a server start message to send to the S2A handshaker service.
237 req := &s2apb.SessionReq{
238 ReqOneof: &s2apb.SessionReq_ServerStart{
239 ServerStart: &s2apb.ServerSessionStartReq{
240 ApplicationProtocols: []string{appProtocol},
241 MinTlsVersion: h.serverOpts.MinTLSVersion,
242 MaxTlsVersion: h.serverOpts.MaxTLSVersion,
243 TlsCiphersuites: h.serverOpts.TLSCiphersuites,
244 LocalIdentities: h.serverOpts.LocalIdentities,
245 InBytes: p[:n],
246 },
247 },
248 AuthMechanisms: h.getAuthMechanisms(),
249 }
250 conn, result, err := h.setUpSession(req)
251 if err != nil {
252 return nil, nil, err
253 }
254 authInfo, err := authinfo.NewS2AAuthInfo(result)
255 if err != nil {
256 return nil, nil, err
257 }
258 return conn, authInfo, nil
259}
260
261// setUpSession proxies messages between the peer and the S2A handshaker
262// service.
263func (h *s2aHandshaker) setUpSession(req *s2apb.SessionReq) (net.Conn, *s2apb.SessionResult, error) {
264 resp, err := h.accessHandshakerService(req)
265 if err != nil {
266 return nil, nil, err
267 }
268 // Check if the returned status is an error.
269 if resp.GetStatus() != nil {
270 if got, want := resp.GetStatus().Code, uint32(codes.OK); got != want {
271 return nil, nil, fmt.Errorf("%v", resp.GetStatus().Details)
272 }
273 }
274 // Calculate the extra unread bytes from the Session. Attempting to consume
275 // more than the bytes sent will throw an error.
276 var extra []byte
277 if req.GetServerStart() != nil {
278 if resp.GetBytesConsumed() > uint32(len(req.GetServerStart().GetInBytes())) {
279 return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
280 }
281 extra = req.GetServerStart().GetInBytes()[resp.GetBytesConsumed():]
282 }
283 result, extra, err := h.processUntilDone(resp, extra)
284 if err != nil {
285 return nil, nil, err
286 }
287 if result.GetLocalIdentity() == nil {
288 return nil, nil, errors.New("local identity must be populated in session result")
289 }
290
291 // Create a new TLS record protocol using the Session Result.
292 newConn, err := record.NewConn(&record.ConnParameters{
293 NetConn: h.conn,
294 Ciphersuite: result.GetState().GetTlsCiphersuite(),
295 TLSVersion: result.GetState().GetTlsVersion(),
296 InTrafficSecret: result.GetState().GetInKey(),
297 OutTrafficSecret: result.GetState().GetOutKey(),
298 UnusedBuf: extra,
299 InSequence: result.GetState().GetInSequence(),
300 OutSequence: result.GetState().GetOutSequence(),
301 HSAddr: h.hsAddr,
302 ConnectionID: result.GetState().GetConnectionId(),
303 LocalIdentity: result.GetLocalIdentity(),
304 EnsureProcessSessionTickets: h.ensureProcessSessionTickets(),
305 })
306 if err != nil {
307 return nil, nil, err
308 }
309 return newConn, result, nil
310}
311
312func (h *s2aHandshaker) ensureProcessSessionTickets() *sync.WaitGroup {
313 if h.clientOpts == nil {
314 return nil
315 }
316 return h.clientOpts.EnsureProcessSessionTickets
317}
318
319// accessHandshakerService sends the session request to the S2A handshaker
320// service and returns the session response.
321func (h *s2aHandshaker) accessHandshakerService(req *s2apb.SessionReq) (*s2apb.SessionResp, error) {
322 if err := h.stream.Send(req); err != nil {
323 return nil, err
324 }
325 resp, err := h.stream.Recv()
326 if err != nil {
327 return nil, err
328 }
329 return resp, nil
330}
331
332// processUntilDone continues proxying messages between the peer and the S2A
333// handshaker service until the handshaker service returns the SessionResult at
334// the end of the handshake or an error occurs.
335func (h *s2aHandshaker) processUntilDone(resp *s2apb.SessionResp, unusedBytes []byte) (*s2apb.SessionResult, []byte, error) {
336 for {
337 if len(resp.OutFrames) > 0 {
338 if _, err := h.conn.Write(resp.OutFrames); err != nil {
339 return nil, nil, err
340 }
341 }
342 if resp.Result != nil {
343 return resp.Result, unusedBytes, nil
344 }
345 buf := make([]byte, frameLimit)
346 n, err := h.conn.Read(buf)
347 if err != nil && err != io.EOF {
348 return nil, nil, err
349 }
350 // If there is nothing to send to the handshaker service and nothing is
351 // received from the peer, then we are stuck. This covers the case when
352 // the peer is not responding. Note that handshaker service connection
353 // issues are caught in accessHandshakerService before we even get
354 // here.
355 if len(resp.OutFrames) == 0 && n == 0 {
356 return nil, nil, errPeerNotResponding
357 }
358 // Append extra bytes from the previous interaction with the handshaker
359 // service with the current buffer read from conn.
360 p := append(unusedBytes, buf[:n]...)
361 // From here on, p and unusedBytes point to the same slice.
362 resp, err = h.accessHandshakerService(&s2apb.SessionReq{
363 ReqOneof: &s2apb.SessionReq_Next{
364 Next: &s2apb.SessionNextReq{
365 InBytes: p,
366 },
367 },
368 AuthMechanisms: h.getAuthMechanisms(),
369 })
370 if err != nil {
371 return nil, nil, err
372 }
373
374 // Cache the local identity returned by S2A, if it is populated. This
375 // overwrites any existing local identities. This is done because, once the
376 // S2A has selected a local identity, then only that local identity should
377 // be asserted in future requests until the end of the current handshake.
378 if resp.GetLocalIdentity() != nil {
379 h.localIdentities = []*commonpb.Identity{resp.GetLocalIdentity()}
380 }
381
382 // Set unusedBytes based on the handshaker service response.
383 if resp.GetBytesConsumed() > uint32(len(p)) {
384 return nil, nil, errors.New("handshaker service consumed bytes value is out-of-bounds")
385 }
386 unusedBytes = p[resp.GetBytesConsumed():]
387 }
388}
389
390// Close shuts down the handshaker and the stream to the S2A handshaker service
391// when the handshake is complete. It should be called when the caller obtains
392// the secure connection at the end of the handshake.
393func (h *s2aHandshaker) Close() error {
394 return h.stream.CloseSend()
395}
396
397func (h *s2aHandshaker) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
398 if h.tokenManager == nil {
399 return nil
400 }
401 // First handle the special case when no local identities have been provided
402 // by the application. In this case, an AuthenticationMechanism with no local
403 // identity will be sent.
404 if len(h.localIdentities) == 0 {
405 token, err := h.tokenManager.DefaultToken()
406 if err != nil {
407 grpclog.Infof("unable to get token for empty local identity: %v", err)
408 return nil
409 }
410 return []*s2apb.AuthenticationMechanism{
411 {
412 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
413 Token: token,
414 },
415 },
416 }
417 }
418
419 // Next, handle the case where the application (or the S2A) has provided
420 // one or more local identities.
421 var authMechanisms []*s2apb.AuthenticationMechanism
422 for _, localIdentity := range h.localIdentities {
423 token, err := h.tokenManager.Token(localIdentity)
424 if err != nil {
425 grpclog.Infof("unable to get token for local identity %v: %v", localIdentity, err)
426 continue
427 }
428
429 authMechanism := &s2apb.AuthenticationMechanism{
430 Identity: localIdentity,
431 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
432 Token: token,
433 },
434 }
435 authMechanisms = append(authMechanisms, authMechanism)
436 }
437 return authMechanisms
438}