1package shell
2
3import (
4 "cmp"
5 "context"
6 "errors"
7 "fmt"
8 "io"
9 "os"
10 "os/exec"
11 "path/filepath"
12 "strings"
13 "sync"
14 "syscall"
15 "time"
16
17 "github.com/charmbracelet/crush/internal/config"
18 "github.com/charmbracelet/crush/internal/logging"
19 "github.com/shirou/gopsutil/v4/process"
20)
21
22type PersistentShell struct {
23 cmd *exec.Cmd
24 stdin io.WriteCloser
25 isAlive bool
26 cwd string
27 mu sync.Mutex
28 commandQueue chan *commandExecution
29}
30
31type commandExecution struct {
32 command string
33 timeout time.Duration
34 resultChan chan commandResult
35 ctx context.Context
36}
37
38type commandResult struct {
39 stdout string
40 stderr string
41 exitCode int
42 interrupted bool
43 err error
44}
45
46var shellInstance *PersistentShell
47
48func GetPersistentShell(workingDir string) *PersistentShell {
49 if shellInstance == nil {
50 shellInstance = newPersistentShell(workingDir)
51 }
52 if !shellInstance.isAlive {
53 shellInstance = newPersistentShell(shellInstance.cwd)
54 }
55 return shellInstance
56}
57
58func newPersistentShell(cwd string) *PersistentShell {
59 // Get shell configuration from config
60 cfg := config.Get()
61
62 // Default to environment variable if config is not set or nil
63 var shellPath string
64 var shellArgs []string
65
66 if cfg != nil {
67 shellPath = cfg.Shell.Path
68 shellArgs = cfg.Shell.Args
69 }
70
71 shellPath = cmp.Or(shellPath, os.Getenv("SHELL"), "/bin/bash")
72 if !strings.HasSuffix(shellPath, "bash") && !strings.HasSuffix(shellPath, "zsh") {
73 logging.Warn("only bash and zsh are supported at this time", "shell", shellPath)
74 shellPath = "/bin/bash"
75 }
76
77 // Default shell args
78 if len(shellArgs) == 0 {
79 shellArgs = []string{"--login"}
80 }
81
82 cmd := exec.Command(shellPath, shellArgs...)
83 cmd.Dir = cwd
84
85 stdinPipe, err := cmd.StdinPipe()
86 if err != nil {
87 return nil
88 }
89
90 cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
91
92 err = cmd.Start()
93 if err != nil {
94 return nil
95 }
96
97 shell := &PersistentShell{
98 cmd: cmd,
99 stdin: stdinPipe,
100 isAlive: true,
101 cwd: cwd,
102 commandQueue: make(chan *commandExecution, 10),
103 }
104
105 go func() {
106 defer func() {
107 if r := recover(); r != nil {
108 fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
109 shell.isAlive = false
110 close(shell.commandQueue)
111 }
112 }()
113 shell.processCommands()
114 }()
115
116 go func() {
117 err := cmd.Wait()
118 if err != nil {
119 // Log the error if needed
120 }
121 shell.isAlive = false
122 close(shell.commandQueue)
123 }()
124
125 return shell
126}
127
128func (s *PersistentShell) processCommands() {
129 for cmd := range s.commandQueue {
130 cmd.resultChan <- s.execCommand(cmd.ctx, cmd.command, cmd.timeout)
131 }
132}
133
134const runBashCommandFormat = `%s </dev/null >%q 2>%q
135echo $? >%q
136pwd >%q`
137
138func (s *PersistentShell) execCommand(ctx context.Context, command string, timeout time.Duration) commandResult {
139 s.mu.Lock()
140 defer s.mu.Unlock()
141
142 if !s.isAlive {
143 return commandResult{
144 stderr: "Shell is not alive",
145 exitCode: 1,
146 err: errors.New("shell is not alive"),
147 }
148 }
149
150 tmp := os.TempDir()
151 now := time.Now().UnixNano()
152 stdoutFile := filepath.Join(tmp, fmt.Sprintf("crush-stdout-%d", now))
153 stderrFile := filepath.Join(tmp, fmt.Sprintf("crush-stderr-%d", now))
154 statusFile := filepath.Join(tmp, fmt.Sprintf("crush-status-%d", now))
155 cwdFile := filepath.Join(tmp, fmt.Sprintf("crush-cwd-%d", now))
156
157 defer func() {
158 _ = os.Remove(stdoutFile)
159 _ = os.Remove(stderrFile)
160 _ = os.Remove(statusFile)
161 _ = os.Remove(cwdFile)
162 }()
163
164 script := fmt.Sprintf(runBashCommandFormat, command, stdoutFile, stderrFile, statusFile, cwdFile)
165 if _, err := s.stdin.Write([]byte(script + "\n")); err != nil {
166 return commandResult{
167 stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
168 exitCode: 1,
169 err: err,
170 }
171 }
172
173 interrupted := false
174 done := make(chan bool)
175 go func() {
176 // Use exponential backoff polling
177 pollInterval := 10 * time.Millisecond
178 maxPollInterval := time.Second
179
180 ticker := time.NewTicker(pollInterval)
181 defer ticker.Stop()
182
183 timeoutTicker := time.NewTicker(cmp.Or(timeout, time.Hour*99999))
184 defer timeoutTicker.Stop()
185
186 for {
187 select {
188 case <-ctx.Done():
189 s.killChildren()
190 interrupted = true
191 done <- true
192 return
193
194 case <-timeoutTicker.C:
195 s.killChildren()
196 interrupted = true
197 done <- true
198 return
199
200 case <-ticker.C:
201 if fileSize(statusFile) > 0 {
202 done <- true
203 return
204 }
205
206 // Exponential backoff to reduce CPU usage for longer-running commands
207 if pollInterval < maxPollInterval {
208 pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
209 ticker.Reset(pollInterval)
210 }
211 }
212 }
213 }()
214
215 <-done
216
217 stdout := readFileOrEmpty(stdoutFile)
218 stderr := readFileOrEmpty(stderrFile)
219 exitCodeStr := readFileOrEmpty(statusFile)
220 newCwd := readFileOrEmpty(cwdFile)
221
222 exitCode := 0
223 if exitCodeStr != "" {
224 fmt.Sscanf(exitCodeStr, "%d", &exitCode)
225 } else if interrupted {
226 exitCode = 143
227 stderr += "\nCommand execution timed out or was interrupted"
228 }
229
230 if newCwd != "" {
231 s.cwd = strings.TrimSpace(newCwd)
232 }
233
234 return commandResult{
235 stdout: stdout,
236 stderr: stderr,
237 exitCode: exitCode,
238 interrupted: interrupted,
239 }
240}
241
242func (s *PersistentShell) killChildren() {
243 if s.cmd == nil || s.cmd.Process == nil {
244 return
245 }
246 p, err := process.NewProcess(int32(s.cmd.Process.Pid))
247 if err != nil {
248 logging.WarnPersist("could not kill persistent shell child processes", "err", err)
249 return
250 }
251
252 children, err := p.Children()
253 if err != nil {
254 logging.WarnPersist("could not kill persistent shell child processes", "err", err)
255 return
256 }
257
258 for _, child := range children {
259 if err := child.SendSignal(syscall.SIGTERM); err != nil {
260 logging.WarnPersist("could not kill persistent shell child processes", "err", err, "pid", child.Pid)
261 }
262 }
263}
264
265func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
266 if !s.isAlive {
267 return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
268 }
269
270 resultChan := make(chan commandResult)
271 s.commandQueue <- &commandExecution{
272 command: command,
273 timeout: time.Duration(timeoutMs) * time.Millisecond,
274 resultChan: resultChan,
275 ctx: ctx,
276 }
277
278 result := <-resultChan
279 return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
280}
281
282func (s *PersistentShell) Close() {
283 s.mu.Lock()
284 defer s.mu.Unlock()
285
286 if !s.isAlive {
287 return
288 }
289
290 s.stdin.Write([]byte("exit\n"))
291
292 if err := s.cmd.Process.Kill(); err != nil {
293 logging.WarnPersist("could not kill persistent shell", "err", err)
294 }
295 s.isAlive = false
296}
297
298func readFileOrEmpty(path string) string {
299 content, err := os.ReadFile(path)
300 if err != nil {
301 return ""
302 }
303 return string(content)
304}
305
306func fileSize(path string) int64 {
307 info, err := os.Stat(path)
308 if err != nil {
309 return 0
310 }
311 return info.Size()
312}