shell.go

  1package shell
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"os"
  8	"os/exec"
  9	"path/filepath"
 10	"strings"
 11	"sync"
 12	"syscall"
 13	"time"
 14)
 15
 16type PersistentShell struct {
 17	cmd          *exec.Cmd
 18	stdin        *os.File
 19	isAlive      bool
 20	cwd          string
 21	mu           sync.Mutex
 22	commandQueue chan *commandExecution
 23}
 24
 25type commandExecution struct {
 26	command    string
 27	timeout    time.Duration
 28	resultChan chan commandResult
 29	ctx        context.Context
 30}
 31
 32type commandResult struct {
 33	stdout      string
 34	stderr      string
 35	exitCode    int
 36	interrupted bool
 37	err         error
 38}
 39
 40var (
 41	shellInstance     *PersistentShell
 42	shellInstanceOnce sync.Once
 43)
 44
 45func GetPersistentShell(workingDir string) *PersistentShell {
 46	shellInstanceOnce.Do(func() {
 47		shellInstance = newPersistentShell(workingDir)
 48	})
 49
 50	if shellInstance == nil {
 51		shellInstance = newPersistentShell(workingDir)
 52	} else if !shellInstance.isAlive {
 53		shellInstance = newPersistentShell(shellInstance.cwd)
 54	}
 55
 56	return shellInstance
 57}
 58
 59func newPersistentShell(cwd string) *PersistentShell {
 60	shellPath := os.Getenv("SHELL")
 61	if shellPath == "" {
 62		shellPath = "/bin/bash"
 63	}
 64
 65	cmd := exec.Command(shellPath, "-l")
 66	cmd.Dir = cwd
 67
 68	stdinPipe, err := cmd.StdinPipe()
 69	if err != nil {
 70		return nil
 71	}
 72
 73	cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
 74
 75	err = cmd.Start()
 76	if err != nil {
 77		return nil
 78	}
 79
 80	shell := &PersistentShell{
 81		cmd:          cmd,
 82		stdin:        stdinPipe.(*os.File),
 83		isAlive:      true,
 84		cwd:          cwd,
 85		commandQueue: make(chan *commandExecution, 10),
 86	}
 87
 88	go func() {
 89		defer func() {
 90			if r := recover(); r != nil {
 91				fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
 92				shell.isAlive = false
 93				close(shell.commandQueue)
 94			}
 95		}()
 96		shell.processCommands()
 97	}()
 98
 99	go func() {
100		err := cmd.Wait()
101		if err != nil {
102			// Log the error if needed
103		}
104		shell.isAlive = false
105		close(shell.commandQueue)
106	}()
107
108	return shell
109}
110
111func (s *PersistentShell) processCommands() {
112	for cmd := range s.commandQueue {
113		result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
114		cmd.resultChan <- result
115	}
116}
117
118func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
119	s.mu.Lock()
120	defer s.mu.Unlock()
121
122	if !s.isAlive {
123		return commandResult{
124			stderr:   "Shell is not alive",
125			exitCode: 1,
126			err:      errors.New("shell is not alive"),
127		}
128	}
129
130	tempDir := os.TempDir()
131	stdoutFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stdout-%d", time.Now().UnixNano()))
132	stderrFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stderr-%d", time.Now().UnixNano()))
133	statusFile := filepath.Join(tempDir, fmt.Sprintf("opencode-status-%d", time.Now().UnixNano()))
134	cwdFile := filepath.Join(tempDir, fmt.Sprintf("opencode-cwd-%d", time.Now().UnixNano()))
135
136	defer func() {
137		os.Remove(stdoutFile)
138		os.Remove(stderrFile)
139		os.Remove(statusFile)
140		os.Remove(cwdFile)
141	}()
142
143	fullCommand := fmt.Sprintf(`
144eval %s < /dev/null > %s 2> %s
145EXEC_EXIT_CODE=$?
146pwd > %s
147echo $EXEC_EXIT_CODE > %s
148`,
149		shellQuote(command),
150		shellQuote(stdoutFile),
151		shellQuote(stderrFile),
152		shellQuote(cwdFile),
153		shellQuote(statusFile),
154	)
155
156	_, err := s.stdin.Write([]byte(fullCommand + "\n"))
157	if err != nil {
158		return commandResult{
159			stderr:   fmt.Sprintf("Failed to write command to shell: %v", err),
160			exitCode: 1,
161			err:      err,
162		}
163	}
164
165	interrupted := false
166
167	startTime := time.Now()
168
169	done := make(chan bool)
170	go func() {
171		for {
172			select {
173			case <-ctx.Done():
174				s.killChildren()
175				interrupted = true
176				done <- true
177				return
178
179			case <-time.After(10 * time.Millisecond):
180				if fileExists(statusFile) && fileSize(statusFile) > 0 {
181					done <- true
182					return
183				}
184
185				if timeout > 0 {
186					elapsed := time.Since(startTime)
187					if elapsed > timeout {
188						s.killChildren()
189						interrupted = true
190						done <- true
191						return
192					}
193				}
194			}
195		}
196	}()
197
198	<-done
199
200	stdout := readFileOrEmpty(stdoutFile)
201	stderr := readFileOrEmpty(stderrFile)
202	exitCodeStr := readFileOrEmpty(statusFile)
203	newCwd := readFileOrEmpty(cwdFile)
204
205	exitCode := 0
206	if exitCodeStr != "" {
207		fmt.Sscanf(exitCodeStr, "%d", &exitCode)
208	} else if interrupted {
209		exitCode = 143
210		stderr += "\nCommand execution timed out or was interrupted"
211	}
212
213	if newCwd != "" {
214		s.cwd = strings.TrimSpace(newCwd)
215	}
216
217	return commandResult{
218		stdout:      stdout,
219		stderr:      stderr,
220		exitCode:    exitCode,
221		interrupted: interrupted,
222	}
223}
224
225func (s *PersistentShell) killChildren() {
226	if s.cmd == nil || s.cmd.Process == nil {
227		return
228	}
229
230	pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
231	output, err := pgrepCmd.Output()
232	if err != nil {
233		return
234	}
235
236	for pidStr := range strings.SplitSeq(string(output), "\n") {
237		if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
238			var pid int
239			fmt.Sscanf(pidStr, "%d", &pid)
240			if pid > 0 {
241				proc, err := os.FindProcess(pid)
242				if err == nil {
243					proc.Signal(syscall.SIGTERM)
244				}
245			}
246		}
247	}
248}
249
250func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
251	if !s.isAlive {
252		return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
253	}
254
255	timeout := time.Duration(timeoutMs) * time.Millisecond
256
257	resultChan := make(chan commandResult)
258	s.commandQueue <- &commandExecution{
259		command:    command,
260		timeout:    timeout,
261		resultChan: resultChan,
262		ctx:        ctx,
263	}
264
265	result := <-resultChan
266	return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
267}
268
269func (s *PersistentShell) Close() {
270	s.mu.Lock()
271	defer s.mu.Unlock()
272
273	if !s.isAlive {
274		return
275	}
276
277	s.stdin.Write([]byte("exit\n"))
278
279	s.cmd.Process.Kill()
280	s.isAlive = false
281}
282
283func shellQuote(s string) string {
284	return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
285}
286
287func readFileOrEmpty(path string) string {
288	content, err := os.ReadFile(path)
289	if err != nil {
290		return ""
291	}
292	return string(content)
293}
294
295func fileExists(path string) bool {
296	_, err := os.Stat(path)
297	return err == nil
298}
299
300func fileSize(path string) int64 {
301	info, err := os.Stat(path)
302	if err != nil {
303		return 0
304	}
305	return info.Size()
306}