shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package provides Shell instances for executing commands with their own
  4// working directory and environment. Each shell execution is independent.
  5//
  6// WINDOWS COMPATIBILITY:
  7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
  8// Windows. Commands should use forward slashes (/) as path separators to work
  9// correctly on all platforms.
 10package shell
 11
 12import (
 13	"bytes"
 14	"context"
 15	"errors"
 16	"fmt"
 17	"io"
 18	"os"
 19	"slices"
 20	"strings"
 21	"sync"
 22
 23	"github.com/charmbracelet/x/exp/slice"
 24	"mvdan.cc/sh/moreinterp/coreutils"
 25	"mvdan.cc/sh/v3/expand"
 26	"mvdan.cc/sh/v3/interp"
 27	"mvdan.cc/sh/v3/syntax"
 28)
 29
 30// ShellType represents the type of shell to use
 31type ShellType int
 32
 33const (
 34	ShellTypePOSIX ShellType = iota
 35	ShellTypeCmd
 36	ShellTypePowerShell
 37)
 38
 39// Logger interface for optional logging
 40type Logger interface {
 41	InfoPersist(msg string, keysAndValues ...any)
 42}
 43
 44// noopLogger is a logger that does nothing
 45type noopLogger struct{}
 46
 47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
 48
 49// BlockFunc is a function that determines if a command should be blocked
 50type BlockFunc func(args []string) bool
 51
 52// Shell provides cross-platform shell execution with optional state persistence
 53type Shell struct {
 54	env        []string
 55	cwd        string
 56	mu         sync.Mutex
 57	logger     Logger
 58	blockFuncs []BlockFunc
 59}
 60
 61// Options for creating a new shell
 62type Options struct {
 63	WorkingDir string
 64	Env        []string
 65	Logger     Logger
 66	BlockFuncs []BlockFunc
 67}
 68
 69// NewShell creates a new shell instance with the given options
 70func NewShell(opts *Options) *Shell {
 71	if opts == nil {
 72		opts = &Options{}
 73	}
 74
 75	cwd := opts.WorkingDir
 76	if cwd == "" {
 77		cwd, _ = os.Getwd()
 78	}
 79
 80	env := opts.Env
 81	if env == nil {
 82		env = os.Environ()
 83	}
 84
 85	logger := opts.Logger
 86	if logger == nil {
 87		logger = noopLogger{}
 88	}
 89
 90	return &Shell{
 91		cwd:        cwd,
 92		env:        env,
 93		logger:     logger,
 94		blockFuncs: opts.BlockFuncs,
 95	}
 96}
 97
 98// Exec executes a command in the shell
 99func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
