Merge pull request #87 from charmbracelet/crush-shell

Kujtim Hoxha created

Move shell to its own package

Change summary

internal/llm/tools/bash.go        |   2 
internal/shell/comparison_test.go |   4 
internal/shell/doc.go             |  30 +++++
internal/shell/persistent.go      |  38 +++++++
internal/shell/shell.go           | 174 +++++++++++++++++++++++---------
internal/shell/shell_test.go      |  29 ++--
6 files changed, 211 insertions(+), 66 deletions(-)

Detailed changes

internal/llm/tools/bash.go 🔗

@@ -9,8 +9,8 @@ import (
 	"time"
 
 	"github.com/charmbracelet/crush/internal/config"
-	"github.com/charmbracelet/crush/internal/llm/tools/shell"
 	"github.com/charmbracelet/crush/internal/permission"
+	"github.com/charmbracelet/crush/internal/shell"
 )
 
 type BashParams struct {

internal/llm/tools/shell/comparison_test.go → internal/shell/comparison_test.go 🔗

@@ -9,7 +9,7 @@ import (
 )
 
 func TestShellPerformanceComparison(t *testing.T) {
-	shell := newPersistentShell(t.TempDir())
+	shell := NewShell(&Options{WorkingDir: t.TempDir()})
 
 	// Test quick command
 	start := time.Now()
@@ -27,7 +27,7 @@ func TestShellPerformanceComparison(t *testing.T) {
 
 // Benchmark CPU usage during polling
 func BenchmarkShellPolling(b *testing.B) {
-	shell := newPersistentShell(b.TempDir())
+	shell := NewShell(&Options{WorkingDir: b.TempDir()})
 
 	b.ReportAllocs()
 

internal/shell/doc.go 🔗

@@ -0,0 +1,30 @@
+package shell
+
+// Example usage of the shell package:
+//
+// 1. For one-off commands:
+//
+//	shell := shell.NewShell(nil)
+//	stdout, stderr, err := shell.Exec(context.Background(), "echo hello")
+//
+// 2. For maintaining state across commands:
+//
+//	shell := shell.NewShell(&shell.Options{
+//	    WorkingDir: "/tmp",
+//	    Logger: myLogger,
+//	})
+//	shell.Exec(ctx, "export FOO=bar")
+//	shell.Exec(ctx, "echo $FOO")  // Will print "bar"
+//
+// 3. For the singleton persistent shell (used by tools):
+//
+//	shell := shell.GetPersistentShell("/path/to/cwd")
+//	stdout, stderr, err := shell.Exec(ctx, "ls -la")
+//
+// 4. Managing environment and working directory:
+//
+//	shell := shell.NewShell(nil)
+//	shell.SetEnv("MY_VAR", "value")
+//	shell.SetWorkingDir("/tmp")
+//	cwd := shell.GetWorkingDir()
+//	env := shell.GetEnv()

internal/shell/persistent.go 🔗

@@ -0,0 +1,38 @@
+package shell
+
+import (
+	"sync"
+
+	"github.com/charmbracelet/crush/internal/logging"
+)
+
+// PersistentShell is a singleton shell instance that maintains state across the application
+type PersistentShell struct {
+	*Shell
+}
+
+var (
+	once          sync.Once
+	shellInstance *PersistentShell
+)
+
+// GetPersistentShell returns the singleton persistent shell instance
+// This maintains backward compatibility with the existing API
+func GetPersistentShell(cwd string) *PersistentShell {
+	once.Do(func() {
+		shellInstance = &PersistentShell{
+			Shell: NewShell(&Options{
+				WorkingDir: cwd,
+				Logger:     &loggingAdapter{},
+			}),
+		}
+	})
+	return shellInstance
+}
+
+// loggingAdapter adapts the internal logging package to the Logger interface
+type loggingAdapter struct{}
+
+func (l *loggingAdapter) InfoPersist(msg string, keysAndValues ...interface{}) {
+	logging.InfoPersist(msg, keysAndValues...)
+}

internal/llm/tools/shell/shell.go → internal/shell/shell.go 🔗

@@ -1,11 +1,12 @@
 // Package shell provides cross-platform shell execution capabilities.
 //
+// This package offers two main types:
+// - Shell: A general-purpose shell executor for one-off or managed commands
+// - PersistentShell: A singleton shell that maintains state across the application
+//
 // WINDOWS COMPATIBILITY:
 // This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
-// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility:
-// - On Windows: Uses native cmd.exe or PowerShell for Windows-specific commands
-// - Cross-platform: Falls back to POSIX emulation for Unix-style commands
-// - Automatic detection: Chooses the best shell based on command and platform
+// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
 package shell
 
 import (
@@ -19,7 +20,6 @@ import (
 	"strings"
 	"sync"
 
-	"github.com/charmbracelet/crush/internal/logging"
 	"mvdan.cc/sh/v3/expand"
 	"mvdan.cc/sh/v3/interp"
 	"mvdan.cc/sh/v3/syntax"
@@ -34,53 +34,61 @@ const (
 	ShellTypePowerShell
 )
 
-type PersistentShell struct {
-	env []string
-	cwd string
-	mu  sync.Mutex
+// Logger interface for optional logging
+type Logger interface {
+	InfoPersist(msg string, keysAndValues ...interface{})
 }
 
-var (
-	once          sync.Once
-	shellInstance *PersistentShell
-)
+// noopLogger is a logger that does nothing
+type noopLogger struct{}
 
-// Windows-specific commands that should use native shell
-var windowsNativeCommands = map[string]bool{
-	"dir":      true,
-	"type":     true,
-	"copy":     true,
-	"move":     true,
-	"del":      true,
-	"md":       true,
-	"mkdir":    true,
-	"rd":       true,
-	"rmdir":    true,
-	"cls":      true,
-	"where":    true,
-	"tasklist": true,
-	"taskkill": true,
-	"net":      true,
-	"sc":       true,
-	"reg":      true,
-	"wmic":     true,
+func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
+
+// Shell provides cross-platform shell execution with optional state persistence
+type Shell struct {
+	env    []string
+	cwd    string
+	mu     sync.Mutex
+	logger Logger
 }
 
-func GetPersistentShell(cwd string) *PersistentShell {
-	once.Do(func() {
-		shellInstance = newPersistentShell(cwd)
-	})
-	return shellInstance
+// Options for creating a new shell
+type Options struct {
+	WorkingDir string
+	Env        []string
+	Logger     Logger
 }
 
-func newPersistentShell(cwd string) *PersistentShell {
-	return &PersistentShell{
-		cwd: cwd,
-		env: os.Environ(),
+// NewShell creates a new shell instance with the given options
+func NewShell(opts *Options) *Shell {
+	if opts == nil {
+		opts = &Options{}
+	}
+
+	cwd := opts.WorkingDir
+	if cwd == "" {
+		cwd, _ = os.Getwd()
+	}
+
+	env := opts.Env
+	if env == nil {
+		env = os.Environ()
+	}
+
+	logger := opts.Logger
+	if logger == nil {
+		logger = noopLogger{}
+	}
+
+	return &Shell{
+		cwd:    cwd,
+		env:    env,
+		logger: logger,
 	}
 }
 
-func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
+// Exec executes a command in the shell
+func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -97,8 +105,76 @@ func (s *PersistentShell) Exec(ctx context.Context, command string) (string, str
 	}
 }
 
+// GetWorkingDir returns the current working directory
+func (s *Shell) GetWorkingDir() string {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+	return s.cwd
+}
+
+// SetWorkingDir sets the working directory
+func (s *Shell) SetWorkingDir(dir string) error {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	// Verify the directory exists
+	if _, err := os.Stat(dir); err != nil {
+		return fmt.Errorf("directory does not exist: %w", err)
+	}
+
+	s.cwd = dir
+	return nil
+}
+
+// GetEnv returns a copy of the environment variables
+func (s *Shell) GetEnv() []string {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	env := make([]string, len(s.env))
+	copy(env, s.env)
+	return env
+}
+
+// SetEnv sets an environment variable
+func (s *Shell) SetEnv(key, value string) {
+	s.mu.Lock()
+	defer s.mu.Unlock()
+
+	// Update or add the environment variable
+	keyPrefix := key + "="
+	for i, env := range s.env {
+		if strings.HasPrefix(env, keyPrefix) {
+			s.env[i] = keyPrefix + value
+			return
+		}
+	}
+	s.env = append(s.env, keyPrefix+value)
+}
+
+// Windows-specific commands that should use native shell
+var windowsNativeCommands = map[string]bool{
+	"dir":      true,
+	"type":     true,
+	"copy":     true,
+	"move":     true,
+	"del":      true,
+	"md":       true,
+	"mkdir":    true,
+	"rd":       true,
+	"rmdir":    true,
+	"cls":      true,
+	"where":    true,
+	"tasklist": true,
+	"taskkill": true,
+	"net":      true,
+	"sc":       true,
+	"reg":      true,
+	"wmic":     true,
+}
+
 // determineShellType decides which shell to use based on platform and command
-func (s *PersistentShell) determineShellType(command string) ShellType {
+func (s *Shell) determineShellType(command string) ShellType {
 	if runtime.GOOS != "windows" {
 		return ShellTypePOSIX
 	}
@@ -128,7 +204,7 @@ func (s *PersistentShell) determineShellType(command string) ShellType {
 }
 
 // execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
-func (s *PersistentShell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
+func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
 	var cmd *exec.Cmd
 
 	// Handle directory changes specially to maintain persistent shell behavior
@@ -160,12 +236,12 @@ func (s *PersistentShell) execWindows(ctx context.Context, command string, shell
 
 	err := cmd.Run()
 
-	logging.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
+	s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
 	return stdout.String(), stderr.String(), err
 }
 
 // handleWindowsCD handles directory changes for Windows shells
-func (s *PersistentShell) handleWindowsCD(command string) (string, string, error) {
+func (s *Shell) handleWindowsCD(command string) (string, string, error) {
 	// Extract the target directory from the cd command
 	parts := strings.Fields(command)
 	if len(parts) < 2 {
@@ -203,7 +279,7 @@ func (s *PersistentShell) handleWindowsCD(command string) (string, string, error
 }
 
 // execPOSIX executes commands using POSIX shell emulation (cross-platform)
-func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string, string, error) {
+func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
 	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
 	if err != nil {
 		return "", "", fmt.Errorf("could not parse command: %w", err)
@@ -226,15 +302,17 @@ func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string
 	for name, vr := range runner.Vars {
 		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
 	}
-	logging.InfoPersist("POSIX command finished", "command", command, "err", err)
+	s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
 	return stdout.String(), stderr.String(), err
 }
 
+// IsInterrupt checks if an error is due to interruption
 func IsInterrupt(err error) bool {
 	return errors.Is(err, context.Canceled) ||
 		errors.Is(err, context.DeadlineExceeded)
 }
 
+// ExitCode extracts the exit code from an error
 func ExitCode(err error) int {
 	if err == nil {
 		return 0

internal/llm/tools/shell/shell_test.go → internal/shell/shell_test.go 🔗

@@ -10,7 +10,7 @@ import (
 
 // Benchmark to measure CPU efficiency
 func BenchmarkShellQuickCommands(b *testing.B) {
-	shell := newPersistentShell(b.TempDir())
+	shell := NewShell(&Options{WorkingDir: b.TempDir()})
 
 	b.ReportAllocs()
 
@@ -27,7 +27,7 @@ func TestTestTimeout(t *testing.T) {
 	ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
 	t.Cleanup(cancel)
 
-	shell := newPersistentShell(t.TempDir())
+	shell := NewShell(&Options{WorkingDir: t.TempDir()})
 	_, _, err := shell.Exec(ctx, "sleep 10")
 	if status := ExitCode(err); status == 0 {
 		t.Fatalf("Expected non-zero exit status, got %d", status)
@@ -44,7 +44,7 @@ func TestTestCancel(t *testing.T) {
 	ctx, cancel := context.WithCancel(t.Context())
 	cancel() // immediately cancel the context
 
-	shell := newPersistentShell(t.TempDir())
+	shell := NewShell(&Options{WorkingDir: t.TempDir()})
 	_, _, err := shell.Exec(ctx, "sleep 10")
 	if status := ExitCode(err); status == 0 {
 		t.Fatalf("Expected non-zero exit status, got %d", status)
@@ -58,7 +58,7 @@ func TestTestCancel(t *testing.T) {
 }
 
 func TestRunCommandError(t *testing.T) {
-	shell := newPersistentShell(t.TempDir())
+	shell := NewShell(&Options{WorkingDir: t.TempDir()})
 	_, _, err := shell.Exec(t.Context(), "nopenopenope")
 	if status := ExitCode(err); status == 0 {
 		t.Fatalf("Expected non-zero exit status, got %d", status)
@@ -72,7 +72,7 @@ func TestRunCommandError(t *testing.T) {
 }
 
 func TestRunContinuity(t *testing.T) {
-	shell := newPersistentShell(t.TempDir())
+	shell := NewShell(&Options{WorkingDir: t.TempDir()})
 	shell.Exec(t.Context(), "export FOO=bar")
 	dst := t.TempDir()
 	shell.Exec(t.Context(), "cd "+dst)
@@ -141,10 +141,9 @@ func TestWindowsCDHandling(t *testing.T) {
 		t.Skip("Windows CD handling test only runs on Windows")
 	}
 
-	shell := &PersistentShell{
-		cwd: "C:\\Users",
-		env: []string{},
-	}
+	shell := NewShell(&Options{
+		WorkingDir: "C:\\Users",
+	})
 
 	tests := []struct {
 		command     string
@@ -159,7 +158,7 @@ func TestWindowsCDHandling(t *testing.T) {
 
 	for _, test := range tests {
 		t.Run(test.command, func(t *testing.T) {
-			originalCwd := shell.cwd
+			originalCwd := shell.GetWorkingDir()
 			stdout, stderr, err := shell.handleWindowsCD(test.command)
 
 			if test.shouldError {
@@ -170,13 +169,13 @@ func TestWindowsCDHandling(t *testing.T) {
 				if err != nil {
 					t.Errorf("Command %q failed: %v", test.command, err)
 				}
-				if shell.cwd != test.expectedCwd {
-					t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.cwd)
+				if shell.GetWorkingDir() != test.expectedCwd {
+					t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.GetWorkingDir())
 				}
 			}
 
 			// Reset for next test
-			shell.cwd = originalCwd
+			shell.SetWorkingDir(originalCwd)
 			_ = stdout
 			_ = stderr
 		})
@@ -184,7 +183,7 @@ func TestWindowsCDHandling(t *testing.T) {
 }
 
 func TestCrossPlatformExecution(t *testing.T) {
-	shell := newPersistentShell(".")
+	shell := NewShell(&Options{WorkingDir: "."})
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 	defer cancel()
 
@@ -209,7 +208,7 @@ func TestWindowsNativeCommands(t *testing.T) {
 		t.Skip("Windows native command test only runs on Windows")
 	}
 
-	shell := newPersistentShell(".")
+	shell := NewShell(&Options{WorkingDir: "."})
 	ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
 	defer cancel()