feat: full windows support

Raphael Amorim created

Change summary

README.md                              |   1 
internal/llm/tools/bash.go             |  64 ++++++++--
internal/llm/tools/bash_test.go        | 125 ++++++++++++++++++++
internal/llm/tools/glob.go             |   5 
internal/llm/tools/grep.go             |   5 
internal/llm/tools/shell/shell.go      | 169 ++++++++++++++++++++++++++-
internal/llm/tools/shell/shell_test.go | 143 +++++++++++++++++++++++
7 files changed, 480 insertions(+), 32 deletions(-)

Detailed changes

README.md 🔗

@@ -10,7 +10,6 @@
 
 Crush is a tool for building software with AI.
 
-
 ## License
 
 [MIT](https://github.com/charmbracelet/crush/raw/main/LICENSE)

internal/llm/tools/bash.go 🔗

@@ -4,6 +4,7 @@ import (
 	"context"
 	"encoding/json"
 	"fmt"
+	"runtime"
 	"strings"
 	"time"
 
@@ -45,27 +46,58 @@ var bannedCommands = []string{
 	"http-prompt", "chrome", "firefox", "safari",
 }
 
-var safeReadOnlyCommands = []string{
-	"ls", "echo", "pwd", "date", "cal", "uptime", "whoami", "id", "groups", "env", "printenv", "set", "unset", "which", "type", "whereis",
-	"whatis", "uname", "hostname", "df", "du", "free", "top", "ps", "kill", "killall", "nice", "nohup", "time", "timeout",
-
-	"git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote",
-	"git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog",
+// getSafeReadOnlyCommands returns platform-appropriate safe commands
+func getSafeReadOnlyCommands() []string {
+	// Base commands that work on all platforms
+	baseCommands := []string{
+		// Cross-platform commands
+		"echo", "hostname", "whoami",
+		
+		// Git commands (cross-platform)
+		"git status", "git log", "git diff", "git show", "git branch", "git tag", "git remote", "git ls-files", "git ls-remote",
+		"git rev-parse", "git config --get", "git config --list", "git describe", "git blame", "git grep", "git shortlog",
+
+		// Go commands (cross-platform)
+		"go version", "go help", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean",
+	}
 
-	"go version", "go help", "go list", "go env", "go doc", "go vet", "go fmt", "go mod", "go test", "go build", "go run", "go install", "go clean",
+	if runtime.GOOS == "windows" {
+		// Windows-specific commands
+		windowsCommands := []string{
+			"dir", "type", "where", "ver", "systeminfo", "tasklist", "ipconfig", "ping", "nslookup",
+			"Get-Process", "Get-Location", "Get-ChildItem", "Get-Content", "Get-Date", "Get-Host", "Get-ComputerInfo",
+		}
+		return append(baseCommands, windowsCommands...)
+	} else {
+		// Unix/Linux commands (including WSL, since WSL reports as Linux)
+		unixCommands := []string{
+			"ls", "pwd", "date", "cal", "uptime", "id", "groups", "env", "printenv", "set", "unset", "which", "type", "whereis",
+			"whatis", "uname", "df", "du", "free", "top", "ps", "kill", "killall", "nice", "nohup", "time", "timeout",
+		}
+		return append(baseCommands, unixCommands...)
+	}
 }
 
 func bashDescription() string {
 	bannedCommandsStr := strings.Join(bannedCommands, ", ")
 	return fmt.Sprintf(`Executes a given bash command in a persistent shell session with optional timeout, ensuring proper handling and security measures.
 
-IMPORTANT FOR WINDOWS USERS:
-- This tool uses a POSIX shell emulator (mvdan.cc/sh/v3) that works cross-platform, including Windows
-- On Windows, this provides bash-like functionality without requiring WSL or Git Bash
-- Use forward slashes (/) in paths - they work on all platforms and are converted automatically
-- Windows-specific commands (like 'dir', 'type', 'copy') are not available - use Unix equivalents ('ls', 'cat', 'cp')
-- Environment variables use Unix syntax: $VAR instead of %%VAR%%
-- File paths are automatically converted between Windows and Unix formats as needed
+CROSS-PLATFORM SHELL SUPPORT:
+- Unix/Linux/macOS: Uses native bash/sh shell
+- Windows: Intelligent shell selection:
+  * Windows commands (dir, type, copy, etc.) use cmd.exe
+  * PowerShell commands (Get-, Set-, etc.) use PowerShell
+  * Unix-style commands (ls, cat, etc.) use POSIX emulation
+- WSL: Automatically treated as Linux (which is correct)
+- Automatic detection: Chooses the best shell based on command and platform
+- Persistent state: Working directory and environment variables persist between commands
+
+WINDOWS-SPECIFIC FEATURES:
+- Native Windows commands: dir, type, copy, move, del, md, rd, cls, where, tasklist, etc.
+- PowerShell support: Get-Process, Set-Location, and other PowerShell cmdlets
+- Windows path handling: Supports both forward slashes (/) and backslashes (\)
+- Drive letters: Properly handles C:\, D:\, etc.
+- Environment variables: Supports both Unix ($VAR) and Windows (%%VAR%%) syntax
 
 Before executing the command, please follow these steps:
 
@@ -262,6 +294,8 @@ func (b *bashTool) Run(ctx context.Context, call ToolCall) (ToolResponse, error)
 	isSafeReadOnly := false
 	cmdLower := strings.ToLower(params.Command)
 
+	// Get platform-appropriate safe commands
+	safeReadOnlyCommands := getSafeReadOnlyCommands()
 	for _, safe := range safeReadOnlyCommands {
 		if strings.HasPrefix(cmdLower, strings.ToLower(safe)) {
 			if len(cmdLower) == len(safe) || cmdLower[len(safe)] == ' ' || cmdLower[len(safe)] == '-' {
@@ -361,4 +395,4 @@ func countLines(s string) int {
 		return 0
 	}
 	return len(strings.Split(s, "\n"))
-}
+}

internal/llm/tools/bash_test.go 🔗

@@ -0,0 +1,125 @@
+package tools
+
+import (
+	"runtime"
+	"testing"
+)
+
+func TestGetSafeReadOnlyCommands(t *testing.T) {
+	commands := getSafeReadOnlyCommands()
+	
+	// Check that we have some commands
+	if len(commands) == 0 {
+		t.Fatal("Expected some safe commands, got none")
+	}
+	
+	// Check for cross-platform commands that should always be present
+	crossPlatformCommands := []string{"echo", "hostname", "whoami", "git status", "go version"}
+	for _, cmd := range crossPlatformCommands {
+		found := false
+		for _, safeCmd := range commands {
+			if safeCmd == cmd {
+				found = true
+				break
+			}
+		}
+		if !found {
+			t.Errorf("Expected cross-platform command %q to be in safe commands", cmd)
+		}
+	}
+	
+	if runtime.GOOS == "windows" {
+		// Check for Windows-specific commands
+		windowsCommands := []string{"dir", "type", "Get-Process"}
+		for _, cmd := range windowsCommands {
+			found := false
+			for _, safeCmd := range commands {
+				if safeCmd == cmd {
+					found = true
+					break
+				}
+			}
+			if !found {
+				t.Errorf("Expected Windows command %q to be in safe commands on Windows", cmd)
+			}
+		}
+		
+		// Check that Unix commands are NOT present on Windows
+		unixCommands := []string{"ls", "pwd", "ps"}
+		for _, cmd := range unixCommands {
+			found := false
+			for _, safeCmd := range commands {
+				if safeCmd == cmd {
+					found = true
+					break
+				}
+			}
+			if found {
+				t.Errorf("Unix command %q should not be in safe commands on Windows", cmd)
+			}
+		}
+	} else {
+		// Check for Unix-specific commands
+		unixCommands := []string{"ls", "pwd", "ps"}
+		for _, cmd := range unixCommands {
+			found := false
+			for _, safeCmd := range commands {
+				if safeCmd == cmd {
+					found = true
+					break
+				}
+			}
+			if !found {
+				t.Errorf("Expected Unix command %q to be in safe commands on Unix", cmd)
+			}
+		}
+		
+		// Check that Windows-specific commands are NOT present on Unix
+		windowsOnlyCommands := []string{"dir", "Get-Process", "systeminfo"}
+		for _, cmd := range windowsOnlyCommands {
+			found := false
+			for _, safeCmd := range commands {
+				if safeCmd == cmd {
+					found = true
+					break
+				}
+			}
+			if found {
+				t.Errorf("Windows-only command %q should not be in safe commands on Unix", cmd)
+			}
+		}
+	}
+}
+
+func TestPlatformSpecificSafeCommands(t *testing.T) {
+	// Test that the function returns different results on different platforms
+	commands := getSafeReadOnlyCommands()
+	
+	hasWindowsCommands := false
+	hasUnixCommands := false
+	
+	for _, cmd := range commands {
+		if cmd == "dir" || cmd == "Get-Process" || cmd == "systeminfo" {
+			hasWindowsCommands = true
+		}
+		if cmd == "ls" || cmd == "ps" || cmd == "df" {
+			hasUnixCommands = true
+		}
+	}
+	
+	if runtime.GOOS == "windows" {
+		if !hasWindowsCommands {
+			t.Error("Expected Windows commands on Windows platform")
+		}
+		if hasUnixCommands {
+			t.Error("Did not expect Unix commands on Windows platform")
+		}
+	} else {
+		if hasWindowsCommands {
+			t.Error("Did not expect Windows-only commands on Unix platform")
+		}
+		if !hasUnixCommands {
+			t.Error("Expected Unix commands on Unix platform")
+		}
+	}
+}

internal/llm/tools/glob.go 🔗

@@ -48,12 +48,11 @@ LIMITATIONS:
 - Hidden files (starting with '.') are skipped
 
 WINDOWS NOTES:
-- Uses ripgrep (rg) command if available, otherwise falls back to built-in Go implementation
-- On Windows, install ripgrep via: winget install BurntSushi.ripgrep.MSVC
 - Path separators are handled automatically (both / and \ work)
-- Patterns should use forward slashes (/) for cross-platform compatibility
+- Uses ripgrep (rg) command if available, otherwise falls back to built-in Go implementation
 
 TIPS:
+- Patterns should use forward slashes (/) for cross-platform compatibility
 - For the most useful results, combine with the Grep tool: first find files with Glob, then search their contents with Grep
 - When doing iterative exploration that may require multiple rounds of searching, consider using the Agent tool instead
 - Always check if results are truncated and refine your search pattern if needed`

internal/llm/tools/grep.go 🔗

@@ -124,11 +124,10 @@ LIMITATIONS:
 - Very large binary files may be skipped
 - Hidden files (starting with '.') are skipped
 
-WINDOWS NOTES:
+CROSS-PLATFORM NOTES:
 - Uses ripgrep (rg) command if available for better performance
-- On Windows, install ripgrep via: winget install BurntSushi.ripgrep.MSVC
 - Falls back to built-in Go implementation if ripgrep is not available
-- File paths are normalized automatically for Windows compatibility
+- File paths are normalized automatically for cross-platform compatibility
 
 TIPS:
 - For faster, more targeted searches, first use Glob to find relevant files, then use Grep

internal/llm/tools/shell/shell.go 🔗

@@ -1,14 +1,11 @@
 // Package shell provides cross-platform shell execution capabilities.
 // 
-// WINDOWS COMPATIBILITY NOTE:
-// This implementation uses mvdan.cc/sh/v3 which provides POSIX shell emulation
-// on Windows. While this works for basic commands, it has limitations:
-// - Windows-specific commands (dir, type, copy) are not available
-// - PowerShell and cmd.exe specific features are not supported
-// - Some Windows path handling may be inconsistent
-// 
-// For full Windows compatibility, consider adding native Windows shell support
-// using os/exec with cmd.exe or PowerShell for Windows-specific commands.
+// WINDOWS COMPATIBILITY:
+// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and 
+// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility:
+// - On Windows: Uses native cmd.exe or PowerShell for Windows-specific commands
+// - Cross-platform: Falls back to POSIX emulation for Unix-style commands
+// - Automatic detection: Chooses the best shell based on command and platform
 package shell
 
 import (
@@ -17,6 +14,8 @@ import (
 	"errors"
 	"fmt"
 	"os"
+	"os/exec"
+	"runtime"
 	"strings"
 	"sync"
 
@@ -26,6 +25,15 @@ import (
 	"mvdan.cc/sh/v3/syntax"
 )
 
+// ShellType represents the type of shell to use
+type ShellType int
+
+const (
+	ShellTypePOSIX ShellType = iota
+	ShellTypeCmd
+	ShellTypePowerShell
+)
+
 type PersistentShell struct {
 	env []string
 	cwd string
@@ -37,6 +45,27 @@ var (
 	shellInstance *PersistentShell
 )
 
+// Windows-specific commands that should use native shell
+var windowsNativeCommands = map[string]bool{
+	"dir":      true,
+	"type":     true,
+	"copy":     true,
+	"move":     true,
+	"del":      true,
+	"md":       true,
+	"mkdir":    true,
+	"rd":       true,
+	"rmdir":    true,
+	"cls":      true,
+	"where":    true,
+	"tasklist": true,
+	"taskkill": true,
+	"net":      true,
+	"sc":       true,
+	"reg":      true,
+	"wmic":     true,
+}
+
 func GetPersistentShell(cwd string) *PersistentShell {
 	once.Do(func() {
 		shellInstance = newPersistentShell(cwd)
@@ -55,6 +84,126 @@ func (s *PersistentShell) Exec(ctx context.Context, command string) (string, str
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
+	// Determine which shell to use based on platform and command
+	shellType := s.determineShellType(command)
+	
+	switch shellType {
+	case ShellTypeCmd:
+		return s.execWindows(ctx, command, "cmd")
+	case ShellTypePowerShell:
+		return s.execWindows(ctx, command, "powershell")
+	default:
+		return s.execPOSIX(ctx, command)
+	}
+}
+
+// determineShellType decides which shell to use based on platform and command
+func (s *PersistentShell) determineShellType(command string) ShellType {
+	if runtime.GOOS != "windows" {
+		return ShellTypePOSIX
+	}
+
+	// Extract the first command from the command line
+	parts := strings.Fields(command)
+	if len(parts) == 0 {
+		return ShellTypePOSIX
+	}
+
+	firstCmd := strings.ToLower(parts[0])
+	
+	// Check if it's a Windows-specific command
+	if windowsNativeCommands[firstCmd] {
+		return ShellTypeCmd
+	}
+	
+	// Check for PowerShell-specific syntax
+	if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") || 
+	   strings.Contains(command, "New-") || strings.Contains(command, "$_") ||
+	   strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") {
+		return ShellTypePowerShell
+	}
+	
+	// Default to POSIX emulation for cross-platform compatibility
+	return ShellTypePOSIX
+}
+
+// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
+func (s *PersistentShell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
+	var cmd *exec.Cmd
+	
+	// Handle directory changes specially to maintain persistent shell behavior
+	if strings.HasPrefix(strings.TrimSpace(command), "cd ") {
+		return s.handleWindowsCD(command)
+	}
+	
+	switch shell {
+	case "cmd":
+		// Use cmd.exe for Windows commands
+		// Add current directory context to maintain state
+		fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command)
+		cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand)
+	case "powershell":
+		// Use PowerShell for PowerShell commands
+		// Add current directory context to maintain state
+		fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command)
+		cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand)
+	default:
+		return "", "", fmt.Errorf("unsupported Windows shell: %s", shell)
+	}
+	
+	// Set environment variables
+	cmd.Env = s.env
+	
+	var stdout, stderr bytes.Buffer
+	cmd.Stdout = &stdout
+	cmd.Stderr = &stderr
+	
+	err := cmd.Run()
+	
+	logging.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
+	return stdout.String(), stderr.String(), err
+}
+
+// handleWindowsCD handles directory changes for Windows shells
+func (s *PersistentShell) handleWindowsCD(command string) (string, string, error) {
+	// Extract the target directory from the cd command
+	parts := strings.Fields(command)
+	if len(parts) < 2 {
+		return "", "cd: missing directory argument", fmt.Errorf("missing directory argument")
+	}
+	
+	targetDir := parts[1]
+	
+	// Handle relative paths
+	if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") {
+		// Relative path - resolve against current directory
+		if targetDir == ".." {
+			// Go up one directory
+			if len(s.cwd) > 3 { // Don't go above drive root (C:\)
+				lastSlash := strings.LastIndex(s.cwd, "\\")
+				if lastSlash > 2 { // Keep drive letter
+					s.cwd = s.cwd[:lastSlash]
+				}
+			}
+		} else if targetDir != "." {
+			// Go to subdirectory
+			s.cwd = s.cwd + "\\" + targetDir
+		}
+	} else {
+		// Absolute path
+		s.cwd = targetDir
+	}
+	
+	// Verify the directory exists
+	if _, err := os.Stat(s.cwd); err != nil {
+		return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err
+	}
+	
+	return "", "", nil
+}
+
+// execPOSIX executes commands using POSIX shell emulation (cross-platform)
+func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string, string, error) {
 	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
 	if err != nil {
 		return "", "", fmt.Errorf("could not parse command: %w", err)
@@ -77,7 +226,7 @@ func (s *PersistentShell) Exec(ctx context.Context, command string) (string, str
 	for name, vr := range runner.Vars {
 		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
 	}
-	logging.InfoPersist("Command finished", "command", command, "err", err)
+	logging.InfoPersist("POSIX command finished", "command", command, "err", err)
 	return stdout.String(), stderr.String(), err
 }
 

internal/llm/tools/shell/shell_test.go 🔗

@@ -2,6 +2,8 @@ package shell
 
 import (
 	"context"
+	"runtime"
+	"strings"
 	"testing"
 	"time"
 )
@@ -80,3 +82,144 @@ func TestRunContinuity(t *testing.T) {
 		t.Fatalf("Expected output %q, got %q", expect, out)
 	}
 }
+
+// New tests for Windows shell support
+
+func TestShellTypeDetection(t *testing.T) {
+	shell := &PersistentShell{}
+	
+	tests := []struct {
+		command    string
+		expected   ShellType
+		windowsOnly bool
+	}{
+		// Windows-specific commands
+		{"dir", ShellTypeCmd, true},
+		{"type file.txt", ShellTypeCmd, true},
+		{"copy file1.txt file2.txt", ShellTypeCmd, true},
+		{"del file.txt", ShellTypeCmd, true},
+		{"md newdir", ShellTypeCmd, true},
+		{"tasklist", ShellTypeCmd, true},
+		
+		// PowerShell commands
+		{"Get-Process", ShellTypePowerShell, true},
+		{"Get-ChildItem", ShellTypePowerShell, true},
+		{"Set-Location C:\\", ShellTypePowerShell, true},
+		{"Get-Content file.txt | Where-Object {$_ -match 'pattern'}", ShellTypePowerShell, true},
+		{"$files = Get-ChildItem", ShellTypePowerShell, true},
+		
+		// Unix/cross-platform commands
+		{"ls -la", ShellTypePOSIX, false},
+		{"cat file.txt", ShellTypePOSIX, false},
+		{"grep pattern file.txt", ShellTypePOSIX, false},
+		{"echo hello", ShellTypePOSIX, false},
+		{"git status", ShellTypePOSIX, false},
+		{"go build", ShellTypePOSIX, false},
+	}
+	
+	for _, test := range tests {
+		t.Run(test.command, func(t *testing.T) {
+			result := shell.determineShellType(test.command)
+			
+			if test.windowsOnly && runtime.GOOS != "windows" {
+				// On non-Windows systems, everything should use POSIX
+				if result != ShellTypePOSIX {
+					t.Errorf("On non-Windows, command %q should use POSIX shell, got %v", test.command, result)
+				}
+			} else if runtime.GOOS == "windows" {
+				// On Windows, check the expected shell type
+				if result != test.expected {
+					t.Errorf("Command %q should use %v shell, got %v", test.command, test.expected, result)
+				}
+			}
+		})
+	}
+}
+
+func TestWindowsCDHandling(t *testing.T) {
+	if runtime.GOOS != "windows" {
+		t.Skip("Windows CD handling test only runs on Windows")
+	}
+	
+	shell := &PersistentShell{
+		cwd: "C:\\Users",
+		env: []string{},
+	}
+	
+	tests := []struct {
+		command     string
+		expectedCwd string
+		shouldError bool
+	}{
+		{"cd ..", "C:\\", false},
+		{"cd Documents", "C:\\Users\\Documents", false},
+		{"cd C:\\Windows", "C:\\Windows", false},
+		{"cd", "", true}, // Missing argument
+	}
+	
+	for _, test := range tests {
+		t.Run(test.command, func(t *testing.T) {
+			originalCwd := shell.cwd
+			stdout, stderr, err := shell.handleWindowsCD(test.command)
+			
+			if test.shouldError {
+				if err == nil {
+					t.Errorf("Command %q should have failed", test.command)
+				}
+			} else {
+				if err != nil {
+					t.Errorf("Command %q failed: %v", test.command, err)
+				}
+				if shell.cwd != test.expectedCwd {
+					t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.cwd)
+				}
+			}
+			
+			// Reset for next test
+			shell.cwd = originalCwd
+			_ = stdout
+			_ = stderr
+		})
+	}
+}
+
+func TestCrossPlatformExecution(t *testing.T) {
+	shell := newPersistentShell(".")
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+	defer cancel()
+	
+	// Test a simple command that should work on all platforms
+	stdout, stderr, err := shell.Exec(ctx, "echo hello")
+	if err != nil {
+		t.Fatalf("Echo command failed: %v, stderr: %s", err, stderr)
+	}
+	
+	if stdout == "" {
+		t.Error("Echo command produced no output")
+	}
+	
+	// The output should contain "hello" regardless of platform
+	if !strings.Contains(strings.ToLower(stdout), "hello") {
+		t.Errorf("Echo output should contain 'hello', got: %q", stdout)
+	}
+}
+
+func TestWindowsNativeCommands(t *testing.T) {
+	if runtime.GOOS != "windows" {
+		t.Skip("Windows native command test only runs on Windows")
+	}
+	
+	shell := newPersistentShell(".")
+	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
+	defer cancel()
+	
+	// Test Windows dir command
+	stdout, stderr, err := shell.Exec(ctx, "dir")
+	if err != nil {
+		t.Fatalf("Dir command failed: %v, stderr: %s", err, stderr)
+	}
+	
+	if stdout == "" {
+		t.Error("Dir command produced no output")
+	}
+}