loop.go

  1package loop
  2
  3import (
  4	"context"
  5	"fmt"
  6	"log/slog"
  7	"sync"
  8	"time"
  9
 10	"shelley.exe.dev/claudetool"
 11	"shelley.exe.dev/gitstate"
 12	"shelley.exe.dev/llm"
 13)
 14
 15// MessageRecordFunc is called to record new messages to persistent storage
 16type MessageRecordFunc func(ctx context.Context, message llm.Message, usage llm.Usage) error
 17
 18// GitStateChangeFunc is called when the git state changes at the end of a turn.
 19// This is used to record user-visible notifications about git changes.
 20type GitStateChangeFunc func(ctx context.Context, state *gitstate.GitState)
 21
 22// Config contains all configuration needed to create a Loop
 23type Config struct {
 24	LLM              llm.Service
 25	History          []llm.Message
 26	Tools            []*llm.Tool
 27	RecordMessage    MessageRecordFunc
 28	Logger           *slog.Logger
 29	System           []llm.SystemContent
 30	WorkingDir       string // working directory for tools
 31	OnGitStateChange GitStateChangeFunc
 32	// GetWorkingDir returns the current working directory for tools.
 33	// If set, this is called at end of turn to check for git state changes.
 34	// If nil, Config.WorkingDir is used as a static value.
 35	GetWorkingDir func() string
 36}
 37
 38// Loop manages a conversation turn with an LLM including tool execution and message recording.
 39// Notably, when the turn ends, the "Loop" is over. TODO: maybe rename to Turn?
 40type Loop struct {
 41	llm              llm.Service
 42	tools            []*llm.Tool
 43	recordMessage    MessageRecordFunc
 44	history          []llm.Message
 45	messageQueue     []llm.Message
 46	totalUsage       llm.Usage
 47	mu               sync.Mutex
 48	logger           *slog.Logger
 49	system           []llm.SystemContent
 50	workingDir       string
 51	onGitStateChange GitStateChangeFunc
 52	getWorkingDir    func() string
 53	lastGitState     *gitstate.GitState
 54}
 55
 56// NewLoop creates a new Loop instance with the provided configuration
 57func NewLoop(config Config) *Loop {
 58	logger := config.Logger
 59	if logger == nil {
 60		logger = slog.Default()
 61	}
 62
 63	// Get initial git state
 64	workingDir := config.WorkingDir
 65	if config.GetWorkingDir != nil {
 66		workingDir = config.GetWorkingDir()
 67	}
 68	initialGitState := gitstate.GetGitState(workingDir)
 69
 70	return &Loop{
 71		llm:              config.LLM,
 72		history:          config.History,
 73		tools:            config.Tools,
 74		recordMessage:    config.RecordMessage,
 75		messageQueue:     make([]llm.Message, 0),
 76		logger:           logger,
 77		system:           config.System,
 78		workingDir:       config.WorkingDir,
 79		onGitStateChange: config.OnGitStateChange,
 80		getWorkingDir:    config.GetWorkingDir,
 81		lastGitState:     initialGitState,
 82	}
 83}
 84
 85// QueueUserMessage adds a user message to the queue to be processed
 86func (l *Loop) QueueUserMessage(message llm.Message) {
 87	l.mu.Lock()
 88	defer l.mu.Unlock()
 89	l.messageQueue = append(l.messageQueue, message)
 90	l.logger.Debug("queued user message", "content_count", len(message.Content))
 91}
 92
 93// GetUsage returns the total usage accumulated by this loop
 94func (l *Loop) GetUsage() llm.Usage {
 95	l.mu.Lock()
 96	defer l.mu.Unlock()
 97	return l.totalUsage
 98}
 99
