diff --git a/internal/llm/tools/shell/shell.go b/internal/llm/tools/shell/shell.go index da23a03d845c2727765716a56c25782f1c20282a..25ba75a7877a7e4009991c78e6f09774c0d1bb79 100644 --- a/internal/llm/tools/shell/shell.go +++ b/internal/llm/tools/shell/shell.go @@ -1,6 +1,7 @@ package shell import ( + "cmp" "context" "errors" "fmt" @@ -67,16 +68,15 @@ func newPersistentShell(cwd string) *PersistentShell { shellArgs = cfg.Shell.Args } - if shellPath == "" { - shellPath = os.Getenv("SHELL") - if shellPath == "" { - shellPath = "/bin/bash" - } + shellPath = cmp.Or(shellPath, os.Getenv("SHELL"), "/bin/bash") + if !strings.HasSuffix(shellPath, "bash") && !strings.HasSuffix(shellPath, "zsh") { + logging.Warn("only bash and zsh are supported at this time", "shell", shellPath) + shellPath = "/bin/bash" } // Default shell args if len(shellArgs) == 0 { - shellArgs = []string{"-l"} + shellArgs = []string{"--login"} } cmd := exec.Command(shellPath, shellArgs...) @@ -127,12 +127,15 @@ func newPersistentShell(cwd string) *PersistentShell { func (s *PersistentShell) processCommands() { for cmd := range s.commandQueue { - result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx) - cmd.resultChan <- result + cmd.resultChan <- s.execCommand(cmd.ctx, cmd.command, cmd.timeout) } } -func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult { +const runBashCommandFormat = `%s %q 2>%q +echo $? >%q +pwd >%q` + +func (s *PersistentShell) execCommand(ctx context.Context, command string, timeout time.Duration) commandResult { s.mu.Lock() defer s.mu.Unlock() @@ -144,34 +147,22 @@ func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx } } - tempDir := os.TempDir() - stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano())) - stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano())) - statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano())) - cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano())) + tmp := os.TempDir() + now := time.Now().UnixNano() + stdoutFile := filepath.Join(tmp, fmt.Sprintf("crush-stdout-%d", now)) + stderrFile := filepath.Join(tmp, fmt.Sprintf("crush-stderr-%d", now)) + statusFile := filepath.Join(tmp, fmt.Sprintf("crush-status-%d", now)) + cwdFile := filepath.Join(tmp, fmt.Sprintf("crush-cwd-%d", now)) defer func() { - os.Remove(stdoutFile) - os.Remove(stderrFile) - os.Remove(statusFile) - os.Remove(cwdFile) + _ = os.Remove(stdoutFile) + _ = os.Remove(stderrFile) + _ = os.Remove(statusFile) + _ = os.Remove(cwdFile) }() - fullCommand := fmt.Sprintf(` -eval %s < /dev/null > %s 2> %s -EXEC_EXIT_CODE=$? -pwd > %s -echo $EXEC_EXIT_CODE > %s -`, - shellQuote(command), - shellQuote(stdoutFile), - shellQuote(stderrFile), - shellQuote(cwdFile), - shellQuote(statusFile), - ) - - _, err := s.stdin.Write([]byte(fullCommand + "\n")) - if err != nil { + script := fmt.Sprintf(runBashCommandFormat, command, stdoutFile, stderrFile, statusFile, cwdFile) + if _, err := s.stdin.Write([]byte(script + "\n")); err != nil { return commandResult{ stderr: fmt.Sprintf("Failed to write command to shell: %v", err), exitCode: 1, @@ -180,18 +171,18 @@ echo $EXEC_EXIT_CODE > %s } interrupted := false - - startTime := time.Now() - done := make(chan bool) go func() { // Use exponential backoff polling - pollInterval := 1 * time.Millisecond - maxPollInterval := 100 * time.Millisecond + pollInterval := 10 * time.Millisecond + maxPollInterval := time.Second ticker := time.NewTicker(pollInterval) defer ticker.Stop() + timeoutTicker := time.NewTicker(cmp.Or(timeout, time.Hour*99999)) + defer timeoutTicker.Stop() + for { select { case <-ctx.Done(): @@ -200,22 +191,18 @@ echo $EXEC_EXIT_CODE > %s done <- true return + case <-timeoutTicker.C: + s.killChildren() + interrupted = true + done <- true + return + case <-ticker.C: - if fileExists(statusFile) && fileSize(statusFile) > 0 { + if fileSize(statusFile) > 0 { done <- true return } - if timeout > 0 { - elapsed := time.Since(startTime) - if elapsed > timeout { - s.killChildren() - interrupted = true - done <- true - return - } - } - // Exponential backoff to reduce CPU usage for longer-running commands if pollInterval < maxPollInterval { pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval) @@ -280,12 +267,10 @@ func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs in return "", "Shell is not alive", 1, false, errors.New("shell is not alive") } - timeout := time.Duration(timeoutMs) * time.Millisecond - resultChan := make(chan commandResult) s.commandQueue <- &commandExecution{ command: command, - timeout: timeout, + timeout: time.Duration(timeoutMs) * time.Millisecond, resultChan: resultChan, ctx: ctx, } @@ -310,10 +295,6 @@ func (s *PersistentShell) Close() { s.isAlive = false } -func shellQuote(s string) string { - return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'" -} - func readFileOrEmpty(path string) string { content, err := os.ReadFile(path) if err != nil { @@ -322,11 +303,6 @@ func readFileOrEmpty(path string) string { return string(content) } -func fileExists(path string) bool { - _, err := os.Stat(path) - return err == nil -} - func fileSize(path string) int64 { info, err := os.Stat(path) if err != nil {