convo.go

  1package server
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"sync"
  9	"time"
 10
 11	"shelley.exe.dev/claudetool"
 12	"shelley.exe.dev/db"
 13	"shelley.exe.dev/db/generated"
 14	"shelley.exe.dev/gitstate"
 15	"shelley.exe.dev/llm"
 16	"shelley.exe.dev/llm/llmhttp"
 17	"shelley.exe.dev/loop"
 18	"shelley.exe.dev/subpub"
 19)
 20
 21var errConversationModelMismatch = errors.New("conversation model mismatch")
 22
 23// ConversationManager manages a single active conversation
 24type ConversationManager struct {
 25	conversationID string
 26	db             *db.DB
 27	loop           *loop.Loop
 28	loopCancel     context.CancelFunc
 29	loopCtx        context.Context
 30	mu             sync.Mutex
 31	lastActivity   time.Time
 32	modelID        string
 33	history        []llm.Message
 34	system         []llm.SystemContent
 35	recordMessage  loop.MessageRecordFunc
 36	logger         *slog.Logger
 37	toolSetConfig  claudetool.ToolSetConfig
 38	toolSet        *claudetool.ToolSet // created per-conversation when loop starts
 39
 40	subpub *subpub.SubPub[StreamResponse]
 41
 42	hydrated              bool
 43	hasConversationEvents bool
 44	cwd                   string // working directory for tools
 45
 46	// agentWorking tracks whether the agent is currently working.
 47	// This is explicitly managed and broadcast to subscribers when it changes.
 48	agentWorking bool
 49
 50	// onStateChange is called when the conversation state changes.
 51	// This allows the server to broadcast state changes to all subscribers.
 52	onStateChange func(state ConversationState)
 53}
 54
 55// NewConversationManager constructs a manager with dependencies but defers hydration until needed.
 56func NewConversationManager(conversationID string, database *db.DB, baseLogger *slog.Logger, toolSetConfig claudetool.ToolSetConfig, recordMessage loop.MessageRecordFunc, onStateChange func(ConversationState)) *ConversationManager {
 57	logger := baseLogger
 58	if logger == nil {
 59		logger = slog.Default()
 60	}
 61	logger = logger.With("conversationID", conversationID)
 62
 63	return &ConversationManager{
 64		conversationID: conversationID,
 65		db:             database,
 66		lastActivity:   time.Now(),
 67		recordMessage:  recordMessage,
 68		logger:         logger,
 69		toolSetConfig:  toolSetConfig,
 70		subpub:         subpub.New[StreamResponse](),
 71		onStateChange:  onStateChange,
 72	}
 73}
 74
 75// SetAgentWorking updates the agent working state and notifies the server to broadcast.
 76func (cm *ConversationManager) SetAgentWorking(working bool) {
 77	cm.mu.Lock()
 78	if cm.agentWorking == working {
 79		cm.mu.Unlock()
 80		return
 81	}
 82	cm.agentWorking = working
 83	onStateChange := cm.onStateChange
 84	convID := cm.conversationID
 85	modelID := cm.modelID
 86	cm.mu.Unlock()
 87
 88	cm.logger.Debug("agent working state changed", "working", working)
 89	if onStateChange != nil {
 90		onStateChange(ConversationState{
 91			ConversationID: convID,
 92			Working:        working,
 93			Model:          modelID,
 94		})
 95	}
 96}
 97
 98// IsAgentWorking returns the current agent working state.
 99func (cm *ConversationManager) IsAgentWorking() bool {
100	cm.mu.Lock()
101	defer cm.mu.Unlock()
102	return cm.agentWorking
103}
104
105// GetModel returns the model ID used by this conversation.
106func (cm *ConversationManager) GetModel() string {
107	cm.mu.Lock()
108	defer cm.mu.Unlock()
109	return cm.modelID
110}
111
112// Hydrate loads conversation state from the database, generating a system prompt if missing.
113func (cm *ConversationManager) Hydrate(ctx context.Context) error {
114	cm.mu.Lock()
115	if cm.hydrated {
116		cm.lastActivity = time.Now()
117		cm.mu.Unlock()
118		return nil
119	}
120	cm.mu.Unlock()
121
122	conversation, err := cm.db.GetConversationByID(ctx, cm.conversationID)
123	if err != nil {
124		return fmt.Errorf("conversation not found: %w", err)
125	}
126
127	var messages []generated.Message
128	err = cm.db.Queries(ctx, func(q *generated.Queries) error {
129		var err error
130		// Use ListMessagesForContext to exclude messages marked as excluded_from_context
131		messages, err = q.ListMessagesForContext(ctx, cm.conversationID)
132		return err
133	})
134	if err != nil {
135		return fmt.Errorf("failed to get conversation history: %w", err)
136	}
137
138	// Load cwd from conversation if available - must happen before generating system prompt
139	// so that the system prompt includes guidance files from the context directory
140	cwd := ""
141	if conversation.Cwd != nil {
142		cwd = *conversation.Cwd
143	}
144	cm.cwd = cwd
145
146	// Load model from conversation if available
147	var modelID string
148	if conversation.Model != nil {
149		modelID = *conversation.Model
150	}
151
152	// Generate system prompt if missing:
153	// - For user-initiated conversations: full system prompt
154	// - For subagent conversations (has parent): minimal subagent prompt
155	if !hasSystemMessage(messages) {
156		var systemMsg *generated.Message
157		var err error
158		if conversation.ParentConversationID != nil {
159			// Subagent conversation - use minimal prompt
160			systemMsg, err = cm.createSubagentSystemPrompt(ctx)
161		} else if conversation.UserInitiated {
162			// User-initiated conversation - use full prompt
163			systemMsg, err = cm.createSystemPrompt(ctx)
164		}
165		if err != nil {
166			return err
167		}
168		if systemMsg != nil {
169			messages = append(messages, *systemMsg)
170		}
171	}
172
173	history, system := cm.partitionMessages(messages)
174
175	cm.mu.Lock()
176	cm.history = history
177	cm.system = system
178	cm.hasConversationEvents = len(history) > 0
179	cm.lastActivity = time.Now()
180	cm.hydrated = true
181	cm.modelID = modelID
182	cm.mu.Unlock()
183
184	if modelID != "" {
185		cm.logger.Info("Loaded model from conversation", "model", modelID)
186	}
187	cm.logSystemPromptState(system, len(messages))
188
189	return nil
190}
191
192// AcceptUserMessage enqueues a user message, ensuring the loop is ready first.
193// The message is recorded to the database immediately so it appears in the UI,
194// even if the loop is busy processing a previous request.
195func (cm *ConversationManager) AcceptUserMessage(ctx context.Context, service llm.Service, modelID string, message llm.Message) (bool, error) {
196	if service == nil {
197		return false, fmt.Errorf("llm service is required")
198	}
199
200	if err := cm.Hydrate(ctx); err != nil {
201		return false, err
202	}
203
204	if err := cm.ensureLoop(service, modelID); err != nil {
205		return false, err
206	}
207
208	cm.mu.Lock()
209	isFirst := !cm.hasConversationEvents
210	cm.hasConversationEvents = true
211	loopInstance := cm.loop
212	cm.lastActivity = time.Now()
213	recordMessage := cm.recordMessage
214	cm.mu.Unlock()
215
216	if loopInstance == nil {
217		return false, fmt.Errorf("conversation loop not initialized")
218	}
219
220	// Record the user message to the database immediately so it appears in the UI,
221	// even if the loop is busy processing a previous request
222	if recordMessage != nil {
223		if err := recordMessage(ctx, message, llm.Usage{}); err != nil {
224			cm.logger.Error("failed to record user message immediately", "error", err)
225			// Continue anyway - the loop will also try to record it
226		}
227	}
228
229	loopInstance.QueueUserMessage(message)
230
231	// Mark agent as working - we just queued work for the loop
232	cm.SetAgentWorking(true)
233
234	return isFirst, nil
235}
236
237// Touch updates last activity timestamp.
238func (cm *ConversationManager) Touch() {
239	cm.mu.Lock()
240	cm.lastActivity = time.Now()
241	cm.mu.Unlock()
242}
243
244func hasSystemMessage(messages []generated.Message) bool {
245	for _, msg := range messages {
246		if msg.Type == string(db.MessageTypeSystem) {
247			return true
248		}
249	}
250	return false
251}
252
253func (cm *ConversationManager) createSystemPrompt(ctx context.Context) (*generated.Message, error) {
254	systemPrompt, err := GenerateSystemPrompt(cm.cwd)
255	if err != nil {
256		return nil, fmt.Errorf("failed to generate system prompt: %w", err)
257	}
258
259	if systemPrompt == "" {
260		cm.logger.Info("Skipping empty system prompt generation")
261		return nil, nil
262	}
263
264	systemMessage := llm.Message{
265		Role:    llm.MessageRoleUser,
266		Content: []llm.Content{{Type: llm.ContentTypeText, Text: systemPrompt}},
267	}
268
269	created, err := cm.db.CreateMessage(ctx, db.CreateMessageParams{
270		ConversationID: cm.conversationID,
271		Type:           db.MessageTypeSystem,
272		LLMData:        systemMessage,
273		UsageData:      llm.Usage{},
274	})
275	if err != nil {
276		return nil, fmt.Errorf("failed to store system prompt: %w", err)
277	}
278
279	if err := cm.db.QueriesTx(ctx, func(q *generated.Queries) error {
280		return q.UpdateConversationTimestamp(ctx, cm.conversationID)
281	}); err != nil {
282		cm.logger.Warn("Failed to update conversation timestamp after system prompt", "error", err)
283	}
284
285	cm.logger.Info("Stored system prompt", "length", len(systemPrompt))
286	return created, nil
287}
288
289func (cm *ConversationManager) createSubagentSystemPrompt(ctx context.Context) (*generated.Message, error) {
290	systemPrompt, err := GenerateSubagentSystemPrompt(cm.cwd)
291	if err != nil {
292		return nil, fmt.Errorf("failed to generate subagent system prompt: %w", err)
293	}
294
295	if systemPrompt == "" {
296		cm.logger.Info("Skipping empty subagent system prompt generation")
297		return nil, nil
298	}
299
300	systemMessage := llm.Message{
301		Role:    llm.MessageRoleUser,
302		Content: []llm.Content{{Type: llm.ContentTypeText, Text: systemPrompt}},
303	}
304
305	created, err := cm.db.CreateMessage(ctx, db.CreateMessageParams{
306		ConversationID: cm.conversationID,
307		Type:           db.MessageTypeSystem,
308		LLMData:        systemMessage,
309		UsageData:      llm.Usage{},
310	})
311	if err != nil {
312		return nil, fmt.Errorf("failed to store subagent system prompt: %w", err)
313	}
314
315	cm.logger.Info("Stored subagent system prompt", "length", len(systemPrompt))
316	return created, nil
317}
318
319func (cm *ConversationManager) partitionMessages(messages []generated.Message) ([]llm.Message, []llm.SystemContent) {
320	var history []llm.Message
321	var system []llm.SystemContent
322
323	for _, msg := range messages {
324		// Skip gitinfo messages - they are user-visible only, not sent to LLM
325		if msg.Type == string(db.MessageTypeGitInfo) {
326			continue
327		}
328
329		// Skip error messages - they are system-generated for user visibility,
330		// but should not be sent to the LLM as they are not part of the conversation
331		if msg.Type == string(db.MessageTypeError) {
332			continue
333		}
334
335		llmMsg, err := convertToLLMMessage(msg)
336		if err != nil {
337			cm.logger.Warn("Failed to convert message to LLM format", "messageID", msg.MessageID, "error", err)
338			continue
339		}
340
341		if msg.Type == string(db.MessageTypeSystem) {
342			for _, content := range llmMsg.Content {
343				if content.Type == llm.ContentTypeText && content.Text != "" {
344					system = append(system, llm.SystemContent{Type: "text", Text: content.Text})
345				}
346			}
347			continue
348		}
349
350		history = append(history, llmMsg)
351	}
352
353	return history, system
354}
355
356func (cm *ConversationManager) logSystemPromptState(system []llm.SystemContent, messageCount int) {
357	if len(system) == 0 {
358		cm.logger.Warn("No system prompt found in database", "message_count", messageCount)
359		return
360	}
361
362	length := 0
363	for _, sys := range system {
364		length += len(sys.Text)
365	}
366	cm.logger.Info("Loaded system prompt from database", "system_items", len(system), "total_length", length)
367}
368
369func (cm *ConversationManager) ensureLoop(service llm.Service, modelID string) error {
370	cm.mu.Lock()
371	if cm.loop != nil {
372		existingModel := cm.modelID
373		cm.mu.Unlock()
374		if existingModel != "" && modelID != "" && existingModel != modelID {
375			return fmt.Errorf("%w: conversation already uses model %s; requested %s", errConversationModelMismatch, existingModel, modelID)
376		}
377		return nil
378	}
379
380	history := append([]llm.Message(nil), cm.history...)
381	system := append([]llm.SystemContent(nil), cm.system...)
382	recordMessage := cm.recordMessage
383	logger := cm.logger
384	cwd := cm.cwd
385	toolSetConfig := cm.toolSetConfig
386	conversationID := cm.conversationID
387	db := cm.db
388	cm.mu.Unlock()
389
390	// Create tools for this conversation with the conversation's working directory
391	toolSetConfig.WorkingDir = cwd
392	toolSetConfig.ModelID = modelID
393	toolSetConfig.ConversationID = conversationID
394	toolSetConfig.ParentConversationID = conversationID // For subagent tool
395	toolSetConfig.OnWorkingDirChange = func(newDir string) {
396		// Persist working directory change to database
397		if err := db.UpdateConversationCwd(context.Background(), conversationID, newDir); err != nil {
398			logger.Error("failed to persist working directory change", "error", err, "newDir", newDir)
399			return
400		}
401
402		// Update local cwd
403		cm.mu.Lock()
404		cm.cwd = newDir
405		cm.mu.Unlock()
406
407		// Broadcast conversation update to subscribers so UI gets the new cwd
408		var conv generated.Conversation
409		err := db.Queries(context.Background(), func(q *generated.Queries) error {
410			var err error
411			conv, err = q.GetConversation(context.Background(), conversationID)
412			return err
413		})
414		if err != nil {
415			logger.Error("failed to get conversation for cwd broadcast", "error", err)
416			return
417		}
418		cm.subpub.Broadcast(StreamResponse{
419			Conversation: conv,
420		})
421	}
422
423	// Create a context with the conversation ID for LLM request recording/prefix dedup
424	baseCtx := llmhttp.WithConversationID(context.Background(), conversationID)
425	processCtx, cancel := context.WithTimeout(baseCtx, 12*time.Hour)
426	toolSet := claudetool.NewToolSet(processCtx, toolSetConfig)
427
428	loopInstance := loop.NewLoop(loop.Config{
429		LLM:           service,
430		History:       history,
431		Tools:         toolSet.Tools(),
432		RecordMessage: recordMessage,
433		Logger:        logger,
434		System:        system,
435		WorkingDir:    cwd,
436		GetWorkingDir: toolSet.WorkingDir().Get,
437		OnGitStateChange: func(ctx context.Context, state *gitstate.GitState) {
438			cm.recordGitStateChange(ctx, state)
439		},
440	})
441
442	cm.mu.Lock()
443	if cm.loop != nil {
444		cm.mu.Unlock()
445		cancel()
446		toolSet.Cleanup()
447		existingModel := cm.modelID
448		if existingModel != "" && modelID != "" && existingModel != modelID {
449			return fmt.Errorf("%w: conversation already uses model %s; requested %s", errConversationModelMismatch, existingModel, modelID)
450		}
451		return nil
452	}
453	// Check if we need to persist the model (for conversations created before model column existed)
454	needsPersist := cm.modelID == "" && modelID != ""
455	cm.loop = loopInstance
456	cm.loopCancel = cancel
457	cm.loopCtx = processCtx
458	cm.modelID = modelID
459	cm.toolSet = toolSet
460	cm.history = nil
461	cm.system = nil
462	cm.mu.Unlock()
463
464	// Persist model for legacy conversations
465	if needsPersist {
466		if err := db.UpdateConversationModel(context.Background(), conversationID, modelID); err != nil {
467			logger.Error("failed to persist model for legacy conversation", "error", err)
468		}
469	}
470
471	go func() {
472		if err := loopInstance.Go(processCtx); err != nil && err != context.DeadlineExceeded && err != context.Canceled {
473			if logger != nil {
474				logger.Error("Conversation loop stopped", "error", err)
475			} else {
476				slog.Default().Error("Conversation loop stopped", "error", err)
477			}
478		}
479	}()
480
481	return nil
482}
483
484func (cm *ConversationManager) stopLoop() {
485	cm.mu.Lock()
486	cancel := cm.loopCancel
487	toolSet := cm.toolSet
488	cm.loopCancel = nil
489	cm.loopCtx = nil
490	cm.loop = nil
491	cm.modelID = ""
492	cm.toolSet = nil
493	cm.mu.Unlock()
494
495	if cancel != nil {
496		cancel()
497	}
498	if toolSet != nil {
499		toolSet.Cleanup()
500	}
501}
502
503// CancelConversation cancels the current conversation loop and records a cancelled tool result if a tool was in progress
504func (cm *ConversationManager) CancelConversation(ctx context.Context) error {
505	cm.mu.Lock()
506	loopInstance := cm.loop
507	loopCtx := cm.loopCtx
508	cancel := cm.loopCancel
509	cm.mu.Unlock()
510
511	if loopInstance == nil {
512		cm.logger.Info("No active loop to cancel")
513		return nil
514	}
515
516	cm.logger.Info("Cancelling conversation")
517
518	// Check if there's an in-progress tool call by examining the history
519	history := loopInstance.GetHistory()
520	var inProgressToolID string
521	var inProgressToolName string
522
523	// Find tool_uses that don't have corresponding tool_results.
524	// Strategy:
525	// 1. Find the last assistant message that contains tool_uses
526	// 2. Collect all tool_result IDs from user messages AFTER that assistant message
527	// 3. Find tool_uses that don't have matching results
528
529	// Step 1: Find the index of the last assistant message with tool_uses
530	lastToolUseAssistantIdx := -1
531	for i := len(history) - 1; i >= 0; i-- {
532		msg := history[i]
533		if msg.Role == llm.MessageRoleAssistant {
534			hasToolUse := false
535			for _, content := range msg.Content {
536				if content.Type == llm.ContentTypeToolUse {
537					hasToolUse = true
538					break
539				}
540			}
541			if hasToolUse {
542				lastToolUseAssistantIdx = i
543				break
544			}
545		}
546	}
547
548	if lastToolUseAssistantIdx >= 0 {
549		// Step 2: Collect all tool_result IDs from messages after the assistant message
550		toolResultIDs := make(map[string]bool)
551		for i := lastToolUseAssistantIdx + 1; i < len(history); i++ {
552			msg := history[i]
553			if msg.Role == llm.MessageRoleUser {
554				for _, content := range msg.Content {
555					if content.Type == llm.ContentTypeToolResult {
556						toolResultIDs[content.ToolUseID] = true
557					}
558				}
559			}
560		}
561
562		// Step 3: Find the first tool_use that doesn't have a result
563		assistantMsg := history[lastToolUseAssistantIdx]
564		for _, content := range assistantMsg.Content {
565			if content.Type == llm.ContentTypeToolUse {
566				if !toolResultIDs[content.ID] {
567					inProgressToolID = content.ID
568					inProgressToolName = content.ToolName
569					break
570				}
571			}
572		}
573	}
574
575	// Cancel the context
576	if cancel != nil {
577		cancel()
578	}
579
580	// Wait briefly for the loop to stop
581	if loopCtx != nil {
582		select {
583		case <-loopCtx.Done():
584		case <-time.After(100 * time.Millisecond):
585		}
586	}
587
588	// Record cancellation messages
589	if inProgressToolID != "" {
590		// If there was an in-progress tool, record a cancelled result
591		cm.logger.Info("Recording cancelled tool result", "tool_id", inProgressToolID, "tool_name", inProgressToolName)
592		cancelTime := time.Now()
593		cancelledMessage := llm.Message{
594			Role: llm.MessageRoleUser,
595			Content: []llm.Content{
596				{
597					Type:             llm.ContentTypeToolResult,
598					ToolUseID:        inProgressToolID,
599					ToolError:        true,
600					ToolResult:       []llm.Content{{Type: llm.ContentTypeText, Text: "Tool execution cancelled by user"}},
601					ToolUseStartTime: &cancelTime,
602					ToolUseEndTime:   &cancelTime,
603				},
604			},
605		}
606
607		if err := cm.recordMessage(ctx, cancelledMessage, llm.Usage{}); err != nil {
608			cm.logger.Error("Failed to record cancelled tool result", "error", err)
609			return fmt.Errorf("failed to record cancelled tool result: %w", err)
610		}
611	}
612
613	// Always record an assistant message with EndOfTurn to properly end the turn
614	// This ensures agentWorking() returns false, even if no tool was executing
615	endTurnMessage := llm.Message{
616		Role:      llm.MessageRoleAssistant,
617		Content:   []llm.Content{{Type: llm.ContentTypeText, Text: "[Operation cancelled]"}},
618		EndOfTurn: true,
619	}
620
621	if err := cm.recordMessage(ctx, endTurnMessage, llm.Usage{}); err != nil {
622		cm.logger.Error("Failed to record end turn message", "error", err)
623		return fmt.Errorf("failed to record end turn message: %w", err)
624	}
625
626	// Mark agent as not working
627	cm.SetAgentWorking(false)
628
629	cm.mu.Lock()
630	cm.loopCancel = nil
631	cm.loopCtx = nil
632	cm.loop = nil
633	cm.modelID = ""
634	// Reset hydrated so that the next AcceptUserMessage will reload history from the database
635	cm.hydrated = false
636	cm.mu.Unlock()
637
638	return nil
639}
640
641// GitInfoUserData is the structured data stored in user_data for gitinfo messages.
642type GitInfoUserData struct {
643	Worktree string `json:"worktree"`
644	Branch   string `json:"branch"`
645	Commit   string `json:"commit"`
646	Subject  string `json:"subject"`
647	Text     string `json:"text"` // Human-readable description
648}
649
650// recordGitStateChange creates a gitinfo message when git state changes.
651// This message is visible to users in the UI but is not sent to the LLM.
652func (cm *ConversationManager) recordGitStateChange(ctx context.Context, state *gitstate.GitState) {
653	if state == nil || !state.IsRepo {
654		return
655	}
656
657	// Create a gitinfo message with the state description
658	message := llm.Message{
659		Role:    llm.MessageRoleAssistant,
660		Content: []llm.Content{{Type: llm.ContentTypeText, Text: state.String()}},
661	}
662
663	userData := GitInfoUserData{
664		Worktree: state.Worktree,
665		Branch:   state.Branch,
666		Commit:   state.Commit,
667		Subject:  state.Subject,
668		Text:     state.String(),
669	}
670
671	createdMsg, err := cm.db.CreateMessage(ctx, db.CreateMessageParams{
672		ConversationID: cm.conversationID,
673		Type:           db.MessageTypeGitInfo,
674		LLMData:        message,
675		UserData:       userData,
676		UsageData:      llm.Usage{},
677	})
678	if err != nil {
679		cm.logger.Error("Failed to record git state change", "error", err)
680		return
681	}
682
683	cm.logger.Debug("Recorded git state change", "state", state.String())
684
685	// Notify subscribers so the UI updates
686	go cm.notifyGitStateChange(context.WithoutCancel(ctx), createdMsg)
687}
688
689// notifyGitStateChange publishes a gitinfo message to subscribers.
690func (cm *ConversationManager) notifyGitStateChange(ctx context.Context, msg *generated.Message) {
691	var conversation generated.Conversation
692	err := cm.db.Queries(ctx, func(q *generated.Queries) error {
693		var err error
694		conversation, err = q.GetConversation(ctx, cm.conversationID)
695		return err
696	})
697	if err != nil {
698		cm.logger.Error("Failed to get conversation for git state notification", "error", err)
699		return
700	}
701
702	apiMessages := toAPIMessages([]generated.Message{*msg})
703	streamData := StreamResponse{
704		Messages:     apiMessages,
705		Conversation: conversation,
706	}
707	cm.subpub.Publish(msg.SequenceID, streamData)
708}