1package shell
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "io"
8 "os"
9 "os/exec"
10 "path/filepath"
11 "strings"
12 "sync"
13 "syscall"
14 "time"
15
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/shirou/gopsutil/v4/process"
19)
20
21type PersistentShell struct {
22 cmd *exec.Cmd
23 stdin io.WriteCloser
24 isAlive bool
25 cwd string
26 mu sync.Mutex
27 commandQueue chan *commandExecution
28}
29
30type commandExecution struct {
31 command string
32 timeout time.Duration
33 resultChan chan commandResult
34 ctx context.Context
35}
36
37type commandResult struct {
38 stdout string
39 stderr string
40 exitCode int
41 interrupted bool
42 err error
43}
44
45var shellInstance *PersistentShell
46
47func GetPersistentShell(workingDir string) *PersistentShell {
48 if shellInstance == nil {
49 shellInstance = newPersistentShell(workingDir)
50 }
51 if !shellInstance.isAlive {
52 shellInstance = newPersistentShell(shellInstance.cwd)
53 }
54 return shellInstance
55}
56
57func newPersistentShell(cwd string) *PersistentShell {
58 // Get shell configuration from config
59 cfg := config.Get()
60
61 // Default to environment variable if config is not set or nil
62 var shellPath string
63 var shellArgs []string
64
65 if cfg != nil {
66 shellPath = cfg.Shell.Path
67 shellArgs = cfg.Shell.Args
68 }
69
70 if shellPath == "" {
71 shellPath = os.Getenv("SHELL")
72 if shellPath == "" {
73 shellPath = "/bin/bash"
74 }
75 }
76
77 // Default shell args
78 if len(shellArgs) == 0 {
79 shellArgs = []string{"-l"}
80 }
81
82 cmd := exec.Command(shellPath, shellArgs...)
83 cmd.Dir = cwd
84
85 stdinPipe, err := cmd.StdinPipe()
86 if err != nil {
87 return nil
88 }
89
90 cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
91
92 err = cmd.Start()
93 if err != nil {
94 return nil
95 }
96
97 shell := &PersistentShell{
98 cmd: cmd,
99 stdin: stdinPipe,
100 isAlive: true,
101 cwd: cwd,
102 commandQueue: make(chan *commandExecution, 10),
103 }
104
105 go func() {
106 defer func() {
107 if r := recover(); r != nil {
108 fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
109 shell.isAlive = false
110 close(shell.commandQueue)
111 }
112 }()
113 shell.processCommands()
114 }()
115
116 go func() {
117 err := cmd.Wait()
118 if err != nil {
119 // Log the error if needed
120 }
121 shell.isAlive = false
122 close(shell.commandQueue)
123 }()
124
125 return shell
126}
127
128func (s *PersistentShell) processCommands() {
129 for cmd := range s.commandQueue {
130 result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
131 cmd.resultChan <- result
132 }
133}
134
135func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
136 s.mu.Lock()
137 defer s.mu.Unlock()
138
139 if !s.isAlive {
140 return commandResult{
141 stderr: "Shell is not alive",
142 exitCode: 1,
143 err: errors.New("shell is not alive"),
144 }
145 }
146
147 tempDir := os.TempDir()
148 stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
149 stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
150 statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
151 cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
152
153 defer func() {
154 os.Remove(stdoutFile)
155 os.Remove(stderrFile)
156 os.Remove(statusFile)
157 os.Remove(cwdFile)
158 }()
159
160 fullCommand := fmt.Sprintf(`
161eval %s < /dev/null > %s 2> %s
162EXEC_EXIT_CODE=$?
163pwd > %s
164echo $EXEC_EXIT_CODE > %s
165`,
166 shellQuote(command),
167 shellQuote(stdoutFile),
168 shellQuote(stderrFile),
169 shellQuote(cwdFile),
170 shellQuote(statusFile),
171 )
172
173 _, err := s.stdin.Write([]byte(fullCommand + "\n"))
174 if err != nil {
175 return commandResult{
176 stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
177 exitCode: 1,
178 err: err,
179 }
180 }
181
182 interrupted := false
183
184 startTime := time.Now()
185
186 done := make(chan bool)
187 go func() {
188 // Use exponential backoff polling
189 pollInterval := 1 * time.Millisecond
190 maxPollInterval := 100 * time.Millisecond
191
192 ticker := time.NewTicker(pollInterval)
193 defer ticker.Stop()
194
195 for {
196 select {
197 case <-ctx.Done():
198 s.killChildren()
199 interrupted = true
200 done <- true
201 return
202
203 case <-ticker.C:
204 if fileExists(statusFile) && fileSize(statusFile) > 0 {
205 done <- true
206 return
207 }
208
209 if timeout > 0 {
210 elapsed := time.Since(startTime)
211 if elapsed > timeout {
212 s.killChildren()
213 interrupted = true
214 done <- true
215 return
216 }
217 }
218
219 // Exponential backoff to reduce CPU usage for longer-running commands
220 if pollInterval < maxPollInterval {
221 pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
222 ticker.Reset(pollInterval)
223 }
224 }
225 }
226 }()
227
228 <-done
229
230 stdout := readFileOrEmpty(stdoutFile)
231 stderr := readFileOrEmpty(stderrFile)
232 exitCodeStr := readFileOrEmpty(statusFile)
233 newCwd := readFileOrEmpty(cwdFile)
234
235 exitCode := 0
236 if exitCodeStr != "" {
237 fmt.Sscanf(exitCodeStr, "%d", &exitCode)
238 } else if interrupted {
239 exitCode = 143
240 stderr += "\nCommand execution timed out or was interrupted"
241 }
242
243 if newCwd != "" {
244 s.cwd = strings.TrimSpace(newCwd)
245 }
246
247 return commandResult{
248 stdout: stdout,
249 stderr: stderr,
250 exitCode: exitCode,
251 interrupted: interrupted,
252 }
253}
254
255func (s *PersistentShell) killChildren() {
256 if s.cmd == nil || s.cmd.Process == nil {
257 return
258 }
259 p, err := process.NewProcess(int32(s.cmd.Process.Pid))
260 if err != nil {
261 logging.WarnPersist("could not kill persistent shell child processes", "err", err)
262 return
263 }
264
265 children, err := p.Children()
266 if err != nil {
267 logging.WarnPersist("could not kill persistent shell child processes", "err", err)
268 return
269 }
270
271 for _, child := range children {
272 if err := child.SendSignal(syscall.SIGTERM); err != nil {
273 logging.WarnPersist("could not kill persistent shell child processes", "err", err, "pid", child.Pid)
274 }
275 }
276}
277
278func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
279 if !s.isAlive {
280 return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
281 }
282
283 timeout := time.Duration(timeoutMs) * time.Millisecond
284
285 resultChan := make(chan commandResult)
286 s.commandQueue <- &commandExecution{
287 command: command,
288 timeout: timeout,
289 resultChan: resultChan,
290 ctx: ctx,
291 }
292
293 result := <-resultChan
294 return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
295}
296
297func (s *PersistentShell) Close() {
298 s.mu.Lock()
299 defer s.mu.Unlock()
300
301 if !s.isAlive {
302 return
303 }
304
305 s.stdin.Write([]byte("exit\n"))
306
307 if err := s.cmd.Process.Kill(); err != nil {
308 logging.WarnPersist("could not kill persistent shell", "err", err)
309 }
310 s.isAlive = false
311}
312
313func shellQuote(s string) string {
314 return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
315}
316
317func readFileOrEmpty(path string) string {
318 content, err := os.ReadFile(path)
319 if err != nil {
320 return ""
321 }
322 return string(content)
323}
324
325func fileExists(path string) bool {
326 _, err := os.Stat(path)
327 return err == nil
328}
329
330func fileSize(path string) int64 {
331 info, err := os.Stat(path)
332 if err != nil {
333 return 0
334 }
335 return info.Size()
336}