1package server
2
3import (
4 "context"
5 "encoding/json"
6 "log/slog"
7 "net/http"
8 "net/http/httptest"
9 "os"
10 "path/filepath"
11 "strings"
12 "testing"
13 "time"
14
15 "shelley.exe.dev/claudetool"
16 "shelley.exe.dev/db"
17 "shelley.exe.dev/db/generated"
18 "shelley.exe.dev/loop"
19)
20
21// TestChangeDirAffectsBash tests that change_dir updates the working directory
22// and subsequent bash commands run in that directory.
23func TestChangeDirAffectsBash(t *testing.T) {
24 // Create a temp directory structure
25 tmpDir := t.TempDir()
26 subDir := filepath.Join(tmpDir, "subdir")
27 if err := os.Mkdir(subDir, 0o755); err != nil {
28 t.Fatal(err)
29 }
30
31 // Create a marker file in subdir
32 markerFile := filepath.Join(subDir, "marker.txt")
33 if err := os.WriteFile(markerFile, []byte("found"), 0o644); err != nil {
34 t.Fatal(err)
35 }
36
37 database, cleanup := setupTestDB(t)
38 defer cleanup()
39
40 predictableService := loop.NewPredictableService()
41 llmManager := &testLLMManager{service: predictableService}
42 logger := slog.Default()
43
44 // Create server with working directory set to tmpDir
45 toolSetConfig := claudetool.ToolSetConfig{
46 WorkingDir: tmpDir,
47 }
48 server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
49
50 // Create conversation
51 conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
52 if err != nil {
53 t.Fatalf("failed to create conversation: %v", err)
54 }
55 conversationID := conversation.ConversationID
56
57 // Step 1: Send change_dir command to change to subdir
58 changeDirReq := ChatRequest{
59 Message: "change_dir: " + subDir,
60 Model: "predictable",
61 }
62 changeDirBody, _ := json.Marshal(changeDirReq)
63
64 req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(changeDirBody)))
65 req.Header.Set("Content-Type", "application/json")
66 w := httptest.NewRecorder()
67
68 server.handleChatConversation(w, req, conversationID)
69 if w.Code != http.StatusAccepted {
70 t.Fatalf("expected status 202 for change_dir, got %d: %s", w.Code, w.Body.String())
71 }
72
73 // Wait for change_dir to complete - look for the tool result message
74 waitForMessageContaining(t, database, conversationID, "Changed working directory", 5*time.Second)
75
76 // Step 2: Now send pwd command - should show subdir
77 pwdReq := ChatRequest{
78 Message: "bash: pwd",
79 Model: "predictable",
80 }
81 pwdBody, _ := json.Marshal(pwdReq)
82
83 req2 := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(pwdBody)))
84 req2.Header.Set("Content-Type", "application/json")
85 w2 := httptest.NewRecorder()
86
87 server.handleChatConversation(w2, req2, conversationID)
88 if w2.Code != http.StatusAccepted {
89 t.Fatalf("expected status 202 for bash pwd, got %d: %s", w2.Code, w2.Body.String())
90 }
91
92 // Wait for bash pwd to complete - the second tool result should contain the subdir
93 // We need to wait for 2 tool results: one from change_dir and one from pwd
94 waitForBashResult(t, database, conversationID, subDir, 5*time.Second)
95}
96
97// waitForBashResult waits for a bash tool result containing the expected text.
98func waitForBashResult(t *testing.T, database *db.DB, conversationID, expectedText string, timeout time.Duration) {
99 t.Helper()
100 deadline := time.Now().Add(timeout)
101 for time.Now().Before(deadline) {
102 messages, err := database.ListMessages(context.Background(), conversationID)
103 if err != nil {
104 t.Fatalf("failed to get messages: %v", err)
105 }
106
107 // Look for a tool result from bash tool that contains the expected text
108 for _, msg := range messages {
109 if msg.LlmData == nil {
110 continue
111 }
112 // The tool result for bash should contain the pwd output
113 // We distinguish it from the change_dir result by looking for the newline at the end
114 // (pwd outputs the path with a newline, change_dir outputs "Changed working directory to: ...")
115 // JSON encodes newline as \n so we check for that
116 if strings.Contains(*msg.LlmData, expectedText+`\n`) {
117 return
118 }
119 }
120 time.Sleep(50 * time.Millisecond)
121 }
122
123 // Print debug info on failure
124 messages, _ := database.ListMessages(context.Background(), conversationID)
125 t.Log("Messages in conversation:")
126 for i, msg := range messages {
127 t.Logf(" Message %d: type=%s", i, msg.Type)
128 if msg.LlmData != nil {
129 t.Logf(" data: %s", truncate(*msg.LlmData, 300))
130 }
131 }
132 t.Fatalf("did not find bash result containing %q within %v", expectedText, timeout)
133}
134
135// waitForMessageContaining waits for a message containing the specified text.
136func waitForMessageContaining(t *testing.T, database *db.DB, conversationID, text string, timeout time.Duration) {
137 t.Helper()
138 deadline := time.Now().Add(timeout)
139 for time.Now().Before(deadline) {
140 messages, err := database.ListMessages(context.Background(), conversationID)
141 if err != nil {
142 t.Fatalf("failed to get messages: %v", err)
143 }
144 for _, msg := range messages {
145 if msg.LlmData != nil && strings.Contains(*msg.LlmData, text) {
146 return
147 }
148 }
149 time.Sleep(50 * time.Millisecond)
150 }
151 t.Fatalf("did not find message containing %q within %v", text, timeout)
152}
153
154// getConversationMessages retrieves all messages for a conversation.
155func getConversationMessages(database *db.DB, conversationID string) ([]generated.Message, error) {
156 return database.ListMessages(context.Background(), conversationID)
157}
158
159// truncate truncates a string to maxLen characters.
160func truncate(s string, maxLen int) string {
161 if len(s) <= maxLen {
162 return s
163 }
164 return s[:maxLen] + "..."
165}