convo_test.go

  1package conversation
  2
  3import (
  4	"cmp"
  5	"context"
  6	"net/http"
  7	"os"
  8	"slices"
  9	"strings"
 10	"testing"
 11
 12	"shelley.exe.dev/llm"
 13	"shelley.exe.dev/llm/ant"
 14	"sketch.dev/httprr"
 15)
 16
 17func TestBasicConvo(t *testing.T) {
 18	ctx := context.Background()
 19	rr, err := httprr.Open("testdata/basic_convo.httprr", http.DefaultTransport)
 20	if err != nil {
 21		t.Fatal(err)
 22	}
 23	rr.ScrubReq(func(req *http.Request) error {
 24		req.Header.Del("x-api-key")
 25		return nil
 26	})
 27
 28	apiKey := cmp.Or(os.Getenv("OUTER_SKETCH_MODEL_API_KEY"), os.Getenv("ANTHROPIC_API_KEY"))
 29	srv := &ant.Service{
 30		APIKey: apiKey,
 31		Model:  ant.Claude4Sonnet, // Use specific model to match cached responses
 32		HTTPC:  rr.Client(),
 33	}
 34	convo := New(ctx, srv, nil)
 35
 36	const name = "Cornelius"
 37	res, err := convo.SendUserTextMessage("Hi, my name is " + name)
 38	if err != nil {
 39		t.Fatal(err)
 40	}
 41	for _, part := range res.Content {
 42		t.Logf("%s", part.Text)
 43	}
 44	res, err = convo.SendUserTextMessage("What is my name?")
 45	if err != nil {
 46		t.Fatal(err)
 47	}
 48	got := ""
 49	for _, part := range res.Content {
 50		got += part.Text
 51	}
 52	if !strings.Contains(got, name) {
 53		t.Errorf("model does not know the given name %s: %q", name, got)
 54	}
 55}
 56
 57// TestCancelToolUse tests the CancelToolUse function of the Convo struct
 58func TestCancelToolUse(t *testing.T) {
 59	tests := []struct {
 60		name         string
 61		setupToolUse bool
 62		toolUseID    string
 63		cancelErr    error
 64		expectError  bool
 65		expectCancel bool
 66	}{
 67		{
 68			name:         "Cancel existing tool use",
 69			setupToolUse: true,
 70			toolUseID:    "tool123",
 71			cancelErr:    nil,
 72			expectError:  false,
 73			expectCancel: true,
 74		},
 75		{
 76			name:         "Cancel existing tool use with error",
 77			setupToolUse: true,
 78			toolUseID:    "tool456",
 79			cancelErr:    context.Canceled,
 80			expectError:  false,
 81			expectCancel: true,
 82		},
 83		{
 84			name:         "Cancel non-existent tool use",
 85			setupToolUse: false,
 86			toolUseID:    "tool789",
 87			cancelErr:    nil,
 88			expectError:  true,
 89			expectCancel: false,
 90		},
 91	}
 92
 93	srv := &ant.Service{}
 94	for _, tt := range tests {
 95		t.Run(tt.name, func(t *testing.T) {
 96			convo := New(context.Background(), srv, nil)
 97
 98			var cancelCalled bool
 99			var cancelledWithErr error
100
101			if tt.setupToolUse {
102				// Setup a mock cancel function to track calls
103				mockCancel := func(err error) {
104					cancelCalled = true
105					cancelledWithErr = err
106				}
107
108				convo.toolUseCancelMu.Lock()
109				convo.toolUseCancel[tt.toolUseID] = mockCancel
110				convo.toolUseCancelMu.Unlock()
111			}
112
113			err := convo.CancelToolUse(tt.toolUseID, tt.cancelErr)
114
115			// Check if we got the expected error state
116			if (err != nil) != tt.expectError {
117				t.Errorf("CancelToolUse() error = %v, expectError %v", err, tt.expectError)
118			}
119
120			// Check if the cancel function was called as expected
121			if cancelCalled != tt.expectCancel {
122				t.Errorf("Cancel function called = %v, expectCancel %v", cancelCalled, tt.expectCancel)
123			}
124
125			// If we expected the cancel to be called, verify it was called with the right error
126			if tt.expectCancel && cancelledWithErr != tt.cancelErr {
127				t.Errorf("Cancel function called with error = %v, expected %v", cancelledWithErr, tt.cancelErr)
128			}
129
130			// Verify the toolUseID was removed from the map if it was initially added
131			if tt.setupToolUse {
132				convo.toolUseCancelMu.Lock()
133				_, exists := convo.toolUseCancel[tt.toolUseID]
134				convo.toolUseCancelMu.Unlock()
135
136				if exists {
137					t.Errorf("toolUseID %s still exists in the map after cancellation", tt.toolUseID)
138				}
139			}
140		})
141	}
142}
143
144// TestInsertMissingToolResults tests the insertMissingToolResults function
145// to ensure it doesn't create duplicate tool results when multiple tool uses are missing results.
146func TestInsertMissingToolResults(t *testing.T) {
147	tests := []struct {
148		name            string
149		messages        []llm.Message
150		currentMsg      llm.Message
151		expectedCount   int
152		expectedToolIDs []string
153	}{
154		{
155			name: "Single missing tool result",
156			messages: []llm.Message{
157				{
158					Role: llm.MessageRoleAssistant,
159					Content: []llm.Content{
160						{
161							Type: llm.ContentTypeToolUse,
162							ID:   "tool1",
163						},
164					},
165				},
166			},
167			currentMsg: llm.Message{
168				Role:    llm.MessageRoleUser,
169				Content: []llm.Content{},
170			},
171			expectedCount:   1,
172			expectedToolIDs: []string{"tool1"},
173		},
174		{
175			name: "Multiple missing tool results",
176			messages: []llm.Message{
177				{
178					Role: llm.MessageRoleAssistant,
179					Content: []llm.Content{
180						{
181							Type: llm.ContentTypeToolUse,
182							ID:   "tool1",
183						},
184						{
185							Type: llm.ContentTypeToolUse,
186							ID:   "tool2",
187						},
188						{
189							Type: llm.ContentTypeToolUse,
190							ID:   "tool3",
191						},
192					},
193				},
194			},
195			currentMsg: llm.Message{
196				Role:    llm.MessageRoleUser,
197				Content: []llm.Content{},
198			},
199			expectedCount:   3,
200			expectedToolIDs: []string{"tool1", "tool2", "tool3"},
201		},
202		{
203			name: "No missing tool results when results already present",
204			messages: []llm.Message{
205				{
206					Role: llm.MessageRoleAssistant,
207					Content: []llm.Content{
208						{
209							Type: llm.ContentTypeToolUse,
210							ID:   "tool1",
211						},
212					},
213				},
214			},
215			currentMsg: llm.Message{
216				Role: llm.MessageRoleUser,
217				Content: []llm.Content{
218					{
219						Type:      llm.ContentTypeToolResult,
220						ToolUseID: "tool1",
221					},
222				},
223			},
224			expectedCount:   1, // Only the existing one
225			expectedToolIDs: []string{"tool1"},
226		},
227		{
228			name: "No tool uses in previous message",
229			messages: []llm.Message{
230				{
231					Role: llm.MessageRoleAssistant,
232					Content: []llm.Content{
233						{
234							Type: llm.ContentTypeText,
235							Text: "Just some text",
236						},
237					},
238				},
239			},
240			currentMsg: llm.Message{
241				Role:    llm.MessageRoleUser,
242				Content: []llm.Content{},
243			},
244			expectedCount:   0,
245			expectedToolIDs: []string{},
246		},
247	}
248
249	for _, tt := range tests {
250		t.Run(tt.name, func(t *testing.T) {
251			srv := &ant.Service{}
252			convo := New(context.Background(), srv, nil)
253
254			// Create request with messages
255			req := &llm.Request{
256				Messages: append(tt.messages, tt.currentMsg),
257			}
258
259			// Call insertMissingToolResults
260			msg := tt.currentMsg
261			convo.insertMissingToolResults(req, &msg)
262
263			// Count tool results in the message
264			toolResultCount := 0
265			toolIDs := []string{}
266			for _, content := range msg.Content {
267				if content.Type == llm.ContentTypeToolResult {
268					toolResultCount++
269					toolIDs = append(toolIDs, content.ToolUseID)
270				}
271			}
272
273			// Verify count
274			if toolResultCount != tt.expectedCount {
275				t.Errorf("Expected %d tool results, got %d", tt.expectedCount, toolResultCount)
276			}
277
278			// Verify no duplicates by checking unique tool IDs
279			seenIDs := make(map[string]int)
280			for _, id := range toolIDs {
281				seenIDs[id]++
282			}
283
284			// Check for duplicates
285			for id, count := range seenIDs {
286				if count > 1 {
287					t.Errorf("Duplicate tool result for ID %s: found %d times", id, count)
288				}
289			}
290
291			// Verify all expected tool IDs are present
292			for _, expectedID := range tt.expectedToolIDs {
293				if !slices.Contains(toolIDs, expectedID) {
294					t.Errorf("Expected tool ID %s not found in results", expectedID)
295				}
296			}
297		})
298	}
299}