@@ -1,11 +1,12 @@
// Package shell provides cross-platform shell execution capabilities.
//
+// This package offers two main types:
+// - Shell: A general-purpose shell executor for one-off or managed commands
+// - PersistentShell: A singleton shell that maintains state across the application
+//
// WINDOWS COMPATIBILITY:
// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
-// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility:
-// - On Windows: Uses native cmd.exe or PowerShell for Windows-specific commands
-// - Cross-platform: Falls back to POSIX emulation for Unix-style commands
-// - Automatic detection: Chooses the best shell based on command and platform
+// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
package shell
import (
@@ -19,7 +20,6 @@ import (
"strings"
"sync"
- "github.com/charmbracelet/crush/internal/logging"
"mvdan.cc/sh/v3/expand"
"mvdan.cc/sh/v3/interp"
"mvdan.cc/sh/v3/syntax"
@@ -34,53 +34,61 @@ const (
ShellTypePowerShell
)
-type PersistentShell struct {
- env []string
- cwd string
- mu sync.Mutex
+// Logger interface for optional logging
+type Logger interface {
+ InfoPersist(msg string, keysAndValues ...interface{})
}
-var (
- once sync.Once
- shellInstance *PersistentShell
-)
+// noopLogger is a logger that does nothing
+type noopLogger struct{}
-// Windows-specific commands that should use native shell
-var windowsNativeCommands = map[string]bool{
- "dir": true,
- "type": true,
- "copy": true,
- "move": true,
- "del": true,
- "md": true,
- "mkdir": true,
- "rd": true,
- "rmdir": true,
- "cls": true,
- "where": true,
- "tasklist": true,
- "taskkill": true,
- "net": true,
- "sc": true,
- "reg": true,
- "wmic": true,
+func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
+
+// Shell provides cross-platform shell execution with optional state persistence
+type Shell struct {
+ env []string
+ cwd string
+ mu sync.Mutex
+ logger Logger
}
-func GetPersistentShell(cwd string) *PersistentShell {
- once.Do(func() {
- shellInstance = newPersistentShell(cwd)
- })
- return shellInstance
+// Options for creating a new shell
+type Options struct {
+ WorkingDir string
+ Env []string
+ Logger Logger
}
-func newPersistentShell(cwd string) *PersistentShell {
- return &PersistentShell{
- cwd: cwd,
- env: os.Environ(),
+// NewShell creates a new shell instance with the given options
+func NewShell(opts *Options) *Shell {
+ if opts == nil {
+ opts = &Options{}
+ }
+
+ cwd := opts.WorkingDir
+ if cwd == "" {
+ cwd, _ = os.Getwd()
+ }
+
+ env := opts.Env
+ if env == nil {
+ env = os.Environ()
+ }
+
+ logger := opts.Logger
+ if logger == nil {
+ logger = noopLogger{}
+ }
+
+ return &Shell{
+ cwd: cwd,
+ env: env,
+ logger: logger,
}
}
-func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
+// Exec executes a command in the shell
+func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -97,8 +105,76 @@ func (s *PersistentShell) Exec(ctx context.Context, command string) (string, str
}
}
+// GetWorkingDir returns the current working directory
+func (s *Shell) GetWorkingDir() string {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return s.cwd
+}
+
+// SetWorkingDir sets the working directory
+func (s *Shell) SetWorkingDir(dir string) error {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Verify the directory exists
+ if _, err := os.Stat(dir); err != nil {
+ return fmt.Errorf("directory does not exist: %w", err)
+ }
+
+ s.cwd = dir
+ return nil
+}
+
+// GetEnv returns a copy of the environment variables
+func (s *Shell) GetEnv() []string {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ env := make([]string, len(s.env))
+ copy(env, s.env)
+ return env
+}
+
+// SetEnv sets an environment variable
+func (s *Shell) SetEnv(key, value string) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // Update or add the environment variable
+ keyPrefix := key + "="
+ for i, env := range s.env {
+ if strings.HasPrefix(env, keyPrefix) {
+ s.env[i] = keyPrefix + value
+ return
+ }
+ }
+ s.env = append(s.env, keyPrefix+value)
+}
+
+// Windows-specific commands that should use native shell
+var windowsNativeCommands = map[string]bool{
+ "dir": true,
+ "type": true,
+ "copy": true,
+ "move": true,
+ "del": true,
+ "md": true,
+ "mkdir": true,
+ "rd": true,
+ "rmdir": true,
+ "cls": true,
+ "where": true,
+ "tasklist": true,
+ "taskkill": true,
+ "net": true,
+ "sc": true,
+ "reg": true,
+ "wmic": true,
+}
+
// determineShellType decides which shell to use based on platform and command
-func (s *PersistentShell) determineShellType(command string) ShellType {
+func (s *Shell) determineShellType(command string) ShellType {
if runtime.GOOS != "windows" {
return ShellTypePOSIX
}
@@ -128,7 +204,7 @@ func (s *PersistentShell) determineShellType(command string) ShellType {
}
// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
-func (s *PersistentShell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
+func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
var cmd *exec.Cmd
// Handle directory changes specially to maintain persistent shell behavior
@@ -160,12 +236,12 @@ func (s *PersistentShell) execWindows(ctx context.Context, command string, shell
err := cmd.Run()
- logging.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
+ s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
return stdout.String(), stderr.String(), err
}
// handleWindowsCD handles directory changes for Windows shells
-func (s *PersistentShell) handleWindowsCD(command string) (string, string, error) {
+func (s *Shell) handleWindowsCD(command string) (string, string, error) {
// Extract the target directory from the cd command
parts := strings.Fields(command)
if len(parts) < 2 {
@@ -203,7 +279,7 @@ func (s *PersistentShell) handleWindowsCD(command string) (string, string, error
}
// execPOSIX executes commands using POSIX shell emulation (cross-platform)
-func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string, string, error) {
+func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
if err != nil {
return "", "", fmt.Errorf("could not parse command: %w", err)
@@ -226,15 +302,17 @@ func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string
for name, vr := range runner.Vars {
s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
}
- logging.InfoPersist("POSIX command finished", "command", command, "err", err)
+ s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
return stdout.String(), stderr.String(), err
}
+// IsInterrupt checks if an error is due to interruption
func IsInterrupt(err error) bool {
return errors.Is(err, context.Canceled) ||
errors.Is(err, context.DeadlineExceeded)
}
+// ExitCode extracts the exit code from an error
func ExitCode(err error) int {
if err == nil {
return 0
@@ -245,3 +323,4 @@ func ExitCode(err error) int {
}
return 1
}
+
@@ -10,7 +10,7 @@ import (
// Benchmark to measure CPU efficiency
func BenchmarkShellQuickCommands(b *testing.B) {
- shell := newPersistentShell(b.TempDir())
+ shell := NewShell(&Options{WorkingDir: b.TempDir()})
b.ReportAllocs()
@@ -27,7 +27,7 @@ func TestTestTimeout(t *testing.T) {
ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
t.Cleanup(cancel)
- shell := newPersistentShell(t.TempDir())
+ shell := NewShell(&Options{WorkingDir: t.TempDir()})
_, _, err := shell.Exec(ctx, "sleep 10")
if status := ExitCode(err); status == 0 {
t.Fatalf("Expected non-zero exit status, got %d", status)
@@ -44,7 +44,7 @@ func TestTestCancel(t *testing.T) {
ctx, cancel := context.WithCancel(t.Context())
cancel() // immediately cancel the context
- shell := newPersistentShell(t.TempDir())
+ shell := NewShell(&Options{WorkingDir: t.TempDir()})
_, _, err := shell.Exec(ctx, "sleep 10")
if status := ExitCode(err); status == 0 {
t.Fatalf("Expected non-zero exit status, got %d", status)
@@ -58,7 +58,7 @@ func TestTestCancel(t *testing.T) {
}
func TestRunCommandError(t *testing.T) {
- shell := newPersistentShell(t.TempDir())
+ shell := NewShell(&Options{WorkingDir: t.TempDir()})
_, _, err := shell.Exec(t.Context(), "nopenopenope")
if status := ExitCode(err); status == 0 {
t.Fatalf("Expected non-zero exit status, got %d", status)
@@ -72,7 +72,7 @@ func TestRunCommandError(t *testing.T) {
}
func TestRunContinuity(t *testing.T) {
- shell := newPersistentShell(t.TempDir())
+ shell := NewShell(&Options{WorkingDir: t.TempDir()})
shell.Exec(t.Context(), "export FOO=bar")
dst := t.TempDir()
shell.Exec(t.Context(), "cd "+dst)
@@ -141,10 +141,9 @@ func TestWindowsCDHandling(t *testing.T) {
t.Skip("Windows CD handling test only runs on Windows")
}
- shell := &PersistentShell{
- cwd: "C:\\Users",
- env: []string{},
- }
+ shell := NewShell(&Options{
+ WorkingDir: "C:\\Users",
+ })
tests := []struct {
command string
@@ -159,7 +158,7 @@ func TestWindowsCDHandling(t *testing.T) {
for _, test := range tests {
t.Run(test.command, func(t *testing.T) {
- originalCwd := shell.cwd
+ originalCwd := shell.GetWorkingDir()
stdout, stderr, err := shell.handleWindowsCD(test.command)
if test.shouldError {
@@ -170,13 +169,13 @@ func TestWindowsCDHandling(t *testing.T) {
if err != nil {
t.Errorf("Command %q failed: %v", test.command, err)
}
- if shell.cwd != test.expectedCwd {
- t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.cwd)
+ if shell.GetWorkingDir() != test.expectedCwd {
+ t.Errorf("Command %q: expected cwd %q, got %q", test.command, test.expectedCwd, shell.GetWorkingDir())
}
}
// Reset for next test
- shell.cwd = originalCwd
+ shell.SetWorkingDir(originalCwd)
_ = stdout
_ = stderr
})
@@ -184,7 +183,7 @@ func TestWindowsCDHandling(t *testing.T) {
}
func TestCrossPlatformExecution(t *testing.T) {
- shell := newPersistentShell(".")
+ shell := NewShell(&Options{WorkingDir: "."})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
@@ -209,7 +208,7 @@ func TestWindowsNativeCommands(t *testing.T) {
t.Skip("Windows native command test only runs on Windows")
}
- shell := newPersistentShell(".")
+ shell := NewShell(&Options{WorkingDir: "."})
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()