diff --git a/internal/llm/tools/shell/comparison_test.go b/internal/llm/tools/shell/comparison_test.go index 8fe2159b85d392d87b2602f77f105375cceb07c3..4906a5c8acb8136d7b7947177f999f7019080f5f 100644 --- a/internal/llm/tools/shell/comparison_test.go +++ b/internal/llm/tools/shell/comparison_test.go @@ -13,7 +13,8 @@ func TestShellPerformanceComparison(t *testing.T) { // Test quick command start := time.Now() - stdout, stderr, exitCode, _, err := shell.Exec(t.Context(), "echo 'hello'", 0) + stdout, stderr, err := shell.Exec(t.Context(), "echo 'hello'") + exitCode := ExitCode(err) duration := time.Since(start) require.NoError(t, err) @@ -32,7 +33,8 @@ func BenchmarkShellPolling(b *testing.B) { for b.Loop() { // Use a short sleep to measure polling overhead - _, _, exitCode, _, err := shell.Exec(b.Context(), "sleep 0.02", 0) + _, _, err := shell.Exec(b.Context(), "sleep 0.02") + exitCode := ExitCode(err) if err != nil || exitCode != 0 { b.Fatalf("Command failed: %v, exit code: %d", err, exitCode) } diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index c9d6826a7f4f17f3df29e1871790222c5df55fa6..cbf33e700c15ae06b66a4cbafa620bbc7ceb1405 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -1,323 +1,87 @@ package shell import ( + "bytes" "context" "errors" "fmt" "os" - "os/exec" - "path/filepath" "strings" "sync" - "syscall" - "time" + + "github.com/charmbracelet/crush/internal/logging" + "mvdan.cc/sh/v3/expand" + "mvdan.cc/sh/v3/interp" + "mvdan.cc/sh/v3/syntax" ) type PersistentShell struct { - cmd *exec.Cmd - stdin *os.File - isAlive bool - cwd string - mu sync.Mutex - commandQueue chan *commandExecution -} - -type commandExecution struct { - command string - timeout time.Duration - resultChan chan commandResult - ctx context.Context -} - -type commandResult struct { - stdout string - stderr string - exitCode int - interrupted bool - err error + env []string + cwd string + mu sync.Mutex } var ( - shellInstance *PersistentShell - shellInstanceOnce sync.Once + once sync.Once + shellInstance *PersistentShell ) -func GetPersistentShell(workingDir string) *PersistentShell { - shellInstanceOnce.Do(func() { - shellInstance = newPersistentShell(workingDir) +func GetPersistentShell(cwd string) *PersistentShell { + once.Do(func() { + shellInstance = newPersistentShell(cwd) }) - - if shellInstance == nil { - shellInstance = newPersistentShell(workingDir) - } else if !shellInstance.isAlive { - shellInstance = newPersistentShell(shellInstance.cwd) - } - return shellInstance } func newPersistentShell(cwd string) *PersistentShell { - // Default to environment variable - shellPath := os.Getenv("SHELL") - if shellPath == "" { - shellPath = "/bin/bash" - } - - // Default shell args - shellArgs := []string{"-l"} - - cmd := exec.Command(shellPath, shellArgs...) - cmd.Dir = cwd - - stdinPipe, err := cmd.StdinPipe() - if err != nil { - return nil - } - - cmd.Env = append(os.Environ(), "GIT_EDITOR=true") - - err = cmd.Start() - if err != nil { - return nil - } - - shell := &PersistentShell{ - cmd: cmd, - stdin: stdinPipe.(*os.File), - isAlive: true, - cwd: cwd, - commandQueue: make(chan *commandExecution, 10), + return &PersistentShell{ + cwd: cwd, + env: os.Environ(), } - - go func() { - defer func() { - if r := recover(); r != nil { - fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r) - shell.isAlive = false - close(shell.commandQueue) - } - }() - shell.processCommands() - }() - - go func() { - err := cmd.Wait() - if err != nil { - // Log the error if needed - } - shell.isAlive = false - close(shell.commandQueue) - }() - - return shell } -func (s *PersistentShell) processCommands() { - for cmd := range s.commandQueue { - result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx) - cmd.resultChan <- result - } -} - -func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult { +func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) { s.mu.Lock() defer s.mu.Unlock() - if !s.isAlive { - return commandResult{ - stderr: "Shell is not alive", - exitCode: 1, - err: errors.New("shell is not alive"), - } - } - - tempDir := os.TempDir() - stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano())) - stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano())) - statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano())) - cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano())) - - defer func() { - os.Remove(stdoutFile) - os.Remove(stderrFile) - os.Remove(statusFile) - os.Remove(cwdFile) - }() - - fullCommand := fmt.Sprintf(` -eval %s < /dev/null > %s 2> %s -EXEC_EXIT_CODE=$? -pwd > %s -echo $EXEC_EXIT_CODE > %s -`, - shellQuote(command), - shellQuote(stdoutFile), - shellQuote(stderrFile), - shellQuote(cwdFile), - shellQuote(statusFile), - ) - - _, err := s.stdin.Write([]byte(fullCommand + "\n")) + line, err := syntax.NewParser().Parse(strings.NewReader(command), "") if err != nil { - return commandResult{ - stderr: fmt.Sprintf("Failed to write command to shell: %v", err), - exitCode: 1, - err: err, - } + return "", "", fmt.Errorf("could not parse command: %w", err) } - interrupted := false - - startTime := time.Now() - - done := make(chan bool) - go func() { - // Use exponential backoff polling - pollInterval := 1 * time.Millisecond - maxPollInterval := 100 * time.Millisecond - - ticker := time.NewTicker(pollInterval) - defer ticker.Stop() - - for { - select { - case <-ctx.Done(): - s.killChildren() - interrupted = true - done <- true - return - - case <-ticker.C: - if fileExists(statusFile) && fileSize(statusFile) > 0 { - done <- true - return - } - - if timeout > 0 { - elapsed := time.Since(startTime) - if elapsed > timeout { - s.killChildren() - interrupted = true - done <- true - return - } - } - - // Exponential backoff to reduce CPU usage for longer-running commands - if pollInterval < maxPollInterval { - pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval) - ticker.Reset(pollInterval) - } - } - } - }() - - <-done - - stdout := readFileOrEmpty(stdoutFile) - stderr := readFileOrEmpty(stderrFile) - exitCodeStr := readFileOrEmpty(statusFile) - newCwd := readFileOrEmpty(cwdFile) - - exitCode := 0 - if exitCodeStr != "" { - fmt.Sscanf(exitCodeStr, "%d", &exitCode) - } else if interrupted { - exitCode = 143 - stderr += "\nCommand execution timed out or was interrupted" - } - - if newCwd != "" { - s.cwd = strings.TrimSpace(newCwd) - } - - return commandResult{ - stdout: stdout, - stderr: stderr, - exitCode: exitCode, - interrupted: interrupted, - } -} - -func (s *PersistentShell) killChildren() { - if s.cmd == nil || s.cmd.Process == nil { - return - } - - pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid)) - output, err := pgrepCmd.Output() + var stdout, stderr bytes.Buffer + runner, err := interp.New( + interp.StdIO(nil, &stdout, &stderr), + interp.Interactive(false), + interp.Env(expand.ListEnviron(s.env...)), + interp.Dir(s.cwd), + ) if err != nil { - return + return "", "", fmt.Errorf("could not run command: %w", err) } - for pidStr := range strings.SplitSeq(string(output), "\n") { - if pidStr = strings.TrimSpace(pidStr); pidStr != "" { - var pid int - fmt.Sscanf(pidStr, "%d", &pid) - if pid > 0 { - proc, err := os.FindProcess(pid) - if err == nil { - proc.Signal(syscall.SIGTERM) - } - } - } + err = runner.Run(ctx, line) + s.cwd = runner.Dir + s.env = []string{} + for name, vr := range runner.Vars { + s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str)) } + logging.InfoPersist("Command finished", "command", command, "err", err) + return stdout.String(), stderr.String(), err } -func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) { - if !s.isAlive { - return "", "Shell is not alive", 1, false, errors.New("shell is not alive") - } - - timeout := time.Duration(timeoutMs) * time.Millisecond - - resultChan := make(chan commandResult) - s.commandQueue <- &commandExecution{ - command: command, - timeout: timeout, - resultChan: resultChan, - ctx: ctx, - } - - result := <-resultChan - return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err +func IsInterrupt(err error) bool { + return errors.Is(err, context.Canceled) || + errors.Is(err, context.DeadlineExceeded) } -func (s *PersistentShell) Close() { - s.mu.Lock() - defer s.mu.Unlock() - - if !s.isAlive { - return +func ExitCode(err error) int { + if err == nil { + return 0 } - - s.stdin.Write([]byte("exit\n")) - - s.cmd.Process.Kill() - s.isAlive = false -} - -func shellQuote(s string) string { - return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" -} - -func readFileOrEmpty(path string) string { - content, err := os.ReadFile(path) - if err != nil { - return "" + status, ok := interp.IsExitStatus(err) + if ok { + return int(status) } - return string(content) -} - -func fileExists(path string) bool { - _, err := os.Stat(path) - return err == nil + return 1 } - -func fileSize(path string) int64 { - info, err := os.Stat(path) - if err != nil { - return 0 - } - return info.Size() -} \ No newline at end of file diff --git a/internal/llm/tools/shell/shell_test.go b/internal/llm/tools/shell/shell_test.go index 329e6e3fa6ba65b8a5784e6c699bd0053f171888..630995aaa756295ff1eb14b05d11a2ee8f634733 100644 --- a/internal/llm/tools/shell/shell_test.go +++ b/internal/llm/tools/shell/shell_test.go @@ -2,28 +2,81 @@ package shell import ( "context" - "os" "testing" - - "github.com/stretchr/testify/require" + "time" ) // Benchmark to measure CPU efficiency func BenchmarkShellQuickCommands(b *testing.B) { - tmpDir, err := os.MkdirTemp("", "shell-bench") - require.NoError(b, err) - defer os.RemoveAll(tmpDir) - - shell := GetPersistentShell(tmpDir) - defer shell.Close() + shell := newPersistentShell(b.TempDir()) - b.ResetTimer() b.ReportAllocs() for b.Loop() { - _, _, exitCode, _, err := shell.Exec(context.Background(), "echo test", 0) + _, _, err := shell.Exec(context.Background(), "echo test") + exitCode := ExitCode(err) if err != nil || exitCode != 0 { b.Fatalf("Command failed: %v, exit code: %d", err, exitCode) } } -} \ No newline at end of file +} + +func TestTestTimeout(t *testing.T) { + ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond) + t.Cleanup(cancel) + + shell := newPersistentShell(t.TempDir()) + _, _, err := shell.Exec(ctx, "sleep 10") + if status := ExitCode(err); status == 0 { + t.Fatalf("Expected non-zero exit status, got %d", status) + } + if !IsInterrupt(err) { + t.Fatalf("Expected command to be interrupted, but it was not") + } + if err == nil { + t.Fatalf("Expected an error due to timeout, but got none") + } +} + +func TestTestCancel(t *testing.T) { + ctx, cancel := context.WithCancel(t.Context()) + cancel() // immediately cancel the context + + shell := newPersistentShell(t.TempDir()) + _, _, err := shell.Exec(ctx, "sleep 10") + if status := ExitCode(err); status == 0 { + t.Fatalf("Expected non-zero exit status, got %d", status) + } + if !IsInterrupt(err) { + t.Fatalf("Expected command to be interrupted, but it was not") + } + if err == nil { + t.Fatalf("Expected an error due to cancel, but got none") + } +} + +func TestRunCommandError(t *testing.T) { + shell := newPersistentShell(t.TempDir()) + _, _, err := shell.Exec(t.Context(), "nopenopenope") + if status := ExitCode(err); status == 0 { + t.Fatalf("Expected non-zero exit status, got %d", status) + } + if IsInterrupt(err) { + t.Fatalf("Expected command to not be interrupted, but it was") + } + if err == nil { + t.Fatalf("Expected an error, got nil") + } +} + +func TestRunContinuity(t *testing.T) { + shell := newPersistentShell(t.TempDir()) + shell.Exec(t.Context(), "export FOO=bar") + dst := t.TempDir() + shell.Exec(t.Context(), "cd "+dst) + out, _, _ := shell.Exec(t.Context(), "echo $FOO ; pwd") + expect := "bar\n" + dst + "\n" + if out != expect { + t.Fatalf("Expected output %q, got %q", expect, out) + } +}