100	s.mu.Lock()
101	defer s.mu.Unlock()
102
103	return s.exec(ctx, command, nil)
104}
105
106// ExecWithStdin executes a command in the shell with provided stdin
107func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin io.Reader) (string, string, error) {
108	s.mu.Lock()
109	defer s.mu.Unlock()
110
111	return s.exec(ctx, command, stdin)
112}
113
114// ExecStream executes a command in the shell with streaming output to provided writers
115func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
116	s.mu.Lock()
117	defer s.mu.Unlock()
118
119	return s.execStream(ctx, command, stdout, stderr)
120}
121
122// GetWorkingDir returns the current working directory
123func (s *Shell) GetWorkingDir() string {
124	s.mu.Lock()
125	defer s.mu.Unlock()
126	return s.cwd
127}
128
129// SetWorkingDir sets the working directory
130func (s *Shell) SetWorkingDir(dir string) error {
131	s.mu.Lock()
132	defer s.mu.Unlock()
133
134	// Verify the directory exists
135	if _, err := os.Stat(dir); err != nil {
136		return fmt.Errorf("directory does not exist: %w", err)
137	}
138
139	s.cwd = dir
140	return nil
141}
142
143// GetEnv returns a copy of the environment variables
144func (s *Shell) GetEnv() []string {
145	s.mu.Lock()
146	defer s.mu.Unlock()
147
148	env := make([]string, len(s.env))
149	copy(env, s.env)
150	return env
151}
152
153// SetEnv sets an environment variable
154func (s *Shell) SetEnv(key, value string) {
155	s.mu.Lock()
156	defer s.mu.Unlock()
157
158	// Update or add the environment variable
159	keyPrefix := key + "="
160	for i, env := range s.env {
161		if strings.HasPrefix(env, keyPrefix) {
162			s.env[i] = keyPrefix + value
163			return
164		}
165	}
166	s.env = append(s.env, keyPrefix+value)
167}
168
169// SetBlockFuncs sets the command block functions for the shell
170func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
171	s.mu.Lock()
172	defer s.mu.Unlock()
173	s.blockFuncs = blockFuncs
174}
175
176// CommandsBlocker creates a BlockFunc that blocks exact command matches
177func CommandsBlocker(cmds []string) BlockFunc {
178	bannedSet := make(map[string]struct{})
179	for _, cmd := range cmds {
180		bannedSet[cmd] = struct{}{}
181	}
182
183	return func(args []string) bool {
184		if len(args) == 0 {
185			return false
186		}
187		_, ok := bannedSet[args[0]]
188		return ok
189	}
190}
191
192// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
193func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
194	return func(parts []string) bool {
195		if len(parts) == 0 || parts[0] != cmd {
196			return false
197		}
198
199		argParts, flagParts := splitArgsFlags(parts[1:])
200		if len(argParts) < len(args) || len(flagParts) < len(flags) {
201			return false
202		}
203
204		argsMatch := slices.Equal(argParts[:len(args)], args)
205		flagsMatch := slice.IsSubset(flags, flagParts)
206
207		return argsMatch && flagsMatch
208	}
209}
210
211func splitArgsFlags(parts []string) (args []string, flags []string) {
212	args = make([]string, 0, len(parts))
213	flags = make([]string, 0, len(parts))
214	for _, part := range parts {
215		if strings.HasPrefix(part, "-") {
216			// Extract flag name before '=' if present
217			flag := part
218			if idx := strings.IndexByte(part, '='); idx != -1 {
219				flag = part[:idx]
220			}
221			flags = append(flags, flag)
222		} else {
223			args = append(args, part)
224		}
225	}
226	return args, flags
227}
228
229func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
230	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
231		return func(ctx context.Context, args []string) error {
232			if len(args) == 0 {
233				return next(ctx, args)
234			}
235
236			for _, blockFunc := range s.blockFuncs {
237				if blockFunc(args) {
238					return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
239				}
240			}
241
242			return next(ctx, args)
243		}
244	}
245}
246
247// newInterp creates a new interpreter with the current shell state
248func (s *Shell) newInterp(stdin io.Reader, stdout, stderr io.Writer) (*interp.Runner, error) {
249	return interp.New(
250		interp.StdIO(stdin, stdout, stderr),
251		interp.Interactive(false),
252		interp.Env(expand.ListEnviron(s.env...)),
253		interp.Dir(s.cwd),
254		interp.ExecHandlers(s.execHandlers()...),
255	)
256}
257
258// updateShellFromRunner updates the shell from the interpreter after execution
259func (s *Shell) updateShellFromRunner(runner *interp.Runner) {
260	s.cwd = runner.Dir
261	s.env = nil
262	for name, vr := range runner.Vars {
263		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
264	}
265}
266
267// execCommon is the shared implementation for executing commands
268func (s *Shell) execCommon(ctx context.Context, command string, stdin io.Reader, stdout, stderr io.Writer) error {
269	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
270	if err != nil {
271		return fmt.Errorf("could not parse command: %w", err)
272	}
273
274	runner, err := s.newInterp(stdin, stdout, stderr)
275	if err != nil {
276		return fmt.Errorf("could not run command: %w", err)
277	}
278
279	err = runner.Run(ctx, line)
280	s.updateShellFromRunner(runner)
281	s.logger.InfoPersist("command finished", "command", command, "err", err)
282	return err
283}
284
285// exec executes commands using a cross-platform shell interpreter.
286func (s *Shell) exec(ctx context.Context, command string, stdin io.Reader) (string, string, error) {
287	var stdout, stderr bytes.Buffer
288	err := s.execCommon(ctx, command, stdin, &stdout, &stderr)
289	return stdout.String(), stderr.String(), err
290}
291
292// execStream executes commands using POSIX shell emulation with streaming output
293func (s *Shell) execStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
294	return s.execCommon(ctx, command, nil, stdout, stderr)
295}
296
297func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
298	handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
299		s.blockHandler(),
300	}
301	if useGoCoreUtils {
302		handlers = append(handlers, coreutils.ExecHandler)
303	}
304	return handlers
305}
306
307// IsInterrupt checks if an error is due to interruption
308func IsInterrupt(err error) bool {
309	return errors.Is(err, context.Canceled) ||
310		errors.Is(err, context.DeadlineExceeded)
311}
312
313// ExitCode extracts the exit code from an error
314func ExitCode(err error) int {
315	if err == nil {
316		return 0
317	}
318	var exitErr interp.ExitStatus
319	if errors.As(err, &exitErr) {
320		return int(exitErr)
321	}
322	return 1
323}