shell.go

  1package shell
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"os"
  8	"os/exec"
  9	"path/filepath"
 10	"strings"
 11	"sync"
 12	"syscall"
 13	"time"
 14
 15	"github.com/charmbracelet/crush/internal/config"
 16)
 17
 18type PersistentShell struct {
 19	cmd          *exec.Cmd
 20	stdin        *os.File
 21	isAlive      bool
 22	cwd          string
 23	mu           sync.Mutex
 24	commandQueue chan *commandExecution
 25}
 26
 27type commandExecution struct {
 28	command    string
 29	timeout    time.Duration
 30	resultChan chan commandResult
 31	ctx        context.Context
 32}
 33
 34type commandResult struct {
 35	stdout      string
 36	stderr      string
 37	exitCode    int
 38	interrupted bool
 39	err         error
 40}
 41
 42var (
 43	shellInstance     *PersistentShell
 44	shellInstanceOnce sync.Once
 45)
 46
 47func GetPersistentShell(workingDir string) *PersistentShell {
 48	shellInstanceOnce.Do(func() {
 49		shellInstance = newPersistentShell(workingDir)
 50	})
 51
 52	if shellInstance == nil {
 53		shellInstance = newPersistentShell(workingDir)
 54	} else if !shellInstance.isAlive {
 55		shellInstance = newPersistentShell(shellInstance.cwd)
 56	}
 57
 58	return shellInstance
 59}
 60
 61func newPersistentShell(cwd string) *PersistentShell {
 62	// Get shell configuration from config
 63	cfg := config.Get()
 64
 65	// Default to environment variable if config is not set or nil
 66	var shellPath string
 67	var shellArgs []string
 68
 69	if cfg != nil {
 70		shellPath = cfg.Shell.Path
 71		shellArgs = cfg.Shell.Args
 72	}
 73
 74	if shellPath == "" {
 75		shellPath = os.Getenv("SHELL")
 76		if shellPath == "" {
 77			shellPath = "/bin/bash"
 78		}
 79	}
 80
 81	// Default shell args
 82	if len(shellArgs) == 0 {
 83		shellArgs = []string{"-l"}
 84	}
 85
 86	cmd := exec.Command(shellPath, shellArgs...)
 87	cmd.Dir = cwd
 88
 89	stdinPipe, err := cmd.StdinPipe()
 90	if err != nil {
 91		return nil
 92	}
 93
 94	cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
 95
 96	err = cmd.Start()
 97	if err != nil {
 98		return nil
 99	}
100
101	shell := &PersistentShell{
102		cmd:          cmd,
103		stdin:        stdinPipe.(*os.File),
104		isAlive:      true,
105		cwd:          cwd,
106		commandQueue: make(chan *commandExecution, 10),
107	}
108
109	go func() {
110		defer func() {
111			if r := recover(); r != nil {
112				fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
113				shell.isAlive = false
114				close(shell.commandQueue)
115			}
116		}()
117		shell.processCommands()
118	}()
119
120	go func() {
121		err := cmd.Wait()
122		if err != nil {
123			// Log the error if needed
124		}
125		shell.isAlive = false
126		close(shell.commandQueue)
127	}()
128
129	return shell
130}
131
132func (s *PersistentShell) processCommands() {
133	for cmd := range s.commandQueue {
134		result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
135		cmd.resultChan <- result
136	}
137}
138
139func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
140	s.mu.Lock()
141	defer s.mu.Unlock()
142
143	if !s.isAlive {
144		return commandResult{
145			stderr:   "Shell is not alive",
146			exitCode: 1,
147			err:      errors.New("shell is not alive"),
148		}
149	}
150
151	tempDir := os.TempDir()
152	stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
153	stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
154	statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
155	cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
156
157	defer func() {
158		os.Remove(stdoutFile)
159		os.Remove(stderrFile)
160		os.Remove(statusFile)
161		os.Remove(cwdFile)
162	}()
163
164	fullCommand := fmt.Sprintf(`
165eval %s < /dev/null > %s 2> %s
166EXEC_EXIT_CODE=$?
167pwd > %s
168echo $EXEC_EXIT_CODE > %s
169`,
170		shellQuote(command),
171		shellQuote(stdoutFile),
172		shellQuote(stderrFile),
173		shellQuote(cwdFile),
174		shellQuote(statusFile),
175	)
176
177	_, err := s.stdin.Write([]byte(fullCommand + "\n"))
178	if err != nil {
179		return commandResult{
180			stderr:   fmt.Sprintf("Failed to write command to shell: %v", err),
181			exitCode: 1,
182			err:      err,
183		}
184	}
185
186	interrupted := false
187
188	startTime := time.Now()
189
190	done := make(chan bool)
191	go func() {
192		// Use exponential backoff polling
193		pollInterval := 1 * time.Millisecond
194		maxPollInterval := 100 * time.Millisecond
195		
196		ticker := time.NewTicker(pollInterval)
197		defer ticker.Stop()
198		
199		for {
200			select {
201			case <-ctx.Done():
202				s.killChildren()
203				interrupted = true
204				done <- true
205				return
206
207			case <-ticker.C:
208				if fileExists(statusFile) && fileSize(statusFile) > 0 {
209					done <- true
210					return
211				}
212
213				if timeout > 0 {
214					elapsed := time.Since(startTime)
215					if elapsed > timeout {
216						s.killChildren()
217						interrupted = true
218						done <- true
219						return
220					}
221				}
222				
223				// Exponential backoff to reduce CPU usage for longer-running commands
224				if pollInterval < maxPollInterval {
225					pollInterval = time.Duration(float64(pollInterval) * 1.5)
226					if pollInterval > maxPollInterval {
227						pollInterval = maxPollInterval
228					}
229					ticker.Reset(pollInterval)
230				}
231			}
232		}
233	}()
234
235	<-done
236
237	stdout := readFileOrEmpty(stdoutFile)
238	stderr := readFileOrEmpty(stderrFile)
239	exitCodeStr := readFileOrEmpty(statusFile)
240	newCwd := readFileOrEmpty(cwdFile)
241
242	exitCode := 0
243	if exitCodeStr != "" {
244		fmt.Sscanf(exitCodeStr, "%d", &exitCode)
245	} else if interrupted {
246		exitCode = 143
247		stderr += "\nCommand execution timed out or was interrupted"
248	}
249
250	if newCwd != "" {
251		s.cwd = strings.TrimSpace(newCwd)
252	}
253
254	return commandResult{
255		stdout:      stdout,
256		stderr:      stderr,
257		exitCode:    exitCode,
258		interrupted: interrupted,
259	}
260}
261
262func (s *PersistentShell) killChildren() {
263	if s.cmd == nil || s.cmd.Process == nil {
264		return
265	}
266
267	pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
268	output, err := pgrepCmd.Output()
269	if err != nil {
270		return
271	}
272
273	for pidStr := range strings.SplitSeq(string(output), "\n") {
274		if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
275			var pid int
276			fmt.Sscanf(pidStr, "%d", &pid)
277			if pid > 0 {
278				proc, err := os.FindProcess(pid)
279				if err == nil {
280					proc.Signal(syscall.SIGTERM)
281				}
282			}
283		}
284	}
285}
286
287func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
288	if !s.isAlive {
289		return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
290	}
291
292	timeout := time.Duration(timeoutMs) * time.Millisecond
293
294	resultChan := make(chan commandResult)
295	s.commandQueue <- &commandExecution{
296		command:    command,
297		timeout:    timeout,
298		resultChan: resultChan,
299		ctx:        ctx,
300	}
301
302	result := <-resultChan
303	return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
304}
305
306func (s *PersistentShell) Close() {
307	s.mu.Lock()
308	defer s.mu.Unlock()
309
310	if !s.isAlive {
311		return
312	}
313
314	s.stdin.Write([]byte("exit\n"))
315
316	s.cmd.Process.Kill()
317	s.isAlive = false
318}
319
320func shellQuote(s string) string {
321	return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
322}
323
324func readFileOrEmpty(path string) string {
325	content, err := os.ReadFile(path)
326	if err != nil {
327		return ""
328	}
329	return string(content)
330}
331
332func fileExists(path string) bool {
333	_, err := os.Stat(path)
334	return err == nil
335}
336
337func fileSize(path string) int64 {
338	info, err := os.Stat(path)
339	if err != nil {
340		return 0
341	}
342	return info.Size()
343}