1package shell
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "os"
8 "os/exec"
9 "path/filepath"
10 "strings"
11 "sync"
12 "syscall"
13 "time"
14)
15
16type PersistentShell struct {
17 cmd *exec.Cmd
18 stdin *os.File
19 isAlive bool
20 cwd string
21 mu sync.Mutex
22 commandQueue chan *commandExecution
23}
24
25type commandExecution struct {
26 command string
27 timeout time.Duration
28 resultChan chan commandResult
29 ctx context.Context
30}
31
32type commandResult struct {
33 stdout string
34 stderr string
35 exitCode int
36 interrupted bool
37 err error
38}
39
40var (
41 shellInstance *PersistentShell
42 shellInstanceOnce sync.Once
43)
44
45func GetPersistentShell(workingDir string) *PersistentShell {
46 shellInstanceOnce.Do(func() {
47 shellInstance = newPersistentShell(workingDir)
48 })
49
50 if shellInstance == nil || !shellInstance.isAlive {
51 shellInstance = newPersistentShell(shellInstance.cwd)
52 }
53
54 return shellInstance
55}
56
57func newPersistentShell(cwd string) *PersistentShell {
58 shellPath := os.Getenv("SHELL")
59 if shellPath == "" {
60 shellPath = "/bin/bash"
61 }
62
63 cmd := exec.Command(shellPath, "-l")
64 cmd.Dir = cwd
65
66 stdinPipe, err := cmd.StdinPipe()
67 if err != nil {
68 return nil
69 }
70
71 cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
72
73 err = cmd.Start()
74 if err != nil {
75 return nil
76 }
77
78 shell := &PersistentShell{
79 cmd: cmd,
80 stdin: stdinPipe.(*os.File),
81 isAlive: true,
82 cwd: cwd,
83 commandQueue: make(chan *commandExecution, 10),
84 }
85
86 go func() {
87 defer func() {
88 if r := recover(); r != nil {
89 fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
90 shell.isAlive = false
91 close(shell.commandQueue)
92 }
93 }()
94 shell.processCommands()
95 }()
96
97 go func() {
98 err := cmd.Wait()
99 if err != nil {
100 // Log the error if needed
101 }
102 shell.isAlive = false
103 close(shell.commandQueue)
104 }()
105
106 return shell
107}
108
109func (s *PersistentShell) processCommands() {
110 for cmd := range s.commandQueue {
111 result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
112 cmd.resultChan <- result
113 }
114}
115
116func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
117 s.mu.Lock()
118 defer s.mu.Unlock()
119
120 if !s.isAlive {
121 return commandResult{
122 stderr: "Shell is not alive",
123 exitCode: 1,
124 err: errors.New("shell is not alive"),
125 }
126 }
127
128 tempDir := os.TempDir()
129 stdoutFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stdout-%d", time.Now().UnixNano()))
130 stderrFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stderr-%d", time.Now().UnixNano()))
131 statusFile := filepath.Join(tempDir, fmt.Sprintf("opencode-status-%d", time.Now().UnixNano()))
132 cwdFile := filepath.Join(tempDir, fmt.Sprintf("opencode-cwd-%d", time.Now().UnixNano()))
133
134 defer func() {
135 os.Remove(stdoutFile)
136 os.Remove(stderrFile)
137 os.Remove(statusFile)
138 os.Remove(cwdFile)
139 }()
140
141 fullCommand := fmt.Sprintf(`
142eval %s < /dev/null > %s 2> %s
143EXEC_EXIT_CODE=$?
144pwd > %s
145echo $EXEC_EXIT_CODE > %s
146`,
147 shellQuote(command),
148 shellQuote(stdoutFile),
149 shellQuote(stderrFile),
150 shellQuote(cwdFile),
151 shellQuote(statusFile),
152 )
153
154 _, err := s.stdin.Write([]byte(fullCommand + "\n"))
155 if err != nil {
156 return commandResult{
157 stderr: fmt.Sprintf("Failed to write command to shell: %v", err),
158 exitCode: 1,
159 err: err,
160 }
161 }
162
163 interrupted := false
164
165 startTime := time.Now()
166
167 done := make(chan bool)
168 go func() {
169 for {
170 select {
171 case <-ctx.Done():
172 s.killChildren()
173 interrupted = true
174 done <- true
175 return
176
177 case <-time.After(10 * time.Millisecond):
178 if fileExists(statusFile) && fileSize(statusFile) > 0 {
179 done <- true
180 return
181 }
182
183 if timeout > 0 {
184 elapsed := time.Since(startTime)
185 if elapsed > timeout {
186 s.killChildren()
187 interrupted = true
188 done <- true
189 return
190 }
191 }
192 }
193 }
194 }()
195
196 <-done
197
198 stdout := readFileOrEmpty(stdoutFile)
199 stderr := readFileOrEmpty(stderrFile)
200 exitCodeStr := readFileOrEmpty(statusFile)
201 newCwd := readFileOrEmpty(cwdFile)
202
203 exitCode := 0
204 if exitCodeStr != "" {
205 fmt.Sscanf(exitCodeStr, "%d", &exitCode)
206 } else if interrupted {
207 exitCode = 143
208 stderr += "\nCommand execution timed out or was interrupted"
209 }
210
211 if newCwd != "" {
212 s.cwd = strings.TrimSpace(newCwd)
213 }
214
215 return commandResult{
216 stdout: stdout,
217 stderr: stderr,
218 exitCode: exitCode,
219 interrupted: interrupted,
220 }
221}
222
223func (s *PersistentShell) killChildren() {
224 if s.cmd == nil || s.cmd.Process == nil {
225 return
226 }
227
228 pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
229 output, err := pgrepCmd.Output()
230 if err != nil {
231 return
232 }
233
234 for pidStr := range strings.SplitSeq(string(output), "\n") {
235 if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
236 var pid int
237 fmt.Sscanf(pidStr, "%d", &pid)
238 if pid > 0 {
239 proc, err := os.FindProcess(pid)
240 if err == nil {
241 proc.Signal(syscall.SIGTERM)
242 }
243 }
244 }
245 }
246}
247
248func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
249 if !s.isAlive {
250 return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
251 }
252
253 timeout := time.Duration(timeoutMs) * time.Millisecond
254
255 resultChan := make(chan commandResult)
256 s.commandQueue <- &commandExecution{
257 command: command,
258 timeout: timeout,
259 resultChan: resultChan,
260 ctx: ctx,
261 }
262
263 result := <-resultChan
264 return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
265}
266
267func (s *PersistentShell) Close() {
268 s.mu.Lock()
269 defer s.mu.Unlock()
270
271 if !s.isAlive {
272 return
273 }
274
275 s.stdin.Write([]byte("exit\n"))
276
277 s.cmd.Process.Kill()
278 s.isAlive = false
279}
280
281func shellQuote(s string) string {
282 return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
283}
284
285func readFileOrEmpty(path string) string {
286 content, err := os.ReadFile(path)
287 if err != nil {
288 return ""
289 }
290 return string(content)
291}
292
293func fileExists(path string) bool {
294 _, err := os.Stat(path)
295 return err == nil
296}
297
298func fileSize(path string) int64 {
299 info, err := os.Stat(path)
300 if err != nil {
301 return 0
302 }
303 return info.Size()
304}