1package llm
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "net/http"
8 "testing"
9)
10
11// mockService implements Service interface for testing
12type mockService struct {
13 tokenContextWindow int
14 maxImageDimension int
15 useSimplifiedPatch bool
16 implementsSimplified bool
17}
18
19func (m *mockService) Do(ctx context.Context, req *Request) (*Response, error) {
20 return &Response{}, nil
21}
22
23func (m *mockService) TokenContextWindow() int {
24 return m.tokenContextWindow
25}
26
27func (m *mockService) MaxImageDimension() int {
28 return m.maxImageDimension
29}
30
31// mockSimplifiedService implements both Service and SimplifiedPatcher interfaces
32type mockSimplifiedService struct {
33 mockService
34}
35
36func (m *mockSimplifiedService) UseSimplifiedPatch() bool {
37 return m.useSimplifiedPatch
38}
39
40func TestMustSchema(t *testing.T) {
41 tests := []struct {
42 name string
43 schema string
44 expectPanic bool
45 }{
46 {
47 name: "valid schema",
48 schema: `{"type": "object", "properties": {}}`,
49 expectPanic: false,
50 },
51 {
52 name: "valid schema with properties",
53 schema: `{"type": "object", "properties": {"name": {"type": "string"}}}`,
54 expectPanic: false,
55 },
56 {
57 name: "invalid json",
58 schema: `{"type": "object", "properties": }`,
59 expectPanic: true,
60 },
61 {
62 name: "missing type",
63 schema: `{"properties": {}}`,
64 expectPanic: true,
65 },
66 {
67 name: "wrong type",
68 schema: `{"type": "string", "properties": {}}`,
69 expectPanic: true,
70 },
71 {
72 name: "missing properties",
73 schema: `{"type": "object"}`,
74 expectPanic: true,
75 },
76 }
77
78 for _, tt := range tests {
79 t.Run(tt.name, func(t *testing.T) {
80 if tt.expectPanic {
81 defer func() {
82 if r := recover(); r == nil {
83 t.Errorf("Expected panic for schema: %s", tt.schema)
84 }
85 }()
86 }
87 result := MustSchema(tt.schema)
88 if !tt.expectPanic {
89 if string(result) != tt.schema {
90 t.Errorf("MustSchema() = %s, want %s", string(result), tt.schema)
91 }
92 }
93 })
94 }
95}
96
97func TestEmptySchema(t *testing.T) {
98 schema := EmptySchema()
99 expected := `{"type": "object", "properties": {}}`
100 if string(schema) != expected {
101 t.Errorf("EmptySchema() = %s, want %s", string(schema), expected)
102 }
103}
104
105func TestUseSimplifiedPatch(t *testing.T) {
106 tests := []struct {
107 name string
108 service Service
109 expected bool
110 }{
111 {
112 name: "service without SimplifiedPatcher",
113 service: &mockService{
114 implementsSimplified: false,
115 useSimplifiedPatch: false,
116 },
117 expected: false,
118 },
119 {
120 name: "service with SimplifiedPatcher returning false",
121 service: &mockSimplifiedService{
122 mockService: mockService{
123 implementsSimplified: true,
124 useSimplifiedPatch: false,
125 },
126 },
127 expected: false,
128 },
129 {
130 name: "service with SimplifiedPatcher returning true",
131 service: &mockSimplifiedService{
132 mockService: mockService{
133 implementsSimplified: true,
134 useSimplifiedPatch: true,
135 },
136 },
137 expected: true,
138 },
139 }
140
141 for _, tt := range tests {
142 t.Run(tt.name, func(t *testing.T) {
143 result := UseSimplifiedPatch(tt.service)
144 if result != tt.expected {
145 t.Errorf("UseSimplifiedPatch() = %v, want %v", result, tt.expected)
146 }
147 })
148 }
149}
150
151func TestStringContent(t *testing.T) {
152 text := "test content"
153 content := StringContent(text)
154
155 if content.Type != ContentTypeText {
156 t.Errorf("StringContent().Type = %v, want %v", content.Type, ContentTypeText)
157 }
158
159 if content.Text != text {
160 t.Errorf("StringContent().Text = %s, want %s", content.Text, text)
161 }
162}
163
164func TestTextContent(t *testing.T) {
165 text := "test text content"
166 contents := TextContent(text)
167
168 if len(contents) != 1 {
169 t.Errorf("TextContent() returned %d items, want 1", len(contents))
170 }
171
172 if contents[0].Type != ContentTypeText {
173 t.Errorf("TextContent()[0].Type = %v, want %v", contents[0].Type, ContentTypeText)
174 }
175
176 if contents[0].Text != text {
177 t.Errorf("TextContent()[0].Text = %s, want %s", contents[0].Text, text)
178 }
179}
180
181func TestUserStringMessage(t *testing.T) {
182 text := "user message"
183 message := UserStringMessage(text)
184
185 if message.Role != MessageRoleUser {
186 t.Errorf("UserStringMessage().Role = %v, want %v", message.Role, MessageRoleUser)
187 }
188
189 if len(message.Content) != 1 {
190 t.Errorf("UserStringMessage().Content length = %d, want 1", len(message.Content))
191 }
192
193 if message.Content[0].Type != ContentTypeText {
194 t.Errorf("UserStringMessage().Content[0].Type = %v, want %v", message.Content[0].Type, ContentTypeText)
195 }
196
197 if message.Content[0].Text != text {
198 t.Errorf("UserStringMessage().Content[0].Text = %s, want %s", message.Content[0].Text, text)
199 }
200}
201
202func TestErrorToolOut(t *testing.T) {
203 err := fmt.Errorf("test error")
204 toolOut := ErrorToolOut(err)
205
206 if toolOut.Error != err {
207 t.Errorf("ErrorToolOut().Error = %v, want %v", toolOut.Error, err)
208 }
209
210 // Test panic with nil error
211 defer func() {
212 if r := recover(); r == nil {
213 t.Errorf("Expected panic when calling ErrorToolOut with nil error")
214 }
215 }()
216 ErrorToolOut(nil)
217}
218
219func TestErrorfToolOut(t *testing.T) {
220 format := "error: %s"
221 arg := "test"
222 toolOut := ErrorfToolOut(format, arg)
223
224 if toolOut.Error == nil {
225 t.Errorf("ErrorfToolOut().Error = nil, want error")
226 }
227
228 expected := fmt.Sprintf(format, arg)
229 if toolOut.Error.Error() != expected {
230 t.Errorf("ErrorfToolOut().Error = %v, want %v", toolOut.Error.Error(), expected)
231 }
232}
233
234func TestUsageAdd(t *testing.T) {
235 u1 := Usage{
236 InputTokens: 100,
237 CacheCreationInputTokens: 50,
238 CacheReadInputTokens: 25,
239 OutputTokens: 200,
240 CostUSD: 0.01,
241 }
242
243 u2 := Usage{
244 InputTokens: 150,
245 CacheCreationInputTokens: 75,
246 CacheReadInputTokens: 30,
247 OutputTokens: 100,
248 CostUSD: 0.02,
249 }
250
251 u1.Add(u2)
252
253 expected := Usage{
254 InputTokens: 250, // 100 + 150
255 CacheCreationInputTokens: 125, // 50 + 75
256 CacheReadInputTokens: 55, // 25 + 30
257 OutputTokens: 300, // 200 + 100
258 CostUSD: 0.03, // 0.01 + 0.02
259 }
260
261 if u1 != expected {
262 t.Errorf("Usage.Add() resulted in %v, want %v", u1, expected)
263 }
264}
265
266func TestUsageString(t *testing.T) {
267 tests := []struct {
268 name string
269 usage Usage
270 want string
271 }{
272 {
273 name: "normal usage",
274 usage: Usage{
275 InputTokens: 100,
276 OutputTokens: 50,
277 },
278 want: "in: 100, out: 50",
279 },
280 {
281 name: "zero usage",
282 usage: Usage{
283 InputTokens: 0,
284 OutputTokens: 0,
285 },
286 want: "in: 0, out: 0",
287 },
288 {
289 name: "high usage",
290 usage: Usage{
291 InputTokens: 1000000,
292 OutputTokens: 500000,
293 },
294 want: "in: 1000000, out: 500000",
295 },
296 }
297
298 for _, tt := range tests {
299 t.Run(tt.name, func(t *testing.T) {
300 result := tt.usage.String()
301 if result != tt.want {
302 t.Errorf("Usage.String() = %s, want %s", result, tt.want)
303 }
304 })
305 }
306}
307
308func TestUsageIsZero(t *testing.T) {
309 tests := []struct {
310 name string
311 usage Usage
312 want bool
313 }{
314 {
315 name: "zero usage",
316 usage: Usage{},
317 want: true,
318 },
319 {
320 name: "non-zero input tokens",
321 usage: Usage{
322 InputTokens: 1,
323 },
324 want: false,
325 },
326 {
327 name: "non-zero output tokens",
328 usage: Usage{
329 OutputTokens: 1,
330 },
331 want: false,
332 },
333 {
334 name: "non-zero cost",
335 usage: Usage{
336 CostUSD: 0.01,
337 },
338 want: false,
339 },
340 {
341 name: "all fields zero",
342 usage: Usage{
343 InputTokens: 0,
344 CacheCreationInputTokens: 0,
345 CacheReadInputTokens: 0,
346 OutputTokens: 0,
347 CostUSD: 0,
348 },
349 want: true,
350 },
351 }
352
353 for _, tt := range tests {
354 t.Run(tt.name, func(t *testing.T) {
355 result := tt.usage.IsZero()
356 if result != tt.want {
357 t.Errorf("Usage.IsZero() = %v, want %v", result, tt.want)
358 }
359 })
360 }
361}
362
363func TestResponseToMessage(t *testing.T) {
364 tests := []struct {
365 name string
366 response Response
367 wantRole MessageRole
368 wantEndOfTurn bool
369 }{
370 {
371 name: "tool use stop reason",
372 response: Response{
373 Role: MessageRoleAssistant,
374 StopReason: StopReasonToolUse,
375 },
376 wantRole: MessageRoleAssistant,
377 wantEndOfTurn: false,
378 },
379 {
380 name: "end turn stop reason",
381 response: Response{
382 Role: MessageRoleAssistant,
383 StopReason: StopReasonEndTurn,
384 },
385 wantRole: MessageRoleAssistant,
386 wantEndOfTurn: true,
387 },
388 {
389 name: "max tokens stop reason",
390 response: Response{
391 Role: MessageRoleAssistant,
392 StopReason: StopReasonMaxTokens,
393 },
394 wantRole: MessageRoleAssistant,
395 wantEndOfTurn: true,
396 },
397 }
398
399 for _, tt := range tests {
400 t.Run(tt.name, func(t *testing.T) {
401 message := tt.response.ToMessage()
402
403 if message.Role != tt.wantRole {
404 t.Errorf("ToMessage().Role = %v, want %v", message.Role, tt.wantRole)
405 }
406
407 if message.EndOfTurn != tt.wantEndOfTurn {
408 t.Errorf("ToMessage().EndOfTurn = %v, want %v", message.EndOfTurn, tt.wantEndOfTurn)
409 }
410 })
411 }
412}
413
414func TestContentsAttr(t *testing.T) {
415 tests := []struct {
416 name string
417 contents []Content
418 }{
419 {
420 name: "text content",
421 contents: []Content{
422 {
423 ID: "1",
424 Type: ContentTypeText,
425 Text: "hello world",
426 },
427 },
428 },
429 {
430 name: "tool use content",
431 contents: []Content{
432 {
433 ID: "2",
434 Type: ContentTypeToolUse,
435 ToolName: "test_tool",
436 ToolInput: json.RawMessage(`{"param": "value"}`),
437 },
438 },
439 },
440 {
441 name: "tool result content",
442 contents: []Content{
443 {
444 ID: "3",
445 Type: ContentTypeToolResult,
446 ToolResult: []Content{{Type: ContentTypeText, Text: "result"}},
447 ToolError: false,
448 },
449 },
450 },
451 {
452 name: "thinking content",
453 contents: []Content{
454 {
455 ID: "4",
456 Type: ContentTypeThinking,
457 Text: "thinking...",
458 },
459 },
460 },
461 {
462 name: "empty contents",
463 contents: []Content{},
464 },
465 }
466
467 for _, tt := range tests {
468 t.Run(tt.name, func(t *testing.T) {
469 attr := ContentsAttr(tt.contents)
470 if attr.Key != "contents" {
471 t.Errorf("ContentsAttr().Key = %s, want 'contents'", attr.Key)
472 }
473 })
474 }
475}
476
477func TestCostUSDFromResponse(t *testing.T) {
478 tests := []struct {
479 name string
480 headers map[string]string
481 wantCost float64
482 }{
483 {
484 name: "valid cost header",
485 headers: map[string]string{
486 "Skaband-Cost-Microcents": "10000000", // 0.1 USD
487 },
488 wantCost: 0.1,
489 },
490 {
491 name: "invalid cost header",
492 headers: map[string]string{
493 "Skaband-Cost-Microcents": "invalid",
494 },
495 wantCost: 0,
496 },
497 {
498 name: "missing cost header",
499 headers: map[string]string{},
500 wantCost: 0,
501 },
502 {
503 name: "empty cost header",
504 headers: map[string]string{
505 "Skaband-Cost-Microcents": "",
506 },
507 wantCost: 0,
508 },
509 }
510
511 for _, tt := range tests {
512 t.Run(tt.name, func(t *testing.T) {
513 headers := make(http.Header)
514 for k, v := range tt.headers {
515 headers.Set(k, v)
516 }
517
518 cost := CostUSDFromResponse(headers)
519 if cost != tt.wantCost {
520 t.Errorf("CostUSDFromResponse() = %f, want %f", cost, tt.wantCost)
521 }
522 })
523 }
524}
525
526func TestUsageAttr(t *testing.T) {
527 usage := Usage{
528 InputTokens: 100,
529 OutputTokens: 50,
530 CacheCreationInputTokens: 25,
531 CacheReadInputTokens: 75,
532 CostUSD: 0.01,
533 }
534
535 attr := usage.Attr()
536 if attr.Key != "usage" {
537 t.Errorf("Attr().Key = %s, want 'usage'", attr.Key)
538 }
539}
540
541func TestDumpToFile(t *testing.T) {
542 // This test just verifies the function exists and can be called
543 // We don't actually want to write files during testing
544 // So we'll just ensure it doesn't panic with valid inputs
545 content := []byte("test content")
546
547 // This might fail due to permissions, but it shouldn't panic
548 _ = DumpToFile("test", "http://example.com", content)
549}