interruption_test.go

  1package loop
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"fmt"
  7	"sync"
  8	"sync/atomic"
  9	"testing"
 10	"time"
 11
 12	"shelley.exe.dev/llm"
 13)
 14
 15// TestInterruptionDuringToolExecution tests that user messages queued during
 16// tool execution are processed after the tool completes but before the next
 17// tool starts (not at the end of the entire turn).
 18func TestInterruptionDuringToolExecution(t *testing.T) {
 19	// Track when the tool is called and when it completes
 20	var toolStarted atomic.Bool
 21	var toolCompleted atomic.Bool
 22	var interruptionSeen atomic.Bool
 23
 24	// Create a slow tool
 25	slowTool := &llm.Tool{
 26		Name:        "slow_tool",
 27		Description: "A tool that takes time to execute",
 28		InputSchema: llm.MustSchema(`{"type": "object", "properties": {"input": {"type": "string"}}}`),
 29		Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut {
 30			toolStarted.Store(true)
 31			// Sleep to simulate slow tool execution
 32			time.Sleep(200 * time.Millisecond)
 33			toolCompleted.Store(true)
 34			return llm.ToolOut{
 35				LLMContent: []llm.Content{
 36					{Type: llm.ContentTypeText, Text: "Tool completed"},
 37				},
 38			}
 39		},
 40	}
 41
 42	recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
 43		return nil
 44	}
 45
 46	// Create a service that detects the interruption
 47	service := &customPredictableService{
 48		responseFunc: func(req *llm.Request) (*llm.Response, error) {
 49			// Check if we've seen the interruption
 50			toolResults := 0
 51			for _, msg := range req.Messages {
 52				for _, c := range msg.Content {
 53					if c.Type == llm.ContentTypeToolResult {
 54						toolResults++
 55					}
 56					if c.Type == llm.ContentTypeText && c.Text == "INTERRUPTION" {
 57						interruptionSeen.Store(true)
 58						return &llm.Response{
 59							Role:       llm.MessageRoleAssistant,
 60							StopReason: llm.StopReasonEndTurn,
 61							Content: []llm.Content{
 62								{Type: llm.ContentTypeText, Text: "Acknowledged interruption"},
 63							},
 64						}, nil
 65					}
 66				}
 67			}
 68
 69			// First call: use the slow tool
 70			if toolResults == 0 {
 71				return &llm.Response{
 72					Role:       llm.MessageRoleAssistant,
 73					StopReason: llm.StopReasonToolUse,
 74					Content: []llm.Content{
 75						{Type: llm.ContentTypeText, Text: "I'll use the slow tool"},
 76						{
 77							Type:      llm.ContentTypeToolUse,
 78							ID:        "tool_1",
 79							ToolName:  "slow_tool",
 80							ToolInput: json.RawMessage(`{"input":"test"}`),
 81						},
 82					},
 83				}, nil
 84			}
 85
 86			// After tool result, continue with more work
 87			return &llm.Response{
 88				Role:       llm.MessageRoleAssistant,
 89				StopReason: llm.StopReasonEndTurn,
 90				Content: []llm.Content{
 91					{Type: llm.ContentTypeText, Text: "Done with tool"},
 92				},
 93			}, nil
 94		},
 95	}
 96
 97	loop := NewLoop(Config{
 98		LLM:           service,
 99		History:       []llm.Message{},
100		Tools:         []*llm.Tool{slowTool},
101		RecordMessage: recordMessage,
102	})
103
104	// Queue initial user message that will trigger tool use
105	loop.QueueUserMessage(llm.Message{
106		Role:    llm.MessageRoleUser,
107		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "use the tool"}},
108	})
109
110	// Run the loop in background
111	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
112	defer cancel()
113
114	var loopDone sync.WaitGroup
115	loopDone.Add(1)
116	go func() {
117		defer loopDone.Done()
118		loop.Go(ctx)
119	}()
120
121	// Wait for tool to start
122	for !toolStarted.Load() {
123		time.Sleep(10 * time.Millisecond)
124	}
125
126	// Queue an interruption message while tool is executing
127	loop.QueueUserMessage(llm.Message{
128		Role:    llm.MessageRoleUser,
129		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "INTERRUPTION"}},
130	})
131	t.Log("Queued interruption message while tool is executing")
132
133	// The message should remain in queue while tool is executing
134	time.Sleep(50 * time.Millisecond)
135	if !toolCompleted.Load() {
136		loop.mu.Lock()
137		queueLen := len(loop.messageQueue)
138		loop.mu.Unlock()
139		if queueLen > 0 {
140			t.Log("Message is waiting in queue during tool execution (expected)")
141		}
142	}
143
144	// Wait for loop to finish
145	time.Sleep(500 * time.Millisecond)
146	cancel()
147	loopDone.Wait()
148
149	// Verify the interruption was seen by the LLM
150	if interruptionSeen.Load() {
151		t.Log("SUCCESS: Interruption was seen by LLM after tool completed")
152	} else {
153		t.Error("Interruption was never seen by the LLM")
154	}
155}
156
157// TestInterruptionDuringMultiToolChain tests interruption during a chain of tool calls.
158// With the fix, the interruption should be visible to the LLM after the first tool completes.
159func TestInterruptionDuringMultiToolChain(t *testing.T) {
160	var toolCallCount atomic.Int32
161	var interruptionSeenAtToolResult atomic.Int32 // -1 means not seen
162
163	// Create a tool that's called multiple times
164	multiTool := &llm.Tool{
165		Name:        "multi_tool",
166		Description: "A tool that might be called multiple times",
167		InputSchema: llm.MustSchema(`{"type": "object", "properties": {"step": {"type": "integer"}}}`),
168		Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut {
169			count := toolCallCount.Add(1)
170			time.Sleep(100 * time.Millisecond) // Simulate some work
171			_ = count
172			return llm.ToolOut{
173				LLMContent: []llm.Content{
174					{Type: llm.ContentTypeText, Text: "Tool step completed"},
175				},
176			}
177		},
178	}
179
180	recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
181		return nil
182	}
183
184	// Service that makes multiple tool calls but stops when it sees "STOP"
185	interruptionSeenAtToolResult.Store(-1)
186	service := &customPredictableService{
187		responseFunc: func(req *llm.Request) (*llm.Response, error) {
188			// Check if we've seen the STOP message
189			toolResults := 0
190			for _, msg := range req.Messages {
191				for _, c := range msg.Content {
192					if c.Type == llm.ContentTypeToolResult {
193						toolResults++
194					}
195					if c.Type == llm.ContentTypeText && c.Text == "STOP" {
196						// Record when we first saw the interruption
197						interruptionSeenAtToolResult.CompareAndSwap(-1, int32(toolResults))
198						// Stop immediately when we see the interruption
199						return &llm.Response{
200							Role:       llm.MessageRoleAssistant,
201							StopReason: llm.StopReasonEndTurn,
202							Content: []llm.Content{
203								{Type: llm.ContentTypeText, Text: "Stopped due to user interruption"},
204							},
205						}, nil
206					}
207				}
208			}
209
210			if toolResults < 5 {
211				// Keep calling the tool (would do 5 if not interrupted)
212				return &llm.Response{
213					Role:       llm.MessageRoleAssistant,
214					StopReason: llm.StopReasonToolUse,
215					Content: []llm.Content{
216						{Type: llm.ContentTypeText, Text: "Calling tool again"},
217						{
218							Type:      llm.ContentTypeToolUse,
219							ID:        fmt.Sprintf("tool_%d", toolResults+1),
220							ToolName:  "multi_tool",
221							ToolInput: json.RawMessage(fmt.Sprintf(`{"step":%d}`, toolResults+1)),
222						},
223					},
224				}, nil
225			}
226
227			// Done with tools
228			return &llm.Response{
229				Role:       llm.MessageRoleAssistant,
230				StopReason: llm.StopReasonEndTurn,
231				Content: []llm.Content{
232					{Type: llm.ContentTypeText, Text: "All tools completed"},
233				},
234			}, nil
235		},
236	}
237
238	loop := NewLoop(Config{
239		LLM:           service,
240		History:       []llm.Message{},
241		Tools:         []*llm.Tool{multiTool},
242		RecordMessage: recordMessage,
243	})
244
245	// Queue initial user message
246	loop.QueueUserMessage(llm.Message{
247		Role:    llm.MessageRoleUser,
248		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "run the tool 5 times"}},
249	})
250
251	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
252	defer cancel()
253
254	var loopDone sync.WaitGroup
255	loopDone.Add(1)
256	go func() {
257		defer loopDone.Done()
258		loop.Go(ctx)
259	}()
260
261	// Wait for first tool call to complete
262	for toolCallCount.Load() < 1 {
263		time.Sleep(10 * time.Millisecond)
264	}
265
266	// Queue interruption after first tool
267	loop.QueueUserMessage(llm.Message{
268		Role:    llm.MessageRoleUser,
269		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "STOP"}},
270	})
271	t.Logf("Queued STOP message after tool call %d", toolCallCount.Load())
272
273	// Wait for loop to process and stop
274	time.Sleep(500 * time.Millisecond)
275
276	cancel()
277	loopDone.Wait()
278
279	finalToolCount := toolCallCount.Load()
280	seenAt := interruptionSeenAtToolResult.Load()
281
282	t.Logf("Final tool call count: %d (would be 5 without interruption)", finalToolCount)
283	t.Logf("Interruption was seen by LLM after tool result %d", seenAt)
284
285	// With the fix, the interruption should be seen after just 1 tool result
286	// (the tool that was running when we queued the STOP message)
287	if seenAt == 1 {
288		t.Log("SUCCESS: Interruption was processed immediately after first tool completed")
289	} else if seenAt > 1 {
290		t.Errorf("Interruption was delayed: seen after %d tool results, expected 1", seenAt)
291	} else if seenAt == -1 {
292		t.Error("Interruption was never seen by the LLM")
293	}
294
295	// The tool should only be called a small number of times since we interrupted
296	if finalToolCount > 2 {
297		t.Errorf("Too many tool calls (%d): interruption should have stopped the chain earlier", finalToolCount)
298	}
299}
300
301// customPredictableService allows custom response logic for testing
302type customPredictableService struct {
303	responses    []customResponse
304	responseFunc func(req *llm.Request) (*llm.Response, error)
305	callIndex    int
306	mu           sync.Mutex
307}
308
309type customResponse struct {
310	response *llm.Response
311	err      error
312}
313
314func (s *customPredictableService) Do(ctx context.Context, req *llm.Request) (*llm.Response, error) {
315	s.mu.Lock()
316	defer s.mu.Unlock()
317
318	if s.responseFunc != nil {
319		return s.responseFunc(req)
320	}
321
322	if s.callIndex >= len(s.responses) {
323		// Default response
324		return &llm.Response{
325			Role:       llm.MessageRoleAssistant,
326			StopReason: llm.StopReasonEndTurn,
327			Content: []llm.Content{
328				{Type: llm.ContentTypeText, Text: "No more responses configured"},
329			},
330		}, nil
331	}
332
333	resp := s.responses[s.callIndex]
334	s.callIndex++
335	return resp.response, resp.err
336}
337
338func (s *customPredictableService) GetDefaultModel() string {
339	return "custom-test"
340}
341
342func (s *customPredictableService) TokenContextWindow() int {
343	return 100000
344}
345
346func (s *customPredictableService) MaxImageDimension() int {
347	return 8000
348}
349
350// TestNoInterruptionNormalFlow verifies that normal tool chains work correctly
351// when no interruption is queued.
352func TestNoInterruptionNormalFlow(t *testing.T) {
353	var toolCallCount atomic.Int32
354
355	// Create a tool that tracks calls
356	multiTool := &llm.Tool{
357		Name:        "multi_tool",
358		Description: "A tool",
359		InputSchema: llm.MustSchema(`{"type": "object", "properties": {"step": {"type": "integer"}}}`),
360		Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut {
361			toolCallCount.Add(1)
362			return llm.ToolOut{
363				LLMContent: []llm.Content{
364					{Type: llm.ContentTypeText, Text: "done"},
365				},
366			}
367		},
368	}
369
370	recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
371		return nil
372	}
373
374	// Service that makes 3 tool calls then finishes
375	service := &customPredictableService{
376		responseFunc: func(req *llm.Request) (*llm.Response, error) {
377			toolResults := 0
378			for _, msg := range req.Messages {
379				for _, c := range msg.Content {
380					if c.Type == llm.ContentTypeToolResult {
381						toolResults++
382					}
383				}
384			}
385
386			if toolResults < 3 {
387				return &llm.Response{
388					Role:       llm.MessageRoleAssistant,
389					StopReason: llm.StopReasonToolUse,
390					Content: []llm.Content{
391						{Type: llm.ContentTypeText, Text: "Calling tool"},
392						{
393							Type:      llm.ContentTypeToolUse,
394							ID:        fmt.Sprintf("tool_%d", toolResults+1),
395							ToolName:  "multi_tool",
396							ToolInput: json.RawMessage(fmt.Sprintf(`{"step":%d}`, toolResults+1)),
397						},
398					},
399				}, nil
400			}
401
402			return &llm.Response{
403				Role:       llm.MessageRoleAssistant,
404				StopReason: llm.StopReasonEndTurn,
405				Content: []llm.Content{
406					{Type: llm.ContentTypeText, Text: "All done"},
407				},
408			}, nil
409		},
410	}
411
412	loop := NewLoop(Config{
413		LLM:           service,
414		History:       []llm.Message{},
415		Tools:         []*llm.Tool{multiTool},
416		RecordMessage: recordMessage,
417	})
418
419	// Queue initial user message (no interruption)
420	loop.QueueUserMessage(llm.Message{
421		Role:    llm.MessageRoleUser,
422		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "run tools"}},
423	})
424
425	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
426	defer cancel()
427
428	var loopDone sync.WaitGroup
429	loopDone.Add(1)
430	go func() {
431		defer loopDone.Done()
432		loop.Go(ctx)
433	}()
434
435	// Wait for completion
436	time.Sleep(500 * time.Millisecond)
437	cancel()
438	loopDone.Wait()
439
440	finalCount := toolCallCount.Load()
441	if finalCount != 3 {
442		t.Errorf("Expected 3 tool calls, got %d", finalCount)
443	} else {
444		t.Log("SUCCESS: Normal flow completed 3 tool calls as expected")
445	}
446}