shell.go

  1package shell
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"os"
  8	"os/exec"
  9	"path/filepath"
 10	"strings"
 11	"sync"
 12	"syscall"
 13	"time"
 14)
 15
 16type PersistentShell struct {
 17	cmd          *exec.Cmd
 18	stdin        *os.File
 19	isAlive      bool
 20	cwd          string
 21	mu           sync.Mutex
 22	commandQueue chan *commandExecution
 23}
 24
 25type commandExecution struct {
 26	command    string
 27	timeout    time.Duration
 28	resultChan chan commandResult
 29	ctx        context.Context
 30}
 31
 32type commandResult struct {
 33	stdout      string
 34	stderr      string
 35	exitCode    int
 36	interrupted bool
 37	err         error
 38}
 39
 40var (
 41	shellInstance     *PersistentShell
 42	shellInstanceOnce sync.Once
 43)
 44
 45func GetPersistentShell(workingDir string) *PersistentShell {
 46	shellInstanceOnce.Do(func() {
 47		shellInstance = newPersistentShell(workingDir)
 48	})
 49
 50	if shellInstance == nil {
 51		shellInstance = newPersistentShell(workingDir)
 52	} else if !shellInstance.isAlive {
 53		shellInstance = newPersistentShell(shellInstance.cwd)
 54	}
 55
 56	return shellInstance
 57}
 58
 59func newPersistentShell(cwd string) *PersistentShell {
 60	// Default to environment variable
 61	shellPath := os.Getenv("SHELL")
 62	if shellPath == "" {
 63		shellPath = "/bin/bash"
 64	}
 65
 66	// Default shell args
 67	shellArgs := []string{"-l"}
 68
 69	cmd := exec.Command(shellPath, shellArgs...)
 70	cmd.Dir = cwd
 71
 72	stdinPipe, err := cmd.StdinPipe()
 73	if err != nil {
 74		return nil
 75	}
 76
 77	cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
 78
 79	err = cmd.Start()
 80	if err != nil {
 81		return nil
 82	}
 83
 84	shell := &PersistentShell{
 85		cmd:          cmd,
 86		stdin:        stdinPipe.(*os.File),
 87		isAlive:      true,
 88		cwd:          cwd,
 89		commandQueue: make(chan *commandExecution, 10),
 90	}
 91
 92	go func() {
 93		defer func() {
 94			if r := recover(); r != nil {
 95				fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
 96				shell.isAlive = false
 97				close(shell.commandQueue)
 98			}
 99		}()
100		shell.processCommands()
101	}()
102
103	go func() {
104		err := cmd.Wait()
105		if err != nil {
106			// Log the error if needed
107		}
108		shell.isAlive = false
109		close(shell.commandQueue)
110	}()
111
112	return shell
113}
114
115func (s *PersistentShell) processCommands() {
116	for cmd := range s.commandQueue {
117		result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
118		cmd.resultChan <- result
119	}
120}
121
122func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
123	s.mu.Lock()
124	defer s.mu.Unlock()
125
126	if !s.isAlive {
127		return commandResult{
128			stderr:   "Shell is not alive",
129			exitCode: 1,
130			err:      errors.New("shell is not alive"),
131		}
132	}
133
134	tempDir := os.TempDir()
135	stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
136	stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
137	statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
138	cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
139
140	defer func() {
141		os.Remove(stdoutFile)
142		os.Remove(stderrFile)
143		os.Remove(statusFile)
144		os.Remove(cwdFile)
145	}()
146
147	fullCommand := fmt.Sprintf(`
148eval %s < /dev/null > %s 2> %s
149EXEC_EXIT_CODE=$?
150pwd > %s
151echo $EXEC_EXIT_CODE > %s
152`,
153		shellQuote(command),
154		shellQuote(stdoutFile),
155		shellQuote(stderrFile),
156		shellQuote(cwdFile),
157		shellQuote(statusFile),
158	)
159
160	_, err := s.stdin.Write([]byte(fullCommand + "\n"))
161	if err != nil {
162		return commandResult{
163			stderr:   fmt.Sprintf("Failed to write command to shell: %v", err),
164			exitCode: 1,
165			err:      err,
166		}
167	}
168
169	interrupted := false
170
171	startTime := time.Now()
172
173	done := make(chan bool)
174	go func() {
175		// Use exponential backoff polling
176		pollInterval := 1 * time.Millisecond
177		maxPollInterval := 100 * time.Millisecond
178
179		ticker := time.NewTicker(pollInterval)
180		defer ticker.Stop()
181
182		for {
183			select {
184			case <-ctx.Done():
185				s.killChildren()
186				interrupted = true
187				done <- true
188				return
189
190			case <-ticker.C:
191				if fileExists(statusFile) && fileSize(statusFile) > 0 {
192					done <- true
193					return
194				}
195
196				if timeout > 0 {
197					elapsed := time.Since(startTime)
198					if elapsed > timeout {
199						s.killChildren()
200						interrupted = true
201						done <- true
202						return
203					}
204				}
205
206				// Exponential backoff to reduce CPU usage for longer-running commands
207				if pollInterval < maxPollInterval {
208					pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
209					ticker.Reset(pollInterval)
210				}
211			}
212		}
213	}()
214
215	<-done
216
217	stdout := readFileOrEmpty(stdoutFile)
218	stderr := readFileOrEmpty(stderrFile)
219	exitCodeStr := readFileOrEmpty(statusFile)
220	newCwd := readFileOrEmpty(cwdFile)
221
222	exitCode := 0
223	if exitCodeStr != "" {
224		fmt.Sscanf(exitCodeStr, "%d", &exitCode)
225	} else if interrupted {
226		exitCode = 143
227		stderr += "\nCommand execution timed out or was interrupted"
228	}
229
230	if newCwd != "" {
231		s.cwd = strings.TrimSpace(newCwd)
232	}
233
234	return commandResult{
235		stdout:      stdout,
236		stderr:      stderr,
237		exitCode:    exitCode,
238		interrupted: interrupted,
239	}
240}
241
242func (s *PersistentShell) killChildren() {
243	if s.cmd == nil || s.cmd.Process == nil {
244		return
245	}
246
247	pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
248	output, err := pgrepCmd.Output()
249	if err != nil {
250		return
251	}
252
253	for pidStr := range strings.SplitSeq(string(output), "\n") {
254		if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
255			var pid int
256			fmt.Sscanf(pidStr, "%d", &pid)
257			if pid > 0 {
258				proc, err := os.FindProcess(pid)
259				if err == nil {
260					proc.Signal(syscall.SIGTERM)
261				}
262			}
263		}
264	}
265}
266
267func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
268	if !s.isAlive {
269		return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
270	}
271
272	timeout := time.Duration(timeoutMs) * time.Millisecond
273
274	resultChan := make(chan commandResult)
275	s.commandQueue <- &commandExecution{
276		command:    command,
277		timeout:    timeout,
278		resultChan: resultChan,
279		ctx:        ctx,
280	}
281
282	result := <-resultChan
283	return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
284}
285
286func (s *PersistentShell) Close() {
287	s.mu.Lock()
288	defer s.mu.Unlock()
289
290	if !s.isAlive {
291		return
292	}
293
294	s.stdin.Write([]byte("exit\n"))
295
296	s.cmd.Process.Kill()
297	s.isAlive = false
298}
299
300func shellQuote(s string) string {
301	return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
302}
303
304func readFileOrEmpty(path string) string {
305	content, err := os.ReadFile(path)
306	if err != nil {
307		return ""
308	}
309	return string(content)
310}
311
312func fileExists(path string) bool {
313	_, err := os.Stat(path)
314	return err == nil
315}
316
317func fileSize(path string) int64 {
318	info, err := os.Stat(path)
319	if err != nil {
320		return 0
321	}
322	return info.Size()
323}