loop.go

  1package loop
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"sync"
  8	"time"
  9
 10	"shelley.exe.dev/claudetool"
 11	"shelley.exe.dev/gitstate"
 12	"shelley.exe.dev/llm"
 13)
 14
 15// MessageRecordFunc is called to record new messages to persistent storage
 16type MessageRecordFunc func(ctx context.Context, message llm.Message, usage llm.Usage) error
 17
 18// GitStateChangeFunc is called when the git state changes at the end of a turn.
 19// This is used to record user-visible notifications about git changes.
 20type GitStateChangeFunc func(ctx context.Context, state *gitstate.GitState)
 21
 22// Config contains all configuration needed to create a Loop
 23type Config struct {
 24	LLM              llm.Service
 25	History          []llm.Message
 26	Tools            []*llm.Tool
 27	RecordMessage    MessageRecordFunc
 28	Logger           *slog.Logger
 29	System           []llm.SystemContent
 30	WorkingDir       string // working directory for tools
 31	OnGitStateChange GitStateChangeFunc
 32	// GetWorkingDir returns the current working directory for tools.
 33	// If set, this is called at end of turn to check for git state changes.
 34	// If nil, Config.WorkingDir is used as a static value.
 35	GetWorkingDir func() string
 36}
 37
 38// Loop manages a conversation turn with an LLM including tool execution and message recording.
 39// Notably, when the turn ends, the "Loop" is over. TODO: maybe rename to Turn?
 40type Loop struct {
 41	llm              llm.Service
 42	tools            []*llm.Tool
 43	recordMessage    MessageRecordFunc
 44	history          []llm.Message
 45	messageQueue     []llm.Message
 46	totalUsage       llm.Usage
 47	mu               sync.Mutex
 48	logger           *slog.Logger
 49	system           []llm.SystemContent
 50	workingDir       string
 51	onGitStateChange GitStateChangeFunc
 52	getWorkingDir    func() string
 53	lastGitState     *gitstate.GitState
 54}
 55
 56// NewLoop creates a new Loop instance with the provided configuration
 57func NewLoop(config Config) *Loop {
 58	logger := config.Logger
 59	if logger == nil {
 60		logger = slog.Default()
 61	}
 62
 63	// Get initial git state
 64	workingDir := config.WorkingDir
 65	if config.GetWorkingDir != nil {
 66		workingDir = config.GetWorkingDir()
 67	}
 68	initialGitState := gitstate.GetGitState(workingDir)
 69
 70	return &Loop{
 71		llm:              config.LLM,
 72		history:          config.History,
 73		tools:            config.Tools,
 74		recordMessage:    config.RecordMessage,
 75		messageQueue:     make([]llm.Message, 0),
 76		logger:           logger,
 77		system:           config.System,
 78		workingDir:       config.WorkingDir,
 79		onGitStateChange: config.OnGitStateChange,
 80		getWorkingDir:    config.GetWorkingDir,
 81		lastGitState:     initialGitState,
 82	}
 83}
 84
 85// QueueUserMessage adds a user message to the queue to be processed
 86func (l *Loop) QueueUserMessage(message llm.Message) {
 87	l.mu.Lock()
 88	defer l.mu.Unlock()
 89	l.messageQueue = append(l.messageQueue, message)
 90	l.logger.Debug("queued user message", "content_count", len(message.Content))
 91}
 92
 93// GetUsage returns the total usage accumulated by this loop
 94func (l *Loop) GetUsage() llm.Usage {
 95	l.mu.Lock()
 96	defer l.mu.Unlock()
 97	return l.totalUsage
 98}
 99
