1// Copyright 2024 Google LLC
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15package genai
16
17import (
18 "bytes"
19 "encoding/base64"
20 "encoding/json"
21 "fmt"
22 "io"
23 "math"
24 "net/http"
25 "net/http/httptest"
26 "os"
27 "path/filepath"
28 "reflect"
29 "regexp"
30 "strconv"
31 "strings"
32 "testing"
33 "time"
34
35 "github.com/google/go-cmp/cmp"
36)
37
38// ReplayAPIClient is a client that reads responses from a replay session file.
39type replayAPIClient struct {
40 ReplayFile *replayFile
41 ReplaysDirectory string
42 currentInteractionIndex int
43 t *testing.T
44 server *httptest.Server
45}
46
47// NewReplayAPIClient creates a new ReplayAPIClient from a replay session file.
48func newReplayAPIClient(t *testing.T) *replayAPIClient {
49 t.Helper()
50 // The replay files are expected to be in a directory specified by the environment variable
51 // GOOGLE_GENAI_REPLAYS_DIRECTORY.
52 replaysDirectory := os.Getenv("GOOGLE_GENAI_REPLAYS_DIRECTORY")
53 rac := &replayAPIClient{
54 ReplayFile: nil,
55 ReplaysDirectory: replaysDirectory,
56 currentInteractionIndex: 0,
57 t: t,
58 }
59 rac.server = httptest.NewServer(rac)
60 rac.t.Cleanup(func() {
61 rac.server.Close()
62 })
63 return rac
64}
65
66// GetBaseURL returns the URL of the mocked HTTP server.
67func (rac *replayAPIClient) GetBaseURL() string {
68 return rac.server.URL
69}
70
71// LoadReplay populates a replay session from a file based on the provided path.
72func (rac *replayAPIClient) LoadReplay(replayFilePath string) {
73 rac.t.Helper()
74 fullReplaysPath := replayFilePath
75 if rac.ReplaysDirectory != "" {
76 fullReplaysPath = filepath.Join(rac.ReplaysDirectory, replayFilePath)
77 }
78 var replayFile replayFile
79 if err := readFileForReplayTest(fullReplaysPath, &replayFile, true); err != nil {
80 rac.t.Errorf("error loading replay file, %v", err)
81 }
82 rac.ReplayFile = &replayFile
83}
84
85// LatestInteraction returns the interaction that was returned by the last call to ServeHTTP.
86func (rac *replayAPIClient) LatestInteraction() *replayInteraction {
87 rac.t.Helper()
88 if rac.currentInteractionIndex == 0 {
89 rac.t.Fatalf("no interactions has been made in replay session so far")
90 }
91 return rac.ReplayFile.Interactions[rac.currentInteractionIndex-1]
92}
93
94// ServeHTTP mocks serving HTTP requests.
95func (rac *replayAPIClient) ServeHTTP(w http.ResponseWriter, req *http.Request) {
96 rac.t.Helper()
97 if rac.ReplayFile == nil {
98 rac.t.Fatalf("no replay file loaded")
99 }
100 if rac.currentInteractionIndex >= len(rac.ReplayFile.Interactions) {
101 rac.t.Fatalf("no more interactions in replay session")
102 }
103 interaction := rac.ReplayFile.Interactions[rac.currentInteractionIndex]
104
105 rac.assertRequest(req, interaction.Request)
106 rac.currentInteractionIndex++
107 var bodySegments []string
108 for i := 0; i < len(interaction.Response.BodySegments); i++ {
109 responseBodySegment, err := json.Marshal(interaction.Response.BodySegments[i])
110 if err != nil {
111 rac.t.Errorf("error marshalling responseBodySegment [%s], err: %+v", rac.ReplayFile.ReplayID, err)
112 }
113 bodySegments = append(bodySegments, string(responseBodySegment))
114 }
115 if interaction.Response.StatusCode != 0 {
116 w.WriteHeader(int(interaction.Response.StatusCode))
117 } else {
118 w.WriteHeader(http.StatusOK)
119 }
120 _, err := w.Write([]byte(strings.Join(bodySegments, "\n")))
121 if err != nil {
122 rac.t.Errorf("error writing response, err: %+v", err)
123 }
124}
125
126func readFileForReplayTest[T any](path string, output *T, omitempty bool) error {
127 dat, err := os.ReadFile(path)
128 if err != nil {
129 return err
130 }
131
132 var m map[string]any
133 if err := json.Unmarshal(dat, &m); err != nil {
134 return fmt.Errorf("error unmarshalling to map: %w", err)
135 }
136
137 if omitempty {
138 omitEmptyValues(m)
139 }
140 m = convertKeysToCamelCase(m, "").(map[string]any)
141
142 // Marshal the modified map back to struct
143 err = mapToStruct(m, output)
144 if err != nil {
145 return fmt.Errorf("error converting map to struct: %w", err)
146 }
147
148 return nil
149}
150
151// In testing server, host and scheme is empty.
152func redactReplayURL(url string) string {
153 url = strings.ReplaceAll(url, "{MLDEV_URL_PREFIX}/", "")
154 url = strings.ReplaceAll(url, "{VERTEX_URL_PREFIX}/", "")
155 return url
156}
157
158func redactSDKURL(url string) string {
159 if strings.Contains(url, "project") {
160 vertexRegexp := regexp.MustCompile(`.*/projects/[^/]+/locations/[^/]+/`)
161 url = vertexRegexp.ReplaceAllString(url, "")
162 } else {
163 mldevRegexp := regexp.MustCompile(`^\/[^/]+\/`)
164 url = mldevRegexp.ReplaceAllString(url, "")
165 }
166
167 return url
168}
169
170func redactProjectLocationPath(path string) string {
171 // Redact a field in the request that is known to vary based on project and
172 // location.
173 projectLocationRegexp := regexp.MustCompile(`projects/[^/]+/locations/[^/]+`)
174 return projectLocationRegexp.ReplaceAllString(path, "{PROJECT_AND_LOCATION_PATH}")
175}
176
177func redactRequestBody(body map[string]any) map[string]any {
178 for key, value := range body {
179 if _, ok := value.(string); ok {
180 body[key] = redactProjectLocationPath(value.(string))
181 }
182 }
183 return body
184}
185
186func (rac *replayAPIClient) assertRequest(sdkRequest *http.Request, replayRequest *replayRequest) {
187 rac.t.Helper()
188 sdkRequestBody, err := io.ReadAll(sdkRequest.Body)
189 if err != nil {
190 rac.t.Errorf("Error reading request body, err: %+v", err)
191 }
192 bodySegment := make(map[string]any)
193 if len(sdkRequestBody) > 0 {
194 if err := json.Unmarshal(sdkRequestBody, &bodySegment); err != nil {
195 rac.t.Errorf("Error unmarshalling body, err: %+v", err)
196 }
197 }
198 bodySegment = redactRequestBody(bodySegment)
199 bodySegment = convertKeysToCamelCase(bodySegment, "").(map[string]any)
200 omitEmptyValues(bodySegment)
201
202 headers := make(map[string]string)
203 for k, v := range sdkRequest.Header {
204 headers[k] = strings.Join(v, ",")
205 }
206 // TODO(b/390425822): support headers validation.
207 got := map[string]any{
208 "method": strings.ToLower(sdkRequest.Method),
209 "url": redactSDKURL(sdkRequest.URL.String()),
210 "bodySegments": []map[string]any{bodySegment},
211 }
212
213 want := map[string]any{
214 "method": replayRequest.Method,
215 "url": redactReplayURL(replayRequest.URL),
216 "bodySegments": replayRequest.BodySegments,
217 }
218
219 opts := cmp.Options{
220 stringComparator,
221 }
222
223 if diff := cmp.Diff(got, want, opts); diff != "" {
224 rac.t.Errorf("Requests had diffs (-got +want):\n%v", diff)
225 }
226}
227
228// omitEmptyValues recursively traverses the given value and if it is a `map[string]any` or
229// `[]any`, it omits the empty values.
230func omitEmptyValues(v any) {
231 if v == nil {
232 return
233 }
234 switch m := v.(type) {
235 case map[string]any:
236 for k, v := range m {
237 // If the value is empty, delete the key from the map.
238 if reflect.ValueOf(v).IsZero() || v == "0001-01-01T00:00:00Z" {
239 delete(m, k)
240 } else {
241 omitEmptyValues(v)
242 }
243 }
244 case []any:
245 for _, item := range m {
246 omitEmptyValues(item)
247 }
248 case []map[string]any:
249 for _, item := range m {
250 omitEmptyValues(item)
251 }
252 }
253}
254
255func convertKeysToCamelCase(v any, parentKey string) any {
256 if v == nil {
257 return nil
258 }
259 switch m := v.(type) {
260 case map[string]any:
261 newMap := make(map[string]any)
262 for key, value := range m {
263 camelCaseKey := toCamelCase(key)
264 if parentKey == "response" && key == "body_segments" {
265 newMap[camelCaseKey] = value
266 } else {
267 newMap[camelCaseKey] = convertKeysToCamelCase(value, key)
268 }
269 }
270 return newMap
271 case []any:
272 newSlice := make([]any, len(m))
273 for i, item := range m {
274 newSlice[i] = convertKeysToCamelCase(item, parentKey)
275 }
276 return newSlice
277 default:
278 return v
279 }
280}
281
282// toCamelCase converts a string from snake case to camel case.
283// Examples:
284//
285// "foo" -> "foo"
286// "fooBar" -> "fooBar"
287// "foo_bar" -> "fooBar"
288// "foo_bar_baz" -> "fooBarBaz"
289func toCamelCase(s string) string {
290 parts := strings.Split(s, "_")
291 if len(parts) == 1 {
292 // There is no underscore, so no need to modify the string.
293 return s
294 }
295 // Skip the first word and convert the first letter of the remaining words to uppercase.
296 for i, part := range parts[1:] {
297 parts[i+1] = strings.ToUpper(part[:1]) + part[1:]
298 }
299 // Concat the parts back together to mak a camelCase string.
300 return strings.Join(parts, "")
301}
302
303var stringComparator = cmp.Comparer(func(x, y string) bool {
304 if timeStringComparator(x, y) || base64StringComparator(x, y) || floatStringComparator(x, y) {
305 return true
306 }
307 return x == y
308})
309
310var floatComparator = cmp.Comparer(func(x, y float64) bool {
311 return math.Abs(x-y) < 1e-6
312})
313
314var floatStringComparator = func(x, y string) bool {
315 vx, err := strconv.ParseFloat(x, 64)
316 if err != nil {
317 return x == y
318 }
319 vy, err := strconv.ParseFloat(y, 64)
320 if err != nil {
321 return x == y
322 }
323 return math.Abs(vx-vy) < 1e-6
324}
325
326var timeStringComparator = func(x, y string) bool {
327 xTime, err := time.Parse(time.RFC3339, x)
328 if err != nil {
329 return x == y
330 }
331 yTime, err := time.Parse(time.RFC3339, y)
332 if err != nil {
333 return x == y
334 }
335 return xTime.Truncate(time.Microsecond).Equal(yTime.Truncate(time.Microsecond))
336}
337
338var base64StringComparator = func(x, y string) bool {
339 stdBase64Handler := func(s string) ([]byte, error) {
340 b, err := base64.URLEncoding.DecodeString(s)
341 if err != nil {
342 b, err = base64.StdEncoding.DecodeString(s)
343 if err != nil {
344 return nil, fmt.Errorf("invalid base64 string %s", s)
345 }
346 }
347 return b, nil
348 }
349
350 xb, err := stdBase64Handler(x)
351 if err != nil {
352 return x == y
353 }
354 yb, err := stdBase64Handler(y)
355 if err != nil {
356 return x == y
357 }
358 return bytes.Equal(xb, yb)
359}