1package shell
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "os"
8 "os/exec"
9 "path/filepath"
10 "strings"
11 "sync"
12 "syscall"
13 "time"
14)
15
16type PersistentShell struct {
17 cmd *exec.Cmd
18 stdin *os.File
19 isAlive bool
20 cwd string
21 mu sync.Mutex
22 commandQueue chan *commandExecution
23}
24
25type commandExecution struct {
26 command string
27 timeout time.Duration
28 resultChan chan commandResult
29 ctx context.Context
30}
31
32type commandResult struct {
33 stdout string
34 stderr string
35 exitCode int
36 interrupted bool
37 err error
38}
39
40var (
41 shellInstance *PersistentShell
42 shellInstanceOnce sync.Once
43)
44
45func GetPersistentShell(workingDir string) *PersistentShell {
46 shellInstanceOnce.Do(func() {
47 shellInstance = newPersistentShell(workingDir)
48 })
49
50 if shellInstance == nil {
51 shellInstance = newPersistentShell(workingDir)
52 } else if !shellInstance.isAlive {
53 shellInstance = newPersistentShell(shellInstance.cwd)
54 }
55
56 return shellInstance
57}
58
59func newPersistentShell(cwd string) *PersistentShell {
60 // Default to environment variable
61 shellPath := os.Getenv("SHELL")
62 if shellPath == "" {
63 shellPath = "/bin/bash"
64 }
65
66 // Default shell args
67 shellArgs := []string{"-l"}
68
69 cmd := exec.Command(shellPath, shellArgs...)
70 cmd.Dir = cwd
71
72 stdinPipe, err := cmd.StdinPipe()
73 if err != nil {
74 return nil
75 }
76
77 cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
78
79 err = cmd.Start()
80 if err != nil {
81 return nil
82 }
83
84 shell := &PersistentShell{
85 cmd: cmd,
86 stdin: stdinPipe.(*os.File),
87 isAlive: true,
88 cwd: cwd,
89 commandQueue: make(chan *commandExecution, 10),
90 }
91
92 go func() {
93 defer func() {
94 if r := recover(); r != nil {
95 fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
96 shell.isAlive = false
97 close(shell.commandQueue)
98 }
99 }()
100 shell.processCommands()
101 }()
102
103 go func() {
104 err := cmd.Wait()
105 if err != nil {
106 // Log the error if needed
107 }
108 shell.isAlive = false
109 close(shell.commandQueue)
110 }()
111
112 return shell
113}
114
115func (s *PersistentShell) processCommands() {
116 for cmd := range s.commandQueue {
117 result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
118 cmd.resultChan <- result
119 }
120}
121
122func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
123 s.mu.Lock()
124 defer s.mu.Unlock()
125
126 if !s.isAlive {
127 return commandResult{
128 stderr: "Shell is not alive",
129 exitCode: 1,
130 err: errors.New("shell is not alive"),
131 }
132 }
133
134 tempDir := os.TempDir()
135 stdoutFile := filepath.Join(tempDir, fmt.Sprintf("crush-stdout-%d", time.Now().UnixNano()))
136 stderrFile := filepath.Join(tempDir, fmt.Sprintf("crush-stderr-%d", time.Now().UnixNano()))
137 statusFile := filepath.Join(tempDir, fmt.Sprintf("crush-status-%d", time.Now().UnixNano()))
138 cwdFile := filepath.Join(tempDir, fmt.Sprintf("crush-cwd-%d", time.Now().UnixNano()))
139
140 defer func() {
141 os.Remove(stdoutFile)
142 os.Remove(stderrFile)
143 os.Remove(statusFile)
144 os.Remove(cwdFile)
145 }()
146
147 fullCommand := fmt.Sprintf(`
148eval %s < /dev/null > %s 2> %s
149EXEC_EXIT_CODE=$?
150pwd > %s
151echo $EXEC_EXIT_CODE > %s
152`,
153 shellQuote(command),
154 shellQuote(stdoutFile),
155 shellQuote(stderrFile),
156 shellQuote(cwdFile),
157 shellQuote(statusFile),
158 )
159
160 _, err := s.stdin.Write([]byte(fullCommand + "\n"))
161 if err != nil {
162 return commandResult{
163 stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
164 exitCode: 1,
165 err: err,
166 }
167 }
168
169 interrupted := false
170
171 startTime := time.Now()
172
173 done := make(chan bool)
174 go func() {
175 // Use exponential backoff polling
176 pollInterval := 1 * time.Millisecond
177 maxPollInterval := 100 * time.Millisecond
178
179 ticker := time.NewTicker(pollInterval)
180 defer ticker.Stop()
181
182 for {
183 select {
184 case <-ctx.Done():
185 s.killChildren()
186 interrupted = true
187 done <- true
188 return
189
190 case <-ticker.C:
191 if fileExists(statusFile) && fileSize(statusFile) > 0 {
192 done <- true
193 return
194 }
195
196 if timeout > 0 {
197 elapsed := time.Since(startTime)
198 if elapsed > timeout {
199 s.killChildren()
200 interrupted = true
201 done <- true
202 return
203 }
204 }
205
206 // Exponential backoff to reduce CPU usage for longer-running commands
207 if pollInterval < maxPollInterval {
208 pollInterval = min(time.Duration(float64(pollInterval)*1.5), maxPollInterval)
209 ticker.Reset(pollInterval)
210 }
211 }
212 }
213 }()
214
215 <-done
216
217 stdout := readFileOrEmpty(stdoutFile)
218 stderr := readFileOrEmpty(stderrFile)
219 exitCodeStr := readFileOrEmpty(statusFile)
220 newCwd := readFileOrEmpty(cwdFile)
221
222 exitCode := 0
223 if exitCodeStr != "" {
224 fmt.Sscanf(exitCodeStr, "%d", &exitCode)
225 } else if interrupted {
226 exitCode = 143
227 stderr += "\nCommand execution timed out or was interrupted"
228 }
229
230 if newCwd != "" {
231 s.cwd = strings.TrimSpace(newCwd)
232 }
233
234 return commandResult{
235 stdout: stdout,
236 stderr: stderr,
237 exitCode: exitCode,
238 interrupted: interrupted,
239 }
240}
241
242func (s *PersistentShell) killChildren() {
243 if s.cmd == nil || s.cmd.Process == nil {
244 return
245 }
246
247 pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
248 output, err := pgrepCmd.Output()
249 if err != nil {
250 return
251 }
252
253 for pidStr := range strings.SplitSeq(string(output), "\n") {
254 if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
255 var pid int
256 fmt.Sscanf(pidStr, "%d", &pid)
257 if pid > 0 {
258 proc, err := os.FindProcess(pid)
259 if err == nil {
260 proc.Signal(syscall.SIGTERM)
261 }
262 }
263 }
264 }
265}
266
267func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
268 if !s.isAlive {
269 return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
270 }
271
272 timeout := time.Duration(timeoutMs) * time.Millisecond
273
274 resultChan := make(chan commandResult)
275 s.commandQueue <- &commandExecution{
276 command: command,
277 timeout: timeout,
278 resultChan: resultChan,
279 ctx: ctx,
280 }
281
282 result := <-resultChan
283 return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
284}
285
286func (s *PersistentShell) Close() {
287 s.mu.Lock()
288 defer s.mu.Unlock()
289
290 if !s.isAlive {
291 return
292 }
293
294 s.stdin.Write([]byte("exit\n"))
295
296 s.cmd.Process.Kill()
297 s.isAlive = false
298}
299
300func shellQuote(s string) string {
301 return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
302}
303
304func readFileOrEmpty(path string) string {
305 content, err := os.ReadFile(path)
306 if err != nil {
307 return ""
308 }
309 return string(content)
310}
311
312func fileExists(path string) bool {
313 _, err := os.Stat(path)
314 return err == nil
315}
316
317func fileSize(path string) int64 {
318 info, err := os.Stat(path)
319 if err != nil {
320 return 0
321 }
322 return info.Size()
323}