service.go

 1/*
 2 *
 3 * Copyright 2021 Google LLC
 4 *
 5 * Licensed under the Apache License, Version 2.0 (the "License");
 6 * you may not use this file except in compliance with the License.
 7 * You may obtain a copy of the License at
 8 *
 9 *     https://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19// Package service is a utility for calling the S2A handshaker service.
20package service
21
22import (
23	"context"
24	"sync"
25
26	grpc "google.golang.org/grpc"
27	"google.golang.org/grpc/credentials"
28	"google.golang.org/grpc/credentials/insecure"
29)
30
31var (
32	// mu guards hsConnMap and hsDialer.
33	mu sync.Mutex
34	// hsConnMap represents a mapping from an S2A handshaker service address
35	// to a corresponding connection to an S2A handshaker service instance.
36	hsConnMap = make(map[string]*grpc.ClientConn)
37	// hsDialer will be reassigned in tests.
38	hsDialer = grpc.DialContext
39)
40
41// Dial dials the S2A handshaker service. If a connection has already been
42// established, this function returns it. Otherwise, a new connection is
43// created.
44func Dial(ctx context.Context, handshakerServiceAddress string, transportCreds credentials.TransportCredentials) (*grpc.ClientConn, error) {
45	mu.Lock()
46	defer mu.Unlock()
47
48	hsConn, ok := hsConnMap[handshakerServiceAddress]
49	if !ok {
50		// Create a new connection to the S2A handshaker service. Note that
51		// this connection stays open until the application is closed.
52		var grpcOpts []grpc.DialOption
53		if transportCreds != nil {
54			grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(transportCreds))
55		} else {
56			grpcOpts = append(grpcOpts, grpc.WithTransportCredentials(insecure.NewCredentials()))
57		}
58		var err error
59		hsConn, err = hsDialer(ctx, handshakerServiceAddress, grpcOpts...)
60		if err != nil {
61			return nil, err
62		}
63		hsConnMap[handshakerServiceAddress] = hsConn
64	}
65	return hsConn, nil
66}