shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package provides Shell instances for executing commands with their own
  4// working directory and environment. Each shell execution is independent.
  5//
  6// WINDOWS COMPATIBILITY:
  7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
  8// Windows. Commands should use forward slashes (/) as path separators to work
  9// correctly on all platforms.
 10package shell
 11
 12import (
 13	"bytes"
 14	"context"
 15	"errors"
 16	"fmt"
 17	"io"
 18	"os"
 19	"slices"
 20	"strings"
 21	"sync"
 22
 23	"github.com/charmbracelet/x/exp/slice"
 24	"mvdan.cc/sh/v3/interp"
 25	"mvdan.cc/sh/v3/syntax"
 26)
 27
 28// ShellType represents the type of shell to use
 29type ShellType int
 30
 31const (
 32	ShellTypePOSIX ShellType = iota
 33	ShellTypeCmd
 34	ShellTypePowerShell
 35)
 36
 37// Logger interface for optional logging
 38type Logger interface {
 39	InfoPersist(msg string, keysAndValues ...any)
 40}
 41
 42// noopLogger is a logger that does nothing
 43type noopLogger struct{}
 44
 45func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
 46
 47// BlockFunc is a function that determines if a command should be blocked
 48type BlockFunc func(args []string) bool
 49
 50// Shell provides cross-platform shell execution with optional state persistence
 51type Shell struct {
 52	env        []string
 53	cwd        string
 54	mu         sync.Mutex
 55	logger     Logger
 56	blockFuncs []BlockFunc
 57}
 58
 59// Options for creating a new shell
 60type Options struct {
 61	WorkingDir string
 62	Env        []string
 63	Logger     Logger
 64	BlockFuncs []BlockFunc
 65}
 66
 67// NewShell creates a new shell instance with the given options
 68func NewShell(opts *Options) *Shell {
 69	if opts == nil {
 70		opts = &Options{}
 71	}
 72
 73	cwd := opts.WorkingDir
 74	if cwd == "" {
 75		cwd, _ = os.Getwd()
 76	}
 77
 78	env := opts.Env
 79	if env == nil {
 80		env = os.Environ()
 81	}
 82
 83	// Allow tools to detect execution by Crush.
 84	env = append(
 85		env,
 86		"CRUSH=1",
 87		"AGENT=crush",
 88		"AI_AGENT=crush",
 89	)
 90
 91	logger := opts.Logger
 92	if logger == nil {
 93		logger = noopLogger{}
 94	}
 95
 96	return &Shell{
 97		cwd:        cwd,
 98		env:        env,
 99		logger:     logger,
100		blockFuncs: opts.BlockFuncs,
101	}
102}
103
104// Exec executes a command in the shell
105func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
106	s.mu.Lock()
107	defer s.mu.Unlock()
108
109	return s.exec(ctx, command)
110}
111
112// ExecStream executes a command in the shell with streaming output to provided writers
113func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
114	s.mu.Lock()
115	defer s.mu.Unlock()
116
117	return s.execStream(ctx, command, stdout, stderr)
118}
119
120// GetWorkingDir returns the current working directory
121func (s *Shell) GetWorkingDir() string {
122	s.mu.Lock()
123	defer s.mu.Unlock()
124	return s.cwd
125}
126
127// SetWorkingDir sets the working directory
128func (s *Shell) SetWorkingDir(dir string) error {
129	s.mu.Lock()
130	defer s.mu.Unlock()
131
132	// Verify the directory exists
133	if _, err := os.Stat(dir); err != nil {
134		return fmt.Errorf("directory does not exist: %w", err)
135	}
136
137	s.cwd = dir
138	return nil
139}
140
141// GetEnv returns a copy of the environment variables
142func (s *Shell) GetEnv() []string {
143	s.mu.Lock()
144	defer s.mu.Unlock()
145
146	env := make([]string, len(s.env))
147	copy(env, s.env)
148	return env
149}
150
151// SetEnv sets an environment variable
152func (s *Shell) SetEnv(key, value string) {
153	s.mu.Lock()
154	defer s.mu.Unlock()
155
156	// Update or add the environment variable
157	keyPrefix := key + "="
158	for i, env := range s.env {
159		if strings.HasPrefix(env, keyPrefix) {
160			s.env[i] = keyPrefix + value
161			return
162		}
163	}
164	s.env = append(s.env, keyPrefix+value)
165}
166
167// SetBlockFuncs sets the command block functions for the shell
168func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
169	s.mu.Lock()
170	defer s.mu.Unlock()
171	s.blockFuncs = blockFuncs
172}
173
174// CommandsBlocker creates a BlockFunc that blocks exact command matches
175func CommandsBlocker(cmds []string) BlockFunc {
176	bannedSet := make(map[string]struct{})
177	for _, cmd := range cmds {
178		bannedSet[cmd] = struct{}{}
179	}
180
181	return func(args []string) bool {
182		if len(args) == 0 {
183			return false
184		}
185		_, ok := bannedSet[args[0]]
186		return ok
187	}
188}
189
190// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
191func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
192	return func(parts []string) bool {
193		if len(parts) == 0 || parts[0] != cmd {
194			return false
195		}
196
197		argParts, flagParts := splitArgsFlags(parts[1:])
198		if len(argParts) < len(args) || len(flagParts) < len(flags) {
199			return false
200		}
201
202		argsMatch := slices.Equal(argParts[:len(args)], args)
203		flagsMatch := slice.IsSubset(flags, flagParts)
204
205		return argsMatch && flagsMatch
206	}
207}
208
209func splitArgsFlags(parts []string) (args []string, flags []string) {
210	args = make([]string, 0, len(parts))
211	flags = make([]string, 0, len(parts))
212	for _, part := range parts {
213		if strings.HasPrefix(part, "-") {
214			// Extract flag name before '=' if present
215			flag := part
216			if before, _, ok := strings.Cut(part, "="); ok {
217				flag = before
218			}
219			flags = append(flags, flag)
220		} else {
221			args = append(args, part)
222		}
223	}
224	return args, flags
225}
226
227// newInterp creates a new interpreter with the current shell state. A nil
228// stdin is equivalent to an empty input stream.
229func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
230	return newRunner(s.cwd, s.env, stdin, stdout, stderr, s.blockFuncs)
231}
232
233// updateShellFromRunner updates the shell from the interpreter after execution.
234func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
235	s.cwd = runner.Dir
236	s.env = s.env[:0]
237	for name, vr := range runner.Vars {
238		if vr.Exported {
239			s.env = append(s.env, name+"="+vr.Str)
240		}
241	}
242}
243
244// execCommon is the shared implementation for executing commands
245func (s *Shell) execCommon(ctx context.Context, command string, stdout, stderr io.Writer) (err error) {
246	var runner *interp.Runner
247	defer func() {
248		if r := recover(); r != nil {
249			err = fmt.Errorf("command execution panic: %v", r)
250		}
251		if runner != nil {
252			s.updateShellFromRunner(runner)
253		}
254		s.logger.InfoPersist("command finished", "command", command, "err", err)
255	}()
256
257	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
258	if err != nil {
259		return fmt.Errorf("could not parse command: %w", err)
260	}
261
262	runner, err = s.newInterp(nil, stdout, stderr)
263	if err != nil {
264		return fmt.Errorf("could not run command: %w", err)
265	}
266
267	err = runner.Run(ctx, line)
268	return err
269}
270
271// exec executes commands using a cross-platform shell interpreter.
272func (s *Shell) exec(ctx context.Context, command string) (string, string, error) {
273	var stdout, stderr bytes.Buffer
274	err := s.execCommon(ctx, command, &stdout, &stderr)
275	return stdout.String(), stderr.String(), err
276}
277
278// execStream executes commands using POSIX shell emulation with streaming output
279func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
280	return s.execCommon(ctx, command, stdout, stderr)
281}
282
283// IsInterrupt checks if an error is due to interruption
284func IsInterrupt(err error) bool {
285	return errors.Is(err, context.Canceled) ||
286		errors.Is(err, context.DeadlineExceeded)
287}
288
289// ExitCode extracts the exit code from an error
290func ExitCode(err error) int {
291	if err == nil {
292		return 0
293	}
294	var exitErr interp.ExitStatus
295	if errors.As(err, &exitErr) {
296		return int(exitErr)
297	}
298	return 1
299}