1package loop
2
3import (
4 "context"
5 "fmt"
6 "io"
7 "log/slog"
8 "strings"
9 "sync"
10 "time"
11
12 "shelley.exe.dev/claudetool"
13 "shelley.exe.dev/gitstate"
14 "shelley.exe.dev/llm"
15)
16
17// MessageRecordFunc is called to record new messages to persistent storage.
18type MessageRecordFunc func(ctx context.Context, message llm.Message, usage llm.Usage) error
19
20// GitStateChangeFunc is called when the git state changes at the end of a turn.
21// This is used to record user-visible notifications about git changes.
22type GitStateChangeFunc func(ctx context.Context, state *gitstate.GitState)
23
24// Config contains all configuration needed to create a Loop.
25type Config struct {
26 LLM llm.Service
27 History []llm.Message
28 Tools []*llm.Tool
29 RecordMessage MessageRecordFunc
30 Logger *slog.Logger
31 System []llm.SystemContent
32 WorkingDir string // working directory for tools
33 OnGitStateChange GitStateChangeFunc
34 // GetWorkingDir returns the current working directory for tools.
35 // If set, this is called at end of turn to check for git state changes.
36 // If nil, Config.WorkingDir is used as a static value.
37 GetWorkingDir func() string
38}
39
40// Loop manages a conversation turn with an LLM including tool execution and message recording.
41// Notably, when the turn ends, the "Loop" is over. TODO: maybe rename to Turn?
42type Loop struct {
43 llm llm.Service
44 tools []*llm.Tool
45 recordMessage MessageRecordFunc
46 history []llm.Message
47 messageQueue []llm.Message
48 totalUsage llm.Usage
49 mu sync.Mutex
50 logger *slog.Logger
51 system []llm.SystemContent
52 workingDir string
53 onGitStateChange GitStateChangeFunc
54 getWorkingDir func() string
55 lastGitState *gitstate.GitState
56}
57
58// NewLoop creates a new Loop instance with the provided configuration
59func NewLoop(config Config) *Loop {
60 logger := config.Logger
61 if logger == nil {
62 logger = slog.Default()
63 }
64
65 // Get initial git state
66 workingDir := config.WorkingDir
67 if config.GetWorkingDir != nil {
68 workingDir = config.GetWorkingDir()
69 }
70 initialGitState := gitstate.GetGitState(workingDir)
71
72 return &Loop{
73 llm: config.LLM,
74 history: config.History,
75 tools: config.Tools,
76 recordMessage: config.RecordMessage,
77 messageQueue: make([]llm.Message, 0),
78 logger: logger,
79 system: config.System,
80 workingDir: config.WorkingDir,
81 onGitStateChange: config.OnGitStateChange,
82 getWorkingDir: config.GetWorkingDir,
83 lastGitState: initialGitState,
84 }
85}
86
87// QueueUserMessage adds a user message to the queue to be processed
88func (l *Loop) QueueUserMessage(message llm.Message) {
89 l.mu.Lock()
90 defer l.mu.Unlock()
91 l.messageQueue = append(l.messageQueue, message)
92 l.logger.Debug("queued user message", "content_count", len(message.Content))
93}
94
95// GetUsage returns the total usage accumulated by this loop
96func (l *Loop) GetUsage() llm.Usage {
97 l.mu.Lock()
98 defer l.mu.Unlock()
99 return l.totalUsage
100}
101
102// GetHistory returns a copy of the current conversation history
103func (l *Loop) GetHistory() []llm.Message {
104 l.mu.Lock()
105 defer l.mu.Unlock()
106 // Deep copy the messages to prevent modifications
107 historyCopy := make([]llm.Message, len(l.history))
108 for i, msg := range l.history {
109 // Copy the message
110 historyCopy[i] = llm.Message{
111 Role: msg.Role,
112 ToolUse: msg.ToolUse, // This is a pointer, but we won't modify it in tests
113 Content: make([]llm.Content, len(msg.Content)),
114 }
115 // Copy content slice
116 copy(historyCopy[i].Content, msg.Content)
117 }
118 return historyCopy
119}
120
121// Go runs the conversation loop until the context is canceled
122func (l *Loop) Go(ctx context.Context) error {
123 if l.llm == nil {
124 return fmt.Errorf("no LLM service configured")
125 }
126
127 l.logger.Info("starting conversation loop", "tools", len(l.tools))
128
129 for {
130 select {
131 case <-ctx.Done():
132 l.logger.Info("conversation loop canceled")
133 return ctx.Err()
134 default:
135 }
136
137 // Process any queued messages
138 l.mu.Lock()
139 hasQueuedMessages := len(l.messageQueue) > 0
140 if hasQueuedMessages {
141 // Add queued messages to history (they are already recorded to DB by ConversationManager)
142 for _, msg := range l.messageQueue {
143 l.history = append(l.history, msg)
144 }
145 l.messageQueue = l.messageQueue[:0] // Clear queue
146 }
147 l.mu.Unlock()
148
149 if hasQueuedMessages {
150 // Send request to LLM
151 l.logger.Debug("processing queued messages", "count", 1)
152 if err := l.processLLMRequest(ctx); err != nil {
153 l.logger.Error("failed to process LLM request", "error", err)
154 time.Sleep(time.Second) // Wait before retrying
155 continue
156 }
157 l.logger.Debug("finished processing queued messages")
158 } else {
159 // No queued messages, wait a bit
160 select {
161 case <-ctx.Done():
162 return ctx.Err()
163 case <-time.After(100 * time.Millisecond):
164 // Continue loop
165 }
166 }
167 }
168}
169
170// ProcessOneTurn processes queued messages through one complete turn (user message + assistant response)
171// It stops after the assistant responds, regardless of whether tools were called
172func (l *Loop) ProcessOneTurn(ctx context.Context) error {
173 if l.llm == nil {
174 return fmt.Errorf("no LLM service configured")
175 }
176
177 // Process any queued messages first
178 l.mu.Lock()
179 if len(l.messageQueue) > 0 {
180 // Add queued messages to history (they are already recorded to DB by ConversationManager)
181 for _, msg := range l.messageQueue {
182 l.history = append(l.history, msg)
183 }
184 l.messageQueue = nil
185 }
186 l.mu.Unlock()
187
188 // Process one LLM request and response
189 return l.processLLMRequest(ctx)
190}
191
192// processLLMRequest sends a request to the LLM and handles the response
193func (l *Loop) processLLMRequest(ctx context.Context) error {
194 l.mu.Lock()
195 messages := append([]llm.Message(nil), l.history...)
196 tools := l.tools
197 system := l.system
198 llmService := l.llm
199 l.mu.Unlock()
200
201 // Enable prompt caching: set cache flag on last tool and last user message content
202 // See https://docs.anthropic.com/en/docs/build-with-claude/prompt-caching
203 if len(tools) > 0 {
204 // Make a copy of tools to avoid modifying the shared slice
205 tools = append([]*llm.Tool(nil), tools...)
206 // Copy the last tool and enable caching
207 lastTool := *tools[len(tools)-1]
208 lastTool.Cache = true
209 tools[len(tools)-1] = &lastTool
210 }
211
212 // Set cache flag on the last content block of the last user message
213 if len(messages) > 0 {
214 for i := len(messages) - 1; i >= 0; i-- {
215 if messages[i].Role == llm.MessageRoleUser && len(messages[i].Content) > 0 {
216 // Deep copy the message to avoid modifying the shared history
217 msg := messages[i]
218 msg.Content = append([]llm.Content(nil), msg.Content...)
219 msg.Content[len(msg.Content)-1].Cache = true
220 messages[i] = msg
221 break
222 }
223 }
224 }
225
226 req := &llm.Request{
227 Messages: messages,
228 Tools: tools,
229 System: system,
230 }
231
232 // Insert missing tool results if the previous message had tool_use blocks
233 // without corresponding tool_result blocks. This can happen when a request
234 // is cancelled or fails after the LLM responds but before tools execute.
235 l.insertMissingToolResults(req)
236
237 systemLen := 0
238 for _, sys := range system {
239 systemLen += len(sys.Text)
240 }
241 l.logger.Debug("sending LLM request", "message_count", len(messages), "tool_count", len(tools), "system_items", len(system), "system_length", systemLen)
242
243 // Add a timeout for the LLM request to prevent indefinite hangs
244 llmCtx, cancel := context.WithTimeout(ctx, 5*time.Minute)
245 defer cancel()
246
247 // Retry LLM requests that fail with retryable errors (EOF, connection reset)
248 const maxRetries = 2
249 var resp *llm.Response
250 var err error
251 for attempt := 1; attempt <= maxRetries; attempt++ {
252 resp, err = llmService.Do(llmCtx, req)
253 if err == nil {
254 break
255 }
256 if !isRetryableError(err) || attempt == maxRetries {
257 break
258 }
259 l.logger.Warn("LLM request failed with retryable error, retrying",
260 "error", err,
261 "attempt", attempt,
262 "max_retries", maxRetries)
263 time.Sleep(time.Second * time.Duration(attempt)) // Simple backoff
264 }
265 if err != nil {
266 // Record the error as a message so it can be displayed in the UI
267 // EndOfTurn must be true so the agent working state is properly updated
268 errorMessage := llm.Message{
269 Role: llm.MessageRoleAssistant,
270 Content: []llm.Content{
271 {
272 Type: llm.ContentTypeText,
273 Text: fmt.Sprintf("LLM request failed: %v", err),
274 },
275 },
276 EndOfTurn: true,
277 ErrorType: llm.ErrorTypeLLMRequest,
278 }
279 if recordErr := l.recordMessage(ctx, errorMessage, llm.Usage{}); recordErr != nil {
280 l.logger.Error("failed to record error message", "error", recordErr)
281 }
282 return fmt.Errorf("LLM request failed: %w", err)
283 }
284
285 l.logger.Debug("received LLM response", "content_count", len(resp.Content), "stop_reason", resp.StopReason.String(), "usage", resp.Usage.String())
286
287 // Update total usage
288 l.mu.Lock()
289 l.totalUsage.Add(resp.Usage)
290 l.mu.Unlock()
291
292 // Handle max tokens truncation BEFORE adding to history - truncated responses
293 // should not be added to history normally (they get special handling)
294 if resp.StopReason == llm.StopReasonMaxTokens {
295 l.logger.Warn("LLM response truncated due to max tokens")
296 return l.handleMaxTokensTruncation(ctx, resp)
297 }
298
299 // Convert response to message and add to history
300 assistantMessage := resp.ToMessage()
301 l.mu.Lock()
302 l.history = append(l.history, assistantMessage)
303 l.mu.Unlock()
304
305 // Record assistant message with model and timing metadata
306 usageWithMeta := resp.Usage
307 usageWithMeta.Model = resp.Model
308 usageWithMeta.StartTime = resp.StartTime
309 usageWithMeta.EndTime = resp.EndTime
310 if err := l.recordMessage(ctx, assistantMessage, usageWithMeta); err != nil {
311 l.logger.Error("failed to record assistant message", "error", err)
312 }
313
314 // Handle tool calls if any
315 if resp.StopReason == llm.StopReasonToolUse {
316 l.logger.Debug("handling tool calls", "content_count", len(resp.Content))
317 return l.handleToolCalls(ctx, resp.Content)
318 }
319
320 // End of turn - check for git state changes
321 l.checkGitStateChange(ctx)
322
323 return nil
324}
325
326// checkGitStateChange checks if the git state has changed and calls the callback if so.
327// This is called at the end of each turn.
328func (l *Loop) checkGitStateChange(ctx context.Context) {
329 if l.onGitStateChange == nil {
330 return
331 }
332
333 // Get current working directory
334 workingDir := l.workingDir
335 if l.getWorkingDir != nil {
336 workingDir = l.getWorkingDir()
337 }
338
339 // Get current git state
340 currentState := gitstate.GetGitState(workingDir)
341
342 // Compare with last known state
343 l.mu.Lock()
344 lastState := l.lastGitState
345 l.mu.Unlock()
346
347 // Check if state changed
348 if !currentState.Equal(lastState) {
349 l.mu.Lock()
350 l.lastGitState = currentState
351 l.mu.Unlock()
352
353 if currentState.IsRepo {
354 l.logger.Debug("git state changed",
355 "worktree", currentState.Worktree,
356 "branch", currentState.Branch,
357 "commit", currentState.Commit)
358 l.onGitStateChange(ctx, currentState)
359 }
360 }
361}
362
363// handleMaxTokensTruncation handles the case where the LLM response was truncated
364// due to hitting the maximum output token limit. It records the truncated message
365// for cost tracking (excluded from context) and an error message for the user.
366func (l *Loop) handleMaxTokensTruncation(ctx context.Context, resp *llm.Response) error {
367 // Record the truncated message for cost tracking, but mark it as excluded from context.
368 // This preserves billing information without confusing the LLM on future turns.
369 truncatedMessage := resp.ToMessage()
370 truncatedMessage.ExcludedFromContext = true
371
372 // Record the truncated message with usage metadata
373 usageWithMeta := resp.Usage
374 usageWithMeta.Model = resp.Model
375 usageWithMeta.StartTime = resp.StartTime
376 usageWithMeta.EndTime = resp.EndTime
377 if err := l.recordMessage(ctx, truncatedMessage, usageWithMeta); err != nil {
378 l.logger.Error("failed to record truncated message", "error", err)
379 }
380
381 // Record a truncation error message with EndOfTurn=true to properly signal end of turn.
382 errorMessage := llm.Message{
383 Role: llm.MessageRoleAssistant,
384 Content: []llm.Content{
385 {
386 Type: llm.ContentTypeText,
387 Text: "[SYSTEM ERROR: Your previous response was truncated because it exceeded the maximum output token limit. " +
388 "Any tool calls in that response were lost. Please retry with smaller, incremental changes. " +
389 "For file operations, break large changes into multiple smaller patches. " +
390 "The user can ask you to continue if needed.]",
391 },
392 },
393 EndOfTurn: true,
394 ErrorType: llm.ErrorTypeTruncation,
395 }
396
397 l.mu.Lock()
398 l.history = append(l.history, errorMessage)
399 l.mu.Unlock()
400
401 // Record the truncation error message
402 if err := l.recordMessage(ctx, errorMessage, llm.Usage{}); err != nil {
403 l.logger.Error("failed to record truncation error message", "error", err)
404 }
405
406 // End the turn - don't automatically continue
407 l.checkGitStateChange(ctx)
408 return nil
409}
410
411// handleToolCalls processes tool calls from the LLM response
412func (l *Loop) handleToolCalls(ctx context.Context, content []llm.Content) error {
413 var toolResults []llm.Content
414
415 for _, c := range content {
416 if c.Type != llm.ContentTypeToolUse {
417 continue
418 }
419
420 l.logger.Debug("executing tool", "name", c.ToolName, "id", c.ID)
421
422 // Find the tool
423 var tool *llm.Tool
424 for _, t := range l.tools {
425 if t.Name == c.ToolName {
426 tool = t
427 break
428 }
429 }
430
431 if tool == nil {
432 l.logger.Error("tool not found", "name", c.ToolName)
433 toolResults = append(toolResults, llm.Content{
434 Type: llm.ContentTypeToolResult,
435 ToolUseID: c.ID,
436 ToolError: true,
437 ToolResult: []llm.Content{
438 {Type: llm.ContentTypeText, Text: fmt.Sprintf("Tool '%s' not found", c.ToolName)},
439 },
440 })
441 continue
442 }
443
444 // Execute the tool with working directory set in context
445 toolCtx := ctx
446 if l.workingDir != "" {
447 toolCtx = claudetool.WithWorkingDir(ctx, l.workingDir)
448 }
449 startTime := time.Now()
450 result := tool.Run(toolCtx, c.ToolInput)
451 endTime := time.Now()
452
453 var toolResultContent []llm.Content
454 if result.Error != nil {
455 l.logger.Error("tool execution failed", "name", c.ToolName, "error", result.Error)
456 toolResultContent = []llm.Content{
457 {Type: llm.ContentTypeText, Text: result.Error.Error()},
458 }
459 } else {
460 toolResultContent = result.LLMContent
461 l.logger.Debug("tool executed successfully", "name", c.ToolName, "duration", endTime.Sub(startTime))
462 }
463
464 toolResults = append(toolResults, llm.Content{
465 Type: llm.ContentTypeToolResult,
466 ToolUseID: c.ID,
467 ToolError: result.Error != nil,
468 ToolResult: toolResultContent,
469 ToolUseStartTime: &startTime,
470 ToolUseEndTime: &endTime,
471 Display: result.Display,
472 })
473 }
474
475 if len(toolResults) > 0 {
476 // Add tool results to history as a user message
477 toolMessage := llm.Message{
478 Role: llm.MessageRoleUser,
479 Content: toolResults,
480 }
481
482 l.mu.Lock()
483 l.history = append(l.history, toolMessage)
484 // Check for queued user messages (interruptions) before continuing.
485 // This allows user messages to be processed as soon as possible.
486 if len(l.messageQueue) > 0 {
487 for _, msg := range l.messageQueue {
488 l.history = append(l.history, msg)
489 }
490 l.messageQueue = l.messageQueue[:0]
491 l.logger.Info("processing user interruption during tool execution")
492 }
493 l.mu.Unlock()
494
495 // Record tool result message
496 if err := l.recordMessage(ctx, toolMessage, llm.Usage{}); err != nil {
497 l.logger.Error("failed to record tool result message", "error", err)
498 }
499
500 // Process another LLM request with the tool results
501 return l.processLLMRequest(ctx)
502 }
503
504 return nil
505}
506
507// insertMissingToolResults fixes tool_result issues in the conversation history:
508// 1. Adds error results for tool_uses that were requested but not included in the next message.
509// This can happen when a request is cancelled or fails after the LLM responds with tool_use
510// blocks but before the tools execute.
511// 2. Removes orphan tool_results that reference tool_use IDs not present in the immediately
512// preceding assistant message. This can happen when a tool execution completes after
513// CancelConversation has already written cancellation messages.
514//
515// This prevents API errors like:
516// - "tool_use ids were found without tool_result blocks"
517// - "unexpected tool_use_id found in tool_result blocks ... Each tool_result block must have
518// a corresponding tool_use block in the previous message"
519//
520// Mutates the request's Messages slice.
521func (l *Loop) insertMissingToolResults(req *llm.Request) {
522 if len(req.Messages) < 1 {
523 return
524 }
525
526 // Scan through all messages looking for assistant messages with tool_use
527 // that are not immediately followed by a user message with corresponding tool_results.
528 // We may need to insert synthetic user messages with tool_results or filter orphans.
529 var newMessages []llm.Message
530 totalInserted := 0
531 totalRemoved := 0
532
533 // Track the tool_use IDs from the most recent assistant message
534 var prevAssistantToolUseIDs map[string]bool
535
536 for i := 0; i < len(req.Messages); i++ {
537 msg := req.Messages[i]
538
539 if msg.Role == llm.MessageRoleAssistant {
540 // Handle empty assistant messages - add placeholder content if not the last message
541 // The API requires all messages to have non-empty content except for the optional
542 // final assistant message. Empty content can happen when the model ends its turn
543 // without producing any output.
544 if len(msg.Content) == 0 && i < len(req.Messages)-1 {
545 req.Messages[i].Content = []llm.Content{{Type: llm.ContentTypeText, Text: "(no response)"}}
546 msg = req.Messages[i] // update local copy for subsequent processing
547 l.logger.Debug("added placeholder content to empty assistant message", "index", i)
548 }
549
550 // Track all tool_use IDs in this assistant message
551 prevAssistantToolUseIDs = make(map[string]bool)
552 for _, c := range msg.Content {
553 if c.Type == llm.ContentTypeToolUse {
554 prevAssistantToolUseIDs[c.ID] = true
555 }
556 }
557 newMessages = append(newMessages, msg)
558
559 // Check if next message needs synthetic tool_results
560 var toolUseContents []llm.Content
561 for _, c := range msg.Content {
562 if c.Type == llm.ContentTypeToolUse {
563 toolUseContents = append(toolUseContents, c)
564 }
565 }
566
567 if len(toolUseContents) == 0 {
568 continue
569 }
570
571 // Check if next message is a user message with corresponding tool_results
572 var nextMsg *llm.Message
573 if i+1 < len(req.Messages) {
574 nextMsg = &req.Messages[i+1]
575 }
576
577 if nextMsg == nil || nextMsg.Role != llm.MessageRoleUser {
578 // Next message is not a user message (or there is no next message).
579 // Insert a synthetic user message with tool_results for all tool_uses.
580 var toolResultContent []llm.Content
581 for _, tu := range toolUseContents {
582 toolResultContent = append(toolResultContent, llm.Content{
583 Type: llm.ContentTypeToolResult,
584 ToolUseID: tu.ID,
585 ToolError: true,
586 ToolResult: []llm.Content{{
587 Type: llm.ContentTypeText,
588 Text: "not executed; retry possible",
589 }},
590 })
591 }
592 syntheticMsg := llm.Message{
593 Role: llm.MessageRoleUser,
594 Content: toolResultContent,
595 }
596 newMessages = append(newMessages, syntheticMsg)
597 totalInserted += len(toolResultContent)
598 }
599 } else if msg.Role == llm.MessageRoleUser {
600 // Filter out orphan tool_results and add missing ones
601 var filteredContent []llm.Content
602 existingResultIDs := make(map[string]bool)
603
604 for _, c := range msg.Content {
605 if c.Type == llm.ContentTypeToolResult {
606 // Only keep tool_results that match a tool_use in the previous assistant message
607 if prevAssistantToolUseIDs != nil && prevAssistantToolUseIDs[c.ToolUseID] {
608 filteredContent = append(filteredContent, c)
609 existingResultIDs[c.ToolUseID] = true
610 } else {
611 // Orphan tool_result - skip it
612 totalRemoved++
613 l.logger.Debug("removing orphan tool_result", "tool_use_id", c.ToolUseID)
614 }
615 } else {
616 // Keep non-tool_result content
617 filteredContent = append(filteredContent, c)
618 }
619 }
620
621 // Check if we need to add missing tool_results for this user message
622 if prevAssistantToolUseIDs != nil {
623 var prefix []llm.Content
624 for toolUseID := range prevAssistantToolUseIDs {
625 if !existingResultIDs[toolUseID] {
626 prefix = append(prefix, llm.Content{
627 Type: llm.ContentTypeToolResult,
628 ToolUseID: toolUseID,
629 ToolError: true,
630 ToolResult: []llm.Content{{
631 Type: llm.ContentTypeText,
632 Text: "not executed; retry possible",
633 }},
634 })
635 totalInserted++
636 }
637 }
638 if len(prefix) > 0 {
639 filteredContent = append(prefix, filteredContent...)
640 }
641 }
642
643 // Only add the message if it has content
644 if len(filteredContent) > 0 {
645 msg.Content = filteredContent
646 newMessages = append(newMessages, msg)
647 } else {
648 // Message is now empty after filtering - skip it entirely
649 l.logger.Debug("removing empty user message after filtering orphan tool_results")
650 }
651
652 // Reset for next iteration - user message "consumes" the previous tool_uses
653 prevAssistantToolUseIDs = nil
654 } else {
655 newMessages = append(newMessages, msg)
656 }
657 }
658
659 if totalInserted > 0 || totalRemoved > 0 {
660 req.Messages = newMessages
661 if totalInserted > 0 {
662 l.logger.Debug("inserted missing tool results", "count", totalInserted)
663 }
664 if totalRemoved > 0 {
665 l.logger.Debug("removed orphan tool results", "count", totalRemoved)
666 }
667 }
668}
669
670// isRetryableError checks if an error is transient and should be retried.
671// This includes EOF errors (connection closed unexpectedly) and similar network issues.
672func isRetryableError(err error) bool {
673 if err == nil {
674 return false
675 }
676 // Check for io.EOF and io.ErrUnexpectedEOF
677 if err == io.EOF || err == io.ErrUnexpectedEOF {
678 return true
679 }
680 // Check error message for common retryable patterns
681 errStr := err.Error()
682 retryablePatterns := []string{
683 "EOF",
684 "connection reset",
685 "connection refused",
686 "no such host",
687 "network is unreachable",
688 "i/o timeout",
689 }
690 for _, pattern := range retryablePatterns {
691 if strings.Contains(errStr, pattern) {
692 return true
693 }
694 }
695 return false
696}