1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package genai
16
17import (
18 "bufio"
19 "bytes"
20 "context"
21 "encoding/json"
22 "fmt"
23 "io"
24 "iter"
25 "log"
26 "net/http"
27 "net/url"
28 "runtime"
29 "strconv"
30 "strings"
31 "time"
32)
33
34const maxChunkSize = 8 * 1024 * 1024 // 8 MB chunk size
35
36type apiClient struct {
37 clientConfig *ClientConfig
38}
39
40// sendStreamRequest issues an server streaming API request and returns a map of the response contents.
41func sendStreamRequest[T responseStream[R], R any](ctx context.Context, ac *apiClient, path string, method string, body map[string]any, httpOptions *HTTPOptions, output *responseStream[R]) error {
42 req, err := buildRequest(ctx, ac, path, body, method, httpOptions)
43 if err != nil {
44 return err
45 }
46
47 resp, err := doRequest(ac, req)
48 if err != nil {
49 return err
50 }
51
52 // resp.Body will be closed by the iterator
53 return deserializeStreamResponse(resp, output)
54}
55
56// sendRequest issues an API request and returns a map of the response contents.
57func sendRequest(ctx context.Context, ac *apiClient, path string, method string, body map[string]any, httpOptions *HTTPOptions) (map[string]any, error) {
58 req, err := buildRequest(ctx, ac, path, body, method, httpOptions)
59 if err != nil {
60 return nil, err
61 }
62
63 resp, err := doRequest(ac, req)
64 if err != nil {
65 return nil, err
66 }
67 defer resp.Body.Close()
68
69 return deserializeUnaryResponse(resp)
70}
71
72func downloadFile(ctx context.Context, ac *apiClient, path string, httpOptions *HTTPOptions) ([]byte, error) {
73 req, err := buildRequest(ctx, ac, path, nil, http.MethodGet, httpOptions)
74 if err != nil {
75 return nil, err
76 }
77
78 resp, err := doRequest(ac, req)
79 if err != nil {
80 return nil, err
81 }
82 return io.ReadAll(resp.Body)
83}
84
85func mapToStruct[R any](input map[string]any, output *R) error {
86 b := new(bytes.Buffer)
87 err := json.NewEncoder(b).Encode(input)
88 if err != nil {
89 return fmt.Errorf("mapToStruct: error encoding input %#v: %w", input, err)
90 }
91 err = json.Unmarshal(b.Bytes(), output)
92 if err != nil {
93 return fmt.Errorf("mapToStruct: error unmarshalling input %#v: %w", input, err)
94 }
95 return nil
96}
97
98func (ac *apiClient) createAPIURL(suffix, method string, httpOptions *HTTPOptions) (*url.URL, error) {
99 if ac.clientConfig.Backend == BackendVertexAI {
100 queryVertexBaseModel := ac.clientConfig.Backend == BackendVertexAI && method == http.MethodGet && strings.HasPrefix(suffix, "publishers/google/models")
101 if !strings.HasPrefix(suffix, "projects/") && !queryVertexBaseModel {
102 suffix = fmt.Sprintf("projects/%s/locations/%s/%s", ac.clientConfig.Project, ac.clientConfig.Location, suffix)
103 }
104 u, err := url.Parse(fmt.Sprintf("%s/%s/%s", httpOptions.BaseURL, httpOptions.APIVersion, suffix))
105 if err != nil {
106 return nil, fmt.Errorf("createAPIURL: error parsing Vertex AI URL: %w", err)
107 }
108 return u, nil
109 } else {
110 if !strings.Contains(suffix, fmt.Sprintf("/%s/", httpOptions.APIVersion)) {
111 suffix = fmt.Sprintf("%s/%s", httpOptions.APIVersion, suffix)
112 }
113 u, err := url.Parse(fmt.Sprintf("%s/%s", httpOptions.BaseURL, suffix))
114 if err != nil {
115 return nil, fmt.Errorf("createAPIURL: error parsing ML Dev URL: %w", err)
116 }
117 return u, nil
118 }
119}
120
121func buildRequest(ctx context.Context, ac *apiClient, path string, body map[string]any, method string, httpOptions *HTTPOptions) (*http.Request, error) {
122 url, err := ac.createAPIURL(path, method, httpOptions)
123 if err != nil {
124 return nil, err
125 }
126 b := new(bytes.Buffer)
127 if len(body) > 0 {
128 if err := json.NewEncoder(b).Encode(body); err != nil {
129 return nil, fmt.Errorf("buildRequest: error encoding body %#v: %w", body, err)
130 }
131 }
132
133 // Create a new HTTP request
134 req, err := http.NewRequestWithContext(ctx, method, url.String(), b)
135 if err != nil {
136 return nil, err
137 }
138 // Set headers
139 doMergeHeaders(httpOptions.Headers, &req.Header)
140 doMergeHeaders(sdkHeader(ctx, ac), &req.Header)
141 return req, nil
142}
143
144func sdkHeader(ctx context.Context, ac *apiClient) http.Header {
145 header := make(http.Header)
146 header.Set("Content-Type", "application/json")
147 if ac.clientConfig.APIKey != "" {
148 header.Set("x-goog-api-key", ac.clientConfig.APIKey)
149 }
150 libraryLabel := fmt.Sprintf("google-genai-sdk/%s", version)
151 languageLabel := fmt.Sprintf("gl-go/%s", runtime.Version())
152 versionHeaderValue := fmt.Sprintf("%s %s", libraryLabel, languageLabel)
153 header.Set("user-agent", versionHeaderValue)
154 header.Set("x-goog-api-client", versionHeaderValue)
155 timeoutSeconds := inferTimeout(ctx, ac).Seconds()
156 if timeoutSeconds > 0 {
157 header.Set("x-server-timeout", strconv.FormatInt(int64(timeoutSeconds), 10))
158 }
159 return header
160}
161
162func inferTimeout(ctx context.Context, ac *apiClient) time.Duration {
163 // ac.clientConfig.HTTPClient is not nil because it's initialized in the NewClient function.
164 requestTimeout := ac.clientConfig.HTTPClient.Timeout
165 contextTimeout := 0 * time.Second
166 if deadline, ok := ctx.Deadline(); ok {
167 contextTimeout = time.Until(deadline)
168 }
169 if requestTimeout != 0 && contextTimeout != 0 {
170 return min(requestTimeout, contextTimeout)
171 }
172 if requestTimeout != 0 {
173 return requestTimeout
174 }
175 return contextTimeout
176}
177
178func doRequest(ac *apiClient, req *http.Request) (*http.Response, error) {
179 // Create a new HTTP client and send the request
180 client := ac.clientConfig.HTTPClient
181 resp, err := client.Do(req)
182 if err != nil {
183 return nil, fmt.Errorf("doRequest: error sending request: %w", err)
184 }
185 return resp, nil
186}
187
188func deserializeUnaryResponse(resp *http.Response) (map[string]any, error) {
189 if !httpStatusOk(resp) {
190 return nil, newAPIError(resp)
191 }
192 respBody, err := io.ReadAll(resp.Body)
193 if err != nil {
194 return nil, err
195 }
196
197 output := make(map[string]any)
198 if len(respBody) > 0 {
199 err = json.Unmarshal(respBody, &output)
200 if err != nil {
201 return nil, fmt.Errorf("deserializeUnaryResponse: error unmarshalling response: %w\n%s", err, respBody)
202 }
203 }
204 output["httpHeaders"] = resp.Header
205 return output, nil
206}
207
208type responseStream[R any] struct {
209 r *bufio.Scanner
210 rc io.ReadCloser
211}
212
213func iterateResponseStream[R any](rs *responseStream[R], responseConverter func(responseMap map[string]any) (*R, error)) iter.Seq2[*R, error] {
214 return func(yield func(*R, error) bool) {
215 defer func() {
216 // Close the response body range over function is done.
217 if err := rs.rc.Close(); err != nil {
218 log.Printf("Error closing response body: %v", err)
219 }
220 }()
221 for rs.r.Scan() {
222 line := rs.r.Bytes()
223 if len(line) == 0 {
224 continue
225 }
226 prefix, data, _ := bytes.Cut(line, []byte(":"))
227 switch string(prefix) {
228 case "data":
229 // Step 1: Unmarshal the JSON into a map[string]any so that we can call fromConverter
230 // in Step 2.
231 respRaw := make(map[string]any)
232 if err := json.Unmarshal(data, &respRaw); err != nil {
233 err = fmt.Errorf("iterateResponseStream: error unmarshalling data %s:%s. error: %w", string(prefix), string(data), err)
234 if !yield(nil, err) {
235 return
236 }
237 }
238 // Step 2: The toStruct function calls fromConverter(handle Vertex and MLDev schema
239 // difference and get a unified response). Then toStruct function converts the unified
240 // response from map[string]any to struct type.
241 // var resp = new(R)
242 resp, err := responseConverter(respRaw)
243 if err != nil {
244 if !yield(nil, err) {
245 return
246 }
247 }
248
249 // Step 3: yield the response.
250 if !yield(resp, nil) {
251 return
252 }
253 default:
254 // Stream chunk not started with "data" is treated as an error.
255 if !yield(nil, fmt.Errorf("iterateResponseStream: invalid stream chunk: %s:%s", string(prefix), string(data))) {
256 return
257 }
258 }
259 }
260 if rs.r.Err() != nil {
261 if rs.r.Err() == bufio.ErrTooLong {
262 log.Printf("The response is too large to process in streaming mode. Please use a non-streaming method.")
263 }
264 log.Printf("Error %v", rs.r.Err())
265 }
266 }
267}
268
269// APIError contains an error response from the server.
270type APIError struct {
271 // Code is the HTTP response status code.
272 Code int `json:"code,omitempty"`
273 // Message is the server response message.
274 Message string `json:"message,omitempty"`
275 // Status is the server response status.
276 Status string `json:"status,omitempty"`
277 // Details field provides more context to an error.
278 Details []map[string]any `json:"details,omitempty"`
279}
280
281type responseWithError struct {
282 ErrorInfo *APIError `json:"error,omitempty"`
283}
284
285func newAPIError(resp *http.Response) error {
286 var respWithError = new(responseWithError)
287 body, err := io.ReadAll(resp.Body)
288 if err != nil {
289 return fmt.Errorf("newAPIError: error reading response body: %w. Response: %v", err, string(body))
290 }
291
292 if len(body) > 0 {
293 if err := json.Unmarshal(body, respWithError); err != nil {
294 // Handle plain text error message. File upload backend doesn't return json error message.
295 return APIError{Code: resp.StatusCode, Status: resp.Status, Message: string(body)}
296 }
297 return *respWithError.ErrorInfo
298 }
299 return APIError{Code: resp.StatusCode, Status: resp.Status}
300}
301
302// Error returns a string representation of the APIError.
303func (e APIError) Error() string {
304 return fmt.Sprintf(
305 "Error %d, Message: %s, Status: %s, Details: %v",
306 e.Code, e.Message, e.Status, e.Details,
307 )
308}
309
310func httpStatusOk(resp *http.Response) bool {
311 return resp.StatusCode >= 200 && resp.StatusCode < 300
312}
313
314func deserializeStreamResponse[T responseStream[R], R any](resp *http.Response, output *responseStream[R]) error {
315 if !httpStatusOk(resp) {
316 return newAPIError(resp)
317 }
318 output.r = bufio.NewScanner(resp.Body)
319 // Scanner default buffer max size is 64*1024 (64KB).
320 // We provide 1KB byte buffer to the scanner and set max to 256MB.
321 // When data exceed 1KB, then scanner will allocate new memory up to 256MB.
322 // When data exceed 256MB, scanner will stop and returns err: bufio.ErrTooLong.
323 output.r.Buffer(make([]byte, 1024), 268435456)
324
325 output.r.Split(scan)
326 output.rc = resp.Body
327 return nil
328}
329
330// dropCR drops a terminal \r from the data.
331func dropCR(data []byte) []byte {
332 if len(data) > 0 && data[len(data)-1] == '\r' {
333 return data[0 : len(data)-1]
334 }
335 return data
336}
337
338func scan(data []byte, atEOF bool) (advance int, token []byte, err error) {
339 if atEOF && len(data) == 0 {
340 return 0, nil, nil
341 }
342 // Look for two consecutive newlines in the data
343 if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
344 // We have a full two-newline-terminated token.
345 return i + 2, dropCR(data[0:i]), nil
346 }
347
348 // Handle the case of Windows-style newlines (\r\n\r\n)
349 if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 {
350 // We have a full Windows-style two-newline-terminated token.
351 return i + 4, dropCR(data[0:i]), nil
352 }
353
354 // If we're at EOF, we have a final, non-terminated line. Return it.
355 if atEOF {
356 return len(data), dropCR(data), nil
357 }
358 // Request more data.
359 return 0, nil, nil
360}
361
362func (ac *apiClient) uploadFile(ctx context.Context, r io.Reader, uploadURL string, httpOptions *HTTPOptions) (*File, error) {
363 var offset int64 = 0
364 var resp *http.Response
365 var respBody map[string]any
366 var uploadCommand = "upload"
367
368 buffer := make([]byte, maxChunkSize)
369 for {
370 bytesRead, err := io.ReadFull(r, buffer)
371 // Check both EOF and UnexpectedEOF errors.
372 // ErrUnexpectedEOF: Reading a file file_size%maxChunkSize<len(buffer).
373 // EOF: Reading a file file_size%maxChunkSize==0. The underlying reader return 0 bytes buffer and EOF at next call.
374 if err == io.EOF || err == io.ErrUnexpectedEOF {
375 uploadCommand += ", finalize"
376 } else if err != nil {
377 return nil, fmt.Errorf("Failed to read bytes from file at offset %d: %w. Bytes actually read: %d", offset, err, bytesRead)
378 }
379
380 req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadURL, bytes.NewReader(buffer[:bytesRead]))
381 if err != nil {
382 return nil, fmt.Errorf("Failed to create upload request for chunk at offset %d: %w", offset, err)
383 }
384 doMergeHeaders(httpOptions.Headers, &req.Header)
385 doMergeHeaders(sdkHeader(ctx, ac), &req.Header)
386
387 req.Header.Set("X-Goog-Upload-Command", uploadCommand)
388 req.Header.Set("X-Goog-Upload-Offset", strconv.FormatInt(offset, 10))
389 req.Header.Set("Content-Length", strconv.FormatInt(int64(bytesRead), 10))
390
391 resp, err = doRequest(ac, req)
392 if err != nil {
393 return nil, fmt.Errorf("upload request failed for chunk at offset %d: %w", offset, err)
394 }
395 defer resp.Body.Close()
396
397 respBody, err = deserializeUnaryResponse(resp)
398 if err != nil {
399 return nil, fmt.Errorf("response body is invalid for chunk at offset %d: %w", offset, err)
400 }
401
402 offset += int64(bytesRead)
403
404 uploadStatus := resp.Header.Get("X-Goog-Upload-Status")
405
406 if uploadStatus != "final" && strings.Contains(uploadCommand, "finalize") {
407 return nil, fmt.Errorf("send finalize command but doesn't receive final status. Offset %d, Bytes read: %d, Upload status: %s", offset, bytesRead, uploadStatus)
408 }
409 if uploadStatus != "active" {
410 // Upload is complete ('final') or interrupted ('cancelled', etc.)
411 break
412 }
413 }
414
415 if resp == nil {
416 return nil, fmt.Errorf("Upload request failed. No response received")
417 }
418
419 finalUploadStatus := resp.Header.Get("X-Goog-Upload-Status")
420 if finalUploadStatus != "final" {
421 return nil, fmt.Errorf("Failed to upload file: Upload status is not finalized")
422 }
423
424 var response = new(File)
425 err := mapToStruct(respBody["file"].(map[string]any), &response)
426 if err != nil {
427 return nil, err
428 }
429 return response, nil
430}