shell.go

  1package shell
  2
  3import (
  4	"cmp"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"io"
  9	"os"
 10	"os/exec"
 11	"path/filepath"
 12	"strings"
 13	"sync"
 14	"syscall"
 15	"time"
 16
 17	"github.com/charmbracelet/crush/internal/config"
 18	"github.com/charmbracelet/crush/internal/logging"
 19	"github.com/shirou/gopsutil/v4/process"
 20)
 21
 22type PersistentShell struct {
 23	cmd          *exec.Cmd
 24	stdin        io.WriteCloser
 25	isAlive      bool
 26	cwd          string
 27	mu           sync.Mutex
 28	commandQueue chan *commandExecution
 29}
 30
 31type commandExecution struct {
 32	command    string
 33	timeout    time.Duration
 34	resultChan chan commandResult
 35	ctx        context.Context
 36}
 37
 38type commandResult struct {
 39	stdout      string
 40	stderr      string
 41	exitCode    int
 42	interrupted bool
 43	err         error
 44}
 45
 46var shellInstance *PersistentShell
 47
 48func GetPersistentShell(workingDir string) *PersistentShell {
 49	if shellInstance == nil {
 50		shellInstance = newPersistentShell(workingDir)
 51	}
 52	if !shellInstance.isAlive {
 53		shellInstance = newPersistentShell(shellInstance.cwd)
 54	}
 55	return shellInstance
 56}
 57
 58func newPersistentShell(cwd string) *PersistentShell {
 59	// Get shell configuration from config
 60	cfg := config.Get()
 61
 62	// Default to environment variable if config is not set or nil
 63	var shellPath string
 64	var shellArgs []string
 65
 66	if cfg != nil {
 67		shellPath = cfg.Shell.Path
 68		shellArgs = cfg.Shell.Args
 69	}
 70
 71	shellPath = cmp.Or(shellPath, os.Getenv("SHELL"), "/bin/bash")
 72	if !strings.HasSuffix(shellPath, "bash") && !strings.HasSuffix(shellPath, "zsh") {
 73		logging.Warn("only bash and zsh are supported at this time", "shell", shellPath)
 74		shellPath = "/bin/bash"
 75	}
 76
 77	// Default shell args
 78	if len(shellArgs) == 0 {
 79		shellArgs = []string{"--login"}
 80	}
 81
 82	cmd := exec.Command(shellPath, shellArgs...)
 83	cmd.Dir = cwd
 84
 85	stdinPipe, err := cmd.StdinPipe()
 86	if err != nil {
 87		return nil
 88	}
 89
 90	cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
 91
 92	err = cmd.Start()
 93	if err != nil {
 94		return nil
 95	}
 96
 97	shell := &PersistentShell{
 98		cmd:          cmd,
 99		stdin:        stdinPipe,
100		isAlive:      true,
101		cwd:          cwd,
102		commandQueue: make(chan *commandExecution, 10),
103	}
104
105	go func() {
106		defer func() {
107			if r := recover(); r != nil {
108				fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
109				shell.isAlive = false
110				close(shell.commandQueue)
111			}
112		}()
113		shell.processCommands()
114	}()
115
116	go func() {
117		err := cmd.Wait()
118		if err != nil {
119			// Log the error if needed
120		}
121		shell.isAlive = false
122		close(shell.commandQueue)
123	}()
124
125	return shell
126}
127
128func (s *PersistentShell) processCommands() {
129	for cmd := range s.commandQueue {
130		cmd.resultChan <- s.execCommand(cmd.ctx, cmd.command, cmd.timeout)
131	}
132}
133
134const runBashCommandFormat = `%s </dev/null >%q 2>%q
135echo $? >%q
136pwd >%q`
137
138func (s *PersistentShell) execCommand(ctx context.Context, command string, timeout time.Duration) commandResult {
139	s.mu.Lock()
140	defer s.mu.Unlock()
141
142	if !s.isAlive {
143		return commandResult{
144			stderr:   "Shell is not alive",
145			exitCode: 1,
146			err:      errors.New("shell is not alive"),
147		}
148	}
149
150	tmp := os.TempDir()
151	now := time.Now().UnixNano()
152	stdoutFile := filepath.Join(tmp, fmt.Sprintf("crush-stdout-%d", now))
153	stderrFile := filepath.Join(tmp, fmt.Sprintf("crush-stderr-%d", now))
154	statusFile := filepath.Join(tmp, fmt.Sprintf("crush-status-%d", now))
155	cwdFile := filepath.Join(tmp, fmt.Sprintf("crush-cwd-%d", now))
156
157	defer func() {
158		_ = os.Remove(stdoutFile)
159		_ = os.Remove(stderrFile)
160		_ = os.Remove(statusFile)
161		_ = os.Remove(cwdFile)
162	}()
163
164	script := fmt.Sprintf(runBashCommandFormat, command, stdoutFile, stderrFile, statusFile, cwdFile)
165	if _, err := s.stdin.Write([]byte(script + "\n")); err != nil {
166		return commandResult{
167			stderr:   fmt.Sprintf("Failed to write command to shell: %v", err),
168			exitCode: 1,
169			err:      err,
170		}
171	}
172
173	interrupted := false
174	done := make(chan bool)
175	go func() {
176		// Use exponential backoff polling
177		pollInterval := 10 * time.Millisecond
178		maxPollInterval := time.Second
179
180		ticker := time.NewTicker(pollInterval)
181		defer ticker.Stop()
182
183		timeoutTicker := time.NewTicker(cmp.Or(timeout, time.Hour*99999))
184		defer timeoutTicker.Stop()
185
186		for {
187			select {
188			case <-ctx.Done():
189				s.killChildren()
190				interrupted = true
191				done <- true
192				return
193
194			case <-timeoutTicker.C:
195				s.killChildren()
196				interrupted = true
197				done <- true
198				return
199
200			case <-ticker.C:
201				if fileSize(statusFile) > 0 {
202					done <- true
203					return
204				}
205
206				// Exponential backoff to reduce CPU usage for longer-running commands
207				if pollInterval < maxPollInterval {
208					pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
209					ticker.Reset(pollInterval)
210				}
211			}
212		}
213	}()
214
215	<-done
216
217	stdout := readFileOrEmpty(stdoutFile)
218	stderr := readFileOrEmpty(stderrFile)
219	exitCodeStr := readFileOrEmpty(statusFile)
220	newCwd := readFileOrEmpty(cwdFile)
221
222	exitCode := 0
223	if exitCodeStr != "" {
224		fmt.Sscanf(exitCodeStr, "%d", &exitCode)
225	} else if interrupted {
226		exitCode = 143
227		stderr += "\nCommand execution timed out or was interrupted"
228	}
229
230	if newCwd != "" {
231		s.cwd = strings.TrimSpace(newCwd)
232	}
233
234	return commandResult{
235		stdout:      stdout,
236		stderr:      stderr,
237		exitCode:    exitCode,
238		interrupted: interrupted,
239	}
240}
241
242func (s *PersistentShell) killChildren() {
243	if s.cmd == nil || s.cmd.Process == nil {
244		return
245	}
246	p, err := process.NewProcess(int32(s.cmd.Process.Pid))
247	if err != nil {
248		logging.WarnPersist("could not kill persistent shell child processes", "err", err)
249		return
250	}
251
252	children, err := p.Children()
253	if err != nil {
254		logging.WarnPersist("could not kill persistent shell child processes", "err", err)
255		return
256	}
257
258	for _, child := range children {
259		if err := child.SendSignal(syscall.SIGTERM); err != nil {
260			logging.WarnPersist("could not kill persistent shell child processes", "err", err, "pid", child.Pid)
261		}
262	}
263}
264
265func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
266	if !s.isAlive {
267		return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
268	}
269
270	resultChan := make(chan commandResult)
271	s.commandQueue <- &commandExecution{
272		command:    command,
273		timeout:    time.Duration(timeoutMs) * time.Millisecond,
274		resultChan: resultChan,
275		ctx:        ctx,
276	}
277
278	result := <-resultChan
279	return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
280}
281
282func (s *PersistentShell) Close() {
283	s.mu.Lock()
284	defer s.mu.Unlock()
285
286	if !s.isAlive {
287		return
288	}
289
290	s.stdin.Write([]byte("exit\n"))
291
292	if err := s.cmd.Process.Kill(); err != nil {
293		logging.WarnPersist("could not kill persistent shell", "err", err)
294	}
295	s.isAlive = false
296}
297
298func readFileOrEmpty(path string) string {
299	content, err := os.ReadFile(path)
300	if err != nil {
301		return ""
302	}
303	return string(content)
304}
305
306func fileSize(path string) int64 {
307	info, err := os.Stat(path)
308	if err != nil {
309		return 0
310	}
311	return info.Size()
312}