1/*
2 *
3 * Copyright 2021 Google LLC
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19package record
20
21import (
22 "context"
23 "fmt"
24 "sync"
25 "time"
26
27 "github.com/google/s2a-go/internal/handshaker/service"
28 commonpb "github.com/google/s2a-go/internal/proto/common_go_proto"
29 s2apb "github.com/google/s2a-go/internal/proto/s2a_go_proto"
30 "github.com/google/s2a-go/internal/tokenmanager"
31 "google.golang.org/grpc/codes"
32 "google.golang.org/grpc/grpclog"
33)
34
35// sessionTimeout is the timeout for creating a session with the S2A handshaker
36// service.
37const sessionTimeout = time.Second * 5
38
39// s2aTicketSender sends session tickets to the S2A handshaker service.
40type s2aTicketSender interface {
41 // sendTicketsToS2A sends the given session tickets to the S2A handshaker
42 // service.
43 sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool)
44}
45
46// ticketStream is the stream used to send and receive session information.
47type ticketStream interface {
48 Send(*s2apb.SessionReq) error
49 Recv() (*s2apb.SessionResp, error)
50}
51
52type ticketSender struct {
53 // hsAddr stores the address of the S2A handshaker service.
54 hsAddr string
55 // connectionID is the connection identifier that was created and sent by
56 // S2A at the end of a handshake.
57 connectionID uint64
58 // localIdentity is the local identity that was used by S2A during session
59 // setup and included in the session result.
60 localIdentity *commonpb.Identity
61 // tokenManager manages access tokens for authenticating to S2A.
62 tokenManager tokenmanager.AccessTokenManager
63 // ensureProcessSessionTickets allows users to wait and ensure that all
64 // available session tickets are sent to S2A before a process completes.
65 ensureProcessSessionTickets *sync.WaitGroup
66}
67
68// sendTicketsToS2A sends the given sessionTickets to the S2A handshaker
69// service. This is done asynchronously and writes to the error logs if an error
70// occurs.
71func (t *ticketSender) sendTicketsToS2A(sessionTickets [][]byte, callComplete chan bool) {
72 // Note that the goroutine is in the function rather than at the caller
73 // because the fake ticket sender used for testing must run synchronously
74 // so that the session tickets can be accessed from it after the tests have
75 // been run.
76 if t.ensureProcessSessionTickets != nil {
77 t.ensureProcessSessionTickets.Add(1)
78 }
79 go func() {
80 if err := func() error {
81 defer func() {
82 if t.ensureProcessSessionTickets != nil {
83 t.ensureProcessSessionTickets.Done()
84 }
85 }()
86 ctx, cancel := context.WithTimeout(context.Background(), sessionTimeout)
87 defer cancel()
88 // The transportCreds only needs to be set when talking to S2AV2 and also
89 // if mTLS is required.
90 hsConn, err := service.Dial(ctx, t.hsAddr, nil)
91 if err != nil {
92 return err
93 }
94 client := s2apb.NewS2AServiceClient(hsConn)
95 session, err := client.SetUpSession(ctx)
96 if err != nil {
97 return err
98 }
99 defer func() {
100 if err := session.CloseSend(); err != nil {
101 grpclog.Error(err)
102 }
103 }()
104 return t.writeTicketsToStream(session, sessionTickets)
105 }(); err != nil {
106 grpclog.Errorf("failed to send resumption tickets to S2A with identity: %v, %v",
107 t.localIdentity, err)
108 }
109 callComplete <- true
110 close(callComplete)
111 }()
112}
113
114// writeTicketsToStream writes the given session tickets to the given stream.
115func (t *ticketSender) writeTicketsToStream(stream ticketStream, sessionTickets [][]byte) error {
116 if err := stream.Send(
117 &s2apb.SessionReq{
118 ReqOneof: &s2apb.SessionReq_ResumptionTicket{
119 ResumptionTicket: &s2apb.ResumptionTicketReq{
120 InBytes: sessionTickets,
121 ConnectionId: t.connectionID,
122 LocalIdentity: t.localIdentity,
123 },
124 },
125 AuthMechanisms: t.getAuthMechanisms(),
126 },
127 ); err != nil {
128 return err
129 }
130 sessionResp, err := stream.Recv()
131 if err != nil {
132 return err
133 }
134 if sessionResp.GetStatus().GetCode() != uint32(codes.OK) {
135 return fmt.Errorf("s2a session ticket response had error status: %v, %v",
136 sessionResp.GetStatus().GetCode(), sessionResp.GetStatus().GetDetails())
137 }
138 return nil
139}
140
141func (t *ticketSender) getAuthMechanisms() []*s2apb.AuthenticationMechanism {
142 if t.tokenManager == nil {
143 return nil
144 }
145 // First handle the special case when no local identity has been provided
146 // by the application. In this case, an AuthenticationMechanism with no local
147 // identity will be sent.
148 if t.localIdentity == nil {
149 token, err := t.tokenManager.DefaultToken()
150 if err != nil {
151 grpclog.Infof("unable to get token for empty local identity: %v", err)
152 return nil
153 }
154 return []*s2apb.AuthenticationMechanism{
155 {
156 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
157 Token: token,
158 },
159 },
160 }
161 }
162
163 // Next, handle the case where the application (or the S2A) has specified
164 // a local identity.
165 token, err := t.tokenManager.Token(t.localIdentity)
166 if err != nil {
167 grpclog.Infof("unable to get token for local identity %v: %v", t.localIdentity, err)
168 return nil
169 }
170 return []*s2apb.AuthenticationMechanism{
171 {
172 Identity: t.localIdentity,
173 MechanismOneof: &s2apb.AuthenticationMechanism_Token{
174 Token: token,
175 },
176 },
177 }
178}