api_client.go

  1// Copyright 2024 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package genai
 16
 17import (
 18	"bufio"
 19	"bytes"
 20	"context"
 21	"encoding/json"
 22	"fmt"
 23	"io"
 24	"iter"
 25	"log"
 26	"net/http"
 27	"net/url"
 28	"runtime"
 29	"strconv"
 30	"strings"
 31	"time"
 32)
 33
 34const maxChunkSize = 8 * 1024 * 1024 // 8 MB chunk size
 35
 36type apiClient struct {
 37	clientConfig *ClientConfig
 38}
 39
 40// sendStreamRequest issues an server streaming API request and returns a map of the response contents.
 41func sendStreamRequest[T responseStream[R], R any](ctx context.Context, ac *apiClient, path string, method string, body map[string]any, httpOptions *HTTPOptions, output *responseStream[R]) error {
 42	req, err := buildRequest(ctx, ac, path, body, method, httpOptions)
 43	if err != nil {
 44		return err
 45	}
 46
 47	resp, err := doRequest(ac, req)
 48	if err != nil {
 49		return err
 50	}
 51
 52	// resp.Body will be closed by the iterator
 53	return deserializeStreamResponse(resp, output)
 54}
 55
 56// sendRequest issues an API request and returns a map of the response contents.
 57func sendRequest(ctx context.Context, ac *apiClient, path string, method string, body map[string]any, httpOptions *HTTPOptions) (map[string]any, error) {
 58	req, err := buildRequest(ctx, ac, path, body, method, httpOptions)
 59	if err != nil {
 60		return nil, err
 61	}
 62
 63	resp, err := doRequest(ac, req)
 64	if err != nil {
 65		return nil, err
 66	}
 67	defer resp.Body.Close()
 68
 69	return deserializeUnaryResponse(resp)
 70}
 71
 72func downloadFile(ctx context.Context, ac *apiClient, path string, httpOptions *HTTPOptions) ([]byte, error) {
 73	req, err := buildRequest(ctx, ac, path, nil, http.MethodGet, httpOptions)
 74	if err != nil {
 75		return nil, err
 76	}
 77
 78	resp, err := doRequest(ac, req)
 79	if err != nil {
 80		return nil, err
 81	}
 82	return io.ReadAll(resp.Body)
 83}
 84
 85func mapToStruct[R any](input map[string]any, output *R) error {
 86	b := new(bytes.Buffer)
 87	err := json.NewEncoder(b).Encode(input)
 88	if err != nil {
 89		return fmt.Errorf("mapToStruct: error encoding input %#v: %w", input, err)
 90	}
 91	err = json.Unmarshal(b.Bytes(), output)
 92	if err != nil {
 93		return fmt.Errorf("mapToStruct: error unmarshalling input %#v: %w", input, err)
 94	}
 95	return nil
 96}
 97
 98func (ac *apiClient) createAPIURL(suffix, method string, httpOptions *HTTPOptions) (*url.URL, error) {
 99	if ac.clientConfig.Backend == BackendVertexAI {
100		queryVertexBaseModel := ac.clientConfig.Backend == BackendVertexAI && method == http.MethodGet && strings.HasPrefix(suffix, "publishers/google/models")
101		if !strings.HasPrefix(suffix, "projects/") && !queryVertexBaseModel {
102			suffix = fmt.Sprintf("projects/%s/locations/%s/%s", ac.clientConfig.Project, ac.clientConfig.Location, suffix)
103		}
104		u, err := url.Parse(fmt.Sprintf("%s/%s/%s", httpOptions.BaseURL, httpOptions.APIVersion, suffix))
105		if err != nil {
106			return nil, fmt.Errorf("createAPIURL: error parsing Vertex AI URL: %w", err)
107		}
108		return u, nil
109	} else {
110		if !strings.Contains(suffix, fmt.Sprintf("/%s/", httpOptions.APIVersion)) {
111			suffix = fmt.Sprintf("%s/%s", httpOptions.APIVersion, suffix)
112		}
113		u, err := url.Parse(fmt.Sprintf("%s/%s", httpOptions.BaseURL, suffix))
114		if err != nil {
115			return nil, fmt.Errorf("createAPIURL: error parsing ML Dev URL: %w", err)
116		}
117		return u, nil
118	}
119}
120
121func buildRequest(ctx context.Context, ac *apiClient, path string, body map[string]any, method string, httpOptions *HTTPOptions) (*http.Request, error) {
122	url, err := ac.createAPIURL(path, method, httpOptions)
123	if err != nil {
124		return nil, err
125	}
126	b := new(bytes.Buffer)
127	if len(body) > 0 {
128		if err := json.NewEncoder(b).Encode(body); err != nil {
129			return nil, fmt.Errorf("buildRequest: error encoding body %#v: %w", body, err)
130		}
131	}
132
133	// Create a new HTTP request
134	req, err := http.NewRequestWithContext(ctx, method, url.String(), b)
135	if err != nil {
136		return nil, err
137	}
138	// Set headers
139	doMergeHeaders(httpOptions.Headers, &req.Header)
140	doMergeHeaders(sdkHeader(ctx, ac), &req.Header)
141	return req, nil
142}
143
144func sdkHeader(ctx context.Context, ac *apiClient) http.Header {
145	header := make(http.Header)
146	header.Set("Content-Type", "application/json")
147	if ac.clientConfig.APIKey != "" {
148		header.Set("x-goog-api-key", ac.clientConfig.APIKey)
149	}
150	libraryLabel := fmt.Sprintf("google-genai-sdk/%s", version)
151	languageLabel := fmt.Sprintf("gl-go/%s", runtime.Version())
152	versionHeaderValue := fmt.Sprintf("%s %s", libraryLabel, languageLabel)
153	header.Set("user-agent", versionHeaderValue)
154	header.Set("x-goog-api-client", versionHeaderValue)
155	timeoutSeconds := inferTimeout(ctx, ac).Seconds()
156	if timeoutSeconds > 0 {
157		header.Set("x-server-timeout", strconv.FormatInt(int64(timeoutSeconds), 10))
158	}
159	return header
160}
161
162func inferTimeout(ctx context.Context, ac *apiClient) time.Duration {
163	// ac.clientConfig.HTTPClient is not nil because it's initialized in the NewClient function.
164	requestTimeout := ac.clientConfig.HTTPClient.Timeout
165	contextTimeout := 0 * time.Second
166	if deadline, ok := ctx.Deadline(); ok {
167		contextTimeout = time.Until(deadline)
168	}
169	if requestTimeout != 0 && contextTimeout != 0 {
170		return min(requestTimeout, contextTimeout)
171	}
172	if requestTimeout != 0 {
173		return requestTimeout
174	}
175	return contextTimeout
176}
177
178func doRequest(ac *apiClient, req *http.Request) (*http.Response, error) {
179	// Create a new HTTP client and send the request
180	client := ac.clientConfig.HTTPClient
181	resp, err := client.Do(req)
182	if err != nil {
183		return nil, fmt.Errorf("doRequest: error sending request: %w", err)
184	}
185	return resp, nil
186}
187
188func deserializeUnaryResponse(resp *http.Response) (map[string]any, error) {
189	if !httpStatusOk(resp) {
190		return nil, newAPIError(resp)
191	}
192	respBody, err := io.ReadAll(resp.Body)
193	if err != nil {
194		return nil, err
195	}
196
197	output := make(map[string]any)
198	if len(respBody) > 0 {
199		err = json.Unmarshal(respBody, &output)
200		if err != nil {
201			return nil, fmt.Errorf("deserializeUnaryResponse: error unmarshalling response: %w\n%s", err, respBody)
202		}
203	}
204	output["httpHeaders"] = resp.Header
205	return output, nil
206}
207
208type responseStream[R any] struct {
209	r  *bufio.Scanner
210	rc io.ReadCloser
211}
212
213func iterateResponseStream[R any](rs *responseStream[R], responseConverter func(responseMap map[string]any) (*R, error)) iter.Seq2[*R, error] {
214	return func(yield func(*R, error) bool) {
215		defer func() {
216			// Close the response body range over function is done.
217			if err := rs.rc.Close(); err != nil {
218				log.Printf("Error closing response body: %v", err)
219			}
220		}()
221		for rs.r.Scan() {
222			line := rs.r.Bytes()
223			if len(line) == 0 {
224				continue
225			}
226			prefix, data, _ := bytes.Cut(line, []byte(":"))
227			switch string(prefix) {
228			case "data":
229				// Step 1: Unmarshal the JSON into a map[string]any so that we can call fromConverter
230				// in Step 2.
231				respRaw := make(map[string]any)
232				if err := json.Unmarshal(data, &respRaw); err != nil {
233					err = fmt.Errorf("iterateResponseStream: error unmarshalling data %s:%s. error: %w", string(prefix), string(data), err)
234					if !yield(nil, err) {
235						return
236					}
237				}
238				// Step 2: The toStruct function calls fromConverter(handle Vertex and MLDev schema
239				// difference and get a unified response). Then toStruct function converts the unified
240				// response from map[string]any to struct type.
241				// var resp = new(R)
242				resp, err := responseConverter(respRaw)
243				if err != nil {
244					if !yield(nil, err) {
245						return
246					}
247				}
248
249				// Step 3: yield the response.
250				if !yield(resp, nil) {
251					return
252				}
253			default:
254				// Stream chunk not started with "data" is treated as an error.
255				if !yield(nil, fmt.Errorf("iterateResponseStream: invalid stream chunk: %s:%s", string(prefix), string(data))) {
256					return
257				}
258			}
259		}
260		if rs.r.Err() != nil {
261			if rs.r.Err() == bufio.ErrTooLong {
262				log.Printf("The response is too large to process in streaming mode. Please use a non-streaming method.")
263			}
264			log.Printf("Error %v", rs.r.Err())
265		}
266	}
267}
268
269// APIError contains an error response from the server.
270type APIError struct {
271	// Code is the HTTP response status code.
272	Code int `json:"code,omitempty"`
273	// Message is the server response message.
274	Message string `json:"message,omitempty"`
275	// Status is the server response status.
276	Status string `json:"status,omitempty"`
277	// Details field provides more context to an error.
278	Details []map[string]any `json:"details,omitempty"`
279}
280
281type responseWithError struct {
282	ErrorInfo *APIError `json:"error,omitempty"`
283}
284
285func newAPIError(resp *http.Response) error {
286	var respWithError = new(responseWithError)
287	body, err := io.ReadAll(resp.Body)
288	if err != nil {
289		return fmt.Errorf("newAPIError: error reading response body: %w. Response: %v", err, string(body))
290	}
291
292	if len(body) > 0 {
293		if err := json.Unmarshal(body, respWithError); err != nil {
294			// Handle plain text error message. File upload backend doesn't return json error message.
295			return APIError{Code: resp.StatusCode, Status: resp.Status, Message: string(body)}
296		}
297		return *respWithError.ErrorInfo
298	}
299	return APIError{Code: resp.StatusCode, Status: resp.Status}
300}
301
302// Error returns a string representation of the APIError.
303func (e APIError) Error() string {
304	return fmt.Sprintf(
305		"Error %d, Message: %s, Status: %s, Details: %v",
306		e.Code, e.Message, e.Status, e.Details,
307	)
308}
309
310func httpStatusOk(resp *http.Response) bool {
311	return resp.StatusCode >= 200 && resp.StatusCode < 300
312}
313
314func deserializeStreamResponse[T responseStream[R], R any](resp *http.Response, output *responseStream[R]) error {
315	if !httpStatusOk(resp) {
316		return newAPIError(resp)
317	}
318	output.r = bufio.NewScanner(resp.Body)
319	// Scanner default buffer max size is 64*1024 (64KB).
320	// We provide 1KB byte buffer to the scanner and set max to 256MB.
321	// When data exceed 1KB, then scanner will allocate new memory up to 256MB.
322	// When data exceed 256MB, scanner will stop and returns err: bufio.ErrTooLong.
323	output.r.Buffer(make([]byte, 1024), 268435456)
324
325	output.r.Split(scan)
326	output.rc = resp.Body
327	return nil
328}
329
330// dropCR drops a terminal \r from the data.
331func dropCR(data []byte) []byte {
332	if len(data) > 0 && data[len(data)-1] == '\r' {
333		return data[0 : len(data)-1]
334	}
335	return data
336}
337
338func scan(data []byte, atEOF bool) (advance int, token []byte, err error) {
339	if atEOF && len(data) == 0 {
340		return 0, nil, nil
341	}
342	// Look for two consecutive newlines in the data
343	if i := bytes.Index(data, []byte("\n\n")); i >= 0 {
344		// We have a full two-newline-terminated token.
345		return i + 2, dropCR(data[0:i]), nil
346	}
347
348	// Handle the case of Windows-style newlines (\r\n\r\n)
349	if i := bytes.Index(data, []byte("\r\n\r\n")); i >= 0 {
350		// We have a full Windows-style two-newline-terminated token.
351		return i + 4, dropCR(data[0:i]), nil
352	}
353
354	// If we're at EOF, we have a final, non-terminated line. Return it.
355	if atEOF {
356		return len(data), dropCR(data), nil
357	}
358	// Request more data.
359	return 0, nil, nil
360}
361
362func (ac *apiClient) uploadFile(ctx context.Context, r io.Reader, uploadURL string, httpOptions *HTTPOptions) (*File, error) {
363	var offset int64 = 0
364	var resp *http.Response
365	var respBody map[string]any
366	var uploadCommand = "upload"
367
368	buffer := make([]byte, maxChunkSize)
369	for {
370		bytesRead, err := io.ReadFull(r, buffer)
371		// Check both EOF and UnexpectedEOF errors.
372		// ErrUnexpectedEOF: Reading a file file_size%maxChunkSize<len(buffer).
373		// EOF: Reading a file file_size%maxChunkSize==0. The underlying reader return 0 bytes buffer and EOF at next call.
374		if err == io.EOF || err == io.ErrUnexpectedEOF {
375			uploadCommand += ", finalize"
376		} else if err != nil {
377			return nil, fmt.Errorf("Failed to read bytes from file at offset %d: %w. Bytes actually read: %d", offset, err, bytesRead)
378		}
379
380		req, err := http.NewRequestWithContext(ctx, http.MethodPost, uploadURL, bytes.NewReader(buffer[:bytesRead]))
381		if err != nil {
382			return nil, fmt.Errorf("Failed to create upload request for chunk at offset %d: %w", offset, err)
383		}
384		doMergeHeaders(httpOptions.Headers, &req.Header)
385		doMergeHeaders(sdkHeader(ctx, ac), &req.Header)
386
387		req.Header.Set("X-Goog-Upload-Command", uploadCommand)
388		req.Header.Set("X-Goog-Upload-Offset", strconv.FormatInt(offset, 10))
389		req.Header.Set("Content-Length", strconv.FormatInt(int64(bytesRead), 10))
390
391		resp, err = doRequest(ac, req)
392		if err != nil {
393			return nil, fmt.Errorf("upload request failed for chunk at offset %d: %w", offset, err)
394		}
395		defer resp.Body.Close()
396
397		respBody, err = deserializeUnaryResponse(resp)
398		if err != nil {
399			return nil, fmt.Errorf("response body is invalid for chunk at offset %d: %w", offset, err)
400		}
401
402		offset += int64(bytesRead)
403
404		uploadStatus := resp.Header.Get("X-Goog-Upload-Status")
405
406		if uploadStatus != "final" && strings.Contains(uploadCommand, "finalize") {
407			return nil, fmt.Errorf("send finalize command but doesn't receive final status. Offset %d, Bytes read: %d, Upload status: %s", offset, bytesRead, uploadStatus)
408		}
409		if uploadStatus != "active" {
410			// Upload is complete ('final') or interrupted ('cancelled', etc.)
411			break
412		}
413	}
414
415	if resp == nil {
416		return nil, fmt.Errorf("Upload request failed. No response received")
417	}
418
419	finalUploadStatus := resp.Header.Get("X-Goog-Upload-Status")
420	if finalUploadStatus != "final" {
421		return nil, fmt.Errorf("Failed to upload file: Upload status is not finalized")
422	}
423
424	var response = new(File)
425	err := mapToStruct(respBody["file"].(map[string]any), &response)
426	if err != nil {
427		return nil, err
428	}
429	return response, nil
430}