change_dir_test.go

  1package server
  2
  3import (
  4	"bufio"
  5	"context"
  6	"encoding/json"
  7	"log/slog"
  8	"net/http"
  9	"net/http/httptest"
 10	"os"
 11	"path/filepath"
 12	"strings"
 13	"testing"
 14	"time"
 15
 16	"shelley.exe.dev/claudetool"
 17	"shelley.exe.dev/db"
 18	"shelley.exe.dev/db/generated"
 19	"shelley.exe.dev/loop"
 20)
 21
 22// TestChangeDirAffectsBash tests that change_dir updates the working directory
 23// and subsequent bash commands run in that directory.
 24func TestChangeDirAffectsBash(t *testing.T) {
 25	// Create a temp directory structure
 26	tmpDir := t.TempDir()
 27	subDir := filepath.Join(tmpDir, "subdir")
 28	if err := os.Mkdir(subDir, 0o755); err != nil {
 29		t.Fatal(err)
 30	}
 31
 32	// Create a marker file in subdir
 33	markerFile := filepath.Join(subDir, "marker.txt")
 34	if err := os.WriteFile(markerFile, []byte("found"), 0o644); err != nil {
 35		t.Fatal(err)
 36	}
 37
 38	database, cleanup := setupTestDB(t)
 39	defer cleanup()
 40
 41	predictableService := loop.NewPredictableService()
 42	llmManager := &testLLMManager{service: predictableService}
 43	logger := slog.Default()
 44
 45	// Create server with working directory set to tmpDir
 46	toolSetConfig := claudetool.ToolSetConfig{
 47		WorkingDir: tmpDir,
 48	}
 49	server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
 50
 51	// Create conversation
 52	conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
 53	if err != nil {
 54		t.Fatalf("failed to create conversation: %v", err)
 55	}
 56	conversationID := conversation.ConversationID
 57
 58	// Step 1: Send change_dir command to change to subdir
 59	changeDirReq := ChatRequest{
 60		Message: "change_dir: " + subDir,
 61		Model:   "predictable",
 62	}
 63	changeDirBody, _ := json.Marshal(changeDirReq)
 64
 65	req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(changeDirBody)))
 66	req.Header.Set("Content-Type", "application/json")
 67	w := httptest.NewRecorder()
 68
 69	server.handleChatConversation(w, req, conversationID)
 70	if w.Code != http.StatusAccepted {
 71		t.Fatalf("expected status 202 for change_dir, got %d: %s", w.Code, w.Body.String())
 72	}
 73
 74	// Wait for change_dir to complete - look for the tool result message
 75	waitForMessageContaining(t, database, conversationID, "Changed working directory", 5*time.Second)
 76
 77	// Step 2: Now send pwd command - should show subdir
 78	pwdReq := ChatRequest{
 79		Message: "bash: pwd",
 80		Model:   "predictable",
 81	}
 82	pwdBody, _ := json.Marshal(pwdReq)
 83
 84	req2 := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(pwdBody)))
 85	req2.Header.Set("Content-Type", "application/json")
 86	w2 := httptest.NewRecorder()
 87
 88	server.handleChatConversation(w2, req2, conversationID)
 89	if w2.Code != http.StatusAccepted {
 90		t.Fatalf("expected status 202 for bash pwd, got %d: %s", w2.Code, w2.Body.String())
 91	}
 92
 93	// Wait for bash pwd to complete - the second tool result should contain the subdir
 94	// We need to wait for 2 tool results: one from change_dir and one from pwd
 95	waitForBashResult(t, database, conversationID, subDir, 5*time.Second)
 96}
 97
 98// waitForBashResult waits for a bash tool result containing the expected text.
 99func waitForBashResult(t *testing.T, database *db.DB, conversationID, expectedText string, timeout time.Duration) {
100	t.Helper()
101	deadline := time.Now().Add(timeout)
102	for time.Now().Before(deadline) {
103		messages, err := database.ListMessages(context.Background(), conversationID)
104		if err != nil {
105			t.Fatalf("failed to get messages: %v", err)
106		}
107
108		// Look for a tool result from bash tool that contains the expected text
109		for _, msg := range messages {
110			if msg.LlmData == nil {
111				continue
112			}
113			// The tool result for bash should contain the pwd output
114			// We distinguish it from the change_dir result by looking for the newline at the end
115			// (pwd outputs the path with a newline, change_dir outputs "Changed working directory to: ...")
116			// JSON encodes newline as \n so we check for that
117			if strings.Contains(*msg.LlmData, expectedText+`\n`) {
118				return
119			}
120		}
121		time.Sleep(50 * time.Millisecond)
122	}
123
124	// Print debug info on failure
125	messages, _ := database.ListMessages(context.Background(), conversationID)
126	t.Log("Messages in conversation:")
127	for i, msg := range messages {
128		t.Logf("  Message %d: type=%s", i, msg.Type)
129		if msg.LlmData != nil {
130			t.Logf("    data: %s", truncate(*msg.LlmData, 300))
131		}
132	}
133	t.Fatalf("did not find bash result containing %q within %v", expectedText, timeout)
134}
135
136// waitForMessageContaining waits for a message containing the specified text.
137func waitForMessageContaining(t *testing.T, database *db.DB, conversationID, text string, timeout time.Duration) {
138	t.Helper()
139	deadline := time.Now().Add(timeout)
140	for time.Now().Before(deadline) {
141		messages, err := database.ListMessages(context.Background(), conversationID)
142		if err != nil {
143			t.Fatalf("failed to get messages: %v", err)
144		}
145		for _, msg := range messages {
146			if msg.LlmData != nil && strings.Contains(*msg.LlmData, text) {
147				return
148			}
149		}
150		time.Sleep(50 * time.Millisecond)
151	}
152	t.Fatalf("did not find message containing %q within %v", text, timeout)
153}
154
155// getConversationMessages retrieves all messages for a conversation.
156func getConversationMessages(database *db.DB, conversationID string) ([]generated.Message, error) {
157	return database.ListMessages(context.Background(), conversationID)
158}
159
160// truncate truncates a string to maxLen characters.
161func truncate(s string, maxLen int) string {
162	if len(s) <= maxLen {
163		return s
164	}
165	return s[:maxLen] + "..."
166}
167
168// TestChangeDirBroadcastsCwdUpdate tests that change_dir broadcasts the updated cwd
169// to SSE subscribers so the UI gets the change immediately.
170func TestChangeDirBroadcastsCwdUpdate(t *testing.T) {
171	// Create a temp directory structure
172	tmpDir := t.TempDir()
173	subDir := filepath.Join(tmpDir, "subdir")
174	if err := os.Mkdir(subDir, 0o755); err != nil {
175		t.Fatal(err)
176	}
177
178	database, cleanup := setupTestDB(t)
179	defer cleanup()
180
181	predictableService := loop.NewPredictableService()
182	llmManager := &testLLMManager{service: predictableService}
183	logger := slog.Default()
184
185	// Create server with working directory set to tmpDir
186	toolSetConfig := claudetool.ToolSetConfig{
187		WorkingDir: tmpDir,
188	}
189	server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
190
191	// Create test server
192	ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
193		if strings.HasPrefix(r.URL.Path, "/api/conversation/") {
194			parts := strings.Split(r.URL.Path, "/")
195			if len(parts) >= 4 {
196				conversationID := parts[3]
197				if len(parts) >= 5 {
198					switch parts[4] {
199					case "chat":
200						server.handleChatConversation(w, r, conversationID)
201						return
202					case "stream":
203						server.handleStreamConversation(w, r, conversationID)
204						return
205					}
206				}
207			}
208		}
209		http.NotFound(w, r)
210	}))
211	defer ts.Close()
212
213	// Create conversation with initial cwd
214	conversation, err := database.CreateConversation(context.Background(), nil, true, &tmpDir, nil)
215	if err != nil {
216		t.Fatalf("failed to create conversation: %v", err)
217	}
218	conversationID := conversation.ConversationID
219
220	// Verify initial cwd
221	if conversation.Cwd == nil || *conversation.Cwd != tmpDir {
222		t.Fatalf("expected initial cwd %q, got %v", tmpDir, conversation.Cwd)
223	}
224
225	// Connect to SSE stream
226	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
227	defer cancel()
228
229	req, _ := http.NewRequestWithContext(ctx, "GET", ts.URL+"/api/conversation/"+conversationID+"/stream", nil)
230	req.Header.Set("Accept", "text/event-stream")
231
232	resp, err := http.DefaultClient.Do(req)
233	if err != nil {
234		t.Fatalf("failed to connect to SSE: %v", err)
235	}
236	defer resp.Body.Close()
237
238	// Channel to receive SSE events
239	events := make(chan StreamResponse, 10)
240	go func() {
241		scanner := bufio.NewScanner(resp.Body)
242		for scanner.Scan() {
243			line := scanner.Text()
244			if strings.HasPrefix(line, "data: ") {
245				data := strings.TrimPrefix(line, "data: ")
246				var sr StreamResponse
247				if err := json.Unmarshal([]byte(data), &sr); err == nil {
248					events <- sr
249				}
250			}
251		}
252	}()
253
254	// Wait for initial SSE event
255	select {
256	case <-events:
257		// Got initial event
258	case <-time.After(2 * time.Second):
259		t.Fatal("timeout waiting for initial SSE event")
260	}
261
262	// Send change_dir command
263	changeDirReq := ChatRequest{
264		Message: "change_dir: " + subDir,
265		Model:   "predictable",
266	}
267	changeDirBody, _ := json.Marshal(changeDirReq)
268
269	chatReq, _ := http.NewRequest("POST", ts.URL+"/api/conversation/"+conversationID+"/chat", strings.NewReader(string(changeDirBody)))
270	chatReq.Header.Set("Content-Type", "application/json")
271	chatResp, err := http.DefaultClient.Do(chatReq)
272	if err != nil {
273		t.Fatalf("failed to send chat: %v", err)
274	}
275	chatResp.Body.Close()
276
277	// Wait for SSE event with updated cwd
278	deadline := time.Now().Add(5 * time.Second)
279	for time.Now().Before(deadline) {
280		select {
281		case event := <-events:
282			// Check if this event has the updated cwd
283			if event.Conversation.Cwd != nil && *event.Conversation.Cwd == subDir {
284				// Success! The UI would receive this update
285				return
286			}
287		case <-time.After(100 * time.Millisecond):
288			// Continue waiting
289		}
290	}
291
292	t.Error("did not receive SSE event with updated cwd")
293}