@@ -13,8 +13,7 @@ func TestShellPerformanceComparison(t *testing.T) {
// Test quick command
start := time.Now()
- stdout, stderr, err := shell.Exec(t.Context(), "echo 'hello'")
- exitCode := ExitCode(err)
+ stdout, stderr, exitCode, _, err := shell.Exec(t.Context(), "echo 'hello'", 0)
duration := time.Since(start)
require.NoError(t, err)
@@ -33,8 +32,7 @@ func BenchmarkShellPolling(b *testing.B) {
for b.Loop() {
// Use a short sleep to measure polling overhead
- _, _, err := shell.Exec(b.Context(), "sleep 0.02")
- exitCode := ExitCode(err)
+ _, _, exitCode, _, err := shell.Exec(b.Context(), "sleep 0.02", 0)
if err != nil || exitCode != 0 {
b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
}
@@ -1,87 +1,323 @@
package shell
import (
- "bytes"
"context"
"errors"
"fmt"
"os"
+ "os/exec"
+ "path/filepath"
"strings"
"sync"
-
- "github.com/charmbracelet/crush/internal/logging"
- "mvdan.cc/sh/v3/expand"
- "mvdan.cc/sh/v3/interp"
- "mvdan.cc/sh/v3/syntax"
+ "syscall"
+ "time"
)
type PersistentShell struct {
- env []string
- cwd string
- mu sync.Mutex
+ cmd *exec.Cmd
+ stdin *os.File
+ isAlive bool
+ cwd string
+ mu sync.Mutex
+ commandQueue chan *commandExecution
+}
+
+type commandExecution struct {
+ command string
+ timeout time.Duration
+ resultChan chan commandResult
+ ctx context.Context
+}
+
+type commandResult struct {
+ stdout string
+ stderr string
+ exitCode int
+ interrupted bool
+ err error
}
var (
- once sync.Once
- shellInstance *PersistentShell
+ shellInstance *PersistentShell
+ shellInstanceOnce sync.Once
)
-func GetPersistentShell(cwd string) *PersistentShell {
- once.Do(func() {
- shellInstance = newPersistentShell(cwd)
+func GetPersistentShell(workingDir string) *PersistentShell {
+ shellInstanceOnce.Do(func() {
+ shellInstance = newPersistentShell(workingDir)
})
+
+ if shellInstance == nil {
+ shellInstance = newPersistentShell(workingDir)
+ } else if !shellInstance.isAlive {
+ shellInstance = newPersistentShell(shellInstance.cwd)
+ }
+
return shellInstance
}
func newPersistentShell(cwd string) *PersistentShell {
- return &PersistentShell{
- cwd: cwd,
- env: os.Environ(),
+ // Default to environment variable
+ shellPath := os.Getenv("SHELL")
+ if shellPath == "" {
+ shellPath = "/bin/bash"
+ }
+
+ // Default shell args
+ shellArgs := []string{"-l"}
+
+ cmd := exec.Command(shellPath, shellArgs...)
+ cmd.Dir = cwd
+
+ stdinPipe, err := cmd.StdinPipe()
+ if err != nil {
+ return nil
+ }
+
+ cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
+
+ err = cmd.Start()
+ if err != nil {
+ return nil
+ }
+
+ shell := &PersistentShell{
+ cmd: cmd,
+ stdin: stdinPipe.(*os.File),
+ isAlive: true,
+ cwd: cwd,
+ commandQueue: make(chan *commandExecution, 10),
}
+
+ go func() {
+ defer func() {
+ if r := recover(); r != nil {
+ fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
+ shell.isAlive = false
+ close(shell.commandQueue)
+ }
+ }()
+ shell.processCommands()
+ }()
+
+ go func() {
+ err := cmd.Wait()
+ if err != nil {
+ // Log the error if needed
+ }
+ shell.isAlive = false
+ close(shell.commandQueue)
+ }()
+
+ return shell
}
-func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
+func (s *PersistentShell) processCommands() {
+ for cmd := range s.commandQueue {
+ result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
+ cmd.resultChan <- result
+ }
+}
+
+func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
s.mu.Lock()
defer s.mu.Unlock()
- line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
- if err != nil {
- return "", "", fmt.Errorf("could not parse command: %w", err)
+ if !s.isAlive {
+ return commandResult{
+ stderr: "Shell is not alive",
+ exitCode: 1,
+ err: errors.New("shell is not alive"),
+ }
}
- var stdout, stderr bytes.Buffer
- runner, err := interp.New(
- interp.StdIO(nil, &stdout, &stderr),
- interp.Interactive(false),
- interp.Env(expand.ListEnviron(s.env...)),
- interp.Dir(s.cwd),
+ tempDir := os.TempDir()
+ stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
+ stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
+ statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
+ cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
+
+ defer func() {
+ os.Remove(stdoutFile)
+ os.Remove(stderrFile)
+ os.Remove(statusFile)
+ os.Remove(cwdFile)
+ }()
+
+ fullCommand := fmt.Sprintf(`
+eval %s < /dev/null > %s 2> %s
+EXEC_EXIT_CODE=$?
+pwd > %s
+echo $EXEC_EXIT_CODE > %s
+`,
+ shellQuote(command),
+ shellQuote(stdoutFile),
+ shellQuote(stderrFile),
+ shellQuote(cwdFile),
+ shellQuote(statusFile),
)
+
+ _, err := s.stdin.Write([]byte(fullCommand + "\n"))
if err != nil {
- return "", "", fmt.Errorf("could not run command: %w", err)
+ return commandResult{
+ stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
+ exitCode: 1,
+ err: err,
+ }
}
- err = runner.Run(ctx, line)
- s.cwd = runner.Dir
- s.env = []string{}
- for name, vr := range runner.Vars {
- s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
+ interrupted := false
+
+ startTime := time.Now()
+
+ done := make(chan bool)
+ go func() {
+ // Use exponential backoff polling
+ pollInterval := 1 * time.Millisecond
+ maxPollInterval := 100 * time.Millisecond
+
+ ticker := time.NewTicker(pollInterval)
+ defer ticker.Stop()
+
+ for {
+ select {
+ case <-ctx.Done():
+ s.killChildren()
+ interrupted = true
+ done <- true
+ return
+
+ case <-ticker.C:
+ if fileExists(statusFile) && fileSize(statusFile) > 0 {
+ done <- true
+ return
+ }
+
+ if timeout > 0 {
+ elapsed := time.Since(startTime)
+ if elapsed > timeout {
+ s.killChildren()
+ interrupted = true
+ done <- true
+ return
+ }
+ }
+
+ // Exponential backoff to reduce CPU usage for longer-running commands
+ if pollInterval < maxPollInterval {
+ pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
+ ticker.Reset(pollInterval)
+ }
+ }
+ }
+ }()
+
+ <-done
+
+ stdout := readFileOrEmpty(stdoutFile)
+ stderr := readFileOrEmpty(stderrFile)
+ exitCodeStr := readFileOrEmpty(statusFile)
+ newCwd := readFileOrEmpty(cwdFile)
+
+ exitCode := 0
+ if exitCodeStr != "" {
+ fmt.Sscanf(exitCodeStr, "%d", &exitCode)
+ } else if interrupted {
+ exitCode = 143
+ stderr += "\nCommand execution timed out or was interrupted"
+ }
+
+ if newCwd != "" {
+ s.cwd = strings.TrimSpace(newCwd)
+ }
+
+ return commandResult{
+ stdout: stdout,
+ stderr: stderr,
+ exitCode: exitCode,
+ interrupted: interrupted,
}
- logging.InfoPersist("Command finished", "command", command, "err", err)
- return stdout.String(), stderr.String(), err
}
-func IsInterrupt(err error) bool {
- return errors.Is(err, context.Canceled) ||
- errors.Is(err, context.DeadlineExceeded)
+func (s *PersistentShell) killChildren() {
+ if s.cmd == nil || s.cmd.Process == nil {
+ return
+ }
+
+ pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
+ output, err := pgrepCmd.Output()
+ if err != nil {
+ return
+ }
+
+ for pidStr := range strings.SplitSeq(string(output), "\n") {
+ if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
+ var pid int
+ fmt.Sscanf(pidStr, "%d", &pid)
+ if pid > 0 {
+ proc, err := os.FindProcess(pid)
+ if err == nil {
+ proc.Signal(syscall.SIGTERM)
+ }
+ }
+ }
+ }
}
-func ExitCode(err error) int {
- if err == nil {
- return 0
+func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
+ if !s.isAlive {
+ return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
+ }
+
+ timeout := time.Duration(timeoutMs) * time.Millisecond
+
+ resultChan := make(chan commandResult)
+ s.commandQueue <- &commandExecution{
+ command: command,
+ timeout: timeout,
+ resultChan: resultChan,
+ ctx: ctx,
+ }
+
+ result := <-resultChan
+ return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
+}
+
+func (s *PersistentShell) Close() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if !s.isAlive {
+ return
}
- status, ok := interp.IsExitStatus(err)
- if ok {
- return int(status)
+
+ s.stdin.Write([]byte("exit\n"))
+
+ s.cmd.Process.Kill()
+ s.isAlive = false
+}
+
+func shellQuote(s string) string {
+ return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
+}
+
+func readFileOrEmpty(path string) string {
+ content, err := os.ReadFile(path)
+ if err != nil {
+ return ""
}
- return 1
+ return string(content)
+}
+
+func fileExists(path string) bool {
+ _, err := os.Stat(path)
+ return err == nil
}
+
+func fileSize(path string) int64 {
+ info, err := os.Stat(path)
+ if err != nil {
+ return 0
+ }
+ return info.Size()
+}
@@ -2,81 +2,28 @@ package shell
import (
"context"
+ "os"
"testing"
- "time"
+
+ "github.com/stretchr/testify/require"
)
// Benchmark to measure CPU efficiency
func BenchmarkShellQuickCommands(b *testing.B) {
- shell := newPersistentShell(b.TempDir())
+ tmpDir, err := os.MkdirTemp("", "shell-bench")
+ require.NoError(b, err)
+ defer os.RemoveAll(tmpDir)
+
+ shell := GetPersistentShell(tmpDir)
+ defer shell.Close()
+ b.ResetTimer()
b.ReportAllocs()
for b.Loop() {
- _, _, err := shell.Exec(context.Background(), "echo test")
- exitCode := ExitCode(err)
+ _, _, exitCode, _, err := shell.Exec(context.Background(), "echo test", 0)
if err != nil || exitCode != 0 {
b.Fatalf("Command failed: %v, exit code: %d", err, exitCode)
}
}
-}
-
-func TestTestTimeout(t *testing.T) {
- ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond)
- t.Cleanup(cancel)
-
- shell := newPersistentShell(t.TempDir())
- _, _, err := shell.Exec(ctx, "sleep 10")
- if status := ExitCode(err); status == 0 {
- t.Fatalf("Expected non-zero exit status, got %d", status)
- }
- if !IsInterrupt(err) {
- t.Fatalf("Expected command to be interrupted, but it was not")
- }
- if err == nil {
- t.Fatalf("Expected an error due to timeout, but got none")
- }
-}
-
-func TestTestCancel(t *testing.T) {
- ctx, cancel := context.WithCancel(t.Context())
- cancel() // immediately cancel the context
-
- shell := newPersistentShell(t.TempDir())
- _, _, err := shell.Exec(ctx, "sleep 10")
- if status := ExitCode(err); status == 0 {
- t.Fatalf("Expected non-zero exit status, got %d", status)
- }
- if !IsInterrupt(err) {
- t.Fatalf("Expected command to be interrupted, but it was not")
- }
- if err == nil {
- t.Fatalf("Expected an error due to cancel, but got none")
- }
-}
-
-func TestRunCommandError(t *testing.T) {
- shell := newPersistentShell(t.TempDir())
- _, _, err := shell.Exec(t.Context(), "nopenopenope")
- if status := ExitCode(err); status == 0 {
- t.Fatalf("Expected non-zero exit status, got %d", status)
- }
- if IsInterrupt(err) {
- t.Fatalf("Expected command to not be interrupted, but it was")
- }
- if err == nil {
- t.Fatalf("Expected an error, got nil")
- }
-}
-
-func TestRunContinuity(t *testing.T) {
- shell := newPersistentShell(t.TempDir())
- shell.Exec(t.Context(), "export FOO=bar")
- dst := t.TempDir()
- shell.Exec(t.Context(), "cd "+dst)
- out, _, _ := shell.Exec(t.Context(), "echo $FOO ; pwd")
- expect := "bar\n" + dst + "\n"
- if out != expect {
- t.Fatalf("Expected output %q, got %q", expect, out)
- }
-}
+}
@@ -1,10 +1,12 @@
package permission
import (
+ "context"
"errors"
"path/filepath"
"slices"
"sync"
+ "time"
"github.com/charmbracelet/crush/internal/config"
"github.com/charmbracelet/crush/internal/pubsub"
@@ -44,9 +46,11 @@ type Service interface {
type permissionService struct {
*pubsub.Broker[PermissionRequest]
- sessionPermissions []PermissionRequest
- pendingRequests sync.Map
- autoApproveSessions []string
+ sessionPermissions []PermissionRequest
+ sessionPermissionsMu sync.RWMutex
+ pendingRequests sync.Map
+ autoApproveSessions []string
+ autoApproveSessionsMu sync.RWMutex
}
func (s *permissionService) GrantPersistent(permission PermissionRequest) {
@@ -54,7 +58,10 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) {
if ok {
respCh.(chan bool) <- true
}
+
+ s.sessionPermissionsMu.Lock()
s.sessionPermissions = append(s.sessionPermissions, permission)
+ s.sessionPermissionsMu.Unlock()
}
func (s *permissionService) Grant(permission PermissionRequest) {
@@ -72,9 +79,14 @@ func (s *permissionService) Deny(permission PermissionRequest) {
}
func (s *permissionService) Request(opts CreatePermissionRequest) bool {
- if slices.Contains(s.autoApproveSessions, opts.SessionID) {
+ s.autoApproveSessionsMu.RLock()
+ autoApprove := slices.Contains(s.autoApproveSessions, opts.SessionID)
+ s.autoApproveSessionsMu.RUnlock()
+
+ if autoApprove {
return true
}
+
dir := filepath.Dir(opts.Path)
if dir == "." {
dir = config.WorkingDirectory()
@@ -89,11 +101,14 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
Params: opts.Params,
}
+ s.sessionPermissionsMu.RLock()
for _, p := range s.sessionPermissions {
if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path {
+ s.sessionPermissionsMu.RUnlock()
return true
}
}
+ s.sessionPermissionsMu.RUnlock()
respCh := make(chan bool, 1)
@@ -102,13 +117,22 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool {
s.Publish(pubsub.CreatedEvent, permission)
- // Wait for the response with a timeout
- resp := <-respCh
- return resp
+ // Wait for the response with a timeout to prevent indefinite blocking
+ ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
+ defer cancel()
+
+ select {
+ case resp := <-respCh:
+ return resp
+ case <-ctx.Done():
+ return false // Timeout - deny by default
+ }
}
func (s *permissionService) AutoApproveSession(sessionID string) {
+ s.autoApproveSessionsMu.Lock()
s.autoApproveSessions = append(s.autoApproveSessions, sessionID)
+ s.autoApproveSessionsMu.Unlock()
}
func NewPermissionService() Service {