client.go

  1// Copyright 2024 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package genai
 16
 17import (
 18	"context"
 19	"fmt"
 20	"net/http"
 21	"os"
 22	"strings"
 23
 24	"cloud.google.com/go/auth"
 25	"cloud.google.com/go/auth/credentials"
 26	"cloud.google.com/go/auth/httptransport"
 27)
 28
 29// Client is the GenAI client. It provides access to the various GenAI services.
 30type Client struct {
 31	clientConfig ClientConfig
 32	// Models provides access to the Models service.
 33	Models *Models
 34	// Live provides access to the Live service.
 35	Live *Live
 36	// Caches provides access to the Caches service.
 37	Caches *Caches
 38	// Chats provides util functions for creating a new chat session.
 39	Chats *Chats
 40	// Files provides access to the Files service.
 41	Files *Files
 42	// Operations provides access to long-running operations.
 43	Operations *Operations
 44}
 45
 46// Backend is the GenAI backend to use for the client.
 47type Backend int
 48
 49const (
 50	// BackendUnspecified causes the backend determined automatically. If the
 51	// GOOGLE_GENAI_USE_VERTEXAI environment variable is set to "1" or "true", then
 52	// the backend is BackendVertexAI. Otherwise, if GOOGLE_GENAI_USE_VERTEXAI
 53	// is unset or set to any other value, then BackendGeminiAPI is used.  Explicitly
 54	// setting the backend in ClientConfig overrides the environment variable.
 55	BackendUnspecified Backend = iota
 56	// BackendGeminiAPI is the Gemini API backend.
 57	BackendGeminiAPI
 58	// BackendVertexAI is the Vertex AI backend.
 59	BackendVertexAI
 60)
 61
 62// The Stringer interface for Backend.
 63func (t Backend) String() string {
 64	switch t {
 65	case BackendGeminiAPI:
 66		return "BackendGeminiAPI"
 67	case BackendVertexAI:
 68		return "BackendVertexAI"
 69	default:
 70		return "BackendUnspecified"
 71	}
 72}
 73
 74// ClientConfig is the configuration for the GenAI client.
 75type ClientConfig struct {
 76	// Optional. API Key for GenAI. Required for BackendGeminiAPI.
 77	// Can also be set via the GOOGLE_API_KEY environment variable.
 78	// Get a Gemini API key: https://ai.google.dev/gemini-api/docs/api-key
 79	APIKey string
 80
 81	// Optional. Backend for GenAI. See Backend constants. Defaults to BackendGeminiAPI unless explicitly set to BackendVertexAI,
 82	// or the environment variable GOOGLE_GENAI_USE_VERTEXAI is set to "1" or "true".
 83	Backend Backend
 84
 85	// Optional. GCP Project ID for Vertex AI. Required for BackendVertexAI.
 86	// Can also be set via the GOOGLE_CLOUD_PROJECT environment variable.
 87	// Find your Project ID: https://cloud.google.com/resource-manager/docs/creating-managing-projects#identifying_projects
 88	Project string
 89
 90	// Optional. GCP Location/Region for Vertex AI. Required for BackendVertexAI.
 91	// Can also be set via the GOOGLE_CLOUD_LOCATION or GOOGLE_CLOUD_REGION environment variable.
 92	// Generative AI locations: https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations.
 93	Location string
 94
 95	// Optional. Google credentials.  If not specified, [Application Default Credentials] will be used.
 96	//
 97	// [Application Default Credentials]: https://developers.google.com/accounts/docs/application-default-credentials
 98	Credentials *auth.Credentials
 99
100	// Optional HTTP client to use. If nil, a default client will be created.
101	// For Vertex AI, this client must handle authentication appropriately.
102	HTTPClient *http.Client
103
104	// Optional HTTP options to override.
105	HTTPOptions HTTPOptions
106
107	envVarProvider func() map[string]string
108}
109
110func defaultEnvVarProvider() map[string]string {
111	vars := make(map[string]string)
112	if v, ok := os.LookupEnv("GOOGLE_GENAI_USE_VERTEXAI"); ok {
113		vars["GOOGLE_GENAI_USE_VERTEXAI"] = v
114	}
115	if v, ok := os.LookupEnv("GOOGLE_API_KEY"); ok {
116		vars["GOOGLE_API_KEY"] = v
117	}
118	if v, ok := os.LookupEnv("GOOGLE_CLOUD_PROJECT"); ok {
119		vars["GOOGLE_CLOUD_PROJECT"] = v
120	}
121	if v, ok := os.LookupEnv("GOOGLE_CLOUD_LOCATION"); ok {
122		vars["GOOGLE_CLOUD_LOCATION"] = v
123	}
124	if v, ok := os.LookupEnv("GOOGLE_CLOUD_REGION"); ok {
125		vars["GOOGLE_CLOUD_REGION"] = v
126	}
127	if v, ok := os.LookupEnv("GOOGLE_GEMINI_BASE_URL"); ok {
128		vars["GOOGLE_GEMINI_BASE_URL"] = v
129	}
130	if v, ok := os.LookupEnv("GOOGLE_VERTEX_BASE_URL"); ok {
131		vars["GOOGLE_VERTEX_BASE_URL"] = v
132	}
133	return vars
134}
135
136// NewClient creates a new GenAI client.
137//
138// You can configure the client by passing in a ClientConfig struct.
139//
140// If a nil ClientConfig is provided, the client will be configured using
141// default settings and environment variables:
142//
143//   - Environment Variables for BackendGeminiAPI:
144//
145//   - GOOGLE_API_KEY: Required. Specifies the API key for the Gemini API.
146//
147//   - Environment Variables for BackendVertexAI:
148//
149//   - GOOGLE_GENAI_USE_VERTEXAI: Must be set to "1" or "true" to use the Vertex AI
150//     backend.
151//
152//   - GOOGLE_CLOUD_PROJECT: Required. Specifies the GCP project ID.
153//
154//   - GOOGLE_CLOUD_LOCATION or GOOGLE_CLOUD_REGION: Required. Specifies the GCP
155//     location/region.
156//
157// If using the Vertex AI backend and no credentials are provided in the
158// ClientConfig, the client will attempt to use application default credentials.
159func NewClient(ctx context.Context, cc *ClientConfig) (*Client, error) {
160	if cc == nil {
161		cc = &ClientConfig{}
162	}
163
164	if cc.envVarProvider == nil {
165		cc.envVarProvider = defaultEnvVarProvider
166	}
167	envVars := cc.envVarProvider()
168
169	if cc.Project != "" && cc.APIKey != "" {
170		return nil, fmt.Errorf("project and API key are mutually exclusive in the client initializer. ClientConfig: %#v", cc)
171	}
172	if cc.Location != "" && cc.APIKey != "" {
173		return nil, fmt.Errorf("location and API key are mutually exclusive in the client initializer. ClientConfig: %#v", cc)
174	}
175
176	if cc.Backend == BackendUnspecified {
177		if v, ok := envVars["GOOGLE_GENAI_USE_VERTEXAI"]; ok {
178			v = strings.ToLower(v)
179			if v == "1" || v == "true" {
180				cc.Backend = BackendVertexAI
181			} else {
182				cc.Backend = BackendGeminiAPI
183			}
184		} else {
185			cc.Backend = BackendGeminiAPI
186		}
187	}
188
189	// Only set the API key for MLDev API.
190	if cc.APIKey == "" && cc.Backend == BackendGeminiAPI {
191		cc.APIKey = envVars["GOOGLE_API_KEY"]
192	}
193	if cc.Project == "" {
194		cc.Project = envVars["GOOGLE_CLOUD_PROJECT"]
195	}
196	if cc.Location == "" {
197		if location, ok := envVars["GOOGLE_CLOUD_LOCATION"]; ok {
198			cc.Location = location
199		} else if location, ok := envVars["GOOGLE_CLOUD_REGION"]; ok {
200			cc.Location = location
201		}
202	}
203
204	if cc.Backend == BackendVertexAI {
205		if cc.Project == "" {
206			return nil, fmt.Errorf("project is required for Vertex AI backend. ClientConfig: %#v", cc)
207		}
208		if cc.Location == "" {
209			return nil, fmt.Errorf("location is required for Vertex AI backend. ClientConfig: %#v", cc)
210		}
211	} else {
212		if cc.APIKey == "" {
213			return nil, fmt.Errorf("api key is required for Google AI backend. ClientConfig: %#v.\nYou can get the API key from https://ai.google.dev/gemini-api/docs/api-key", cc)
214		}
215	}
216
217	if cc.Backend == BackendVertexAI && cc.Credentials == nil {
218		cred, err := credentials.DetectDefault(&credentials.DetectOptions{
219			Scopes: []string{"https://www.googleapis.com/auth/cloud-platform"},
220		})
221		if err != nil {
222			return nil, fmt.Errorf("failed to find default credentials: %w", err)
223		}
224		cc.Credentials = cred
225	}
226
227	baseURL := getBaseURL(cc.Backend, &cc.HTTPOptions, envVars)
228	if baseURL != "" {
229		cc.HTTPOptions.BaseURL = baseURL
230	}
231	if cc.HTTPOptions.BaseURL == "" && cc.Backend == BackendVertexAI {
232		if cc.Location == "global" {
233			cc.HTTPOptions.BaseURL = "https://aiplatform.googleapis.com/"
234		} else {
235			cc.HTTPOptions.BaseURL = fmt.Sprintf("https://%s-aiplatform.googleapis.com/", cc.Location)
236		}
237	} else if cc.HTTPOptions.BaseURL == "" {
238		cc.HTTPOptions.BaseURL = "https://generativelanguage.googleapis.com/"
239	}
240
241	if cc.HTTPOptions.APIVersion == "" && cc.Backend == BackendVertexAI {
242		cc.HTTPOptions.APIVersion = "v1beta1"
243	} else if cc.HTTPOptions.APIVersion == "" {
244		cc.HTTPOptions.APIVersion = "v1beta"
245	}
246
247	if cc.HTTPClient == nil {
248		if cc.Backend == BackendVertexAI {
249			quotaProjectID, err := cc.Credentials.QuotaProjectID(ctx)
250			if err != nil {
251				return nil, fmt.Errorf("failed to get quota project ID: %w", err)
252			}
253			client, err := httptransport.NewClient(&httptransport.Options{
254				Credentials: cc.Credentials,
255				Headers: http.Header{
256					"X-Goog-User-Project": []string{quotaProjectID},
257				},
258			})
259			if err != nil {
260				return nil, fmt.Errorf("failed to create HTTP client: %w", err)
261			}
262			cc.HTTPClient = client
263		} else {
264			cc.HTTPClient = &http.Client{}
265		}
266	}
267
268	ac := &apiClient{clientConfig: cc}
269	c := &Client{
270		clientConfig: *cc,
271		Models:       &Models{apiClient: ac},
272		Live:         &Live{apiClient: ac},
273		Caches:       &Caches{apiClient: ac},
274		Chats:        &Chats{apiClient: ac},
275		Operations:   &Operations{apiClient: ac},
276		Files:        &Files{apiClient: ac},
277	}
278	return c, nil
279}
280
281// ClientConfig returns the ClientConfig for the client.
282//
283// The returned ClientConfig is a copy of the ClientConfig used to create the client.
284func (c Client) ClientConfig() ClientConfig {
285	return c.clientConfig
286}