convo.go

  1package server
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"sync"
  9	"time"
 10
 11	"shelley.exe.dev/claudetool"
 12	"shelley.exe.dev/db"
 13	"shelley.exe.dev/db/generated"
 14	"shelley.exe.dev/gitstate"
 15	"shelley.exe.dev/llm"
 16	"shelley.exe.dev/llm/llmhttp"
 17	"shelley.exe.dev/loop"
 18	"shelley.exe.dev/subpub"
 19)
 20
 21var errConversationModelMismatch = errors.New("conversation model mismatch")
 22
 23// ConversationManager manages a single active conversation
 24type ConversationManager struct {
 25	conversationID string
 26	db             *db.DB
 27	loop           *loop.Loop
 28	loopCancel     context.CancelFunc
 29	loopCtx        context.Context
 30	mu             sync.Mutex
 31	lastActivity   time.Time
 32	modelID        string
 33	history        []llm.Message
 34	system         []llm.SystemContent
 35	recordMessage  loop.MessageRecordFunc
 36	logger         *slog.Logger
 37	toolSetConfig  claudetool.ToolSetConfig
 38	toolSet        *claudetool.ToolSet // created per-conversation when loop starts
 39
 40	subpub *subpub.SubPub[StreamResponse]
 41
 42	hydrated              bool
 43	hasConversationEvents bool
 44	cwd                   string // working directory for tools
 45
 46	// agentWorking tracks whether the agent is currently working.
 47	// This is explicitly managed and broadcast to subscribers when it changes.
 48	agentWorking bool
 49
 50	// onStateChange is called when the conversation state changes.
 51	// This allows the server to broadcast state changes to all subscribers.
 52	onStateChange func(state ConversationState)
 53}
 54
 55// NewConversationManager constructs a manager with dependencies but defers hydration until needed.
 56func NewConversationManager(conversationID string, database *db.DB, baseLogger *slog.Logger, toolSetConfig claudetool.ToolSetConfig, recordMessage loop.MessageRecordFunc, onStateChange func(ConversationState)) *ConversationManager {
 57	logger := baseLogger
 58	if logger == nil {
 59		logger = slog.Default()
 60	}
 61	logger = logger.With("conversationID", conversationID)
 62
 63	return &ConversationManager{
 64		conversationID: conversationID,
 65		db:             database,
 66		lastActivity:   time.Now(),
 67		recordMessage:  recordMessage,
 68		logger:         logger,
 69		toolSetConfig:  toolSetConfig,
 70		subpub:         subpub.New[StreamResponse](),
 71		onStateChange:  onStateChange,
 72	}
 73}
 74
 75// SetAgentWorking updates the agent working state and notifies the server to broadcast.
 76func (cm *ConversationManager) SetAgentWorking(working bool) {
 77	cm.mu.Lock()
 78	if cm.agentWorking == working {
 79		cm.mu.Unlock()
 80		return
 81	}
 82	cm.agentWorking = working
 83	onStateChange := cm.onStateChange
 84	convID := cm.conversationID
 85	modelID := cm.modelID
 86	cm.mu.Unlock()
 87
 88	cm.logger.Debug("agent working state changed", "working", working)
 89	if onStateChange != nil {
 90		onStateChange(ConversationState{
 91			ConversationID: convID,
 92			Working:        working,
 93			Model:          modelID,
 94		})
 95	}
 96}
 97
 98// IsAgentWorking returns the current agent working state.
 99func (cm *ConversationManager) IsAgentWorking() bool {
100	cm.mu.Lock()
101	defer cm.mu.Unlock()
102	return cm.agentWorking
103}
104
105// GetModel returns the model ID used by this conversation.
106func (cm *ConversationManager) GetModel() string {
107	cm.mu.Lock()
108	defer cm.mu.Unlock()
109	return cm.modelID
110}
111
112// Hydrate loads conversation state from the database, generating a system prompt if missing.
113func (cm *ConversationManager) Hydrate(ctx context.Context) error {
114	cm.mu.Lock()
115	if cm.hydrated {
116		cm.lastActivity = time.Now()
117		cm.mu.Unlock()
118		return nil
119	}
120	cm.mu.Unlock()
121
122	conversation, err := cm.db.GetConversationByID(ctx, cm.conversationID)
123	if err != nil {
124		return fmt.Errorf("conversation not found: %w", err)
125	}
126
127	var messages []generated.Message
128	err = cm.db.Queries(ctx, func(q *generated.Queries) error {
129		var err error
130		// Use ListMessagesForContext to exclude messages marked as excluded_from_context
131		messages, err = q.ListMessagesForContext(ctx, cm.conversationID)
132		return err
133	})
134	if err != nil {
135		return fmt.Errorf("failed to get conversation history: %w", err)
136	}
137
138	// Load cwd from conversation if available - must happen before generating system prompt
139	// so that the system prompt includes guidance files from the context directory
140	cwd := ""
141	if conversation.Cwd != nil {
142		cwd = *conversation.Cwd
143	}
144	cm.cwd = cwd
145
146	// Load model from conversation if available
147	var modelID string
148	if conversation.Model != nil {
149		modelID = *conversation.Model
150	}
151
152	// Generate system prompt if missing:
153	// - For user-initiated conversations: full system prompt
154	// - For subagent conversations (has parent): minimal subagent prompt
155	if !hasSystemMessage(messages) {
156		var systemMsg *generated.Message
157		var err error
158		if conversation.ParentConversationID != nil {
159			// Subagent conversation - use minimal prompt
160			systemMsg, err = cm.createSubagentSystemPrompt(ctx)
161		} else if conversation.UserInitiated {
162			// User-initiated conversation - use full prompt
163			systemMsg, err = cm.createSystemPrompt(ctx)
164		}
165		if err != nil {
166			return err
167		}
168		if systemMsg != nil {
169			messages = append(messages, *systemMsg)
170		}
171	}
172
173	history, system := cm.partitionMessages(messages)
174
175	cm.mu.Lock()
176	cm.history = history
177	cm.system = system
178	cm.hasConversationEvents = len(history) > 0
179	cm.lastActivity = time.Now()
180	cm.hydrated = true
181	cm.modelID = modelID
182	cm.mu.Unlock()
183
184	if modelID != "" {
185		cm.logger.Info("Loaded model from conversation", "model", modelID)
186	}
187	cm.logSystemPromptState(system, len(messages))
188
189	return nil
190}
191
192// AcceptUserMessage enqueues a user message, ensuring the loop is ready first.
193// The message is recorded to the database immediately so it appears in the UI,
194// even if the loop is busy processing a previous request.
195func (cm *ConversationManager) AcceptUserMessage(ctx context.Context, service llm.Service, modelID string, message llm.Message) (bool, error) {
196	if service == nil {
197		return false, fmt.Errorf("llm service is required")
198	}
199
200	if err := cm.Hydrate(ctx); err != nil {
201		return false, err
202	}
203
204	if err := cm.ensureLoop(service, modelID); err != nil {
205		return false, err
206	}
207
208	cm.mu.Lock()
209	isFirst := !cm.hasConversationEvents
210	cm.hasConversationEvents = true
211	loopInstance := cm.loop
212	cm.lastActivity = time.Now()
213	recordMessage := cm.recordMessage
214	cm.mu.Unlock()
215
216	if loopInstance == nil {
217		return false, fmt.Errorf("conversation loop not initialized")
218	}
219
220	// Record the user message to the database immediately so it appears in the UI,
221	// even if the loop is busy processing a previous request
222	if recordMessage != nil {
223		if err := recordMessage(ctx, message, llm.Usage{}); err != nil {
224			cm.logger.Error("failed to record user message immediately", "error", err)
225			// Continue anyway - the loop will also try to record it
226		}
227	}
228
229	loopInstance.QueueUserMessage(message)
230
231	// Mark agent as working - we just queued work for the loop
232	cm.SetAgentWorking(true)
233
234	return isFirst, nil
235}
236
237// Touch updates last activity timestamp.
238func (cm *ConversationManager) Touch() {
239	cm.mu.Lock()
240	cm.lastActivity = time.Now()
241	cm.mu.Unlock()
242}
243
244func hasSystemMessage(messages []generated.Message) bool {
245	for _, msg := range messages {
246		if msg.Type == string(db.MessageTypeSystem) {
247			return true
248		}
249	}
250	return false
251}
252
253func (cm *ConversationManager) createSystemPrompt(ctx context.Context) (*generated.Message, error) {
254	systemPrompt, err := GenerateSystemPrompt(cm.cwd)
255	if err != nil {
256		return nil, fmt.Errorf("failed to generate system prompt: %w", err)
257	}
258
259	if systemPrompt == "" {
260		cm.logger.Info("Skipping empty system prompt generation")
261		return nil, nil
262	}
263
264	systemMessage := llm.Message{
265		Role:    llm.MessageRoleUser,
266		Content: []llm.Content{{Type: llm.ContentTypeText, Text: systemPrompt}},
267	}
268
269	created, err := cm.db.CreateMessage(ctx, db.CreateMessageParams{
270		ConversationID: cm.conversationID,
271		Type:           db.MessageTypeSystem,
272		LLMData:        systemMessage,
273		UsageData:      llm.Usage{},
274	})
275	if err != nil {
276		return nil, fmt.Errorf("failed to store system prompt: %w", err)
277	}
278
279	if err := cm.db.QueriesTx(ctx, func(q *generated.Queries) error {
280		return q.UpdateConversationTimestamp(ctx, cm.conversationID)
281	}); err != nil {
282		cm.logger.Warn("Failed to update conversation timestamp after system prompt", "error", err)
283	}
284
285	cm.logger.Info("Stored system prompt", "length", len(systemPrompt))
286	return created, nil
287}
288
289func (cm *ConversationManager) createSubagentSystemPrompt(ctx context.Context) (*generated.Message, error) {
290	systemPrompt, err := GenerateSubagentSystemPrompt(cm.cwd)
291	if err != nil {
292		return nil, fmt.Errorf("failed to generate subagent system prompt: %w", err)
293	}
294
295	if systemPrompt == "" {
296		cm.logger.Info("Skipping empty subagent system prompt generation")
297		return nil, nil
298	}
299
300	systemMessage := llm.Message{
301		Role:    llm.MessageRoleUser,
302		Content: []llm.Content{{Type: llm.ContentTypeText, Text: systemPrompt}},
303	}
304
305	created, err := cm.db.CreateMessage(ctx, db.CreateMessageParams{
306		ConversationID: cm.conversationID,
307		Type:           db.MessageTypeSystem,
308		LLMData:        systemMessage,
309		UsageData:      llm.Usage{},
310	})
311	if err != nil {
312		return nil, fmt.Errorf("failed to store subagent system prompt: %w", err)
313	}
314
315	cm.logger.Info("Stored subagent system prompt", "length", len(systemPrompt))
316	return created, nil
317}
318
319func (cm *ConversationManager) partitionMessages(messages []generated.Message) ([]llm.Message, []llm.SystemContent) {
320	var history []llm.Message
321	var system []llm.SystemContent
322
323	for _, msg := range messages {
324		// Skip gitinfo messages - they are user-visible only, not sent to LLM
325		if msg.Type == string(db.MessageTypeGitInfo) {
326			continue
327		}
328
329		// Skip error messages - they are system-generated for user visibility,
330		// but should not be sent to the LLM as they are not part of the conversation
331		if msg.Type == string(db.MessageTypeError) {
332			continue
333		}
334
335		llmMsg, err := convertToLLMMessage(msg)
336		if err != nil {
337			cm.logger.Warn("Failed to convert message to LLM format", "messageID", msg.MessageID, "error", err)
338			continue
339		}
340
341		if msg.Type == string(db.MessageTypeSystem) {
342			for _, content := range llmMsg.Content {
343				if content.Type == llm.ContentTypeText && content.Text != "" {
344					system = append(system, llm.SystemContent{Type: "text", Text: content.Text})
345				}
346			}
347			continue
348		}
349
350		history = append(history, llmMsg)
351	}
352
353	return history, system
354}
355
356func (cm *ConversationManager) logSystemPromptState(system []llm.SystemContent, messageCount int) {
357	if len(system) == 0 {
358		cm.logger.Warn("No system prompt found in database", "message_count", messageCount)
359		return
360	}
361
362	length := 0
363	for _, sys := range system {
364		length += len(sys.Text)
365	}
366	cm.logger.Info("Loaded system prompt from database", "system_items", len(system), "total_length", length)
367}
368
369func (cm *ConversationManager) ensureLoop(service llm.Service, modelID string) error {
370	cm.mu.Lock()
371	if cm.loop != nil {
372		existingModel := cm.modelID
373		cm.mu.Unlock()
374		if existingModel != "" && modelID != "" && existingModel != modelID {
375			return fmt.Errorf("%w: conversation already uses model %s; requested %s", errConversationModelMismatch, existingModel, modelID)
376		}
377		return nil
378	}
379
380	history := append([]llm.Message(nil), cm.history...)
381	system := append([]llm.SystemContent(nil), cm.system...)
382	recordMessage := cm.recordMessage
383	logger := cm.logger
384	cwd := cm.cwd
385	toolSetConfig := cm.toolSetConfig
386	conversationID := cm.conversationID
387	db := cm.db
388	cm.mu.Unlock()
389
390	// Create tools for this conversation with the conversation's working directory
391	toolSetConfig.WorkingDir = cwd
392	toolSetConfig.ModelID = modelID
393	toolSetConfig.ConversationID = conversationID
394	toolSetConfig.ParentConversationID = conversationID // For subagent tool
395	toolSetConfig.OnWorkingDirChange = func(newDir string) {
396		// Persist working directory change to database
397		if err := db.UpdateConversationCwd(context.Background(), conversationID, newDir); err != nil {
398			logger.Error("failed to persist working directory change", "error", err, "newDir", newDir)
399		}
400	}
401
402	// Create a context with the conversation ID for LLM request recording/prefix dedup
403	baseCtx := llmhttp.WithConversationID(context.Background(), conversationID)
404	processCtx, cancel := context.WithTimeout(baseCtx, 12*time.Hour)
405	toolSet := claudetool.NewToolSet(processCtx, toolSetConfig)
406
407	loopInstance := loop.NewLoop(loop.Config{
408		LLM:           service,
409		History:       history,
410		Tools:         toolSet.Tools(),
411		RecordMessage: recordMessage,
412		Logger:        logger,
413		System:        system,
414		WorkingDir:    cwd,
415		GetWorkingDir: toolSet.WorkingDir().Get,
416		OnGitStateChange: func(ctx context.Context, state *gitstate.GitState) {
417			cm.recordGitStateChange(ctx, state)
418		},
419	})
420
421	cm.mu.Lock()
422	if cm.loop != nil {
423		cm.mu.Unlock()
424		cancel()
425		toolSet.Cleanup()
426		existingModel := cm.modelID
427		if existingModel != "" && modelID != "" && existingModel != modelID {
428			return fmt.Errorf("%w: conversation already uses model %s; requested %s", errConversationModelMismatch, existingModel, modelID)
429		}
430		return nil
431	}
432	// Check if we need to persist the model (for conversations created before model column existed)
433	needsPersist := cm.modelID == "" && modelID != ""
434	cm.loop = loopInstance
435	cm.loopCancel = cancel
436	cm.loopCtx = processCtx
437	cm.modelID = modelID
438	cm.toolSet = toolSet
439	cm.history = nil
440	cm.system = nil
441	cm.mu.Unlock()
442
443	// Persist model for legacy conversations
444	if needsPersist {
445		if err := db.UpdateConversationModel(context.Background(), conversationID, modelID); err != nil {
446			logger.Error("failed to persist model for legacy conversation", "error", err)
447		}
448	}
449
450	go func() {
451		if err := loopInstance.Go(processCtx); err != nil && err != context.DeadlineExceeded && err != context.Canceled {
452			if logger != nil {
453				logger.Error("Conversation loop stopped", "error", err)
454			} else {
455				slog.Default().Error("Conversation loop stopped", "error", err)
456			}
457		}
458	}()
459
460	return nil
461}
462
463func (cm *ConversationManager) stopLoop() {
464	cm.mu.Lock()
465	cancel := cm.loopCancel
466	toolSet := cm.toolSet
467	cm.loopCancel = nil
468	cm.loopCtx = nil
469	cm.loop = nil
470	cm.modelID = ""
471	cm.toolSet = nil
472	cm.mu.Unlock()
473
474	if cancel != nil {
475		cancel()
476	}
477	if toolSet != nil {
478		toolSet.Cleanup()
479	}
480}
481
482// CancelConversation cancels the current conversation loop and records a cancelled tool result if a tool was in progress
483func (cm *ConversationManager) CancelConversation(ctx context.Context) error {
484	cm.mu.Lock()
485	loopInstance := cm.loop
486	loopCtx := cm.loopCtx
487	cancel := cm.loopCancel
488	cm.mu.Unlock()
489
490	if loopInstance == nil {
491		cm.logger.Info("No active loop to cancel")
492		return nil
493	}
494
495	cm.logger.Info("Cancelling conversation")
496
497	// Check if there's an in-progress tool call by examining the history
498	history := loopInstance.GetHistory()
499	var inProgressToolID string
500	var inProgressToolName string
501
502	// Find tool_uses that don't have corresponding tool_results.
503	// Strategy:
504	// 1. Find the last assistant message that contains tool_uses
505	// 2. Collect all tool_result IDs from user messages AFTER that assistant message
506	// 3. Find tool_uses that don't have matching results
507
508	// Step 1: Find the index of the last assistant message with tool_uses
509	lastToolUseAssistantIdx := -1
510	for i := len(history) - 1; i >= 0; i-- {
511		msg := history[i]
512		if msg.Role == llm.MessageRoleAssistant {
513			hasToolUse := false
514			for _, content := range msg.Content {
515				if content.Type == llm.ContentTypeToolUse {
516					hasToolUse = true
517					break
518				}
519			}
520			if hasToolUse {
521				lastToolUseAssistantIdx = i
522				break
523			}
524		}
525	}
526
527	if lastToolUseAssistantIdx >= 0 {
528		// Step 2: Collect all tool_result IDs from messages after the assistant message
529		toolResultIDs := make(map[string]bool)
530		for i := lastToolUseAssistantIdx + 1; i < len(history); i++ {
531			msg := history[i]
532			if msg.Role == llm.MessageRoleUser {
533				for _, content := range msg.Content {
534					if content.Type == llm.ContentTypeToolResult {
535						toolResultIDs[content.ToolUseID] = true
536					}
537				}
538			}
539		}
540
541		// Step 3: Find the first tool_use that doesn't have a result
542		assistantMsg := history[lastToolUseAssistantIdx]
543		for _, content := range assistantMsg.Content {
544			if content.Type == llm.ContentTypeToolUse {
545				if !toolResultIDs[content.ID] {
546					inProgressToolID = content.ID
547					inProgressToolName = content.ToolName
548					break
549				}
550			}
551		}
552	}
553
554	// Cancel the context
555	if cancel != nil {
556		cancel()
557	}
558
559	// Wait briefly for the loop to stop
560	if loopCtx != nil {
561		select {
562		case <-loopCtx.Done():
563		case <-time.After(100 * time.Millisecond):
564		}
565	}
566
567	// Record cancellation messages
568	if inProgressToolID != "" {
569		// If there was an in-progress tool, record a cancelled result
570		cm.logger.Info("Recording cancelled tool result", "tool_id", inProgressToolID, "tool_name", inProgressToolName)
571		cancelTime := time.Now()
572		cancelledMessage := llm.Message{
573			Role: llm.MessageRoleUser,
574			Content: []llm.Content{
575				{
576					Type:             llm.ContentTypeToolResult,
577					ToolUseID:        inProgressToolID,
578					ToolError:        true,
579					ToolResult:       []llm.Content{{Type: llm.ContentTypeText, Text: "Tool execution cancelled by user"}},
580					ToolUseStartTime: &cancelTime,
581					ToolUseEndTime:   &cancelTime,
582				},
583			},
584		}
585
586		if err := cm.recordMessage(ctx, cancelledMessage, llm.Usage{}); err != nil {
587			cm.logger.Error("Failed to record cancelled tool result", "error", err)
588			return fmt.Errorf("failed to record cancelled tool result: %w", err)
589		}
590	}
591
592	// Always record an assistant message with EndOfTurn to properly end the turn
593	// This ensures agentWorking() returns false, even if no tool was executing
594	endTurnMessage := llm.Message{
595		Role:      llm.MessageRoleAssistant,
596		Content:   []llm.Content{{Type: llm.ContentTypeText, Text: "[Operation cancelled]"}},
597		EndOfTurn: true,
598	}
599
600	if err := cm.recordMessage(ctx, endTurnMessage, llm.Usage{}); err != nil {
601		cm.logger.Error("Failed to record end turn message", "error", err)
602		return fmt.Errorf("failed to record end turn message: %w", err)
603	}
604
605	// Mark agent as not working
606	cm.SetAgentWorking(false)
607
608	cm.mu.Lock()
609	cm.loopCancel = nil
610	cm.loopCtx = nil
611	cm.loop = nil
612	cm.modelID = ""
613	// Reset hydrated so that the next AcceptUserMessage will reload history from the database
614	cm.hydrated = false
615	cm.mu.Unlock()
616
617	return nil
618}
619
620// GitInfoUserData is the structured data stored in user_data for gitinfo messages.
621type GitInfoUserData struct {
622	Worktree string `json:"worktree"`
623	Branch   string `json:"branch"`
624	Commit   string `json:"commit"`
625	Subject  string `json:"subject"`
626	Text     string `json:"text"` // Human-readable description
627}
628
629// recordGitStateChange creates a gitinfo message when git state changes.
630// This message is visible to users in the UI but is not sent to the LLM.
631func (cm *ConversationManager) recordGitStateChange(ctx context.Context, state *gitstate.GitState) {
632	if state == nil || !state.IsRepo {
633		return
634	}
635
636	// Create a gitinfo message with the state description
637	message := llm.Message{
638		Role:    llm.MessageRoleAssistant,
639		Content: []llm.Content{{Type: llm.ContentTypeText, Text: state.String()}},
640	}
641
642	userData := GitInfoUserData{
643		Worktree: state.Worktree,
644		Branch:   state.Branch,
645		Commit:   state.Commit,
646		Subject:  state.Subject,
647		Text:     state.String(),
648	}
649
650	createdMsg, err := cm.db.CreateMessage(ctx, db.CreateMessageParams{
651		ConversationID: cm.conversationID,
652		Type:           db.MessageTypeGitInfo,
653		LLMData:        message,
654		UserData:       userData,
655		UsageData:      llm.Usage{},
656	})
657	if err != nil {
658		cm.logger.Error("Failed to record git state change", "error", err)
659		return
660	}
661
662	cm.logger.Debug("Recorded git state change", "state", state.String())
663
664	// Notify subscribers so the UI updates
665	go cm.notifyGitStateChange(context.WithoutCancel(ctx), createdMsg)
666}
667
668// notifyGitStateChange publishes a gitinfo message to subscribers.
669func (cm *ConversationManager) notifyGitStateChange(ctx context.Context, msg *generated.Message) {
670	var conversation generated.Conversation
671	err := cm.db.Queries(ctx, func(q *generated.Queries) error {
672		var err error
673		conversation, err = q.GetConversation(ctx, cm.conversationID)
674		return err
675	})
676	if err != nil {
677		cm.logger.Error("Failed to get conversation for git state notification", "error", err)
678		return
679	}
680
681	apiMessages := toAPIMessages([]generated.Message{*msg})
682	streamData := StreamResponse{
683		Messages:     apiMessages,
684		Conversation: conversation,
685	}
686	cm.subpub.Publish(msg.SequenceID, streamData)
687}