replay_api_client.go

  1// Copyright 2024 Google LLC
  2//
  3// Licensed under the Apache License, Version 2.0 (the "License");
  4// you may not use this file except in compliance with the License.
  5// You may obtain a copy of the License at
  6//
  7//      http://www.apache.org/licenses/LICENSE-2.0
  8//
  9// Unless required by applicable law or agreed to in writing, software
 10// distributed under the License is distributed on an "AS IS" BASIS,
 11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 12// See the License for the specific language governing permissions and
 13// limitations under the License.
 14
 15package genai
 16
 17import (
 18	"bytes"
 19	"encoding/base64"
 20	"encoding/json"
 21	"fmt"
 22	"io"
 23	"math"
 24	"net/http"
 25	"net/http/httptest"
 26	"os"
 27	"path/filepath"
 28	"reflect"
 29	"regexp"
 30	"strconv"
 31	"strings"
 32	"testing"
 33	"time"
 34
 35	"github.com/google/go-cmp/cmp"
 36)
 37
 38// ReplayAPIClient is a client that reads responses from a replay session file.
 39type replayAPIClient struct {
 40	ReplayFile              *replayFile
 41	ReplaysDirectory        string
 42	currentInteractionIndex int
 43	t                       *testing.T
 44	server                  *httptest.Server
 45}
 46
 47// NewReplayAPIClient creates a new ReplayAPIClient from a replay session file.
 48func newReplayAPIClient(t *testing.T) *replayAPIClient {
 49	t.Helper()
 50	// The replay files are expected to be in a directory specified by the environment variable
 51	// GOOGLE_GENAI_REPLAYS_DIRECTORY.
 52	replaysDirectory := os.Getenv("GOOGLE_GENAI_REPLAYS_DIRECTORY")
 53	rac := &replayAPIClient{
 54		ReplayFile:              nil,
 55		ReplaysDirectory:        replaysDirectory,
 56		currentInteractionIndex: 0,
 57		t:                       t,
 58	}
 59	rac.server = httptest.NewServer(rac)
 60	rac.t.Cleanup(func() {
 61		rac.server.Close()
 62	})
 63	return rac
 64}
 65
 66// GetBaseURL returns the URL of the mocked HTTP server.
 67func (rac *replayAPIClient) GetBaseURL() string {
 68	return rac.server.URL
 69}
 70
 71// LoadReplay populates a replay session from a file based on the provided path.
 72func (rac *replayAPIClient) LoadReplay(replayFilePath string) {
 73	rac.t.Helper()
 74	fullReplaysPath := replayFilePath
 75	if rac.ReplaysDirectory != "" {
 76		fullReplaysPath = filepath.Join(rac.ReplaysDirectory, replayFilePath)
 77	}
 78	var replayFile replayFile
 79	if err := readFileForReplayTest(fullReplaysPath, &replayFile, true); err != nil {
 80		rac.t.Errorf("error loading replay file, %v", err)
 81	}
 82	rac.ReplayFile = &replayFile
 83}
 84
 85// LatestInteraction returns the interaction that was returned by the last call to ServeHTTP.
 86func (rac *replayAPIClient) LatestInteraction() *replayInteraction {
 87	rac.t.Helper()
 88	if rac.currentInteractionIndex == 0 {
 89		rac.t.Fatalf("no interactions has been made in replay session so far")
 90	}
 91	return rac.ReplayFile.Interactions[rac.currentInteractionIndex-1]
 92}
 93
 94// ServeHTTP mocks serving HTTP requests.
 95func (rac *replayAPIClient) ServeHTTP(w http.ResponseWriter, req *http.Request) {
 96	rac.t.Helper()
 97	if rac.ReplayFile == nil {
 98		rac.t.Fatalf("no replay file loaded")
 99	}
100	if rac.currentInteractionIndex >= len(rac.ReplayFile.Interactions) {
101		rac.t.Fatalf("no more interactions in replay session")
102	}
103	interaction := rac.ReplayFile.Interactions[rac.currentInteractionIndex]
104
105	rac.assertRequest(req, interaction.Request)
106	rac.currentInteractionIndex++
107	var bodySegments []string
108	for i := 0; i < len(interaction.Response.BodySegments); i++ {
109		responseBodySegment, err := json.Marshal(interaction.Response.BodySegments[i])
110		if err != nil {
111			rac.t.Errorf("error marshalling responseBodySegment [%s], err: %+v", rac.ReplayFile.ReplayID, err)
112		}
113		bodySegments = append(bodySegments, string(responseBodySegment))
114	}
115	if interaction.Response.StatusCode != 0 {
116		w.WriteHeader(int(interaction.Response.StatusCode))
117	} else {
118		w.WriteHeader(http.StatusOK)
119	}
120	_, err := w.Write([]byte(strings.Join(bodySegments, "\n")))
121	if err != nil {
122		rac.t.Errorf("error writing response, err: %+v", err)
123	}
124}
125
126func readFileForReplayTest[T any](path string, output *T, omitempty bool) error {
127	dat, err := os.ReadFile(path)
128	if err != nil {
129		return err
130	}
131
132	var m map[string]any
133	if err := json.Unmarshal(dat, &m); err != nil {
134		return fmt.Errorf("error unmarshalling to map: %w", err)
135	}
136
137	if omitempty {
138		omitEmptyValues(m)
139	}
140	m = convertKeysToCamelCase(m, "").(map[string]any)
141
142	// Marshal the modified map back to struct
143	err = mapToStruct(m, output)
144	if err != nil {
145		return fmt.Errorf("error converting map to struct: %w", err)
146	}
147
148	return nil
149}
150
151// In testing server, host and scheme is empty.
152func redactReplayURL(url string) string {
153	url = strings.ReplaceAll(url, "{MLDEV_URL_PREFIX}/", "")
154	url = strings.ReplaceAll(url, "{VERTEX_URL_PREFIX}/", "")
155	return url
156}
157
158func redactSDKURL(url string) string {
159	if strings.Contains(url, "project") {
160		vertexRegexp := regexp.MustCompile(`.*/projects/[^/]+/locations/[^/]+/`)
161		url = vertexRegexp.ReplaceAllString(url, "")
162	} else {
163		mldevRegexp := regexp.MustCompile(`^\/[^/]+\/`)
164		url = mldevRegexp.ReplaceAllString(url, "")
165	}
166
167	return url
168}
169
170func redactProjectLocationPath(path string) string {
171	// Redact a field in the request that is known to vary based on project and
172	// location.
173	projectLocationRegexp := regexp.MustCompile(`projects/[^/]+/locations/[^/]+`)
174	return projectLocationRegexp.ReplaceAllString(path, "{PROJECT_AND_LOCATION_PATH}")
175}
176
177func redactRequestBody(body map[string]any) map[string]any {
178	for key, value := range body {
179		if _, ok := value.(string); ok {
180			body[key] = redactProjectLocationPath(value.(string))
181		}
182	}
183	return body
184}
185
186func (rac *replayAPIClient) assertRequest(sdkRequest *http.Request, replayRequest *replayRequest) {
187	rac.t.Helper()
188	sdkRequestBody, err := io.ReadAll(sdkRequest.Body)
189	if err != nil {
190		rac.t.Errorf("Error reading request body, err: %+v", err)
191	}
192	bodySegment := make(map[string]any)
193	if len(sdkRequestBody) > 0 {
194		if err := json.Unmarshal(sdkRequestBody, &bodySegment); err != nil {
195			rac.t.Errorf("Error unmarshalling body, err: %+v", err)
196		}
197	}
198	bodySegment = redactRequestBody(bodySegment)
199	bodySegment = convertKeysToCamelCase(bodySegment, "").(map[string]any)
200	omitEmptyValues(bodySegment)
201
202	headers := make(map[string]string)
203	for k, v := range sdkRequest.Header {
204		headers[k] = strings.Join(v, ",")
205	}
206	// TODO(b/390425822): support headers validation.
207	got := map[string]any{
208		"method":       strings.ToLower(sdkRequest.Method),
209		"url":          redactSDKURL(sdkRequest.URL.String()),
210		"bodySegments": []map[string]any{bodySegment},
211	}
212
213	want := map[string]any{
214		"method":       replayRequest.Method,
215		"url":          redactReplayURL(replayRequest.URL),
216		"bodySegments": replayRequest.BodySegments,
217	}
218
219	opts := cmp.Options{
220		stringComparator,
221	}
222
223	if diff := cmp.Diff(got, want, opts); diff != "" {
224		rac.t.Errorf("Requests had diffs (-got +want):\n%v", diff)
225	}
226}
227
228// omitEmptyValues recursively traverses the given value and if it is a `map[string]any` or
229// `[]any`, it omits the empty values.
230func omitEmptyValues(v any) {
231	if v == nil {
232		return
233	}
234	switch m := v.(type) {
235	case map[string]any:
236		for k, v := range m {
237			// If the value is empty, delete the key from the map.
238			if reflect.ValueOf(v).IsZero() || v == "0001-01-01T00:00:00Z" {
239				delete(m, k)
240			} else {
241				omitEmptyValues(v)
242			}
243		}
244	case []any:
245		for _, item := range m {
246			omitEmptyValues(item)
247		}
248	case []map[string]any:
249		for _, item := range m {
250			omitEmptyValues(item)
251		}
252	}
253}
254
255func convertKeysToCamelCase(v any, parentKey string) any {
256	if v == nil {
257		return nil
258	}
259	switch m := v.(type) {
260	case map[string]any:
261		newMap := make(map[string]any)
262		for key, value := range m {
263			camelCaseKey := toCamelCase(key)
264			if parentKey == "response" && key == "body_segments" {
265				newMap[camelCaseKey] = value
266			} else {
267				newMap[camelCaseKey] = convertKeysToCamelCase(value, key)
268			}
269		}
270		return newMap
271	case []any:
272		newSlice := make([]any, len(m))
273		for i, item := range m {
274			newSlice[i] = convertKeysToCamelCase(item, parentKey)
275		}
276		return newSlice
277	default:
278		return v
279	}
280}
281
282// toCamelCase converts a string from snake case to camel case.
283// Examples:
284//
285//	"foo" -> "foo"
286//	"fooBar" -> "fooBar"
287//	"foo_bar" -> "fooBar"
288//	"foo_bar_baz" -> "fooBarBaz"
289func toCamelCase(s string) string {
290	parts := strings.Split(s, "_")
291	if len(parts) == 1 {
292		// There is no underscore, so no need to modify the string.
293		return s
294	}
295	// Skip the first word and convert the first letter of the remaining words to uppercase.
296	for i, part := range parts[1:] {
297		parts[i+1] = strings.ToUpper(part[:1]) + part[1:]
298	}
299	// Concat the parts back together to mak a camelCase string.
300	return strings.Join(parts, "")
301}
302
303var stringComparator = cmp.Comparer(func(x, y string) bool {
304	if timeStringComparator(x, y) || base64StringComparator(x, y) || floatStringComparator(x, y) {
305		return true
306	}
307	return x == y
308})
309
310var floatComparator = cmp.Comparer(func(x, y float64) bool {
311	return math.Abs(x-y) < 1e-6
312})
313
314var floatStringComparator = func(x, y string) bool {
315	vx, err := strconv.ParseFloat(x, 64)
316	if err != nil {
317		return x == y
318	}
319	vy, err := strconv.ParseFloat(y, 64)
320	if err != nil {
321		return x == y
322	}
323	return math.Abs(vx-vy) < 1e-6
324}
325
326var timeStringComparator = func(x, y string) bool {
327	xTime, err := time.Parse(time.RFC3339, x)
328	if err != nil {
329		return x == y
330	}
331	yTime, err := time.Parse(time.RFC3339, y)
332	if err != nil {
333		return x == y
334	}
335	return xTime.Truncate(time.Microsecond).Equal(yTime.Truncate(time.Microsecond))
336}
337
338var base64StringComparator = func(x, y string) bool {
339	stdBase64Handler := func(s string) ([]byte, error) {
340		b, err := base64.URLEncoding.DecodeString(s)
341		if err != nil {
342			b, err = base64.StdEncoding.DecodeString(s)
343			if err != nil {
344				return nil, fmt.Errorf("invalid base64 string %s", s)
345			}
346		}
347		return b, nil
348	}
349
350	xb, err := stdBase64Handler(x)
351	if err != nil {
352		return x == y
353	}
354	yb, err := stdBase64Handler(y)
355	if err != nil {
356		return x == y
357	}
358	return bytes.Equal(xb, yb)
359}