cwd_test.go

  1package server
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"net/http"
  7	"net/http/httptest"
  8	"os"
  9	"os/exec"
 10	"path/filepath"
 11	"strings"
 12	"testing"
 13
 14	"shelley.exe.dev/db/generated"
 15	"shelley.exe.dev/llm"
 16)
 17
 18// TestWorkingDirectoryConfiguration tests that the working directory (cwd) setting
 19// is properly passed through from HTTP requests to tool execution.
 20func TestWorkingDirectoryConfiguration(t *testing.T) {
 21	h := NewTestHarness(t)
 22	defer h.Close()
 23
 24	t.Run("cwd_tmp", func(t *testing.T) {
 25		h.NewConversation("bash: pwd", "/tmp")
 26		result := strings.TrimSpace(h.WaitToolResult())
 27		// Resolve symlinks for comparison (on macOS, /tmp -> /private/tmp)
 28		expected, _ := filepath.EvalSymlinks("/tmp")
 29		if result != expected {
 30			t.Errorf("expected %q, got: %s", expected, result)
 31		}
 32	})
 33
 34	t.Run("cwd_root", func(t *testing.T) {
 35		h.NewConversation("bash: pwd", "/")
 36		result := strings.TrimSpace(h.WaitToolResult())
 37		if result != "/" {
 38			t.Errorf("expected '/', got: %s", result)
 39		}
 40	})
 41}
 42
 43// TestListDirectory tests the list-directory API endpoint used by the directory picker.
 44func TestListDirectory(t *testing.T) {
 45	h := NewTestHarness(t)
 46	defer h.Close()
 47
 48	t.Run("list_tmp", func(t *testing.T) {
 49		req := httptest.NewRequest("GET", "/api/list-directory?path=/tmp", nil)
 50		w := httptest.NewRecorder()
 51		h.server.handleListDirectory(w, req)
 52
 53		if w.Code != http.StatusOK {
 54			t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
 55		}
 56
 57		var resp ListDirectoryResponse
 58		if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
 59			t.Fatalf("failed to parse response: %v", err)
 60		}
 61
 62		if resp.Path != "/tmp" {
 63			t.Errorf("expected path '/tmp', got: %s", resp.Path)
 64		}
 65
 66		if resp.Parent != "/" {
 67			t.Errorf("expected parent '/', got: %s", resp.Parent)
 68		}
 69	})
 70
 71	t.Run("list_root", func(t *testing.T) {
 72		req := httptest.NewRequest("GET", "/api/list-directory?path=/", nil)
 73		w := httptest.NewRecorder()
 74		h.server.handleListDirectory(w, req)
 75
 76		if w.Code != http.StatusOK {
 77			t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
 78		}
 79
 80		var resp ListDirectoryResponse
 81		if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
 82			t.Fatalf("failed to parse response: %v", err)
 83		}
 84
 85		if resp.Path != "/" {
 86			t.Errorf("expected path '/', got: %s", resp.Path)
 87		}
 88
 89		// Root should have no parent
 90		if resp.Parent != "" {
 91			t.Errorf("expected no parent, got: %s", resp.Parent)
 92		}
 93
 94		// Root should have at least some directories (tmp, etc, home, etc.)
 95		if len(resp.Entries) == 0 {
 96			t.Error("expected at least some entries in root")
 97		}
 98	})
 99
