1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package genai
16
17import (
18 "encoding/base64"
19 "fmt"
20 "log"
21 "reflect"
22 "strconv"
23 "strings"
24 "testing"
25)
26
27// sanitizeMapWithSourceType sanitizes byte fields within a map based on the provided source type.
28// It converts byte fields encoded with URL Base64 to standard Base64 encoding to prevent SDK unmarshal error.
29//
30// Args:
31//
32// t: The testing.T instance for reporting errors.
33// sourceType: The reflect.Type of the source struct. This is used to determine the paths to byte fields.
34// m: The map containing the data to sanitize. The map will be modified in place.
35func sanitizeMapWithSourceType(t *testing.T, sourceType reflect.Type, m any) {
36 t.Helper()
37 paths := make([]string, 0)
38
39 st := sourceType
40 if sourceType.Kind() == reflect.Slice {
41 st = sourceType.Elem()
42 }
43 visitedTypes := make(map[string]bool)
44 if err := getFieldPath(st, reflect.TypeOf([]byte{}), &paths, "", visitedTypes, false); err != nil {
45 t.Fatal(err)
46 }
47
48 stdBase64Handler := func(data any, path string) any {
49 s := data.(string)
50 b, err := base64.URLEncoding.DecodeString(s)
51 if err != nil {
52 b, err = base64.StdEncoding.DecodeString(s)
53 if err != nil {
54 t.Errorf("invalid base64 string %s at path %s", s, path)
55 }
56 }
57 return base64.StdEncoding.EncodeToString(b)
58 }
59
60 for _, path := range paths {
61 if sourceType.Kind() == reflect.Slice {
62 data := m.([]any)
63 for i := 0; i < len(data); i++ {
64 sanitizeMapByPath(data[i], path, stdBase64Handler, false)
65 }
66 } else {
67 sanitizeMapByPath(m.(map[string]any), path, stdBase64Handler, false)
68 }
69 }
70}
71
72// sanitizeMapByPath sanitizes a value within a nested map structure based on the given path.
73// It applies the provided sanitizer function to the value found at the specified path.
74//
75// Args:
76//
77// data: The map containing the data to sanitize. This can be a nested map structure. The map may be modified in place.
78// path: The path to the value to sanitize. This is a dot-separated string, where each component represents a key in the map.
79// Array elements can be accessed using the "[]" prefix, e.g., "[]sliceField.fieldName".
80// sanitizer: The function to apply to the value found at the specified path. The function should take the value and the path as input and return the sanitized value.
81// debug: A boolean indicating whether debug logging should be enabled.
82func sanitizeMapByPath(data any, path string, sanitizer func(data any, path string) any, debug bool) {
83 if _, ok := data.(map[string]any); !ok {
84 if debug {
85 log.Println("data is not map type", data, path)
86 }
87 return
88 }
89 m := data.(map[string]any)
90
91 keys := strings.Split(path, ".")
92 key := keys[0]
93
94 // Handle path not exists.
95 if strings.HasPrefix(key, "[]") {
96 if _, ok := m[key[2:]]; !ok {
97 if debug {
98 log.Println("path doesn't exist", data, path)
99 }
100 return
101 }
102 } else if _, ok := m[key]; !ok {
103 if debug {
104 log.Println("path doesn't exist", data, path)
105 }
106 return
107 }
108
109 // We are at the last component of the path.
110 if strings.HasPrefix(key, "[]") && len(keys) == 1 {
111 items := []any{}
112 v := m[key[2:]]
113 if reflect.ValueOf(v).Type().Kind() != reflect.Slice {
114 if debug {
115 log.Println("data is not slice type as the path denoted", data, path)
116 }
117 return
118 }
119 for i := 0; i < reflect.ValueOf(v).Len(); i++ {
120 items = append(items, sanitizer(reflect.ValueOf(v).Index(i).Interface(), key))
121 }
122 m[key[2:]] = items
123 return
124 } else if len(keys) == 1 {
125 m[key] = sanitizer(m[key], path)
126 return
127 }
128
129 if strings.HasPrefix(key, "[]") {
130 v := m[key[2:]]
131 if reflect.ValueOf(v).Type().Kind() != reflect.Slice {
132 if debug {
133 log.Println("data is not slice type as the path denoted", data, path)
134 }
135 return
136 }
137 s := reflect.ValueOf(v)
138 for i := 0; i < s.Len(); i++ {
139 element := s.Index(i).Interface()
140 sanitizeMapByPath(element, strings.Join(keys[1:], "."), sanitizer, debug)
141 }
142 } else {
143 sanitizeMapByPath(m[key], strings.Join(keys[1:], "."), sanitizer, debug)
144 }
145}
146
147// convertFloat64ToString recursively converts float64 values within a map[string]any to strings.
148func convertFloat64ToString(data map[string]any) map[string]any {
149 for key, value := range data {
150 switch v := value.(type) {
151 case float64:
152 // Convert float64 to string
153 data[key] = strconv.FormatFloat(v, 'f', 5, 64)
154 case map[string]any:
155 // Recursively process nested maps
156 data[key] = convertFloat64ToString(v)
157 case []any:
158 // Recursively process slices
159 data[key] = convertSliceFloat64ToString(v)
160 }
161 }
162 return data
163}
164
165// convertSliceFloat64ToString recursively converts float64 values within a []any to strings.
166func convertSliceFloat64ToString(data []any) []any {
167 for i, value := range data {
168 switch v := value.(type) {
169 case float64:
170 // Convert float64 to string
171 data[i] = strconv.FormatFloat(v, 'f', -1, 64)
172 case map[string]any:
173 // Recursively process nested maps
174 data[i] = convertFloat64ToString(v)
175 case []any:
176 // Recursively process nested slices
177 data[i] = convertSliceFloat64ToString(v)
178 }
179 }
180 return data
181}
182
183// getFieldPath retrieves the paths to all fields within a nested struct that match a given target type.
184// It uses reflection to traverse the struct and its nested fields.
185//
186// Args:
187//
188// sourceType: The reflect.Type of the source struct to traverse.
189// targetType: The reflect.Type of the target field to search for.
190// outputPaths: A pointer to a string slice where the resulting paths will be stored.
191// prefix: The current path prefix, used during recursive calls.
192// visitedTypes: Serves to prevent infinite recursion when dealing with recursive data structures
193// debug: A boolean indicating whether debug logging should be enabled.
194//
195// Returns:
196//
197// An error if the targetType is a pointer or a struct.
198func getFieldPath(sourceType reflect.Type, targetType reflect.Type, outputPaths *[]string, prefix string, visitedTypes map[string]bool, debug bool) error {
199 if targetType.Kind() == reflect.Ptr {
200 return fmt.Errorf("targetType cannot be a pointer")
201 }
202 if targetType.Kind() == reflect.Struct {
203 return fmt.Errorf("targetType cannot be a struct")
204 }
205 if sourceType.Kind() == reflect.Ptr {
206 _ = getFieldPath(sourceType.Elem(), targetType, outputPaths, prefix, visitedTypes, debug) // handle pointer nested field
207 } else if sourceType.Kind() == reflect.Struct {
208 for i := 0; i < sourceType.NumField(); i++ {
209 field := sourceType.Field(i)
210 if debug {
211 log.Println("field name:", field.Name, "field type:", field.Type.String(), "field tag:", field.Tag.Get("json"))
212 }
213 if visitedTypes[sourceType.String()+"."+fieldJSONName(field)] {
214 continue
215 }
216 visitedTypes[sourceType.String()+"."+fieldJSONName(field)] = true
217
218 if field.Type == targetType {
219 *outputPaths = append(*outputPaths, prefix+fieldJSONName(field))
220 } else if field.Type.Kind() == reflect.Struct {
221 _ = getFieldPath(field.Type, targetType, outputPaths, prefix+fieldJSONName(field)+".", visitedTypes, debug)
222 } else if field.Type.Kind() == reflect.Ptr {
223 _ = getFieldPath(field.Type.Elem(), targetType, outputPaths, prefix+fieldJSONName(field)+".", visitedTypes, debug)
224 } else if field.Type.Kind() == reflect.Slice {
225 elementType := field.Type.Elem() // Get the type of elements in the array
226 _ = getFieldPath(elementType, targetType, outputPaths, prefix+"[]"+fieldJSONName(field)+".", visitedTypes, debug)
227 }
228 visitedTypes[sourceType.String()+"."+fieldJSONName(field)] = false
229 // TODO(b/390425822): support map type.
230 }
231 if debug {
232 log.Printf("field of type %s not found\n", targetType.String())
233 }
234 }
235 if debug {
236 log.Printf("field of type %s not found\n", targetType.String())
237 }
238 return nil
239}
240
241func fieldJSONName(field reflect.StructField) string {
242 return strings.Split(field.Tag.Get("json"), ",")[0]
243}