diff --git a/server/change_dir_test.go b/server/change_dir_test.go index 15647ae5149002b5b96f4b448b337f1cc5ca0e77..bee54acf5e542019df03bae584e552dcdb3acb80 100644 --- a/server/change_dir_test.go +++ b/server/change_dir_test.go @@ -1,6 +1,7 @@ package server import ( + "bufio" "context" "encoding/json" "log/slog" @@ -163,3 +164,130 @@ func truncate(s string, maxLen int) string { } return s[:maxLen] + "..." } + +// TestChangeDirBroadcastsCwdUpdate tests that change_dir broadcasts the updated cwd +// to SSE subscribers so the UI gets the change immediately. +func TestChangeDirBroadcastsCwdUpdate(t *testing.T) { + // Create a temp directory structure + tmpDir := t.TempDir() + subDir := filepath.Join(tmpDir, "subdir") + if err := os.Mkdir(subDir, 0o755); err != nil { + t.Fatal(err) + } + + database, cleanup := setupTestDB(t) + defer cleanup() + + predictableService := loop.NewPredictableService() + llmManager := &testLLMManager{service: predictableService} + logger := slog.Default() + + // Create server with working directory set to tmpDir + toolSetConfig := claudetool.ToolSetConfig{ + WorkingDir: tmpDir, + } + server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil) + + // Create test server + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if strings.HasPrefix(r.URL.Path, "/api/conversation/") { + parts := strings.Split(r.URL.Path, "/") + if len(parts) >= 4 { + conversationID := parts[3] + if len(parts) >= 5 { + switch parts[4] { + case "chat": + server.handleChatConversation(w, r, conversationID) + return + case "stream": + server.handleStreamConversation(w, r, conversationID) + return + } + } + } + } + http.NotFound(w, r) + })) + defer ts.Close() + + // Create conversation with initial cwd + conversation, err := database.CreateConversation(context.Background(), nil, true, &tmpDir, nil) + if err != nil { + t.Fatalf("failed to create conversation: %v", err) + } + conversationID := conversation.ConversationID + + // Verify initial cwd + if conversation.Cwd == nil || *conversation.Cwd != tmpDir { + t.Fatalf("expected initial cwd %q, got %v", tmpDir, conversation.Cwd) + } + + // Connect to SSE stream + ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) + defer cancel() + + req, _ := http.NewRequestWithContext(ctx, "GET", ts.URL+"/api/conversation/"+conversationID+"/stream", nil) + req.Header.Set("Accept", "text/event-stream") + + resp, err := http.DefaultClient.Do(req) + if err != nil { + t.Fatalf("failed to connect to SSE: %v", err) + } + defer resp.Body.Close() + + // Channel to receive SSE events + events := make(chan StreamResponse, 10) + go func() { + scanner := bufio.NewScanner(resp.Body) + for scanner.Scan() { + line := scanner.Text() + if strings.HasPrefix(line, "data: ") { + data := strings.TrimPrefix(line, "data: ") + var sr StreamResponse + if err := json.Unmarshal([]byte(data), &sr); err == nil { + events <- sr + } + } + } + }() + + // Wait for initial SSE event + select { + case <-events: + // Got initial event + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for initial SSE event") + } + + // Send change_dir command + changeDirReq := ChatRequest{ + Message: "change_dir: " + subDir, + Model: "predictable", + } + changeDirBody, _ := json.Marshal(changeDirReq) + + chatReq, _ := http.NewRequest("POST", ts.URL+"/api/conversation/"+conversationID+"/chat", strings.NewReader(string(changeDirBody))) + chatReq.Header.Set("Content-Type", "application/json") + chatResp, err := http.DefaultClient.Do(chatReq) + if err != nil { + t.Fatalf("failed to send chat: %v", err) + } + chatResp.Body.Close() + + // Wait for SSE event with updated cwd + deadline := time.Now().Add(5 * time.Second) + for time.Now().Before(deadline) { + select { + case event := <-events: + // Check if this event has the updated cwd + if event.Conversation.Cwd != nil && *event.Conversation.Cwd == subDir { + // Success! The UI would receive this update + return + } + case <-time.After(100 * time.Millisecond): + // Continue waiting + } + } + + t.Error("did not receive SSE event with updated cwd") +} diff --git a/server/convo.go b/server/convo.go index ec17267956d3fe8bbe5dd9e96bfc4ff2dff72160..0207f2f845739e7383209478b8924c0c9daa375a 100644 --- a/server/convo.go +++ b/server/convo.go @@ -396,7 +396,28 @@ func (cm *ConversationManager) ensureLoop(service llm.Service, modelID string) e // Persist working directory change to database if err := db.UpdateConversationCwd(context.Background(), conversationID, newDir); err != nil { logger.Error("failed to persist working directory change", "error", err, "newDir", newDir) + return } + + // Update local cwd + cm.mu.Lock() + cm.cwd = newDir + cm.mu.Unlock() + + // Broadcast conversation update to subscribers so UI gets the new cwd + var conv generated.Conversation + err := db.Queries(context.Background(), func(q *generated.Queries) error { + var err error + conv, err = q.GetConversation(context.Background(), conversationID) + return err + }) + if err != nil { + logger.Error("failed to get conversation for cwd broadcast", "error", err) + return + } + cm.subpub.Broadcast(StreamResponse{ + Conversation: conv, + }) } // Create a context with the conversation ID for LLM request recording/prefix dedup