gem_test.go

  1package gem
  2
  3import (
  4	"bytes"
  5	"context"
  6	"encoding/json"
  7	"io"
  8	"net/http"
  9	"testing"
 10
 11	"shelley.exe.dev/llm"
 12	"shelley.exe.dev/llm/gem/gemini"
 13)
 14
 15func TestBuildGeminiRequest(t *testing.T) {
 16	// Create a service
 17	service := &Service{
 18		Model:  DefaultModel,
 19		APIKey: "test-api-key",
 20	}
 21
 22	// Create a simple request
 23	req := &llm.Request{
 24		Messages: []llm.Message{
 25			{
 26				Role: llm.MessageRoleUser,
 27				Content: []llm.Content{
 28					{
 29						Type: llm.ContentTypeText,
 30						Text: "Hello, world!",
 31					},
 32				},
 33			},
 34		},
 35		System: []llm.SystemContent{
 36			{
 37				Text: "You are a helpful assistant.",
 38			},
 39		},
 40	}
 41
 42	// Build the Gemini request
 43	gemReq, err := service.buildGeminiRequest(req)
 44	if err != nil {
 45		t.Fatalf("Failed to build Gemini request: %v", err)
 46	}
 47
 48	// Verify the system instruction
 49	if gemReq.SystemInstruction == nil {
 50		t.Fatalf("Expected system instruction, got nil")
 51	}
 52	if len(gemReq.SystemInstruction.Parts) != 1 {
 53		t.Fatalf("Expected 1 system part, got %d", len(gemReq.SystemInstruction.Parts))
 54	}
 55	if gemReq.SystemInstruction.Parts[0].Text != "You are a helpful assistant." {
 56		t.Fatalf("Expected system text 'You are a helpful assistant.', got '%s'", gemReq.SystemInstruction.Parts[0].Text)
 57	}
 58
 59	// Verify the contents
 60	if len(gemReq.Contents) != 1 {
 61		t.Fatalf("Expected 1 content, got %d", len(gemReq.Contents))
 62	}
 63	if len(gemReq.Contents[0].Parts) != 1 {
 64		t.Fatalf("Expected 1 part, got %d", len(gemReq.Contents[0].Parts))
 65	}
 66	if gemReq.Contents[0].Parts[0].Text != "Hello, world!" {
 67		t.Fatalf("Expected text 'Hello, world!', got '%s'", gemReq.Contents[0].Parts[0].Text)
 68	}
 69	// Verify the role is set correctly
 70	if gemReq.Contents[0].Role != "user" {
 71		t.Fatalf("Expected role 'user', got '%s'", gemReq.Contents[0].Role)
 72	}
 73}
 74
 75func TestConvertToolSchemas(t *testing.T) {
 76	// Create a simple tool with a JSON schema
 77	schema := `{
 78		"type": "object",
 79		"properties": {
 80			"name": {
 81				"type": "string",
 82				"description": "The name of the person"
 83			},
 84			"age": {
 85				"type": "integer",
 86				"description": "The age of the person"
 87			}
 88		},
 89		"required": ["name"]
 90	}`
 91
 92	tools := []*llm.Tool{
 93		{
 94			Name:        "get_person",
 95			Description: "Get information about a person",
 96			InputSchema: json.RawMessage(schema),
 97		},
 98	}
 99
100	// Convert the tools
101	decls, err := convertToolSchemas(tools)
102	if err != nil {
103		t.Fatalf("Failed to convert tool schemas: %v", err)
104	}
105
106	// Verify the result
107	if len(decls) != 1 {
108		t.Fatalf("Expected 1 declaration, got %d", len(decls))
109	}
110	if decls[0].Name != "get_person" {
111		t.Fatalf("Expected name 'get_person', got '%s'", decls[0].Name)
112	}
113	if decls[0].Description != "Get information about a person" {
114		t.Fatalf("Expected description 'Get information about a person', got '%s'", decls[0].Description)
115	}
116
117	// Verify the schema properties
118	if decls[0].Parameters.Type != 6 { // DataTypeOBJECT
119		t.Fatalf("Expected type OBJECT (6), got %d", decls[0].Parameters.Type)
120	}
121	if len(decls[0].Parameters.Properties) != 2 {
122		t.Fatalf("Expected 2 properties, got %d", len(decls[0].Parameters.Properties))
123	}
124	if decls[0].Parameters.Properties["name"].Type != 1 { // DataTypeSTRING
125		t.Fatalf("Expected name type STRING (1), got %d", decls[0].Parameters.Properties["name"].Type)
126	}
127	if decls[0].Parameters.Properties["age"].Type != 3 { // DataTypeINTEGER
128		t.Fatalf("Expected age type INTEGER (3), got %d", decls[0].Parameters.Properties["age"].Type)
129	}
130	if len(decls[0].Parameters.Required) != 1 || decls[0].Parameters.Required[0] != "name" {
131		t.Fatalf("Expected required field 'name', got %v", decls[0].Parameters.Required)
132	}
133}
134
135func TestService_Do_MockResponse(t *testing.T) {
136	// This is a mock test that doesn't make actual API calls
137	// Create a mock HTTP client that returns a predefined response
138
139	// Create a Service with a mock client
140	service := &Service{
141		Model:  DefaultModel,
142		APIKey: "test-api-key",
143		// We would use a mock HTTP client here in a real test
144	}
145
146	// Create a sample request
147	ir := &llm.Request{
148		Messages: []llm.Message{
149			{
150				Role: llm.MessageRoleUser,
151				Content: []llm.Content{
152					{
153						Type: llm.ContentTypeText,
154						Text: "Hello",
155					},
156				},
157			},
158		},
159	}
160
161	// In a real test, we would execute service.Do with a mock client
162	// and verify the response structure
163
164	// For now, we'll just test that buildGeminiRequest works correctly
165	_, err := service.buildGeminiRequest(ir)
166	if err != nil {
167		t.Fatalf("Failed to build request: %v", err)
168	}
169}
170
171func TestConvertResponseWithToolCall(t *testing.T) {
172	// Create a mock Gemini response with a function call
173	gemRes := &gemini.Response{
174		Candidates: []gemini.Candidate{
175			{
176				Content: gemini.Content{
177					Parts: []gemini.Part{
178						{
179							FunctionCall: &gemini.FunctionCall{
180								Name: "bash",
181								Args: map[string]any{
182									"command": "cat README.md",
183								},
184							},
185						},
186					},
187				},
188			},
189		},
190	}
191
192	// Convert the response
193	content := convertGeminiResponseToContent(gemRes)
194
195	// Verify that content has a tool use
196	if len(content) != 1 {
197		t.Fatalf("Expected 1 content item, got %d", len(content))
198	}
199
200	if content[0].Type != llm.ContentTypeToolUse {
201		t.Fatalf("Expected content type ToolUse, got %s", content[0].Type)
202	}
203
204	if content[0].ToolName != "bash" {
205		t.Fatalf("Expected tool name 'bash', got '%s'", content[0].ToolName)
206	}
207
208	// Verify the tool input
209	var args map[string]any
210	if err := json.Unmarshal(content[0].ToolInput, &args); err != nil {
211		t.Fatalf("Failed to unmarshal tool input: %v", err)
212	}
213
214	cmd, ok := args["command"]
215	if !ok {
216		t.Fatalf("Expected 'command' argument, not found")
217	}
218
219	if cmd != "cat README.md" {
220		t.Fatalf("Expected command 'cat README.md', got '%s'", cmd)
221	}
222}
223
224func TestGeminiHeaderCapture(t *testing.T) {
225	// Create a mock HTTP client that returns a response with headers
226	mockClient := &http.Client{
227		Transport: &mockRoundTripper{
228			response: &http.Response{
229				StatusCode: http.StatusOK,
230				Header: http.Header{
231					"Content-Type":            []string{"application/json"},
232					"Skaband-Cost-Microcents": []string{"123456"},
233				},
234				Body: io.NopCloser(bytes.NewBufferString(`{
235					"candidates": [{
236						"content": {
237							"parts": [{
238								"text": "Hello!"
239							}]
240						}
241					}]
242				}`)),
243			},
244		},
245	}
246
247	// Create a Gemini model with the mock client
248	model := gemini.Model{
249		Model:    "models/gemini-test",
250		APIKey:   "test-key",
251		HTTPC:    mockClient,
252		Endpoint: "https://test.googleapis.com",
253	}
254
255	// Make a request
256	req := &gemini.Request{
257		Contents: []gemini.Content{
258			{
259				Parts: []gemini.Part{{Text: "Hello"}},
260				Role:  "user",
261			},
262		},
263	}
264
265	ctx := context.Background()
266	res, err := model.GenerateContent(ctx, req)
267	if err != nil {
268		t.Fatalf("Failed to generate content: %v", err)
269	}
270
271	// Verify that headers were captured
272	headers := res.Header()
273	if headers == nil {
274		t.Fatalf("Expected headers to be captured, got nil")
275	}
276
277	// Check for the cost header
278	costHeader := headers.Get("Skaband-Cost-Microcents")
279	if costHeader != "123456" {
280		t.Fatalf("Expected cost header '123456', got '%s'", costHeader)
281	}
282
283	// Verify that llm.CostUSDFromResponse works with these headers
284	costUSD := llm.CostUSDFromResponse(headers)
285	expectedCost := 0.00123456 // 123456 microcents / 100,000,000
286	if costUSD != expectedCost {
287		t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, costUSD)
288	}
289}
290
291// mockRoundTripper is a mock HTTP transport for testing
292type mockRoundTripper struct {
293	response *http.Response
294}
295
296func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
297	return m.response, nil
298}
299
300func TestHeaderCostIntegration(t *testing.T) {
301	// Create a mock HTTP client that returns a response with cost headers
302	mockClient := &http.Client{
303		Transport: &mockRoundTripper{
304			response: &http.Response{
305				StatusCode: http.StatusOK,
306				Header: http.Header{
307					"Content-Type":            []string{"application/json"},
308					"Skaband-Cost-Microcents": []string{"50000"}, // 0.5 USD
309				},
310				Body: io.NopCloser(bytes.NewBufferString(`{
311					"candidates": [{
312						"content": {
313							"parts": [{
314								"text": "Test response"
315							}]
316						}
317					}]
318				}`)),
319			},
320		},
321	}
322
323	// Create a Gem service with the mock client
324	service := &Service{
325		Model:  "gemini-test",
326		APIKey: "test-key",
327		HTTPC:  mockClient,
328		URL:    "https://test.googleapis.com",
329	}
330
331	// Create a request
332	ir := &llm.Request{
333		Messages: []llm.Message{
334			{
335				Role: llm.MessageRoleUser,
336				Content: []llm.Content{
337					{
338						Type: llm.ContentTypeText,
339						Text: "Hello",
340					},
341				},
342			},
343		},
344	}
345
346	// Make the request
347	ctx := context.Background()
348	res, err := service.Do(ctx, ir)
349	if err != nil {
350		t.Fatalf("Failed to make request: %v", err)
351	}
352
353	// Verify that the cost was captured from headers
354	expectedCost := 0.0005 // 50000 microcents / 100,000,000
355	if res.Usage.CostUSD != expectedCost {
356		t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, res.Usage.CostUSD)
357	}
358
359	// Verify token counts are still estimated
360	if res.Usage.InputTokens == 0 {
361		t.Fatalf("Expected input tokens to be estimated, got 0")
362	}
363	if res.Usage.OutputTokens == 0 {
364		t.Fatalf("Expected output tokens to be estimated, got 0")
365	}
366}
367
368func TestTokenContextWindow(t *testing.T) {
369	tests := []struct {
370		name     string
371		model    string
372		expected int
373	}{
374		{
375			name:     "gemini-3-pro-preview",
376			model:    "gemini-3-pro-preview",
377			expected: 1000000,
378		},
379		{
380			name:     "gemini-3-flash-preview",
381			model:    "gemini-3-flash-preview",
382			expected: 1000000,
383		},
384		{
385			name:     "gemini-2.5-pro",
386			model:    "gemini-2.5-pro",
387			expected: 1000000,
388		},
389		{
390			name:     "gemini-2.5-flash",
391			model:    "gemini-2.5-flash",
392			expected: 1000000,
393		},
394		{
395			name:     "gemini-2.0-flash-exp",
396			model:    "gemini-2.0-flash-exp",
397			expected: 1000000,
398		},
399		{
400			name:     "gemini-2.0-flash",
401			model:    "gemini-2.0-flash",
402			expected: 1000000,
403		},
404		{
405			name:     "gemini-1.5-pro",
406			model:    "gemini-1.5-pro",
407			expected: 2000000,
408		},
409		{
410			name:     "gemini-1.5-pro-latest",
411			model:    "gemini-1.5-pro-latest",
412			expected: 2000000,
413		},
414		{
415			name:     "gemini-1.5-flash",
416			model:    "gemini-1.5-flash",
417			expected: 1000000,
418		},
419		{
420			name:     "gemini-1.5-flash-latest",
421			model:    "gemini-1.5-flash-latest",
422			expected: 1000000,
423		},
424		{
425			name:     "default model",
426			model:    "",
427			expected: 1000000,
428		},
429		{
430			name:     "unknown model",
431			model:    "unknown-model",
432			expected: 1000000,
433		},
434	}
435
436	for _, tt := range tests {
437		t.Run(tt.name, func(t *testing.T) {
438			service := &Service{
439				Model: tt.model,
440			}
441			got := service.TokenContextWindow()
442			if got != tt.expected {
443				t.Errorf("TokenContextWindow() = %v, want %v", got, tt.expected)
444			}
445		})
446	}
447}
448
449func TestMaxImageDimension(t *testing.T) {
450	service := &Service{}
451	got := service.MaxImageDimension()
452	// Currently returns 0 as per implementation
453	expected := 0
454	if got != expected {
455		t.Errorf("MaxImageDimension() = %v, want %v", got, expected)
456	}
457}
458
459func TestEnsureToolIDs(t *testing.T) {
460	tests := []struct {
461		name     string
462		contents []llm.Content
463		wantIDs  bool
464	}{
465		{
466			name: "no tool uses",
467			contents: []llm.Content{
468				{
469					Type: llm.ContentTypeText,
470					Text: "Hello",
471				},
472			},
473			wantIDs: false,
474		},
475		{
476			name: "tool use with existing ID",
477			contents: []llm.Content{
478				{
479					ID:       "existing-id",
480					Type:     llm.ContentTypeToolUse,
481					ToolName: "test-tool",
482				},
483			},
484			wantIDs: true,
485		},
486		{
487			name: "tool use without ID",
488			contents: []llm.Content{
489				{
490					Type:     llm.ContentTypeToolUse,
491					ToolName: "test-tool",
492				},
493			},
494			wantIDs: true,
495		},
496		{
497			name: "mixed content",
498			contents: []llm.Content{
499				{
500					Type: llm.ContentTypeText,
501					Text: "Hello",
502				},
503				{
504					Type:     llm.ContentTypeToolUse,
505					ToolName: "test-tool",
506				},
507				{
508					ID:       "existing-id",
509					Type:     llm.ContentTypeToolUse,
510					ToolName: "test-tool-2",
511				},
512			},
513			wantIDs: true,
514		},
515	}
516
517	for _, tt := range tests {
518		t.Run(tt.name, func(t *testing.T) {
519			// Make a copy to avoid modifying the test data
520			contents := make([]llm.Content, len(tt.contents))
521			copy(contents, tt.contents)
522
523			ensureToolIDs(contents)
524
525			// Check if tool uses have IDs
526			hasGeneratedIDs := false
527			for _, content := range contents {
528				if content.Type == llm.ContentTypeToolUse {
529					if content.ID == "" {
530						t.Errorf("Tool use missing ID")
531					} else if content.ID != "existing-id" {
532						// This is a generated ID
533						hasGeneratedIDs = true
534					}
535				}
536			}
537
538			// If we expected IDs to be generated, check that at least one was
539			if tt.wantIDs && !hasGeneratedIDs {
540				// Check if all tool uses had existing IDs
541				hasExistingIDs := false
542				for _, content := range tt.contents {
543					if content.Type == llm.ContentTypeToolUse && content.ID != "" {
544						hasExistingIDs = true
545					}
546				}
547				if !hasExistingIDs {
548					t.Errorf("Expected generated IDs but none were found")
549				}
550			}
551		})
552	}
553}
554
555func TestCalculateUsage(t *testing.T) {
556	// Test with a simple request and response
557	req := &gemini.Request{
558		SystemInstruction: &gemini.Content{
559			Parts: []gemini.Part{
560				{Text: "You are a helpful assistant."},
561			},
562		},
563		Contents: []gemini.Content{
564			{
565				Parts: []gemini.Part{
566					{Text: "Hello, how are you?"},
567				},
568				Role: "user",
569			},
570		},
571	}
572
573	res := &gemini.Response{
574		Candidates: []gemini.Candidate{
575			{
576				Content: gemini.Content{
577					Parts: []gemini.Part{
578						{Text: "I'm doing well, thank you for asking!"},
579					},
580				},
581			},
582		},
583	}
584
585	usage := calculateUsage(req, res)
586
587	// Verify that we got some token counts (they'll be estimated)
588	if usage.InputTokens == 0 {
589		t.Errorf("Expected input tokens to be greater than 0, got %d", usage.InputTokens)
590	}
591	if usage.OutputTokens == 0 {
592		t.Errorf("Expected output tokens to be greater than 0, got %d", usage.OutputTokens)
593	}
594
595	// Test with nil response
596	usageNil := calculateUsage(req, nil)
597	if usageNil.InputTokens == 0 {
598		t.Errorf("Expected input tokens with nil response to be greater than 0, got %d", usageNil.InputTokens)
599	}
600	if usageNil.OutputTokens != 0 {
601		t.Errorf("Expected output tokens with nil response to be 0, got %d", usageNil.OutputTokens)
602	}
603
604	// Test with function calls
605	reqWithFunction := &gemini.Request{
606		Contents: []gemini.Content{
607			{
608				Parts: []gemini.Part{
609					{
610						FunctionCall: &gemini.FunctionCall{
611							Name: "test_function",
612							Args: map[string]any{
613								"param1": "value1",
614							},
615						},
616					},
617				},
618				Role: "user",
619			},
620		},
621	}
622
623	resWithFunction := &gemini.Response{
624		Candidates: []gemini.Candidate{
625			{
626				Content: gemini.Content{
627					Parts: []gemini.Part{
628						{
629							FunctionCall: &gemini.FunctionCall{
630								Name: "response_function",
631								Args: map[string]any{
632									"result": "success",
633								},
634							},
635						},
636					},
637				},
638			},
639		},
640	}
641
642	usageWithFunction := calculateUsage(reqWithFunction, resWithFunction)
643	if usageWithFunction.InputTokens == 0 {
644		t.Errorf("Expected input tokens with function calls to be greater than 0, got %d", usageWithFunction.InputTokens)
645	}
646	if usageWithFunction.OutputTokens == 0 {
647		t.Errorf("Expected output tokens with function calls to be greater than 0, got %d", usageWithFunction.OutputTokens)
648	}
649}
650
651func TestCalculateUsageWithFunctionResponse(t *testing.T) {
652	// Test with function response in input (tool result)
653	reqWithFunctionResponse := &gemini.Request{
654		Contents: []gemini.Content{
655			{
656				Parts: []gemini.Part{
657					{
658						FunctionResponse: &gemini.FunctionResponse{
659							Name: "test_function",
660							Response: map[string]any{
661								"result": "success",
662								"error":  nil,
663							},
664						},
665					},
666				},
667				Role: "user",
668			},
669		},
670	}
671
672	res := &gemini.Response{
673		Candidates: []gemini.Candidate{
674			{
675				Content: gemini.Content{
676					Parts: []gemini.Part{
677						{Text: "Hello"},
678					},
679				},
680			},
681		},
682	}
683
684	usage := calculateUsage(reqWithFunctionResponse, res)
685	// Should have some input tokens from the function response
686	if usage.InputTokens == 0 {
687		t.Errorf("Expected input tokens with function response to be greater than 0, got %d", usage.InputTokens)
688	}
689	if usage.OutputTokens == 0 {
690		t.Errorf("Expected output tokens to be greater than 0, got %d", usage.OutputTokens)
691	}
692}
693
694func TestCalculateUsageWithEmptyText(t *testing.T) {
695	// Test with empty text parts
696	req := &gemini.Request{
697		Contents: []gemini.Content{
698			{
699				Parts: []gemini.Part{
700					{Text: ""}, // Empty text
701				},
702				Role: "user",
703			},
704		},
705	}
706
707	res := &gemini.Response{
708		Candidates: []gemini.Candidate{
709			{
710				Content: gemini.Content{
711					Parts: []gemini.Part{
712						{Text: ""}, // Empty text
713					},
714				},
715			},
716		},
717	}
718
719	usage := calculateUsage(req, res)
720	// Should have 0 tokens for empty text
721	if usage.InputTokens != 0 {
722		t.Errorf("Expected input tokens to be 0 for empty text, got %d", usage.InputTokens)
723	}
724	if usage.OutputTokens != 0 {
725		t.Errorf("Expected output tokens to be 0 for empty text, got %d", usage.OutputTokens)
726	}
727}
728
729func TestCalculateUsageWithComplexFunctionCall(t *testing.T) {
730	// Test with complex function call arguments
731	req := &gemini.Request{
732		Contents: []gemini.Content{
733			{
734				Parts: []gemini.Part{
735					{
736						FunctionCall: &gemini.FunctionCall{
737							Name: "complex_function",
738							Args: map[string]any{
739								"string_param": "value",
740								"int_param":    42,
741								"array_param":  []any{"item1", "item2"},
742								"object_param": map[string]any{
743									"nested": "value",
744								},
745							},
746						},
747					},
748				},
749				Role: "user",
750			},
751		},
752	}
753
754	res := &gemini.Response{
755		Candidates: []gemini.Candidate{
756			{
757				Content: gemini.Content{
758					Parts: []gemini.Part{
759						{
760							FunctionCall: &gemini.FunctionCall{
761								Name: "response_function",
762								Args: map[string]any{
763									"complex_result": map[string]any{
764										"status": "success",
765										"data":   []any{1, 2, 3},
766									},
767								},
768							},
769						},
770					},
771				},
772			},
773		},
774	}
775
776	usage := calculateUsage(req, res)
777	if usage.InputTokens == 0 {
778		t.Errorf("Expected input tokens with complex function call to be greater than 0, got %d", usage.InputTokens)
779	}
780	if usage.OutputTokens == 0 {
781		t.Errorf("Expected output tokens with complex function call to be greater than 0, got %d", usage.OutputTokens)
782	}
783}