1package shell
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "os"
8 "os/exec"
9 "path/filepath"
10 "strings"
11 "sync"
12 "syscall"
13 "time"
14)
15
16type PersistentShell struct {
17 cmd *exec.Cmd
18 stdin *os.File
19 isAlive bool
20 cwd string
21 mu sync.Mutex
22 commandQueue chan *commandExecution
23}
24
25type commandExecution struct {
26 command string
27 timeout time.Duration
28 resultChan chan commandResult
29 ctx context.Context
30}
31
32type commandResult struct {
33 stdout string
34 stderr string
35 exitCode int
36 interrupted bool
37 err error
38}
39
40var (
41 shellInstance *PersistentShell
42 shellInstanceOnce sync.Once
43)
44
45func GetPersistentShell(workingDir string) *PersistentShell {
46 shellInstanceOnce.Do(func() {
47 shellInstance = newPersistentShell(workingDir)
48 })
49
50 if shellInstance == nil {
51 shellInstance = newPersistentShell(workingDir)
52 } else if !shellInstance.isAlive {
53 shellInstance = newPersistentShell(shellInstance.cwd)
54 }
55
56 return shellInstance
57}
58
59func newPersistentShell(cwd string) *PersistentShell {
60 shellPath := os.Getenv("SHELL")
61 if shellPath == "" {
62 shellPath = "/bin/bash"
63 }
64
65 cmd := exec.Command(shellPath, "-l")
66 cmd.Dir = cwd
67
68 stdinPipe, err := cmd.StdinPipe()
69 if err != nil {
70 return nil
71 }
72
73 cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
74
75 err = cmd.Start()
76 if err != nil {
77 return nil
78 }
79
80 shell := &PersistentShell{
81 cmd: cmd,
82 stdin: stdinPipe.(*os.File),
83 isAlive: true,
84 cwd: cwd,
85 commandQueue: make(chan *commandExecution, 10),
86 }
87
88 go func() {
89 defer func() {
90 if r := recover(); r != nil {
91 fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
92 shell.isAlive = false
93 close(shell.commandQueue)
94 }
95 }()
96 shell.processCommands()
97 }()
98
99 go func() {
100 err := cmd.Wait()
101 if err != nil {
102 // Log the error if needed
103 }
104 shell.isAlive = false
105 close(shell.commandQueue)
106 }()
107
108 return shell
109}
110
111func (s *PersistentShell) processCommands() {
112 for cmd := range s.commandQueue {
113 result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
114 cmd.resultChan <- result
115 }
116}
117
118func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
119 s.mu.Lock()
120 defer s.mu.Unlock()
121
122 if !s.isAlive {
123 return commandResult{
124 stderr: "Shell is not alive",
125 exitCode: 1,
126 err: errors.New("shell is not alive"),
127 }
128 }
129
130 tempDir := os.TempDir()
131 stdoutFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stdout-%d", time.Now().UnixNano()))
132 stderrFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stderr-%d", time.Now().UnixNano()))
133 statusFile := filepath.Join(tempDir, fmt.Sprintf("opencode-status-%d", time.Now().UnixNano()))
134 cwdFile := filepath.Join(tempDir, fmt.Sprintf("opencode-cwd-%d", time.Now().UnixNano()))
135
136 defer func() {
137 os.Remove(stdoutFile)
138 os.Remove(stderrFile)
139 os.Remove(statusFile)
140 os.Remove(cwdFile)
141 }()
142
143 fullCommand := fmt.Sprintf(`
144eval %s < /dev/null > %s 2> %s
145EXEC_EXIT_CODE=$?
146pwd > %s
147echo $EXEC_EXIT_CODE > %s
148`,
149 shellQuote(command),
150 shellQuote(stdoutFile),
151 shellQuote(stderrFile),
152 shellQuote(cwdFile),
153 shellQuote(statusFile),
154 )
155
156 _, err := s.stdin.Write([]byte(fullCommand + "\n"))
157 if err != nil {
158 return commandResult{
159 stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
160 exitCode: 1,
161 err: err,
162 }
163 }
164
165 interrupted := false
166
167 startTime := time.Now()
168
169 done := make(chan bool)
170 go func() {
171 for {
172 select {
173 case <-ctx.Done():
174 s.killChildren()
175 interrupted = true
176 done <- true
177 return
178
179 case <-time.After(10 * time.Millisecond):
180 if fileExists(statusFile) && fileSize(statusFile) > 0 {
181 done <- true
182 return
183 }
184
185 if timeout > 0 {
186 elapsed := time.Since(startTime)
187 if elapsed > timeout {
188 s.killChildren()
189 interrupted = true
190 done <- true
191 return
192 }
193 }
194 }
195 }
196 }()
197
198 <-done
199
200 stdout := readFileOrEmpty(stdoutFile)
201 stderr := readFileOrEmpty(stderrFile)
202 exitCodeStr := readFileOrEmpty(statusFile)
203 newCwd := readFileOrEmpty(cwdFile)
204
205 exitCode := 0
206 if exitCodeStr != "" {
207 fmt.Sscanf(exitCodeStr, "%d", &exitCode)
208 } else if interrupted {
209 exitCode = 143
210 stderr += "\nCommand execution timed out or was interrupted"
211 }
212
213 if newCwd != "" {
214 s.cwd = strings.TrimSpace(newCwd)
215 }
216
217 return commandResult{
218 stdout: stdout,
219 stderr: stderr,
220 exitCode: exitCode,
221 interrupted: interrupted,
222 }
223}
224
225func (s *PersistentShell) killChildren() {
226 if s.cmd == nil || s.cmd.Process == nil {
227 return
228 }
229
230 pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
231 output, err := pgrepCmd.Output()
232 if err != nil {
233 return
234 }
235
236 for pidStr := range strings.SplitSeq(string(output), "\n") {
237 if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
238 var pid int
239 fmt.Sscanf(pidStr, "%d", &pid)
240 if pid > 0 {
241 proc, err := os.FindProcess(pid)
242 if err == nil {
243 proc.Signal(syscall.SIGTERM)
244 }
245 }
246 }
247 }
248}
249
250func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
251 if !s.isAlive {
252 return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
253 }
254
255 timeout := time.Duration(timeoutMs) * time.Millisecond
256
257 resultChan := make(chan commandResult)
258 s.commandQueue <- &commandExecution{
259 command: command,
260 timeout: timeout,
261 resultChan: resultChan,
262 ctx: ctx,
263 }
264
265 result := <-resultChan
266 return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
267}
268
269func (s *PersistentShell) Close() {
270 s.mu.Lock()
271 defer s.mu.Unlock()
272
273 if !s.isAlive {
274 return
275 }
276
277 s.stdin.Write([]byte("exit\n"))
278
279 s.cmd.Process.Kill()
280 s.isAlive = false
281}
282
283func shellQuote(s string) string {
284 return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
285}
286
287func readFileOrEmpty(path string) string {
288 content, err := os.ReadFile(path)
289 if err != nil {
290 return ""
291 }
292 return string(content)
293}
294
295func fileExists(path string) bool {
296 _, err := os.Stat(path)
297 return err == nil
298}
299
300func fileSize(path string) int64 {
301 info, err := os.Stat(path)
302 if err != nil {
303 return 0
304 }
305 return info.Size()
306}