From 2937e194c746b5518bfe15746da8810b0864d1c4 Mon Sep 17 00:00:00 2001 From: Raphael Amorim Date: Tue, 17 Jun 2025 17:33:22 +0200 Subject: [PATCH] perf: concurrency improvements --- internal/llm/tools/shell/comparison_test.go | 6 +- internal/llm/tools/shell/shell.go | 326 +++++++++++++++++--- internal/llm/tools/shell/shell_test.go | 77 +---- internal/permission/permission.go | 38 ++- internal/pubsub/broker.go | 2 + 5 files changed, 328 insertions(+), 121 deletions(-) diff --git a/internal/llm/tools/shell/comparison_test.go b/internal/llm/tools/shell/comparison_test.go index 4906a5c8acb8136d7b7947177f999f7019080f5f..8fe2159b85d392d87b2602f77f105375cceb07c3 100644 --- a/internal/llm/tools/shell/comparison_test.go +++ b/internal/llm/tools/shell/comparison_test.go @@ -13,8 +13,7 @@ func TestShellPerformanceComparison(t *testing.T) { // Test quick command start := time.Now() - stdout, stderr, err := shell.Exec(t.Context(), "echo 'hello'") - exitCode := ExitCode(err) + stdout, stderr, exitCode, _, err := shell.Exec(t.Context(), "echo 'hello'", 0) duration := time.Since(start) require.NoError(t, err) @@ -33,8 +32,7 @@ func BenchmarkShellPolling(b *testing.B) { for b.Loop() { // Use a short sleep to measure polling overhead - _, _, err := shell.Exec(b.Context(), "sleep 0.02") - exitCode := ExitCode(err) + _, _, exitCode, _, err := shell.Exec(b.Context(), "sleep 0.02", 0) if err != nil || exitCode != 0 { b.Fatalf("Command failed: %v, exit code: %d", err, exitCode) } diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index cbf33e700c15ae06b66a4cbafa620bbc7ceb1405..c9d6826a7f4f17f3df29e1871790222c5df55fa6 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -1,87 +1,323 @@ package shell import ( - "bytes" "context" "errors" "fmt" "os" + "os/exec" + "path/filepath" "strings" "sync" - - "github.com/charmbracelet/crush/internal/logging" - "mvdan.cc/sh/v3/expand" - "mvdan.cc/sh/v3/interp" - "mvdan.cc/sh/v3/syntax" + "syscall" + "time" ) type PersistentShell struct { - env []string - cwd string - mu sync.Mutex + cmd *exec.Cmd + stdin *os.File + isAlive bool + cwd string + mu sync.Mutex + commandQueue chan *commandExecution +} + +type commandExecution struct { + command string + timeout time.Duration + resultChan chan commandResult + ctx context.Context +} + +type commandResult struct { + stdout string + stderr string + exitCode int + interrupted bool + err error } var ( - once sync.Once - shellInstance *PersistentShell + shellInstance *PersistentShell + shellInstanceOnce sync.Once ) -func GetPersistentShell(cwd string) *PersistentShell { - once.Do(func() { - shellInstance = newPersistentShell(cwd) +func GetPersistentShell(workingDir string) *PersistentShell { + shellInstanceOnce.Do(func() { + shellInstance = newPersistentShell(workingDir) }) + + if shellInstance == nil { + shellInstance = newPersistentShell(workingDir) + } else if !shellInstance.isAlive { + shellInstance = newPersistentShell(shellInstance.cwd) + } + return shellInstance } func newPersistentShell(cwd string) *PersistentShell { - return &PersistentShell{ - cwd: cwd, - env: os.Environ(), + // Default to environment variable + shellPath := os.Getenv("SHELL") + if shellPath == "" { + shellPath = "/bin/bash" + } + + // Default shell args + shellArgs := []string{"-l"} + + cmd := exec.Command(shellPath, shellArgs...) + cmd.Dir = cwd + + stdinPipe, err := cmd.StdinPipe() + if err != nil { + return nil + } + + cmd.Env = append(os.Environ(), "GIT_EDITOR=true") + + err = cmd.Start() + if err != nil { + return nil + } + + shell := &PersistentShell{ + cmd: cmd, + stdin: stdinPipe.(*os.File), + isAlive: true, + cwd: cwd, + commandQueue: make(chan *commandExecution, 10), } + + go func() { + defer func() { + if r := recover(); r != nil { + fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r) + shell.isAlive = false + close(shell.commandQueue) + } + }() + shell.processCommands() + }() + + go func() { + err := cmd.Wait() + if err != nil { + // Log the error if needed + } + shell.isAlive = false + close(shell.commandQueue) + }() + + return shell } -func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) { +func (s *PersistentShell) processCommands() { + for cmd := range s.commandQueue { + result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx) + cmd.resultChan <- result + } +} + +func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult { s.mu.Lock() defer s.mu.Unlock() - line, err := syntax.NewParser().Parse(strings.NewReader(command), "") - if err != nil { - return "", "", fmt.Errorf("could not parse command: %w", err) + if !s.isAlive { + return commandResult{ + stderr: "Shell is not alive", + exitCode: 1, + err: errors.New("shell is not alive"), + } } - var stdout, stderr bytes.Buffer - runner, err := interp.New( - interp.StdIO(nil, &stdout, &stderr), - interp.Interactive(false), - interp.Env(expand.ListEnviron(s.env...)), - interp.Dir(s.cwd), + tempDir := os.TempDir() + stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano())) + stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano())) + statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano())) + cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano())) + + defer func() { + os.Remove(stdoutFile) + os.Remove(stderrFile) + os.Remove(statusFile) + os.Remove(cwdFile) + }() + + fullCommand := fmt.Sprintf(` +eval %s < /dev/null > %s 2> %s +EXEC_EXIT_CODE=$? +pwd > %s +echo $EXEC_EXIT_CODE > %s +`, + shellQuote(command), + shellQuote(stdoutFile), + shellQuote(stderrFile), + shellQuote(cwdFile), + shellQuote(statusFile), ) + + _, err := s.stdin.Write([]byte(fullCommand + "\n")) if err != nil { - return "", "", fmt.Errorf("could not run command: %w", err) + return commandResult{ + stderr: fmt.Sprintf("Failed to write command to shell: %v", err), + exitCode: 1, + err: err, + } } - err = runner.Run(ctx, line) - s.cwd = runner.Dir - s.env = []string{} - for name, vr := range runner.Vars { - s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str)) + interrupted := false + + startTime := time.Now() + + done := make(chan bool) + go func() { + // Use exponential backoff polling + pollInterval := 1 * time.Millisecond + maxPollInterval := 100 * time.Millisecond + + ticker := time.NewTicker(pollInterval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + s.killChildren() + interrupted = true + done <- true + return + + case <-ticker.C: + if fileExists(statusFile) && fileSize(statusFile) > 0 { + done <- true + return + } + + if timeout > 0 { + elapsed := time.Since(startTime) + if elapsed > timeout { + s.killChildren() + interrupted = true + done <- true + return + } + } + + // Exponential backoff to reduce CPU usage for longer-running commands + if pollInterval < maxPollInterval { + pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval) + ticker.Reset(pollInterval) + } + } + } + }() + + <-done + + stdout := readFileOrEmpty(stdoutFile) + stderr := readFileOrEmpty(stderrFile) + exitCodeStr := readFileOrEmpty(statusFile) + newCwd := readFileOrEmpty(cwdFile) + + exitCode := 0 + if exitCodeStr != "" { + fmt.Sscanf(exitCodeStr, "%d", &exitCode) + } else if interrupted { + exitCode = 143 + stderr += "\nCommand execution timed out or was interrupted" + } + + if newCwd != "" { + s.cwd = strings.TrimSpace(newCwd) + } + + return commandResult{ + stdout: stdout, + stderr: stderr, + exitCode: exitCode, + interrupted: interrupted, } - logging.InfoPersist("Command finished", "command", command, "err", err) - return stdout.String(), stderr.String(), err } -func IsInterrupt(err error) bool { - return errors.Is(err, context.Canceled) || - errors.Is(err, context.DeadlineExceeded) +func (s *PersistentShell) killChildren() { + if s.cmd == nil || s.cmd.Process == nil { + return + } + + pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid)) + output, err := pgrepCmd.Output() + if err != nil { + return + } + + for pidStr := range strings.SplitSeq(string(output), "\n") { + if pidStr = strings.TrimSpace(pidStr); pidStr != "" { + var pid int + fmt.Sscanf(pidStr, "%d", &pid) + if pid > 0 { + proc, err := os.FindProcess(pid) + if err == nil { + proc.Signal(syscall.SIGTERM) + } + } + } + } } -func ExitCode(err error) int { - if err == nil { - return 0 +func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) { + if !s.isAlive { + return "", "Shell is not alive", 1, false, errors.New("shell is not alive") + } + + timeout := time.Duration(timeoutMs) * time.Millisecond + + resultChan := make(chan commandResult) + s.commandQueue <- &commandExecution{ + command: command, + timeout: timeout, + resultChan: resultChan, + ctx: ctx, + } + + result := <-resultChan + return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err +} + +func (s *PersistentShell) Close() { + s.mu.Lock() + defer s.mu.Unlock() + + if !s.isAlive { + return } - status, ok := interp.IsExitStatus(err) - if ok { - return int(status) + + s.stdin.Write([]byte("exit\n")) + + s.cmd.Process.Kill() + s.isAlive = false +} + +func shellQuote(s string) string { + return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" +} + +func readFileOrEmpty(path string) string { + content, err := os.ReadFile(path) + if err != nil { + return "" } - return 1 + return string(content) +} + +func fileExists(path string) bool { + _, err := os.Stat(path) + return err == nil } + +func fileSize(path string) int64 { + info, err := os.Stat(path) + if err != nil { + return 0 + } + return info.Size() +} \ No newline at end of file diff --git a/internal/llm/tools/shell/shell_test.go b/internal/llm/tools/shell/shell_test.go index 630995aaa756295ff1eb14b05d11a2ee8f634733..329e6e3fa6ba65b8a5784e6c699bd0053f171888 100644 --- a/internal/llm/tools/shell/shell_test.go +++ b/internal/llm/tools/shell/shell_test.go @@ -2,81 +2,28 @@ package shell import ( "context" + "os" "testing" - "time" + + "github.com/stretchr/testify/require" ) // Benchmark to measure CPU efficiency func BenchmarkShellQuickCommands(b *testing.B) { - shell := newPersistentShell(b.TempDir()) + tmpDir, err := os.MkdirTemp("", "shell-bench") + require.NoError(b, err) + defer os.RemoveAll(tmpDir) + + shell := GetPersistentShell(tmpDir) + defer shell.Close() + b.ResetTimer() b.ReportAllocs() for b.Loop() { - _, _, err := shell.Exec(context.Background(), "echo test") - exitCode := ExitCode(err) + _, _, exitCode, _, err := shell.Exec(context.Background(), "echo test", 0) if err != nil || exitCode != 0 { b.Fatalf("Command failed: %v, exit code: %d", err, exitCode) } } -} - -func TestTestTimeout(t *testing.T) { - ctx, cancel := context.WithTimeout(t.Context(), time.Millisecond) - t.Cleanup(cancel) - - shell := newPersistentShell(t.TempDir()) - _, _, err := shell.Exec(ctx, "sleep 10") - if status := ExitCode(err); status == 0 { - t.Fatalf("Expected non-zero exit status, got %d", status) - } - if !IsInterrupt(err) { - t.Fatalf("Expected command to be interrupted, but it was not") - } - if err == nil { - t.Fatalf("Expected an error due to timeout, but got none") - } -} - -func TestTestCancel(t *testing.T) { - ctx, cancel := context.WithCancel(t.Context()) - cancel() // immediately cancel the context - - shell := newPersistentShell(t.TempDir()) - _, _, err := shell.Exec(ctx, "sleep 10") - if status := ExitCode(err); status == 0 { - t.Fatalf("Expected non-zero exit status, got %d", status) - } - if !IsInterrupt(err) { - t.Fatalf("Expected command to be interrupted, but it was not") - } - if err == nil { - t.Fatalf("Expected an error due to cancel, but got none") - } -} - -func TestRunCommandError(t *testing.T) { - shell := newPersistentShell(t.TempDir()) - _, _, err := shell.Exec(t.Context(), "nopenopenope") - if status := ExitCode(err); status == 0 { - t.Fatalf("Expected non-zero exit status, got %d", status) - } - if IsInterrupt(err) { - t.Fatalf("Expected command to not be interrupted, but it was") - } - if err == nil { - t.Fatalf("Expected an error, got nil") - } -} - -func TestRunContinuity(t *testing.T) { - shell := newPersistentShell(t.TempDir()) - shell.Exec(t.Context(), "export FOO=bar") - dst := t.TempDir() - shell.Exec(t.Context(), "cd "+dst) - out, _, _ := shell.Exec(t.Context(), "echo $FOO ; pwd") - expect := "bar\n" + dst + "\n" - if out != expect { - t.Fatalf("Expected output %q, got %q", expect, out) - } -} +} \ No newline at end of file diff --git a/internal/permission/permission.go b/internal/permission/permission.go index 6790e1d208c02f24a9640b464f0253ef69cfcc77..4dec691e9e09507f4b86eddfa14fca5a6ef8a2ed 100644 --- a/internal/permission/permission.go +++ b/internal/permission/permission.go @@ -1,10 +1,12 @@ package permission import ( + "context" "errors" "path/filepath" "slices" "sync" + "time" "github.com/charmbracelet/crush/internal/config" "github.com/charmbracelet/crush/internal/pubsub" @@ -44,9 +46,11 @@ type Service interface { type permissionService struct { *pubsub.Broker[PermissionRequest] - sessionPermissions []PermissionRequest - pendingRequests sync.Map - autoApproveSessions []string + sessionPermissions []PermissionRequest + sessionPermissionsMu sync.RWMutex + pendingRequests sync.Map + autoApproveSessions []string + autoApproveSessionsMu sync.RWMutex } func (s *permissionService) GrantPersistent(permission PermissionRequest) { @@ -54,7 +58,10 @@ func (s *permissionService) GrantPersistent(permission PermissionRequest) { if ok { respCh.(chan bool) <- true } + + s.sessionPermissionsMu.Lock() s.sessionPermissions = append(s.sessionPermissions, permission) + s.sessionPermissionsMu.Unlock() } func (s *permissionService) Grant(permission PermissionRequest) { @@ -72,9 +79,14 @@ func (s *permissionService) Deny(permission PermissionRequest) { } func (s *permissionService) Request(opts CreatePermissionRequest) bool { - if slices.Contains(s.autoApproveSessions, opts.SessionID) { + s.autoApproveSessionsMu.RLock() + autoApprove := slices.Contains(s.autoApproveSessions, opts.SessionID) + s.autoApproveSessionsMu.RUnlock() + + if autoApprove { return true } + dir := filepath.Dir(opts.Path) if dir == "." { dir = config.WorkingDirectory() @@ -89,11 +101,14 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { Params: opts.Params, } + s.sessionPermissionsMu.RLock() for _, p := range s.sessionPermissions { if p.ToolName == permission.ToolName && p.Action == permission.Action && p.SessionID == permission.SessionID && p.Path == permission.Path { + s.sessionPermissionsMu.RUnlock() return true } } + s.sessionPermissionsMu.RUnlock() respCh := make(chan bool, 1) @@ -102,13 +117,22 @@ func (s *permissionService) Request(opts CreatePermissionRequest) bool { s.Publish(pubsub.CreatedEvent, permission) - // Wait for the response with a timeout - resp := <-respCh - return resp + // Wait for the response with a timeout to prevent indefinite blocking + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + select { + case resp := <-respCh: + return resp + case <-ctx.Done(): + return false // Timeout - deny by default + } } func (s *permissionService) AutoApproveSession(sessionID string) { + s.autoApproveSessionsMu.Lock() s.autoApproveSessions = append(s.autoApproveSessions, sessionID) + s.autoApproveSessionsMu.Unlock() } func NewPermissionService() Service { diff --git a/internal/pubsub/broker.go b/internal/pubsub/broker.go index 0de1be063b05e522c951ee9fe25c9358cf44ef52..80948d3d515a4fb5dad0d4dc36adbbff4e502993 100644 --- a/internal/pubsub/broker.go +++ b/internal/pubsub/broker.go @@ -111,6 +111,8 @@ func (b *Broker[T]) Publish(t EventType, payload T) { select { case sub <- event: default: + // Channel is full, subscriber is slow - skip this event + // This prevents blocking the publisher } } }