1package loop
2
3import (
4 "context"
5 "encoding/json"
6 "fmt"
7 "sync"
8 "sync/atomic"
9 "testing"
10 "time"
11
12 "shelley.exe.dev/llm"
13)
14
15// TestInterruptionDuringToolExecution tests that user messages queued during
16// tool execution are processed after the tool completes but before the next
17// tool starts (not at the end of the entire turn).
18func TestInterruptionDuringToolExecution(t *testing.T) {
19 // Track when the tool is called and when it completes
20 var toolStarted atomic.Bool
21 var toolCompleted atomic.Bool
22 var interruptionSeen atomic.Bool
23
24 // Create a slow tool
25 slowTool := &llm.Tool{
26 Name: "slow_tool",
27 Description: "A tool that takes time to execute",
28 InputSchema: llm.MustSchema(`{"type": "object", "properties": {"input": {"type": "string"}}}`),
29 Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut {
30 toolStarted.Store(true)
31 // Sleep to simulate slow tool execution
32 time.Sleep(200 * time.Millisecond)
33 toolCompleted.Store(true)
34 return llm.ToolOut{
35 LLMContent: []llm.Content{
36 {Type: llm.ContentTypeText, Text: "Tool completed"},
37 },
38 }
39 },
40 }
41
42 recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
43 return nil
44 }
45
46 // Create a service that detects the interruption
47 service := &customPredictableService{
48 responseFunc: func(req *llm.Request) (*llm.Response, error) {
49 // Check if we've seen the interruption
50 toolResults := 0
51 for _, msg := range req.Messages {
52 for _, c := range msg.Content {
53 if c.Type == llm.ContentTypeToolResult {
54 toolResults++
55 }
56 if c.Type == llm.ContentTypeText && c.Text == "INTERRUPTION" {
57 interruptionSeen.Store(true)
58 return &llm.Response{
59 Role: llm.MessageRoleAssistant,
60 StopReason: llm.StopReasonEndTurn,
61 Content: []llm.Content{
62 {Type: llm.ContentTypeText, Text: "Acknowledged interruption"},
63 },
64 }, nil
65 }
66 }
67 }
68
69 // First call: use the slow tool
70 if toolResults == 0 {
71 return &llm.Response{
72 Role: llm.MessageRoleAssistant,
73 StopReason: llm.StopReasonToolUse,
74 Content: []llm.Content{
75 {Type: llm.ContentTypeText, Text: "I'll use the slow tool"},
76 {
77 Type: llm.ContentTypeToolUse,
78 ID: "tool_1",
79 ToolName: "slow_tool",
80 ToolInput: json.RawMessage(`{"input":"test"}`),
81 },
82 },
83 }, nil
84 }
85
86 // After tool result, continue with more work
87 return &llm.Response{
88 Role: llm.MessageRoleAssistant,
89 StopReason: llm.StopReasonEndTurn,
90 Content: []llm.Content{
91 {Type: llm.ContentTypeText, Text: "Done with tool"},
92 },
93 }, nil
94 },
95 }
96
97 loop := NewLoop(Config{
98 LLM: service,
99 History: []llm.Message{},
100 Tools: []*llm.Tool{slowTool},
101 RecordMessage: recordMessage,
102 })
103
104 // Queue initial user message that will trigger tool use
105 loop.QueueUserMessage(llm.Message{
106 Role: llm.MessageRoleUser,
107 Content: []llm.Content{{Type: llm.ContentTypeText, Text: "use the tool"}},
108 })
109
110 // Run the loop in background
111 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
112 defer cancel()
113
114 var loopDone sync.WaitGroup
115 loopDone.Add(1)
116 go func() {
117 defer loopDone.Done()
118 loop.Go(ctx)
119 }()
120
121 // Wait for tool to start
122 for !toolStarted.Load() {
123 time.Sleep(10 * time.Millisecond)
124 }
125
126 // Queue an interruption message while tool is executing
127 loop.QueueUserMessage(llm.Message{
128 Role: llm.MessageRoleUser,
129 Content: []llm.Content{{Type: llm.ContentTypeText, Text: "INTERRUPTION"}},
130 })
131 t.Log("Queued interruption message while tool is executing")
132
133 // The message should remain in queue while tool is executing
134 time.Sleep(50 * time.Millisecond)
135 if !toolCompleted.Load() {
136 loop.mu.Lock()
137 queueLen := len(loop.messageQueue)
138 loop.mu.Unlock()
139 if queueLen > 0 {
140 t.Log("Message is waiting in queue during tool execution (expected)")
141 }
142 }
143
144 // Wait for loop to finish
145 time.Sleep(500 * time.Millisecond)
146 cancel()
147 loopDone.Wait()
148
149 // Verify the interruption was seen by the LLM
150 if interruptionSeen.Load() {
151 t.Log("SUCCESS: Interruption was seen by LLM after tool completed")
152 } else {
153 t.Error("Interruption was never seen by the LLM")
154 }
155}
156
157// TestInterruptionDuringMultiToolChain tests interruption during a chain of tool calls.
158// With the fix, the interruption should be visible to the LLM after the first tool completes.
159func TestInterruptionDuringMultiToolChain(t *testing.T) {
160 var toolCallCount atomic.Int32
161 var interruptionSeenAtToolResult atomic.Int32 // -1 means not seen
162
163 // Create a tool that's called multiple times
164 multiTool := &llm.Tool{
165 Name: "multi_tool",
166 Description: "A tool that might be called multiple times",
167 InputSchema: llm.MustSchema(`{"type": "object", "properties": {"step": {"type": "integer"}}}`),
168 Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut {
169 count := toolCallCount.Add(1)
170 time.Sleep(100 * time.Millisecond) // Simulate some work
171 _ = count
172 return llm.ToolOut{
173 LLMContent: []llm.Content{
174 {Type: llm.ContentTypeText, Text: "Tool step completed"},
175 },
176 }
177 },
178 }
179
180 recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
181 return nil
182 }
183
184 // Service that makes multiple tool calls but stops when it sees "STOP"
185 interruptionSeenAtToolResult.Store(-1)
186 service := &customPredictableService{
187 responseFunc: func(req *llm.Request) (*llm.Response, error) {
188 // Check if we've seen the STOP message
189 toolResults := 0
190 for _, msg := range req.Messages {
191 for _, c := range msg.Content {
192 if c.Type == llm.ContentTypeToolResult {
193 toolResults++
194 }
195 if c.Type == llm.ContentTypeText && c.Text == "STOP" {
196 // Record when we first saw the interruption
197 interruptionSeenAtToolResult.CompareAndSwap(-1, int32(toolResults))
198 // Stop immediately when we see the interruption
199 return &llm.Response{
200 Role: llm.MessageRoleAssistant,
201 StopReason: llm.StopReasonEndTurn,
202 Content: []llm.Content{
203 {Type: llm.ContentTypeText, Text: "Stopped due to user interruption"},
204 },
205 }, nil
206 }
207 }
208 }
209
210 if toolResults < 5 {
211 // Keep calling the tool (would do 5 if not interrupted)
212 return &llm.Response{
213 Role: llm.MessageRoleAssistant,
214 StopReason: llm.StopReasonToolUse,
215 Content: []llm.Content{
216 {Type: llm.ContentTypeText, Text: "Calling tool again"},
217 {
218 Type: llm.ContentTypeToolUse,
219 ID: fmt.Sprintf("tool_%d", toolResults+1),
220 ToolName: "multi_tool",
221 ToolInput: json.RawMessage(fmt.Sprintf(`{"step":%d}`, toolResults+1)),
222 },
223 },
224 }, nil
225 }
226
227 // Done with tools
228 return &llm.Response{
229 Role: llm.MessageRoleAssistant,
230 StopReason: llm.StopReasonEndTurn,
231 Content: []llm.Content{
232 {Type: llm.ContentTypeText, Text: "All tools completed"},
233 },
234 }, nil
235 },
236 }
237
238 loop := NewLoop(Config{
239 LLM: service,
240 History: []llm.Message{},
241 Tools: []*llm.Tool{multiTool},
242 RecordMessage: recordMessage,
243 })
244
245 // Queue initial user message
246 loop.QueueUserMessage(llm.Message{
247 Role: llm.MessageRoleUser,
248 Content: []llm.Content{{Type: llm.ContentTypeText, Text: "run the tool 5 times"}},
249 })
250
251 ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
252 defer cancel()
253
254 var loopDone sync.WaitGroup
255 loopDone.Add(1)
256 go func() {
257 defer loopDone.Done()
258 loop.Go(ctx)
259 }()
260
261 // Wait for first tool call to complete
262 for toolCallCount.Load() < 1 {
263 time.Sleep(10 * time.Millisecond)
264 }
265
266 // Queue interruption after first tool
267 loop.QueueUserMessage(llm.Message{
268 Role: llm.MessageRoleUser,
269 Content: []llm.Content{{Type: llm.ContentTypeText, Text: "STOP"}},
270 })
271 t.Logf("Queued STOP message after tool call %d", toolCallCount.Load())
272
273 // Wait for loop to process and stop
274 time.Sleep(500 * time.Millisecond)
275
276 cancel()
277 loopDone.Wait()
278
279 finalToolCount := toolCallCount.Load()
280 seenAt := interruptionSeenAtToolResult.Load()
281
282 t.Logf("Final tool call count: %d (would be 5 without interruption)", finalToolCount)
283 t.Logf("Interruption was seen by LLM after tool result %d", seenAt)
284
285 // With the fix, the interruption should be seen after just 1 tool result
286 // (the tool that was running when we queued the STOP message)
287 if seenAt == 1 {
288 t.Log("SUCCESS: Interruption was processed immediately after first tool completed")
289 } else if seenAt > 1 {
290 t.Errorf("Interruption was delayed: seen after %d tool results, expected 1", seenAt)
291 } else if seenAt == -1 {
292 t.Error("Interruption was never seen by the LLM")
293 }
294
295 // The tool should only be called a small number of times since we interrupted
296 if finalToolCount > 2 {
297 t.Errorf("Too many tool calls (%d): interruption should have stopped the chain earlier", finalToolCount)
298 }
299}
300
301// customPredictableService allows custom response logic for testing
302type customPredictableService struct {
303 responses []customResponse
304 responseFunc func(req *llm.Request) (*llm.Response, error)
305 callIndex int
306 mu sync.Mutex
307}
308
309type customResponse struct {
310 response *llm.Response
311 err error
312}
313
314func (s *customPredictableService) Do(ctx context.Context, req *llm.Request) (*llm.Response, error) {
315 s.mu.Lock()
316 defer s.mu.Unlock()
317
318 if s.responseFunc != nil {
319 return s.responseFunc(req)
320 }
321
322 if s.callIndex >= len(s.responses) {
323 // Default response
324 return &llm.Response{
325 Role: llm.MessageRoleAssistant,
326 StopReason: llm.StopReasonEndTurn,
327 Content: []llm.Content{
328 {Type: llm.ContentTypeText, Text: "No more responses configured"},
329 },
330 }, nil
331 }
332
333 resp := s.responses[s.callIndex]
334 s.callIndex++
335 return resp.response, resp.err
336}
337
338func (s *customPredictableService) GetDefaultModel() string {
339 return "custom-test"
340}
341
342func (s *customPredictableService) TokenContextWindow() int {
343 return 100000
344}
345
346func (s *customPredictableService) MaxImageDimension() int {
347 return 8000
348}
349
350// TestNoInterruptionNormalFlow verifies that normal tool chains work correctly
351// when no interruption is queued.
352func TestNoInterruptionNormalFlow(t *testing.T) {
353 var toolCallCount atomic.Int32
354
355 // Create a tool that tracks calls
356 multiTool := &llm.Tool{
357 Name: "multi_tool",
358 Description: "A tool",
359 InputSchema: llm.MustSchema(`{"type": "object", "properties": {"step": {"type": "integer"}}}`),
360 Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut {
361 toolCallCount.Add(1)
362 return llm.ToolOut{
363 LLMContent: []llm.Content{
364 {Type: llm.ContentTypeText, Text: "done"},
365 },
366 }
367 },
368 }
369
370 recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
371 return nil
372 }
373
374 // Service that makes 3 tool calls then finishes
375 service := &customPredictableService{
376 responseFunc: func(req *llm.Request) (*llm.Response, error) {
377 toolResults := 0
378 for _, msg := range req.Messages {
379 for _, c := range msg.Content {
380 if c.Type == llm.ContentTypeToolResult {
381 toolResults++
382 }
383 }
384 }
385
386 if toolResults < 3 {
387 return &llm.Response{
388 Role: llm.MessageRoleAssistant,
389 StopReason: llm.StopReasonToolUse,
390 Content: []llm.Content{
391 {Type: llm.ContentTypeText, Text: "Calling tool"},
392 {
393 Type: llm.ContentTypeToolUse,
394 ID: fmt.Sprintf("tool_%d", toolResults+1),
395 ToolName: "multi_tool",
396 ToolInput: json.RawMessage(fmt.Sprintf(`{"step":%d}`, toolResults+1)),
397 },
398 },
399 }, nil
400 }
401
402 return &llm.Response{
403 Role: llm.MessageRoleAssistant,
404 StopReason: llm.StopReasonEndTurn,
405 Content: []llm.Content{
406 {Type: llm.ContentTypeText, Text: "All done"},
407 },
408 }, nil
409 },
410 }
411
412 loop := NewLoop(Config{
413 LLM: service,
414 History: []llm.Message{},
415 Tools: []*llm.Tool{multiTool},
416 RecordMessage: recordMessage,
417 })
418
419 // Queue initial user message (no interruption)
420 loop.QueueUserMessage(llm.Message{
421 Role: llm.MessageRoleUser,
422 Content: []llm.Content{{Type: llm.ContentTypeText, Text: "run tools"}},
423 })
424
425 ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
426 defer cancel()
427
428 var loopDone sync.WaitGroup
429 loopDone.Add(1)
430 go func() {
431 defer loopDone.Done()
432 loop.Go(ctx)
433 }()
434
435 // Wait for completion
436 time.Sleep(500 * time.Millisecond)
437 cancel()
438 loopDone.Wait()
439
440 finalCount := toolCallCount.Load()
441 if finalCount != 3 {
442 t.Errorf("Expected 3 tool calls, got %d", finalCount)
443 } else {
444 t.Log("SUCCESS: Normal flow completed 3 tool calls as expected")
445 }
446}