replay_sanitizer.go

  1// Copyright 2024 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package genai
 16
 17import (
 18	"encoding/base64"
 19	"fmt"
 20	"log"
 21	"reflect"
 22	"strconv"
 23	"strings"
 24	"testing"
 25)
 26
 27// sanitizeMapWithSourceType sanitizes byte fields within a map based on the provided source type.
 28// It converts byte fields encoded with URL Base64 to standard Base64 encoding to prevent SDK unmarshal error.
 29//
 30// Args:
 31//
 32//	t: The testing.T instance for reporting errors.
 33//	sourceType: The reflect.Type of the source struct. This is used to determine the paths to byte fields.
 34//	m: The map containing the data to sanitize.  The map will be modified in place.
 35func sanitizeMapWithSourceType(t *testing.T, sourceType reflect.Type, m any) {
 36	t.Helper()
 37	paths := make([]string, 0)
 38
 39	st := sourceType
 40	if sourceType.Kind() == reflect.Slice {
 41		st = sourceType.Elem()
 42	}
 43	visitedTypes := make(map[string]bool)
 44	if err := getFieldPath(st, reflect.TypeOf([]byte{}), &paths, "", visitedTypes, false); err != nil {
 45		t.Fatal(err)
 46	}
 47
 48	stdBase64Handler := func(data any, path string) any {
 49		s := data.(string)
 50		b, err := base64.URLEncoding.DecodeString(s)
 51		if err != nil {
 52			b, err = base64.StdEncoding.DecodeString(s)
 53			if err != nil {
 54				t.Errorf("invalid base64 string %s at path %s", s, path)
 55			}
 56		}
 57		return base64.StdEncoding.EncodeToString(b)
 58	}
 59
 60	for _, path := range paths {
 61		if sourceType.Kind() == reflect.Slice {
 62			data := m.([]any)
 63			for i := 0; i < len(data); i++ {
 64				sanitizeMapByPath(data[i], path, stdBase64Handler, false)
 65			}
 66		} else {
 67			sanitizeMapByPath(m.(map[string]any), path, stdBase64Handler, false)
 68		}
 69	}
 70}
 71
 72// sanitizeMapByPath sanitizes a value within a nested map structure based on the given path.
 73// It applies the provided sanitizer function to the value found at the specified path.
 74//
 75// Args:
 76//
 77//	data: The map containing the data to sanitize. This can be a nested map structure. The map may be modified in place.
 78//	path: The path to the value to sanitize. This is a dot-separated string, where each component represents a key in the map.
 79//	      Array elements can be accessed using the "[]" prefix, e.g., "[]sliceField.fieldName".
 80//	sanitizer: The function to apply to the value found at the specified path. The function should take the value and the path as input and return the sanitized value.
 81//	debug: A boolean indicating whether debug logging should be enabled.
 82func sanitizeMapByPath(data any, path string, sanitizer func(data any, path string) any, debug bool) {
 83	if _, ok := data.(map[string]any); !ok {
 84		if debug {
 85			log.Println("data is not map type", data, path)
 86		}
 87		return
 88	}
 89	m := data.(map[string]any)
 90
 91	keys := strings.Split(path, ".")
 92	key := keys[0]
 93
 94	// Handle path not exists.
 95	if strings.HasPrefix(key, "[]") {
 96		if _, ok := m[key[2:]]; !ok {
 97			if debug {
 98				log.Println("path doesn't exist", data, path)
 99			}
100			return
101		}
102	} else if _, ok := m[key]; !ok {
103		if debug {
104			log.Println("path doesn't exist", data, path)
105		}
106		return
107	}
108
109	// We are at the last component of the path.
110	if strings.HasPrefix(key, "[]") && len(keys) == 1 {
111		items := []any{}
112		v := m[key[2:]]
113		if reflect.ValueOf(v).Type().Kind() != reflect.Slice {
114			if debug {
115				log.Println("data is not slice type as the path denoted", data, path)
116			}
117			return
118		}
119		for i := 0; i < reflect.ValueOf(v).Len(); i++ {
120			items = append(items, sanitizer(reflect.ValueOf(v).Index(i).Interface(), key))
121		}
122		m[key[2:]] = items
123		return
124	} else if len(keys) == 1 {
125		m[key] = sanitizer(m[key], path)
126		return
127	}
128
129	if strings.HasPrefix(key, "[]") {
130		v := m[key[2:]]
131		if reflect.ValueOf(v).Type().Kind() != reflect.Slice {
132			if debug {
133				log.Println("data is not slice type as the path denoted", data, path)
134			}
135			return
136		}
137		s := reflect.ValueOf(v)
138		for i := 0; i < s.Len(); i++ {
139			element := s.Index(i).Interface()
140			sanitizeMapByPath(element, strings.Join(keys[1:], "."), sanitizer, debug)
141		}
142	} else {
143		sanitizeMapByPath(m[key], strings.Join(keys[1:], "."), sanitizer, debug)
144	}
145}
146
147// convertFloat64ToString recursively converts float64 values within a map[string]any to strings.
148func convertFloat64ToString(data map[string]any) map[string]any {
149	for key, value := range data {
150		switch v := value.(type) {
151		case float64:
152			// Convert float64 to string
153			data[key] = strconv.FormatFloat(v, 'f', 5, 64)
154		case map[string]any:
155			// Recursively process nested maps
156			data[key] = convertFloat64ToString(v)
157		case []any:
158			// Recursively process slices
159			data[key] = convertSliceFloat64ToString(v)
160		}
161	}
162	return data
163}
164
165// convertSliceFloat64ToString recursively converts float64 values within a []any to strings.
166func convertSliceFloat64ToString(data []any) []any {
167	for i, value := range data {
168		switch v := value.(type) {
169		case float64:
170			// Convert float64 to string
171			data[i] = strconv.FormatFloat(v, 'f', -1, 64)
172		case map[string]any:
173			// Recursively process nested maps
174			data[i] = convertFloat64ToString(v)
175		case []any:
176			// Recursively process nested slices
177			data[i] = convertSliceFloat64ToString(v)
178		}
179	}
180	return data
181}
182
183// getFieldPath retrieves the paths to all fields within a nested struct that match a given target type.
184// It uses reflection to traverse the struct and its nested fields.
185//
186// Args:
187//
188//		sourceType: The reflect.Type of the source struct to traverse.
189//		targetType: The reflect.Type of the target field to search for.
190//		outputPaths: A pointer to a string slice where the resulting paths will be stored.
191//		prefix: The current path prefix, used during recursive calls.
192//	 	visitedTypes: Serves to prevent infinite recursion when dealing with recursive data structures
193//		debug: A boolean indicating whether debug logging should be enabled.
194//
195// Returns:
196//
197//	An error if the targetType is a pointer or a struct.
198func getFieldPath(sourceType reflect.Type, targetType reflect.Type, outputPaths *[]string, prefix string, visitedTypes map[string]bool, debug bool) error {
199	if targetType.Kind() == reflect.Ptr {
200		return fmt.Errorf("targetType cannot be a pointer")
201	}
202	if targetType.Kind() == reflect.Struct {
203		return fmt.Errorf("targetType cannot be a struct")
204	}
205	if sourceType.Kind() == reflect.Ptr {
206		_ = getFieldPath(sourceType.Elem(), targetType, outputPaths, prefix, visitedTypes, debug) // handle pointer nested field
207	} else if sourceType.Kind() == reflect.Struct {
208		for i := 0; i < sourceType.NumField(); i++ {
209			field := sourceType.Field(i)
210			if debug {
211				log.Println("field name:", field.Name, "field type:", field.Type.String(), "field tag:", field.Tag.Get("json"))
212			}
213			if visitedTypes[sourceType.String()+"."+fieldJSONName(field)] {
214				continue
215			}
216			visitedTypes[sourceType.String()+"."+fieldJSONName(field)] = true
217
218			if field.Type == targetType {
219				*outputPaths = append(*outputPaths, prefix+fieldJSONName(field))
220			} else if field.Type.Kind() == reflect.Struct {
221				_ = getFieldPath(field.Type, targetType, outputPaths, prefix+fieldJSONName(field)+".", visitedTypes, debug)
222			} else if field.Type.Kind() == reflect.Ptr {
223				_ = getFieldPath(field.Type.Elem(), targetType, outputPaths, prefix+fieldJSONName(field)+".", visitedTypes, debug)
224			} else if field.Type.Kind() == reflect.Slice {
225				elementType := field.Type.Elem() // Get the type of elements in the array
226				_ = getFieldPath(elementType, targetType, outputPaths, prefix+"[]"+fieldJSONName(field)+".", visitedTypes, debug)
227			}
228			visitedTypes[sourceType.String()+"."+fieldJSONName(field)] = false
229			// TODO(b/390425822): support map type.
230		}
231		if debug {
232			log.Printf("field of type %s not found\n", targetType.String())
233		}
234	}
235	if debug {
236		log.Printf("field of type %s not found\n", targetType.String())
237	}
238	return nil
239}
240
241func fieldJSONName(field reflect.StructField) string {
242	return strings.Split(field.Tag.Get("json"), ",")[0]
243}