shell.go

  1package shell
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"os"
  8	"os/exec"
  9	"path/filepath"
 10	"strings"
 11	"sync"
 12	"syscall"
 13	"time"
 14)
 15
 16type PersistentShell struct {
 17	cmd          *exec.Cmd
 18	stdin        *os.File
 19	isAlive      bool
 20	cwd          string
 21	mu           sync.Mutex
 22	commandQueue chan *commandExecution
 23}
 24
 25type commandExecution struct {
 26	command    string
 27	timeout    time.Duration
 28	resultChan chan commandResult
 29	ctx        context.Context
 30}
 31
 32type commandResult struct {
 33	stdout      string
 34	stderr      string
 35	exitCode    int
 36	interrupted bool
 37	err         error
 38}
 39
 40var (
 41	shellInstance     *PersistentShell
 42	shellInstanceOnce sync.Once
 43)
 44
 45func GetPersistentShell(workingDir string) *PersistentShell {
 46	shellInstanceOnce.Do(func() {
 47		shellInstance = newPersistentShell(workingDir)
 48	})
 49
 50	if !shellInstance.isAlive {
 51		shellInstance = newPersistentShell(shellInstance.cwd)
 52	}
 53
 54	return shellInstance
 55}
 56
 57func newPersistentShell(cwd string) *PersistentShell {
 58	shellPath := os.Getenv("SHELL")
 59	if shellPath == "" {
 60		shellPath = "/bin/bash"
 61	}
 62
 63	cmd := exec.Command(shellPath, "-l")
 64	cmd.Dir = cwd
 65
 66	stdinPipe, err := cmd.StdinPipe()
 67	if err != nil {
 68		return nil
 69	}
 70
 71	cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
 72
 73	err = cmd.Start()
 74	if err != nil {
 75		return nil
 76	}
 77
 78	shell := &PersistentShell{
 79		cmd:          cmd,
 80		stdin:        stdinPipe.(*os.File),
 81		isAlive:      true,
 82		cwd:          cwd,
 83		commandQueue: make(chan *commandExecution, 10),
 84	}
 85
 86	go func() {
 87		defer func() {
 88			if r := recover(); r != nil {
 89				fmt.Fprintf(os.Stderr, "Panic in shell command processor: %v\n", r)
 90				shell.isAlive = false
 91				close(shell.commandQueue)
 92			}
 93		}()
 94		shell.processCommands()
 95	}()
 96
 97	go func() {
 98		err := cmd.Wait()
 99		if err != nil {
100			// Log the error if needed
101		}
102		shell.isAlive = false
103		close(shell.commandQueue)
104	}()
105
106	return shell
107}
108
109func (s *PersistentShell) processCommands() {
110	for cmd := range s.commandQueue {
111		result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
112		cmd.resultChan <- result
113	}
114}
115
116func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
117	s.mu.Lock()
118	defer s.mu.Unlock()
119
120	if !s.isAlive {
121		return commandResult{
122			stderr:   "Shell is not alive",
123			exitCode: 1,
124			err:      errors.New("shell is not alive"),
125		}
126	}
127
128	tempDir := os.TempDir()
129	stdoutFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stdout-%d", time.Now().UnixNano()))
130	stderrFile := filepath.Join(tempDir, fmt.Sprintf("opencode-stderr-%d", time.Now().UnixNano()))
131	statusFile := filepath.Join(tempDir, fmt.Sprintf("opencode-status-%d", time.Now().UnixNano()))
132	cwdFile := filepath.Join(tempDir, fmt.Sprintf("opencode-cwd-%d", time.Now().UnixNano()))
133
134	defer func() {
135		os.Remove(stdoutFile)
136		os.Remove(stderrFile)
137		os.Remove(statusFile)
138		os.Remove(cwdFile)
139	}()
140
141	fullCommand := fmt.Sprintf(`
142eval %s < /dev/null > %s 2> %s
143EXEC_EXIT_CODE=$?
144pwd > %s
145echo $EXEC_EXIT_CODE > %s
146`,
147		shellQuote(command),
148		shellQuote(stdoutFile),
149		shellQuote(stderrFile),
150		shellQuote(cwdFile),
151		shellQuote(statusFile),
152	)
153
154	_, err := s.stdin.Write([]byte(fullCommand + "\n"))
155	if err != nil {
156		return commandResult{
157			stderr:   fmt.Sprintf("Failed to write command to shell: %v", err),
158			exitCode: 1,
159			err:      err,
160		}
161	}
162
163	interrupted := false
164
165	startTime := time.Now()
166
167	done := make(chan bool)
168	go func() {
169		for {
170			select {
171			case <-ctx.Done():
172				s.killChildren()
173				interrupted = true
174				done <- true
175				return
176
177			case <-time.After(10 * time.Millisecond):
178				if fileExists(statusFile) && fileSize(statusFile) > 0 {
179					done <- true
180					return
181				}
182
183				if timeout > 0 {
184					elapsed := time.Since(startTime)
185					if elapsed > timeout {
186						s.killChildren()
187						interrupted = true
188						done <- true
189						return
190					}
191				}
192			}
193		}
194	}()
195
196	<-done
197
198	stdout := readFileOrEmpty(stdoutFile)
199	stderr := readFileOrEmpty(stderrFile)
200	exitCodeStr := readFileOrEmpty(statusFile)
201	newCwd := readFileOrEmpty(cwdFile)
202
203	exitCode := 0
204	if exitCodeStr != "" {
205		fmt.Sscanf(exitCodeStr, "%d", &exitCode)
206	} else if interrupted {
207		exitCode = 143
208		stderr += "\nCommand execution timed out or was interrupted"
209	}
210
211	if newCwd != "" {
212		s.cwd = strings.TrimSpace(newCwd)
213	}
214
215	return commandResult{
216		stdout:      stdout,
217		stderr:      stderr,
218		exitCode:    exitCode,
219		interrupted: interrupted,
220	}
221}
222
223func (s *PersistentShell) killChildren() {
224	if s.cmd == nil || s.cmd.Process == nil {
225		return
226	}
227
228	pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
229	output, err := pgrepCmd.Output()
230	if err != nil {
231		return
232	}
233
234	for pidStr := range strings.SplitSeq(string(output), "\n") {
235		if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
236			var pid int
237			fmt.Sscanf(pidStr, "%d", &pid)
238			if pid > 0 {
239				proc, err := os.FindProcess(pid)
240				if err == nil {
241					proc.Signal(syscall.SIGTERM)
242				}
243			}
244		}
245	}
246}
247
248func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
249	if !s.isAlive {
250		return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
251	}
252
253	timeout := time.Duration(timeoutMs) * time.Millisecond
254
255	resultChan := make(chan commandResult)
256	s.commandQueue <- &commandExecution{
257		command:    command,
258		timeout:    timeout,
259		resultChan: resultChan,
260		ctx:        ctx,
261	}
262
263	result := <-resultChan
264	return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
265}
266
267func (s *PersistentShell) Close() {
268	s.mu.Lock()
269	defer s.mu.Unlock()
270
271	if !s.isAlive {
272		return
273	}
274
275	s.stdin.Write([]byte("exit\n"))
276
277	s.cmd.Process.Kill()
278	s.isAlive = false
279}
280
281func shellQuote(s string) string {
282	return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
283}
284
285func readFileOrEmpty(path string) string {
286	content, err := os.ReadFile(path)
287	if err != nil {
288		return ""
289	}
290	return string(content)
291}
292
293func fileExists(path string) bool {
294	_, err := os.Stat(path)
295	return err == nil
296}
297
298func fileSize(path string) int64 {
299	info, err := os.Stat(path)
300	if err != nil {
301		return 0
302	}
303	return info.Size()
304}