100	t.Run("list_default_path", func(t *testing.T) {
101		req := httptest.NewRequest("GET", "/api/list-directory", nil)
102		w := httptest.NewRecorder()
103		h.server.handleListDirectory(w, req)
104
105		if w.Code != http.StatusOK {
106			t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
107		}
108
109		var resp ListDirectoryResponse
110		if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
111			t.Fatalf("failed to parse response: %v", err)
112		}
113
114		// Should default to home directory
115		homeDir, _ := os.UserHomeDir()
116		if homeDir != "" && resp.Path != homeDir {
117			t.Errorf("expected path '%s', got: %s", homeDir, resp.Path)
118		}
119	})
120
121	t.Run("list_nonexistent", func(t *testing.T) {
122		req := httptest.NewRequest("GET", "/api/list-directory?path=/nonexistent/path/123456", nil)
123		w := httptest.NewRecorder()
124		h.server.handleListDirectory(w, req)
125
126		if w.Code != http.StatusOK {
127			t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
128		}
129
130		var resp map[string]interface{}
131		if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
132			t.Fatalf("failed to parse response: %v", err)
133		}
134
135		if resp["error"] == nil {
136			t.Error("expected error field in response")
137		}
138	})
139
140	t.Run("list_file_not_directory", func(t *testing.T) {
141		// Create a temp file
142		f, err := os.CreateTemp("", "test")
143		if err != nil {
144			t.Fatalf("failed to create temp file: %v", err)
145		}
146		defer os.Remove(f.Name())
147		f.Close()
148
149		req := httptest.NewRequest("GET", "/api/list-directory?path="+f.Name(), nil)
150		w := httptest.NewRecorder()
151		h.server.handleListDirectory(w, req)
152
153		if w.Code != http.StatusOK {
154			t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
155		}
156
157		var resp map[string]interface{}
158		if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
159			t.Fatalf("failed to parse response: %v", err)
160		}
161
162		errMsg, ok := resp["error"].(string)
163		if !ok || errMsg != "path is not a directory" {
164			t.Errorf("expected error 'path is not a directory', got: %v", resp["error"])
165		}
166	})
167
168	t.Run("only_directories_returned", func(t *testing.T) {
169		// Create a temp directory with both files and directories
170		tmpDir, err := os.MkdirTemp("", "listdir_test")
171		if err != nil {
172			t.Fatalf("failed to create temp dir: %v", err)
173		}
174		defer os.RemoveAll(tmpDir)
175
176		// Create a subdirectory
177		subDir := tmpDir + "/subdir"
178		if err := os.Mkdir(subDir, 0o755); err != nil {
179			t.Fatalf("failed to create subdir: %v", err)
180		}
181
182		// Create a file
183		file := tmpDir + "/file.txt"
184		if err := os.WriteFile(file, []byte("test"), 0o644); err != nil {
185			t.Fatalf("failed to create file: %v", err)
186		}
187
188		req := httptest.NewRequest("GET", "/api/list-directory?path="+tmpDir, nil)
189		w := httptest.NewRecorder()
190		h.server.handleListDirectory(w, req)
191
192		if w.Code != http.StatusOK {
193			t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
194		}
195
196		var resp ListDirectoryResponse
197		if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
198			t.Fatalf("failed to parse response: %v", err)
199		}
200
201		// Should only include the directory, not the file
202		if len(resp.Entries) != 1 {
203			t.Errorf("expected 1 entry, got: %d", len(resp.Entries))
204		}
205
206		if len(resp.Entries) > 0 && resp.Entries[0].Name != "subdir" {
207			t.Errorf("expected entry 'subdir', got: %s", resp.Entries[0].Name)
208		}
209	})
210
211	t.Run("hidden_directories_included", func(t *testing.T) {
212		// Create a temp directory with a hidden directory
213		tmpDir, err := os.MkdirTemp("", "listdir_hidden_test")
214		if err != nil {
215			t.Fatalf("failed to create temp dir: %v", err)
216		}
217		defer os.RemoveAll(tmpDir)
218
219		// Create a visible subdirectory
220		visibleDir := tmpDir + "/visible"
221		if err := os.Mkdir(visibleDir, 0o755); err != nil {
222			t.Fatalf("failed to create visible dir: %v", err)
223		}
224
225		// Create a hidden subdirectory
226		hiddenDir := tmpDir + "/.hidden"
227		if err := os.Mkdir(hiddenDir, 0o755); err != nil {
228			t.Fatalf("failed to create hidden dir: %v", err)
229		}
230
231		req := httptest.NewRequest("GET", "/api/list-directory?path="+tmpDir, nil)
232		w := httptest.NewRecorder()
233		h.server.handleListDirectory(w, req)
234
235		if w.Code != http.StatusOK {
236			t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
237		}
238
239		var resp ListDirectoryResponse
240		if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
241			t.Fatalf("failed to parse response: %v", err)
242		}
243
244		// Should include both visible and hidden directories
245		if len(resp.Entries) != 2 {
246			t.Errorf("expected 2 entries, got: %d", len(resp.Entries))
247		}
248
249		// Check that both directories are present (sorted alphabetically, hidden first)
250		names := make(map[string]bool)
251		for _, e := range resp.Entries {
252			names[e.Name] = true
253		}
254		if !names[".hidden"] {
255			t.Errorf("expected .hidden to be included")
256		}
257		if !names["visible"] {
258			t.Errorf("expected visible to be included")
259		}
260	})
261
262	t.Run("git_repo_head_subject", func(t *testing.T) {
263		// Create a temp directory containing a git repo
264		tmpDir, err := os.MkdirTemp("", "listdir_git_test")
265		if err != nil {
266			t.Fatalf("failed to create temp dir: %v", err)
267		}
268		defer os.RemoveAll(tmpDir)
269
270		// Create a subdirectory that will be a git repo
271		repoDir := tmpDir + "/myrepo"
272		if err := os.Mkdir(repoDir, 0o755); err != nil {
273			t.Fatalf("failed to create repo dir: %v", err)
274		}
275
276		// Initialize git repo and create a commit
277		cmd := exec.Command("git", "init")
278		cmd.Dir = repoDir
279		if err := cmd.Run(); err != nil {
280			t.Fatalf("failed to init git: %v", err)
281		}
282
283		cmd = exec.Command("git", "config", "user.email", "test@example.com")
284		cmd.Dir = repoDir
285		if err := cmd.Run(); err != nil {
286			t.Fatalf("failed to config git email: %v", err)
287		}
288
289		cmd = exec.Command("git", "config", "user.name", "Test User")
290		cmd.Dir = repoDir
291		if err := cmd.Run(); err != nil {
292			t.Fatalf("failed to config git name: %v", err)
293		}
294
295		// Create a file and commit it
296		if err := os.WriteFile(repoDir+"/README.md", []byte("# Hello"), 0o644); err != nil {
297			t.Fatalf("failed to write file: %v", err)
298		}
299
300		cmd = exec.Command("git", "add", "README.md")
301		cmd.Dir = repoDir
302		if err := cmd.Run(); err != nil {
303			t.Fatalf("failed to git add: %v", err)
304		}
305
306		cmd = exec.Command("git", "commit", "-m", "Test commit subject line\n\nPrompt: test")
307		cmd.Dir = repoDir
308		if err := cmd.Run(); err != nil {
309			t.Fatalf("failed to git commit: %v", err)
310		}
311
312		// Create another directory that is not a git repo
313		nonRepoDir := tmpDir + "/notarepo"
314		if err := os.Mkdir(nonRepoDir, 0o755); err != nil {
315			t.Fatalf("failed to create non-repo dir: %v", err)
316		}
317
318		req := httptest.NewRequest("GET", "/api/list-directory?path="+tmpDir, nil)
319		w := httptest.NewRecorder()
320		h.server.handleListDirectory(w, req)
321
322		if w.Code != http.StatusOK {
323			t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
324		}
325
326		var resp ListDirectoryResponse
327		if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
328			t.Fatalf("failed to parse response: %v", err)
329		}
330
331		if len(resp.Entries) != 2 {
332			t.Fatalf("expected 2 entries, got: %d", len(resp.Entries))
333		}
334
335		// Find the git repo entry and verify it has the commit subject
336		var gitEntry, nonGitEntry *DirectoryEntry
337		for i := range resp.Entries {
338			if resp.Entries[i].Name == "myrepo" {
339				gitEntry = &resp.Entries[i]
340			} else if resp.Entries[i].Name == "notarepo" {
341				nonGitEntry = &resp.Entries[i]
342			}
343		}
344
345		if gitEntry == nil {
346			t.Fatal("expected to find myrepo entry")
347		}
348		if nonGitEntry == nil {
349			t.Fatal("expected to find notarepo entry")
350		}
351
352		// Git repo should have the HEAD commit subject
353		if gitEntry.GitHeadSubject != "Test commit subject line" {
354			t.Errorf("expected git_head_subject 'Test commit subject line', got: %q", gitEntry.GitHeadSubject)
355		}
356
357		// Non-git dir should not have a subject
358		if nonGitEntry.GitHeadSubject != "" {
359			t.Errorf("expected empty git_head_subject for non-git dir, got: %q", nonGitEntry.GitHeadSubject)
360		}
361	})
362
363	t.Run("git_worktree_head_subject", func(t *testing.T) {
364		// Create a temp directory containing a git repo and a worktree
365		tmpDir, err := os.MkdirTemp("", "listdir_worktree_test")
366		if err != nil {
367			t.Fatalf("failed to create temp dir: %v", err)
368		}
369		defer os.RemoveAll(tmpDir)
370
371		// Create a main git repo
372		mainRepo := tmpDir + "/main-repo"
373		if err := os.Mkdir(mainRepo, 0o755); err != nil {
374			t.Fatalf("failed to create main repo dir: %v", err)
375		}
376
377		// Initialize git repo and create a commit
378		cmd := exec.Command("git", "init")
379		cmd.Dir = mainRepo
380		if err := cmd.Run(); err != nil {
381			t.Fatalf("failed to init git: %v", err)
382		}
383
384		cmd = exec.Command("git", "config", "user.email", "test@example.com")
385		cmd.Dir = mainRepo
386		if err := cmd.Run(); err != nil {
387			t.Fatalf("failed to config git email: %v", err)
388		}
389
390		cmd = exec.Command("git", "config", "user.name", "Test User")
391		cmd.Dir = mainRepo
392		if err := cmd.Run(); err != nil {
393			t.Fatalf("failed to config git name: %v", err)
394		}
395
396		// Create a file and commit it
397		if err := os.WriteFile(mainRepo+"/README.md", []byte("# Hello"), 0o644); err != nil {
398			t.Fatalf("failed to write file: %v", err)
399		}
400
401		cmd = exec.Command("git", "add", "README.md")
402		cmd.Dir = mainRepo
403		if err := cmd.Run(); err != nil {
404			t.Fatalf("failed to git add: %v", err)
405		}
406
407		cmd = exec.Command("git", "commit", "-m", "Main repo commit\n\nPrompt: test")
408		cmd.Dir = mainRepo
409		if err := cmd.Run(); err != nil {
410			t.Fatalf("failed to git commit: %v", err)
411		}
412
413		// Create a branch and worktree
414		cmd = exec.Command("git", "branch", "feature-branch")
415		cmd.Dir = mainRepo
416		if err := cmd.Run(); err != nil {
417			t.Fatalf("failed to create branch: %v", err)
418		}
419
420		worktreePath := tmpDir + "/worktree-dir"
421		cmd = exec.Command("git", "worktree", "add", worktreePath, "feature-branch")
422		cmd.Dir = mainRepo
423		if err := cmd.Run(); err != nil {
424			t.Fatalf("failed to create worktree: %v", err)
425		}
426
427		// Verify the worktree has a .git file (not directory)
428		gitPath := worktreePath + "/.git"
429		fi, err := os.Stat(gitPath)
430		if err != nil {
431			t.Fatalf("failed to stat worktree .git: %v", err)
432		}
433		if fi.IsDir() {
434			t.Fatalf("expected .git to be a file for worktree, got directory")
435		}
436
437		req := httptest.NewRequest("GET", "/api/list-directory?path="+tmpDir, nil)
438		w := httptest.NewRecorder()
439		h.server.handleListDirectory(w, req)
440
441		if w.Code != http.StatusOK {
442			t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
443		}
444
445		var resp ListDirectoryResponse
446		if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
447			t.Fatalf("failed to parse response: %v", err)
448		}
449
450		// Find the worktree entry and verify it has the commit subject
451		var worktreeEntry *DirectoryEntry
452		for i := range resp.Entries {
453			if resp.Entries[i].Name == "worktree-dir" {
454				worktreeEntry = &resp.Entries[i]
455			}
456		}
457
458		if worktreeEntry == nil {
459			t.Fatal("expected to find worktree-dir entry")
460		}
461
462		// Worktree should have the HEAD commit subject
463		if worktreeEntry.GitHeadSubject != "Main repo commit" {
464			t.Errorf("expected git_head_subject 'Main repo commit', got: %q", worktreeEntry.GitHeadSubject)
465		}
466	})
467}
468
469// TestConversationCwdReturnedInList tests that CWD is returned in the conversations list.
470func TestConversationCwdReturnedInList(t *testing.T) {
471	h := NewTestHarness(t)
472	defer h.Close()
473
474	// Create a conversation with a specific CWD
475	h.NewConversation("bash: pwd", "/tmp")
476	h.WaitToolResult() // Wait for the conversation to complete
477
478	// Get the conversations list
479	req := httptest.NewRequest("GET", "/api/conversations", nil)
480	w := httptest.NewRecorder()
481	h.server.handleConversations(w, req)
482
483	if w.Code != http.StatusOK {
484		t.Fatalf("expected status 200, got %d: %s", w.Code, w.Body.String())
485	}
486
487	var convs []map[string]interface{}
488	if err := json.Unmarshal(w.Body.Bytes(), &convs); err != nil {
489		t.Fatalf("failed to parse response: %v", err)
490	}
491
492	if len(convs) == 0 {
493		t.Fatal("expected at least one conversation")
494	}
495
496	// Find our conversation
497	found := false
498	for _, conv := range convs {
499		if conv["conversation_id"] == h.ConversationID() {
500			found = true
501			cwd, ok := conv["cwd"].(string)
502			if !ok {
503				t.Errorf("expected cwd to be a string, got: %T", conv["cwd"])
504			}
505			if cwd != "/tmp" {
506				t.Errorf("expected cwd '/tmp', got: %s", cwd)
507			}
508			break
509		}
510	}
511
512	if !found {
513		t.Error("conversation not found in list")
514	}
515}
516
517// TestSystemPromptUsesCwdFromConversation verifies that when a conversation
518// is created with a specific cwd, the system prompt is generated using that
519// directory (not the server's working directory). This tests the fix for
520// https://github.com/boldsoftware/shelley/issues/30
521func TestSystemPromptUsesCwdFromConversation(t *testing.T) {
522	// Create a temp directory with an AGENTS.md file
523	tmpDir, err := os.MkdirTemp("", "shelley_cwd_test")
524	if err != nil {
525		t.Fatalf("failed to create temp dir: %v", err)
526	}
527	defer os.RemoveAll(tmpDir)
528
529	// Create an AGENTS.md file with unique content we can search for
530	agentsContent := "UNIQUE_MARKER_FOR_CWD_TEST_XYZ123: This is test guidance."
531	agentsFile := filepath.Join(tmpDir, "AGENTS.md")
532	if err := os.WriteFile(agentsFile, []byte(agentsContent), 0o644); err != nil {
533		t.Fatalf("failed to write AGENTS.md: %v", err)
534	}
535
536	h := NewTestHarness(t)
537	defer h.Close()
538
539	// Create a conversation with the temp directory as cwd
540	h.NewConversation("bash: echo hello", tmpDir)
541	h.WaitToolResult()
542
543	// Get the system prompt from the database
544	var messages []generated.Message
545	err = h.db.Queries(context.Background(), func(q *generated.Queries) error {
546		var qerr error
547		messages, qerr = q.ListMessages(context.Background(), h.ConversationID())
548		return qerr
549	})
550	if err != nil {
551		t.Fatalf("failed to get messages: %v", err)
552	}
553
554	// Find the system message
555	var systemPrompt string
556	for _, msg := range messages {
557		if msg.Type == "system" && msg.LlmData != nil {
558			var llmMsg llm.Message
559			if err := json.Unmarshal([]byte(*msg.LlmData), &llmMsg); err == nil {
560				for _, content := range llmMsg.Content {
561					if content.Type == llm.ContentTypeText {
562						systemPrompt = content.Text
563						break
564					}
565				}
566			}
567			break
568		}
569	}
570
571	if systemPrompt == "" {
572		t.Fatal("no system prompt found in messages")
573	}
574
575	// Verify the system prompt contains our unique marker from AGENTS.md
576	if !strings.Contains(systemPrompt, "UNIQUE_MARKER_FOR_CWD_TEST_XYZ123") {
577		t.Errorf("system prompt should contain content from AGENTS.md in the cwd directory")
578		// Log first 1000 chars to help debug
579		if len(systemPrompt) > 1000 {
580			t.Logf("system prompt (first 1000 chars): %s...", systemPrompt[:1000])
581		} else {
582			t.Logf("system prompt: %s", systemPrompt)
583		}
584	}
585
586	// Verify the working directory in the prompt is our temp directory
587	if !strings.Contains(systemPrompt, tmpDir) {
588		t.Errorf("system prompt should reference the cwd directory: %s", tmpDir)
589	}
590}