@@ -22,6 +22,7 @@ import (
// - "bash: <command>" - triggers bash tool with command
// - "think: <thoughts>" - triggers think tool
// - "subagent: <slug> <prompt>" - triggers subagent tool
+// - "change_dir: <path>" - triggers change_dir tool
// - "delay: <seconds>" - delays response by specified seconds
// - See Do() method for complete list of supported patterns
type PredictableService struct {
@@ -83,18 +84,25 @@ func (s *PredictableService) Do(ctx context.Context, req *llm.Request) (*llm.Res
// Extract the text content from the last user message
var inputText string
+ var hasToolResult bool
if len(req.Messages) > 0 {
lastMessage := req.Messages[len(req.Messages)-1]
if lastMessage.Role == llm.MessageRoleUser {
for _, content := range lastMessage.Content {
if content.Type == llm.ContentTypeText {
inputText = strings.TrimSpace(content.Text)
- break
+ } else if content.Type == llm.ContentTypeToolResult {
+ hasToolResult = true
}
}
}
}
+ // If the message is purely a tool result (no text), acknowledge it and end turn
+ if hasToolResult && inputText == "" {
+ return s.makeResponse("Done.", inputTokens), nil
+ }
+
// Handle input using case statements
switch inputText {
case "hello":
@@ -176,6 +184,11 @@ func (s *PredictableService) Do(ctx context.Context, req *llm.Request) (*llm.Res
return s.makeSubagentToolResponse(slug, prompt, inputTokens), nil
}
+ if strings.HasPrefix(inputText, "change_dir: ") {
+ path := strings.TrimPrefix(inputText, "change_dir: ")
+ return s.makeChangeDirToolResponse(path, inputTokens), nil
+ }
+
if strings.HasPrefix(inputText, "delay: ") {
delayStr := strings.TrimPrefix(inputText, "delay: ")
delaySeconds, err := strconv.ParseFloat(delayStr, 64)
@@ -532,6 +545,39 @@ func (s *PredictableService) makeScreenshotToolResponse(selector string, inputTo
}
}
+// makeChangeDirToolResponse creates a response that calls the change_dir tool
+func (s *PredictableService) makeChangeDirToolResponse(path string, inputTokens uint64) *llm.Response {
+ toolInputData := map[string]string{"path": path}
+ toolInputBytes, _ := json.Marshal(toolInputData)
+ toolInput := json.RawMessage(toolInputBytes)
+ responseText := fmt.Sprintf("I'll change to directory: %s", path)
+ outputTokens := uint64(len(responseText)/4 + len(toolInputBytes)/4)
+ if outputTokens == 0 {
+ outputTokens = 1
+ }
+ return &llm.Response{
+ ID: fmt.Sprintf("pred-change_dir-%d", time.Now().UnixNano()),
+ Type: "message",
+ Role: llm.MessageRoleAssistant,
+ Model: "predictable-v1",
+ Content: []llm.Content{
+ {Type: llm.ContentTypeText, Text: responseText},
+ {
+ ID: fmt.Sprintf("tool_%d", time.Now().UnixNano()%1000),
+ Type: llm.ContentTypeToolUse,
+ ToolName: "change_dir",
+ ToolInput: toolInput,
+ },
+ },
+ StopReason: llm.StopReasonToolUse,
+ Usage: llm.Usage{
+ InputTokens: inputTokens,
+ OutputTokens: outputTokens,
+ CostUSD: 0.001,
+ },
+ }
+}
+
func (s *PredictableService) makeSubagentToolResponse(slug, prompt string, inputTokens uint64) *llm.Response {
toolInputData := map[string]any{
"slug": slug,
@@ -0,0 +1,165 @@
+package server
+
+import (
+ "context"
+ "encoding/json"
+ "log/slog"
+ "net/http"
+ "net/http/httptest"
+ "os"
+ "path/filepath"
+ "strings"
+ "testing"
+ "time"
+
+ "shelley.exe.dev/claudetool"
+ "shelley.exe.dev/db"
+ "shelley.exe.dev/db/generated"
+ "shelley.exe.dev/loop"
+)
+
+// TestChangeDirAffectsBash tests that change_dir updates the working directory
+// and subsequent bash commands run in that directory.
+func TestChangeDirAffectsBash(t *testing.T) {
+ // Create a temp directory structure
+ tmpDir := t.TempDir()
+ subDir := filepath.Join(tmpDir, "subdir")
+ if err := os.Mkdir(subDir, 0o755); err != nil {
+ t.Fatal(err)
+ }
+
+ // Create a marker file in subdir
+ markerFile := filepath.Join(subDir, "marker.txt")
+ if err := os.WriteFile(markerFile, []byte("found"), 0o644); err != nil {
+ t.Fatal(err)
+ }
+
+ database, cleanup := setupTestDB(t)
+ defer cleanup()
+
+ predictableService := loop.NewPredictableService()
+ llmManager := &testLLMManager{service: predictableService}
+ logger := slog.Default()
+
+ // Create server with working directory set to tmpDir
+ toolSetConfig := claudetool.ToolSetConfig{
+ WorkingDir: tmpDir,
+ }
+ server := NewServer(database, llmManager, toolSetConfig, logger, true, "", "predictable", "", nil)
+
+ // Create conversation
+ conversation, err := database.CreateConversation(context.Background(), nil, true, nil, nil)
+ if err != nil {
+ t.Fatalf("failed to create conversation: %v", err)
+ }
+ conversationID := conversation.ConversationID
+
+ // Step 1: Send change_dir command to change to subdir
+ changeDirReq := ChatRequest{
+ Message: "change_dir: " + subDir,
+ Model: "predictable",
+ }
+ changeDirBody, _ := json.Marshal(changeDirReq)
+
+ req := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(changeDirBody)))
+ req.Header.Set("Content-Type", "application/json")
+ w := httptest.NewRecorder()
+
+ server.handleChatConversation(w, req, conversationID)
+ if w.Code != http.StatusAccepted {
+ t.Fatalf("expected status 202 for change_dir, got %d: %s", w.Code, w.Body.String())
+ }
+
+ // Wait for change_dir to complete - look for the tool result message
+ waitForMessageContaining(t, database, conversationID, "Changed working directory", 5*time.Second)
+
+ // Step 2: Now send pwd command - should show subdir
+ pwdReq := ChatRequest{
+ Message: "bash: pwd",
+ Model: "predictable",
+ }
+ pwdBody, _ := json.Marshal(pwdReq)
+
+ req2 := httptest.NewRequest("POST", "/api/conversation/"+conversationID+"/chat", strings.NewReader(string(pwdBody)))
+ req2.Header.Set("Content-Type", "application/json")
+ w2 := httptest.NewRecorder()
+
+ server.handleChatConversation(w2, req2, conversationID)
+ if w2.Code != http.StatusAccepted {
+ t.Fatalf("expected status 202 for bash pwd, got %d: %s", w2.Code, w2.Body.String())
+ }
+
+ // Wait for bash pwd to complete - the second tool result should contain the subdir
+ // We need to wait for 2 tool results: one from change_dir and one from pwd
+ waitForBashResult(t, database, conversationID, subDir, 5*time.Second)
+}
+
+// waitForBashResult waits for a bash tool result containing the expected text.
+func waitForBashResult(t *testing.T, database *db.DB, conversationID, expectedText string, timeout time.Duration) {
+ t.Helper()
+ deadline := time.Now().Add(timeout)
+ for time.Now().Before(deadline) {
+ messages, err := database.ListMessages(context.Background(), conversationID)
+ if err != nil {
+ t.Fatalf("failed to get messages: %v", err)
+ }
+
+ // Look for a tool result from bash tool that contains the expected text
+ for _, msg := range messages {
+ if msg.LlmData == nil {
+ continue
+ }
+ // The tool result for bash should contain the pwd output
+ // We distinguish it from the change_dir result by looking for the newline at the end
+ // (pwd outputs the path with a newline, change_dir outputs "Changed working directory to: ...")
+ // JSON encodes newline as \n so we check for that
+ if strings.Contains(*msg.LlmData, expectedText+`\n`) {
+ return
+ }
+ }
+ time.Sleep(50 * time.Millisecond)
+ }
+
+ // Print debug info on failure
+ messages, _ := database.ListMessages(context.Background(), conversationID)
+ t.Log("Messages in conversation:")
+ for i, msg := range messages {
+ t.Logf(" Message %d: type=%s", i, msg.Type)
+ if msg.LlmData != nil {
+ t.Logf(" data: %s", truncate(*msg.LlmData, 300))
+ }
+ }
+ t.Fatalf("did not find bash result containing %q within %v", expectedText, timeout)
+}
+
+// waitForMessageContaining waits for a message containing the specified text.
+func waitForMessageContaining(t *testing.T, database *db.DB, conversationID, text string, timeout time.Duration) {
+ t.Helper()
+ deadline := time.Now().Add(timeout)
+ for time.Now().Before(deadline) {
+ messages, err := database.ListMessages(context.Background(), conversationID)
+ if err != nil {
+ t.Fatalf("failed to get messages: %v", err)
+ }
+ for _, msg := range messages {
+ if msg.LlmData != nil && strings.Contains(*msg.LlmData, text) {
+ return
+ }
+ }
+ time.Sleep(50 * time.Millisecond)
+ }
+ t.Fatalf("did not find message containing %q within %v", text, timeout)
+}
+
+// getConversationMessages retrieves all messages for a conversation.
+func getConversationMessages(database *db.DB, conversationID string) ([]generated.Message, error) {
+ return database.ListMessages(context.Background(), conversationID)
+}
+
+// truncate truncates a string to maxLen characters.
+func truncate(s string, maxLen int) string {
+ if len(s) <= maxLen {
+ return s
+ }
+ return s[:maxLen] + "..."
+}