ticketsender.go

  1/*
  2 *
  3 * Copyright 2021 Google LLC
  4 *
  5 * Licensed under the Apache License, Version 2.0 (the "License");
  6 * you may not use this file except in compliance with the License.
  7 * You may obtain a copy of the License at
  8 *
  9 *     https://www.apache.org/licenses/LICENSE-2.0
 10 *
 11 * Unless required by applicable law or agreed to in writing, software
 12 * distributed under the License is distributed on an "AS IS" BASIS,
 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 14 * See the License for the specific language governing permissions and
 15 * limitations under the License.
 16 *
 17 */
 18
 19package record
 20
 21import (
 22	"context"
 23	"fmt"
 24	"sync"
 25	"time"
 26
 27	"github.com/google/s2a-go/internal/handshaker/service"
 28	commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
 29	s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
 30	"github.com/google/s2a-go/internal/tokenmanager"
 31	"google.golang.org/grpc/codes"
 32	"google.golang.org/grpc/grpclog"
 33)
 34
 35// sessionTimeout is the timeout for creating a session with the S2A handshaker
 36// service.
 37const sessionTimeout = time.Second * 5
 38
 39// s2aTicketSender sends session tickets to the S2A handshaker service.
 40type s2aTicketSender interface {
 41	// sendTicketsToS2A sends the given session tickets to the S2A handshaker
 42	// service.
 43	sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool)
 44}
 45
 46// ticketStream is the stream used to send and receive session information.
 47type ticketStream interface {
 48	Send(*s2apb.SessionReq) error
 49	Recv() (*s2apb.SessionResp, error)
 50}
 51
 52type ticketSender struct {
 53	// hsAddr stores the address of the S2A handshaker service.
 54	hsAddr string
 55	// connectionID is the connection identifier that was created and sent by
 56	// S2A at the end of a handshake.
 57	connectionID uint64
 58	// localIdentity is the local identity that was used by S2A during session
 59	// setup and included in the session result.
 60	localIdentity *commonpb.Identity
 61	// tokenManager manages access tokens for authenticating to S2A.
 62	tokenManager tokenmanager.AccessTokenManager
 63	// ensureProcessSessionTickets allows users to wait and ensure that all
 64	// available session tickets are sent to S2A before a process completes.
 65	ensureProcessSessionTickets *sync.WaitGroup
 66}
 67
 68// sendTicketsToS2A sends the given sessionTickets to the S2A handshaker
 69// service. This is done asynchronously and writes to the error logs if an error
 70// occurs.
 71func (t *ticketSender) sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool) {
 72	// Note that the goroutine is in the function rather than at the caller
 73	// because the fake ticket sender used for testing must run synchronously
 74	// so that the session tickets can be accessed from it after the tests have
 75	// been run.
 76	if t.ensureProcessSessionTickets != nil {
 77		t.ensureProcessSessionTickets.Add(1)
 78	}
 79	go func() {
 80		if err := func() error {
 81			defer func() {
 82				if t.ensureProcessSessionTickets != nil {
 83					t.ensureProcessSessionTickets.Done()
 84				}
 85			}()
 86			ctx, cancel := context.WithTimeout(context.Background(), sessionTimeout)
 87			defer cancel()
 88			// The transportCreds only needs to be set when talking to S2AV2 and also
 89			// if mTLS is required.
 90			hsConn, err := service.Dial(ctx, t.hsAddr, nil)
 91			if err != nil {
 92				return err
 93			}
 94			client := s2apb.NewS2AServiceClient(hsConn)
 95			session, err := client.SetUpSession(ctx)
 96			if err != nil {
 97				return err
 98			}
 99			defer func() {
100				if err := session.CloseSend(); err != nil {
101					grpclog.Error(err)
102				}
103			}()
104			return t.writeTicketsToStream(session, sessionTickets)
105		}(); err != nil {
106			grpclog.Errorf("failed to send resumption tickets to S2A with identity: %v, %v",
107				t.localIdentity, err)
108		}
109		callComplete <- true
110		close(callComplete)
111	}()
112}
113
114// writeTicketsToStream writes the given session tickets to the given stream.
115func (t *ticketSender) writeTicketsToStream(stream ticketStream, sessionTickets [][]byte) error {
116	if err := stream.Send(
117		&s2apb.SessionReq{
118			ReqOneof: &s2apb.SessionReq_ResumptionTicket{
119				ResumptionTicket: &s2apb.ResumptionTicketReq{
120					InBytes:       sessionTickets,
121					ConnectionId:  t.connectionID,
122					LocalIdentity: t.localIdentity,
123				},
124			},
125			AuthMechanisms: t.getAuthMechanisms(),
126		},
127	); err != nil {
128		return err
129	}
130	sessionResp, err := stream.Recv()
131	if err != nil {
132		return err
133	}
134	if sessionResp.GetStatus().GetCode() != uint32(codes.OK) {
135		return fmt.Errorf("s2a session ticket response had error status: %v, %v",
136			sessionResp.GetStatus().GetCode(), sessionResp.GetStatus().GetDetails())
137	}
138	return nil
139}
140
141func (t *ticketSender) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
142	if t.tokenManager == nil {
143		return nil
144	}
145	// First handle the special case when no local identity has been provided
146	// by the application. In this case, an AuthenticationMechanism with no local
147	// identity will be sent.
148	if t.localIdentity == nil {
149		token, err := t.tokenManager.DefaultToken()
150		if err != nil {
151			grpclog.Infof("unable to get token for empty local identity: %v", err)
152			return nil
153		}
154		return []*s2apb.AuthenticationMechanism{
155			{
156				MechanismOneof: &s2apb.AuthenticationMechanism_Token{
157					Token: token,
158				},
159			},
160		}
161	}
162
163	// Next, handle the case where the application (or the S2A) has specified
164	// a local identity.
165	token, err := t.tokenManager.Token(t.localIdentity)
166	if err != nil {
167		grpclog.Infof("unable to get token for local identity %v: %v", t.localIdentity, err)
168		return nil
169	}
170	return []*s2apb.AuthenticationMechanism{
171		{
172			Identity: t.localIdentity,
173			MechanismOneof: &s2apb.AuthenticationMechanism_Token{
174				Token: token,
175			},
176		},
177	}
178}