100// GetHistory returns a copy of the current conversation history
101func (l *Loop) GetHistory() []llm.Message {
102	l.mu.Lock()
103	defer l.mu.Unlock()
104	// Deep copy the messages to prevent modifications
105	historyCopy := make([]llm.Message, len(l.history))
106	for i, msg := range l.history {
107		// Copy the message
108		historyCopy[i] = llm.Message{
109			Role:    msg.Role,
110			ToolUse: msg.ToolUse, // This is a pointer, but we won't modify it in tests
111			Content: make([]llm.Content, len(msg.Content)),
112		}
113		// Copy content slice
114		copy(historyCopy[i].Content, msg.Content)
115	}
116	return historyCopy
117}
118
119// Go runs the conversation loop until the context is canceled
120func (l *Loop) Go(ctx context.Context) error {
121	if l.llm == nil {
122		return fmt.Errorf("no LLM service configured")
123	}
124
125	l.logger.Info("starting conversation loop", "tools", len(l.tools))
126
127	for {
128		select {
129		case <-ctx.Done():
130			l.logger.Info("conversation loop canceled")
131			return ctx.Err()
132		default:
133		}
134
135		// Process any queued messages
136		l.mu.Lock()
137		hasQueuedMessages := len(l.messageQueue) > 0
138		if hasQueuedMessages {
139			// Add queued messages to history (they are already recorded to DB by ConversationManager)
140			for _, msg := range l.messageQueue {
141				l.history = append(l.history, msg)
142			}
143			l.messageQueue = l.messageQueue[:0] // Clear queue
144		}
145		l.mu.Unlock()
146
147		if hasQueuedMessages {
148			// Send request to LLM
149			l.logger.Debug("processing queued messages", "count", 1)
150			if err := l.processLLMRequest(ctx); err != nil {
151				l.logger.Error("failed to process LLM request", "error", err)
152				time.Sleep(time.Second) // Wait before retrying
153				continue
154			}
155			l.logger.Debug("finished processing queued messages")
156		} else {
157			// No queued messages, wait a bit
158			select {
159			case <-ctx.Done():
160				return ctx.Err()
161			case <-time.After(100 * time.Millisecond):
162				// Continue loop
163			}
164		}
165	}
166}
167
168// ProcessOneTurn processes queued messages through one complete turn (user message + assistant response)
169// It stops after the assistant responds, regardless of whether tools were called
170func (l *Loop) ProcessOneTurn(ctx context.Context) error {
171	if l.llm == nil {
172		return fmt.Errorf("no LLM service configured")
173	}
174
175	// Process any queued messages first
176	l.mu.Lock()
177	if len(l.messageQueue) > 0 {
178		// Add queued messages to history (they are already recorded to DB by ConversationManager)
179		for _, msg := range l.messageQueue {
180			l.history = append(l.history, msg)
181		}
182		l.messageQueue = nil
183	}
184	l.mu.Unlock()
185
186	// Process one LLM request and response
187	return l.processLLMRequest(ctx)
188}
189
190// processLLMRequest sends a request to the LLM and handles the response
191func (l *Loop) processLLMRequest(ctx context.Context) error {
192	l.mu.Lock()
193	messages := append([]llm.Message(nil), l.history...)
194	tools := l.tools
195	system := l.system
196	llmService := l.llm
197	l.mu.Unlock()
198
199	// Enable prompt caching: set cache flag on last tool and last user message content
200	// See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
201	if len(tools) > 0 {
202		// Make a copy of tools to avoid modifying the shared slice
203		tools = append([]*llm.Tool(nil), tools...)
204		// Copy the last tool and enable caching
205		lastTool := *tools[len(tools)-1]
206		lastTool.Cache = true
207		tools[len(tools)-1] = &lastTool
208	}
209
210	// Set cache flag on the last content block of the last user message
211	if len(messages) > 0 {
212		for i := len(messages) - 1; i >= 0; i-- {
213			if messages[i].Role == llm.MessageRoleUser && len(messages[i].Content) > 0 {
214				// Deep copy the message to avoid modifying the shared history
215				msg := messages[i]
216				msg.Content = append([]llm.Content(nil), msg.Content...)
217				msg.Content[len(msg.Content)-1].Cache = true
218				messages[i] = msg
219				break
220			}
221		}
222	}
223
224	req := &llm.Request{
225		Messages: messages,
226		Tools:    tools,
227		System:   system,
228	}
229
230	// Insert missing tool results if the previous message had tool_use blocks
231	// without corresponding tool_result blocks. This can happen when a request
232	// is cancelled or fails after the LLM responds but before tools execute.
233	l.insertMissingToolResults(req)
234
235	systemLen := 0
236	for _, sys := range system {
237		systemLen += len(sys.Text)
238	}
239	l.logger.Debug("sending LLM request", "message_count", len(messages), "tool_count", len(tools), "system_items", len(system), "system_length", systemLen)
240
241	// Add a timeout for the LLM request to prevent indefinite hangs
242	llmCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
243	defer cancel()
244
245	resp, err := llmService.Do(llmCtx, req)
246	if err != nil {
247		// Record the error as a message so it can be displayed in the UI
248		// EndOfTurn must be true so the agent working state is properly updated
249		errorMessage := llm.Message{
250			Role: llm.MessageRoleAssistant,
251			Content: []llm.Content{
252				{
253					Type: llm.ContentTypeText,
254					Text: fmt.Sprintf("LLM request failed: %v", err),
255				},
256			},
257			EndOfTurn: true,
258			ErrorType: llm.ErrorTypeLLMRequest,
259		}
260		if recordErr := l.recordMessage(ctx, errorMessage, llm.Usage{}); recordErr != nil {
261			l.logger.Error("failed to record error message", "error", recordErr)
262		}
263		return fmt.Errorf("LLM request failed: %w", err)
264	}
265
266	l.logger.Debug("received LLM response", "content_count", len(resp.Content), "stop_reason", resp.StopReason.String(), "usage", resp.Usage.String())
267
268	// Update total usage
269	l.mu.Lock()
270	l.totalUsage.Add(resp.Usage)
271	l.mu.Unlock()
272
273	// Handle max tokens truncation BEFORE adding to history - truncated responses
274	// should not be added to history normally (they get special handling)
275	if resp.StopReason == llm.StopReasonMaxTokens {
276		l.logger.Warn("LLM response truncated due to max tokens")
277		return l.handleMaxTokensTruncation(ctx, resp)
278	}
279
280	// Convert response to message and add to history
281	assistantMessage := resp.ToMessage()
282	l.mu.Lock()
283	l.history = append(l.history, assistantMessage)
284	l.mu.Unlock()
285
286	// Record assistant message with model and timing metadata
287	usageWithMeta := resp.Usage
288	usageWithMeta.Model = resp.Model
289	usageWithMeta.StartTime = resp.StartTime
290	usageWithMeta.EndTime = resp.EndTime
291	if err := l.recordMessage(ctx, assistantMessage, usageWithMeta); err != nil {
292		l.logger.Error("failed to record assistant message", "error", err)
293	}
294
295	// Handle tool calls if any
296	if resp.StopReason == llm.StopReasonToolUse {
297		l.logger.Debug("handling tool calls", "content_count", len(resp.Content))
298		return l.handleToolCalls(ctx, resp.Content)
299	}
300
301	// End of turn - check for git state changes
302	l.checkGitStateChange(ctx)
303
304	return nil
305}
306
307// checkGitStateChange checks if the git state has changed and calls the callback if so.
308// This is called at the end of each turn.
309func (l *Loop) checkGitStateChange(ctx context.Context) {
310	if l.onGitStateChange == nil {
311		return
312	}
313
314	// Get current working directory
315	workingDir := l.workingDir
316	if l.getWorkingDir != nil {
317		workingDir = l.getWorkingDir()
318	}
319
320	// Get current git state
321	currentState := gitstate.GetGitState(workingDir)
322
323	// Compare with last known state
324	l.mu.Lock()
325	lastState := l.lastGitState
326	l.mu.Unlock()
327
328	// Check if state changed
329	if !currentState.Equal(lastState) {
330		l.mu.Lock()
331		l.lastGitState = currentState
332		l.mu.Unlock()
333
334		if currentState.IsRepo {
335			l.logger.Debug("git state changed",
336				"worktree", currentState.Worktree,
337				"branch", currentState.Branch,
338				"commit", currentState.Commit)
339			l.onGitStateChange(ctx, currentState)
340		}
341	}
342}
343
344// handleMaxTokensTruncation handles the case where the LLM response was truncated
345// due to hitting the maximum output token limit. It records the truncated message
346// for cost tracking (excluded from context) and an error message for the user.
347func (l *Loop) handleMaxTokensTruncation(ctx context.Context, resp *llm.Response) error {
348	// Record the truncated message for cost tracking, but mark it as excluded from context.
349	// This preserves billing information without confusing the LLM on future turns.
350	truncatedMessage := resp.ToMessage()
351	truncatedMessage.ExcludedFromContext = true
352
353	// Record the truncated message with usage metadata
354	usageWithMeta := resp.Usage
355	usageWithMeta.Model = resp.Model
356	usageWithMeta.StartTime = resp.StartTime
357	usageWithMeta.EndTime = resp.EndTime
358	if err := l.recordMessage(ctx, truncatedMessage, usageWithMeta); err != nil {
359		l.logger.Error("failed to record truncated message", "error", err)
360	}
361
362	// Record a truncation error message with EndOfTurn=true to properly signal end of turn.
363	errorMessage := llm.Message{
364		Role: llm.MessageRoleAssistant,
365		Content: []llm.Content{
366			{
367				Type: llm.ContentTypeText,
368				Text: "[SYSTEM ERROR: Your previous response was truncated because it exceeded the maximum output token limit. " +
369					"Any tool calls in that response were lost. Please retry with smaller, incremental changes. " +
370					"For file operations, break large changes into multiple smaller patches. " +
371					"The user can ask you to continue if needed.]",
372			},
373		},
374		EndOfTurn: true,
375		ErrorType: llm.ErrorTypeTruncation,
376	}
377
378	l.mu.Lock()
379	l.history = append(l.history, errorMessage)
380	l.mu.Unlock()
381
382	// Record the truncation error message
383	if err := l.recordMessage(ctx, errorMessage, llm.Usage{}); err != nil {
384		l.logger.Error("failed to record truncation error message", "error", err)
385	}
386
387	// End the turn - don't automatically continue
388	l.checkGitStateChange(ctx)
389	return nil
390}
391
392// handleToolCalls processes tool calls from the LLM response
393func (l *Loop) handleToolCalls(ctx context.Context, content []llm.Content) error {
394	var toolResults []llm.Content
395
396	for _, c := range content {
397		if c.Type != llm.ContentTypeToolUse {
398			continue
399		}
400
401		l.logger.Debug("executing tool", "name", c.ToolName, "id", c.ID)
402
403		// Find the tool
404		var tool *llm.Tool
405		for _, t := range l.tools {
406			if t.Name == c.ToolName {
407				tool = t
408				break
409			}
410		}
411
412		if tool == nil {
413			l.logger.Error("tool not found", "name", c.ToolName)
414			toolResults = append(toolResults, llm.Content{
415				Type:      llm.ContentTypeToolResult,
416				ToolUseID: c.ID,
417				ToolError: true,
418				ToolResult: []llm.Content{
419					{Type: llm.ContentTypeText, Text: fmt.Sprintf("Tool '%s' not found", c.ToolName)},
420				},
421			})
422			continue
423		}
424
425		// Execute the tool with working directory set in context
426		toolCtx := ctx
427		if l.workingDir != "" {
428			toolCtx = claudetool.WithWorkingDir(ctx, l.workingDir)
429		}
430		startTime := time.Now()
431		result := tool.Run(toolCtx, c.ToolInput)
432		endTime := time.Now()
433
434		var toolResultContent []llm.Content
435		if result.Error != nil {
436			l.logger.Error("tool execution failed", "name", c.ToolName, "error", result.Error)
437			toolResultContent = []llm.Content{
438				{Type: llm.ContentTypeText, Text: result.Error.Error()},
439			}
440		} else {
441			toolResultContent = result.LLMContent
442			l.logger.Debug("tool executed successfully", "name", c.ToolName, "duration", endTime.Sub(startTime))
443		}
444
445		toolResults = append(toolResults, llm.Content{
446			Type:             llm.ContentTypeToolResult,
447			ToolUseID:        c.ID,
448			ToolError:        result.Error != nil,
449			ToolResult:       toolResultContent,
450			ToolUseStartTime: &startTime,
451			ToolUseEndTime:   &endTime,
452			Display:          result.Display,
453		})
454	}
455
456	if len(toolResults) > 0 {
457		// Add tool results to history as a user message
458		toolMessage := llm.Message{
459			Role:    llm.MessageRoleUser,
460			Content: toolResults,
461		}
462
463		l.mu.Lock()
464		l.history = append(l.history, toolMessage)
465		// Check for queued user messages (interruptions) before continuing.
466		// This allows user messages to be processed as soon as possible.
467		if len(l.messageQueue) > 0 {
468			for _, msg := range l.messageQueue {
469				l.history = append(l.history, msg)
470			}
471			l.messageQueue = l.messageQueue[:0]
472			l.logger.Info("processing user interruption during tool execution")
473		}
474		l.mu.Unlock()
475
476		// Record tool result message
477		if err := l.recordMessage(ctx, toolMessage, llm.Usage{}); err != nil {
478			l.logger.Error("failed to record tool result message", "error", err)
479		}
480
481		// Process another LLM request with the tool results
482		return l.processLLMRequest(ctx)
483	}
484
485	return nil
486}
487
488// insertMissingToolResults fixes tool_result issues in the conversation history:
489//  1. Adds error results for tool_uses that were requested but not included in the next message.
490//     This can happen when a request is cancelled or fails after the LLM responds with tool_use
491//     blocks but before the tools execute.
492//  2. Removes orphan tool_results that reference tool_use IDs not present in the immediately
493//     preceding assistant message. This can happen when a tool execution completes after
494//     CancelConversation has already written cancellation messages.
495//
496// This prevents API errors like:
497//   - "tool_use ids were found without tool_result blocks"
498//   - "unexpected tool_use_id found in tool_result blocks ... Each tool_result block must have
499//     a corresponding tool_use block in the previous message"
500//
501// Mutates the request's Messages slice.
502func (l *Loop) insertMissingToolResults(req *llm.Request) {
503	if len(req.Messages) < 1 {
504		return
505	}
506
507	// Scan through all messages looking for assistant messages with tool_use
508	// that are not immediately followed by a user message with corresponding tool_results.
509	// We may need to insert synthetic user messages with tool_results or filter orphans.
510	var newMessages []llm.Message
511	totalInserted := 0
512	totalRemoved := 0
513
514	// Track the tool_use IDs from the most recent assistant message
515	var prevAssistantToolUseIDs map[string]bool
516
517	for i := 0; i < len(req.Messages); i++ {
518		msg := req.Messages[i]
519
520		if msg.Role == llm.MessageRoleAssistant {
521			// Handle empty assistant messages - add placeholder content if not the last message
522			// The API requires all messages to have non-empty content except for the optional
523			// final assistant message. Empty content can happen when the model ends its turn
524			// without producing any output.
525			if len(msg.Content) == 0 && i < len(req.Messages)-1 {
526				req.Messages[i].Content = []llm.Content{{Type: llm.ContentTypeText, Text: "(no response)"}}
527				msg = req.Messages[i] // update local copy for subsequent processing
528				l.logger.Debug("added placeholder content to empty assistant message", "index", i)
529			}
530
531			// Track all tool_use IDs in this assistant message
532			prevAssistantToolUseIDs = make(map[string]bool)
533			for _, c := range msg.Content {
534				if c.Type == llm.ContentTypeToolUse {
535					prevAssistantToolUseIDs[c.ID] = true
536				}
537			}
538			newMessages = append(newMessages, msg)
539
540			// Check if next message needs synthetic tool_results
541			var toolUseContents []llm.Content
542			for _, c := range msg.Content {
543				if c.Type == llm.ContentTypeToolUse {
544					toolUseContents = append(toolUseContents, c)
545				}
546			}
547
548			if len(toolUseContents) == 0 {
549				continue
550			}
551
552			// Check if next message is a user message with corresponding tool_results
553			var nextMsg *llm.Message
554			if i+1 < len(req.Messages) {
555				nextMsg = &req.Messages[i+1]
556			}
557
558			if nextMsg == nil || nextMsg.Role != llm.MessageRoleUser {
559				// Next message is not a user message (or there is no next message).
560				// Insert a synthetic user message with tool_results for all tool_uses.
561				var toolResultContent []llm.Content
562				for _, tu := range toolUseContents {
563					toolResultContent = append(toolResultContent, llm.Content{
564						Type:      llm.ContentTypeToolResult,
565						ToolUseID: tu.ID,
566						ToolError: true,
567						ToolResult: []llm.Content{{
568							Type: llm.ContentTypeText,
569							Text: "not executed; retry possible",
570						}},
571					})
572				}
573				syntheticMsg := llm.Message{
574					Role:    llm.MessageRoleUser,
575					Content: toolResultContent,
576				}
577				newMessages = append(newMessages, syntheticMsg)
578				totalInserted += len(toolResultContent)
579			}
580		} else if msg.Role == llm.MessageRoleUser {
581			// Filter out orphan tool_results and add missing ones
582			var filteredContent []llm.Content
583			existingResultIDs := make(map[string]bool)
584
585			for _, c := range msg.Content {
586				if c.Type == llm.ContentTypeToolResult {
587					// Only keep tool_results that match a tool_use in the previous assistant message
588					if prevAssistantToolUseIDs != nil && prevAssistantToolUseIDs[c.ToolUseID] {
589						filteredContent = append(filteredContent, c)
590						existingResultIDs[c.ToolUseID] = true
591					} else {
592						// Orphan tool_result - skip it
593						totalRemoved++
594						l.logger.Debug("removing orphan tool_result", "tool_use_id", c.ToolUseID)
595					}
596				} else {
597					// Keep non-tool_result content
598					filteredContent = append(filteredContent, c)
599				}
600			}
601
602			// Check if we need to add missing tool_results for this user message
603			if prevAssistantToolUseIDs != nil {
604				var prefix []llm.Content
605				for toolUseID := range prevAssistantToolUseIDs {
606					if !existingResultIDs[toolUseID] {
607						prefix = append(prefix, llm.Content{
608							Type:      llm.ContentTypeToolResult,
609							ToolUseID: toolUseID,
610							ToolError: true,
611							ToolResult: []llm.Content{{
612								Type: llm.ContentTypeText,
613								Text: "not executed; retry possible",
614							}},
615						})
616						totalInserted++
617					}
618				}
619				if len(prefix) > 0 {
620					filteredContent = append(prefix, filteredContent...)
621				}
622			}
623
624			// Only add the message if it has content
625			if len(filteredContent) > 0 {
626				msg.Content = filteredContent
627				newMessages = append(newMessages, msg)
628			} else {
629				// Message is now empty after filtering - skip it entirely
630				l.logger.Debug("removing empty user message after filtering orphan tool_results")
631			}
632
633			// Reset for next iteration - user message "consumes" the previous tool_uses
634			prevAssistantToolUseIDs = nil
635		} else {
636			newMessages = append(newMessages, msg)
637		}
638	}
639
640	if totalInserted > 0 || totalRemoved > 0 {
641		req.Messages = newMessages
642		if totalInserted > 0 {
643			l.logger.Debug("inserted missing tool results", "count", totalInserted)
644		}
645		if totalRemoved > 0 {
646			l.logger.Debug("removed orphan tool results", "count", totalRemoved)
647		}
648	}
649}