1package conversation
2
3import (
4 "cmp"
5 "context"
6 "net/http"
7 "os"
8 "slices"
9 "strings"
10 "testing"
11
12 "shelley.exe.dev/llm"
13 "shelley.exe.dev/llm/ant"
14 "sketch.dev/httprr"
15)
16
17func TestBasicConvo(t *testing.T) {
18 ctx := context.Background()
19 rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
20 if err != nil {
21 t.Fatal(err)
22 }
23 rr.ScrubReq(func(req *http.Request) error {
24 req.Header.Del("x-api-key")
25 return nil
26 })
27
28 apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_MODEL_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
29 srv := &ant.Service{
30 APIKey: apiKey,
31 Model: ant.Claude4Sonnet, // Use specific model to match cached responses
32 HTTPC: rr.Client(),
33 }
34 convo := New(ctx, srv, nil)
35
36 const name = "Cornelius"
37 res, err := convo.SendUserTextMessage("Hi, my name is " + name)
38 if err != nil {
39 t.Fatal(err)
40 }
41 for _, part := range res.Content {
42 t.Logf("%s", part.Text)
43 }
44 res, err = convo.SendUserTextMessage("What is my name?")
45 if err != nil {
46 t.Fatal(err)
47 }
48 got := ""
49 for _, part := range res.Content {
50 got += part.Text
51 }
52 if !strings.Contains(got, name) {
53 t.Errorf("model does not know the given name %s: %q", name, got)
54 }
55}
56
57// TestCancelToolUse tests the CancelToolUse function of the Convo struct
58func TestCancelToolUse(t *testing.T) {
59 tests := []struct {
60 name string
61 setupToolUse bool
62 toolUseID string
63 cancelErr error
64 expectError bool
65 expectCancel bool
66 }{
67 {
68 name: "Cancel existing tool use",
69 setupToolUse: true,
70 toolUseID: "tool123",
71 cancelErr: nil,
72 expectError: false,
73 expectCancel: true,
74 },
75 {
76 name: "Cancel existing tool use with error",
77 setupToolUse: true,
78 toolUseID: "tool456",
79 cancelErr: context.Canceled,
80 expectError: false,
81 expectCancel: true,
82 },
83 {
84 name: "Cancel non-existent tool use",
85 setupToolUse: false,
86 toolUseID: "tool789",
87 cancelErr: nil,
88 expectError: true,
89 expectCancel: false,
90 },
91 }
92
93 srv := &ant.Service{}
94 for _, tt := range tests {
95 t.Run(tt.name, func(t *testing.T) {
96 convo := New(context.Background(), srv, nil)
97
98 var cancelCalled bool
99 var cancelledWithErr error
100
101 if tt.setupToolUse {
102 // Setup a mock cancel function to track calls
103 mockCancel := func(err error) {
104 cancelCalled = true
105 cancelledWithErr = err
106 }
107
108 convo.toolUseCancelMu.Lock()
109 convo.toolUseCancel[tt.toolUseID] = mockCancel
110 convo.toolUseCancelMu.Unlock()
111 }
112
113 err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
114
115 // Check if we got the expected error state
116 if (err != nil) != tt.expectError {
117 t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
118 }
119
120 // Check if the cancel function was called as expected
121 if cancelCalled != tt.expectCancel {
122 t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
123 }
124
125 // If we expected the cancel to be called, verify it was called with the right error
126 if tt.expectCancel && cancelledWithErr != tt.cancelErr {
127 t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
128 }
129
130 // Verify the toolUseID was removed from the map if it was initially added
131 if tt.setupToolUse {
132 convo.toolUseCancelMu.Lock()
133 _, exists := convo.toolUseCancel[tt.toolUseID]
134 convo.toolUseCancelMu.Unlock()
135
136 if exists {
137 t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
138 }
139 }
140 })
141 }
142}
143
144// TestInsertMissingToolResults tests the insertMissingToolResults function
145// to ensure it doesn't create duplicate tool results when multiple tool uses are missing results.
146func TestInsertMissingToolResults(t *testing.T) {
147 tests := []struct {
148 name string
149 messages []llm.Message
150 currentMsg llm.Message
151 expectedCount int
152 expectedToolIDs []string
153 }{
154 {
155 name: "Single missing tool result",
156 messages: []llm.Message{
157 {
158 Role: llm.MessageRoleAssistant,
159 Content: []llm.Content{
160 {
161 Type: llm.ContentTypeToolUse,
162 ID: "tool1",
163 },
164 },
165 },
166 },
167 currentMsg: llm.Message{
168 Role: llm.MessageRoleUser,
169 Content: []llm.Content{},
170 },
171 expectedCount: 1,
172 expectedToolIDs: []string{"tool1"},
173 },
174 {
175 name: "Multiple missing tool results",
176 messages: []llm.Message{
177 {
178 Role: llm.MessageRoleAssistant,
179 Content: []llm.Content{
180 {
181 Type: llm.ContentTypeToolUse,
182 ID: "tool1",
183 },
184 {
185 Type: llm.ContentTypeToolUse,
186 ID: "tool2",
187 },
188 {
189 Type: llm.ContentTypeToolUse,
190 ID: "tool3",
191 },
192 },
193 },
194 },
195 currentMsg: llm.Message{
196 Role: llm.MessageRoleUser,
197 Content: []llm.Content{},
198 },
199 expectedCount: 3,
200 expectedToolIDs: []string{"tool1", "tool2", "tool3"},
201 },
202 {
203 name: "No missing tool results when results already present",
204 messages: []llm.Message{
205 {
206 Role: llm.MessageRoleAssistant,
207 Content: []llm.Content{
208 {
209 Type: llm.ContentTypeToolUse,
210 ID: "tool1",
211 },
212 },
213 },
214 },
215 currentMsg: llm.Message{
216 Role: llm.MessageRoleUser,
217 Content: []llm.Content{
218 {
219 Type: llm.ContentTypeToolResult,
220 ToolUseID: "tool1",
221 },
222 },
223 },
224 expectedCount: 1, // Only the existing one
225 expectedToolIDs: []string{"tool1"},
226 },
227 {
228 name: "No tool uses in previous message",
229 messages: []llm.Message{
230 {
231 Role: llm.MessageRoleAssistant,
232 Content: []llm.Content{
233 {
234 Type: llm.ContentTypeText,
235 Text: "Just some text",
236 },
237 },
238 },
239 },
240 currentMsg: llm.Message{
241 Role: llm.MessageRoleUser,
242 Content: []llm.Content{},
243 },
244 expectedCount: 0,
245 expectedToolIDs: []string{},
246 },
247 }
248
249 for _, tt := range tests {
250 t.Run(tt.name, func(t *testing.T) {
251 srv := &ant.Service{}
252 convo := New(context.Background(), srv, nil)
253
254 // Create request with messages
255 req := &llm.Request{
256 Messages: append(tt.messages, tt.currentMsg),
257 }
258
259 // Call insertMissingToolResults
260 msg := tt.currentMsg
261 convo.insertMissingToolResults(req, &msg)
262
263 // Count tool results in the message
264 toolResultCount := 0
265 toolIDs := []string{}
266 for _, content := range msg.Content {
267 if content.Type == llm.ContentTypeToolResult {
268 toolResultCount++
269 toolIDs = append(toolIDs, content.ToolUseID)
270 }
271 }
272
273 // Verify count
274 if toolResultCount != tt.expectedCount {
275 t.Errorf("Expected %d tool results, got %d", tt.expectedCount, toolResultCount)
276 }
277
278 // Verify no duplicates by checking unique tool IDs
279 seenIDs := make(map[string]int)
280 for _, id := range toolIDs {
281 seenIDs[id]++
282 }
283
284 // Check for duplicates
285 for id, count := range seenIDs {
286 if count > 1 {
287 t.Errorf("Duplicate tool result for ID %s: found %d times", id, count)
288 }
289 }
290
291 // Verify all expected tool IDs are present
292 for _, expectedID := range tt.expectedToolIDs {
293 if !slices.Contains(toolIDs, expectedID) {
294 t.Errorf("Expected tool ID %s not found in results", expectedID)
295 }
296 }
297 })
298 }
299}