diff --git a/loop/interruption_test.go b/loop/interruption_test.go new file mode 100644 index 0000000000000000000000000000000000000000..93c3627ae0e49770cb2a62b0613395af90acf227 --- /dev/null +++ b/loop/interruption_test.go @@ -0,0 +1,446 @@ +package loop + +import ( + "context" + "encoding/json" + "fmt" + "sync" + "sync/atomic" + "testing" + "time" + + "shelley.exe.dev/llm" +) + +// TestInterruptionDuringToolExecution tests that user messages queued during +// tool execution are processed after the tool completes but before the next +// tool starts (not at the end of the entire turn). +func TestInterruptionDuringToolExecution(t *testing.T) { + // Track when the tool is called and when it completes + var toolStarted atomic.Bool + var toolCompleted atomic.Bool + var interruptionSeen atomic.Bool + + // Create a slow tool + slowTool := &llm.Tool{ + Name: "slow_tool", + Description: "A tool that takes time to execute", + InputSchema: llm.MustSchema(`{"type": "object", "properties": {"input": {"type": "string"}}}`), + Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut { + toolStarted.Store(true) + // Sleep to simulate slow tool execution + time.Sleep(200 * time.Millisecond) + toolCompleted.Store(true) + return llm.ToolOut{ + LLMContent: []llm.Content{ + {Type: llm.ContentTypeText, Text: "Tool completed"}, + }, + } + }, + } + + recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error { + return nil + } + + // Create a service that detects the interruption + service := &customPredictableService{ + responseFunc: func(req *llm.Request) (*llm.Response, error) { + // Check if we've seen the interruption + toolResults := 0 + for _, msg := range req.Messages { + for _, c := range msg.Content { + if c.Type == llm.ContentTypeToolResult { + toolResults++ + } + if c.Type == llm.ContentTypeText && c.Text == "INTERRUPTION" { + interruptionSeen.Store(true) + return &llm.Response{ + Role: llm.MessageRoleAssistant, + StopReason: llm.StopReasonEndTurn, + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: "Acknowledged interruption"}, + }, + }, nil + } + } + } + + // First call: use the slow tool + if toolResults == 0 { + return &llm.Response{ + Role: llm.MessageRoleAssistant, + StopReason: llm.StopReasonToolUse, + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: "I'll use the slow tool"}, + { + Type: llm.ContentTypeToolUse, + ID: "tool_1", + ToolName: "slow_tool", + ToolInput: json.RawMessage(`{"input":"test"}`), + }, + }, + }, nil + } + + // After tool result, continue with more work + return &llm.Response{ + Role: llm.MessageRoleAssistant, + StopReason: llm.StopReasonEndTurn, + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: "Done with tool"}, + }, + }, nil + }, + } + + loop := NewLoop(Config{ + LLM: service, + History: []llm.Message{}, + Tools: []*llm.Tool{slowTool}, + RecordMessage: recordMessage, + }) + + // Queue initial user message that will trigger tool use + loop.QueueUserMessage(llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "use the tool"}}, + }) + + // Run the loop in background + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + var loopDone sync.WaitGroup + loopDone.Add(1) + go func() { + defer loopDone.Done() + loop.Go(ctx) + }() + + // Wait for tool to start + for !toolStarted.Load() { + time.Sleep(10 * time.Millisecond) + } + + // Queue an interruption message while tool is executing + loop.QueueUserMessage(llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "INTERRUPTION"}}, + }) + t.Log("Queued interruption message while tool is executing") + + // The message should remain in queue while tool is executing + time.Sleep(50 * time.Millisecond) + if !toolCompleted.Load() { + loop.mu.Lock() + queueLen := len(loop.messageQueue) + loop.mu.Unlock() + if queueLen > 0 { + t.Log("Message is waiting in queue during tool execution (expected)") + } + } + + // Wait for loop to finish + time.Sleep(500 * time.Millisecond) + cancel() + loopDone.Wait() + + // Verify the interruption was seen by the LLM + if interruptionSeen.Load() { + t.Log("SUCCESS: Interruption was seen by LLM after tool completed") + } else { + t.Error("Interruption was never seen by the LLM") + } +} + +// TestInterruptionDuringMultiToolChain tests interruption during a chain of tool calls. +// With the fix, the interruption should be visible to the LLM after the first tool completes. +func TestInterruptionDuringMultiToolChain(t *testing.T) { + var toolCallCount atomic.Int32 + var interruptionSeenAtToolResult atomic.Int32 // -1 means not seen + + // Create a tool that's called multiple times + multiTool := &llm.Tool{ + Name: "multi_tool", + Description: "A tool that might be called multiple times", + InputSchema: llm.MustSchema(`{"type": "object", "properties": {"step": {"type": "integer"}}}`), + Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut { + count := toolCallCount.Add(1) + time.Sleep(100 * time.Millisecond) // Simulate some work + _ = count + return llm.ToolOut{ + LLMContent: []llm.Content{ + {Type: llm.ContentTypeText, Text: "Tool step completed"}, + }, + } + }, + } + + recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error { + return nil + } + + // Service that makes multiple tool calls but stops when it sees "STOP" + interruptionSeenAtToolResult.Store(-1) + service := &customPredictableService{ + responseFunc: func(req *llm.Request) (*llm.Response, error) { + // Check if we've seen the STOP message + toolResults := 0 + for _, msg := range req.Messages { + for _, c := range msg.Content { + if c.Type == llm.ContentTypeToolResult { + toolResults++ + } + if c.Type == llm.ContentTypeText && c.Text == "STOP" { + // Record when we first saw the interruption + interruptionSeenAtToolResult.CompareAndSwap(-1, int32(toolResults)) + // Stop immediately when we see the interruption + return &llm.Response{ + Role: llm.MessageRoleAssistant, + StopReason: llm.StopReasonEndTurn, + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: "Stopped due to user interruption"}, + }, + }, nil + } + } + } + + if toolResults < 5 { + // Keep calling the tool (would do 5 if not interrupted) + return &llm.Response{ + Role: llm.MessageRoleAssistant, + StopReason: llm.StopReasonToolUse, + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: "Calling tool again"}, + { + Type: llm.ContentTypeToolUse, + ID: fmt.Sprintf("tool_%d", toolResults+1), + ToolName: "multi_tool", + ToolInput: json.RawMessage(fmt.Sprintf(`{"step":%d}`, toolResults+1)), + }, + }, + }, nil + } + + // Done with tools + return &llm.Response{ + Role: llm.MessageRoleAssistant, + StopReason: llm.StopReasonEndTurn, + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: "All tools completed"}, + }, + }, nil + }, + } + + loop := NewLoop(Config{ + LLM: service, + History: []llm.Message{}, + Tools: []*llm.Tool{multiTool}, + RecordMessage: recordMessage, + }) + + // Queue initial user message + loop.QueueUserMessage(llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "run the tool 5 times"}}, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) + defer cancel() + + var loopDone sync.WaitGroup + loopDone.Add(1) + go func() { + defer loopDone.Done() + loop.Go(ctx) + }() + + // Wait for first tool call to complete + for toolCallCount.Load() < 1 { + time.Sleep(10 * time.Millisecond) + } + + // Queue interruption after first tool + loop.QueueUserMessage(llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "STOP"}}, + }) + t.Logf("Queued STOP message after tool call %d", toolCallCount.Load()) + + // Wait for loop to process and stop + time.Sleep(500 * time.Millisecond) + + cancel() + loopDone.Wait() + + finalToolCount := toolCallCount.Load() + seenAt := interruptionSeenAtToolResult.Load() + + t.Logf("Final tool call count: %d (would be 5 without interruption)", finalToolCount) + t.Logf("Interruption was seen by LLM after tool result %d", seenAt) + + // With the fix, the interruption should be seen after just 1 tool result + // (the tool that was running when we queued the STOP message) + if seenAt == 1 { + t.Log("SUCCESS: Interruption was processed immediately after first tool completed") + } else if seenAt > 1 { + t.Errorf("Interruption was delayed: seen after %d tool results, expected 1", seenAt) + } else if seenAt == -1 { + t.Error("Interruption was never seen by the LLM") + } + + // The tool should only be called a small number of times since we interrupted + if finalToolCount > 2 { + t.Errorf("Too many tool calls (%d): interruption should have stopped the chain earlier", finalToolCount) + } +} + +// customPredictableService allows custom response logic for testing +type customPredictableService struct { + responses []customResponse + responseFunc func(req *llm.Request) (*llm.Response, error) + callIndex int + mu sync.Mutex +} + +type customResponse struct { + response *llm.Response + err error +} + +func (s *customPredictableService) Do(ctx context.Context, req *llm.Request) (*llm.Response, error) { + s.mu.Lock() + defer s.mu.Unlock() + + if s.responseFunc != nil { + return s.responseFunc(req) + } + + if s.callIndex >= len(s.responses) { + // Default response + return &llm.Response{ + Role: llm.MessageRoleAssistant, + StopReason: llm.StopReasonEndTurn, + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: "No more responses configured"}, + }, + }, nil + } + + resp := s.responses[s.callIndex] + s.callIndex++ + return resp.response, resp.err +} + +func (s *customPredictableService) GetDefaultModel() string { + return "custom-test" +} + +func (s *customPredictableService) TokenContextWindow() int { + return 100000 +} + +func (s *customPredictableService) MaxImageDimension() int { + return 8000 +} + +// TestNoInterruptionNormalFlow verifies that normal tool chains work correctly +// when no interruption is queued. +func TestNoInterruptionNormalFlow(t *testing.T) { + var toolCallCount atomic.Int32 + + // Create a tool that tracks calls + multiTool := &llm.Tool{ + Name: "multi_tool", + Description: "A tool", + InputSchema: llm.MustSchema(`{"type": "object", "properties": {"step": {"type": "integer"}}}`), + Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut { + toolCallCount.Add(1) + return llm.ToolOut{ + LLMContent: []llm.Content{ + {Type: llm.ContentTypeText, Text: "done"}, + }, + } + }, + } + + recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error { + return nil + } + + // Service that makes 3 tool calls then finishes + service := &customPredictableService{ + responseFunc: func(req *llm.Request) (*llm.Response, error) { + toolResults := 0 + for _, msg := range req.Messages { + for _, c := range msg.Content { + if c.Type == llm.ContentTypeToolResult { + toolResults++ + } + } + } + + if toolResults < 3 { + return &llm.Response{ + Role: llm.MessageRoleAssistant, + StopReason: llm.StopReasonToolUse, + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: "Calling tool"}, + { + Type: llm.ContentTypeToolUse, + ID: fmt.Sprintf("tool_%d", toolResults+1), + ToolName: "multi_tool", + ToolInput: json.RawMessage(fmt.Sprintf(`{"step":%d}`, toolResults+1)), + }, + }, + }, nil + } + + return &llm.Response{ + Role: llm.MessageRoleAssistant, + StopReason: llm.StopReasonEndTurn, + Content: []llm.Content{ + {Type: llm.ContentTypeText, Text: "All done"}, + }, + }, nil + }, + } + + loop := NewLoop(Config{ + LLM: service, + History: []llm.Message{}, + Tools: []*llm.Tool{multiTool}, + RecordMessage: recordMessage, + }) + + // Queue initial user message (no interruption) + loop.QueueUserMessage(llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "run tools"}}, + }) + + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + var loopDone sync.WaitGroup + loopDone.Add(1) + go func() { + defer loopDone.Done() + loop.Go(ctx) + }() + + // Wait for completion + time.Sleep(500 * time.Millisecond) + cancel() + loopDone.Wait() + + finalCount := toolCallCount.Load() + if finalCount != 3 { + t.Errorf("Expected 3 tool calls, got %d", finalCount) + } else { + t.Log("SUCCESS: Normal flow completed 3 tool calls as expected") + } +} diff --git a/loop/loop.go b/loop/loop.go index 246a96ba7bb4744b8e849a3ac4e41d208bee69b1..e115287a2b97683e179ceb6791a4303d1714228c 100644 --- a/loop/loop.go +++ b/loop/loop.go @@ -441,6 +441,15 @@ func (l *Loop) handleToolCalls(ctx context.Context, content []llm.Content) error l.mu.Lock() l.history = append(l.history, toolMessage) + // Check for queued user messages (interruptions) before continuing. + // This allows user messages to be processed as soon as possible. + if len(l.messageQueue) > 0 { + for _, msg := range l.messageQueue { + l.history = append(l.history, msg) + } + l.messageQueue = l.messageQueue[:0] + l.logger.Info("processing user interruption during tool execution") + } l.mu.Unlock() // Record tool result message