fix: improve shell

Carlos Alexandro Becker created

Change summary

internal/llm/tools/shell/shell.go | 98 ++++++++++++--------------------
1 file changed, 37 insertions(+), 61 deletions(-)

Detailed changes

internal/llm/tools/shell/shell.go 🔗

@@ -1,6 +1,7 @@
 package shell
 
 import (
+	"cmp"
 	"context"
 	"errors"
 	"fmt"
@@ -67,16 +68,15 @@ func newPersistentShell(cwd string) *PersistentShell {
 		shellArgs = cfg.Shell.Args
 	}
 
-	if shellPath == "" {
-		shellPath = os.Getenv("SHELL")
-		if shellPath == "" {
-			shellPath = "/bin/bash"
-		}
+	shellPath = cmp.Or(shellPath, os.Getenv("SHELL"), "/bin/bash")
+	if !strings.HasSuffix(shellPath, "bash") && !strings.HasSuffix(shellPath, "zsh") {
+		logging.Warn("only bash and zsh are supported at this time", "shell", shellPath)
+		shellPath = "/bin/bash"
 	}
 
 	// Default shell args
 	if len(shellArgs) == 0 {
-		shellArgs = []string{"-l"}
+		shellArgs = []string{"--login"}
 	}
 
 	cmd := exec.Command(shellPath, shellArgs...)
@@ -127,12 +127,15 @@ func newPersistentShell(cwd string) *PersistentShell {
 
 func (s *PersistentShell) processCommands() {
 	for cmd := range s.commandQueue {
-		result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
-		cmd.resultChan <- result
+		cmd.resultChan <- s.execCommand(cmd.ctx, cmd.command, cmd.timeout)
 	}
 }
 
-func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
+const runBashCommandFormat = `%s </dev/null >%q 2>%q
+echo $? >%q
+pwd >%q`
+
+func (s *PersistentShell) execCommand(ctx context.Context, command string, timeout time.Duration) commandResult {
 	s.mu.Lock()
 	defer s.mu.Unlock()
 
@@ -144,34 +147,22 @@ func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx
 		}
 	}
 
-	tempDir := os.TempDir()
-	stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
-	stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
-	statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
-	cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
+	tmp := os.TempDir()
+	now := time.Now().UnixNano()
+	stdoutFile := filepath.Join(tmp, fmt.Sprintf("crush-stdout-%d", now))
+	stderrFile := filepath.Join(tmp, fmt.Sprintf("crush-stderr-%d", now))
+	statusFile := filepath.Join(tmp, fmt.Sprintf("crush-status-%d", now))
+	cwdFile := filepath.Join(tmp, fmt.Sprintf("crush-cwd-%d", now))
 
 	defer func() {
-		os.Remove(stdoutFile)
-		os.Remove(stderrFile)
-		os.Remove(statusFile)
-		os.Remove(cwdFile)
+		_ = os.Remove(stdoutFile)
+		_ = os.Remove(stderrFile)
+		_ = os.Remove(statusFile)
+		_ = os.Remove(cwdFile)
 	}()
 
-	fullCommand := fmt.Sprintf(`
-eval %s < /dev/null > %s 2> %s
-EXEC_EXIT_CODE=$?
-pwd > %s
-echo $EXEC_EXIT_CODE > %s
-`,
-		shellQuote(command),
-		shellQuote(stdoutFile),
-		shellQuote(stderrFile),
-		shellQuote(cwdFile),
-		shellQuote(statusFile),
-	)
-
-	_, err := s.stdin.Write([]byte(fullCommand + "\n"))
-	if err != nil {
+	script := fmt.Sprintf(runBashCommandFormat, command, stdoutFile, stderrFile, statusFile, cwdFile)
+	if _, err := s.stdin.Write([]byte(script + "\n")); err != nil {
 		return commandResult{
 			stderr:   fmt.Sprintf("Failed to write command to shell: %v", err),
 			exitCode: 1,
@@ -180,18 +171,18 @@ echo $EXEC_EXIT_CODE > %s
 	}
 
 	interrupted := false
-
-	startTime := time.Now()
-
 	done := make(chan bool)
 	go func() {
 		// Use exponential backoff polling
-		pollInterval := 1 * time.Millisecond
-		maxPollInterval := 100 * time.Millisecond
+		pollInterval := 10 * time.Millisecond
+		maxPollInterval := time.Second
 
 		ticker := time.NewTicker(pollInterval)
 		defer ticker.Stop()
 
+		timeoutTicker := time.NewTicker(cmp.Or(timeout, time.Hour*99999))
+		defer timeoutTicker.Stop()
+
 		for {
 			select {
 			case <-ctx.Done():
@@ -200,22 +191,18 @@ echo $EXEC_EXIT_CODE > %s
 				done <- true
 				return
 
+			case <-timeoutTicker.C:
+				s.killChildren()
+				interrupted = true
+				done <- true
+				return
+
 			case <-ticker.C:
-				if fileExists(statusFile) && fileSize(statusFile) > 0 {
+				if fileSize(statusFile) > 0 {
 					done <- true
 					return
 				}
 
-				if timeout > 0 {
-					elapsed := time.Since(startTime)
-					if elapsed > timeout {
-						s.killChildren()
-						interrupted = true
-						done <- true
-						return
-					}
-				}
-
 				// Exponential backoff to reduce CPU usage for longer-running commands
 				if pollInterval < maxPollInterval {
 					pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
@@ -280,12 +267,10 @@ func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs in
 		return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
 	}
 
-	timeout := time.Duration(timeoutMs) * time.Millisecond
-
 	resultChan := make(chan commandResult)
 	s.commandQueue <- &commandExecution{
 		command:    command,
-		timeout:    timeout,
+		timeout:    time.Duration(timeoutMs) * time.Millisecond,
 		resultChan: resultChan,
 		ctx:        ctx,
 	}
@@ -310,10 +295,6 @@ func (s *PersistentShell) Close() {
 	s.isAlive = false
 }
 
-func shellQuote(s string) string {
-	return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
-}
-
 func readFileOrEmpty(path string) string {
 	content, err := os.ReadFile(path)
 	if err != nil {
@@ -322,11 +303,6 @@ func readFileOrEmpty(path string) string {
 	return string(content)
 }
 
-func fileExists(path string) bool {
-	_, err := os.Stat(path)
-	return err == nil
-}
-
 func fileSize(path string) int64 {
 	info, err := os.Stat(path)
 	if err != nil {