revert changes on shell

Raphael Amorim created

Change summary

internal/llm/tools/shell/comparison_test.go |   6 
internal/llm/tools/shell/shell.go           | 326 +++-------------------
internal/llm/tools/shell/shell_test.go      |  77 ++++
3 files changed, 114 insertions(+), 295 deletions(-)

Detailed changes

internal/llm/tools/shell/comparison_test.go 🔗

@@ -13,7 +13,8 @@ func TestShellPerformanceComparison(t *testing.T) {
 
 	// Test quick command
 	start := time.Now()
-	stdout, stderr, exitCode, _, err := shell.Exec(t.Context(), "echo 'hello'", 0)
+	stdout, stderr, err := shell.Exec(t.Context(), "echo 'hello'")
+	exitCode := ExitCode(err)
 	duration := time.Since(start)
 
 	require.NoError(t, err)
@@ -32,7 +33,8 @@ func BenchmarkShellPolling(b *testing.B) {
 
 	for b.Loop() {
 		// Use a short sleep to measure polling overhead
-		_, _, exitCode, _, err := shell.Exec(b.Context(), "sleep 0.02", 0)
+		_, _, err := shell.Exec(b.Context(), "sleep 0.02")
+		exitCode := ExitCode(err)
 		if err != nil || exitCode != 0 {
 			b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
 		}

internal/llm/tools/shell/shell.go 🔗

@@ -1,323 +1,87 @@
 package shell
 
 import (
+	"bytes"
 	"context"
 	"errors"
 	"fmt"
 	"os"
-	"os/exec"
-	"path/filepath"
 	"strings"
 	"sync"
-	"syscall"
-	"time"
+
+	"github.com/charmbracelet/crush/internal/logging"
+	"mvdan.cc/sh/v3/expand"
+	"mvdan.cc/sh/v3/interp"
+	"mvdan.cc/sh/v3/syntax"
 )
 
 type PersistentShell struct {
-	cmd          *exec.Cmd
-	stdin        *os.File
-	isAlive      bool
-	cwd          string
-	mu           sync.Mutex
-	commandQueue chan *commandExecution
-}
-
-type commandExecution struct {
-	command    string
-	timeout    time.Duration
-	resultChan chan commandResult
-	ctx        context.Context
-}
-
-type commandResult struct {
-	stdout      string
-	stderr      string
-	exitCode    int
-	interrupted bool
-	err         error
+	env []string
+	cwd string
+	mu  sync.Mutex
 }
 
 var (
-	shellInstance     *PersistentShell
-	shellInstanceOnce sync.Once
+	once          sync.Once
+	shellInstance *PersistentShell
 )
 
-func GetPersistentShell(workingDir string) *PersistentShell {
-	shellInstanceOnce.Do(func() {
-		shellInstance = newPersistentShell(workingDir)
+func GetPersistentShell(cwd string) *PersistentShell {
+	once.Do(func() {
+		shellInstance = newPersistentShell(cwd)
 	})
-
-	if shellInstance == nil {
-		shellInstance = newPersistentShell(workingDir)
-	} else if !shellInstance.isAlive {
-		shellInstance = newPersistentShell(shellInstance.cwd)
-	}
-
 	return shellInstance
 }
 
 func newPersistentShell(cwd string) *PersistentShell {
-	// Default to environment variable
-	shellPath := os.Getenv("SHELL")
-	if shellPath == "" {
-		shellPath = "/bin/bash"
-	}
-
-	// Default shell args
-	shellArgs := []string{"-l"}
-
-	cmd := exec.Command(shellPath, shellArgs...)
-	cmd.Dir = cwd
-
-	stdinPipe, err := cmd.StdinPipe()
-	if err != nil {
-		return nil
-	}
-
-	cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
-
-	err = cmd.Start()
-	if err != nil {
-		return nil
-	}
-
-	shell := &PersistentShell{
-		cmd:          cmd,
-		stdin:        stdinPipe.(*os.File),
-		isAlive:      true,
-		cwd:          cwd,
-		commandQueue: make(chan *commandExecution, 10),
+	return &PersistentShell{
+		cwd: cwd,
+		env: os.Environ(),
 	}
-
-	go func() {
-		defer func() {
-			if r := recover(); r != nil {
-				fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
-				shell.isAlive = false
-				close(shell.commandQueue)
-			}
-		}()
-		shell.processCommands()
-	}()
-
-	go func() {
-		err := cmd.Wait()
-		if err != nil {
-			// Log the error if needed
-		}
-		shell.isAlive = false
-		close(shell.commandQueue)
-	}()
-
-	return shell
 }
 
-func (s *PersistentShell) processCommands() {
-	for cmd := range s.commandQueue {
-		result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
-		cmd.resultChan <- result
-	}
-}
-
-func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
+func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
-	if !s.isAlive {
-		return commandResult{
-			stderr:   "Shell is not alive",
-			exitCode: 1,
-			err:      errors.New("shell is not alive"),
-		}
-	}
-
-	tempDir := os.TempDir()
-	stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
-	stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
-	statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
-	cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
-
-	defer func() {
-		os.Remove(stdoutFile)
-		os.Remove(stderrFile)
-		os.Remove(statusFile)
-		os.Remove(cwdFile)
-	}()
-
-	fullCommand := fmt.Sprintf(`
-eval %s < /dev/null > %s 2> %s
-EXEC_EXIT_CODE=$?
-pwd > %s
-echo $EXEC_EXIT_CODE > %s
-`,
-		shellQuote(command),
-		shellQuote(stdoutFile),
-		shellQuote(stderrFile),
-		shellQuote(cwdFile),
-		shellQuote(statusFile),
-	)
-
-	_, err := s.stdin.Write([]byte(fullCommand + "\n"))
+	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
 	if err != nil {
-		return commandResult{
-			stderr:   fmt.Sprintf("Failed to write command to shell: %v", err),
-			exitCode: 1,
-			err:      err,
-		}
+		return "", "", fmt.Errorf("could not parse command: %w", err)
 	}
 
-	interrupted := false
-
-	startTime := time.Now()
-
-	done := make(chan bool)
-	go func() {
-		// Use exponential backoff polling
-		pollInterval := 1 * time.Millisecond
-		maxPollInterval := 100 * time.Millisecond
-
-		ticker := time.NewTicker(pollInterval)
-		defer ticker.Stop()
-
-		for {
-			select {
-			case <-ctx.Done():
-				s.killChildren()
-				interrupted = true
-				done <- true
-				return
-
-			case <-ticker.C:
-				if fileExists(statusFile) && fileSize(statusFile) > 0 {
-					done <- true
-					return
-				}
-
-				if timeout > 0 {
-					elapsed := time.Since(startTime)
-					if elapsed > timeout {
-						s.killChildren()
-						interrupted = true
-						done <- true
-						return
-					}
-				}
-
-				// Exponential backoff to reduce CPU usage for longer-running commands
-				if pollInterval < maxPollInterval {
-					pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
-					ticker.Reset(pollInterval)
-				}
-			}
-		}
-	}()
-
-	<-done
-
-	stdout := readFileOrEmpty(stdoutFile)
-	stderr := readFileOrEmpty(stderrFile)
-	exitCodeStr := readFileOrEmpty(statusFile)
-	newCwd := readFileOrEmpty(cwdFile)
-
-	exitCode := 0
-	if exitCodeStr != "" {
-		fmt.Sscanf(exitCodeStr, "%d", &exitCode)
-	} else if interrupted {
-		exitCode = 143
-		stderr += "\nCommand execution timed out or was interrupted"
-	}
-
-	if newCwd != "" {
-		s.cwd = strings.TrimSpace(newCwd)
-	}
-
-	return commandResult{
-		stdout:      stdout,
-		stderr:      stderr,
-		exitCode:    exitCode,
-		interrupted: interrupted,
-	}
-}
-
-func (s *PersistentShell) killChildren() {
-	if s.cmd == nil || s.cmd.Process == nil {
-		return
-	}
-
-	pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
-	output, err := pgrepCmd.Output()
+	var stdout, stderr bytes.Buffer
+	runner, err := interp.New(
+		interp.StdIO(nil, &stdout, &stderr),
+		interp.Interactive(false),
+		interp.Env(expand.ListEnviron(s.env...)),
+		interp.Dir(s.cwd),
+	)
 	if err != nil {
-		return
+		return "", "", fmt.Errorf("could not run command: %w", err)
 	}
 
-	for pidStr := range strings.SplitSeq(string(output), "\n") {
-		if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
-			var pid int
-			fmt.Sscanf(pidStr, "%d", &pid)
-			if pid > 0 {
-				proc, err := os.FindProcess(pid)
-				if err == nil {
-					proc.Signal(syscall.SIGTERM)
-				}
-			}
-		}
+	err = runner.Run(ctx, line)
+	s.cwd = runner.Dir
+	s.env = []string{}
+	for name, vr := range runner.Vars {
+		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
 	}
+	logging.InfoPersist("Command finished", "command", command, "err", err)
+	return stdout.String(), stderr.String(), err
 }
 
-func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
-	if !s.isAlive {
-		return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
-	}
-
-	timeout := time.Duration(timeoutMs) * time.Millisecond
-
-	resultChan := make(chan commandResult)
-	s.commandQueue <- &commandExecution{
-		command:    command,
-		timeout:    timeout,
-		resultChan: resultChan,
-		ctx:        ctx,
-	}
-
-	result := <-resultChan
-	return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
+func IsInterrupt(err error) bool {
+	return errors.Is(err, context.Canceled) ||
+		errors.Is(err, context.DeadlineExceeded)
 }
 
-func (s *PersistentShell) Close() {
-	s.mu.Lock()
-	defer s.mu.Unlock()
-
-	if !s.isAlive {
-		return
+func ExitCode(err error) int {
+	if err == nil {
+		return 0
 	}
-
-	s.stdin.Write([]byte("exit\n"))
-
-	s.cmd.Process.Kill()
-	s.isAlive = false
-}
-
-func shellQuote(s string) string {
-	return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
-}
-
-func readFileOrEmpty(path string) string {
-	content, err := os.ReadFile(path)
-	if err != nil {
-		return ""
+	status, ok := interp.IsExitStatus(err)
+	if ok {
+		return int(status)
 	}
-	return string(content)
-}
-
-func fileExists(path string) bool {
-	_, err := os.Stat(path)
-	return err == nil
+	return 1
 }
-
-func fileSize(path string) int64 {
-	info, err := os.Stat(path)
-	if err != nil {
-		return 0
-	}
-	return info.Size()
-}

internal/llm/tools/shell/shell_test.go 🔗

@@ -2,28 +2,81 @@ package shell
 
 import (
 	"context"
-	"os"
 	"testing"
-
-	"github.com/stretchr/testify/require"
+	"time"
 )
 
 // Benchmark to measure CPU efficiency
 func BenchmarkShellQuickCommands(b *testing.B) {
-	tmpDir, err := os.MkdirTemp("", "shell-bench")
-	require.NoError(b, err)
-	defer os.RemoveAll(tmpDir)
-
-	shell := GetPersistentShell(tmpDir)
-	defer shell.Close()
+	shell := newPersistentShell(b.TempDir())
 
-	b.ResetTimer()
 	b.ReportAllocs()
 
 	for b.Loop() {
-		_, _, exitCode, _, err := shell.Exec(context.Background(), "echo test", 0)
+		_, _, err := shell.Exec(context.Background(), "echo test")
+		exitCode := ExitCode(err)
 		if err != nil || exitCode != 0 {
 			b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
 		}
 	}
-}
+}
+
+func TestTestTimeout(t *testing.T) {
+	ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
+	t.Cleanup(cancel)
+
+	shell := newPersistentShell(t.TempDir())
+	_, _, err := shell.Exec(ctx, "sleep 10")
+	if status := ExitCode(err); status == 0 {
+		t.Fatalf("Expected non-zero exit status, got %d", status)
+	}
+	if !IsInterrupt(err) {
+		t.Fatalf("Expected command to be interrupted, but it was not")
+	}
+	if err == nil {
+		t.Fatalf("Expected an error due to timeout, but got none")
+	}
+}
+
+func TestTestCancel(t *testing.T) {
+	ctx, cancel := context.WithCancel(t.Context())
+	cancel() // immediately cancel the context
+
+	shell := newPersistentShell(t.TempDir())
+	_, _, err := shell.Exec(ctx, "sleep 10")
+	if status := ExitCode(err); status == 0 {
+		t.Fatalf("Expected non-zero exit status, got %d", status)
+	}
+	if !IsInterrupt(err) {
+		t.Fatalf("Expected command to be interrupted, but it was not")
+	}
+	if err == nil {
+		t.Fatalf("Expected an error due to cancel, but got none")
+	}
+}
+
+func TestRunCommandError(t *testing.T) {
+	shell := newPersistentShell(t.TempDir())
+	_, _, err := shell.Exec(t.Context(), "nopenopenope")
+	if status := ExitCode(err); status == 0 {
+		t.Fatalf("Expected non-zero exit status, got %d", status)
+	}
+	if IsInterrupt(err) {
+		t.Fatalf("Expected command to not be interrupted, but it was")
+	}
+	if err == nil {
+		t.Fatalf("Expected an error, got nil")
+	}
+}
+
+func TestRunContinuity(t *testing.T) {
+	shell := newPersistentShell(t.TempDir())
+	shell.Exec(t.Context(), "export FOO=bar")
+	dst := t.TempDir()
+	shell.Exec(t.Context(), "cd "+dst)
+	out, _, _ := shell.Exec(t.Context(), "echo $FOO ; pwd")
+	expect := "bar\n" + dst + "\n"
+	if out != expect {
+		t.Fatalf("Expected output %q, got %q", expect, out)
+	}
+}