From baad28ac9012967020558d4c1d388cb08abbd97d Mon Sep 17 00:00:00 2001 From: Kujtim Hoxha Date: Sat, 28 Jun 2025 13:35:50 +0200 Subject: [PATCH] chore: move shell to its own package --- internal/llm/tools/bash.go | 2 +- .../{llm/tools => }/shell/comparison_test.go | 4 +- internal/shell/doc.go | 30 +++ internal/shell/persistent.go | 38 ++++ internal/{llm/tools => }/shell/shell.go | 175 +++++++++++++----- internal/{llm/tools => }/shell/shell_test.go | 29 ++- 6 files changed, 212 insertions(+), 66 deletions(-) rename internal/{llm/tools => }/shell/comparison_test.go (88%) create mode 100644 internal/shell/doc.go create mode 100644 internal/shell/persistent.go rename internal/{llm/tools => }/shell/shell.go (64%) rename internal/{llm/tools => }/shell/shell_test.go (90%) diff --git a/internal/llm/tools/bash.go b/internal/llm/tools/bash.go index 194632742d1f9916172cb0933c37d1e7a42adbf3..abbd19113db746cd8e82c5cdebc02c4b8fc28b99 100644 --- a/internal/llm/tools/bash.go +++ b/internal/llm/tools/bash.go @@ -9,8 +9,8 @@ import ( "time" "github.com/charmbracelet/crush/internal/config" - "github.com/charmbracelet/crush/internal/llm/tools/shell" "github.com/charmbracelet/crush/internal/permission" + "github.com/charmbracelet/crush/internal/shell" ) type BashParams struct { diff --git a/internal/llm/tools/shell/comparison_test.go b/internal/shell/comparison_test.go similarity index 88% rename from internal/llm/tools/shell/comparison_test.go rename to internal/shell/comparison_test.go index 4906a5c8acb8136d7b7947177f999f7019080f5f..d92e00c17edf5a6b13c2c0d10ce9f52edd0a14ff 100644 --- a/internal/llm/tools/shell/comparison_test.go +++ b/internal/shell/comparison_test.go @@ -9,7 +9,7 @@ import ( ) func TestShellPerformanceComparison(t *testing.T) { - shell := newPersistentShell(t.TempDir()) + shell := NewShell(&Options{WorkingDir: t.TempDir()}) // Test quick command start := time.Now() @@ -27,7 +27,7 @@ func TestShellPerformanceComparison(t *testing.T) { // Benchmark CPU usage during polling func BenchmarkShellPolling(b *testing.B) { - shell := newPersistentShell(b.TempDir()) + shell := NewShell(&Options{WorkingDir: b.TempDir()}) b.ReportAllocs() diff --git a/internal/shell/doc.go b/internal/shell/doc.go new file mode 100644 index 0000000000000000000000000000000000000000..67c93de47625692341e53951c1b85c7ccc272cef --- /dev/null +++ b/internal/shell/doc.go @@ -0,0 +1,30 @@ +package shell + +// Example usage of the shell package: +// +// 1. For one-off commands: +// +// shell := shell.NewShell(nil) +// stdout, stderr, err := shell.Exec(context.Background(), "echo hello") +// +// 2. For maintaining state across commands: +// +// shell := shell.NewShell(&shell.Options{ +// WorkingDir: "/tmp", +// Logger: myLogger, +// }) +// shell.Exec(ctx, "export FOO=bar") +// shell.Exec(ctx, "echo $FOO") // Will print "bar" +// +// 3. For the singleton persistent shell (used by tools): +// +// shell := shell.GetPersistentShell("/path/to/cwd") +// stdout, stderr, err := shell.Exec(ctx, "ls -la") +// +// 4. Managing environment and working directory: +// +// shell := shell.NewShell(nil) +// shell.SetEnv("MY_VAR", "value") +// shell.SetWorkingDir("/tmp") +// cwd := shell.GetWorkingDir() +// env := shell.GetEnv() \ No newline at end of file diff --git a/internal/shell/persistent.go b/internal/shell/persistent.go new file mode 100644 index 0000000000000000000000000000000000000000..9038caaad67c46427e1852cd4ff68a5faa5b14b3 --- /dev/null +++ b/internal/shell/persistent.go @@ -0,0 +1,38 @@ +package shell + +import ( + "sync" + + "github.com/charmbracelet/crush/internal/logging" +) + +// PersistentShell is a singleton shell instance that maintains state across the application +type PersistentShell struct { + *Shell +} + +var ( + once sync.Once + shellInstance *PersistentShell +) + +// GetPersistentShell returns the singleton persistent shell instance +// This maintains backward compatibility with the existing API +func GetPersistentShell(cwd string) *PersistentShell { + once.Do(func() { + shellInstance = &PersistentShell{ + Shell: NewShell(&Options{ + WorkingDir: cwd, + Logger: &loggingAdapter{}, + }), + } + }) + return shellInstance +} + +// loggingAdapter adapts the internal logging package to the Logger interface +type loggingAdapter struct{} + +func (l *loggingAdapter) InfoPersist(msg string, keysAndValues ...interface{}) { + logging.InfoPersist(msg, keysAndValues...) +} \ No newline at end of file diff --git a/internal/llm/tools/shell/shell.go b/internal/shell/shell.go similarity index 64% rename from internal/llm/tools/shell/shell.go rename to internal/shell/shell.go index 81fa33085a40bf2bccd8814a5b179d3ca4453a8c..a80ce4a0237d82a6fdce08b9bf4752a2b6e1dcf5 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/shell/shell.go @@ -1,11 +1,12 @@ // Package shell provides cross-platform shell execution capabilities. // +// This package offers two main types: +// - Shell: A general-purpose shell executor for one-off or managed commands +// - PersistentShell: A singleton shell that maintains state across the application +// // WINDOWS COMPATIBILITY: // This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and -// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility: -// - On Windows: Uses native cmd.exe or PowerShell for Windows-specific commands -// - Cross-platform: Falls back to POSIX emulation for Unix-style commands -// - Automatic detection: Chooses the best shell based on command and platform +// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility. package shell import ( @@ -19,7 +20,6 @@ import ( "strings" "sync" - "github.com/charmbracelet/crush/internal/logging" "mvdan.cc/sh/v3/expand" "mvdan.cc/sh/v3/interp" "mvdan.cc/sh/v3/syntax" @@ -34,53 +34,61 @@ const ( ShellTypePowerShell ) -type PersistentShell struct { - env []string - cwd string - mu sync.Mutex +// Logger interface for optional logging +type Logger interface { + InfoPersist(msg string, keysAndValues ...interface{}) } -var ( - once sync.Once - shellInstance *PersistentShell -) +// noopLogger is a logger that does nothing +type noopLogger struct{} -// Windows-specific commands that should use native shell -var windowsNativeCommands = map[string]bool{ - "dir": true, - "type": true, - "copy": true, - "move": true, - "del": true, - "md": true, - "mkdir": true, - "rd": true, - "rmdir": true, - "cls": true, - "where": true, - "tasklist": true, - "taskkill": true, - "net": true, - "sc": true, - "reg": true, - "wmic": true, +func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {} + +// Shell provides cross-platform shell execution with optional state persistence +type Shell struct { + env []string + cwd string + mu sync.Mutex + logger Logger } -func GetPersistentShell(cwd string) *PersistentShell { - once.Do(func() { - shellInstance = newPersistentShell(cwd) - }) - return shellInstance +// Options for creating a new shell +type Options struct { + WorkingDir string + Env []string + Logger Logger } -func newPersistentShell(cwd string) *PersistentShell { - return &PersistentShell{ - cwd: cwd, - env: os.Environ(), +// NewShell creates a new shell instance with the given options +func NewShell(opts *Options) *Shell { + if opts == nil { + opts = &Options{} + } + + cwd := opts.WorkingDir + if cwd == "" { + cwd, _ = os.Getwd() + } + + env := opts.Env + if env == nil { + env = os.Environ() + } + + logger := opts.Logger + if logger == nil { + logger = noopLogger{} + } + + return &Shell{ + cwd: cwd, + env: env, + logger: logger, } } -func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) { +// Exec executes a command in the shell +func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) { s.mu.Lock() defer s.mu.Unlock() @@ -97,8 +105,76 @@ func (s *PersistentShell) Exec(ctx context.Context, command string) (string, str } } +// GetWorkingDir returns the current working directory +func (s *Shell) GetWorkingDir() string { + s.mu.Lock() + defer s.mu.Unlock() + return s.cwd +} + +// SetWorkingDir sets the working directory +func (s *Shell) SetWorkingDir(dir string) error { + s.mu.Lock() + defer s.mu.Unlock() + + // Verify the directory exists + if _, err := os.Stat(dir); err != nil { + return fmt.Errorf("directory does not exist: %w", err) + } + + s.cwd = dir + return nil +} + +// GetEnv returns a copy of the environment variables +func (s *Shell) GetEnv() []string { + s.mu.Lock() + defer s.mu.Unlock() + + env := make([]string, len(s.env)) + copy(env, s.env) + return env +} + +// SetEnv sets an environment variable +func (s *Shell) SetEnv(key, value string) { + s.mu.Lock() + defer s.mu.Unlock() + + // Update or add the environment variable + keyPrefix := key + "=" + for i, env := range s.env { + if strings.HasPrefix(env, keyPrefix) { + s.env[i] = keyPrefix + value + return + } + } + s.env = append(s.env, keyPrefix+value) +} + +// Windows-specific commands that should use native shell +var windowsNativeCommands = map[string]bool{ + "dir": true, + "type": true, + "copy": true, + "move": true, + "del": true, + "md": true, + "mkdir": true, + "rd": true, + "rmdir": true, + "cls": true, + "where": true, + "tasklist": true, + "taskkill": true, + "net": true, + "sc": true, + "reg": true, + "wmic": true, +} + // determineShellType decides which shell to use based on platform and command -func (s *PersistentShell) determineShellType(command string) ShellType { +func (s *Shell) determineShellType(command string) ShellType { if runtime.GOOS != "windows" { return ShellTypePOSIX } @@ -128,7 +204,7 @@ func (s *PersistentShell) determineShellType(command string) ShellType { } // execWindows executes commands using native Windows shells (cmd.exe or PowerShell) -func (s *PersistentShell) execWindows(ctx context.Context, command string, shell string) (string, string, error) { +func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) { var cmd *exec.Cmd // Handle directory changes specially to maintain persistent shell behavior @@ -160,12 +236,12 @@ func (s *PersistentShell) execWindows(ctx context.Context, command string, shell err := cmd.Run() - logging.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err) + s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err) return stdout.String(), stderr.String(), err } // handleWindowsCD handles directory changes for Windows shells -func (s *PersistentShell) handleWindowsCD(command string) (string, string, error) { +func (s *Shell) handleWindowsCD(command string) (string, string, error) { // Extract the target directory from the cd command parts := strings.Fields(command) if len(parts) < 2 { @@ -203,7 +279,7 @@ func (s *PersistentShell) handleWindowsCD(command string) (string, string, error } // execPOSIX executes commands using POSIX shell emulation (cross-platform) -func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string, string, error) { +func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) { line, err := syntax.NewParser().Parse(strings.NewReader(command), "") if err != nil { return "", "", fmt.Errorf("could not parse command: %w", err) @@ -226,15 +302,17 @@ func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string for name, vr := range runner.Vars { s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str)) } - logging.InfoPersist("POSIX command finished", "command", command, "err", err) + s.logger.InfoPersist("POSIX command finished", "command", command, "err", err) return stdout.String(), stderr.String(), err } +// IsInterrupt checks if an error is due to interruption func IsInterrupt(err error) bool { return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) } +// ExitCode extracts the exit code from an error func ExitCode(err error) int { if err == nil { return 0 @@ -245,3 +323,4 @@ func ExitCode(err error) int { } return 1 } + diff --git a/internal/llm/tools/shell/shell_test.go b/internal/shell/shell_test.go similarity index 90% rename from internal/llm/tools/shell/shell_test.go rename to internal/shell/shell_test.go index e4273a7a60a2ea96069b97b9d111e4a5b7a4c73a..417743caef7fa386f8c23d418682ab6a364e8e3e 100644 --- a/internal/llm/tools/shell/shell_test.go +++ b/internal/shell/shell_test.go @@ -10,7 +10,7 @@ import ( // Benchmark to measure CPU efficiency func BenchmarkShellQuickCommands(b *testing.B) { - shell := newPersistentShell(b.TempDir()) + shell := NewShell(&Options{WorkingDir: b.TempDir()}) b.ReportAllocs() @@ -27,7 +27,7 @@ func TestTestTimeout(t *testing.T) { ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond) t.Cleanup(cancel) - shell := newPersistentShell(t.TempDir()) + shell := NewShell(&Options{WorkingDir: t.TempDir()}) _, _, err := shell.Exec(ctx, "sleep 10") if status := ExitCode(err); status == 0 { t.Fatalf("Expected non-zero exit status, got %d", status) @@ -44,7 +44,7 @@ func TestTestCancel(t *testing.T) { ctx, cancel := context.WithCancel(t.Context()) cancel() // immediately cancel the context - shell := newPersistentShell(t.TempDir()) + shell := NewShell(&Options{WorkingDir: t.TempDir()}) _, _, err := shell.Exec(ctx, "sleep 10") if status := ExitCode(err); status == 0 { t.Fatalf("Expected non-zero exit status, got %d", status) @@ -58,7 +58,7 @@ func TestTestCancel(t *testing.T) { } func TestRunCommandError(t *testing.T) { - shell := newPersistentShell(t.TempDir()) + shell := NewShell(&Options{WorkingDir: t.TempDir()}) _, _, err := shell.Exec(t.Context(), "nopenopenope") if status := ExitCode(err); status == 0 { t.Fatalf("Expected non-zero exit status, got %d", status) @@ -72,7 +72,7 @@ func TestRunCommandError(t *testing.T) { } func TestRunContinuity(t *testing.T) { - shell := newPersistentShell(t.TempDir()) + shell := NewShell(&Options{WorkingDir: t.TempDir()}) shell.Exec(t.Context(), "export FOO=bar") dst := t.TempDir() shell.Exec(t.Context(), "cd "+dst) @@ -141,10 +141,9 @@ func TestWindowsCDHandling(t *testing.T) { t.Skip("Windows CD handling test only runs on Windows") } - shell := &PersistentShell{ - cwd: "C:\\Users", - env: []string{}, - } + shell := NewShell(&Options{ + WorkingDir: "C:\\Users", + }) tests := []struct { command string @@ -159,7 +158,7 @@ func TestWindowsCDHandling(t *testing.T) { for _, test := range tests { t.Run(test.command, func(t *testing.T) { - originalCwd := shell.cwd + originalCwd := shell.GetWorkingDir() stdout, stderr, err := shell.handleWindowsCD(test.command) if test.shouldError { @@ -170,13 +169,13 @@ func TestWindowsCDHandling(t *testing.T) { if err != nil { t.Errorf("Command %q failed: %v", test.command, err) } - if shell.cwd != test.expectedCwd { - t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.cwd) + if shell.GetWorkingDir() != test.expectedCwd { + t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.GetWorkingDir()) } } // Reset for next test - shell.cwd = originalCwd + shell.SetWorkingDir(originalCwd) _ = stdout _ = stderr }) @@ -184,7 +183,7 @@ func TestWindowsCDHandling(t *testing.T) { } func TestCrossPlatformExecution(t *testing.T) { - shell := newPersistentShell(".") + shell := NewShell(&Options{WorkingDir: "."}) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel() @@ -209,7 +208,7 @@ func TestWindowsNativeCommands(t *testing.T) { t.Skip("Windows native command test only runs on Windows") } - shell := newPersistentShell(".") + shell := NewShell(&Options{WorkingDir: "."}) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) defer cancel()