1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package genai
16
17import (
18 "bytes"
19 "encoding/json"
20 "errors"
21 "fmt"
22 "iter"
23 "net/http"
24 "net/url"
25 "reflect"
26 "sort"
27 "strconv"
28 "strings"
29)
30
31// Ptr returns a pointer to its argument.
32// It can be used to initialize pointer fields:
33//
34// genai.GenerateContentConfig{Temperature: genai.Ptr(0.5)}
35func Ptr[T any](t T) *T { return &t }
36
37type converterFunc func(*apiClient, map[string]any, map[string]any) (map[string]any, error)
38
39type transformerFunc[T any] func(*apiClient, T) (T, error)
40
41// setValueByPath handles setting values within nested maps, including handling array-like structures.
42//
43// Examples:
44//
45// setValueByPath(map[string]any{}, []string{"a", "b"}, v)
46// -> {"a": {"b": v}}
47//
48// setValueByPath(map[string]any{}, []string{"a", "b[]", "c"}, []any{v1, v2})
49// -> {"a": {"b": [{"c": v1}, {"c": v2}]}}
50//
51// setValueByPath(map[string]any{"a": {"b": [{"c": v1}, {"c": v2}]}}, []string{"a", "b[]", "d"}, v3)
52// -> {"a": {"b": [{"c": v1, "d": v3}, {"c": v2, "d": v3}]}}
53func setValueByPath(data map[string]any, keys []string, value any) {
54 if value == nil {
55 return
56 }
57 for i, key := range keys[:len(keys)-1] {
58 if strings.HasSuffix(key, "[]") {
59 keyName := key[:len(key)-2]
60 if _, ok := data[keyName]; !ok {
61 if reflect.ValueOf(value).Kind() == reflect.Slice {
62 data[keyName] = make([]map[string]any, reflect.ValueOf(value).Len())
63 } else {
64 data[keyName] = make([]map[string]any, 1)
65 }
66 for k := range data[keyName].([]map[string]any) {
67 data[keyName].([]map[string]any)[k] = make(map[string]any)
68 }
69 }
70
71 if reflect.ValueOf(value).Kind() == reflect.Slice {
72 for j, d := range data[keyName].([]map[string]any) {
73 if j >= reflect.ValueOf(value).Len() {
74 continue
75 }
76 setValueByPath(d, keys[i+1:], reflect.ValueOf(value).Index(j).Interface())
77 }
78 } else {
79 for _, d := range data[keyName].([]map[string]any) {
80 setValueByPath(d, keys[i+1:], value)
81 }
82 }
83 return
84 } else if strings.HasSuffix(key, "[0]") {
85 keyName := key[:len(key)-3]
86 if _, ok := data[keyName]; !ok {
87 data[keyName] = make([]map[string]any, 1)
88 data[keyName].([]map[string]any)[0] = make(map[string]any)
89 }
90 setValueByPath(data[keyName].([]map[string]any)[0], keys[i+1:], value)
91 return
92 } else {
93 if _, ok := data[key]; !ok {
94 data[key] = make(map[string]any)
95 }
96 if _, ok := data[key].(map[string]any); !ok {
97 data[key] = make(map[string]any)
98 }
99 data = data[key].(map[string]any)
100 }
101 }
102
103 data[keys[len(keys)-1]] = value
104}
105
106// getValueByPath retrieves a value from a nested map or slice or struct based on a path of keys.
107//
108// Examples:
109//
110// getValueByPath(map[string]any{"a": {"b": "v"}}, []string{"a", "b"})
111// -> "v"
112// getValueByPath(map[string]any{"a": {"b": [{"c": "v1"}, {"c": "v2"}]}}, []string{"a", "b[]", "c"})
113// -> []any{"v1", "v2"}
114func getValueByPath(data any, keys []string) any {
115 if len(keys) == 1 && keys[0] == "_self" {
116 return data
117 }
118 if len(keys) == 0 {
119 return nil
120 }
121 var current any = data
122 for i, key := range keys {
123 if strings.HasSuffix(key, "[]") {
124 keyName := key[:len(key)-2]
125 switch v := current.(type) {
126 case map[string]any:
127 if sliceData, ok := v[keyName]; ok {
128 var result []any
129 switch concreteSliceData := sliceData.(type) {
130 case []map[string]any:
131 for _, d := range concreteSliceData {
132 result = append(result, getValueByPath(d, keys[i+1:]))
133 }
134 case []any:
135 for _, d := range concreteSliceData {
136 result = append(result, getValueByPath(d, keys[i+1:]))
137 }
138 default:
139 return nil
140 }
141 return result
142 } else {
143 return nil
144 }
145 default:
146 return nil
147 }
148 } else {
149 switch v := current.(type) {
150 case map[string]any:
151 current = v[key]
152 default:
153 return nil
154 }
155 }
156 }
157 return current
158}
159
160func formatMap(template string, variables map[string]any) (string, error) {
161 var buffer bytes.Buffer
162 for i := 0; i < len(template); i++ {
163 if template[i] == '{' {
164 j := i + 1
165 for j < len(template) && template[j] != '}' {
166 j++
167 }
168 if j < len(template) {
169 key := template[i+1 : j]
170 if value, ok := variables[key]; ok {
171 switch val := value.(type) {
172 case string:
173 buffer.WriteString(val)
174 default:
175 return "", errors.New("formatMap: nested interface or unsupported type found")
176 }
177 }
178 i = j
179 }
180 } else {
181 buffer.WriteByte(template[i])
182 }
183 }
184 return buffer.String(), nil
185}
186
187// applyConverterToSlice calls converter function to each element of the slice.
188func applyConverterToSlice(ac *apiClient, inputs []any, converter converterFunc) ([]map[string]any, error) {
189 var outputs []map[string]any
190 for _, object := range inputs {
191 object, err := converter(ac, object.(map[string]any), nil)
192 if err != nil {
193 return nil, err
194 }
195 outputs = append(outputs, object)
196 }
197 return outputs, nil
198}
199
200// applyItemTransformerToSlice calls item transformer function to each element of the slice.
201func applyItemTransformerToSlice[T any](ac *apiClient, inputs []T, itemTransformer transformerFunc[T]) ([]T, error) {
202 var outputs []T
203 for _, input := range inputs {
204 object, err := itemTransformer(ac, input)
205 if err != nil {
206 return nil, err
207 }
208 outputs = append(outputs, object)
209 }
210 return outputs, nil
211}
212
213func deepMarshal(input any, output *map[string]any) error {
214 if inputBytes, err := json.Marshal(input); err != nil {
215 return fmt.Errorf("deepMarshal: unable to marshal input: %w", err)
216 } else if err := json.Unmarshal(inputBytes, output); err != nil {
217 return fmt.Errorf("deepMarshal: unable to unmarshal input: %w", err)
218 }
219 return nil
220}
221
222func deepCopy[T any](original T, copied *T) error {
223 bytes, err := json.Marshal(original)
224 if err != nil {
225 return err
226 }
227
228 err = json.Unmarshal(bytes, copied)
229 return err
230}
231
232// createURLQuery creates a URL query string from a map of key-value pairs.
233// The keys are sorted alphabetically before being encoded.
234// Supported value types are string, int, float64, bool, and []string.
235// An error is returned if an unsupported type is encountered.
236func createURLQuery(query map[string]any) (string, error) {
237 v := url.Values{}
238 keys := make([]string, 0, len(query))
239 for k := range query {
240 keys = append(keys, k)
241 }
242 sort.Strings(keys)
243 for _, key := range keys {
244 value := query[key]
245 switch value := value.(type) {
246 case string:
247 v.Add(key, value)
248 case int:
249 v.Add(key, strconv.Itoa(value))
250 case float64:
251 v.Add(key, strconv.FormatFloat(value, 'f', -1, 64))
252 case bool:
253 v.Add(key, strconv.FormatBool(value))
254 case []string:
255 for _, item := range value {
256 v.Add(key, item)
257 }
258 default:
259 return "", fmt.Errorf("unsupported type: %T", value)
260 }
261 }
262 return v.Encode(), nil
263}
264
265func yieldErrorAndEndIterator[T any](err error) iter.Seq2[*T, error] {
266 return func(yield func(*T, error) bool) {
267 if !yield(nil, err) {
268 return
269 }
270 }
271}
272
273func mergeHTTPOptions(clientConfig *ClientConfig, configHTTPOptions *HTTPOptions) *HTTPOptions {
274 var clientHTTPOptions *HTTPOptions
275 if clientConfig != nil {
276 clientHTTPOptions = &(clientConfig.HTTPOptions)
277 }
278
279 result := HTTPOptions{}
280 if clientHTTPOptions == nil && configHTTPOptions == nil {
281 return nil
282 } else if clientHTTPOptions == nil {
283 result = HTTPOptions{
284 BaseURL: configHTTPOptions.BaseURL,
285 APIVersion: configHTTPOptions.APIVersion,
286 }
287 } else {
288 result = HTTPOptions{
289 BaseURL: clientHTTPOptions.BaseURL,
290 APIVersion: clientHTTPOptions.APIVersion,
291 }
292 }
293
294 if configHTTPOptions != nil {
295 if configHTTPOptions.BaseURL != "" {
296 result.BaseURL = configHTTPOptions.BaseURL
297 }
298 if configHTTPOptions.APIVersion != "" {
299 result.APIVersion = configHTTPOptions.APIVersion
300 }
301 }
302 result.Headers = mergeHeaders(clientHTTPOptions, configHTTPOptions)
303 return &result
304}
305
306func mergeHeaders(clientHTTPOptions *HTTPOptions, configHTTPOptions *HTTPOptions) http.Header {
307 result := http.Header{}
308 if clientHTTPOptions == nil && configHTTPOptions == nil {
309 return result
310 }
311
312 if clientHTTPOptions != nil {
313 doMergeHeaders(clientHTTPOptions.Headers, &result)
314 }
315 // configHTTPOptions takes precedence over clientHTTPOptions.
316 if configHTTPOptions != nil {
317 doMergeHeaders(configHTTPOptions.Headers, &result)
318 }
319 return result
320}
321
322func doMergeHeaders(input http.Header, output *http.Header) {
323 for k, v := range input {
324 for _, vv := range v {
325 output.Add(k, vv)
326 }
327 }
328}