1/*
2 *
3 * Copyright 2022 Google LLC
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package v2 provides the S2Av2 transport credentials used by a gRPC
20// application.
21package v2
22
23import (
24 "context"
25 "crypto/tls"
26 "errors"
27 "net"
28 "os"
29 "time"
30
31 "github.com/google/s2a-go/fallback"
32 "github.com/google/s2a-go/internal/handshaker/service"
33 "github.com/google/s2a-go/internal/tokenmanager"
34 "github.com/google/s2a-go/internal/v2/tlsconfigstore"
35 "github.com/google/s2a-go/retry"
36 "github.com/google/s2a-go/stream"
37 "google.golang.org/grpc"
38 "google.golang.org/grpc/credentials"
39 "google.golang.org/grpc/grpclog"
40 "google.golang.org/protobuf/proto"
41
42 commonpb "github.com/google/s2a-go/internal/proto/v2/common_go_proto"
43 s2av2pb "github.com/google/s2a-go/internal/proto/v2/s2a_go_proto"
44)
45
46const (
47 s2aSecurityProtocol = "tls"
48 defaultS2ATimeout = 6 * time.Second
49)
50
51// An environment variable, which sets the timeout enforced on the connection to the S2A service for handshake.
52const s2aTimeoutEnv = "S2A_TIMEOUT"
53
54type s2av2TransportCreds struct {
55 info *credentials.ProtocolInfo
56 isClient bool
57 serverName string
58 s2av2Address string
59 transportCreds credentials.TransportCredentials
60 tokenManager *tokenmanager.AccessTokenManager
61 // localIdentity should only be used by the client.
62 localIdentity *commonpb.Identity
63 // localIdentities should only be used by the server.
64 localIdentities []*commonpb.Identity
65 verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode
66 fallbackClientHandshake fallback.ClientHandshake
67 getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)
68 serverAuthorizationPolicy []byte
69}
70
71// NewClientCreds returns a client-side transport credentials object that uses
72// the S2Av2 to establish a secure connection with a server.
73func NewClientCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentity *commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, fallbackClientHandshakeFunc fallback.ClientHandshake, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error), serverAuthorizationPolicy []byte) (credentials.TransportCredentials, error) {
74 // Create an AccessTokenManager instance to use to authenticate to S2Av2.
75 accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
76
77 creds := &s2av2TransportCreds{
78 info: &credentials.ProtocolInfo{
79 SecurityProtocol: s2aSecurityProtocol,
80 },
81 isClient: true,
82 serverName: "",
83 s2av2Address: s2av2Address,
84 transportCreds: transportCreds,
85 localIdentity: localIdentity,
86 verificationMode: verificationMode,
87 fallbackClientHandshake: fallbackClientHandshakeFunc,
88 getS2AStream: getS2AStream,
89 serverAuthorizationPolicy: serverAuthorizationPolicy,
90 }
91 if err != nil {
92 creds.tokenManager = nil
93 } else {
94 creds.tokenManager = &accessTokenManager
95 }
96 if grpclog.V(1) {
97 grpclog.Info("Created client S2Av2 transport credentials.")
98 }
99 return creds, nil
100}
101
102// NewServerCreds returns a server-side transport credentials object that uses
103// the S2Av2 to establish a secure connection with a client.
104func NewServerCreds(s2av2Address string, transportCreds credentials.TransportCredentials, localIdentities []*commonpb.Identity, verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (credentials.TransportCredentials, error) {
105 // Create an AccessTokenManager instance to use to authenticate to S2Av2.
106 accessTokenManager, err := tokenmanager.NewSingleTokenAccessTokenManager()
107 creds := &s2av2TransportCreds{
108 info: &credentials.ProtocolInfo{
109 SecurityProtocol: s2aSecurityProtocol,
110 },
111 isClient: false,
112 s2av2Address: s2av2Address,
113 transportCreds: transportCreds,
114 localIdentities: localIdentities,
115 verificationMode: verificationMode,
116 getS2AStream: getS2AStream,
117 }
118 if err != nil {
119 creds.tokenManager = nil
120 } else {
121 creds.tokenManager = &accessTokenManager
122 }
123 if grpclog.V(1) {
124 grpclog.Info("Created server S2Av2 transport credentials.")
125 }
126 return creds, nil
127}
128
129// ClientHandshake performs a client-side mTLS handshake using the S2Av2.
130func (c *s2av2TransportCreds) ClientHandshake(ctx context.Context, serverAuthority string, rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
131 if !c.isClient {
132 return nil, nil, errors.New("client handshake called using server transport credentials")
133 }
134 // Remove the port from serverAuthority.
135 serverName := removeServerNamePort(serverAuthority)
136 timeoutCtx, cancel := context.WithTimeout(ctx, GetS2ATimeout())
137 defer cancel()
138 var s2AStream stream.S2AStream
139 var err error
140 retry.Run(timeoutCtx,
141 func() error {
142 s2AStream, err = createStream(timeoutCtx, c.s2av2Address, c.transportCreds, c.getS2AStream)
143 return err
144 })
145 if err != nil {
146 grpclog.Infof("Failed to connect to S2Av2: %v", err)
147 if c.fallbackClientHandshake != nil {
148 return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
149 }
150 return nil, nil, err
151 }
152 defer s2AStream.CloseSend()
153 if grpclog.V(1) {
154 grpclog.Infof("Connected to S2Av2.")
155 }
156 var config *tls.Config
157
158 var tokenManager tokenmanager.AccessTokenManager
159 if c.tokenManager == nil {
160 tokenManager = nil
161 } else {
162 tokenManager = *c.tokenManager
163 }
164
165 sn := serverName
166 if c.serverName != "" {
167 sn = c.serverName
168 }
169 retry.Run(timeoutCtx,
170 func() error {
171 config, err = tlsconfigstore.GetTLSConfigurationForClient(sn, s2AStream, tokenManager, c.localIdentity, c.verificationMode, c.serverAuthorizationPolicy)
172 return err
173 })
174 if err != nil {
175 grpclog.Info("Failed to get client TLS config from S2Av2: %v", err)
176 if c.fallbackClientHandshake != nil {
177 return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
178 }
179 return nil, nil, err
180 }
181 if grpclog.V(1) {
182 grpclog.Infof("Got client TLS config from S2Av2.")
183 }
184
185 creds := credentials.NewTLS(config)
186 conn, authInfo, err := creds.ClientHandshake(timeoutCtx, serverName, rawConn)
187 if err != nil {
188 grpclog.Infof("Failed to do client handshake using S2Av2: %v", err)
189 if c.fallbackClientHandshake != nil {
190 return c.fallbackClientHandshake(ctx, serverAuthority, rawConn, err)
191 }
192 return nil, nil, err
193 }
194 grpclog.Infof("client-side handshake is done using S2Av2 to: %s", serverName)
195
196 return conn, authInfo, err
197}
198
199// ServerHandshake performs a server-side mTLS handshake using the S2Av2.
200func (c *s2av2TransportCreds) ServerHandshake(rawConn net.Conn) (net.Conn, credentials.AuthInfo, error) {
201 if c.isClient {
202 return nil, nil, errors.New("server handshake called using client transport credentials")
203 }
204 ctx, cancel := context.WithTimeout(context.Background(), GetS2ATimeout())
205 defer cancel()
206 var s2AStream stream.S2AStream
207 var err error
208 retry.Run(ctx,
209 func() error {
210 s2AStream, err = createStream(ctx, c.s2av2Address, c.transportCreds, c.getS2AStream)
211 return err
212 })
213 if err != nil {
214 grpclog.Infof("Failed to connect to S2Av2: %v", err)
215 return nil, nil, err
216 }
217 defer s2AStream.CloseSend()
218 if grpclog.V(1) {
219 grpclog.Infof("Connected to S2Av2.")
220 }
221
222 var tokenManager tokenmanager.AccessTokenManager
223 if c.tokenManager == nil {
224 tokenManager = nil
225 } else {
226 tokenManager = *c.tokenManager
227 }
228
229 var config *tls.Config
230 retry.Run(ctx,
231 func() error {
232 config, err = tlsconfigstore.GetTLSConfigurationForServer(s2AStream, tokenManager, c.localIdentities, c.verificationMode)
233 return err
234 })
235 if err != nil {
236 grpclog.Infof("Failed to get server TLS config from S2Av2: %v", err)
237 return nil, nil, err
238 }
239 if grpclog.V(1) {
240 grpclog.Infof("Got server TLS config from S2Av2.")
241 }
242
243 creds := credentials.NewTLS(config)
244 conn, authInfo, err := creds.ServerHandshake(rawConn)
245 if err != nil {
246 grpclog.Infof("Failed to do server handshake using S2Av2: %v", err)
247 return nil, nil, err
248 }
249 return conn, authInfo, err
250}
251
252// Info returns protocol info of s2av2TransportCreds.
253func (c *s2av2TransportCreds) Info() credentials.ProtocolInfo {
254 return *c.info
255}
256
257// Clone makes a deep copy of s2av2TransportCreds.
258func (c *s2av2TransportCreds) Clone() credentials.TransportCredentials {
259 info := *c.info
260 serverName := c.serverName
261 fallbackClientHandshake := c.fallbackClientHandshake
262
263 s2av2Address := c.s2av2Address
264 var tokenManager tokenmanager.AccessTokenManager
265 if c.tokenManager == nil {
266 tokenManager = nil
267 } else {
268 tokenManager = *c.tokenManager
269 }
270 verificationMode := c.verificationMode
271 var localIdentity *commonpb.Identity
272 if c.localIdentity != nil {
273 localIdentity = proto.Clone(c.localIdentity).(*commonpb.Identity)
274 }
275 var localIdentities []*commonpb.Identity
276 if c.localIdentities != nil {
277 localIdentities = make([]*commonpb.Identity, len(c.localIdentities))
278 for i, localIdentity := range c.localIdentities {
279 localIdentities[i] = proto.Clone(localIdentity).(*commonpb.Identity)
280 }
281 }
282 creds := &s2av2TransportCreds{
283 info: &info,
284 isClient: c.isClient,
285 serverName: serverName,
286 fallbackClientHandshake: fallbackClientHandshake,
287 s2av2Address: s2av2Address,
288 localIdentity: localIdentity,
289 localIdentities: localIdentities,
290 verificationMode: verificationMode,
291 }
292 if c.tokenManager == nil {
293 creds.tokenManager = nil
294 } else {
295 creds.tokenManager = &tokenManager
296 }
297 return creds
298}
299
300// NewClientTLSConfig returns a tls.Config instance that uses S2Av2 to establish a TLS connection as
301// a client. The tls.Config MUST only be used to establish a single TLS connection.
302func NewClientTLSConfig(
303 ctx context.Context,
304 s2av2Address string,
305 transportCreds credentials.TransportCredentials,
306 tokenManager tokenmanager.AccessTokenManager,
307 verificationMode s2av2pb.ValidatePeerCertificateChainReq_VerificationMode,
308 serverName string,
309 serverAuthorizationPolicy []byte) (*tls.Config, error) {
310 s2AStream, err := createStream(ctx, s2av2Address, transportCreds, nil)
311 if err != nil {
312 grpclog.Infof("Failed to connect to S2Av2: %v", err)
313 return nil, err
314 }
315
316 return tlsconfigstore.GetTLSConfigurationForClient(removeServerNamePort(serverName), s2AStream, tokenManager, nil, verificationMode, serverAuthorizationPolicy)
317}
318
319// OverrideServerName sets the ServerName in the s2av2TransportCreds protocol
320// info. The ServerName MUST be a hostname.
321func (c *s2av2TransportCreds) OverrideServerName(serverNameOverride string) error {
322 serverName := removeServerNamePort(serverNameOverride)
323 c.info.ServerName = serverName
324 c.serverName = serverName
325 return nil
326}
327
328// Remove the trailing port from server name.
329func removeServerNamePort(serverName string) string {
330 name, _, err := net.SplitHostPort(serverName)
331 if err != nil {
332 name = serverName
333 }
334 return name
335}
336
337type s2AGrpcStream struct {
338 stream s2av2pb.S2AService_SetUpSessionClient
339}
340
341func (x s2AGrpcStream) Send(m *s2av2pb.SessionReq) error {
342 return x.stream.Send(m)
343}
344
345func (x s2AGrpcStream) Recv() (*s2av2pb.SessionResp, error) {
346 return x.stream.Recv()
347}
348
349func (x s2AGrpcStream) CloseSend() error {
350 return x.stream.CloseSend()
351}
352
353func createStream(ctx context.Context, s2av2Address string, transportCreds credentials.TransportCredentials, getS2AStream func(ctx context.Context, s2av2Address string) (stream.S2AStream, error)) (stream.S2AStream, error) {
354 if getS2AStream != nil {
355 return getS2AStream(ctx, s2av2Address)
356 }
357 // TODO(rmehta19): Consider whether to close the connection to S2Av2.
358 conn, err := service.Dial(ctx, s2av2Address, transportCreds)
359 if err != nil {
360 return nil, err
361 }
362 client := s2av2pb.NewS2AServiceClient(conn)
363 gRPCStream, err := client.SetUpSession(ctx, []grpc.CallOption{}...)
364 if err != nil {
365 return nil, err
366 }
367 return &s2AGrpcStream{
368 stream: gRPCStream,
369 }, nil
370}
371
372// GetS2ATimeout returns the timeout enforced on the connection to the S2A service for handshake.
373func GetS2ATimeout() time.Duration {
374 timeout, err := time.ParseDuration(os.Getenv(s2aTimeoutEnv))
375 if err != nil {
376 return defaultS2ATimeout
377 }
378 return timeout
379}