1package server
2
3import (
4 "bufio"
5 "context"
6 "encoding/json"
7 "log/slog"
8 "net/http"
9 "net/http/httptest"
10 "os"
11 "path/filepath"
12 "strings"
13 "testing"
14 "time"
15
16 "shelley.exe.dev/claudetool"
17 "shelley.exe.dev/db"
18 "shelley.exe.dev/db/generated"
19 "shelley.exe.dev/loop"
20)
21
22// TestChangeDirAffectsBash tests that change_dir updates the working directory
23// and subsequent bash commands run in that directory.
24func TestChangeDirAffectsBash(t *testing.T) {
25 // Create a temp directory structure
26 tmpDir := t.TempDir()
27 subDir := filepath.Join(tmpDir, "subdir")
28 if err := os.Mkdir(subDir, 0o755); err != nil {
29 t.Fatal(err)
30 }
31
32 // Create a marker file in subdir
33 markerFile := filepath.Join(subDir, "marker.txt")
34 if err := os.WriteFile(markerFile, []byte("found"), 0o644); err != nil {
35 t.Fatal(err)
36 }
37
38 database, cleanup := setupTestDB(t)
39 defer cleanup()
40
41 predictableService := loop.NewPredictableService()
42 llmManager := &testLLMManager{service: predictableService}
43 logger := slog.Default()
44
45 // Create server with working directory set to tmpDir
46 toolSetConfig := claudetool.ToolSetConfig{
47 WorkingDir: tmpDir,
48 }
49 server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
50
51 // Create conversation
52 conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
53 if err != nil {
54 t.Fatalf("failed to create conversation: %v", err)
55 }
56 conversationID := conversation.ConversationID
57
58 // Step 1: Send change_dir command to change to subdir
59 changeDirReq := ChatRequest{
60 Message: "change_dir: " + subDir,
61 Model: "predictable",
62 }
63 changeDirBody, _ := json.Marshal(changeDirReq)
64
65 req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(changeDirBody)))
66 req.Header.Set("Content-Type", "application/json")
67 w := httptest.NewRecorder()
68
69 server.handleChatConversation(w, req, conversationID)
70 if w.Code != http.StatusAccepted {
71 t.Fatalf("expected status 202 for change_dir, got %d: %s", w.Code, w.Body.String())
72 }
73
74 // Wait for change_dir to complete - look for the tool result message
75 waitForMessageContaining(t, database, conversationID, "Changed working directory", 5*time.Second)
76
77 // Step 2: Now send pwd command - should show subdir
78 pwdReq := ChatRequest{
79 Message: "bash: pwd",
80 Model: "predictable",
81 }
82 pwdBody, _ := json.Marshal(pwdReq)
83
84 req2 := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(pwdBody)))
85 req2.Header.Set("Content-Type", "application/json")
86 w2 := httptest.NewRecorder()
87
88 server.handleChatConversation(w2, req2, conversationID)
89 if w2.Code != http.StatusAccepted {
90 t.Fatalf("expected status 202 for bash pwd, got %d: %s", w2.Code, w2.Body.String())
91 }
92
93 // Wait for bash pwd to complete - the second tool result should contain the subdir
94 // We need to wait for 2 tool results: one from change_dir and one from pwd
95 waitForBashResult(t, database, conversationID, subDir, 5*time.Second)
96}
97
98// waitForBashResult waits for a bash tool result containing the expected text.
99func waitForBashResult(t *testing.T, database *db.DB, conversationID, expectedText string, timeout time.Duration) {
100 t.Helper()
101 deadline := time.Now().Add(timeout)
102 for time.Now().Before(deadline) {
103 messages, err := database.ListMessages(context.Background(), conversationID)
104 if err != nil {
105 t.Fatalf("failed to get messages: %v", err)
106 }
107
108 // Look for a tool result from bash tool that contains the expected text
109 for _, msg := range messages {
110 if msg.LlmData == nil {
111 continue
112 }
113 // The tool result for bash should contain the pwd output
114 // We distinguish it from the change_dir result by looking for the newline at the end
115 // (pwd outputs the path with a newline, change_dir outputs "Changed working directory to: ...")
116 // JSON encodes newline as \n so we check for that
117 if strings.Contains(*msg.LlmData, expectedText+`\n`) {
118 return
119 }
120 }
121 time.Sleep(50 * time.Millisecond)
122 }
123
124 // Print debug info on failure
125 messages, _ := database.ListMessages(context.Background(), conversationID)
126 t.Log("Messages in conversation:")
127 for i, msg := range messages {
128 t.Logf(" Message %d: type=%s", i, msg.Type)
129 if msg.LlmData != nil {
130 t.Logf(" data: %s", truncate(*msg.LlmData, 300))
131 }
132 }
133 t.Fatalf("did not find bash result containing %q within %v", expectedText, timeout)
134}
135
136// waitForMessageContaining waits for a message containing the specified text.
137func waitForMessageContaining(t *testing.T, database *db.DB, conversationID, text string, timeout time.Duration) {
138 t.Helper()
139 deadline := time.Now().Add(timeout)
140 for time.Now().Before(deadline) {
141 messages, err := database.ListMessages(context.Background(), conversationID)
142 if err != nil {
143 t.Fatalf("failed to get messages: %v", err)
144 }
145 for _, msg := range messages {
146 if msg.LlmData != nil && strings.Contains(*msg.LlmData, text) {
147 return
148 }
149 }
150 time.Sleep(50 * time.Millisecond)
151 }
152 t.Fatalf("did not find message containing %q within %v", text, timeout)
153}
154
155// getConversationMessages retrieves all messages for a conversation.
156func getConversationMessages(database *db.DB, conversationID string) ([]generated.Message, error) {
157 return database.ListMessages(context.Background(), conversationID)
158}
159
160// truncate truncates a string to maxLen characters.
161func truncate(s string, maxLen int) string {
162 if len(s) <= maxLen {
163 return s
164 }
165 return s[:maxLen] + "..."
166}
167
168// TestChangeDirBroadcastsCwdUpdate tests that change_dir broadcasts the updated cwd
169// to SSE subscribers so the UI gets the change immediately.
170func TestChangeDirBroadcastsCwdUpdate(t *testing.T) {
171 // Create a temp directory structure
172 tmpDir := t.TempDir()
173 subDir := filepath.Join(tmpDir, "subdir")
174 if err := os.Mkdir(subDir, 0o755); err != nil {
175 t.Fatal(err)
176 }
177
178 database, cleanup := setupTestDB(t)
179 defer cleanup()
180
181 predictableService := loop.NewPredictableService()
182 llmManager := &testLLMManager{service: predictableService}
183 logger := slog.Default()
184
185 // Create server with working directory set to tmpDir
186 toolSetConfig := claudetool.ToolSetConfig{
187 WorkingDir: tmpDir,
188 }
189 server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
190
191 // Create test server
192 ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
193 if strings.HasPrefix(r.URL.Path, "/api/conversation/") {
194 parts := strings.Split(r.URL.Path, "/")
195 if len(parts) >= 4 {
196 conversationID := parts[3]
197 if len(parts) >= 5 {
198 switch parts[4] {
199 case "chat":
200 server.handleChatConversation(w, r, conversationID)
201 return
202 case "stream":
203 server.handleStreamConversation(w, r, conversationID)
204 return
205 }
206 }
207 }
208 }
209 http.NotFound(w, r)
210 }))
211 defer ts.Close()
212
213 // Create conversation with initial cwd
214 conversation, err := database.CreateConversation(context.Background(), nil, true, &tmpDir, nil)
215 if err != nil {
216 t.Fatalf("failed to create conversation: %v", err)
217 }
218 conversationID := conversation.ConversationID
219
220 // Verify initial cwd
221 if conversation.Cwd == nil || *conversation.Cwd != tmpDir {
222 t.Fatalf("expected initial cwd %q, got %v", tmpDir, conversation.Cwd)
223 }
224
225 // Connect to SSE stream
226 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
227 defer cancel()
228
229 req, _ := http.NewRequestWithContext(ctx, "GET", ts.URL+"/api/conversation/"+conversationID+"/stream", nil)
230 req.Header.Set("Accept", "text/event-stream")
231
232 resp, err := http.DefaultClient.Do(req)
233 if err != nil {
234 t.Fatalf("failed to connect to SSE: %v", err)
235 }
236 defer resp.Body.Close()
237
238 // Channel to receive SSE events
239 events := make(chan StreamResponse, 10)
240 go func() {
241 scanner := bufio.NewScanner(resp.Body)
242 for scanner.Scan() {
243 line := scanner.Text()
244 if strings.HasPrefix(line, "data: ") {
245 data := strings.TrimPrefix(line, "data: ")
246 var sr StreamResponse
247 if err := json.Unmarshal([]byte(data), &sr); err == nil {
248 events <- sr
249 }
250 }
251 }
252 }()
253
254 // Wait for initial SSE event
255 select {
256 case <-events:
257 // Got initial event
258 case <-time.After(2 * time.Second):
259 t.Fatal("timeout waiting for initial SSE event")
260 }
261
262 // Send change_dir command
263 changeDirReq := ChatRequest{
264 Message: "change_dir: " + subDir,
265 Model: "predictable",
266 }
267 changeDirBody, _ := json.Marshal(changeDirReq)
268
269 chatReq, _ := http.NewRequest("POST", ts.URL+"/api/conversation/"+conversationID+"/chat", strings.NewReader(string(changeDirBody)))
270 chatReq.Header.Set("Content-Type", "application/json")
271 chatResp, err := http.DefaultClient.Do(chatReq)
272 if err != nil {
273 t.Fatalf("failed to send chat: %v", err)
274 }
275 chatResp.Body.Close()
276
277 // Wait for SSE event with updated cwd
278 deadline := time.Now().Add(5 * time.Second)
279 for time.Now().Before(deadline) {
280 select {
281 case event := <-events:
282 // Check if this event has the updated cwd
283 if event.Conversation.Cwd != nil && *event.Conversation.Cwd == subDir {
284 // Success! The UI would receive this update
285 return
286 }
287 case <-time.After(100 * time.Millisecond):
288 // Continue waiting
289 }
290 }
291
292 t.Error("did not receive SSE event with updated cwd")
293}