loop: process user messages immediately after tool completion

Philip Zeyliger and Shelley created

Prompt: Somewhere in history we fixed how Shelley deal with interruptions. They should enter the conversation as soon as possible. Find the old conversation and reimplement the same. (Or pull the change from git if that makes sense.)

Previously, user messages queued during tool execution would wait until the
entire turn completed (all tool calls finished). This could be frustrating
when an agent was in a long chain of tool calls and the user wanted to
interrupt or provide additional guidance.

Now, after each tool completes, the loop checks for queued user messages and
adds them to history before the next LLM request. This means:

- User messages are visible to the LLM after the current tool completes
- The LLM can choose to stop the chain or adjust based on user input
- Tools that were already invoked still complete (no mid-tool cancellation)

Added tests:
- TestInterruptionDuringToolExecution: verifies interruption is seen after tool
- TestInterruptionDuringMultiToolChain: verifies chain can be stopped early
- TestNoInterruptionNormalFlow: verifies normal behavior is unchanged

Co-authored-by: Shelley <shelley@exe.dev>

Change summary

loop/interruption_test.go | 446 +++++++++++++++++++++++++++++++++++++++++
loop/loop.go              |   9 
2 files changed, 455 insertions(+)

Detailed changes

loop/interruption_test.go 🔗

