From cc96c20490392bede61e1d5c86108e9671e4e371 Mon Sep 17 00:00:00 2001 From: Raphael Amorim Date: Wed, 25 Jun 2025 12:46:29 +0200 Subject: [PATCH] feat: full windows support --- README.md | 1 - internal/llm/tools/bash.go | 64 +++++++--- internal/llm/tools/bash_test.go | 125 ++++++++++++++++++ internal/llm/tools/glob.go | 5 +- internal/llm/tools/grep.go | 5 +- internal/llm/tools/shell/shell.go | 169 +++++++++++++++++++++++-- internal/llm/tools/shell/shell_test.go | 143 +++++++++++++++++++++ 7 files changed, 480 insertions(+), 32 deletions(-) create mode 100644 internal/llm/tools/bash_test.go diff --git a/README.md b/README.md index 18767a8fc74fc1583caebb1e66d9c2e9a1feb77f..d94fb2690b7ad370d3096f92504cd203268d08b5 100644 --- a/README.md +++ b/README.md @@ -10,7 +10,6 @@ Crush is a tool for building software with AI. - ## License [MIT](https://github.com/charmbracelet/crush/raw/main/LICENSE) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 7963e95af24c6bb6257adc0c6341b95f453443b3..d9c19b808d487641f06529036c918e73225be7c8 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "fmt" + "runtime" "strings" "time" @@ -45,27 +46,58 @@ var bannedCommands = []string{ "http-prompt", "chrome", "firefox", "safari", } -var safeReadOnlyCommands = []string{ - "ls", "echo", "pwd", "date", "cal", "uptime", "whoami", "id", "groups", "env", "printenv", "set", "unset", "which", "type", "whereis", - "whatis", "uname", "hostname", "df", "du", "free", "top", "ps", "kill", "killall", "nice", "nohup", "time", "timeout", - - "git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote", - "git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog", +// getSafeReadOnlyCommands returns platform-appropriate safe commands +func getSafeReadOnlyCommands() []string { + // Base commands that work on all platforms + baseCommands := []string{ + // Cross-platform commands + "echo", "hostname", "whoami", + + // Git commands (cross-platform) + "git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote", + "git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog", + + // Go commands (cross-platform) + "go version", "go help", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean", + } - "go version", "go help", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean", + if runtime.GOOS == "windows" { + // Windows-specific commands + windowsCommands := []string{ + "dir", "type", "where", "ver", "systeminfo", "tasklist", "ipconfig", "ping", "nslookup", + "Get-Process", "Get-Location", "Get-ChildItem", "Get-Content", "Get-Date", "Get-Host", "Get-ComputerInfo", + } + return append(baseCommands, windowsCommands...) + } else { + // Unix/Linux commands (including WSL, since WSL reports as Linux) + unixCommands := []string{ + "ls", "pwd", "date", "cal", "uptime", "id", "groups", "env", "printenv", "set", "unset", "which", "type", "whereis", + "whatis", "uname", "df", "du", "free", "top", "ps", "kill", "killall", "nice", "nohup", "time", "timeout", + } + return append(baseCommands, unixCommands...) + } } func bashDescription() string { bannedCommandsStr := strings.Join(bannedCommands, ", ") return fmt.Sprintf(`Executes a given bash command in a persistent shell session with optional timeout, ensuring proper handling and security measures. -IMPORTANT FOR WINDOWS USERS: -- This tool uses a POSIX shell emulator (mvdan.cc/sh/v3) that works cross-platform, including Windows -- On Windows, this provides bash-like functionality without requiring WSL or Git Bash -- Use forward slashes (/) in paths - they work on all platforms and are converted automatically -- Windows-specific commands (like 'dir', 'type', 'copy') are not available - use Unix equivalents ('ls', 'cat', 'cp') -- Environment variables use Unix syntax: $VAR instead of %%VAR%% -- File paths are automatically converted between Windows and Unix formats as needed +CROSS-PLATFORM SHELL SUPPORT: +- Unix/Linux/macOS: Uses native bash/sh shell +- Windows: Intelligent shell selection: + * Windows commands (dir, type, copy, etc.) use cmd.exe + * PowerShell commands (Get-, Set-, etc.) use PowerShell + * Unix-style commands (ls, cat, etc.) use POSIX emulation +- WSL: Automatically treated as Linux (which is correct) +- Automatic detection: Chooses the best shell based on command and platform +- Persistent state: Working directory and environment variables persist between commands + +WINDOWS-SPECIFIC FEATURES: +- Native Windows commands: dir, type, copy, move, del, md, rd, cls, where, tasklist, etc. +- PowerShell support: Get-Process, Set-Location, and other PowerShell cmdlets +- Windows path handling: Supports both forward slashes (/) and backslashes (\) +- Drive letters: Properly handles C:\, D:\, etc. +- Environment variables: Supports both Unix ($VAR) and Windows (%%VAR%%) syntax Before executing the command, please follow these steps: @@ -262,6 +294,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error) isSafeReadOnly := false cmdLower := strings.ToLower(params.Command) + // Get platform-appropriate safe commands + safeReadOnlyCommands := getSafeReadOnlyCommands() for _, safe := range safeReadOnlyCommands { if strings.HasPrefix(cmdLower, strings.ToLower(safe)) { if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' { @@ -361,4 +395,4 @@ func countLines(s string) int { return 0 } return len(strings.Split(s, "\n")) -} +} \ No newline at end of file diff --git a/internal/llm/tools/bash_test.go b/internal/llm/tools/bash_test.go new file mode 100644 index 0000000000000000000000000000000000000000..b26c96097d25414567f0c456b88dbf54e6503e12 --- /dev/null +++ b/internal/llm/tools/bash_test.go @@ -0,0 +1,125 @@ +package tools + +import ( + "runtime" + "testing" +) + +func TestGetSafeReadOnlyCommands(t *testing.T) { + commands := getSafeReadOnlyCommands() + + // Check that we have some commands + if len(commands) == 0 { + t.Fatal("Expected some safe commands, got none") + } + + // Check for cross-platform commands that should always be present + crossPlatformCommands := []string{"echo", "hostname", "whoami", "git status", "go version"} + for _, cmd := range crossPlatformCommands { + found := false + for _, safeCmd := range commands { + if safeCmd == cmd { + found = true + break + } + } + if !found { + t.Errorf("Expected cross-platform command %q to be in safe commands", cmd) + } + } + + if runtime.GOOS == "windows" { + // Check for Windows-specific commands + windowsCommands := []string{"dir", "type", "Get-Process"} + for _, cmd := range windowsCommands { + found := false + for _, safeCmd := range commands { + if safeCmd == cmd { + found = true + break + } + } + if !found { + t.Errorf("Expected Windows command %q to be in safe commands on Windows", cmd) + } + } + + // Check that Unix commands are NOT present on Windows + unixCommands := []string{"ls", "pwd", "ps"} + for _, cmd := range unixCommands { + found := false + for _, safeCmd := range commands { + if safeCmd == cmd { + found = true + break + } + } + if found { + t.Errorf("Unix command %q should not be in safe commands on Windows", cmd) + } + } + } else { + // Check for Unix-specific commands + unixCommands := []string{"ls", "pwd", "ps"} + for _, cmd := range unixCommands { + found := false + for _, safeCmd := range commands { + if safeCmd == cmd { + found = true + break + } + } + if !found { + t.Errorf("Expected Unix command %q to be in safe commands on Unix", cmd) + } + } + + // Check that Windows-specific commands are NOT present on Unix + windowsOnlyCommands := []string{"dir", "Get-Process", "systeminfo"} + for _, cmd := range windowsOnlyCommands { + found := false + for _, safeCmd := range commands { + if safeCmd == cmd { + found = true + break + } + } + if found { + t.Errorf("Windows-only command %q should not be in safe commands on Unix", cmd) + } + } + } +} + +func TestPlatformSpecificSafeCommands(t *testing.T) { + // Test that the function returns different results on different platforms + commands := getSafeReadOnlyCommands() + + hasWindowsCommands := false + hasUnixCommands := false + + for _, cmd := range commands { + if cmd == "dir" || cmd == "Get-Process" || cmd == "systeminfo" { + hasWindowsCommands = true + } + if cmd == "ls" || cmd == "ps" || cmd == "df" { + hasUnixCommands = true + } + } + + if runtime.GOOS == "windows" { + if !hasWindowsCommands { + t.Error("Expected Windows commands on Windows platform") + } + if hasUnixCommands { + t.Error("Did not expect Unix commands on Windows platform") + } + } else { + if hasWindowsCommands { + t.Error("Did not expect Windows-only commands on Unix platform") + } + if !hasUnixCommands { + t.Error("Expected Unix commands on Unix platform") + } + } +} \ No newline at end of file diff --git a/internal/llm/tools/glob.go b/internal/llm/tools/glob.go index d59cc801cb50ebbdcb7896c20a140b1fbedfdfcf..39471a8ae81f9c31c4bccccf273f88044184f6ad 100644 --- a/internal/llm/tools/glob.go +++ b/internal/llm/tools/glob.go @@ -48,12 +48,11 @@ LIMITATIONS: - Hidden files (starting with '.') are skipped WINDOWS NOTES: -- Uses ripgrep (rg) command if available, otherwise falls back to built-in Go implementation -- On Windows, install ripgrep via: winget install BurntSushi.ripgrep.MSVC - Path separators are handled automatically (both / and \ work) -- Patterns should use forward slashes (/) for cross-platform compatibility +- Uses ripgrep (rg) command if available, otherwise falls back to built-in Go implementation TIPS: +- Patterns should use forward slashes (/) for cross-platform compatibility - For the most useful results, combine with the Grep tool: first find files with Glob, then search their contents with Grep - When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead - Always check if results are truncated and refine your search pattern if needed` diff --git a/internal/llm/tools/grep.go b/internal/llm/tools/grep.go index fe77adf48dda0c637d63d26fe37cbd348e4ad9de..7321ad4ed9a2d713cb8685b3f4b23dbc62cc6c2a 100644 --- a/internal/llm/tools/grep.go +++ b/internal/llm/tools/grep.go @@ -124,11 +124,10 @@ LIMITATIONS: - Very large binary files may be skipped - Hidden files (starting with '.') are skipped -WINDOWS NOTES: +CROSS-PLATFORM NOTES: - Uses ripgrep (rg) command if available for better performance -- On Windows, install ripgrep via: winget install BurntSushi.ripgrep.MSVC - Falls back to built-in Go implementation if ripgrep is not available -- File paths are normalized automatically for Windows compatibility +- File paths are normalized automatically for cross-platform compatibility TIPS: - For faster, more targeted searches, first use Glob to find relevant files, then use Grep diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index 7b75f04b6207fbb40a076dea2926e0846670e738..475746f7a4758a14d44dd59348fb264a2d5461f0 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -1,14 +1,11 @@ // Package shell provides cross-platform shell execution capabilities. // -// WINDOWS COMPATIBILITY NOTE: -// This implementation uses mvdan.cc/sh/v3 which provides POSIX shell emulation -// on Windows. While this works for basic commands, it has limitations: -// - Windows-specific commands (dir, type, copy) are not available -// - PowerShell and cmd.exe specific features are not supported -// - Some Windows path handling may be inconsistent -// -// For full Windows compatibility, consider adding native Windows shell support -// using os/exec with cmd.exe or PowerShell for Windows-specific commands. +// WINDOWS COMPATIBILITY: +// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and +// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility: +// - On Windows: Uses native cmd.exe or PowerShell for Windows-specific commands +// - Cross-platform: Falls back to POSIX emulation for Unix-style commands +// - Automatic detection: Chooses the best shell based on command and platform package shell import ( @@ -17,6 +14,8 @@ import ( "errors" "fmt" "os" + "os/exec" + "runtime" "strings" "sync" @@ -26,6 +25,15 @@ import ( "mvdan.cc/sh/v3/syntax" ) +// ShellType represents the type of shell to use +type ShellType int + +const ( + ShellTypePOSIX ShellType = iota + ShellTypeCmd + ShellTypePowerShell +) + type PersistentShell struct { env []string cwd string @@ -37,6 +45,27 @@ var ( shellInstance *PersistentShell ) +// Windows-specific commands that should use native shell +var windowsNativeCommands = map[string]bool{ + "dir": true, + "type": true, + "copy": true, + "move": true, + "del": true, + "md": true, + "mkdir": true, + "rd": true, + "rmdir": true, + "cls": true, + "where": true, + "tasklist": true, + "taskkill": true, + "net": true, + "sc": true, + "reg": true, + "wmic": true, +} + func GetPersistentShell(cwd string) *PersistentShell { once.Do(func() { shellInstance = newPersistentShell(cwd) @@ -55,6 +84,126 @@ func (s *PersistentShell) Exec(ctx context.Context, command string) (string, str s.mu.Lock() defer s.mu.Unlock() + // Determine which shell to use based on platform and command + shellType := s.determineShellType(command) + + switch shellType { + case ShellTypeCmd: + return s.execWindows(ctx, command, "cmd") + case ShellTypePowerShell: + return s.execWindows(ctx, command, "powershell") + default: + return s.execPOSIX(ctx, command) + } +} + +// determineShellType decides which shell to use based on platform and command +func (s *PersistentShell) determineShellType(command string) ShellType { + if runtime.GOOS != "windows" { + return ShellTypePOSIX + } + + // Extract the first command from the command line + parts := strings.Fields(command) + if len(parts) == 0 { + return ShellTypePOSIX + } + + firstCmd := strings.ToLower(parts[0]) + + // Check if it's a Windows-specific command + if windowsNativeCommands[firstCmd] { + return ShellTypeCmd + } + + // Check for PowerShell-specific syntax + if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") || + strings.Contains(command, "New-") || strings.Contains(command, "$_") || + strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") { + return ShellTypePowerShell + } + + // Default to POSIX emulation for cross-platform compatibility + return ShellTypePOSIX +} + +// execWindows executes commands using native Windows shells (cmd.exe or PowerShell) +func (s *PersistentShell) execWindows(ctx context.Context, command string, shell string) (string, string, error) { + var cmd *exec.Cmd + + // Handle directory changes specially to maintain persistent shell behavior + if strings.HasPrefix(strings.TrimSpace(command), "cd ") { + return s.handleWindowsCD(command) + } + + switch shell { + case "cmd": + // Use cmd.exe for Windows commands + // Add current directory context to maintain state + fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command) + cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand) + case "powershell": + // Use PowerShell for PowerShell commands + // Add current directory context to maintain state + fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command) + cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand) + default: + return "", "", fmt.Errorf("unsupported Windows shell: %s", shell) + } + + // Set environment variables + cmd.Env = s.env + + var stdout, stderr bytes.Buffer + cmd.Stdout = &stdout + cmd.Stderr = &stderr + + err := cmd.Run() + + logging.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err) + return stdout.String(), stderr.String(), err +} + +// handleWindowsCD handles directory changes for Windows shells +func (s *PersistentShell) handleWindowsCD(command string) (string, string, error) { + // Extract the target directory from the cd command + parts := strings.Fields(command) + if len(parts) < 2 { + return "", "cd: missing directory argument", fmt.Errorf("missing directory argument") + } + + targetDir := parts[1] + + // Handle relative paths + if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") { + // Relative path - resolve against current directory + if targetDir == ".." { + // Go up one directory + if len(s.cwd) > 3 { // Don't go above drive root (C:\) + lastSlash := strings.LastIndex(s.cwd, "\\") + if lastSlash > 2 { // Keep drive letter + s.cwd = s.cwd[:lastSlash] + } + } + } else if targetDir != "." { + // Go to subdirectory + s.cwd = s.cwd + "\\" + targetDir + } + } else { + // Absolute path + s.cwd = targetDir + } + + // Verify the directory exists + if _, err := os.Stat(s.cwd); err != nil { + return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err + } + + return "", "", nil +} + +// execPOSIX executes commands using POSIX shell emulation (cross-platform) +func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string, string, error) { line, err := syntax.NewParser().Parse(strings.NewReader(command), "") if err != nil { return "", "", fmt.Errorf("could not parse command: %w", err) @@ -77,7 +226,7 @@ func (s *PersistentShell) Exec(ctx context.Context, command string) (string, str for name, vr := range runner.Vars { s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str)) } - logging.InfoPersist("Command finished", "command", command, "err", err) + logging.InfoPersist("POSIX command finished", "command", command, "err", err) return stdout.String(), stderr.String(), err } diff --git a/internal/llm/tools/shell/shell_test.go b/internal/llm/tools/shell/shell_test.go index 630995aaa756295ff1eb14b05d11a2ee8f634733..d54ab8a051d6aa8123de38f2bfe5be217565e29c 100644 --- a/internal/llm/tools/shell/shell_test.go +++ b/internal/llm/tools/shell/shell_test.go @@ -2,6 +2,8 @@ package shell import ( "context" + "runtime" + "strings" "testing" "time" ) @@ -80,3 +82,144 @@ func TestRunContinuity(t *testing.T) { t.Fatalf("Expected output %q, got %q", expect, out) } } + +// New tests for Windows shell support + +func TestShellTypeDetection(t *testing.T) { + shell := &PersistentShell{} + + tests := []struct { + command string + expected ShellType + windowsOnly bool + }{ + // Windows-specific commands + {"dir", ShellTypeCmd, true}, + {"type file.txt", ShellTypeCmd, true}, + {"copy file1.txt file2.txt", ShellTypeCmd, true}, + {"del file.txt", ShellTypeCmd, true}, + {"md newdir", ShellTypeCmd, true}, + {"tasklist", ShellTypeCmd, true}, + + // PowerShell commands + {"Get-Process", ShellTypePowerShell, true}, + {"Get-ChildItem", ShellTypePowerShell, true}, + {"Set-Location C:\\", ShellTypePowerShell, true}, + {"Get-Content file.txt | Where-Object {$_ -match 'pattern'}", ShellTypePowerShell, true}, + {"$files = Get-ChildItem", ShellTypePowerShell, true}, + + // Unix/cross-platform commands + {"ls -la", ShellTypePOSIX, false}, + {"cat file.txt", ShellTypePOSIX, false}, + {"grep pattern file.txt", ShellTypePOSIX, false}, + {"echo hello", ShellTypePOSIX, false}, + {"git status", ShellTypePOSIX, false}, + {"go build", ShellTypePOSIX, false}, + } + + for _, test := range tests { + t.Run(test.command, func(t *testing.T) { + result := shell.determineShellType(test.command) + + if test.windowsOnly && runtime.GOOS != "windows" { + // On non-Windows systems, everything should use POSIX + if result != ShellTypePOSIX { + t.Errorf("On non-Windows, command %q should use POSIX shell, got %v", test.command, result) + } + } else if runtime.GOOS == "windows" { + // On Windows, check the expected shell type + if result != test.expected { + t.Errorf("Command %q should use %v shell, got %v", test.command, test.expected, result) + } + } + }) + } +} + +func TestWindowsCDHandling(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Windows CD handling test only runs on Windows") + } + + shell := &PersistentShell{ + cwd: "C:\\Users", + env: []string{}, + } + + tests := []struct { + command string + expectedCwd string + shouldError bool + }{ + {"cd ..", "C:\\", false}, + {"cd Documents", "C:\\Users\\Documents", false}, + {"cd C:\\Windows", "C:\\Windows", false}, + {"cd", "", true}, // Missing argument + } + + for _, test := range tests { + t.Run(test.command, func(t *testing.T) { + originalCwd := shell.cwd + stdout, stderr, err := shell.handleWindowsCD(test.command) + + if test.shouldError { + if err == nil { + t.Errorf("Command %q should have failed", test.command) + } + } else { + if err != nil { + t.Errorf("Command %q failed: %v", test.command, err) + } + if shell.cwd != test.expectedCwd { + t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.cwd) + } + } + + // Reset for next test + shell.cwd = originalCwd + _ = stdout + _ = stderr + }) + } +} + +func TestCrossPlatformExecution(t *testing.T) { + shell := newPersistentShell(".") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Test a simple command that should work on all platforms + stdout, stderr, err := shell.Exec(ctx, "echo hello") + if err != nil { + t.Fatalf("Echo command failed: %v, stderr: %s", err, stderr) + } + + if stdout == "" { + t.Error("Echo command produced no output") + } + + // The output should contain "hello" regardless of platform + if !strings.Contains(strings.ToLower(stdout), "hello") { + t.Errorf("Echo output should contain 'hello', got: %q", stdout) + } +} + +func TestWindowsNativeCommands(t *testing.T) { + if runtime.GOOS != "windows" { + t.Skip("Windows native command test only runs on Windows") + } + + shell := newPersistentShell(".") + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + // Test Windows dir command + stdout, stderr, err := shell.Exec(ctx, "dir") + if err != nil { + t.Fatalf("Dir command failed: %v, stderr: %s", err, stderr) + } + + if stdout == "" { + t.Error("Dir command produced no output") + } +}