diff --git a/cmd/go2ts.go b/cmd/go2ts.go index a27e8776b487bd269f0b5393a91b1c0dcb1bed65..da90c50e3a9baea5655eeb482c55302f60e5af5b 100644 --- a/cmd/go2ts.go +++ b/cmd/go2ts.go @@ -66,6 +66,7 @@ func TS() *go2ts.Go2TS { generator.AddMultiple( apiMessageForTS{}, streamResponseForTS{}, + conversationWithStateForTS{}, ) // Generate clean nominal types @@ -87,8 +88,24 @@ type apiMessageForTS struct { EndOfTurn *bool `json:"end_of_turn,omitempty"` } +type conversationStateForTS struct { + ConversationID string `json:"conversation_id"` + Working bool `json:"working"` +} + +type conversationWithStateForTS struct { + ConversationID string `json:"conversation_id"` + Slug *string `json:"slug"` + UserInitiated bool `json:"user_initiated"` + CreatedAt string `json:"created_at"` + UpdatedAt string `json:"updated_at"` + Cwd *string `json:"cwd"` + Archived bool `json:"archived"` + Working bool `json:"working"` +} + type streamResponseForTS struct { - Messages []apiMessageForTS `json:"messages"` - Conversation generated.Conversation `json:"conversation"` - AgentWorking *bool `json:"agent_working,omitempty"` + Messages []apiMessageForTS `json:"messages"` + Conversation generated.Conversation `json:"conversation"` + ConversationState *conversationStateForTS `json:"conversation_state,omitempty"` } diff --git a/server/agent_working_test.go b/server/agent_working_test.go deleted file mode 100644 index 15a9bc70b212b834cc1bd80f934e9ebe5cc42ab0..0000000000000000000000000000000000000000 --- a/server/agent_working_test.go +++ /dev/null @@ -1,104 +0,0 @@ -package server - -import ( - "fmt" - "testing" - - "shelley.exe.dev/db" -) - -func TestAgentWorking(t *testing.T) { - tests := []struct { - name string - messages []APIMessage - want bool - }{ - { - name: "empty messages", - messages: []APIMessage{}, - want: false, - }, - { - name: "agent with end_of_turn true", - messages: []APIMessage{ - {Type: string(db.MessageTypeAgent), EndOfTurn: truePtr}, - }, - want: false, - }, - { - name: "agent with end_of_turn false", - messages: []APIMessage{ - {Type: string(db.MessageTypeAgent), EndOfTurn: falsePtr}, - }, - want: true, - }, - { - name: "agent with end_of_turn nil", - messages: []APIMessage{ - {Type: string(db.MessageTypeAgent), EndOfTurn: nil}, - }, - want: true, - }, - { - name: "error message", - messages: []APIMessage{ - {Type: string(db.MessageTypeError)}, - }, - want: false, - }, - { - name: "agent end_of_turn then tool message means working", - messages: []APIMessage{ - {Type: string(db.MessageTypeAgent), EndOfTurn: truePtr}, - {Type: string(db.MessageTypeTool)}, - }, - want: true, - }, - { - name: "gitinfo after agent end_of_turn should NOT indicate working", - messages: []APIMessage{ - {Type: string(db.MessageTypeAgent), EndOfTurn: truePtr}, - {Type: string(db.MessageTypeGitInfo)}, - }, - want: false, - }, - { - name: "multiple gitinfo after agent end_of_turn should NOT indicate working", - messages: []APIMessage{ - {Type: string(db.MessageTypeAgent), EndOfTurn: truePtr}, - {Type: string(db.MessageTypeGitInfo)}, - {Type: string(db.MessageTypeGitInfo)}, - }, - want: false, - }, - { - name: "gitinfo after agent not end_of_turn should indicate working", - messages: []APIMessage{ - {Type: string(db.MessageTypeAgent), EndOfTurn: falsePtr}, - {Type: string(db.MessageTypeGitInfo)}, - }, - want: true, - }, - { - name: "only gitinfo messages", - messages: []APIMessage{ - {Type: string(db.MessageTypeGitInfo)}, - {Type: string(db.MessageTypeGitInfo)}, - }, - want: false, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - got := agentWorking(tt.messages) - if got == nil || *got != tt.want { - gotVal := "nil" - if got != nil { - gotVal = fmt.Sprintf("%v", *got) - } - t.Errorf("agentWorking() = %v, want %v", gotVal, tt.want) - } - }) - } -} diff --git a/server/conversation_state_test.go b/server/conversation_state_test.go new file mode 100644 index 0000000000000000000000000000000000000000..2ac5f24f004b97fc3cc463c925f51d9bd2364e4f --- /dev/null +++ b/server/conversation_state_test.go @@ -0,0 +1,147 @@ +package server + +import ( + "context" + "encoding/json" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" + + "shelley.exe.dev/claudetool" + "shelley.exe.dev/db" + "shelley.exe.dev/llm" + "shelley.exe.dev/loop" +) + +// responseRecorderWithClose wraps httptest.ResponseRecorder to support CloseNotify +type responseRecorderWithClose struct { + *httptest.ResponseRecorder + closeNotify chan bool +} + +func newResponseRecorderWithClose() *responseRecorderWithClose { + return &responseRecorderWithClose{ + ResponseRecorder: httptest.NewRecorder(), + closeNotify: make(chan bool, 1), + } +} + +func (r *responseRecorderWithClose) CloseNotify() <-chan bool { + return r.closeNotify +} + +func (r *responseRecorderWithClose) Close() { + select { + case r.closeNotify <- true: + default: + } +} + +// TestConversationStateAfterServerRestart verifies that when a conversation is +// loaded after a server restart (new manager created), the agent is correctly +// reported as not working since the loop isn't running. +func TestConversationStateAfterServerRestart(t *testing.T) { + database, cleanup := setupTestDB(t) + defer cleanup() + + ctx := context.Background() + + // Create a conversation with some messages (simulating previous activity) + conv, err := database.CreateConversation(ctx, nil, true, nil) + if err != nil { + t.Fatalf("Failed to create conversation: %v", err) + } + + // Add a user message + userMsg := llm.Message{ + Role: llm.MessageRoleUser, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hello"}}, + } + _, err = database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: conv.ConversationID, + Type: db.MessageTypeUser, + LLMData: userMsg, + }) + if err != nil { + t.Fatalf("Failed to create user message: %v", err) + } + + // Add an agent message (without end_of_turn to simulate mid-conversation) + agentMsg := llm.Message{ + Role: llm.MessageRoleAssistant, + Content: []llm.Content{{Type: llm.ContentTypeText, Text: "Hi there!"}}, + EndOfTurn: false, + } + _, err = database.CreateMessage(ctx, db.CreateMessageParams{ + ConversationID: conv.ConversationID, + Type: db.MessageTypeAgent, + LLMData: agentMsg, + }) + if err != nil { + t.Fatalf("Failed to create agent message: %v", err) + } + + // Create a NEW server (simulating server restart - no active managers) + predictableService := loop.NewPredictableService() + llmManager := &testLLMManager{service: predictableService} + toolSetConfig := claudetool.ToolSetConfig{EnableBrowser: false} + server := NewServer(database, llmManager, toolSetConfig, nil, true, "", "predictable", "", nil) + + mux := http.NewServeMux() + server.RegisterRoutes(mux) + + // Make a streaming request with a context that cancels after we read the first message + ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) + defer cancel() + + req := httptest.NewRequest("GET", "/api/conversation/"+conv.ConversationID+"/stream", nil).WithContext(ctx) + req.Header.Set("Accept", "text/event-stream") + + w := newResponseRecorderWithClose() + + // Run handler in goroutine and close connection after getting first response + done := make(chan struct{}) + go func() { + defer close(done) + mux.ServeHTTP(w, req) + }() + + // Wait for some data or timeout + time.Sleep(500 * time.Millisecond) + w.Close() + cancel() + + // Wait for handler to finish + <-done + + // Parse the first SSE message + body := w.Body.String() + if !strings.HasPrefix(body, "data: ") { + t.Fatalf("Expected SSE data, got: %s", body) + } + + jsonData := strings.TrimPrefix(strings.Split(body, "\n")[0], "data: ") + var response StreamResponse + if err := json.Unmarshal([]byte(jsonData), &response); err != nil { + t.Fatalf("Failed to parse response: %v", err) + } + + // Verify conversation state shows agent is NOT working + // (because after server restart, no loop is running) + if response.ConversationState == nil { + t.Fatal("Expected ConversationState in response") + } + if response.ConversationState.ConversationID != conv.ConversationID { + t.Errorf("Expected ConversationID %s, got %s", conv.ConversationID, response.ConversationState.ConversationID) + } + if response.ConversationState.Working { + t.Error("Expected Working=false after server restart (no active loop)") + } + + // Verify messages were loaded + if len(response.Messages) != 2 { + t.Errorf("Expected 2 messages, got %d", len(response.Messages)) + } +} diff --git a/server/convo.go b/server/convo.go index f4c528d6c5096f8e14caba6b06aa04538bc3169a..c26d4983c12e0571ca6c5a2f72c75ec4b5b1862b 100644 --- a/server/convo.go +++ b/server/convo.go @@ -41,10 +41,18 @@ type ConversationManager struct { hydrated bool hasConversationEvents bool cwd string // working directory for tools + + // agentWorking tracks whether the agent is currently working. + // This is explicitly managed and broadcast to subscribers when it changes. + agentWorking bool + + // onStateChange is called when the conversation state changes. + // This allows the server to broadcast state changes to all subscribers. + onStateChange func(state ConversationState) } // NewConversationManager constructs a manager with dependencies but defers hydration until needed. -func NewConversationManager(conversationID string, database *db.DB, baseLogger *slog.Logger, toolSetConfig claudetool.ToolSetConfig, recordMessage loop.MessageRecordFunc) *ConversationManager { +func NewConversationManager(conversationID string, database *db.DB, baseLogger *slog.Logger, toolSetConfig claudetool.ToolSetConfig, recordMessage loop.MessageRecordFunc, onStateChange func(ConversationState)) *ConversationManager { logger := baseLogger if logger == nil { logger = slog.Default() @@ -59,7 +67,36 @@ func NewConversationManager(conversationID string, database *db.DB, baseLogger * logger: logger, toolSetConfig: toolSetConfig, subpub: subpub.New[StreamResponse](), + onStateChange: onStateChange, + } +} + +// SetAgentWorking updates the agent working state and notifies the server to broadcast. +func (cm *ConversationManager) SetAgentWorking(working bool) { + cm.mu.Lock() + if cm.agentWorking == working { + cm.mu.Unlock() + return } + cm.agentWorking = working + onStateChange := cm.onStateChange + convID := cm.conversationID + cm.mu.Unlock() + + cm.logger.Debug("agent working state changed", "working", working) + if onStateChange != nil { + onStateChange(ConversationState{ + ConversationID: convID, + Working: working, + }) + } +} + +// IsAgentWorking returns the current agent working state. +func (cm *ConversationManager) IsAgentWorking() bool { + cm.mu.Lock() + defer cm.mu.Unlock() + return cm.agentWorking } // Hydrate loads conversation state from the database, generating a system prompt if missing. @@ -158,6 +195,9 @@ func (cm *ConversationManager) AcceptUserMessage(ctx context.Context, service ll loopInstance.QueueUserMessage(message) + // Mark agent as working - we just queued work for the loop + cm.SetAgentWorking(true) + return isFirst, nil } @@ -480,6 +520,9 @@ func (cm *ConversationManager) CancelConversation(ctx context.Context) error { return fmt.Errorf("failed to record end turn message: %w", err) } + // Mark agent as not working + cm.SetAgentWorking(false) + cm.mu.Lock() cm.loopCancel = nil cm.loopCtx = nil @@ -557,7 +600,6 @@ func (cm *ConversationManager) notifyGitStateChange(ctx context.Context, msg *ge streamData := StreamResponse{ Messages: apiMessages, Conversation: conversation, - AgentWorking: falsePtr, // Gitinfo is recorded at end of turn, agent is done } cm.subpub.Publish(msg.SequenceID, streamData) } diff --git a/server/handlers.go b/server/handlers.go index b4894df100a0883bab64c58a1d3cfb06d9898bed..397a65419acfcd3018434ac86916f46f59d23b12 100644 --- a/server/handlers.go +++ b/server/handlers.go @@ -520,8 +520,20 @@ func (s *Server) handleConversations(w http.ResponseWriter, r *http.Request) { return } + // Get working states for all active conversations + workingStates := s.getWorkingConversations() + + // Build response with working state included + result := make([]ConversationWithState, len(conversations)) + for i, conv := range conversations { + result[i] = ConversationWithState{ + Conversation: conv, + Working: workingStates[conv.ConversationID], + } + } + w.Header().Set("Content-Type", "application/json") - json.NewEncoder(w).Encode(conversations) + json.NewEncoder(w).Encode(result) } // conversationMux returns a mux for /api/conversation//* routes @@ -593,9 +605,9 @@ func (s *Server) handleGetConversation(w http.ResponseWriter, r *http.Request, c w.Header().Set("Content-Type", "application/json") apiMessages := toAPIMessages(messages) json.NewEncoder(w).Encode(StreamResponse{ - Messages: apiMessages, - Conversation: conversation, - AgentWorking: agentWorking(apiMessages), + Messages: apiMessages, + Conversation: conversation, + // ConversationState is sent via the streaming endpoint, not on initial load ContextWindowSize: calculateContextWindowSize(apiMessages), }) } @@ -863,25 +875,28 @@ func (s *Server) handleStreamConversation(w http.ResponseWriter, r *http.Request return } - // Send current messages and conversation data + // Get or create conversation manager to access working state + manager, err := s.getOrCreateConversationManager(ctx, conversationID) + if err != nil { + s.logger.Error("Failed to get conversation manager", "conversationID", conversationID, "error", err) + return + } + + // Send current messages, conversation data, and conversation state apiMessages := toAPIMessages(messages) streamData := StreamResponse{ - Messages: apiMessages, - Conversation: conversation, - AgentWorking: agentWorking(apiMessages), + Messages: apiMessages, + Conversation: conversation, + ConversationState: &ConversationState{ + ConversationID: conversationID, + Working: manager.IsAgentWorking(), + }, ContextWindowSize: calculateContextWindowSize(apiMessages), } data, _ := json.Marshal(streamData) fmt.Fprintf(w, "data: %s\n\n", data) w.(http.Flusher).Flush() - // Get or create conversation manager - manager, err := s.getOrCreateConversationManager(ctx, conversationID) - if err != nil { - s.logger.Error("Failed to get conversation manager", "conversationID", conversationID, "error", err) - return - } - // Subscribe to new messages after the last one we sent last := int64(-1) if len(messages) > 0 { diff --git a/server/server.go b/server/server.go index bcaa58fb96dafbbdd546848986841fc82d3eae4d..f64db957136eb02ba5d8eff08198e8e260a4635b 100644 --- a/server/server.go +++ b/server/server.go @@ -41,11 +41,24 @@ type APIMessage struct { EndOfTurn *bool `json:"end_of_turn,omitempty"` } +// ConversationState represents the current state of a conversation. +// This is broadcast to all subscribers whenever the state changes. +type ConversationState struct { + ConversationID string `json:"conversation_id"` + Working bool `json:"working"` +} + +// ConversationWithState combines a conversation with its working state. +type ConversationWithState struct { + generated.Conversation + Working bool `json:"working"` +} + // StreamResponse represents the response format for conversation streaming type StreamResponse struct { Messages []APIMessage `json:"messages"` Conversation generated.Conversation `json:"conversation"` - AgentWorking *bool `json:"agent_working,omitempty"` + ConversationState *ConversationState `json:"conversation_state,omitempty"` ContextWindowSize uint64 `json:"context_window_size,omitempty"` // ConversationListUpdate is set when another conversation in the list changed ConversationListUpdate *ConversationListUpdate `json:"conversation_list_update,omitempty"` @@ -149,69 +162,12 @@ func calculateContextWindowSize(messages []APIMessage) uint64 { return 0 } -var ( - truePtr = ptr(true) - falsePtr = ptr(false) -) - -func ptr[T any](v T) *T { return &v } - -func agentWorking(messages []APIMessage) *bool { - if len(messages) == 0 { - return falsePtr - } - - // Find the last non-gitinfo message (gitinfo messages are passive notifications) - lastIdx := len(messages) - 1 - for lastIdx >= 0 && messages[lastIdx].Type == string(db.MessageTypeGitInfo) { - lastIdx-- - } - if lastIdx < 0 { - return falsePtr - } - last := messages[lastIdx] - - // If the last message is an error, agent is not working - if last.Type == string(db.MessageTypeError) { - return falsePtr - } - - if last.Type == string(db.MessageTypeAgent) { - if last.EndOfTurn == nil { - return truePtr - } - if *last.EndOfTurn { - return falsePtr - } - return truePtr - } - - for i := lastIdx; i >= 0; i-- { - msg := messages[i] - if msg.Type != string(db.MessageTypeAgent) { - continue - } - // Agent ended turn, but newer non-agent messages exist, so agent is working again. - return truePtr - } - - // No agent message found yet but conversation has activity, assume agent is working. - return truePtr -} - -// isEndOfTurn checks if a database message represents end of turn -func isEndOfTurn(msg *generated.Message) bool { +// isAgentEndOfTurn checks if a message is an agent message with end_of_turn=true. +// This indicates the agent loop has finished processing. +func isAgentEndOfTurn(msg *generated.Message) bool { if msg == nil { return false } - // Error messages end the turn - if msg.Type == string(db.MessageTypeError) { - return true - } - // Gitinfo messages always come at end of turn (after a commit) - if msg.Type == string(db.MessageTypeGitInfo) { - return true - } // Only agent messages can have end_of_turn if msg.Type != string(db.MessageTypeAgent) { return false @@ -480,7 +436,11 @@ func (s *Server) getOrCreateConversationManager(ctx context.Context, conversatio return s.recordMessage(ctx, conversationID, message, usage) } - manager := NewConversationManager(conversationID, s.db, s.logger, s.toolSetConfig, recordMessage) + onStateChange := func(state ConversationState) { + s.publishConversationState(state) + } + + manager := NewConversationManager(conversationID, s.db, s.logger, s.toolSetConfig, recordMessage, onStateChange) if err := manager.Hydrate(ctx); err != nil { return nil, err } @@ -682,15 +642,15 @@ func (s *Server) notifySubscribersNewMessage(ctx context.Context, conversationID // Convert the single new message to API format apiMessages := toAPIMessages([]generated.Message{*newMsg}) - // Publish only the new message - agentWorking := falsePtr - if !isEndOfTurn(newMsg) { - agentWorking = truePtr + // Update agent working state based on message type + if isAgentEndOfTurn(newMsg) { + manager.SetAgentWorking(false) } + + // Publish only the new message streamData := StreamResponse{ Messages: apiMessages, Conversation: conversation, - AgentWorking: agentWorking, // ContextWindowSize: 0 for messages without usage data (user/tool messages). // With omitempty, 0 is omitted from JSON, so the UI keeps its cached value. // Only agent messages have usage data, so context window updates when they arrive. @@ -721,6 +681,35 @@ func (s *Server) publishConversationListUpdate(update ConversationListUpdate) { } } +// publishConversationState broadcasts a conversation state update to ALL active +// conversation streams. This allows clients to see the working state of other conversations. +func (s *Server) publishConversationState(state ConversationState) { + s.mu.Lock() + defer s.mu.Unlock() + + // Broadcast to all active conversation managers + for _, manager := range s.activeConversations { + streamData := StreamResponse{ + ConversationState: &state, + } + manager.subpub.Broadcast(streamData) + } +} + +// getWorkingConversations returns a map of conversation IDs that are currently working. +func (s *Server) getWorkingConversations() map[string]bool { + s.mu.Lock() + defer s.mu.Unlock() + + working := make(map[string]bool) + for id, manager := range s.activeConversations { + if manager.IsAgentWorking() { + working[id] = true + } + } + return working +} + // Cleanup removes inactive conversation managers func (s *Server) Cleanup() { s.mu.Lock() diff --git a/test/server_test.go b/test/server_test.go index f2eb2cb9db795ab90648dc2a329a4a60d7bbc251..583f21b64e18233a07060702d76538c424037e63 100644 --- a/test/server_test.go +++ b/test/server_test.go @@ -644,7 +644,8 @@ func TestSSEIncrementalUpdates(t *testing.T) { defer client1.Body.Close() // Read initial response from client1 (should contain the first message) - buf1 := make([]byte, 2048) + // Buffer must be large enough to hold the full response including system prompt + buf1 := make([]byte, 32768) n1, err := client1.Body.Read(buf1) if err != nil && err != io.EOF { t.Fatalf("Failed to read from client1: %v", err) @@ -678,7 +679,7 @@ func TestSSEIncrementalUpdates(t *testing.T) { defer client2.Body.Close() // Read response from client2 (should contain both messages since it's a new client) - buf2 := make([]byte, 2048) + buf2 := make([]byte, 32768) n2, err := client2.Body.Read(buf2) if err != nil && err != io.EOF { t.Fatalf("Failed to read from client2: %v", err) diff --git a/ui/src/App.tsx b/ui/src/App.tsx index 7dbe0a29d69059799a39c9b51e760291bcdcbeef..9c864bdb9ef3eda78ead59add8daed13a4f6994b 100644 --- a/ui/src/App.tsx +++ b/ui/src/App.tsx @@ -2,7 +2,7 @@ import React, { useState, useEffect, useCallback, useRef } from "react"; import ChatInterface from "./components/ChatInterface"; import ConversationDrawer from "./components/ConversationDrawer"; import CommandPalette from "./components/CommandPalette"; -import { Conversation, ConversationListUpdate } from "./types"; +import { Conversation, ConversationWithState, ConversationListUpdate } from "./types"; import { api } from "./services/api"; // Check if a slug is a generated ID (format: cXXXX where X is alphanumeric) @@ -59,7 +59,7 @@ function updatePageTitle(conversation: Conversation | undefined) { } function App() { - const [conversations, setConversations] = useState([]); + const [conversations, setConversations] = useState([]); const [currentConversationId, setCurrentConversationId] = useState(null); const [drawerOpen, setDrawerOpen] = useState(false); const [drawerCollapsed, setDrawerCollapsed] = useState(false); @@ -127,13 +127,17 @@ function App() { ); if (existingIndex >= 0) { - // Update existing conversation in place (don't re-sort to avoid distracting jumps) + // Update existing conversation in place, preserving working state + // (working state is updated separately via conversation_state) const updated = [...prev]; - updated[existingIndex] = update.conversation!; + updated[existingIndex] = { + ...update.conversation!, + working: prev[existingIndex].working, + }; return updated; } else { - // Add new conversation at the top - return [update.conversation!, ...prev]; + // Add new conversation at the top (not working by default) + return [{ ...update.conversation!, working: false }, ...prev]; } }); } else if (update.type === "delete" && update.conversation_id) { @@ -141,6 +145,20 @@ function App() { } }, []); + // Handle conversation state updates (working state changes) + const handleConversationStateUpdate = useCallback( + (state: { conversation_id: string; working: boolean }) => { + setConversations((prev) => + prev.map((conv) => + conv.conversation_id === state.conversation_id + ? { ...conv, working: state.working } + : conv, + ), + ); + }, + [], + ); + // Update page title and URL when conversation changes useEffect(() => { const currentConv = conversations.find( @@ -193,7 +211,9 @@ function App() { const updateConversation = (updatedConversation: Conversation) => { setConversations((prev) => prev.map((conv) => - conv.conversation_id === updatedConversation.conversation_id ? updatedConversation : conv, + conv.conversation_id === updatedConversation.conversation_id + ? { ...updatedConversation, working: conv.working } + : conv, ), ); }; @@ -208,14 +228,18 @@ function App() { }; const handleConversationUnarchived = (conversation: Conversation) => { - // Add the unarchived conversation back to the list - setConversations((prev) => [conversation, ...prev]); + // Add the unarchived conversation back to the list (not working by default) + setConversations((prev) => [{ ...conversation, working: false }, ...prev]); }; const handleConversationRenamed = (conversation: Conversation) => { - // Update the conversation in the list with the new slug + // Update the conversation in the list with the new slug, preserving working state setConversations((prev) => - prev.map((c) => (c.conversation_id === conversation.conversation_id ? conversation : c)), + prev.map((c) => + c.conversation_id === conversation.conversation_id + ? { ...conversation, working: c.working } + : c, + ), ); }; @@ -294,6 +318,7 @@ function App() { currentConversation={currentConversation} onConversationUpdate={updateConversation} onConversationListUpdate={handleConversationListUpdate} + onConversationStateUpdate={handleConversationStateUpdate} onFirstMessage={handleFirstMessage} mostRecentCwd={mostRecentCwd} isDrawerCollapsed={drawerCollapsed} diff --git a/ui/src/components/ChatInterface.tsx b/ui/src/components/ChatInterface.tsx index af7ad6cafa7cb16934bbaebda068050165d2c54f..a193b17e66099a263bd40dcc8fc902c3c3fca461 100644 --- a/ui/src/components/ChatInterface.tsx +++ b/ui/src/components/ChatInterface.tsx @@ -353,6 +353,11 @@ function AnimatedWorkingStatus() { ); } +interface ConversationStateUpdate { + conversation_id: string; + working: boolean; +} + interface ChatInterfaceProps { conversationId: string | null; onOpenDrawer: () => void; @@ -360,6 +365,7 @@ interface ChatInterfaceProps { currentConversation?: Conversation; onConversationUpdate?: (conversation: Conversation) => void; onConversationListUpdate?: (update: ConversationListUpdate) => void; + onConversationStateUpdate?: (state: ConversationStateUpdate) => void; onFirstMessage?: (message: string, model: string, cwd?: string) => Promise; mostRecentCwd?: string | null; isDrawerCollapsed?: boolean; @@ -374,6 +380,7 @@ function ChatInterface({ currentConversation, onConversationUpdate, onConversationListUpdate, + onConversationStateUpdate, onFirstMessage, mostRecentCwd, isDrawerCollapsed, @@ -578,7 +585,8 @@ function ChatInterface({ setError(null); const response = await api.getConversation(conversationId); setMessages(response.messages ?? []); - setAgentWorking(Boolean(response.agent_working)); + // ConversationState is sent via the streaming endpoint, not on initial load + // We don't update agentWorking here - the stream will provide the current state // Always update context window size when loading a conversation. // If omitted from response (due to omitempty when 0), default to 0. setContextWindowSize(response.context_window_size ?? 0); @@ -638,8 +646,16 @@ function ChatInterface({ onConversationListUpdate(streamResponse.conversation_list_update); } - if (typeof streamResponse.agent_working === "boolean") { - setAgentWorking(streamResponse.agent_working); + // Handle conversation state updates (explicit from server) + if (streamResponse.conversation_state) { + // Update the conversations list with new working state + if (onConversationStateUpdate) { + onConversationStateUpdate(streamResponse.conversation_state); + } + // Update local state if this is for our conversation + if (streamResponse.conversation_state.conversation_id === conversationId) { + setAgentWorking(streamResponse.conversation_state.working); + } } if (typeof streamResponse.context_window_size === "number") { diff --git a/ui/src/components/ConversationDrawer.tsx b/ui/src/components/ConversationDrawer.tsx index 4f16da4e605be88f2c7f433aa45e783b8d21b1b7..b16194deb2a4f603f90f31f42bd46be746ad9afc 100644 --- a/ui/src/components/ConversationDrawer.tsx +++ b/ui/src/components/ConversationDrawer.tsx @@ -1,5 +1,5 @@ import React, { useState, useEffect } from "react"; -import { Conversation } from "../types"; +import { Conversation, ConversationWithState } from "../types"; import { api } from "../services/api"; interface ConversationDrawerProps { @@ -7,7 +7,7 @@ interface ConversationDrawerProps { isCollapsed: boolean; onClose: () => void; onToggleCollapse: () => void; - conversations: Conversation[]; + conversations: ConversationWithState[]; currentConversationId: string | null; onSelectConversation: (id: string) => void; onNewConversation: () => void; @@ -279,33 +279,53 @@ function ConversationDrawer({ style={{ cursor: showArchived ? "default" : "pointer" }} >
- {editingId === conversation.conversation_id ? ( - setEditingSlug(e.target.value)} - onBlur={() => handleRename(conversation.conversation_id)} - onKeyDown={(e) => handleRenameKeyDown(e, conversation.conversation_id)} - onClick={(e) => e.stopPropagation()} - autoFocus - className="conversation-title" - style={{ - width: "100%", - background: "transparent", - border: "none", - borderBottom: "1px solid var(--text-secondary)", - outline: "none", - padding: 0, - font: "inherit", - color: "inherit", - }} - /> - ) : ( -
- {getConversationPreview(conversation)} +
+
+ {editingId === conversation.conversation_id ? ( + setEditingSlug(e.target.value)} + onBlur={() => handleRename(conversation.conversation_id)} + onKeyDown={(e) => + handleRenameKeyDown(e, conversation.conversation_id) + } + onClick={(e) => e.stopPropagation()} + autoFocus + className="conversation-title" + style={{ + width: "100%", + background: "transparent", + border: "none", + borderBottom: "1px solid var(--text-secondary)", + outline: "none", + padding: 0, + font: "inherit", + color: "inherit", + }} + /> + ) : ( +
+ {getConversationPreview(conversation)} +
+ )}
- )} + {(conversation as ConversationWithState).working && ( + + )} +
{formatDate(conversation.updated_at)} @@ -315,100 +335,102 @@ function ConversationDrawer({ {formatCwdForDisplay(conversation.cwd)} )} -
-
-
- {showArchived ? ( - <> - - + - - ) : ( - <> - +
+ )} +
+ + {showArchived && ( +
+ - + - - )} -
+ + + + + )} ); })} diff --git a/ui/src/generated-types.ts b/ui/src/generated-types.ts index a75f08fc39cfbb1c055ce469cca97c63ba23516c..6666f86fdfeec35a473473e222396f1a11c17104 100644 --- a/ui/src/generated-types.ts +++ b/ui/src/generated-types.ts @@ -37,10 +37,26 @@ export interface ApiMessageForTS { end_of_turn?: boolean | null; } +export interface ConversationStateForTS { + conversation_id: string; + working: boolean; +} + export interface StreamResponseForTS { messages: ApiMessageForTS[] | null; conversation: Conversation; - agent_working?: boolean | null; + conversation_state?: ConversationStateForTS | null; +} + +export interface ConversationWithStateForTS { + conversation_id: string; + slug: string | null; + user_initiated: boolean; + created_at: string; + updated_at: string; + cwd: string | null; + archived: boolean; + working: boolean; } export type MessageType = "user" | "agent" | "tool" | "error" | "system" | "gitinfo"; diff --git a/ui/src/services/api.ts b/ui/src/services/api.ts index 977a7da5b4fc482f0b8992aec3561e991fbce85f..b1d3ba14605db7cee119b0affecd149b48de06c7 100644 --- a/ui/src/services/api.ts +++ b/ui/src/services/api.ts @@ -1,5 +1,6 @@ import { Conversation, + ConversationWithState, StreamResponse, ChatRequest, GitDiffInfo, @@ -16,7 +17,7 @@ class ApiService { "X-Shelley-Request": "1", }; - async getConversations(): Promise { + async getConversations(): Promise { const response = await fetch(`${this.baseUrl}/conversations`); if (!response.ok) { throw new Error(`Failed to get conversations: ${response.statusText}`); @@ -24,7 +25,7 @@ class ApiService { return response.json(); } - async searchConversations(query: string): Promise { + async searchConversations(query: string): Promise { const params = new URLSearchParams({ q: query, search_content: "true", diff --git a/ui/src/types.ts b/ui/src/types.ts index 5585cc322dfa735bc35aae0ba9a7a018acd121e2..31173186805650e37d50fb78a8db81dc8834d8e8 100644 --- a/ui/src/types.ts +++ b/ui/src/types.ts @@ -1,6 +1,7 @@ // Types for Shelley UI import { Conversation as GeneratedConversation, + ConversationWithStateForTS, ApiMessageForTS, StreamResponseForTS, Usage as GeneratedUsage, @@ -9,6 +10,7 @@ import { // Re-export generated types export type Conversation = GeneratedConversation; +export type ConversationWithState = ConversationWithStateForTS; export type Usage = GeneratedUsage; export type MessageType = GeneratedMessageType;