1package shell
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "os"
8 "os/exec"
9 "path/filepath"
10 "strings"
11 "sync"
12 "syscall"
13 "time"
14)
15
16type PersistentShell struct {
17 cmd *exec.Cmd
18 stdin *os.File
19 isAlive bool
20 cwd string
21 mu sync.Mutex
22 commandQueue chan *commandExecution
23}
24
25type commandExecution struct {
26 command string
27 timeout time.Duration
28 resultChan chan commandResult
29 ctx context.Context
30}
31
32type commandResult struct {
33 stdout string
34 stderr string
35 exitCode int
36 interrupted bool
37 err error
38}
39
40var (
41 shellInstance *PersistentShell
42 shellInstanceOnce sync.Once
43)
44
45func GetPersistentShell(workingDir string) *PersistentShell {
46 shellInstanceOnce.Do(func() {
47 shellInstance = newPersistentShell(workingDir)
48 })
49
50 if !shellInstance.isAlive {
51 shellInstance = newPersistentShell(shellInstance.cwd)
52 }
53
54 return shellInstance
55}
56
57func newPersistentShell(cwd string) *PersistentShell {
58 shellPath := os.Getenv("SHELL")
59 if shellPath == "" {
60 shellPath = "/bin/bash"
61 }
62
63 cmd := exec.Command(shellPath, "-l")
64 cmd.Dir = cwd
65
66 stdinPipe, err := cmd.StdinPipe()
67 if err != nil {
68 return nil
69 }
70
71 cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
72
73 err = cmd.Start()
74 if err != nil {
75 return nil
76 }
77
78 shell := &PersistentShell{
79 cmd: cmd,
80 stdin: stdinPipe.(*os.File),
81 isAlive: true,
82 cwd: cwd,
83 commandQueue: make(chan *commandExecution, 10),
84 }
85
86 go shell.processCommands()
87
88 go func() {
89 err := cmd.Wait()
90 if err != nil {
91 }
92 shell.isAlive = false
93 close(shell.commandQueue)
94 }()
95
96 return shell
97}
98
99func (s *PersistentShell) processCommands() {
100 for cmd := range s.commandQueue {
101 result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
102 cmd.resultChan <- result
103 }
104}
105
106func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
107 s.mu.Lock()
108 defer s.mu.Unlock()
109
110 if !s.isAlive {
111 return commandResult{
112 stderr: "Shell is not alive",
113 exitCode: 1,
114 err: errors.New("shell is not alive"),
115 }
116 }
117
118 tempDir := os.TempDir()
119 stdoutFile := filepath.Join(tempDir, fmt.Sprintf("termai-stdout-%d", time.Now().UnixNano()))
120 stderrFile := filepath.Join(tempDir, fmt.Sprintf("termai-stderr-%d", time.Now().UnixNano()))
121 statusFile := filepath.Join(tempDir, fmt.Sprintf("termai-status-%d", time.Now().UnixNano()))
122 cwdFile := filepath.Join(tempDir, fmt.Sprintf("termai-cwd-%d", time.Now().UnixNano()))
123
124 defer func() {
125 os.Remove(stdoutFile)
126 os.Remove(stderrFile)
127 os.Remove(statusFile)
128 os.Remove(cwdFile)
129 }()
130
131 fullCommand := fmt.Sprintf(`
132eval %s < /dev/null > %s 2> %s
133EXEC_EXIT_CODE=$?
134pwd > %s
135echo $EXEC_EXIT_CODE > %s
136`,
137 shellQuote(command),
138 shellQuote(stdoutFile),
139 shellQuote(stderrFile),
140 shellQuote(cwdFile),
141 shellQuote(statusFile),
142 )
143
144 _, err := s.stdin.Write([]byte(fullCommand + "\n"))
145 if err != nil {
146 return commandResult{
147 stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
148 exitCode: 1,
149 err: err,
150 }
151 }
152
153 interrupted := false
154
155 startTime := time.Now()
156
157 done := make(chan bool)
158 go func() {
159 for {
160 select {
161 case <-ctx.Done():
162 s.killChildren()
163 interrupted = true
164 done <- true
165 return
166
167 case <-time.After(10 * time.Millisecond):
168 if fileExists(statusFile) && fileSize(statusFile) > 0 {
169 done <- true
170 return
171 }
172
173 if timeout > 0 {
174 elapsed := time.Since(startTime)
175 if elapsed > timeout {
176 s.killChildren()
177 interrupted = true
178 done <- true
179 return
180 }
181 }
182 }
183 }
184 }()
185
186 <-done
187
188 stdout := readFileOrEmpty(stdoutFile)
189 stderr := readFileOrEmpty(stderrFile)
190 exitCodeStr := readFileOrEmpty(statusFile)
191 newCwd := readFileOrEmpty(cwdFile)
192
193 exitCode := 0
194 if exitCodeStr != "" {
195 fmt.Sscanf(exitCodeStr, "%d", &exitCode)
196 } else if interrupted {
197 exitCode = 143
198 stderr += "\nCommand execution timed out or was interrupted"
199 }
200
201 if newCwd != "" {
202 s.cwd = strings.TrimSpace(newCwd)
203 }
204
205 return commandResult{
206 stdout: stdout,
207 stderr: stderr,
208 exitCode: exitCode,
209 interrupted: interrupted,
210 }
211}
212
213func (s *PersistentShell) killChildren() {
214 if s.cmd == nil || s.cmd.Process == nil {
215 return
216 }
217
218 pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
219 output, err := pgrepCmd.Output()
220 if err != nil {
221 return
222 }
223
224 for pidStr := range strings.SplitSeq(string(output), "\n") {
225 if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
226 var pid int
227 fmt.Sscanf(pidStr, "%d", &pid)
228 if pid > 0 {
229 proc, err := os.FindProcess(pid)
230 if err == nil {
231 proc.Signal(syscall.SIGTERM)
232 }
233 }
234 }
235 }
236}
237
238func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
239 if !s.isAlive {
240 return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
241 }
242
243 timeout := time.Duration(timeoutMs) * time.Millisecond
244
245 resultChan := make(chan commandResult)
246 s.commandQueue <- &commandExecution{
247 command: command,
248 timeout: timeout,
249 resultChan: resultChan,
250 ctx: ctx,
251 }
252
253 result := <-resultChan
254 return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
255}
256
257func (s *PersistentShell) Close() {
258 s.mu.Lock()
259 defer s.mu.Unlock()
260
261 if !s.isAlive {
262 return
263 }
264
265 s.stdin.Write([]byte("exit\n"))
266
267 s.cmd.Process.Kill()
268 s.isAlive = false
269}
270
271func shellQuote(s string) string {
272 return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
273}
274
275func readFileOrEmpty(path string) string {
276 content, err := os.ReadFile(path)
277 if err != nil {
278 return ""
279 }
280 return string(content)
281}
282
283func fileExists(path string) bool {
284 _, err := os.Stat(path)
285 return err == nil
286}
287
288func fileSize(path string) int64 {
289 info, err := os.Stat(path)
290 if err != nil {
291 return 0
292 }
293 return info.Size()
294}