common.go

  1// Copyright 2024 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package genai
 16
 17import (
 18	"bytes"
 19	"encoding/json"
 20	"errors"
 21	"fmt"
 22	"iter"
 23	"net/http"
 24	"net/url"
 25	"reflect"
 26	"sort"
 27	"strconv"
 28	"strings"
 29)
 30
 31// Ptr returns a pointer to its argument.
 32// It can be used to initialize pointer fields:
 33//
 34//	genai.GenerateContentConfig{Temperature: genai.Ptr(0.5)}
 35func Ptr[T any](t T) *T { return &t }
 36
 37type converterFunc func(*apiClient, map[string]any, map[string]any) (map[string]any, error)
 38
 39type transformerFunc[T any] func(*apiClient, T) (T, error)
 40
 41// setValueByPath handles setting values within nested maps, including handling array-like structures.
 42//
 43// Examples:
 44//
 45//	setValueByPath(map[string]any{}, []string{"a", "b"}, v)
 46//	  -> {"a": {"b": v}}
 47//
 48//	setValueByPath(map[string]any{}, []string{"a", "b[]", "c"}, []any{v1, v2})
 49//	  -> {"a": {"b": [{"c": v1}, {"c": v2}]}}
 50//
 51//	setValueByPath(map[string]any{"a": {"b": [{"c": v1}, {"c": v2}]}}, []string{"a", "b[]", "d"}, v3)
 52//	  -> {"a": {"b": [{"c": v1, "d": v3}, {"c": v2, "d": v3}]}}
 53func setValueByPath(data map[string]any, keys []string, value any) {
 54	if value == nil {
 55		return
 56	}
 57	for i, key := range keys[:len(keys)-1] {
 58		if strings.HasSuffix(key, "[]") {
 59			keyName := key[:len(key)-2]
 60			if _, ok := data[keyName]; !ok {
 61				if reflect.ValueOf(value).Kind() == reflect.Slice {
 62					data[keyName] = make([]map[string]any, reflect.ValueOf(value).Len())
 63				} else {
 64					data[keyName] = make([]map[string]any, 1)
 65				}
 66				for k := range data[keyName].([]map[string]any) {
 67					data[keyName].([]map[string]any)[k] = make(map[string]any)
 68				}
 69			}
 70
 71			if reflect.ValueOf(value).Kind() == reflect.Slice {
 72				for j, d := range data[keyName].([]map[string]any) {
 73					if j >= reflect.ValueOf(value).Len() {
 74						continue
 75					}
 76					setValueByPath(d, keys[i+1:], reflect.ValueOf(value).Index(j).Interface())
 77				}
 78			} else {
 79				for _, d := range data[keyName].([]map[string]any) {
 80					setValueByPath(d, keys[i+1:], value)
 81				}
 82			}
 83			return
 84		} else if strings.HasSuffix(key, "[0]") {
 85			keyName := key[:len(key)-3]
 86			if _, ok := data[keyName]; !ok {
 87				data[keyName] = make([]map[string]any, 1)
 88				data[keyName].([]map[string]any)[0] = make(map[string]any)
 89			}
 90			setValueByPath(data[keyName].([]map[string]any)[0], keys[i+1:], value)
 91			return
 92		} else {
 93			if _, ok := data[key]; !ok {
 94				data[key] = make(map[string]any)
 95			}
 96			if _, ok := data[key].(map[string]any); !ok {
 97				data[key] = make(map[string]any)
 98			}
 99			data = data[key].(map[string]any)
100		}
101	}
102
103	data[keys[len(keys)-1]] = value
104}
105
106// getValueByPath retrieves a value from a nested map or slice or struct based on a path of keys.
107//
108// Examples:
109//
110//	getValueByPath(map[string]any{"a": {"b": "v"}}, []string{"a", "b"})
111//	  -> "v"
112//	getValueByPath(map[string]any{"a": {"b": [{"c": "v1"}, {"c": "v2"}]}}, []string{"a", "b[]", "c"})
113//	  -> []any{"v1", "v2"}
114func getValueByPath(data any, keys []string) any {
115	if len(keys) == 1 && keys[0] == "_self" {
116		return data
117	}
118	if len(keys) == 0 {
119		return nil
120	}
121	var current any = data
122	for i, key := range keys {
123		if strings.HasSuffix(key, "[]") {
124			keyName := key[:len(key)-2]
125			switch v := current.(type) {
126			case map[string]any:
127				if sliceData, ok := v[keyName]; ok {
128					var result []any
129					switch concreteSliceData := sliceData.(type) {
130					case []map[string]any:
131						for _, d := range concreteSliceData {
132							result = append(result, getValueByPath(d, keys[i+1:]))
133						}
134					case []any:
135						for _, d := range concreteSliceData {
136							result = append(result, getValueByPath(d, keys[i+1:]))
137						}
138					default:
139						return nil
140					}
141					return result
142				} else {
143					return nil
144				}
145			default:
146				return nil
147			}
148		} else {
149			switch v := current.(type) {
150			case map[string]any:
151				current = v[key]
152			default:
153				return nil
154			}
155		}
156	}
157	return current
158}
159
160func formatMap(template string, variables map[string]any) (string, error) {
161	var buffer bytes.Buffer
162	for i := 0; i < len(template); i++ {
163		if template[i] == '{' {
164			j := i + 1
165			for j < len(template) && template[j] != '}' {
166				j++
167			}
168			if j < len(template) {
169				key := template[i+1 : j]
170				if value, ok := variables[key]; ok {
171					switch val := value.(type) {
172					case string:
173						buffer.WriteString(val)
174					default:
175						return "", errors.New("formatMap: nested interface or unsupported type found")
176					}
177				}
178				i = j
179			}
180		} else {
181			buffer.WriteByte(template[i])
182		}
183	}
184	return buffer.String(), nil
185}
186
187// applyConverterToSlice calls converter function to each element of the slice.
188func applyConverterToSlice(ac *apiClient, inputs []any, converter converterFunc) ([]map[string]any, error) {
189	var outputs []map[string]any
190	for _, object := range inputs {
191		object, err := converter(ac, object.(map[string]any), nil)
192		if err != nil {
193			return nil, err
194		}
195		outputs = append(outputs, object)
196	}
197	return outputs, nil
198}
199
200// applyItemTransformerToSlice calls item transformer function to each element of the slice.
201func applyItemTransformerToSlice[T any](ac *apiClient, inputs []T, itemTransformer transformerFunc[T]) ([]T, error) {
202	var outputs []T
203	for _, input := range inputs {
204		object, err := itemTransformer(ac, input)
205		if err != nil {
206			return nil, err
207		}
208		outputs = append(outputs, object)
209	}
210	return outputs, nil
211}
212
213func deepMarshal(input any, output *map[string]any) error {
214	if inputBytes, err := json.Marshal(input); err != nil {
215		return fmt.Errorf("deepMarshal: unable to marshal input: %w", err)
216	} else if err := json.Unmarshal(inputBytes, output); err != nil {
217		return fmt.Errorf("deepMarshal: unable to unmarshal input: %w", err)
218	}
219	return nil
220}
221
222func deepCopy[T any](original T, copied *T) error {
223	bytes, err := json.Marshal(original)
224	if err != nil {
225		return err
226	}
227
228	err = json.Unmarshal(bytes, copied)
229	return err
230}
231
232// createURLQuery creates a URL query string from a map of key-value pairs.
233// The keys are sorted alphabetically before being encoded.
234// Supported value types are string, int, float64, bool, and []string.
235// An error is returned if an unsupported type is encountered.
236func createURLQuery(query map[string]any) (string, error) {
237	v := url.Values{}
238	keys := make([]string, 0, len(query))
239	for k := range query {
240		keys = append(keys, k)
241	}
242	sort.Strings(keys)
243	for _, key := range keys {
244		value := query[key]
245		switch value := value.(type) {
246		case string:
247			v.Add(key, value)
248		case int:
249			v.Add(key, strconv.Itoa(value))
250		case float64:
251			v.Add(key, strconv.FormatFloat(value, 'f', -1, 64))
252		case bool:
253			v.Add(key, strconv.FormatBool(value))
254		case []string:
255			for _, item := range value {
256				v.Add(key, item)
257			}
258		default:
259			return "", fmt.Errorf("unsupported type: %T", value)
260		}
261	}
262	return v.Encode(), nil
263}
264
265func yieldErrorAndEndIterator[T any](err error) iter.Seq2[*T, error] {
266	return func(yield func(*T, error) bool) {
267		if !yield(nil, err) {
268			return
269		}
270	}
271}
272
273func mergeHTTPOptions(clientConfig *ClientConfig, configHTTPOptions *HTTPOptions) *HTTPOptions {
274	var clientHTTPOptions *HTTPOptions
275	if clientConfig != nil {
276		clientHTTPOptions = &(clientConfig.HTTPOptions)
277	}
278
279	result := HTTPOptions{}
280	if clientHTTPOptions == nil && configHTTPOptions == nil {
281		return nil
282	} else if clientHTTPOptions == nil {
283		result = HTTPOptions{
284			BaseURL:    configHTTPOptions.BaseURL,
285			APIVersion: configHTTPOptions.APIVersion,
286		}
287	} else {
288		result = HTTPOptions{
289			BaseURL:    clientHTTPOptions.BaseURL,
290			APIVersion: clientHTTPOptions.APIVersion,
291		}
292	}
293
294	if configHTTPOptions != nil {
295		if configHTTPOptions.BaseURL != "" {
296			result.BaseURL = configHTTPOptions.BaseURL
297		}
298		if configHTTPOptions.APIVersion != "" {
299			result.APIVersion = configHTTPOptions.APIVersion
300		}
301	}
302	result.Headers = mergeHeaders(clientHTTPOptions, configHTTPOptions)
303	return &result
304}
305
306func mergeHeaders(clientHTTPOptions *HTTPOptions, configHTTPOptions *HTTPOptions) http.Header {
307	result := http.Header{}
308	if clientHTTPOptions == nil && configHTTPOptions == nil {
309		return result
310	}
311
312	if clientHTTPOptions != nil {
313		doMergeHeaders(clientHTTPOptions.Headers, &result)
314	}
315	// configHTTPOptions takes precedence over clientHTTPOptions.
316	if configHTTPOptions != nil {
317		doMergeHeaders(configHTTPOptions.Headers, &result)
318	}
319	return result
320}
321
322func doMergeHeaders(input http.Header, output *http.Header) {
323	for k, v := range input {
324		for _, vv := range v {
325			output.Add(k, vv)
326		}
327	}
328}