change_dir_test.go

  1package server
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"log/slog"
  7	"net/http"
  8	"net/http/httptest"
  9	"os"
 10	"path/filepath"
 11	"strings"
 12	"testing"
 13	"time"
 14
 15	"shelley.exe.dev/claudetool"
 16	"shelley.exe.dev/db"
 17	"shelley.exe.dev/db/generated"
 18	"shelley.exe.dev/loop"
 19)
 20
 21// TestChangeDirAffectsBash tests that change_dir updates the working directory
 22// and subsequent bash commands run in that directory.
 23func TestChangeDirAffectsBash(t *testing.T) {
 24	// Create a temp directory structure
 25	tmpDir := t.TempDir()
 26	subDir := filepath.Join(tmpDir, "subdir")
 27	if err := os.Mkdir(subDir, 0o755); err != nil {
 28		t.Fatal(err)
 29	}
 30
 31	// Create a marker file in subdir
 32	markerFile := filepath.Join(subDir, "marker.txt")
 33	if err := os.WriteFile(markerFile, []byte("found"), 0o644); err != nil {
 34		t.Fatal(err)
 35	}
 36
 37	database, cleanup := setupTestDB(t)
 38	defer cleanup()
 39
 40	predictableService := loop.NewPredictableService()
 41	llmManager := &testLLMManager{service: predictableService}
 42	logger := slog.Default()
 43
 44	// Create server with working directory set to tmpDir
 45	toolSetConfig := claudetool.ToolSetConfig{
 46		WorkingDir: tmpDir,
 47	}
 48	server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
 49
 50	// Create conversation
 51	conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
 52	if err != nil {
 53		t.Fatalf("failed to create conversation: %v", err)
 54	}
 55	conversationID := conversation.ConversationID
 56
 57	// Step 1: Send change_dir command to change to subdir
 58	changeDirReq := ChatRequest{
 59		Message: "change_dir: " + subDir,
 60		Model:   "predictable",
 61	}
 62	changeDirBody, _ := json.Marshal(changeDirReq)
 63
 64	req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(changeDirBody)))
 65	req.Header.Set("Content-Type", "application/json")
 66	w := httptest.NewRecorder()
 67
 68	server.handleChatConversation(w, req, conversationID)
 69	if w.Code != http.StatusAccepted {
 70		t.Fatalf("expected status 202 for change_dir, got %d: %s", w.Code, w.Body.String())
 71	}
 72
 73	// Wait for change_dir to complete - look for the tool result message
 74	waitForMessageContaining(t, database, conversationID, "Changed working directory", 5*time.Second)
 75
 76	// Step 2: Now send pwd command - should show subdir
 77	pwdReq := ChatRequest{
 78		Message: "bash: pwd",
 79		Model:   "predictable",
 80	}
 81	pwdBody, _ := json.Marshal(pwdReq)
 82
 83	req2 := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(pwdBody)))
 84	req2.Header.Set("Content-Type", "application/json")
 85	w2 := httptest.NewRecorder()
 86
 87	server.handleChatConversation(w2, req2, conversationID)
 88	if w2.Code != http.StatusAccepted {
 89		t.Fatalf("expected status 202 for bash pwd, got %d: %s", w2.Code, w2.Body.String())
 90	}
 91
 92	// Wait for bash pwd to complete - the second tool result should contain the subdir
 93	// We need to wait for 2 tool results: one from change_dir and one from pwd
 94	waitForBashResult(t, database, conversationID, subDir, 5*time.Second)
 95}
 96
 97// waitForBashResult waits for a bash tool result containing the expected text.
 98func waitForBashResult(t *testing.T, database *db.DB, conversationID, expectedText string, timeout time.Duration) {
 99	t.Helper()
100	deadline := time.Now().Add(timeout)
101	for time.Now().Before(deadline) {
102		messages, err := database.ListMessages(context.Background(), conversationID)
103		if err != nil {
104			t.Fatalf("failed to get messages: %v", err)
105		}
106
107		// Look for a tool result from bash tool that contains the expected text
108		for _, msg := range messages {
109			if msg.LlmData == nil {
110				continue
111			}
112			// The tool result for bash should contain the pwd output
113			// We distinguish it from the change_dir result by looking for the newline at the end
114			// (pwd outputs the path with a newline, change_dir outputs "Changed working directory to: ...")
115			// JSON encodes newline as \n so we check for that
116			if strings.Contains(*msg.LlmData, expectedText+`\n`) {
117				return
118			}
119		}
120		time.Sleep(50 * time.Millisecond)
121	}
122
123	// Print debug info on failure
124	messages, _ := database.ListMessages(context.Background(), conversationID)
125	t.Log("Messages in conversation:")
126	for i, msg := range messages {
127		t.Logf("  Message %d: type=%s", i, msg.Type)
128		if msg.LlmData != nil {
129			t.Logf("    data: %s", truncate(*msg.LlmData, 300))
130		}
131	}
132	t.Fatalf("did not find bash result containing %q within %v", expectedText, timeout)
133}
134
135// waitForMessageContaining waits for a message containing the specified text.
136func waitForMessageContaining(t *testing.T, database *db.DB, conversationID, text string, timeout time.Duration) {
137	t.Helper()
138	deadline := time.Now().Add(timeout)
139	for time.Now().Before(deadline) {
140		messages, err := database.ListMessages(context.Background(), conversationID)
141		if err != nil {
142			t.Fatalf("failed to get messages: %v", err)
143		}
144		for _, msg := range messages {
145			if msg.LlmData != nil && strings.Contains(*msg.LlmData, text) {
146				return
147			}
148		}
149		time.Sleep(50 * time.Millisecond)
150	}
151	t.Fatalf("did not find message containing %q within %v", text, timeout)
152}
153
154// getConversationMessages retrieves all messages for a conversation.
155func getConversationMessages(database *db.DB, conversationID string) ([]generated.Message, error) {
156	return database.ListMessages(context.Background(), conversationID)
157}
158
159// truncate truncates a string to maxLen characters.
160func truncate(s string, maxLen int) string {
161	if len(s) <= maxLen {
162		return s
163	}
164	return s[:maxLen] + "..."
165}