shell.go

  1package shell
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"os"
  8	"os/exec"
  9	"path/filepath"
 10	"strings"
 11	"sync"
 12	"syscall"
 13	"time"
 14)
 15
 16type PersistentShell struct {
 17	cmd          *exec.Cmd
 18	stdin        *os.File
 19	isAlive      bool
 20	cwd          string
 21	mu           sync.Mutex
 22	commandQueue chan *commandExecution
 23}
 24
 25type commandExecution struct {
 26	command    string
 27	timeout    time.Duration
 28	resultChan chan commandResult
 29	ctx        context.Context
 30}
 31
 32type commandResult struct {
 33	stdout      string
 34	stderr      string
 35	exitCode    int
 36	interrupted bool
 37	err         error
 38}
 39
 40var (
 41	shellInstance     *PersistentShell
 42	shellInstanceOnce sync.Once
 43)
 44
 45func GetPersistentShell(workingDir string) *PersistentShell {
 46	shellInstanceOnce.Do(func() {
 47		shellInstance = newPersistentShell(workingDir)
 48	})
 49
 50	if !shellInstance.isAlive {
 51		shellInstance = newPersistentShell(shellInstance.cwd)
 52	}
 53
 54	return shellInstance
 55}
 56
 57func newPersistentShell(cwd string) *PersistentShell {
 58	shellPath := os.Getenv("SHELL")
 59	if shellPath == "" {
 60		shellPath = "/bin/bash"
 61	}
 62
 63	cmd := exec.Command(shellPath, "-l")
 64	cmd.Dir = cwd
 65
 66	stdinPipe, err := cmd.StdinPipe()
 67	if err != nil {
 68		return nil
 69	}
 70
 71	cmd.Env = append(os.Environ(), "GIT_EDITOR=true")
 72
 73	err = cmd.Start()
 74	if err != nil {
 75		return nil
 76	}
 77
 78	shell := &PersistentShell{
 79		cmd:          cmd,
 80		stdin:        stdinPipe.(*os.File),
 81		isAlive:      true,
 82		cwd:          cwd,
 83		commandQueue: make(chan *commandExecution, 10),
 84	}
 85
 86	go shell.processCommands()
 87
 88	go func() {
 89		err := cmd.Wait()
 90		if err != nil {
 91		}
 92		shell.isAlive = false
 93		close(shell.commandQueue)
 94	}()
 95
 96	return shell
 97}
 98
 99func (s *PersistentShell) processCommands() {
100	for cmd := range s.commandQueue {
101		result := s.execCommand(cmd.command, cmd.timeout, cmd.ctx)
102		cmd.resultChan <- result
103	}
104}
105
106func (s *PersistentShell) execCommand(command string, timeout time.Duration, ctx context.Context) commandResult {
107	s.mu.Lock()
108	defer s.mu.Unlock()
109
110	if !s.isAlive {
111		return commandResult{
112			stderr:   "Shell is not alive",
113			exitCode: 1,
114			err:      errors.New("shell is not alive"),
115		}
116	}
117
118	tempDir := os.TempDir()
119	stdoutFile := filepath.Join(tempDir, fmt.Sprintf("termai-stdout-%d", time.Now().UnixNano()))
120	stderrFile := filepath.Join(tempDir, fmt.Sprintf("termai-stderr-%d", time.Now().UnixNano()))
121	statusFile := filepath.Join(tempDir, fmt.Sprintf("termai-status-%d", time.Now().UnixNano()))
122	cwdFile := filepath.Join(tempDir, fmt.Sprintf("termai-cwd-%d", time.Now().UnixNano()))
123
124	defer func() {
125		os.Remove(stdoutFile)
126		os.Remove(stderrFile)
127		os.Remove(statusFile)
128		os.Remove(cwdFile)
129	}()
130
131	fullCommand := fmt.Sprintf(`
132eval %s < /dev/null > %s 2> %s
133EXEC_EXIT_CODE=$?
134pwd > %s
135echo $EXEC_EXIT_CODE > %s
136`,
137		shellQuote(command),
138		shellQuote(stdoutFile),
139		shellQuote(stderrFile),
140		shellQuote(cwdFile),
141		shellQuote(statusFile),
142	)
143
144	_, err := s.stdin.Write([]byte(fullCommand + "\n"))
145	if err != nil {
146		return commandResult{
147			stderr:   fmt.Sprintf("Failed to write command to shell: %v", err),
148			exitCode: 1,
149			err:      err,
150		}
151	}
152
153	interrupted := false
154
155	startTime := time.Now()
156
157	done := make(chan bool)
158	go func() {
159		for {
160			select {
161			case <-ctx.Done():
162				s.killChildren()
163				interrupted = true
164				done <- true
165				return
166
167			case <-time.After(10 * time.Millisecond):
168				if fileExists(statusFile) && fileSize(statusFile) > 0 {
169					done <- true
170					return
171				}
172
173				if timeout > 0 {
174					elapsed := time.Since(startTime)
175					if elapsed > timeout {
176						s.killChildren()
177						interrupted = true
178						done <- true
179						return
180					}
181				}
182			}
183		}
184	}()
185
186	<-done
187
188	stdout := readFileOrEmpty(stdoutFile)
189	stderr := readFileOrEmpty(stderrFile)
190	exitCodeStr := readFileOrEmpty(statusFile)
191	newCwd := readFileOrEmpty(cwdFile)
192
193	exitCode := 0
194	if exitCodeStr != "" {
195		fmt.Sscanf(exitCodeStr, "%d", &exitCode)
196	} else if interrupted {
197		exitCode = 143
198		stderr += "\nCommand execution timed out or was interrupted"
199	}
200
201	if newCwd != "" {
202		s.cwd = strings.TrimSpace(newCwd)
203	}
204
205	return commandResult{
206		stdout:      stdout,
207		stderr:      stderr,
208		exitCode:    exitCode,
209		interrupted: interrupted,
210	}
211}
212
213func (s *PersistentShell) killChildren() {
214	if s.cmd == nil || s.cmd.Process == nil {
215		return
216	}
217
218	pgrepCmd := exec.Command("pgrep", "-P", fmt.Sprintf("%d", s.cmd.Process.Pid))
219	output, err := pgrepCmd.Output()
220	if err != nil {
221		return
222	}
223
224	for pidStr := range strings.SplitSeq(string(output), "\n") {
225		if pidStr = strings.TrimSpace(pidStr); pidStr != "" {
226			var pid int
227			fmt.Sscanf(pidStr, "%d", &pid)
228			if pid > 0 {
229				proc, err := os.FindProcess(pid)
230				if err == nil {
231					proc.Signal(syscall.SIGTERM)
232				}
233			}
234		}
235	}
236}
237
238func (s *PersistentShell) Exec(ctx context.Context, command string, timeoutMs int) (string, string, int, bool, error) {
239	if !s.isAlive {
240		return "", "Shell is not alive", 1, false, errors.New("shell is not alive")
241	}
242
243	timeout := time.Duration(timeoutMs) * time.Millisecond
244
245	resultChan := make(chan commandResult)
246	s.commandQueue <- &commandExecution{
247		command:    command,
248		timeout:    timeout,
249		resultChan: resultChan,
250		ctx:        ctx,
251	}
252
253	result := <-resultChan
254	return result.stdout, result.stderr, result.exitCode, result.interrupted, result.err
255}
256
257func (s *PersistentShell) Close() {
258	s.mu.Lock()
259	defer s.mu.Unlock()
260
261	if !s.isAlive {
262		return
263	}
264
265	s.stdin.Write([]byte("exit\n"))
266
267	s.cmd.Process.Kill()
268	s.isAlive = false
269}
270
271func shellQuote(s string) string {
272	return "'" + strings.ReplaceAll(s, "'", "'\\''") + "'"
273}
274
275func readFileOrEmpty(path string) string {
276	content, err := os.ReadFile(path)
277	if err != nil {
278		return ""
279	}
280	return string(content)
281}
282
283func fileExists(path string) bool {
284	_, err := os.Stat(path)
285	return err == nil
286}
287
288func fileSize(path string) int64 {
289	info, err := os.Stat(path)
290	if err != nil {
291		return 0
292	}
293	return info.Size()
294}