100// GetHistory returns a copy of the current conversation history
101func (l *Loop) GetHistory() []llm.Message {
102	l.mu.Lock()
103	defer l.mu.Unlock()
104	// Deep copy the messages to prevent modifications
105	historyCopy := make([]llm.Message, len(l.history))
106	for i, msg := range l.history {
107		// Copy the message
108		historyCopy[i] = llm.Message{
109			Role:    msg.Role,
110			ToolUse: msg.ToolUse, // This is a pointer, but we won't modify it in tests
111			Content: make([]llm.Content, len(msg.Content)),
112		}
113		// Copy content slice
114		copy(historyCopy[i].Content, msg.Content)
115	}
116	return historyCopy
117}
118
119// Go runs the conversation loop until the context is canceled
120func (l *Loop) Go(ctx context.Context) error {
121	if l.llm == nil {
122		return fmt.Errorf("no LLM service configured")
123	}
124
125	l.logger.Info("starting conversation loop", "tools", len(l.tools))
126
127	for {
128		select {
129		case <-ctx.Done():
130			l.logger.Info("conversation loop canceled")
131			return ctx.Err()
132		default:
133		}
134
135		// Process any queued messages
136		l.mu.Lock()
137		hasQueuedMessages := len(l.messageQueue) > 0
138		if hasQueuedMessages {
139			// Add queued messages to history (they are already recorded to DB by ConversationManager)
140			for _, msg := range l.messageQueue {
141				l.history = append(l.history, msg)
142			}
143			l.messageQueue = l.messageQueue[:0] // Clear queue
144		}
145		l.mu.Unlock()
146
147		if hasQueuedMessages {
148			// Send request to LLM
149			l.logger.Debug("processing queued messages", "count", 1)
150			if err := l.processLLMRequest(ctx); err != nil {
151				l.logger.Error("failed to process LLM request", "error", err)
152				time.Sleep(time.Second) // Wait before retrying
153				continue
154			}
155			l.logger.Debug("finished processing queued messages")
156		} else {
157			// No queued messages, wait a bit
158			select {
159			case <-ctx.Done():
160				return ctx.Err()
161			case <-time.After(100 * time.Millisecond):
162				// Continue loop
163			}
164		}
165	}
166}
167
168// ProcessOneTurn processes queued messages through one complete turn (user message + assistant response)
169// It stops after the assistant responds, regardless of whether tools were called
170func (l *Loop) ProcessOneTurn(ctx context.Context) error {
171	if l.llm == nil {
172		return fmt.Errorf("no LLM service configured")
173	}
174
175	// Process any queued messages first
176	l.mu.Lock()
177	if len(l.messageQueue) > 0 {
178		// Add queued messages to history (they are already recorded to DB by ConversationManager)
179		for _, msg := range l.messageQueue {
180			l.history = append(l.history, msg)
181		}
182		l.messageQueue = nil
183	}
184	l.mu.Unlock()
185
186	// Process one LLM request and response
187	return l.processLLMRequest(ctx)
188}
189
190// processLLMRequest sends a request to the LLM and handles the response
191func (l *Loop) processLLMRequest(ctx context.Context) error {
192	l.mu.Lock()
193	messages := append([]llm.Message(nil), l.history...)
194	tools := l.tools
195	system := l.system
196	llmService := l.llm
197	l.mu.Unlock()
198
199	// Enable prompt caching: set cache flag on last tool and last user message content
200	// See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
201	if len(tools) > 0 {
202		// Make a copy of tools to avoid modifying the shared slice
203		tools = append([]*llm.Tool(nil), tools...)
204		// Copy the last tool and enable caching
205		lastTool := *tools[len(tools)-1]
206		lastTool.Cache = true
207		tools[len(tools)-1] = &lastTool
208	}
209
210	// Set cache flag on the last content block of the last user message
211	if len(messages) > 0 {
212		for i := len(messages) - 1; i >= 0; i-- {
213			if messages[i].Role == llm.MessageRoleUser && len(messages[i].Content) > 0 {
214				// Deep copy the message to avoid modifying the shared history
215				msg := messages[i]
216				msg.Content = append([]llm.Content(nil), msg.Content...)
217				msg.Content[len(msg.Content)-1].Cache = true
218				messages[i] = msg
219				break
220			}
221		}
222	}
223
224	req := &llm.Request{
225		Messages: messages,
226		Tools:    tools,
227		System:   system,
228	}
229
230	// Insert missing tool results if the previous message had tool_use blocks
231	// without corresponding tool_result blocks. This can happen when a request
232	// is cancelled or fails after the LLM responds but before tools execute.
233	l.insertMissingToolResults(req)
234
235	systemLen := 0
236	for _, sys := range system {
237		systemLen += len(sys.Text)
238	}
239	l.logger.Debug("sending LLM request", "message_count", len(messages), "tool_count", len(tools), "system_items", len(system), "system_length", systemLen)
240
241	// Add a timeout for the LLM request to prevent indefinite hangs
242	llmCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
243	defer cancel()
244
245	resp, err := llmService.Do(llmCtx, req)
246	if err != nil {
247		// Record the error as a message so it can be displayed in the UI
248		errorMessage := llm.Message{
249			Role: llm.MessageRoleAssistant,
250			Content: []llm.Content{
251				{
252					Type: llm.ContentTypeText,
253					Text: fmt.Sprintf("LLM request failed: %v", err),
254				},
255			},
256		}
257		if recordErr := l.recordMessage(ctx, errorMessage, llm.Usage{}); recordErr != nil {
258			l.logger.Error("failed to record error message", "error", recordErr)
259		}
260		return fmt.Errorf("LLM request failed: %w", err)
261	}
262
263	l.logger.Debug("received LLM response", "content_count", len(resp.Content), "stop_reason", resp.StopReason.String(), "usage", resp.Usage.String())
264
265	// Update total usage
266	l.mu.Lock()
267	l.totalUsage.Add(resp.Usage)
268	l.mu.Unlock()
269
270	// Convert response to message and add to history
271	assistantMessage := resp.ToMessage()
272	l.mu.Lock()
273	l.history = append(l.history, assistantMessage)
274	l.mu.Unlock()
275
276	// Record assistant message with model and timing metadata
277	usageWithMeta := resp.Usage
278	usageWithMeta.Model = resp.Model
279	usageWithMeta.StartTime = resp.StartTime
280	usageWithMeta.EndTime = resp.EndTime
281	if err := l.recordMessage(ctx, assistantMessage, usageWithMeta); err != nil {
282		l.logger.Error("failed to record assistant message", "error", err)
283	}
284
285	// Handle tool calls if any
286	if resp.StopReason == llm.StopReasonToolUse {
287		l.logger.Debug("handling tool calls", "content_count", len(resp.Content))
288		return l.handleToolCalls(ctx, resp.Content)
289	}
290
291	// Handle max tokens truncation - record error message for the user
292	if resp.StopReason == llm.StopReasonMaxTokens {
293		l.logger.Warn("LLM response truncated due to max tokens")
294		return l.handleMaxTokensTruncation(ctx)
295	}
296
297	// End of turn - check for git state changes
298	l.checkGitStateChange(ctx)
299
300	return nil
301}
302
303// checkGitStateChange checks if the git state has changed and calls the callback if so.
304// This is called at the end of each turn.
305func (l *Loop) checkGitStateChange(ctx context.Context) {
306	if l.onGitStateChange == nil {
307		return
308	}
309
310	// Get current working directory
311	workingDir := l.workingDir
312	if l.getWorkingDir != nil {
313		workingDir = l.getWorkingDir()
314	}
315
316	// Get current git state
317	currentState := gitstate.GetGitState(workingDir)
318
319	// Compare with last known state
320	l.mu.Lock()
321	lastState := l.lastGitState
322	l.mu.Unlock()
323
324	// Check if state changed
325	if !currentState.Equal(lastState) {
326		l.mu.Lock()
327		l.lastGitState = currentState
328		l.mu.Unlock()
329
330		if currentState.IsRepo {
331			l.logger.Debug("git state changed",
332				"worktree", currentState.Worktree,
333				"branch", currentState.Branch,
334				"commit", currentState.Commit)
335			l.onGitStateChange(ctx, currentState)
336		}
337	}
338}
339
340// handleMaxTokensTruncation handles the case where the LLM response was truncated
341// due to hitting the maximum output token limit. It records an error message
342// informing the user and instructing the LLM to use smaller outputs.
343func (l *Loop) handleMaxTokensTruncation(ctx context.Context) error {
344	errorMessage := llm.Message{
345		Role: llm.MessageRoleUser,
346		Content: []llm.Content{
347			{
348				Type: llm.ContentTypeText,
349				Text: "[SYSTEM ERROR: Your previous response was truncated because it exceeded the maximum output token limit. " +
350					"Any tool calls in that response were lost. Please retry with smaller, incremental changes. " +
351					"For file operations, break large changes into multiple smaller patches. " +
352					"The user can ask you to continue if needed.]",
353			},
354		},
355	}
356
357	l.mu.Lock()
358	l.history = append(l.history, errorMessage)
359	l.mu.Unlock()
360
361	// Record the error message
362	if err := l.recordMessage(ctx, errorMessage, llm.Usage{}); err != nil {
363		l.logger.Error("failed to record truncation error message", "error", err)
364	}
365
366	// End the turn - don't automatically continue
367	l.checkGitStateChange(ctx)
368	return nil
369}
370
371// handleToolCalls processes tool calls from the LLM response
372func (l *Loop) handleToolCalls(ctx context.Context, content []llm.Content) error {
373	var toolResults []llm.Content
374
375	for _, c := range content {
376		if c.Type != llm.ContentTypeToolUse {
377			continue
378		}
379
380		l.logger.Debug("executing tool", "name", c.ToolName, "id", c.ID)
381
382		// Find the tool
383		var tool *llm.Tool
384		for _, t := range l.tools {
385			if t.Name == c.ToolName {
386				tool = t
387				break
388			}
389		}
390
391		if tool == nil {
392			l.logger.Error("tool not found", "name", c.ToolName)
393			toolResults = append(toolResults, llm.Content{
394				Type:      llm.ContentTypeToolResult,
395				ToolUseID: c.ID,
396				ToolError: true,
397				ToolResult: []llm.Content{
398					{Type: llm.ContentTypeText, Text: fmt.Sprintf("Tool '%s' not found", c.ToolName)},
399				},
400			})
401			continue
402		}
403
404		// Execute the tool with working directory set in context
405		toolCtx := ctx
406		if l.workingDir != "" {
407			toolCtx = claudetool.WithWorkingDir(ctx, l.workingDir)
408		}
409		startTime := time.Now()
410		result := tool.Run(toolCtx, c.ToolInput)
411		endTime := time.Now()
412
413		var toolResultContent []llm.Content
414		if result.Error != nil {
415			l.logger.Error("tool execution failed", "name", c.ToolName, "error", result.Error)
416			toolResultContent = []llm.Content{
417				{Type: llm.ContentTypeText, Text: result.Error.Error()},
418			}
419		} else {
420			toolResultContent = result.LLMContent
421			l.logger.Debug("tool executed successfully", "name", c.ToolName, "duration", endTime.Sub(startTime))
422		}
423
424		toolResults = append(toolResults, llm.Content{
425			Type:             llm.ContentTypeToolResult,
426			ToolUseID:        c.ID,
427			ToolError:        result.Error != nil,
428			ToolResult:       toolResultContent,
429			ToolUseStartTime: &startTime,
430			ToolUseEndTime:   &endTime,
431			Display:          result.Display,
432		})
433	}
434
435	if len(toolResults) > 0 {
436		// Add tool results to history as a user message
437		toolMessage := llm.Message{
438			Role:    llm.MessageRoleUser,
439			Content: toolResults,
440		}
441
442		l.mu.Lock()
443		l.history = append(l.history, toolMessage)
444		l.mu.Unlock()
445
446		// Record tool result message
447		if err := l.recordMessage(ctx, toolMessage, llm.Usage{}); err != nil {
448			l.logger.Error("failed to record tool result message", "error", err)
449		}
450
451		// Process another LLM request with the tool results
452		return l.processLLMRequest(ctx)
453	}
454
455	return nil
456}
457
458// insertMissingToolResults fixes tool_result issues in the conversation history:
459//  1. Adds error results for tool_uses that were requested but not included in the next message.
460//     This can happen when a request is cancelled or fails after the LLM responds with tool_use
461//     blocks but before the tools execute.
462//  2. Removes orphan tool_results that reference tool_use IDs not present in the immediately
463//     preceding assistant message. This can happen when a tool execution completes after
464//     CancelConversation has already written cancellation messages.
465//
466// This prevents API errors like:
467//   - "tool_use ids were found without tool_result blocks"
468//   - "unexpected tool_use_id found in tool_result blocks ... Each tool_result block must have
469//     a corresponding tool_use block in the previous message"
470//
471// Mutates the request's Messages slice.
472func (l *Loop) insertMissingToolResults(req *llm.Request) {
473	if len(req.Messages) < 1 {
474		return
475	}
476
477	// Scan through all messages looking for assistant messages with tool_use
478	// that are not immediately followed by a user message with corresponding tool_results.
479	// We may need to insert synthetic user messages with tool_results or filter orphans.
480	var newMessages []llm.Message
481	totalInserted := 0
482	totalRemoved := 0
483
484	// Track the tool_use IDs from the most recent assistant message
485	var prevAssistantToolUseIDs map[string]bool
486
487	for i := 0; i < len(req.Messages); i++ {
488		msg := req.Messages[i]
489
490		if msg.Role == llm.MessageRoleAssistant {
491			// Handle empty assistant messages - add placeholder content if not the last message
492			// The API requires all messages to have non-empty content except for the optional
493			// final assistant message. Empty content can happen when the model ends its turn
494			// without producing any output.
495			if len(msg.Content) == 0 && i < len(req.Messages)-1 {
496				req.Messages[i].Content = []llm.Content{{Type: llm.ContentTypeText, Text: "(no response)"}}
497				msg = req.Messages[i] // update local copy for subsequent processing
498				l.logger.Debug("added placeholder content to empty assistant message", "index", i)
499			}
500
501			// Track all tool_use IDs in this assistant message
502			prevAssistantToolUseIDs = make(map[string]bool)
503			for _, c := range msg.Content {
504				if c.Type == llm.ContentTypeToolUse {
505					prevAssistantToolUseIDs[c.ID] = true
506				}
507			}
508			newMessages = append(newMessages, msg)
509
510			// Check if next message needs synthetic tool_results
511			var toolUseContents []llm.Content
512			for _, c := range msg.Content {
513				if c.Type == llm.ContentTypeToolUse {
514					toolUseContents = append(toolUseContents, c)
515				}
516			}
517
518			if len(toolUseContents) == 0 {
519				continue
520			}
521
522			// Check if next message is a user message with corresponding tool_results
523			var nextMsg *llm.Message
524			if i+1 < len(req.Messages) {
525				nextMsg = &req.Messages[i+1]
526			}
527
528			if nextMsg == nil || nextMsg.Role != llm.MessageRoleUser {
529				// Next message is not a user message (or there is no next message).
530				// Insert a synthetic user message with tool_results for all tool_uses.
531				var toolResultContent []llm.Content
532				for _, tu := range toolUseContents {
533					toolResultContent = append(toolResultContent, llm.Content{
534						Type:      llm.ContentTypeToolResult,
535						ToolUseID: tu.ID,
536						ToolError: true,
537						ToolResult: []llm.Content{{
538							Type: llm.ContentTypeText,
539							Text: "not executed; retry possible",
540						}},
541					})
542				}
543				syntheticMsg := llm.Message{
544					Role:    llm.MessageRoleUser,
545					Content: toolResultContent,
546				}
547				newMessages = append(newMessages, syntheticMsg)
548				totalInserted += len(toolResultContent)
549			}
550		} else if msg.Role == llm.MessageRoleUser {
551			// Filter out orphan tool_results and add missing ones
552			var filteredContent []llm.Content
553			existingResultIDs := make(map[string]bool)
554
555			for _, c := range msg.Content {
556				if c.Type == llm.ContentTypeToolResult {
557					// Only keep tool_results that match a tool_use in the previous assistant message
558					if prevAssistantToolUseIDs != nil && prevAssistantToolUseIDs[c.ToolUseID] {
559						filteredContent = append(filteredContent, c)
560						existingResultIDs[c.ToolUseID] = true
561					} else {
562						// Orphan tool_result - skip it
563						totalRemoved++
564						l.logger.Debug("removing orphan tool_result", "tool_use_id", c.ToolUseID)
565					}
566				} else {
567					// Keep non-tool_result content
568					filteredContent = append(filteredContent, c)
569				}
570			}
571
572			// Check if we need to add missing tool_results for this user message
573			if prevAssistantToolUseIDs != nil {
574				var prefix []llm.Content
575				for toolUseID := range prevAssistantToolUseIDs {
576					if !existingResultIDs[toolUseID] {
577						prefix = append(prefix, llm.Content{
578							Type:      llm.ContentTypeToolResult,
579							ToolUseID: toolUseID,
580							ToolError: true,
581							ToolResult: []llm.Content{{
582								Type: llm.ContentTypeText,
583								Text: "not executed; retry possible",
584							}},
585						})
586						totalInserted++
587					}
588				}
589				if len(prefix) > 0 {
590					filteredContent = append(prefix, filteredContent...)
591				}
592			}
593
594			// Only add the message if it has content
595			if len(filteredContent) > 0 {
596				msg.Content = filteredContent
597				newMessages = append(newMessages, msg)
598			} else {
599				// Message is now empty after filtering - skip it entirely
600				l.logger.Debug("removing empty user message after filtering orphan tool_results")
601			}
602
603			// Reset for next iteration - user message "consumes" the previous tool_uses
604			prevAssistantToolUseIDs = nil
605		} else {
606			newMessages = append(newMessages, msg)
607		}
608	}
609
610	if totalInserted > 0 || totalRemoved > 0 {
611		req.Messages = newMessages
612		if totalInserted > 0 {
613			l.logger.Debug("inserted missing tool results", "count", totalInserted)
614		}
615		if totalRemoved > 0 {
616			l.logger.Debug("removed orphan tool results", "count", totalRemoved)
617		}
618	}
619}