@@ -13,7 +13,8 @@ func TestShellPerformanceComparison(t *testing.T) {
// Test quick command
start := time.Now()
- stdout, stderr, exitCode, _, err := shell.Exec(t.Context(), "echo 'hello'", 0)
+ stdout, stderr, err := shell.Exec(t.Context(), "echo 'hello'")
+ exitCode := ExitCode(err)
duration := time.Since(start)
require.NoError(t, err)
@@ -32,7 +33,8 @@ func BenchmarkShellPolling(b *testing.B) {
for b.Loop() {
// Use a short sleep to measure polling overhead
- _, _, exitCode, _, err := shell.Exec(b.Context(), "sleep 0.02", 0)
+ _, _, err := shell.Exec(b.Context(), "sleep 0.02")
+ exitCode := ExitCode(err)
if err != nil || exitCode != 0 {
b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
}
@@ -1,323 +1,87 @@
package shell
import (
+ "bytes"
"context"
"errors"
"fmt"
"os"
- "os/exec"
- "path/filepath"
"strings"
"sync"
- "syscall"
- "time"
+
+ "github.com/charmbracelet/crush/internal/logging"
+ "mvdan.cc/sh/v3/expand"
+ "mvdan.cc/sh/v3/interp"
+ "mvdan.cc/sh/v3/syntax"
)
type PersistentShell struct {
- cmd *exec.Cmd
- stdin *os.File
- isAlive bool
- cwd string
- mu sync.Mutex
- commandQueue chan *commandExecution
-}
-
-type commandExecution struct {
- command string
- timeout time.Duration
- resultChan chan commandResult
- ctx context.Context
-}
-
-type commandResult struct {
- stdout string
- stderr string
- exitCode int
- interrupted bool
- err error
+ env []string
+ cwd string
+ mu sync.Mutex
}
var (
- shellInstance *PersistentShell
- shellInstanceOnce sync.Once
+ once sync.Once
+ shellInstance *PersistentShell
)
-func GetPersistentShell(workingDir string) *PersistentShell {
- shellInstanceOnce.Do(func() {
- shellInstance = newPersistentShell(workingDir)
+func GetPersistentShell(cwd string) *PersistentShell {
+ once.Do(func() {
+ shellInstance = newPersistentShell(cwd)
})
-
- if shellInstance == nil {
- shellInstance = newPersistentShell(workingDir)
- } else if !shellInstance.isAlive {
- shellInstance = newPersistentShell(shellInstance.cwd)
- }
-
return shellInstance
}
func newPersistentShell(cwd string) *PersistentShell {
- // Default to environment variable
- shellPath := os.Getenv("SHELL")
- if shellPath == "" {
- shellPath = "/bin/bash"
- }
-
- // Default shell args
- shellArgs := []string{"-l"}
-
- cmd := exec.Command(shellPath, shellArgs...)
- cmd.Dir = cwd
-
- stdinPipe, err := cmd.StdinPipe()
- if err != nil {
- return nil
- }
-
- cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
-
- err = cmd.Start()
- if err != nil {
- return nil
- }
-
- shell := &PersistentShell{
- cmd: cmd,
- stdin: stdinPipe.(*os.File),
- isAlive: true,
- cwd: cwd,
- commandQueue: make(chan *commandExecution, 10),
+ return &PersistentShell{
+ cwd: cwd,
+ env: os.Environ(),
}
-
- go func() {
- defer func() {
- if r := recover(); r != nil {
- fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
- shell.isAlive = false
- close(shell.commandQueue)
- }
- }()
- shell.processCommands()
- }()
-
- go func() {
- err := cmd.Wait()
- if err != nil {
- // Log the error if needed
- }
- shell.isAlive = false
- close(shell.commandQueue)
- }()
-
- return shell
}
-func (s *PersistentShell) processCommands() {
- for cmd := range s.commandQueue {
- result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
- cmd.resultChan <- result
- }
-}
-
-func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
+func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
s.mu.Lock()
defer s.mu.Unlock()
- if !s.isAlive {
- return commandResult{
- stderr: "Shell is not alive",
- exitCode: 1,
- err: errors.New("shell is not alive"),
- }
- }
-
- tempDir := os.TempDir()
- stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
- stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
- statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
- cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
-
- defer func() {
- os.Remove(stdoutFile)
- os.Remove(stderrFile)
- os.Remove(statusFile)
- os.Remove(cwdFile)
- }()
-
- fullCommand := fmt.Sprintf(`
-eval %s < /dev/null > %s 2> %s
-EXEC_EXIT_CODE=$?
-pwd > %s
-echo $EXEC_EXIT_CODE > %s
-`,
- shellQuote(command),
- shellQuote(stdoutFile),
- shellQuote(stderrFile),
- shellQuote(cwdFile),
- shellQuote(statusFile),
- )
-
- _, err := s.stdin.Write([]byte(fullCommand + "\n"))
+ line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
if err != nil {
- return commandResult{
- stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
- exitCode: 1,
- err: err,
- }
+ return "", "", fmt.Errorf("could not parse command: %w", err)
}
- interrupted := false
-
- startTime := time.Now()
-
- done := make(chan bool)
- go func() {
- // Use exponential backoff polling
- pollInterval := 1 * time.Millisecond
- maxPollInterval := 100 * time.Millisecond
-
- ticker := time.NewTicker(pollInterval)
- defer ticker.Stop()
-
- for {
- select {
- case <-ctx.Done():
- s.killChildren()
- interrupted = true
- done <- true
- return
-
- case <-ticker.C:
- if fileExists(statusFile) && fileSize(statusFile) > 0 {
- done <- true
- return
- }
-
- if timeout > 0 {
- elapsed := time.Since(startTime)
- if elapsed > timeout {
- s.killChildren()
- interrupted = true
- done <- true
- return
- }
- }
-
- // Exponential backoff to reduce CPU usage for longer-running commands
- if pollInterval < maxPollInterval {
- pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
- ticker.Reset(pollInterval)
- }
- }
- }
- }()
-
- <-done
-
- stdout := readFileOrEmpty(stdoutFile)
- stderr := readFileOrEmpty(stderrFile)
- exitCodeStr := readFileOrEmpty(statusFile)
- newCwd := readFileOrEmpty(cwdFile)
-
- exitCode := 0
- if exitCodeStr != "" {
- fmt.Sscanf(exitCodeStr, "%d", &exitCode)
- } else if interrupted {
- exitCode = 143
- stderr += "\nCommand execution timed out or was interrupted"
- }
-
- if newCwd != "" {
- s.cwd = strings.TrimSpace(newCwd)
- }
-
- return commandResult{
- stdout: stdout,
- stderr: stderr,
- exitCode: exitCode,
- interrupted: interrupted,
- }
-}
-
-func (s *PersistentShell) killChildren() {
- if s.cmd == nil || s.cmd.Process == nil {
- return
- }
-
- pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
- output, err := pgrepCmd.Output()
+ var stdout, stderr bytes.Buffer
+ runner, err := interp.New(
+ interp.StdIO(nil, &stdout, &stderr),
+ interp.Interactive(false),
+ interp.Env(expand.ListEnviron(s.env...)),
+ interp.Dir(s.cwd),
+ )
if err != nil {
- return
+ return "", "", fmt.Errorf("could not run command: %w", err)
}
- for pidStr := range strings.SplitSeq(string(output), "\n") {
- if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
- var pid int
- fmt.Sscanf(pidStr, "%d", &pid)
- if pid > 0 {
- proc, err := os.FindProcess(pid)
- if err == nil {
- proc.Signal(syscall.SIGTERM)
- }
- }
- }
+ err = runner.Run(ctx, line)
+ s.cwd = runner.Dir
+ s.env = []string{}
+ for name, vr := range runner.Vars {
+ s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
}
+ logging.InfoPersist("Command finished", "command", command, "err", err)
+ return stdout.String(), stderr.String(), err
}
-func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
- if !s.isAlive {
- return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
- }
-
- timeout := time.Duration(timeoutMs) * time.Millisecond
-
- resultChan := make(chan commandResult)
- s.commandQueue <- &commandExecution{
- command: command,
- timeout: timeout,
- resultChan: resultChan,
- ctx: ctx,
- }
-
- result := <-resultChan
- return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
+func IsInterrupt(err error) bool {
+ return errors.Is(err, context.Canceled) ||
+ errors.Is(err, context.DeadlineExceeded)
}
-func (s *PersistentShell) Close() {
- s.mu.Lock()
- defer s.mu.Unlock()
-
- if !s.isAlive {
- return
+func ExitCode(err error) int {
+ if err == nil {
+ return 0
}
-
- s.stdin.Write([]byte("exit\n"))
-
- s.cmd.Process.Kill()
- s.isAlive = false
-}
-
-func shellQuote(s string) string {
- return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
-}
-
-func readFileOrEmpty(path string) string {
- content, err := os.ReadFile(path)
- if err != nil {
- return ""
+ status, ok := interp.IsExitStatus(err)
+ if ok {
+ return int(status)
}
- return string(content)
-}
-
-func fileExists(path string) bool {
- _, err := os.Stat(path)
- return err == nil
+ return 1
}
-
-func fileSize(path string) int64 {
- info, err := os.Stat(path)
- if err != nil {
- return 0
- }
- return info.Size()
-}
@@ -2,28 +2,81 @@ package shell
import (
"context"
- "os"
"testing"
-
- "github.com/stretchr/testify/require"
+ "time"
)
// Benchmark to measure CPU efficiency
func BenchmarkShellQuickCommands(b *testing.B) {
- tmpDir, err := os.MkdirTemp("", "shell-bench")
- require.NoError(b, err)
- defer os.RemoveAll(tmpDir)
-
- shell := GetPersistentShell(tmpDir)
- defer shell.Close()
+ shell := newPersistentShell(b.TempDir())
- b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
- _, _, exitCode, _, err := shell.Exec(context.Background(), "echo test", 0)
+ _, _, err := shell.Exec(context.Background(), "echo test")
+ exitCode := ExitCode(err)
if err != nil || exitCode != 0 {
b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
}
}
-}
+}
+
+func TestTestTimeout(t *testing.T) {
+ ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
+ t.Cleanup(cancel)
+
+ shell := newPersistentShell(t.TempDir())
+ _, _, err := shell.Exec(ctx, "sleep 10")
+ if status := ExitCode(err); status == 0 {
+ t.Fatalf("Expected non-zero exit status, got %d", status)
+ }
+ if !IsInterrupt(err) {
+ t.Fatalf("Expected command to be interrupted, but it was not")
+ }
+ if err == nil {
+ t.Fatalf("Expected an error due to timeout, but got none")
+ }
+}
+
+func TestTestCancel(t *testing.T) {
+ ctx, cancel := context.WithCancel(t.Context())
+ cancel() // immediately cancel the context
+
+ shell := newPersistentShell(t.TempDir())
+ _, _, err := shell.Exec(ctx, "sleep 10")
+ if status := ExitCode(err); status == 0 {
+ t.Fatalf("Expected non-zero exit status, got %d", status)
+ }
+ if !IsInterrupt(err) {
+ t.Fatalf("Expected command to be interrupted, but it was not")
+ }
+ if err == nil {
+ t.Fatalf("Expected an error due to cancel, but got none")
+ }
+}
+
+func TestRunCommandError(t *testing.T) {
+ shell := newPersistentShell(t.TempDir())
+ _, _, err := shell.Exec(t.Context(), "nopenopenope")
+ if status := ExitCode(err); status == 0 {
+ t.Fatalf("Expected non-zero exit status, got %d", status)
+ }
+ if IsInterrupt(err) {
+ t.Fatalf("Expected command to not be interrupted, but it was")
+ }
+ if err == nil {
+ t.Fatalf("Expected an error, got nil")
+ }
+}
+
+func TestRunContinuity(t *testing.T) {
+ shell := newPersistentShell(t.TempDir())
+ shell.Exec(t.Context(), "export FOO=bar")
+ dst := t.TempDir()
+ shell.Exec(t.Context(), "cd "+dst)
+ out, _, _ := shell.Exec(t.Context(), "echo $FOO ; pwd")
+ expect := "bar\n" + dst + "\n"
+ if out != expect {
+ t.Fatalf("Expected output %q, got %q", expect, out)
+ }
+}