@@ -1,6 +1,7 @@
package shell
import (
+ "cmp"
"context"
"errors"
"fmt"
@@ -67,16 +68,15 @@ func newPersistentShell(cwd string) *PersistentShell {
shellArgs = cfg.Shell.Args
}
- if shellPath == "" {
- shellPath = os.Getenv("SHELL")
- if shellPath == "" {
- shellPath = "/bin/bash"
- }
+ shellPath = cmp.Or(shellPath, os.Getenv("SHELL"), "/bin/bash")
+ if !strings.HasSuffix(shellPath, "bash") && !strings.HasSuffix(shellPath, "zsh") {
+ logging.Warn("only bash and zsh are supported at this time", "shell", shellPath)
+ shellPath = "/bin/bash"
}
// Default shell args
if len(shellArgs) == 0 {
- shellArgs = []string{"-l"}
+ shellArgs = []string{"--login"}
}
cmd := exec.Command(shellPath, shellArgs...)
@@ -127,12 +127,15 @@ func newPersistentShell(cwd string) *PersistentShell {
func (s *PersistentShell) processCommands() {
for cmd := range s.commandQueue {
- result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
- cmd.resultChan <- result
+ cmd.resultChan <- s.execCommand(cmd.ctx, cmd.command, cmd.timeout)
}
}
-func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
+const runBashCommandFormat = `%s </dev/null >%q 2>%q
+echo $? >%q
+pwd >%q`
+
+func (s *PersistentShell) execCommand(ctx context.Context, command string, timeout time.Duration) commandResult {
s.mu.Lock()
defer s.mu.Unlock()
@@ -144,34 +147,22 @@ func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx
}
}
- tempDir := os.TempDir()
- stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
- stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
- statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
- cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
+ tmp := os.TempDir()
+ now := time.Now().UnixNano()
+ stdoutFile := filepath.Join(tmp, fmt.Sprintf("crush-stdout-%d", now))
+ stderrFile := filepath.Join(tmp, fmt.Sprintf("crush-stderr-%d", now))
+ statusFile := filepath.Join(tmp, fmt.Sprintf("crush-status-%d", now))
+ cwdFile := filepath.Join(tmp, fmt.Sprintf("crush-cwd-%d", now))
defer func() {
- os.Remove(stdoutFile)
- os.Remove(stderrFile)
- os.Remove(statusFile)
- os.Remove(cwdFile)
+ _ = os.Remove(stdoutFile)
+ _ = os.Remove(stderrFile)
+ _ = os.Remove(statusFile)
+ _ = os.Remove(cwdFile)
}()
- fullCommand := fmt.Sprintf(`
-eval %s < /dev/null > %s 2> %s
-EXEC_EXIT_CODE=$?
-pwd > %s
-echo $EXEC_EXIT_CODE > %s
-`,
- shellQuote(command),
- shellQuote(stdoutFile),
- shellQuote(stderrFile),
- shellQuote(cwdFile),
- shellQuote(statusFile),
- )
-
- _, err := s.stdin.Write([]byte(fullCommand + "\n"))
- if err != nil {
+ script := fmt.Sprintf(runBashCommandFormat, command, stdoutFile, stderrFile, statusFile, cwdFile)
+ if _, err := s.stdin.Write([]byte(script + "\n")); err != nil {
return commandResult{
stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
exitCode: 1,
@@ -180,18 +171,18 @@ echo $EXEC_EXIT_CODE > %s
}
interrupted := false
-
- startTime := time.Now()
-
done := make(chan bool)
go func() {
// Use exponential backoff polling
- pollInterval := 1 * time.Millisecond
- maxPollInterval := 100 * time.Millisecond
+ pollInterval := 10 * time.Millisecond
+ maxPollInterval := time.Second
ticker := time.NewTicker(pollInterval)
defer ticker.Stop()
+ timeoutTicker := time.NewTicker(cmp.Or(timeout, time.Hour*99999))
+ defer timeoutTicker.Stop()
+
for {
select {
case <-ctx.Done():
@@ -200,22 +191,18 @@ echo $EXEC_EXIT_CODE > %s
done <- true
return
+ case <-timeoutTicker.C:
+ s.killChildren()
+ interrupted = true
+ done <- true
+ return
+
case <-ticker.C:
- if fileExists(statusFile) && fileSize(statusFile) > 0 {
+ if fileSize(statusFile) > 0 {
done <- true
return
}
- if timeout > 0 {
- elapsed := time.Since(startTime)
- if elapsed > timeout {
- s.killChildren()
- interrupted = true
- done <- true
- return
- }
- }
-
// Exponential backoff to reduce CPU usage for longer-running commands
if pollInterval < maxPollInterval {
pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
@@ -280,12 +267,10 @@ func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs in
return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
}
- timeout := time.Duration(timeoutMs) * time.Millisecond
-
resultChan := make(chan commandResult)
s.commandQueue <- &commandExecution{
command: command,
- timeout: timeout,
+ timeout: time.Duration(timeoutMs) * time.Millisecond,
resultChan: resultChan,
ctx: ctx,
}
@@ -310,10 +295,6 @@ func (s *PersistentShell) Close() {
s.isAlive = false
}
-func shellQuote(s string) string {
- return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
-}
-
func readFileOrEmpty(path string) string {
content, err := os.ReadFile(path)
if err != nil {
@@ -322,11 +303,6 @@ func readFileOrEmpty(path string) string {
return string(content)
}
-func fileExists(path string) bool {
- _, err := os.Stat(path)
- return err == nil
-}
-
func fileSize(path string) int64 {
info, err := os.Stat(path)
if err != nil {