shell.go

  1package shell
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"io"
  8	"os"
  9	"os/exec"
 10	"path/filepath"
 11	"strings"
 12	"sync"
 13	"syscall"
 14	"time"
 15
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/shirou/gopsutil/v4/process"
 19)
 20
 21type PersistentShell struct {
 22	cmd          *exec.Cmd
 23	stdin        io.WriteCloser
 24	isAlive      bool
 25	cwd          string
 26	mu           sync.Mutex
 27	commandQueue chan *commandExecution
 28}
 29
 30type commandExecution struct {
 31	command    string
 32	timeout    time.Duration
 33	resultChan chan commandResult
 34	ctx        context.Context
 35}
 36
 37type commandResult struct {
 38	stdout      string
 39	stderr      string
 40	exitCode    int
 41	interrupted bool
 42	err         error
 43}
 44
 45var shellInstance *PersistentShell
 46
 47func GetPersistentShell(workingDir string) *PersistentShell {
 48	if shellInstance == nil {
 49		shellInstance = newPersistentShell(workingDir)
 50	}
 51	if !shellInstance.isAlive {
 52		shellInstance = newPersistentShell(shellInstance.cwd)
 53	}
 54	return shellInstance
 55}
 56
 57func newPersistentShell(cwd string) *PersistentShell {
 58	// Get shell configuration from config
 59	cfg := config.Get()
 60
 61	// Default to environment variable if config is not set or nil
 62	var shellPath string
 63	var shellArgs []string
 64
 65	if cfg != nil {
 66		shellPath = cfg.Shell.Path
 67		shellArgs = cfg.Shell.Args
 68	}
 69
 70	if shellPath == "" {
 71		shellPath = os.Getenv("SHELL")
 72		if shellPath == "" {
 73			shellPath = "/bin/bash"
 74		}
 75	}
 76
 77	// Default shell args
 78	if len(shellArgs) == 0 {
 79		shellArgs = []string{"-l"}
 80	}
 81
 82	cmd := exec.Command(shellPath, shellArgs...)
 83	cmd.Dir = cwd
 84
 85	stdinPipe, err := cmd.StdinPipe()
 86	if err != nil {
 87		return nil
 88	}
 89
 90	cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
 91
 92	err = cmd.Start()
 93	if err != nil {
 94		return nil
 95	}
 96
 97	shell := &PersistentShell{
 98		cmd:          cmd,
 99		stdin:        stdinPipe,
100		isAlive:      true,
101		cwd:          cwd,
102		commandQueue: make(chan *commandExecution, 10),
103	}
104
105	go func() {
106		defer func() {
107			if r := recover(); r != nil {
108				fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
109				shell.isAlive = false
110				close(shell.commandQueue)
111			}
112		}()
113		shell.processCommands()
114	}()
115
116	go func() {
117		err := cmd.Wait()
118		if err != nil {
119			// Log the error if needed
120		}
121		shell.isAlive = false
122		close(shell.commandQueue)
123	}()
124
125	return shell
126}
127
128func (s *PersistentShell) processCommands() {
129	for cmd := range s.commandQueue {
130		result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
131		cmd.resultChan <- result
132	}
133}
134
135func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
136	s.mu.Lock()
137	defer s.mu.Unlock()
138
139	if !s.isAlive {
140		return commandResult{
141			stderr:   "Shell is not alive",
142			exitCode: 1,
143			err:      errors.New("shell is not alive"),
144		}
145	}
146
147	tempDir := os.TempDir()
148	stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
149	stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
150	statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
151	cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
152
153	defer func() {
154		os.Remove(stdoutFile)
155		os.Remove(stderrFile)
156		os.Remove(statusFile)
157		os.Remove(cwdFile)
158	}()
159
160	fullCommand := fmt.Sprintf(`
161eval %s < /dev/null > %s 2> %s
162EXEC_EXIT_CODE=$?
163pwd > %s
164echo $EXEC_EXIT_CODE > %s
165`,
166		shellQuote(command),
167		shellQuote(stdoutFile),
168		shellQuote(stderrFile),
169		shellQuote(cwdFile),
170		shellQuote(statusFile),
171	)
172
173	_, err := s.stdin.Write([]byte(fullCommand + "\n"))
174	if err != nil {
175		return commandResult{
176			stderr:   fmt.Sprintf("Failed to write command to shell: %v", err),
177			exitCode: 1,
178			err:      err,
179		}
180	}
181
182	interrupted := false
183
184	startTime := time.Now()
185
186	done := make(chan bool)
187	go func() {
188		// Use exponential backoff polling
189		pollInterval := 1 * time.Millisecond
190		maxPollInterval := 100 * time.Millisecond
191
192		ticker := time.NewTicker(pollInterval)
193		defer ticker.Stop()
194
195		for {
196			select {
197			case <-ctx.Done():
198				s.killChildren()
199				interrupted = true
200				done <- true
201				return
202
203			case <-ticker.C:
204				if fileExists(statusFile) && fileSize(statusFile) > 0 {
205					done <- true
206					return
207				}
208
209				if timeout > 0 {
210					elapsed := time.Since(startTime)
211					if elapsed > timeout {
212						s.killChildren()
213						interrupted = true
214						done <- true
215						return
216					}
217				}
218
219				// Exponential backoff to reduce CPU usage for longer-running commands
220				if pollInterval < maxPollInterval {
221					pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
222					ticker.Reset(pollInterval)
223				}
224			}
225		}
226	}()
227
228	<-done
229
230	stdout := readFileOrEmpty(stdoutFile)
231	stderr := readFileOrEmpty(stderrFile)
232	exitCodeStr := readFileOrEmpty(statusFile)
233	newCwd := readFileOrEmpty(cwdFile)
234
235	exitCode := 0
236	if exitCodeStr != "" {
237		fmt.Sscanf(exitCodeStr, "%d", &exitCode)
238	} else if interrupted {
239		exitCode = 143
240		stderr += "\nCommand execution timed out or was interrupted"
241	}
242
243	if newCwd != "" {
244		s.cwd = strings.TrimSpace(newCwd)
245	}
246
247	return commandResult{
248		stdout:      stdout,
249		stderr:      stderr,
250		exitCode:    exitCode,
251		interrupted: interrupted,
252	}
253}
254
255func (s *PersistentShell) killChildren() {
256	if s.cmd == nil || s.cmd.Process == nil {
257		return
258	}
259	p, err := process.NewProcess(int32(s.cmd.Process.Pid))
260	if err != nil {
261		logging.WarnPersist("could not kill persistent shell child processes", "err", err)
262		return
263	}
264
265	children, err := p.Children()
266	if err != nil {
267		logging.WarnPersist("could not kill persistent shell child processes", "err", err)
268		return
269	}
270
271	for _, child := range children {
272		if err := child.SendSignal(syscall.SIGTERM); err != nil {
273			logging.WarnPersist("could not kill persistent shell child processes", "err", err, "pid", child.Pid)
274		}
275	}
276}
277
278func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
279	if !s.isAlive {
280		return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
281	}
282
283	timeout := time.Duration(timeoutMs) * time.Millisecond
284
285	resultChan := make(chan commandResult)
286	s.commandQueue <- &commandExecution{
287		command:    command,
288		timeout:    timeout,
289		resultChan: resultChan,
290		ctx:        ctx,
291	}
292
293	result := <-resultChan
294	return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
295}
296
297func (s *PersistentShell) Close() {
298	s.mu.Lock()
299	defer s.mu.Unlock()
300
301	if !s.isAlive {
302		return
303	}
304
305	s.stdin.Write([]byte("exit\n"))
306
307	if err := s.cmd.Process.Kill(); err != nil {
308		logging.WarnPersist("could not kill persistent shell", "err", err)
309	}
310	s.isAlive = false
311}
312
313func shellQuote(s string) string {
314	return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
315}
316
317func readFileOrEmpty(path string) string {
318	content, err := os.ReadFile(path)
319	if err != nil {
320		return ""
321	}
322	return string(content)
323}
324
325func fileExists(path string) bool {
326	_, err := os.Stat(path)
327	return err == nil
328}
329
330func fileSize(path string) int64 {
331	info, err := os.Stat(path)
332	if err != nil {
333		return 0
334	}
335	return info.Size()
336}