@@ -0,0 +1,446 @@
+package loop
+
+import (
+	"context"
+	"encoding/json"
+	"fmt"
+	"sync"
+	"sync/atomic"
+	"testing"
+	"time"
+
+	"shelley.exe.dev/llm"
+)
+
+// TestInterruptionDuringToolExecution tests that user messages queued during
+// tool execution are processed after the tool completes but before the next
+// tool starts (not at the end of the entire turn).
+func TestInterruptionDuringToolExecution(t *testing.T) {
+	// Track when the tool is called and when it completes
+	var toolStarted atomic.Bool
+	var toolCompleted atomic.Bool
+	var interruptionSeen atomic.Bool
+
+	// Create a slow tool
+	slowTool := &llm.Tool{
+		Name:        "slow_tool",
+		Description: "A tool that takes time to execute",
+		InputSchema: llm.MustSchema(`{"type": "object", "properties": {"input": {"type": "string"}}}`),
+		Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut {
+			toolStarted.Store(true)
+			// Sleep to simulate slow tool execution
+			time.Sleep(200 * time.Millisecond)
+			toolCompleted.Store(true)
+			return llm.ToolOut{
+				LLMContent: []llm.Content{
+					{Type: llm.ContentTypeText, Text: "Tool completed"},
+				},
+			}
+		},
+	}
+
+	recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
+		return nil
+	}
+
+	// Create a service that detects the interruption
+	service := &customPredictableService{
+		responseFunc: func(req *llm.Request) (*llm.Response, error) {
+			// Check if we've seen the interruption
+			toolResults := 0
+			for _, msg := range req.Messages {
+				for _, c := range msg.Content {
+					if c.Type == llm.ContentTypeToolResult {
+						toolResults++
+					}
+					if c.Type == llm.ContentTypeText && c.Text == "INTERRUPTION" {
+						interruptionSeen.Store(true)
+						return &llm.Response{
+							Role:       llm.MessageRoleAssistant,
+							StopReason: llm.StopReasonEndTurn,
+							Content: []llm.Content{
+								{Type: llm.ContentTypeText, Text: "Acknowledged interruption"},
+							},
+						}, nil
+					}
+				}
+			}
+
+			// First call: use the slow tool
+			if toolResults == 0 {
+				return &llm.Response{
+					Role:       llm.MessageRoleAssistant,
+					StopReason: llm.StopReasonToolUse,
+					Content: []llm.Content{
+						{Type: llm.ContentTypeText, Text: "I'll use the slow tool"},
+						{
+							Type:      llm.ContentTypeToolUse,
+							ID:        "tool_1",
+							ToolName:  "slow_tool",
+							ToolInput: json.RawMessage(`{"input":"test"}`),
+						},
+					},
+				}, nil
+			}
+
+			// After tool result, continue with more work
+			return &llm.Response{
+				Role:       llm.MessageRoleAssistant,
+				StopReason: llm.StopReasonEndTurn,
+				Content: []llm.Content{
+					{Type: llm.ContentTypeText, Text: "Done with tool"},
+				},
+			}, nil
+		},
+	}
+
+	loop := NewLoop(Config{
+		LLM:           service,
+		History:       []llm.Message{},
+		Tools:         []*llm.Tool{slowTool},
+		RecordMessage: recordMessage,
+	})
+
+	// Queue initial user message that will trigger tool use
+	loop.QueueUserMessage(llm.Message{
+		Role:    llm.MessageRoleUser,
+		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "use the tool"}},
+	})
+
+	// Run the loop in background
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	var loopDone sync.WaitGroup
+	loopDone.Add(1)
+	go func() {
+		defer loopDone.Done()
+		loop.Go(ctx)
+	}()
+
+	// Wait for tool to start
+	for !toolStarted.Load() {
+		time.Sleep(10 * time.Millisecond)
+	}
+
+	// Queue an interruption message while tool is executing
+	loop.QueueUserMessage(llm.Message{
+		Role:    llm.MessageRoleUser,
+		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "INTERRUPTION"}},
+	})
+	t.Log("Queued interruption message while tool is executing")
+
+	// The message should remain in queue while tool is executing
+	time.Sleep(50 * time.Millisecond)
+	if !toolCompleted.Load() {
+		loop.mu.Lock()
+		queueLen := len(loop.messageQueue)
+		loop.mu.Unlock()
+		if queueLen > 0 {
+			t.Log("Message is waiting in queue during tool execution (expected)")
+		}
+	}
+
+	// Wait for loop to finish
+	time.Sleep(500 * time.Millisecond)
+	cancel()
+	loopDone.Wait()
+
+	// Verify the interruption was seen by the LLM
+	if interruptionSeen.Load() {
+		t.Log("SUCCESS: Interruption was seen by LLM after tool completed")
+	} else {
+		t.Error("Interruption was never seen by the LLM")
+	}
+}
+
+// TestInterruptionDuringMultiToolChain tests interruption during a chain of tool calls.
+// With the fix, the interruption should be visible to the LLM after the first tool completes.
+func TestInterruptionDuringMultiToolChain(t *testing.T) {
+	var toolCallCount atomic.Int32
+	var interruptionSeenAtToolResult atomic.Int32 // -1 means not seen
+
+	// Create a tool that's called multiple times
+	multiTool := &llm.Tool{
+		Name:        "multi_tool",
+		Description: "A tool that might be called multiple times",
+		InputSchema: llm.MustSchema(`{"type": "object", "properties": {"step": {"type": "integer"}}}`),
+		Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut {
+			count := toolCallCount.Add(1)
+			time.Sleep(100 * time.Millisecond) // Simulate some work
+			_ = count
+			return llm.ToolOut{
+				LLMContent: []llm.Content{
+					{Type: llm.ContentTypeText, Text: "Tool step completed"},
+				},
+			}
+		},
+	}
+
+	recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
+		return nil
+	}
+
+	// Service that makes multiple tool calls but stops when it sees "STOP"
+	interruptionSeenAtToolResult.Store(-1)
+	service := &customPredictableService{
+		responseFunc: func(req *llm.Request) (*llm.Response, error) {
+			// Check if we've seen the STOP message
+			toolResults := 0
+			for _, msg := range req.Messages {
+				for _, c := range msg.Content {
+					if c.Type == llm.ContentTypeToolResult {
+						toolResults++
+					}
+					if c.Type == llm.ContentTypeText && c.Text == "STOP" {
+						// Record when we first saw the interruption
+						interruptionSeenAtToolResult.CompareAndSwap(-1, int32(toolResults))
+						// Stop immediately when we see the interruption
+						return &llm.Response{
+							Role:       llm.MessageRoleAssistant,
+							StopReason: llm.StopReasonEndTurn,
+							Content: []llm.Content{
+								{Type: llm.ContentTypeText, Text: "Stopped due to user interruption"},
+							},
+						}, nil
+					}
+				}
+			}
+
+			if toolResults < 5 {
+				// Keep calling the tool (would do 5 if not interrupted)
+				return &llm.Response{
+					Role:       llm.MessageRoleAssistant,
+					StopReason: llm.StopReasonToolUse,
+					Content: []llm.Content{
+						{Type: llm.ContentTypeText, Text: "Calling tool again"},
+						{
+							Type:      llm.ContentTypeToolUse,
+							ID:        fmt.Sprintf("tool_%d", toolResults+1),
+							ToolName:  "multi_tool",
+							ToolInput: json.RawMessage(fmt.Sprintf(`{"step":%d}`, toolResults+1)),
+						},
+					},
+				}, nil
+			}
+
+			// Done with tools
+			return &llm.Response{
+				Role:       llm.MessageRoleAssistant,
+				StopReason: llm.StopReasonEndTurn,
+				Content: []llm.Content{
+					{Type: llm.ContentTypeText, Text: "All tools completed"},
+				},
+			}, nil
+		},
+	}
+
+	loop := NewLoop(Config{
+		LLM:           service,
+		History:       []llm.Message{},
+		Tools:         []*llm.Tool{multiTool},
+		RecordMessage: recordMessage,
+	})
+
+	// Queue initial user message
+	loop.QueueUserMessage(llm.Message{
+		Role:    llm.MessageRoleUser,
+		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "run the tool 5 times"}},
+	})
+
+	ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
+	defer cancel()
+
+	var loopDone sync.WaitGroup
+	loopDone.Add(1)
+	go func() {
+		defer loopDone.Done()
+		loop.Go(ctx)
+	}()
+
+	// Wait for first tool call to complete
+	for toolCallCount.Load() < 1 {
+		time.Sleep(10 * time.Millisecond)
+	}
+
+	// Queue interruption after first tool
+	loop.QueueUserMessage(llm.Message{
+		Role:    llm.MessageRoleUser,
+		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "STOP"}},
+	})
+	t.Logf("Queued STOP message after tool call %d", toolCallCount.Load())
+
+	// Wait for loop to process and stop
+	time.Sleep(500 * time.Millisecond)
+
+	cancel()
+	loopDone.Wait()
+
+	finalToolCount := toolCallCount.Load()
+	seenAt := interruptionSeenAtToolResult.Load()
+
+	t.Logf("Final tool call count: %d (would be 5 without interruption)", finalToolCount)
+	t.Logf("Interruption was seen by LLM after tool result %d", seenAt)
+
+	// With the fix, the interruption should be seen after just 1 tool result
+	// (the tool that was running when we queued the STOP message)
+	if seenAt == 1 {
+		t.Log("SUCCESS: Interruption was processed immediately after first tool completed")
+	} else if seenAt > 1 {
+		t.Errorf("Interruption was delayed: seen after %d tool results, expected 1", seenAt)
+	} else if seenAt == -1 {
+		t.Error("Interruption was never seen by the LLM")
+	}
+
+	// The tool should only be called a small number of times since we interrupted
+	if finalToolCount > 2 {
+		t.Errorf("Too many tool calls (%d): interruption should have stopped the chain earlier", finalToolCount)
+	}
+}
+
+// customPredictableService allows custom response logic for testing
+type customPredictableService struct {
+	responses    []customResponse
+	responseFunc func(req *llm.Request) (*llm.Response, error)
+	callIndex    int
+	mu           sync.Mutex
+}
+
+type customResponse struct {
+	response *llm.Response
+	err      error
+}
+
+func (s *customPredictableService) Do(ctx context.Context, req *llm.Request) (*llm.Response, error) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	if s.responseFunc != nil {
+		return s.responseFunc(req)
+	}
+
+	if s.callIndex >= len(s.responses) {
+		// Default response
+		return &llm.Response{
+			Role:       llm.MessageRoleAssistant,
+			StopReason: llm.StopReasonEndTurn,
+			Content: []llm.Content{
+				{Type: llm.ContentTypeText, Text: "No more responses configured"},
+			},
+		}, nil
+	}
+
+	resp := s.responses[s.callIndex]
+	s.callIndex++
+	return resp.response, resp.err
+}
+
+func (s *customPredictableService) GetDefaultModel() string {
+	return "custom-test"
+}
+
+func (s *customPredictableService) TokenContextWindow() int {
+	return 100000
+}
+
+func (s *customPredictableService) MaxImageDimension() int {
+	return 8000
+}
+
+// TestNoInterruptionNormalFlow verifies that normal tool chains work correctly
+// when no interruption is queued.
+func TestNoInterruptionNormalFlow(t *testing.T) {
+	var toolCallCount atomic.Int32
+
+	// Create a tool that tracks calls
+	multiTool := &llm.Tool{
+		Name:        "multi_tool",
+		Description: "A tool",
+		InputSchema: llm.MustSchema(`{"type": "object", "properties": {"step": {"type": "integer"}}}`),
+		Run: func(ctx context.Context, input json.RawMessage) llm.ToolOut {
+			toolCallCount.Add(1)
+			return llm.ToolOut{
+				LLMContent: []llm.Content{
+					{Type: llm.ContentTypeText, Text: "done"},
+				},
+			}
+		},
+	}
+
+	recordMessage := func(ctx context.Context, message llm.Message, usage llm.Usage) error {
+		return nil
+	}
+
+	// Service that makes 3 tool calls then finishes
+	service := &customPredictableService{
+		responseFunc: func(req *llm.Request) (*llm.Response, error) {
+			toolResults := 0
+			for _, msg := range req.Messages {
+				for _, c := range msg.Content {
+					if c.Type == llm.ContentTypeToolResult {
+						toolResults++
+					}
+				}
+			}
+
+			if toolResults < 3 {
+				return &llm.Response{
+					Role:       llm.MessageRoleAssistant,
+					StopReason: llm.StopReasonToolUse,
+					Content: []llm.Content{
+						{Type: llm.ContentTypeText, Text: "Calling tool"},
+						{
+							Type:      llm.ContentTypeToolUse,
+							ID:        fmt.Sprintf("tool_%d", toolResults+1),
+							ToolName:  "multi_tool",
+							ToolInput: json.RawMessage(fmt.Sprintf(`{"step":%d}`, toolResults+1)),
+						},
+					},
+				}, nil
+			}
+
+			return &llm.Response{
+				Role:       llm.MessageRoleAssistant,
+				StopReason: llm.StopReasonEndTurn,
+				Content: []llm.Content{
+					{Type: llm.ContentTypeText, Text: "All done"},
+				},
+			}, nil
+		},
+	}
+
+	loop := NewLoop(Config{
+		LLM:           service,
+		History:       []llm.Message{},
+		Tools:         []*llm.Tool{multiTool},
+		RecordMessage: recordMessage,
+	})
+
+	// Queue initial user message (no interruption)
+	loop.QueueUserMessage(llm.Message{
+		Role:    llm.MessageRoleUser,
+		Content: []llm.Content{{Type: llm.ContentTypeText, Text: "run tools"}},
+	})
+
+	ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
+	defer cancel()
+
+	var loopDone sync.WaitGroup
+	loopDone.Add(1)
+	go func() {
+		defer loopDone.Done()
+		loop.Go(ctx)
+	}()
+
+	// Wait for completion
+	time.Sleep(500 * time.Millisecond)
+	cancel()
+	loopDone.Wait()
+
+	finalCount := toolCallCount.Load()
+	if finalCount != 3 {
+		t.Errorf("Expected 3 tool calls, got %d", finalCount)
+	} else {
+		t.Log("SUCCESS: Normal flow completed 3 tool calls as expected")
+	}
+}

loop/loop.go 🔗

@@ -441,6 +441,15 @@ func (l *Loop) handleToolCalls(ctx context.Context, content []llm.Content) error
 
 		l.mu.Lock()
 		l.history = append(l.history, toolMessage)
+		// Check for queued user messages (interruptions) before continuing.
+		// This allows user messages to be processed as soon as possible.
+		if len(l.messageQueue) > 0 {
+			for _, msg := range l.messageQueue {
+				l.history = append(l.history, msg)
+			}
+			l.messageQueue = l.messageQueue[:0]
+			l.logger.Info("processing user interruption during tool execution")
+		}
 		l.mu.Unlock()
 
 		// Record tool result message