1package conversation
2
3import (
4 "cmp"
5 "context"
6 "encoding/json"
7 "net/http"
8 "os"
9 "slices"
10 "strings"
11 "testing"
12 "time"
13
14 "shelley.exe.dev/llm"
15 "shelley.exe.dev/llm/ant"
16 "shelley.exe.dev/loop"
17 "sketch.dev/httprr"
18)
19
20func TestBasicConvo(t *testing.T) {
21 ctx := context.Background()
22 rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
23 if err != nil {
24 t.Fatal(err)
25 }
26 rr.ScrubReq(func(req *http.Request) error {
27 req.Header.Del("x-api-key")
28 req.Header.Del("User-Agent")
29 req.Header.Del("Shelley-Conversation-Id")
30 return nil
31 })
32
33 apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_MODEL_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
34 srv := &ant.Service{
35 APIKey: apiKey,
36 Model: ant.Claude4Sonnet, // Use specific model to match cached responses
37 HTTPC: rr.Client(),
38 }
39 convo := New(ctx, srv, nil)
40
41 const name = "Cornelius"
42 res, err := convo.SendUserTextMessage("Hi, my name is " + name)
43 if err != nil {
44 t.Fatal(err)
45 }
46 for _, part := range res.Content {
47 t.Logf("%s", part.Text)
48 }
49 res, err = convo.SendUserTextMessage("What is my name?")
50 if err != nil {
51 t.Fatal(err)
52 }
53 got := ""
54 for _, part := range res.Content {
55 got += part.Text
56 }
57 if !strings.Contains(got, name) {
58 t.Errorf("model does not know the given name %s: %q", name, got)
59 }
60}
61
62// TestCancelToolUse tests the CancelToolUse function of the Convo struct
63func TestCancelToolUse(t *testing.T) {
64 tests := []struct {
65 name string
66 setupToolUse bool
67 toolUseID string
68 cancelErr error
69 expectError bool
70 expectCancel bool
71 }{
72 {
73 name: "Cancel existing tool use",
74 setupToolUse: true,
75 toolUseID: "tool123",
76 cancelErr: nil,
77 expectError: false,
78 expectCancel: true,
79 },
80 {
81 name: "Cancel existing tool use with error",
82 setupToolUse: true,
83 toolUseID: "tool456",
84 cancelErr: context.Canceled,
85 expectError: false,
86 expectCancel: true,
87 },
88 {
89 name: "Cancel non-existent tool use",
90 setupToolUse: false,
91 toolUseID: "tool789",
92 cancelErr: nil,
93 expectError: true,
94 expectCancel: false,
95 },
96 }
97
98 srv := &ant.Service{}
99 for _, tt := range tests {
100 t.Run(tt.name, func(t *testing.T) {
101 convo := New(context.Background(), srv, nil)
102
103 var cancelCalled bool
104 var cancelledWithErr error
105
106 if tt.setupToolUse {
107 // Setup a mock cancel function to track calls
108 mockCancel := func(err error) {
109 cancelCalled = true
110 cancelledWithErr = err
111 }
112
113 convo.toolUseCancelMu.Lock()
114 convo.toolUseCancel[tt.toolUseID] = mockCancel
115 convo.toolUseCancelMu.Unlock()
116 }
117
118 err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
119
120 // Check if we got the expected error state
121 if (err != nil) != tt.expectError {
122 t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
123 }
124
125 // Check if the cancel function was called as expected
126 if cancelCalled != tt.expectCancel {
127 t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
128 }
129
130 // If we expected the cancel to be called, verify it was called with the right error
131 if tt.expectCancel && cancelledWithErr != tt.cancelErr {
132 t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
133 }
134
135 // Verify the toolUseID was removed from the map if it was initially added
136 if tt.setupToolUse {
137 convo.toolUseCancelMu.Lock()
138 _, exists := convo.toolUseCancel[tt.toolUseID]
139 convo.toolUseCancelMu.Unlock()
140
141 if exists {
142 t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
143 }
144 }
145 })
146 }
147}
148
149// TestInsertMissingToolResults tests the insertMissingToolResults function
150// to ensure it doesn't create duplicate tool results when multiple tool uses are missing results.
151func TestInsertMissingToolResults(t *testing.T) {
152 tests := []struct {
153 name string
154 messages []llm.Message
155 currentMsg llm.Message
156 expectedCount int
157 expectedToolIDs []string
158 }{
159 {
160 name: "Single missing tool result",
161 messages: []llm.Message{
162 {
163 Role: llm.MessageRoleAssistant,
164 Content: []llm.Content{
165 {
166 Type: llm.ContentTypeToolUse,
167 ID: "tool1",
168 },
169 },
170 },
171 },
172 currentMsg: llm.Message{
173 Role: llm.MessageRoleUser,
174 Content: []llm.Content{},
175 },
176 expectedCount: 1,
177 expectedToolIDs: []string{"tool1"},
178 },
179 {
180 name: "Multiple missing tool results",
181 messages: []llm.Message{
182 {
183 Role: llm.MessageRoleAssistant,
184 Content: []llm.Content{
185 {
186 Type: llm.ContentTypeToolUse,
187 ID: "tool1",
188 },
189 {
190 Type: llm.ContentTypeToolUse,
191 ID: "tool2",
192 },
193 {
194 Type: llm.ContentTypeToolUse,
195 ID: "tool3",
196 },
197 },
198 },
199 },
200 currentMsg: llm.Message{
201 Role: llm.MessageRoleUser,
202 Content: []llm.Content{},
203 },
204 expectedCount: 3,
205 expectedToolIDs: []string{"tool1", "tool2", "tool3"},
206 },
207 {
208 name: "No missing tool results when results already present",
209 messages: []llm.Message{
210 {
211 Role: llm.MessageRoleAssistant,
212 Content: []llm.Content{
213 {
214 Type: llm.ContentTypeToolUse,
215 ID: "tool1",
216 },
217 },
218 },
219 },
220 currentMsg: llm.Message{
221 Role: llm.MessageRoleUser,
222 Content: []llm.Content{
223 {
224 Type: llm.ContentTypeToolResult,
225 ToolUseID: "tool1",
226 },
227 },
228 },
229 expectedCount: 1, // Only the existing one
230 expectedToolIDs: []string{"tool1"},
231 },
232 {
233 name: "No tool uses in previous message",
234 messages: []llm.Message{
235 {
236 Role: llm.MessageRoleAssistant,
237 Content: []llm.Content{
238 {
239 Type: llm.ContentTypeText,
240 Text: "Just some text",
241 },
242 },
243 },
244 },
245 currentMsg: llm.Message{
246 Role: llm.MessageRoleUser,
247 Content: []llm.Content{},
248 },
249 expectedCount: 0,
250 expectedToolIDs: []string{},
251 },
252 }
253
254 for _, tt := range tests {
255 t.Run(tt.name, func(t *testing.T) {
256 srv := &ant.Service{}
257 convo := New(context.Background(), srv, nil)
258
259 // Create request with messages
260 req := &llm.Request{
261 Messages: append(tt.messages, tt.currentMsg),
262 }
263
264 // Call insertMissingToolResults
265 msg := tt.currentMsg
266 convo.insertMissingToolResults(req, &msg)
267
268 // Count tool results in the message
269 toolResultCount := 0
270 toolIDs := []string{}
271 for _, content := range msg.Content {
272 if content.Type == llm.ContentTypeToolResult {
273 toolResultCount++
274 toolIDs = append(toolIDs, content.ToolUseID)
275 }
276 }
277
278 // Verify count
279 if toolResultCount != tt.expectedCount {
280 t.Errorf("Expected %d tool results, got %d", tt.expectedCount, toolResultCount)
281 }
282
283 // Verify no duplicates by checking unique tool IDs
284 seenIDs := make(map[string]int)
285 for _, id := range toolIDs {
286 seenIDs[id]++
287 }
288
289 // Check for duplicates
290 for id, count := range seenIDs {
291 if count > 1 {
292 t.Errorf("Duplicate tool result for ID %s: found %d times", id, count)
293 }
294 }
295
296 // Verify all expected tool IDs are present
297 for _, expectedID := range tt.expectedToolIDs {
298 if !slices.Contains(toolIDs, expectedID) {
299 t.Errorf("Expected tool ID %s not found in results", expectedID)
300 }
301 }
302 })
303 }
304}
305
306// TestSubConvo tests the SubConvo function
307func TestSubConvo(t *testing.T) {
308 ctx := context.Background()
309 srv := &ant.Service{}
310 parentConvo := New(ctx, srv, nil)
311
312 // Test that SubConvo creates a new conversation with the correct parent relationship
313 subConvo := parentConvo.SubConvo()
314
315 if subConvo == nil {
316 t.Fatal("SubConvo returned nil")
317 }
318
319 if subConvo.Parent != parentConvo {
320 t.Error("SubConvo did not set the correct parent")
321 }
322
323 if subConvo.Service != parentConvo.Service {
324 t.Error("SubConvo did not inherit the service")
325 }
326
327 if subConvo.PromptCaching != parentConvo.PromptCaching {
328 t.Error("SubConvo did not inherit PromptCaching setting")
329 }
330
331 // Check that the sub-convo has a different ID
332 if subConvo.ID == parentConvo.ID {
333 t.Error("SubConvo should have a different ID from parent")
334 }
335
336 // Check that the sub-convo shares tool uses with parent
337 if &subConvo.usage.ToolUses == &parentConvo.usage.ToolUses {
338 t.Error("SubConvo should share tool uses map with parent")
339 }
340
341 // Check that the sub-convo has its own usage instance
342 if subConvo.usage == parentConvo.usage {
343 t.Error("SubConvo should have its own usage instance (but sharing ToolUses)")
344 }
345}
346
347// TestSubConvoWithHistory tests the SubConvoWithHistory function
348
349// TestDepth tests the Depth function
350
351// TestFindTool tests the findTool function
352func TestFindTool(t *testing.T) {
353 ctx := context.Background()
354 srv := &ant.Service{}
355 convo := New(ctx, srv, nil)
356
357 // Add some tools to the conversation
358 tool1 := &llm.Tool{Name: "tool1"}
359 tool2 := &llm.Tool{Name: "tool2"}
360 convo.Tools = append(convo.Tools, tool1, tool2)
361
362 // Test finding an existing tool
363 foundTool, err := convo.findTool("tool1")
364 if err != nil {
365 t.Errorf("findTool returned error for existing tool: %v", err)
366 }
367 if foundTool != tool1 {
368 t.Error("findTool did not return the correct tool")
369 }
370
371 // Test finding another existing tool
372 foundTool, err = convo.findTool("tool2")
373 if err != nil {
374 t.Errorf("findTool returned error for existing tool: %v", err)
375 }
376 if foundTool != tool2 {
377 t.Error("findTool did not return the correct tool")
378 }
379
380 // Test finding a non-existent tool
381 _, err = convo.findTool("nonexistent")
382 if err == nil {
383 t.Error("findTool should return error for non-existent tool")
384 }
385 expectedErr := `tool "nonexistent" not found`
386 if err.Error() != expectedErr {
387 t.Errorf("Expected error %q, got %q", expectedErr, err.Error())
388 }
389}
390
391// TestToolCallInfoFromContext tests the ToolCallInfoFromContext function
392func TestToolCallInfoFromContext(t *testing.T) {
393 // Test with no tool call info in context
394 ctx := context.Background()
395 info := ToolCallInfoFromContext(ctx)
396 if info.ToolUseID != "" {
397 t.Error("ToolCallInfoFromContext should return empty info when no tool call info is in context")
398 }
399
400 // Test with tool call info in context
401 toolInfo := ToolCallInfo{
402 ToolUseID: "testID",
403 }
404 ctxWithInfo := context.WithValue(ctx, toolCallInfoKey, toolInfo)
405 info = ToolCallInfoFromContext(ctxWithInfo)
406 if info.ToolUseID != "testID" {
407 t.Errorf("Expected ToolUseID 'testID', got %q", info.ToolUseID)
408 }
409}
410
411// TestCumulativeUsageMethods tests CumulativeUsage methods
412func TestCumulativeUsageMethods(t *testing.T) {
413 // Test Clone method
414 original := &CumulativeUsage{
415 StartTime: time.Now(),
416 Responses: 5,
417 InputTokens: 100,
418 OutputTokens: 200,
419 CacheReadInputTokens: 50,
420 CacheCreationInputTokens: 30,
421 TotalCostUSD: 1.23,
422 ToolUses: map[string]int{
423 "tool1": 3,
424 "tool2": 2,
425 },
426 }
427
428 clone := original.Clone()
429
430 // Check that values are copied correctly
431 if clone.StartTime != original.StartTime {
432 t.Error("Clone did not copy StartTime correctly")
433 }
434 if clone.Responses != original.Responses {
435 t.Error("Clone did not copy Responses correctly")
436 }
437 if clone.InputTokens != original.InputTokens {
438 t.Error("Clone did not copy InputTokens correctly")
439 }
440 if clone.OutputTokens != original.OutputTokens {
441 t.Error("Clone did not copy OutputTokens correctly")
442 }
443 if clone.CacheReadInputTokens != original.CacheReadInputTokens {
444 t.Error("Clone did not copy CacheReadInputTokens correctly")
445 }
446 if clone.CacheCreationInputTokens != original.CacheCreationInputTokens {
447 t.Error("Clone did not copy CacheCreationInputTokens correctly")
448 }
449 if clone.TotalCostUSD != original.TotalCostUSD {
450 t.Error("Clone did not copy TotalCostUSD correctly")
451 }
452 if len(clone.ToolUses) != len(original.ToolUses) {
453 t.Error("Clone did not copy ToolUses correctly")
454 }
455 for k, v := range original.ToolUses {
456 if clone.ToolUses[k] != v {
457 t.Errorf("Clone did not copy ToolUses correctly for key %s", k)
458 }
459 }
460
461 // Check that maps are separate instances
462 clone.ToolUses["tool3"] = 1
463 if _, exists := original.ToolUses["tool3"]; exists {
464 t.Error("Clone should have separate ToolUses map")
465 }
466}
467
468// TestUsageMethods tests various usage calculation methods
469func TestUsageMethods(t *testing.T) {
470 ctx := context.Background()
471 srv := loop.NewPredictableService()
472 convo := New(ctx, srv, nil)
473
474 // Test CumulativeUsage on empty conversation
475 usage := convo.CumulativeUsage()
476 if usage.Responses != 0 {
477 t.Error("CumulativeUsage should be empty for new conversation")
478 }
479
480 // Test WallTime method
481 wallTime := usage.WallTime()
482 if wallTime <= 0 {
483 t.Error("WallTime should be positive")
484 }
485
486 // Test DollarsPerHour method
487 dollarsPerHour := usage.DollarsPerHour()
488 if dollarsPerHour != 0 {
489 t.Error("DollarsPerHour should be 0 for empty usage")
490 }
491
492 // Test TotalInputTokens method
493 totalInputTokens := usage.TotalInputTokens()
494 if totalInputTokens != 0 {
495 t.Error("TotalInputTokens should be 0 for empty usage")
496 }
497
498 // Test Attr method
499 attr := usage.Attr()
500 if attr.Key != "usage" {
501 t.Error("Attr should have key 'usage'")
502 }
503}
504
505// TestLastUsage tests the LastUsage function
506func TestLastUsage(t *testing.T) {
507 ctx := context.Background()
508 srv := loop.NewPredictableService()
509 convo := New(ctx, srv, nil)
510
511 // Test LastUsage on empty conversation
512 lastUsage := convo.LastUsage()
513 if lastUsage.InputTokens != 0 {
514 t.Error("LastUsage should be empty for new conversation")
515 }
516
517 // Send a message to generate some usage
518 _, err := convo.SendUserTextMessage("echo: hello")
519 if err != nil {
520 t.Fatalf("SendUserTextMessage failed: %v", err)
521 }
522
523 // Test LastUsage after sending a message
524 lastUsage = convo.LastUsage()
525 if lastUsage.InputTokens == 0 {
526 t.Error("LastUsage should have input tokens after sending a message")
527 }
528}
529
530// TestOverBudget tests the OverBudget function
531func TestOverBudget(t *testing.T) {
532 ctx := context.Background()
533 srv := loop.NewPredictableService()
534 convo := New(ctx, srv, nil)
535
536 // Test OverBudget with no budget set
537 err := convo.OverBudget()
538 if err != nil {
539 t.Errorf("OverBudget should return nil when no budget is set, got %v", err)
540 }
541
542 // Set a budget
543 convo.Budget.MaxDollars = 10.0
544
545 // Test OverBudget with budget not exceeded
546 err = convo.OverBudget()
547 if err != nil {
548 t.Errorf("OverBudget should return nil when budget is not exceeded, got %v", err)
549 }
550
551 // Test with sub-conversation
552 subConvo := convo.SubConvo()
553 err = subConvo.OverBudget()
554 if err != nil {
555 t.Errorf("OverBudget should return nil for sub-conversation when budget is not exceeded, got %v", err)
556 }
557}
558
559// TestResetBudget tests the ResetBudget function
560func TestResetBudget(t *testing.T) {
561 ctx := context.Background()
562 srv := loop.NewPredictableService()
563 convo := New(ctx, srv, nil)
564
565 // Set initial budget
566 initialBudget := Budget{MaxDollars: 5.0}
567 convo.ResetBudget(initialBudget)
568
569 // Check that budget was set
570 if convo.Budget.MaxDollars != 5.0 {
571 t.Errorf("Expected budget MaxDollars to be 5.0, got %f", convo.Budget.MaxDollars)
572 }
573
574 // Send a message to accumulate some usage
575 _, err := convo.SendUserTextMessage("echo: hello")
576 if err != nil {
577 t.Fatalf("SendUserTextMessage failed: %v", err)
578 }
579
580 // Get current usage
581 usage := convo.CumulativeUsage()
582 usedAmount := usage.TotalCostUSD
583
584 // Reset budget again
585 newBudget := Budget{MaxDollars: 10.0}
586 convo.ResetBudget(newBudget)
587
588 // Check that budget was adjusted by usage
589 expectedBudget := 10.0 + usedAmount
590 if convo.Budget.MaxDollars != expectedBudget {
591 t.Errorf("Expected adjusted budget MaxDollars to be %f, got %f", expectedBudget, convo.Budget.MaxDollars)
592 }
593}
594
595// TestOverBudgetFunction tests the overBudget function
596func TestOverBudgetFunction(t *testing.T) {
597 ctx := context.Background()
598 srv := loop.NewPredictableService()
599 convo := New(ctx, srv, nil)
600
601 // Test overBudget with no budget set
602 err := convo.overBudget()
603 if err != nil {
604 t.Errorf("overBudget should return nil when no budget is set, got %v", err)
605 }
606
607 // Set a budget
608 convo.Budget.MaxDollars = 5.0
609
610 // Test overBudget with budget not exceeded
611 err = convo.overBudget()
612 if err != nil {
613 t.Errorf("overBudget should return nil when budget is not exceeded, got %v", err)
614 }
615}
616
617// TestGetID tests the GetID function
618
619// TestListenerMethods tests the listener methods
620func TestListenerMethods(t *testing.T) {
621 listener := &NoopListener{}
622 ctx := context.Background()
623 convo := &Convo{}
624
625 // Test that noop listener methods don't panic
626 listener.OnToolCall(ctx, convo, "id", "toolName", json.RawMessage(`{"key":"value"}`), llm.Content{})
627 listener.OnToolResult(ctx, convo, "id", "toolName", json.RawMessage(`{"key":"value"}`), llm.Content{}, nil, nil)
628 listener.OnResponse(ctx, convo, "id", &llm.Response{})
629 listener.OnRequest(ctx, convo, "id", &llm.Message{})
630
631 t.Log("NoopListener methods executed without panic")
632}
633
634// TestIncrementToolUse tests the incrementToolUse function
635func TestIncrementToolUse(t *testing.T) {
636 ctx := context.Background()
637 srv := loop.NewPredictableService()
638 convo := New(ctx, srv, nil)
639
640 // Check initial state
641 usage := convo.CumulativeUsage()
642 if usage.ToolUses["testTool"] != 0 {
643 t.Errorf("Expected 0 uses of testTool, got %d", usage.ToolUses["testTool"])
644 }
645
646 // Increment tool use
647 convo.incrementToolUse("testTool")
648
649 // Check that tool use was incremented
650 usage = convo.CumulativeUsage()
651 if usage.ToolUses["testTool"] != 1 {
652 t.Errorf("Expected 1 use of testTool, got %d", usage.ToolUses["testTool"])
653 }
654
655 // Increment again
656 convo.incrementToolUse("testTool")
657
658 // Check that tool use was incremented again
659 usage = convo.CumulativeUsage()
660 if usage.ToolUses["testTool"] != 2 {
661 t.Errorf("Expected 2 uses of testTool, got %d", usage.ToolUses["testTool"])
662 }
663
664 // Test with different tool
665 convo.incrementToolUse("anotherTool")
666 usage = convo.CumulativeUsage()
667 if usage.ToolUses["anotherTool"] != 1 {
668 t.Errorf("Expected 1 use of anotherTool, got %d", usage.ToolUses["anotherTool"])
669 }
670}
671
672// TestDebugJSON tests the DebugJSON function
673// TestToolResultCancelContents tests the ToolResultCancelContents function
674func TestToolResultCancelContents(t *testing.T) {
675 ctx := context.Background()
676 srv := &ant.Service{}
677 convo := New(ctx, srv, nil)
678
679 // Test with response that doesn't have tool use stop reason
680 resp := &llm.Response{
681 StopReason: llm.StopReasonEndTurn,
682 }
683 contents, err := convo.ToolResultCancelContents(resp)
684 if err != nil {
685 t.Errorf("ToolResultCancelContents should not error with non-tool-use response: %v", err)
686 }
687 if contents != nil {
688 t.Error("ToolResultCancelContents should return nil with non-tool-use response")
689 }
690
691 // Test with response that has tool use stop reason but no tool use content
692 resp = &llm.Response{
693 StopReason: llm.StopReasonToolUse,
694 Content: []llm.Content{
695 {Type: llm.ContentTypeText, Text: "Hello"},
696 },
697 }
698 contents, err = convo.ToolResultCancelContents(resp)
699 if err != nil {
700 t.Errorf("ToolResultCancelContents should not error with tool use response but no tool content: %v", err)
701 }
702 // Check if contents is nil (this is expected when no tool uses are found)
703 if contents != nil && len(contents) != 0 {
704 t.Errorf("ToolResultCancelContents should return nil or empty slice with tool use response but no tool content, got length %d", len(contents))
705 }
706
707 // Test with response that has tool use stop reason and actual tool use content
708 resp = &llm.Response{
709 StopReason: llm.StopReasonToolUse,
710 Content: []llm.Content{
711 {Type: llm.ContentTypeToolUse, ID: "tool1", ToolName: "testTool"},
712 },
713 }
714 contents, err = convo.ToolResultCancelContents(resp)
715 if err != nil {
716 t.Errorf("ToolResultCancelContents should not error with tool use response and tool content: %v", err)
717 }
718 if contents == nil {
719 t.Error("ToolResultCancelContents should return non-nil slice with tool use response and tool content")
720 } else if len(contents) != 1 {
721 t.Errorf("ToolResultCancelContents should return slice with one element with tool use response and tool content, got length %d", len(contents))
722 } else {
723 // Check that the returned content has the correct properties
724 if contents[0].Type != llm.ContentTypeToolResult {
725 t.Errorf("ToolResultCancelContents should return tool result content, got type %v", contents[0].Type)
726 }
727 if contents[0].ToolUseID != "tool1" {
728 t.Errorf("ToolResultCancelContents should return content with correct ToolUseID, got %v", contents[0].ToolUseID)
729 }
730 if !contents[0].ToolError {
731 t.Error("ToolResultCancelContents should return content with ToolError set to true")
732 }
733 }
734}
735
736// TestNewToolUseContext tests the newToolUseContext function
737func TestNewToolUseContext(t *testing.T) {
738 ctx := context.Background()
739 srv := &ant.Service{}
740 convo := New(ctx, srv, nil)
741
742 // Test creating a new tool use context
743 toolUseID := "test-tool-use-id"
744 toolCtx, cancel := convo.newToolUseContext(ctx, toolUseID)
745
746 if toolCtx == nil {
747 t.Error("newToolUseContext should return a valid context")
748 }
749
750 if cancel == nil {
751 t.Error("newToolUseContext should return a valid cancel function")
752 }
753
754 // Check that the tool use was registered
755 convo.toolUseCancelMu.Lock()
756 _, exists := convo.toolUseCancel[toolUseID]
757 convo.toolUseCancelMu.Unlock()
758
759 if !exists {
760 t.Error("newToolUseContext should register the tool use cancel function")
761 }
762
763 // Test that cancel function works
764 cancel()
765
766 // Check that the tool use was unregistered
767 convo.toolUseCancelMu.Lock()
768 _, exists = convo.toolUseCancel[toolUseID]
769 convo.toolUseCancelMu.Unlock()
770
771 if exists {
772 t.Error("Cancel function should unregister the tool use")
773 }
774}
775
776// TestToolResultContents tests the ToolResultContents function
777func TestToolResultContents(t *testing.T) {
778 ctx := context.Background()
779 srv := &ant.Service{}
780 convo := New(ctx, srv, nil)
781
782 // Skip nil response test as the function doesn't handle nil properly
783 // This would cause a nil pointer dereference in the actual function
784
785 // Test with response that doesn't have tool use stop reason
786 resp := &llm.Response{
787 StopReason: llm.StopReasonEndTurn,
788 }
789 contents, endsTurn, err := convo.ToolResultContents(ctx, resp)
790 if err != nil {
791 t.Errorf("ToolResultContents should not error with non-tool-use response: %v", err)
792 }
793 if contents != nil {
794 t.Error("ToolResultContents should return nil with non-tool-use response")
795 }
796 if endsTurn {
797 t.Error("ToolResultContents should return false for endsTurn with non-tool-use response")
798 }
799}
800
801// testListener is a custom listener implementation for testing
802type testListener struct {
803 events []string
804}
805
806func (tl *testListener) OnToolCall(ctx context.Context, convo *Convo, id, toolName string, toolInput json.RawMessage, content llm.Content) {
807 tl.events = append(tl.events, "OnToolCall")
808}
809
810func (tl *testListener) OnToolResult(ctx context.Context, convo *Convo, id, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error) {
811 tl.events = append(tl.events, "OnToolResult")
812}
813
814func (tl *testListener) OnResponse(ctx context.Context, convo *Convo, id string, resp *llm.Response) {
815 tl.events = append(tl.events, "OnResponse")
816}
817
818func (tl *testListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *llm.Message) {
819 tl.events = append(tl.events, "OnRequest")
820}
821
822// TestListenerInterface tests that the Listener interface methods are called
823func TestListenerInterface(t *testing.T) {
824 listener := &testListener{}
825 ctx := context.Background()
826 convo := &Convo{}
827
828 // Test that all listener methods can be called without panicking
829 listener.OnToolCall(ctx, convo, "id", "toolName", json.RawMessage(`{"key":"value"}`), llm.Content{})
830 listener.OnToolResult(ctx, convo, "id", "toolName", json.RawMessage(`{"key":"value"}`), llm.Content{}, nil, nil)
831 listener.OnResponse(ctx, convo, "id", &llm.Response{})
832 listener.OnRequest(ctx, convo, "id", &llm.Message{})
833
834 // Check that events were recorded
835 if len(listener.events) != 4 {
836 t.Errorf("Expected 4 events, got %d", len(listener.events))
837 }
838
839 expectedEvents := []string{"OnToolCall", "OnToolResult", "OnResponse", "OnRequest"}
840 for i, expected := range expectedEvents {
841 if listener.events[i] != expected {
842 t.Errorf("Expected event %s, got %s", expected, listener.events[i])
843 }
844 }
845}
846
847// TestToolResultContentsWithToolUse tests ToolResultContents with actual tool use
848func TestToolResultContentsWithToolUse(t *testing.T) {
849 ctx := context.Background()
850 srv := loop.NewPredictableService()
851 convo := New(ctx, srv, nil)
852
853 // Add a simple echo tool
854 convo.Tools = append(convo.Tools, &llm.Tool{
855 Name: "echo",
856 Description: "Echo tool for testing",
857 InputSchema: json.RawMessage(`{"type": "object", "properties": {"message": {"type": "string"}}}`),
858 Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut {
859 return llm.ToolOut{
860 LLMContent: []llm.Content{{Type: llm.ContentTypeText, Text: "echo response"}},
861 }
862 },
863 })
864
865 // Create a response with tool use stop reason
866 resp := &llm.Response{
867 StopReason: llm.StopReasonToolUse,
868 Content: []llm.Content{
869 {
870 Type: llm.ContentTypeToolUse,
871 ID: "test-tool-call",
872 ToolName: "echo",
873 ToolInput: json.RawMessage(`{"message": "test"}`),
874 },
875 },
876 }
877
878 // Test ToolResultContents with tool use
879 contents, endsTurn, err := convo.ToolResultContents(ctx, resp)
880 if err != nil {
881 t.Fatalf("ToolResultContents failed: %v", err)
882 }
883
884 // Should return tool results
885 if len(contents) == 0 {
886 t.Error("ToolResultContents should return tool results")
887 }
888
889 // Check the content type
890 if contents[0].Type != llm.ContentTypeToolResult {
891 t.Errorf("Expected ContentTypeToolResult, got %s", contents[0].Type)
892 }
893
894 // For our echo tool, endsTurn should be false
895 if endsTurn {
896 t.Error("Expected endsTurn to be false for echo tool")
897 }
898}
899
900// TestOverBudgetWithExceeded tests OverBudget when budget is exceeded
901func TestOverBudgetWithExceeded(t *testing.T) {
902 ctx := context.Background()
903 srv := loop.NewPredictableService()
904 convo := New(ctx, srv, nil)
905
906 // Set a tiny budget
907 convo.Budget.MaxDollars = 0.0000001
908
909 // Send a message to accumulate usage
910 _, err := convo.SendUserTextMessage("test message")
911 if err != nil {
912 t.Fatalf("SendUserTextMessage failed: %v", err)
913 }
914
915 // Test that OverBudget returns an error
916 err = convo.OverBudget()
917 if err == nil {
918 t.Error("OverBudget should return an error when budget is exceeded")
919 }
920}
921
922// TestResetBudgetWithUsage tests ResetBudget with existing usage
923func TestResetBudgetWithUsage(t *testing.T) {
924 ctx := context.Background()
925 srv := loop.NewPredictableService()
926 convo := New(ctx, srv, nil)
927
928 // Send a message to accumulate usage
929 _, err := convo.SendUserTextMessage("test message")
930 if err != nil {
931 t.Fatalf("SendUserTextMessage failed: %v", err)
932 }
933
934 // Get current usage
935 initialUsage := convo.CumulativeUsage()
936 initialCost := initialUsage.TotalCostUSD
937
938 // Reset budget
939 newBudget := Budget{MaxDollars: 10.0}
940 convo.ResetBudget(newBudget)
941
942 // Check that budget was adjusted
943 expectedBudget := 10.0 + initialCost
944 if convo.Budget.MaxDollars != expectedBudget {
945 t.Errorf("Expected budget to be %f, got %f", expectedBudget, convo.Budget.MaxDollars)
946 }
947}
948
949// TestSubConvoWithHistory tests SubConvoWithHistory method
950
951// TestDepth tests Depth method
952
953// TestGetID tests GetID method
954
955// TestDebugJSON tests DebugJSON method
956
957// recordingListener is a listener that records all calls for testing
958type recordingListener struct {
959 calls []string
960}
961
962func (rl *recordingListener) OnToolCall(ctx context.Context, convo *Convo, id, toolName string, toolInput json.RawMessage, content llm.Content) {
963 rl.calls = append(rl.calls, "OnToolCall")
964}
965
966func (rl *recordingListener) OnToolResult(ctx context.Context, convo *Convo, id, toolName string, toolInput json.RawMessage, content llm.Content, result *string, err error) {
967 rl.calls = append(rl.calls, "OnToolResult")
968}
969
970func (rl *recordingListener) OnResponse(ctx context.Context, convo *Convo, id string, resp *llm.Response) {
971 rl.calls = append(rl.calls, "OnResponse")
972}
973
974func (rl *recordingListener) OnRequest(ctx context.Context, convo *Convo, id string, msg *llm.Message) {
975 rl.calls = append(rl.calls, "OnRequest")
976}
977
978// TestConvoListenerIntegration tests that Convo actually calls listener methods during operation
979func TestConvoListenerIntegration(t *testing.T) {
980 ctx := context.Background()
981 srv := loop.NewPredictableService()
982 convo := New(ctx, srv, nil)
983
984 // Set up recording listener
985 listener := &recordingListener{}
986 convo.Listener = listener
987
988 // Send a message to trigger listener calls
989 _, err := convo.SendUserTextMessage("Hello")
990 if err != nil {
991 t.Fatalf("SendUserTextMessage failed: %v", err)
992 }
993
994 // Check that we recorded some calls
995 if len(listener.calls) == 0 {
996 t.Error("Expected listener methods to be called during conversation, but no calls were recorded")
997 }
998
999 // Verify that request and response events were recorded
1000 requestFound := false
1001 responseFound := false
1002 for _, call := range listener.calls {
1003 if call == "OnRequest" {
1004 requestFound = true
1005 }
1006 if call == "OnResponse" {
1007 responseFound = true
1008 }
1009 }
1010
1011 if !requestFound {
1012 t.Error("Expected OnRequest to be called during conversation")
1013 }
1014 if !responseFound {
1015 t.Error("Expected OnResponse to be called during conversation")
1016 }
1017}
1018
1019// TestSubConvoWithHistory tests SubConvoWithHistory method
1020func TestSubConvoWithHistoryAdditional(t *testing.T) {
1021 ctx := context.Background()
1022 srv := loop.NewPredictableService()
1023 convo := New(ctx, srv, nil)
1024
1025 // Send a message to create some history
1026 _, err := convo.SendUserTextMessage("Hello")
1027 if err != nil {
1028 t.Fatalf("SendUserTextMessage failed: %v", err)
1029 }
1030
1031 // Create sub-conversation with history
1032 subConvo := convo.SubConvoWithHistory()
1033 if subConvo == nil {
1034 t.Fatal("SubConvoWithHistory should return a valid conversation")
1035 }
1036
1037 // Check that sub-conversation has parent
1038 if subConvo.Parent != convo {
1039 t.Error("Sub-conversation should have parent set")
1040 }
1041
1042 // Check that sub-conversation has messages (history)
1043 if len(subConvo.messages) == 0 {
1044 t.Error("Sub-conversation should have messages from parent")
1045 }
1046
1047 // Check that the first message is from the parent conversation
1048 if len(subConvo.messages) < 1 {
1049 t.Error("Sub-conversation should have at least one message")
1050 }
1051}
1052
1053// TestDepthAdditional tests Depth method
1054func TestDepthAdditional(t *testing.T) {
1055 ctx := context.Background()
1056 srv := loop.NewPredictableService()
1057 convo := New(ctx, srv, nil)
1058
1059 // Root conversation should have depth 0
1060 if convo.Depth() != 0 {
1061 t.Errorf("Expected depth 0, got %d", convo.Depth())
1062 }
1063
1064 // Sub-conversation should have depth 1
1065 subConvo := convo.SubConvo()
1066 if subConvo.Depth() != 1 {
1067 t.Errorf("Expected depth 1, got %d", subConvo.Depth())
1068 }
1069
1070 // Sub-sub-conversation should have depth 2
1071 subSubConvo := subConvo.SubConvo()
1072 if subSubConvo.Depth() != 2 {
1073 t.Errorf("Expected depth 2, got %d", subSubConvo.Depth())
1074 }
1075}
1076
1077// TestGetIDAdditional tests GetID method
1078func TestGetIDAdditional(t *testing.T) {
1079 ctx := context.Background()
1080 srv := loop.NewPredictableService()
1081 convo := New(ctx, srv, nil)
1082
1083 id := convo.GetID()
1084 if id == "" {
1085 t.Error("GetID should return a non-empty ID")
1086 }
1087 if id != convo.ID {
1088 t.Error("GetID should return the conversation ID")
1089 }
1090}
1091
1092// TestDebugJSONAdditional tests DebugJSON method
1093func TestDebugJSONAdditional(t *testing.T) {
1094 ctx := context.Background()
1095 srv := loop.NewPredictableService()
1096 convo := New(ctx, srv, nil)
1097
1098 // Test with empty conversation
1099 jsonData, err := convo.DebugJSON()
1100 if err != nil {
1101 t.Errorf("DebugJSON failed: %v", err)
1102 }
1103 if len(jsonData) == 0 {
1104 t.Error("DebugJSON should return non-empty data")
1105 }
1106
1107 // Test with conversation that has messages
1108 _, err = convo.SendUserTextMessage("Hello")
1109 if err != nil {
1110 t.Fatalf("SendUserTextMessage failed: %v", err)
1111 }
1112
1113 jsonData, err = convo.DebugJSON()
1114 if err != nil {
1115 t.Errorf("DebugJSON failed: %v", err)
1116 }
1117 if len(jsonData) == 0 {
1118 t.Error("DebugJSON should return non-empty data")
1119 }
1120
1121 // Verify it's valid JSON by trying to unmarshal it
1122 var parsed interface{}
1123 err = json.Unmarshal(jsonData, &parsed)
1124 if err != nil {
1125 t.Errorf("DebugJSON should return valid JSON: %v", err)
1126 }
1127}