1package gem
2
3import (
4 "bytes"
5 "context"
6 "encoding/json"
7 "io"
8 "net/http"
9 "testing"
10
11 "shelley.exe.dev/llm"
12 "shelley.exe.dev/llm/gem/gemini"
13)
14
15func TestBuildGeminiRequest(t *testing.T) {
16 // Create a service
17 service := &Service{
18 Model: DefaultModel,
19 APIKey: "test-api-key",
20 }
21
22 // Create a simple request
23 req := &llm.Request{
24 Messages: []llm.Message{
25 {
26 Role: llm.MessageRoleUser,
27 Content: []llm.Content{
28 {
29 Type: llm.ContentTypeText,
30 Text: "Hello, world!",
31 },
32 },
33 },
34 },
35 System: []llm.SystemContent{
36 {
37 Text: "You are a helpful assistant.",
38 },
39 },
40 }
41
42 // Build the Gemini request
43 gemReq, err := service.buildGeminiRequest(req)
44 if err != nil {
45 t.Fatalf("Failed to build Gemini request: %v", err)
46 }
47
48 // Verify the system instruction
49 if gemReq.SystemInstruction == nil {
50 t.Fatalf("Expected system instruction, got nil")
51 }
52 if len(gemReq.SystemInstruction.Parts) != 1 {
53 t.Fatalf("Expected 1 system part, got %d", len(gemReq.SystemInstruction.Parts))
54 }
55 if gemReq.SystemInstruction.Parts[0].Text != "You are a helpful assistant." {
56 t.Fatalf("Expected system text 'You are a helpful assistant.', got '%s'", gemReq.SystemInstruction.Parts[0].Text)
57 }
58
59 // Verify the contents
60 if len(gemReq.Contents) != 1 {
61 t.Fatalf("Expected 1 content, got %d", len(gemReq.Contents))
62 }
63 if len(gemReq.Contents[0].Parts) != 1 {
64 t.Fatalf("Expected 1 part, got %d", len(gemReq.Contents[0].Parts))
65 }
66 if gemReq.Contents[0].Parts[0].Text != "Hello, world!" {
67 t.Fatalf("Expected text 'Hello, world!', got '%s'", gemReq.Contents[0].Parts[0].Text)
68 }
69 // Verify the role is set correctly
70 if gemReq.Contents[0].Role != "user" {
71 t.Fatalf("Expected role 'user', got '%s'", gemReq.Contents[0].Role)
72 }
73}
74
75func TestConvertToolSchemas(t *testing.T) {
76 // Create a simple tool with a JSON schema
77 schema := `{
78 "type": "object",
79 "properties": {
80 "name": {
81 "type": "string",
82 "description": "The name of the person"
83 },
84 "age": {
85 "type": "integer",
86 "description": "The age of the person"
87 }
88 },
89 "required": ["name"]
90 }`
91
92 tools := []*llm.Tool{
93 {
94 Name: "get_person",
95 Description: "Get information about a person",
96 InputSchema: json.RawMessage(schema),
97 },
98 }
99
100 // Convert the tools
101 decls, err := convertToolSchemas(tools)
102 if err != nil {
103 t.Fatalf("Failed to convert tool schemas: %v", err)
104 }
105
106 // Verify the result
107 if len(decls) != 1 {
108 t.Fatalf("Expected 1 declaration, got %d", len(decls))
109 }
110 if decls[0].Name != "get_person" {
111 t.Fatalf("Expected name 'get_person', got '%s'", decls[0].Name)
112 }
113 if decls[0].Description != "Get information about a person" {
114 t.Fatalf("Expected description 'Get information about a person', got '%s'", decls[0].Description)
115 }
116
117 // Verify the schema properties
118 if decls[0].Parameters.Type != 6 { // DataTypeOBJECT
119 t.Fatalf("Expected type OBJECT (6), got %d", decls[0].Parameters.Type)
120 }
121 if len(decls[0].Parameters.Properties) != 2 {
122 t.Fatalf("Expected 2 properties, got %d", len(decls[0].Parameters.Properties))
123 }
124 if decls[0].Parameters.Properties["name"].Type != 1 { // DataTypeSTRING
125 t.Fatalf("Expected name type STRING (1), got %d", decls[0].Parameters.Properties["name"].Type)
126 }
127 if decls[0].Parameters.Properties["age"].Type != 3 { // DataTypeINTEGER
128 t.Fatalf("Expected age type INTEGER (3), got %d", decls[0].Parameters.Properties["age"].Type)
129 }
130 if len(decls[0].Parameters.Required) != 1 || decls[0].Parameters.Required[0] != "name" {
131 t.Fatalf("Expected required field 'name', got %v", decls[0].Parameters.Required)
132 }
133}
134
135func TestService_Do_MockResponse(t *testing.T) {
136 // This is a mock test that doesn't make actual API calls
137 // Create a mock HTTP client that returns a predefined response
138
139 // Create a Service with a mock client
140 service := &Service{
141 Model: DefaultModel,
142 APIKey: "test-api-key",
143 // We would use a mock HTTP client here in a real test
144 }
145
146 // Create a sample request
147 ir := &llm.Request{
148 Messages: []llm.Message{
149 {
150 Role: llm.MessageRoleUser,
151 Content: []llm.Content{
152 {
153 Type: llm.ContentTypeText,
154 Text: "Hello",
155 },
156 },
157 },
158 },
159 }
160
161 // In a real test, we would execute service.Do with a mock client
162 // and verify the response structure
163
164 // For now, we'll just test that buildGeminiRequest works correctly
165 _, err := service.buildGeminiRequest(ir)
166 if err != nil {
167 t.Fatalf("Failed to build request: %v", err)
168 }
169}
170
171func TestConvertResponseWithToolCall(t *testing.T) {
172 // Create a mock Gemini response with a function call
173 gemRes := &gemini.Response{
174 Candidates: []gemini.Candidate{
175 {
176 Content: gemini.Content{
177 Parts: []gemini.Part{
178 {
179 FunctionCall: &gemini.FunctionCall{
180 Name: "bash",
181 Args: map[string]any{
182 "command": "cat README.md",
183 },
184 },
185 },
186 },
187 },
188 },
189 },
190 }
191
192 // Convert the response
193 content := convertGeminiResponseToContent(gemRes)
194
195 // Verify that content has a tool use
196 if len(content) != 1 {
197 t.Fatalf("Expected 1 content item, got %d", len(content))
198 }
199
200 if content[0].Type != llm.ContentTypeToolUse {
201 t.Fatalf("Expected content type ToolUse, got %s", content[0].Type)
202 }
203
204 if content[0].ToolName != "bash" {
205 t.Fatalf("Expected tool name 'bash', got '%s'", content[0].ToolName)
206 }
207
208 // Verify the tool input
209 var args map[string]any
210 if err := json.Unmarshal(content[0].ToolInput, &args); err != nil {
211 t.Fatalf("Failed to unmarshal tool input: %v", err)
212 }
213
214 cmd, ok := args["command"]
215 if !ok {
216 t.Fatalf("Expected 'command' argument, not found")
217 }
218
219 if cmd != "cat README.md" {
220 t.Fatalf("Expected command 'cat README.md', got '%s'", cmd)
221 }
222}
223
224func TestGeminiHeaderCapture(t *testing.T) {
225 // Create a mock HTTP client that returns a response with headers
226 mockClient := &http.Client{
227 Transport: &mockRoundTripper{
228 response: &http.Response{
229 StatusCode: http.StatusOK,
230 Header: http.Header{
231 "Content-Type": []string{"application/json"},
232 "Skaband-Cost-Microcents": []string{"123456"},
233 },
234 Body: io.NopCloser(bytes.NewBufferString(`{
235 "candidates": [{
236 "content": {
237 "parts": [{
238 "text": "Hello!"
239 }]
240 }
241 }]
242 }`)),
243 },
244 },
245 }
246
247 // Create a Gemini model with the mock client
248 model := gemini.Model{
249 Model: "models/gemini-test",
250 APIKey: "test-key",
251 HTTPC: mockClient,
252 Endpoint: "https://test.googleapis.com",
253 }
254
255 // Make a request
256 req := &gemini.Request{
257 Contents: []gemini.Content{
258 {
259 Parts: []gemini.Part{{Text: "Hello"}},
260 Role: "user",
261 },
262 },
263 }
264
265 ctx := context.Background()
266 res, err := model.GenerateContent(ctx, req)
267 if err != nil {
268 t.Fatalf("Failed to generate content: %v", err)
269 }
270
271 // Verify that headers were captured
272 headers := res.Header()
273 if headers == nil {
274 t.Fatalf("Expected headers to be captured, got nil")
275 }
276
277 // Check for the cost header
278 costHeader := headers.Get("Skaband-Cost-Microcents")
279 if costHeader != "123456" {
280 t.Fatalf("Expected cost header '123456', got '%s'", costHeader)
281 }
282
283 // Verify that llm.CostUSDFromResponse works with these headers
284 costUSD := llm.CostUSDFromResponse(headers)
285 expectedCost := 0.00123456 // 123456 microcents / 100,000,000
286 if costUSD != expectedCost {
287 t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, costUSD)
288 }
289}
290
291// mockRoundTripper is a mock HTTP transport for testing
292type mockRoundTripper struct {
293 response *http.Response
294}
295
296func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) {
297 return m.response, nil
298}
299
300func TestHeaderCostIntegration(t *testing.T) {
301 // Create a mock HTTP client that returns a response with cost headers
302 mockClient := &http.Client{
303 Transport: &mockRoundTripper{
304 response: &http.Response{
305 StatusCode: http.StatusOK,
306 Header: http.Header{
307 "Content-Type": []string{"application/json"},
308 "Skaband-Cost-Microcents": []string{"50000"}, // 0.5 USD
309 },
310 Body: io.NopCloser(bytes.NewBufferString(`{
311 "candidates": [{
312 "content": {
313 "parts": [{
314 "text": "Test response"
315 }]
316 }
317 }]
318 }`)),
319 },
320 },
321 }
322
323 // Create a Gem service with the mock client
324 service := &Service{
325 Model: "gemini-test",
326 APIKey: "test-key",
327 HTTPC: mockClient,
328 URL: "https://test.googleapis.com",
329 }
330
331 // Create a request
332 ir := &llm.Request{
333 Messages: []llm.Message{
334 {
335 Role: llm.MessageRoleUser,
336 Content: []llm.Content{
337 {
338 Type: llm.ContentTypeText,
339 Text: "Hello",
340 },
341 },
342 },
343 },
344 }
345
346 // Make the request
347 ctx := context.Background()
348 res, err := service.Do(ctx, ir)
349 if err != nil {
350 t.Fatalf("Failed to make request: %v", err)
351 }
352
353 // Verify that the cost was captured from headers
354 expectedCost := 0.0005 // 50000 microcents / 100,000,000
355 if res.Usage.CostUSD != expectedCost {
356 t.Fatalf("Expected cost USD %.8f, got %.8f", expectedCost, res.Usage.CostUSD)
357 }
358
359 // Verify token counts are still estimated
360 if res.Usage.InputTokens == 0 {
361 t.Fatalf("Expected input tokens to be estimated, got 0")
362 }
363 if res.Usage.OutputTokens == 0 {
364 t.Fatalf("Expected output tokens to be estimated, got 0")
365 }
366}
367
368func TestTokenContextWindow(t *testing.T) {
369 tests := []struct {
370 name string
371 model string
372 expected int
373 }{
374 {
375 name: "gemini-3-pro-preview",
376 model: "gemini-3-pro-preview",
377 expected: 1000000,
378 },
379 {
380 name: "gemini-3-flash-preview",
381 model: "gemini-3-flash-preview",
382 expected: 1000000,
383 },
384 {
385 name: "gemini-2.5-pro",
386 model: "gemini-2.5-pro",
387 expected: 1000000,
388 },
389 {
390 name: "gemini-2.5-flash",
391 model: "gemini-2.5-flash",
392 expected: 1000000,
393 },
394 {
395 name: "gemini-2.0-flash-exp",
396 model: "gemini-2.0-flash-exp",
397 expected: 1000000,
398 },
399 {
400 name: "gemini-2.0-flash",
401 model: "gemini-2.0-flash",
402 expected: 1000000,
403 },
404 {
405 name: "gemini-1.5-pro",
406 model: "gemini-1.5-pro",
407 expected: 2000000,
408 },
409 {
410 name: "gemini-1.5-pro-latest",
411 model: "gemini-1.5-pro-latest",
412 expected: 2000000,
413 },
414 {
415 name: "gemini-1.5-flash",
416 model: "gemini-1.5-flash",
417 expected: 1000000,
418 },
419 {
420 name: "gemini-1.5-flash-latest",
421 model: "gemini-1.5-flash-latest",
422 expected: 1000000,
423 },
424 {
425 name: "default model",
426 model: "",
427 expected: 1000000,
428 },
429 {
430 name: "unknown model",
431 model: "unknown-model",
432 expected: 1000000,
433 },
434 }
435
436 for _, tt := range tests {
437 t.Run(tt.name, func(t *testing.T) {
438 service := &Service{
439 Model: tt.model,
440 }
441 got := service.TokenContextWindow()
442 if got != tt.expected {
443 t.Errorf("TokenContextWindow() = %v, want %v", got, tt.expected)
444 }
445 })
446 }
447}
448
449func TestMaxImageDimension(t *testing.T) {
450 service := &Service{}
451 got := service.MaxImageDimension()
452 // Currently returns 0 as per implementation
453 expected := 0
454 if got != expected {
455 t.Errorf("MaxImageDimension() = %v, want %v", got, expected)
456 }
457}
458
459func TestEnsureToolIDs(t *testing.T) {
460 tests := []struct {
461 name string
462 contents []llm.Content
463 wantIDs bool
464 }{
465 {
466 name: "no tool uses",
467 contents: []llm.Content{
468 {
469 Type: llm.ContentTypeText,
470 Text: "Hello",
471 },
472 },
473 wantIDs: false,
474 },
475 {
476 name: "tool use with existing ID",
477 contents: []llm.Content{
478 {
479 ID: "existing-id",
480 Type: llm.ContentTypeToolUse,
481 ToolName: "test-tool",
482 },
483 },
484 wantIDs: true,
485 },
486 {
487 name: "tool use without ID",
488 contents: []llm.Content{
489 {
490 Type: llm.ContentTypeToolUse,
491 ToolName: "test-tool",
492 },
493 },
494 wantIDs: true,
495 },
496 {
497 name: "mixed content",
498 contents: []llm.Content{
499 {
500 Type: llm.ContentTypeText,
501 Text: "Hello",
502 },
503 {
504 Type: llm.ContentTypeToolUse,
505 ToolName: "test-tool",
506 },
507 {
508 ID: "existing-id",
509 Type: llm.ContentTypeToolUse,
510 ToolName: "test-tool-2",
511 },
512 },
513 wantIDs: true,
514 },
515 }
516
517 for _, tt := range tests {
518 t.Run(tt.name, func(t *testing.T) {
519 // Make a copy to avoid modifying the test data
520 contents := make([]llm.Content, len(tt.contents))
521 copy(contents, tt.contents)
522
523 ensureToolIDs(contents)
524
525 // Check if tool uses have IDs
526 hasGeneratedIDs := false
527 for _, content := range contents {
528 if content.Type == llm.ContentTypeToolUse {
529 if content.ID == "" {
530 t.Errorf("Tool use missing ID")
531 } else if content.ID != "existing-id" {
532 // This is a generated ID
533 hasGeneratedIDs = true
534 }
535 }
536 }
537
538 // If we expected IDs to be generated, check that at least one was
539 if tt.wantIDs && !hasGeneratedIDs {
540 // Check if all tool uses had existing IDs
541 hasExistingIDs := false
542 for _, content := range tt.contents {
543 if content.Type == llm.ContentTypeToolUse && content.ID != "" {
544 hasExistingIDs = true
545 }
546 }
547 if !hasExistingIDs {
548 t.Errorf("Expected generated IDs but none were found")
549 }
550 }
551 })
552 }
553}
554
555func TestCalculateUsage(t *testing.T) {
556 // Test with a simple request and response
557 req := &gemini.Request{
558 SystemInstruction: &gemini.Content{
559 Parts: []gemini.Part{
560 {Text: "You are a helpful assistant."},
561 },
562 },
563 Contents: []gemini.Content{
564 {
565 Parts: []gemini.Part{
566 {Text: "Hello, how are you?"},
567 },
568 Role: "user",
569 },
570 },
571 }
572
573 res := &gemini.Response{
574 Candidates: []gemini.Candidate{
575 {
576 Content: gemini.Content{
577 Parts: []gemini.Part{
578 {Text: "I'm doing well, thank you for asking!"},
579 },
580 },
581 },
582 },
583 }
584
585 usage := calculateUsage(req, res)
586
587 // Verify that we got some token counts (they'll be estimated)
588 if usage.InputTokens == 0 {
589 t.Errorf("Expected input tokens to be greater than 0, got %d", usage.InputTokens)
590 }
591 if usage.OutputTokens == 0 {
592 t.Errorf("Expected output tokens to be greater than 0, got %d", usage.OutputTokens)
593 }
594
595 // Test with nil response
596 usageNil := calculateUsage(req, nil)
597 if usageNil.InputTokens == 0 {
598 t.Errorf("Expected input tokens with nil response to be greater than 0, got %d", usageNil.InputTokens)
599 }
600 if usageNil.OutputTokens != 0 {
601 t.Errorf("Expected output tokens with nil response to be 0, got %d", usageNil.OutputTokens)
602 }
603
604 // Test with function calls
605 reqWithFunction := &gemini.Request{
606 Contents: []gemini.Content{
607 {
608 Parts: []gemini.Part{
609 {
610 FunctionCall: &gemini.FunctionCall{
611 Name: "test_function",
612 Args: map[string]any{
613 "param1": "value1",
614 },
615 },
616 },
617 },
618 Role: "user",
619 },
620 },
621 }
622
623 resWithFunction := &gemini.Response{
624 Candidates: []gemini.Candidate{
625 {
626 Content: gemini.Content{
627 Parts: []gemini.Part{
628 {
629 FunctionCall: &gemini.FunctionCall{
630 Name: "response_function",
631 Args: map[string]any{
632 "result": "success",
633 },
634 },
635 },
636 },
637 },
638 },
639 },
640 }
641
642 usageWithFunction := calculateUsage(reqWithFunction, resWithFunction)
643 if usageWithFunction.InputTokens == 0 {
644 t.Errorf("Expected input tokens with function calls to be greater than 0, got %d", usageWithFunction.InputTokens)
645 }
646 if usageWithFunction.OutputTokens == 0 {
647 t.Errorf("Expected output tokens with function calls to be greater than 0, got %d", usageWithFunction.OutputTokens)
648 }
649}
650
651func TestCalculateUsageWithFunctionResponse(t *testing.T) {
652 // Test with function response in input (tool result)
653 reqWithFunctionResponse := &gemini.Request{
654 Contents: []gemini.Content{
655 {
656 Parts: []gemini.Part{
657 {
658 FunctionResponse: &gemini.FunctionResponse{
659 Name: "test_function",
660 Response: map[string]any{
661 "result": "success",
662 "error": nil,
663 },
664 },
665 },
666 },
667 Role: "user",
668 },
669 },
670 }
671
672 res := &gemini.Response{
673 Candidates: []gemini.Candidate{
674 {
675 Content: gemini.Content{
676 Parts: []gemini.Part{
677 {Text: "Hello"},
678 },
679 },
680 },
681 },
682 }
683
684 usage := calculateUsage(reqWithFunctionResponse, res)
685 // Should have some input tokens from the function response
686 if usage.InputTokens == 0 {
687 t.Errorf("Expected input tokens with function response to be greater than 0, got %d", usage.InputTokens)
688 }
689 if usage.OutputTokens == 0 {
690 t.Errorf("Expected output tokens to be greater than 0, got %d", usage.OutputTokens)
691 }
692}
693
694func TestCalculateUsageWithEmptyText(t *testing.T) {
695 // Test with empty text parts
696 req := &gemini.Request{
697 Contents: []gemini.Content{
698 {
699 Parts: []gemini.Part{
700 {Text: ""}, // Empty text
701 },
702 Role: "user",
703 },
704 },
705 }
706
707 res := &gemini.Response{
708 Candidates: []gemini.Candidate{
709 {
710 Content: gemini.Content{
711 Parts: []gemini.Part{
712 {Text: ""}, // Empty text
713 },
714 },
715 },
716 },
717 }
718
719 usage := calculateUsage(req, res)
720 // Should have 0 tokens for empty text
721 if usage.InputTokens != 0 {
722 t.Errorf("Expected input tokens to be 0 for empty text, got %d", usage.InputTokens)
723 }
724 if usage.OutputTokens != 0 {
725 t.Errorf("Expected output tokens to be 0 for empty text, got %d", usage.OutputTokens)
726 }
727}
728
729func TestCalculateUsageWithComplexFunctionCall(t *testing.T) {
730 // Test with complex function call arguments
731 req := &gemini.Request{
732 Contents: []gemini.Content{
733 {
734 Parts: []gemini.Part{
735 {
736 FunctionCall: &gemini.FunctionCall{
737 Name: "complex_function",
738 Args: map[string]any{
739 "string_param": "value",
740 "int_param": 42,
741 "array_param": []any{"item1", "item2"},
742 "object_param": map[string]any{
743 "nested": "value",
744 },
745 },
746 },
747 },
748 },
749 Role: "user",
750 },
751 },
752 }
753
754 res := &gemini.Response{
755 Candidates: []gemini.Candidate{
756 {
757 Content: gemini.Content{
758 Parts: []gemini.Part{
759 {
760 FunctionCall: &gemini.FunctionCall{
761 Name: "response_function",
762 Args: map[string]any{
763 "complex_result": map[string]any{
764 "status": "success",
765 "data": []any{1, 2, 3},
766 },
767 },
768 },
769 },
770 },
771 },
772 },
773 },
774 }
775
776 usage := calculateUsage(req, res)
777 if usage.InputTokens == 0 {
778 t.Errorf("Expected input tokens with complex function call to be greater than 0, got %d", usage.InputTokens)
779 }
780 if usage.OutputTokens == 0 {
781 t.Errorf("Expected output tokens with complex function call to be greater than 0, got %d", usage.OutputTokens)
782 }
783}