loop.go

  1package loop
  2
  3import (
  4	"context"
  5	"fmt"
  6	"io"
  7	"log/slog"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"shelley.exe.dev/claudetool"
 13	"shelley.exe.dev/gitstate"
 14	"shelley.exe.dev/llm"
 15)
 16
 17// MessageRecordFunc is called to record new messages to persistent storage.
 18type MessageRecordFunc func(ctx context.Context, message llm.Message, usage llm.Usage) error
 19
 20// GitStateChangeFunc is called when the git state changes at the end of a turn.
 21// This is used to record user-visible notifications about git changes.
 22type GitStateChangeFunc func(ctx context.Context, state *gitstate.GitState)
 23
 24// Config contains all configuration needed to create a Loop.
 25type Config struct {
 26	LLM              llm.Service
 27	History          []llm.Message
 28	Tools            []*llm.Tool
 29	RecordMessage    MessageRecordFunc
 30	Logger           *slog.Logger
 31	System           []llm.SystemContent
 32	WorkingDir       string // working directory for tools
 33	OnGitStateChange GitStateChangeFunc
 34	// GetWorkingDir returns the current working directory for tools.
 35	// If set, this is called at end of turn to check for git state changes.
 36	// If nil, Config.WorkingDir is used as a static value.
 37	GetWorkingDir func() string
 38}
 39
 40// Loop manages a conversation turn with an LLM including tool execution and message recording.
 41// Notably, when the turn ends, the "Loop" is over. TODO: maybe rename to Turn?
 42type Loop struct {
 43	llm              llm.Service
 44	tools            []*llm.Tool
 45	recordMessage    MessageRecordFunc
 46	history          []llm.Message
 47	messageQueue     []llm.Message
 48	totalUsage       llm.Usage
 49	mu               sync.Mutex
 50	logger           *slog.Logger
 51	system           []llm.SystemContent
 52	workingDir       string
 53	onGitStateChange GitStateChangeFunc
 54	getWorkingDir    func() string
 55	lastGitState     *gitstate.GitState
 56}
 57
 58// NewLoop creates a new Loop instance with the provided configuration
 59func NewLoop(config Config) *Loop {
 60	logger := config.Logger
 61	if logger == nil {
 62		logger = slog.Default()
 63	}
 64
 65	// Get initial git state
 66	workingDir := config.WorkingDir
 67	if config.GetWorkingDir != nil {
 68		workingDir = config.GetWorkingDir()
 69	}
 70	initialGitState := gitstate.GetGitState(workingDir)
 71
 72	return &Loop{
 73		llm:              config.LLM,
 74		history:          config.History,
 75		tools:            config.Tools,
 76		recordMessage:    config.RecordMessage,
 77		messageQueue:     make([]llm.Message, 0),
 78		logger:           logger,
 79		system:           config.System,
 80		workingDir:       config.WorkingDir,
 81		onGitStateChange: config.OnGitStateChange,
 82		getWorkingDir:    config.GetWorkingDir,
 83		lastGitState:     initialGitState,
 84	}
 85}
 86
 87// QueueUserMessage adds a user message to the queue to be processed
 88func (l *Loop) QueueUserMessage(message llm.Message) {
 89	l.mu.Lock()
 90	defer l.mu.Unlock()
 91	l.messageQueue = append(l.messageQueue, message)
 92	l.logger.Debug("queued user message", "content_count", len(message.Content))
 93}
 94
 95// GetUsage returns the total usage accumulated by this loop
 96func (l *Loop) GetUsage() llm.Usage {
 97	l.mu.Lock()
 98	defer l.mu.Unlock()
 99	return l.totalUsage
100}
101
102// GetHistory returns a copy of the current conversation history
103func (l *Loop) GetHistory() []llm.Message {
104	l.mu.Lock()
105	defer l.mu.Unlock()
106	// Deep copy the messages to prevent modifications
107	historyCopy := make([]llm.Message, len(l.history))
108	for i, msg := range l.history {
109		// Copy the message
110		historyCopy[i] = llm.Message{
111			Role:    msg.Role,
112			ToolUse: msg.ToolUse, // This is a pointer, but we won't modify it in tests
113			Content: make([]llm.Content, len(msg.Content)),
114		}
115		// Copy content slice
116		copy(historyCopy[i].Content, msg.Content)
117	}
118	return historyCopy
119}
120
121// Go runs the conversation loop until the context is canceled
122func (l *Loop) Go(ctx context.Context) error {
123	if l.llm == nil {
124		return fmt.Errorf("no LLM service configured")
125	}
126
127	l.logger.Info("starting conversation loop", "tools", len(l.tools))
128
129	for {
130		select {
131		case <-ctx.Done():
132			l.logger.Info("conversation loop canceled")
133			return ctx.Err()
134		default:
135		}
136
137		// Process any queued messages
138		l.mu.Lock()
139		hasQueuedMessages := len(l.messageQueue) > 0
140		if hasQueuedMessages {
141			// Add queued messages to history (they are already recorded to DB by ConversationManager)
142			for _, msg := range l.messageQueue {
143				l.history = append(l.history, msg)
144			}
145			l.messageQueue = l.messageQueue[:0] // Clear queue
146		}
147		l.mu.Unlock()
148
149		if hasQueuedMessages {
150			// Send request to LLM
151			l.logger.Debug("processing queued messages", "count", 1)
152			if err := l.processLLMRequest(ctx); err != nil {
153				l.logger.Error("failed to process LLM request", "error", err)
154				time.Sleep(time.Second) // Wait before retrying
155				continue
156			}
157			l.logger.Debug("finished processing queued messages")
158		} else {
159			// No queued messages, wait a bit
160			select {
161			case <-ctx.Done():
162				return ctx.Err()
163			case <-time.After(100 * time.Millisecond):
164				// Continue loop
165			}
166		}
167	}
168}
169
170// ProcessOneTurn processes queued messages through one complete turn (user message + assistant response)
171// It stops after the assistant responds, regardless of whether tools were called
172func (l *Loop) ProcessOneTurn(ctx context.Context) error {
173	if l.llm == nil {
174		return fmt.Errorf("no LLM service configured")
175	}
176
177	// Process any queued messages first
178	l.mu.Lock()
179	if len(l.messageQueue) > 0 {
180		// Add queued messages to history (they are already recorded to DB by ConversationManager)
181		for _, msg := range l.messageQueue {
182			l.history = append(l.history, msg)
183		}
184		l.messageQueue = nil
185	}
186	l.mu.Unlock()
187
188	// Process one LLM request and response
189	return l.processLLMRequest(ctx)
190}
191
192// processLLMRequest sends a request to the LLM and handles the response
193func (l *Loop) processLLMRequest(ctx context.Context) error {
194	l.mu.Lock()
195	messages := append([]llm.Message(nil), l.history...)
196	tools := l.tools
197	system := l.system
198	llmService := l.llm
199	l.mu.Unlock()
200
201	// Enable prompt caching: set cache flag on last tool and last user message content
202	// See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
203	if len(tools) > 0 {
204		// Make a copy of tools to avoid modifying the shared slice
205		tools = append([]*llm.Tool(nil), tools...)
206		// Copy the last tool and enable caching
207		lastTool := *tools[len(tools)-1]
208		lastTool.Cache = true
209		tools[len(tools)-1] = &lastTool
210	}
211
212	// Set cache flag on the last content block of the last user message
213	if len(messages) > 0 {
214		for i := len(messages) - 1; i >= 0; i-- {
215			if messages[i].Role == llm.MessageRoleUser && len(messages[i].Content) > 0 {
216				// Deep copy the message to avoid modifying the shared history
217				msg := messages[i]
218				msg.Content = append([]llm.Content(nil), msg.Content...)
219				msg.Content[len(msg.Content)-1].Cache = true
220				messages[i] = msg
221				break
222			}
223		}
224	}
225
226	req := &llm.Request{
227		Messages: messages,
228		Tools:    tools,
229		System:   system,
230	}
231
232	// Insert missing tool results if the previous message had tool_use blocks
233	// without corresponding tool_result blocks. This can happen when a request
234	// is cancelled or fails after the LLM responds but before tools execute.
235	l.insertMissingToolResults(req)
236
237	systemLen := 0
238	for _, sys := range system {
239		systemLen += len(sys.Text)
240	}
241	l.logger.Debug("sending LLM request", "message_count", len(messages), "tool_count", len(tools), "system_items", len(system), "system_length", systemLen)
242
243	// Add a timeout for the LLM request to prevent indefinite hangs
244	llmCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
245	defer cancel()
246
247	// Retry LLM requests that fail with retryable errors (EOF, connection reset)
248	const maxRetries = 2
249	var resp *llm.Response
250	var err error
251	for attempt := 1; attempt <= maxRetries; attempt++ {
252		resp, err = llmService.Do(llmCtx, req)
253		if err == nil {
254			break
255		}
256		if !isRetryableError(err) || attempt == maxRetries {
257			break
258		}
259		l.logger.Warn("LLM request failed with retryable error, retrying",
260			"error", err,
261			"attempt", attempt,
262			"max_retries", maxRetries)
263		time.Sleep(time.Second * time.Duration(attempt)) // Simple backoff
264	}
265	if err != nil {
266		// Record the error as a message so it can be displayed in the UI
267		// EndOfTurn must be true so the agent working state is properly updated
268		errorMessage := llm.Message{
269			Role: llm.MessageRoleAssistant,
270			Content: []llm.Content{
271				{
272					Type: llm.ContentTypeText,
273					Text: fmt.Sprintf("LLM request failed: %v", err),
274				},
275			},
276			EndOfTurn: true,
277			ErrorType: llm.ErrorTypeLLMRequest,
278		}
279		if recordErr := l.recordMessage(ctx, errorMessage, llm.Usage{}); recordErr != nil {
280			l.logger.Error("failed to record error message", "error", recordErr)
281		}
282		return fmt.Errorf("LLM request failed: %w", err)
283	}
284
285	l.logger.Debug("received LLM response", "content_count", len(resp.Content), "stop_reason", resp.StopReason.String(), "usage", resp.Usage.String())
286
287	// Update total usage
288	l.mu.Lock()
289	l.totalUsage.Add(resp.Usage)
290	l.mu.Unlock()
291
292	// Handle max tokens truncation BEFORE adding to history - truncated responses
293	// should not be added to history normally (they get special handling)
294	if resp.StopReason == llm.StopReasonMaxTokens {
295		l.logger.Warn("LLM response truncated due to max tokens")
296		return l.handleMaxTokensTruncation(ctx, resp)
297	}
298
299	// Convert response to message and add to history
300	assistantMessage := resp.ToMessage()
301	l.mu.Lock()
302	l.history = append(l.history, assistantMessage)
303	l.mu.Unlock()
304
305	// Record assistant message with model and timing metadata
306	usageWithMeta := resp.Usage
307	usageWithMeta.Model = resp.Model
308	usageWithMeta.StartTime = resp.StartTime
309	usageWithMeta.EndTime = resp.EndTime
310	if err := l.recordMessage(ctx, assistantMessage, usageWithMeta); err != nil {
311		l.logger.Error("failed to record assistant message", "error", err)
312	}
313
314	// Handle tool calls if any
315	if resp.StopReason == llm.StopReasonToolUse {
316		l.logger.Debug("handling tool calls", "content_count", len(resp.Content))
317		return l.handleToolCalls(ctx, resp.Content)
318	}
319
320	// End of turn - check for git state changes
321	l.checkGitStateChange(ctx)
322
323	return nil
324}
325
326// checkGitStateChange checks if the git state has changed and calls the callback if so.
327// This is called at the end of each turn.
328func (l *Loop) checkGitStateChange(ctx context.Context) {
329	if l.onGitStateChange == nil {
330		return
331	}
332
333	// Get current working directory
334	workingDir := l.workingDir
335	if l.getWorkingDir != nil {
336		workingDir = l.getWorkingDir()
337	}
338
339	// Get current git state
340	currentState := gitstate.GetGitState(workingDir)
341
342	// Compare with last known state
343	l.mu.Lock()
344	lastState := l.lastGitState
345	l.mu.Unlock()
346
347	// Check if state changed
348	if !currentState.Equal(lastState) {
349		l.mu.Lock()
350		l.lastGitState = currentState
351		l.mu.Unlock()
352
353		if currentState.IsRepo {
354			l.logger.Debug("git state changed",
355				"worktree", currentState.Worktree,
356				"branch", currentState.Branch,
357				"commit", currentState.Commit)
358			l.onGitStateChange(ctx, currentState)
359		}
360	}
361}
362
363// handleMaxTokensTruncation handles the case where the LLM response was truncated
364// due to hitting the maximum output token limit. It records the truncated message
365// for cost tracking (excluded from context) and an error message for the user.
366func (l *Loop) handleMaxTokensTruncation(ctx context.Context, resp *llm.Response) error {
367	// Record the truncated message for cost tracking, but mark it as excluded from context.
368	// This preserves billing information without confusing the LLM on future turns.
369	truncatedMessage := resp.ToMessage()
370	truncatedMessage.ExcludedFromContext = true
371
372	// Record the truncated message with usage metadata
373	usageWithMeta := resp.Usage
374	usageWithMeta.Model = resp.Model
375	usageWithMeta.StartTime = resp.StartTime
376	usageWithMeta.EndTime = resp.EndTime
377	if err := l.recordMessage(ctx, truncatedMessage, usageWithMeta); err != nil {
378		l.logger.Error("failed to record truncated message", "error", err)
379	}
380
381	// Record a truncation error message with EndOfTurn=true to properly signal end of turn.
382	errorMessage := llm.Message{
383		Role: llm.MessageRoleAssistant,
384		Content: []llm.Content{
385			{
386				Type: llm.ContentTypeText,
387				Text: "[SYSTEM ERROR: Your previous response was truncated because it exceeded the maximum output token limit. " +
388					"Any tool calls in that response were lost. Please retry with smaller, incremental changes. " +
389					"For file operations, break large changes into multiple smaller patches. " +
390					"The user can ask you to continue if needed.]",
391			},
392		},
393		EndOfTurn: true,
394		ErrorType: llm.ErrorTypeTruncation,
395	}
396
397	l.mu.Lock()
398	l.history = append(l.history, errorMessage)
399	l.mu.Unlock()
400
401	// Record the truncation error message
402	if err := l.recordMessage(ctx, errorMessage, llm.Usage{}); err != nil {
403		l.logger.Error("failed to record truncation error message", "error", err)
404	}
405
406	// End the turn - don't automatically continue
407	l.checkGitStateChange(ctx)
408	return nil
409}
410
411// handleToolCalls processes tool calls from the LLM response
412func (l *Loop) handleToolCalls(ctx context.Context, content []llm.Content) error {
413	var toolResults []llm.Content
414
415	for _, c := range content {
416		if c.Type != llm.ContentTypeToolUse {
417			continue
418		}
419
420		l.logger.Debug("executing tool", "name", c.ToolName, "id", c.ID)
421
422		// Find the tool
423		var tool *llm.Tool
424		for _, t := range l.tools {
425			if t.Name == c.ToolName {
426				tool = t
427				break
428			}
429		}
430
431		if tool == nil {
432			l.logger.Error("tool not found", "name", c.ToolName)
433			toolResults = append(toolResults, llm.Content{
434				Type:      llm.ContentTypeToolResult,
435				ToolUseID: c.ID,
436				ToolError: true,
437				ToolResult: []llm.Content{
438					{Type: llm.ContentTypeText, Text: fmt.Sprintf("Tool '%s' not found", c.ToolName)},
439				},
440			})
441			continue
442		}
443
444		// Execute the tool with working directory set in context
445		toolCtx := ctx
446		if l.workingDir != "" {
447			toolCtx = claudetool.WithWorkingDir(ctx, l.workingDir)
448		}
449		startTime := time.Now()
450		result := tool.Run(toolCtx, c.ToolInput)
451		endTime := time.Now()
452
453		var toolResultContent []llm.Content
454		if result.Error != nil {
455			l.logger.Error("tool execution failed", "name", c.ToolName, "error", result.Error)
456			toolResultContent = []llm.Content{
457				{Type: llm.ContentTypeText, Text: result.Error.Error()},
458			}
459		} else {
460			toolResultContent = result.LLMContent
461			l.logger.Debug("tool executed successfully", "name", c.ToolName, "duration", endTime.Sub(startTime))
462		}
463
464		toolResults = append(toolResults, llm.Content{
465			Type:             llm.ContentTypeToolResult,
466			ToolUseID:        c.ID,
467			ToolError:        result.Error != nil,
468			ToolResult:       toolResultContent,
469			ToolUseStartTime: &startTime,
470			ToolUseEndTime:   &endTime,
471			Display:          result.Display,
472		})
473	}
474
475	if len(toolResults) > 0 {
476		// Add tool results to history as a user message
477		toolMessage := llm.Message{
478			Role:    llm.MessageRoleUser,
479			Content: toolResults,
480		}
481
482		l.mu.Lock()
483		l.history = append(l.history, toolMessage)
484		// Check for queued user messages (interruptions) before continuing.
485		// This allows user messages to be processed as soon as possible.
486		if len(l.messageQueue) > 0 {
487			for _, msg := range l.messageQueue {
488				l.history = append(l.history, msg)
489			}
490			l.messageQueue = l.messageQueue[:0]
491			l.logger.Info("processing user interruption during tool execution")
492		}
493		l.mu.Unlock()
494
495		// Record tool result message
496		if err := l.recordMessage(ctx, toolMessage, llm.Usage{}); err != nil {
497			l.logger.Error("failed to record tool result message", "error", err)
498		}
499
500		// Process another LLM request with the tool results
501		return l.processLLMRequest(ctx)
502	}
503
504	return nil
505}
506
507// insertMissingToolResults fixes tool_result issues in the conversation history:
508//  1. Adds error results for tool_uses that were requested but not included in the next message.
509//     This can happen when a request is cancelled or fails after the LLM responds with tool_use
510//     blocks but before the tools execute.
511//  2. Removes orphan tool_results that reference tool_use IDs not present in the immediately
512//     preceding assistant message. This can happen when a tool execution completes after
513//     CancelConversation has already written cancellation messages.
514//
515// This prevents API errors like:
516//   - "tool_use ids were found without tool_result blocks"
517//   - "unexpected tool_use_id found in tool_result blocks ... Each tool_result block must have
518//     a corresponding tool_use block in the previous message"
519//
520// Mutates the request's Messages slice.
521func (l *Loop) insertMissingToolResults(req *llm.Request) {
522	if len(req.Messages) < 1 {
523		return
524	}
525
526	// Scan through all messages looking for assistant messages with tool_use
527	// that are not immediately followed by a user message with corresponding tool_results.
528	// We may need to insert synthetic user messages with tool_results or filter orphans.
529	var newMessages []llm.Message
530	totalInserted := 0
531	totalRemoved := 0
532
533	// Track the tool_use IDs from the most recent assistant message
534	var prevAssistantToolUseIDs map[string]bool
535
536	for i := 0; i < len(req.Messages); i++ {
537		msg := req.Messages[i]
538
539		if msg.Role == llm.MessageRoleAssistant {
540			// Handle empty assistant messages - add placeholder content if not the last message
541			// The API requires all messages to have non-empty content except for the optional
542			// final assistant message. Empty content can happen when the model ends its turn
543			// without producing any output.
544			if len(msg.Content) == 0 && i < len(req.Messages)-1 {
545				req.Messages[i].Content = []llm.Content{{Type: llm.ContentTypeText, Text: "(no response)"}}
546				msg = req.Messages[i] // update local copy for subsequent processing
547				l.logger.Debug("added placeholder content to empty assistant message", "index", i)
548			}
549
550			// Track all tool_use IDs in this assistant message
551			prevAssistantToolUseIDs = make(map[string]bool)
552			for _, c := range msg.Content {
553				if c.Type == llm.ContentTypeToolUse {
554					prevAssistantToolUseIDs[c.ID] = true
555				}
556			}
557			newMessages = append(newMessages, msg)
558
559			// Check if next message needs synthetic tool_results
560			var toolUseContents []llm.Content
561			for _, c := range msg.Content {
562				if c.Type == llm.ContentTypeToolUse {
563					toolUseContents = append(toolUseContents, c)
564				}
565			}
566
567			if len(toolUseContents) == 0 {
568				continue
569			}
570
571			// Check if next message is a user message with corresponding tool_results
572			var nextMsg *llm.Message
573			if i+1 < len(req.Messages) {
574				nextMsg = &req.Messages[i+1]
575			}
576
577			if nextMsg == nil || nextMsg.Role != llm.MessageRoleUser {
578				// Next message is not a user message (or there is no next message).
579				// Insert a synthetic user message with tool_results for all tool_uses.
580				var toolResultContent []llm.Content
581				for _, tu := range toolUseContents {
582					toolResultContent = append(toolResultContent, llm.Content{
583						Type:      llm.ContentTypeToolResult,
584						ToolUseID: tu.ID,
585						ToolError: true,
586						ToolResult: []llm.Content{{
587							Type: llm.ContentTypeText,
588							Text: "not executed; retry possible",
589						}},
590					})
591				}
592				syntheticMsg := llm.Message{
593					Role:    llm.MessageRoleUser,
594					Content: toolResultContent,
595				}
596				newMessages = append(newMessages, syntheticMsg)
597				totalInserted += len(toolResultContent)
598			}
599		} else if msg.Role == llm.MessageRoleUser {
600			// Filter out orphan tool_results and add missing ones
601			var filteredContent []llm.Content
602			existingResultIDs := make(map[string]bool)
603
604			for _, c := range msg.Content {
605				if c.Type == llm.ContentTypeToolResult {
606					// Only keep tool_results that match a tool_use in the previous assistant message
607					if prevAssistantToolUseIDs != nil && prevAssistantToolUseIDs[c.ToolUseID] {
608						filteredContent = append(filteredContent, c)
609						existingResultIDs[c.ToolUseID] = true
610					} else {
611						// Orphan tool_result - skip it
612						totalRemoved++
613						l.logger.Debug("removing orphan tool_result", "tool_use_id", c.ToolUseID)
614					}
615				} else {
616					// Keep non-tool_result content
617					filteredContent = append(filteredContent, c)
618				}
619			}
620
621			// Check if we need to add missing tool_results for this user message
622			if prevAssistantToolUseIDs != nil {
623				var prefix []llm.Content
624				for toolUseID := range prevAssistantToolUseIDs {
625					if !existingResultIDs[toolUseID] {
626						prefix = append(prefix, llm.Content{
627							Type:      llm.ContentTypeToolResult,
628							ToolUseID: toolUseID,
629							ToolError: true,
630							ToolResult: []llm.Content{{
631								Type: llm.ContentTypeText,
632								Text: "not executed; retry possible",
633							}},
634						})
635						totalInserted++
636					}
637				}
638				if len(prefix) > 0 {
639					filteredContent = append(prefix, filteredContent...)
640				}
641			}
642
643			// Only add the message if it has content
644			if len(filteredContent) > 0 {
645				msg.Content = filteredContent
646				newMessages = append(newMessages, msg)
647			} else {
648				// Message is now empty after filtering - skip it entirely
649				l.logger.Debug("removing empty user message after filtering orphan tool_results")
650			}
651
652			// Reset for next iteration - user message "consumes" the previous tool_uses
653			prevAssistantToolUseIDs = nil
654		} else {
655			newMessages = append(newMessages, msg)
656		}
657	}
658
659	if totalInserted > 0 || totalRemoved > 0 {
660		req.Messages = newMessages
661		if totalInserted > 0 {
662			l.logger.Debug("inserted missing tool results", "count", totalInserted)
663		}
664		if totalRemoved > 0 {
665			l.logger.Debug("removed orphan tool results", "count", totalRemoved)
666		}
667	}
668}
669
670// isRetryableError checks if an error is transient and should be retried.
671// This includes EOF errors (connection closed unexpectedly) and similar network issues.
672func isRetryableError(err error) bool {
673	if err == nil {
674		return false
675	}
676	// Check for io.EOF and io.ErrUnexpectedEOF
677	if err == io.EOF || err == io.ErrUnexpectedEOF {
678		return true
679	}
680	// Check error message for common retryable patterns
681	errStr := err.Error()
682	retryablePatterns := []string{
683		"EOF",
684		"connection reset",
685		"connection refused",
686		"no such host",
687		"network is unreachable",
688		"i/o timeout",
689	}
690	for _, pattern := range retryablePatterns {
691		if strings.Contains(errStr, pattern) {
692			return true
693		}
694	}
695	return false
696}