llm_test.go

  1package llm
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"net/http"
  8	"testing"
  9)
 10
 11// mockService implements Service interface for testing
 12type mockService struct {
 13	tokenContextWindow   int
 14	maxImageDimension    int
 15	useSimplifiedPatch   bool
 16	implementsSimplified bool
 17}
 18
 19func (m *mockService) Do(ctx context.Context, req *Request) (*Response, error) {
 20	return &Response{}, nil
 21}
 22
 23func (m *mockService) TokenContextWindow() int {
 24	return m.tokenContextWindow
 25}
 26
 27func (m *mockService) MaxImageDimension() int {
 28	return m.maxImageDimension
 29}
 30
 31// mockSimplifiedService implements both Service and SimplifiedPatcher interfaces
 32type mockSimplifiedService struct {
 33	mockService
 34}
 35
 36func (m *mockSimplifiedService) UseSimplifiedPatch() bool {
 37	return m.useSimplifiedPatch
 38}
 39
 40func TestMustSchema(t *testing.T) {
 41	tests := []struct {
 42		name        string
 43		schema      string
 44		expectPanic bool
 45	}{
 46		{
 47			name:        "valid schema",
 48			schema:      `{"type": "object", "properties": {}}`,
 49			expectPanic: false,
 50		},
 51		{
 52			name:        "valid schema with properties",
 53			schema:      `{"type": "object", "properties": {"name": {"type": "string"}}}`,
 54			expectPanic: false,
 55		},
 56		{
 57			name:        "invalid json",
 58			schema:      `{"type": "object", "properties": }`,
 59			expectPanic: true,
 60		},
 61		{
 62			name:        "missing type",
 63			schema:      `{"properties": {}}`,
 64			expectPanic: true,
 65		},
 66		{
 67			name:        "wrong type",
 68			schema:      `{"type": "string", "properties": {}}`,
 69			expectPanic: true,
 70		},
 71		{
 72			name:        "missing properties",
 73			schema:      `{"type": "object"}`,
 74			expectPanic: true,
 75		},
 76	}
 77
 78	for _, tt := range tests {
 79		t.Run(tt.name, func(t *testing.T) {
 80			if tt.expectPanic {
 81				defer func() {
 82					if r := recover(); r == nil {
 83						t.Errorf("Expected panic for schema: %s", tt.schema)
 84					}
 85				}()
 86			}
 87			result := MustSchema(tt.schema)
 88			if !tt.expectPanic {
 89				if string(result) != tt.schema {
 90					t.Errorf("MustSchema() = %s, want %s", string(result), tt.schema)
 91				}
 92			}
 93		})
 94	}
 95}
 96
 97func TestEmptySchema(t *testing.T) {
 98	schema := EmptySchema()
 99	expected := `{"type": "object", "properties": {}}`
100	if string(schema) != expected {
101		t.Errorf("EmptySchema() = %s, want %s", string(schema), expected)
102	}
103}
104
105func TestUseSimplifiedPatch(t *testing.T) {
106	tests := []struct {
107		name     string
108		service  Service
109		expected bool
110	}{
111		{
112			name: "service without SimplifiedPatcher",
113			service: &mockService{
114				implementsSimplified: false,
115				useSimplifiedPatch:   false,
116			},
117			expected: false,
118		},
119		{
120			name: "service with SimplifiedPatcher returning false",
121			service: &mockSimplifiedService{
122				mockService: mockService{
123					implementsSimplified: true,
124					useSimplifiedPatch:   false,
125				},
126			},
127			expected: false,
128		},
129		{
130			name: "service with SimplifiedPatcher returning true",
131			service: &mockSimplifiedService{
132				mockService: mockService{
133					implementsSimplified: true,
134					useSimplifiedPatch:   true,
135				},
136			},
137			expected: true,
138		},
139	}
140
141	for _, tt := range tests {
142		t.Run(tt.name, func(t *testing.T) {
143			result := UseSimplifiedPatch(tt.service)
144			if result != tt.expected {
145				t.Errorf("UseSimplifiedPatch() = %v, want %v", result, tt.expected)
146			}
147		})
148	}
149}
150
151func TestStringContent(t *testing.T) {
152	text := "test content"
153	content := StringContent(text)
154
155	if content.Type != ContentTypeText {
156		t.Errorf("StringContent().Type = %v, want %v", content.Type, ContentTypeText)
157	}
158
159	if content.Text != text {
160		t.Errorf("StringContent().Text = %s, want %s", content.Text, text)
161	}
162}
163
164func TestTextContent(t *testing.T) {
165	text := "test text content"
166	contents := TextContent(text)
167
168	if len(contents) != 1 {
169		t.Errorf("TextContent() returned %d items, want 1", len(contents))
170	}
171
172	if contents[0].Type != ContentTypeText {
173		t.Errorf("TextContent()[0].Type = %v, want %v", contents[0].Type, ContentTypeText)
174	}
175
176	if contents[0].Text != text {
177		t.Errorf("TextContent()[0].Text = %s, want %s", contents[0].Text, text)
178	}
179}
180
181func TestUserStringMessage(t *testing.T) {
182	text := "user message"
183	message := UserStringMessage(text)
184
185	if message.Role != MessageRoleUser {
186		t.Errorf("UserStringMessage().Role = %v, want %v", message.Role, MessageRoleUser)
187	}
188
189	if len(message.Content) != 1 {
190		t.Errorf("UserStringMessage().Content length = %d, want 1", len(message.Content))
191	}
192
193	if message.Content[0].Type != ContentTypeText {
194		t.Errorf("UserStringMessage().Content[0].Type = %v, want %v", message.Content[0].Type, ContentTypeText)
195	}
196
197	if message.Content[0].Text != text {
198		t.Errorf("UserStringMessage().Content[0].Text = %s, want %s", message.Content[0].Text, text)
199	}
200}
201
202func TestErrorToolOut(t *testing.T) {
203	err := fmt.Errorf("test error")
204	toolOut := ErrorToolOut(err)
205
206	if toolOut.Error != err {
207		t.Errorf("ErrorToolOut().Error = %v, want %v", toolOut.Error, err)
208	}
209
210	// Test panic with nil error
211	defer func() {
212		if r := recover(); r == nil {
213			t.Errorf("Expected panic when calling ErrorToolOut with nil error")
214		}
215	}()
216	ErrorToolOut(nil)
217}
218
219func TestErrorfToolOut(t *testing.T) {
220	format := "error: %s"
221	arg := "test"
222	toolOut := ErrorfToolOut(format, arg)
223
224	if toolOut.Error == nil {
225		t.Errorf("ErrorfToolOut().Error = nil, want error")
226	}
227
228	expected := fmt.Sprintf(format, arg)
229	if toolOut.Error.Error() != expected {
230		t.Errorf("ErrorfToolOut().Error = %v, want %v", toolOut.Error.Error(), expected)
231	}
232}
233
234func TestUsageAdd(t *testing.T) {
235	u1 := Usage{
236		InputTokens:              100,
237		CacheCreationInputTokens: 50,
238		CacheReadInputTokens:     25,
239		OutputTokens:             200,
240		CostUSD:                  0.01,
241	}
242
243	u2 := Usage{
244		InputTokens:              150,
245		CacheCreationInputTokens: 75,
246		CacheReadInputTokens:     30,
247		OutputTokens:             100,
248		CostUSD:                  0.02,
249	}
250
251	u1.Add(u2)
252
253	expected := Usage{
254		InputTokens:              250,  // 100 + 150
255		CacheCreationInputTokens: 125,  // 50 + 75
256		CacheReadInputTokens:     55,   // 25 + 30
257		OutputTokens:             300,  // 200 + 100
258		CostUSD:                  0.03, // 0.01 + 0.02
259	}
260
261	if u1 != expected {
262		t.Errorf("Usage.Add() resulted in %v, want %v", u1, expected)
263	}
264}
265
266func TestUsageString(t *testing.T) {
267	tests := []struct {
268		name  string
269		usage Usage
270		want  string
271	}{
272		{
273			name: "normal usage",
274			usage: Usage{
275				InputTokens:  100,
276				OutputTokens: 50,
277			},
278			want: "in: 100, out: 50",
279		},
280		{
281			name: "zero usage",
282			usage: Usage{
283				InputTokens:  0,
284				OutputTokens: 0,
285			},
286			want: "in: 0, out: 0",
287		},
288		{
289			name: "high usage",
290			usage: Usage{
291				InputTokens:  1000000,
292				OutputTokens: 500000,
293			},
294			want: "in: 1000000, out: 500000",
295		},
296	}
297
298	for _, tt := range tests {
299		t.Run(tt.name, func(t *testing.T) {
300			result := tt.usage.String()
301			if result != tt.want {
302				t.Errorf("Usage.String() = %s, want %s", result, tt.want)
303			}
304		})
305	}
306}
307
308func TestUsageIsZero(t *testing.T) {
309	tests := []struct {
310		name  string
311		usage Usage
312		want  bool
313	}{
314		{
315			name:  "zero usage",
316			usage: Usage{},
317			want:  true,
318		},
319		{
320			name: "non-zero input tokens",
321			usage: Usage{
322				InputTokens: 1,
323			},
324			want: false,
325		},
326		{
327			name: "non-zero output tokens",
328			usage: Usage{
329				OutputTokens: 1,
330			},
331			want: false,
332		},
333		{
334			name: "non-zero cost",
335			usage: Usage{
336				CostUSD: 0.01,
337			},
338			want: false,
339		},
340		{
341			name: "all fields zero",
342			usage: Usage{
343				InputTokens:              0,
344				CacheCreationInputTokens: 0,
345				CacheReadInputTokens:     0,
346				OutputTokens:             0,
347				CostUSD:                  0,
348			},
349			want: true,
350		},
351	}
352
353	for _, tt := range tests {
354		t.Run(tt.name, func(t *testing.T) {
355			result := tt.usage.IsZero()
356			if result != tt.want {
357				t.Errorf("Usage.IsZero() = %v, want %v", result, tt.want)
358			}
359		})
360	}
361}
362
363func TestResponseToMessage(t *testing.T) {
364	tests := []struct {
365		name          string
366		response      Response
367		wantRole      MessageRole
368		wantEndOfTurn bool
369	}{
370		{
371			name: "tool use stop reason",
372			response: Response{
373				Role:       MessageRoleAssistant,
374				StopReason: StopReasonToolUse,
375			},
376			wantRole:      MessageRoleAssistant,
377			wantEndOfTurn: false,
378		},
379		{
380			name: "end turn stop reason",
381			response: Response{
382				Role:       MessageRoleAssistant,
383				StopReason: StopReasonEndTurn,
384			},
385			wantRole:      MessageRoleAssistant,
386			wantEndOfTurn: true,
387		},
388		{
389			name: "max tokens stop reason",
390			response: Response{
391				Role:       MessageRoleAssistant,
392				StopReason: StopReasonMaxTokens,
393			},
394			wantRole:      MessageRoleAssistant,
395			wantEndOfTurn: true,
396		},
397	}
398
399	for _, tt := range tests {
400		t.Run(tt.name, func(t *testing.T) {
401			message := tt.response.ToMessage()
402
403			if message.Role != tt.wantRole {
404				t.Errorf("ToMessage().Role = %v, want %v", message.Role, tt.wantRole)
405			}
406
407			if message.EndOfTurn != tt.wantEndOfTurn {
408				t.Errorf("ToMessage().EndOfTurn = %v, want %v", message.EndOfTurn, tt.wantEndOfTurn)
409			}
410		})
411	}
412}
413
414func TestContentsAttr(t *testing.T) {
415	tests := []struct {
416		name     string
417		contents []Content
418	}{
419		{
420			name: "text content",
421			contents: []Content{
422				{
423					ID:   "1",
424					Type: ContentTypeText,
425					Text: "hello world",
426				},
427			},
428		},
429		{
430			name: "tool use content",
431			contents: []Content{
432				{
433					ID:        "2",
434					Type:      ContentTypeToolUse,
435					ToolName:  "test_tool",
436					ToolInput: json.RawMessage(`{"param": "value"}`),
437				},
438			},
439		},
440		{
441			name: "tool result content",
442			contents: []Content{
443				{
444					ID:         "3",
445					Type:       ContentTypeToolResult,
446					ToolResult: []Content{{Type: ContentTypeText, Text: "result"}},
447					ToolError:  false,
448				},
449			},
450		},
451		{
452			name: "thinking content",
453			contents: []Content{
454				{
455					ID:   "4",
456					Type: ContentTypeThinking,
457					Text: "thinking...",
458				},
459			},
460		},
461		{
462			name:     "empty contents",
463			contents: []Content{},
464		},
465	}
466
467	for _, tt := range tests {
468		t.Run(tt.name, func(t *testing.T) {
469			attr := ContentsAttr(tt.contents)
470			if attr.Key != "contents" {
471				t.Errorf("ContentsAttr().Key = %s, want 'contents'", attr.Key)
472			}
473		})
474	}
475}
476
477func TestCostUSDFromResponse(t *testing.T) {
478	tests := []struct {
479		name     string
480		headers  map[string]string
481		wantCost float64
482	}{
483		{
484			name: "valid cost header",
485			headers: map[string]string{
486				"Skaband-Cost-Microcents": "10000000", // 0.1 USD
487			},
488			wantCost: 0.1,
489		},
490		{
491			name: "invalid cost header",
492			headers: map[string]string{
493				"Skaband-Cost-Microcents": "invalid",
494			},
495			wantCost: 0,
496		},
497		{
498			name:     "missing cost header",
499			headers:  map[string]string{},
500			wantCost: 0,
501		},
502		{
503			name: "empty cost header",
504			headers: map[string]string{
505				"Skaband-Cost-Microcents": "",
506			},
507			wantCost: 0,
508		},
509	}
510
511	for _, tt := range tests {
512		t.Run(tt.name, func(t *testing.T) {
513			headers := make(http.Header)
514			for k, v := range tt.headers {
515				headers.Set(k, v)
516			}
517
518			cost := CostUSDFromResponse(headers)
519			if cost != tt.wantCost {
520				t.Errorf("CostUSDFromResponse() = %f, want %f", cost, tt.wantCost)
521			}
522		})
523	}
524}
525
526func TestUsageAttr(t *testing.T) {
527	usage := Usage{
528		InputTokens:              100,
529		OutputTokens:             50,
530		CacheCreationInputTokens: 25,
531		CacheReadInputTokens:     75,
532		CostUSD:                  0.01,
533	}
534
535	attr := usage.Attr()
536	if attr.Key != "usage" {
537		t.Errorf("Attr().Key = %s, want 'usage'", attr.Key)
538	}
539}
540
541func TestDumpToFile(t *testing.T) {
542	// This test just verifies the function exists and can be called
543	// We don't actually want to write files during testing
544	// So we'll just ensure it doesn't panic with valid inputs
545	content := []byte("test content")
546
547	// This might fail due to permissions, but it shouldn't panic
548	_ = DumpToFile("test", "http://example.com", content)
549}