shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package provides Shell instances for executing commands with their own
  4// working directory and environment. Each shell execution is independent.
  5//
  6// WINDOWS COMPATIBILITY:
  7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
  8// Windows. Commands should use forward slashes (/) as path separators to work
  9// correctly on all platforms.
 10package shell
 11
 12import (
 13	"bytes"
 14	"context"
 15	"errors"
 16	"fmt"
 17	"io"
 18	"os"
 19	"slices"
 20	"strings"
 21	"sync"
 22
 23	"github.com/charmbracelet/x/exp/slice"
 24	"mvdan.cc/sh/moreinterp/coreutils"
 25	"mvdan.cc/sh/v3/expand"
 26	"mvdan.cc/sh/v3/interp"
 27	"mvdan.cc/sh/v3/syntax"
 28)
 29
 30// ShellType represents the type of shell to use
 31type ShellType int
 32
 33const (
 34	ShellTypePOSIX ShellType = iota
 35	ShellTypeCmd
 36	ShellTypePowerShell
 37)
 38
 39// Logger interface for optional logging
 40type Logger interface {
 41	InfoPersist(msg string, keysAndValues ...any)
 42}
 43
 44// noopLogger is a logger that does nothing
 45type noopLogger struct{}
 46
 47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
 48
 49// BlockFunc is a function that determines if a command should be blocked
 50type BlockFunc func(args []string) bool
 51
 52// Shell provides cross-platform shell execution with optional state persistence
 53type Shell struct {
 54	env        []string
 55	cwd        string
 56	mu         sync.Mutex
 57	logger     Logger
 58	blockFuncs []BlockFunc
 59}
 60
 61// Options for creating a new shell
 62type Options struct {
 63	WorkingDir string
 64	Env        []string
 65	Logger     Logger
 66	BlockFuncs []BlockFunc
 67}
 68
 69// NewShell creates a new shell instance with the given options
 70func NewShell(opts *Options) *Shell {
 71	if opts == nil {
 72		opts = &Options{}
 73	}
 74
 75	cwd := opts.WorkingDir
 76	if cwd == "" {
 77		cwd, _ = os.Getwd()
 78	}
 79
 80	env := opts.Env
 81	if env == nil {
 82		env = os.Environ()
 83	}
 84
 85	logger := opts.Logger
 86	if logger == nil {
 87		logger = noopLogger{}
 88	}
 89
 90	return &Shell{
 91		cwd:        cwd,
 92		env:        env,
 93		logger:     logger,
 94		blockFuncs: opts.BlockFuncs,
 95	}
 96}
 97
 98// Exec executes a command in the shell
 99func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
100	s.mu.Lock()
101	defer s.mu.Unlock()
102
103	return s.exec(ctx, command)
104}
105
106// ExecStream executes a command in the shell with streaming output to provided writers
107func (s *Shell) ExecStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
108	s.mu.Lock()
109	defer s.mu.Unlock()
110
111	return s.execPOSIXStream(ctx, command, stdout, stderr)
112}
113
114// GetWorkingDir returns the current working directory
115func (s *Shell) GetWorkingDir() string {
116	s.mu.Lock()
117	defer s.mu.Unlock()
118	return s.cwd
119}
120
121// SetWorkingDir sets the working directory
122func (s *Shell) SetWorkingDir(dir string) error {
123	s.mu.Lock()
124	defer s.mu.Unlock()
125
126	// Verify the directory exists
127	if _, err := os.Stat(dir); err != nil {
128		return fmt.Errorf("directory does not exist: %w", err)
129	}
130
131	s.cwd = dir
132	return nil
133}
134
135// GetEnv returns a copy of the environment variables
136func (s *Shell) GetEnv() []string {
137	s.mu.Lock()
138	defer s.mu.Unlock()
139
140	env := make([]string, len(s.env))
141	copy(env, s.env)
142	return env
143}
144
145// SetEnv sets an environment variable
146func (s *Shell) SetEnv(key, value string) {
147	s.mu.Lock()
148	defer s.mu.Unlock()
149
150	// Update or add the environment variable
151	keyPrefix := key + "="
152	for i, env := range s.env {
153		if strings.HasPrefix(env, keyPrefix) {
154			s.env[i] = keyPrefix + value
155			return
156		}
157	}
158	s.env = append(s.env, keyPrefix+value)
159}
160
161// SetBlockFuncs sets the command block functions for the shell
162func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
163	s.mu.Lock()
164	defer s.mu.Unlock()
165	s.blockFuncs = blockFuncs
166}
167
168// CommandsBlocker creates a BlockFunc that blocks exact command matches
169func CommandsBlocker(cmds []string) BlockFunc {
170	bannedSet := make(map[string]struct{})
171	for _, cmd := range cmds {
172		bannedSet[cmd] = struct{}{}
173	}
174
175	return func(args []string) bool {
176		if len(args) == 0 {
177			return false
178		}
179		_, ok := bannedSet[args[0]]
180		return ok
181	}
182}
183
184// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
185func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
186	return func(parts []string) bool {
187		if len(parts) == 0 || parts[0] != cmd {
188			return false
189		}
190
191		argParts, flagParts := splitArgsFlags(parts[1:])
192		if len(argParts) < len(args) || len(flagParts) < len(flags) {
193			return false
194		}
195
196		argsMatch := slices.Equal(argParts[:len(args)], args)
197		flagsMatch := slice.IsSubset(flags, flagParts)
198
199		return argsMatch && flagsMatch
200	}
201}
202
203func splitArgsFlags(parts []string) (args []string, flags []string) {
204	args = make([]string, 0, len(parts))
205	flags = make([]string, 0, len(parts))
206	for _, part := range parts {
207		if strings.HasPrefix(part, "-") {
208			// Extract flag name before '=' if present
209			flag := part
210			if idx := strings.IndexByte(part, '='); idx != -1 {
211				flag = part[:idx]
212			}
213			flags = append(flags, flag)
214		} else {
215			args = append(args, part)
216		}
217	}
218	return args, flags
219}
220
221func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
222	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
223		return func(ctx context.Context, args []string) error {
224			if len(args) == 0 {
225				return next(ctx, args)
226			}
227
228			for _, blockFunc := range s.blockFuncs {
229				if blockFunc(args) {
230					return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
231				}
232			}
233
234			return next(ctx, args)
235		}
236	}
237}
238
239// exec executes commands using a cross-platform shell interpreter.
240func (s *Shell) exec(ctx context.Context, command string) (string, string, error) {
241	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
242	if err != nil {
243		return "", "", fmt.Errorf("could not parse command: %w", err)
244	}
245
246	var stdout, stderr bytes.Buffer
247	runner, err := interp.New(
248		interp.StdIO(nil, &stdout, &stderr),
249		interp.Interactive(false),
250		interp.Env(expand.ListEnviron(s.env...)),
251		interp.Dir(s.cwd),
252		interp.ExecHandlers(s.execHandlers()...),
253	)
254	if err != nil {
255		return "", "", fmt.Errorf("could not run command: %w", err)
256	}
257
258	err = runner.Run(ctx, line)
259	s.cwd = runner.Dir
260	for name, vr := range runner.Vars {
261		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
262	}
263	s.logger.InfoPersist("command finished", "command", command, "err", err)
264	return stdout.String(), stderr.String(), err
265}
266
267// execPOSIXStream executes commands using POSIX shell emulation with streaming output
268func (s *Shell) execPOSIXStream(ctx context.Context, command string, stdout, stderr io.Writer) error {
269	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
270	if err != nil {
271		return fmt.Errorf("could not parse command: %w", err)
272	}
273
274	runner, err := interp.New(
275		interp.StdIO(nil, stdout, stderr),
276		interp.Interactive(false),
277		interp.Env(expand.ListEnviron(s.env...)),
278		interp.Dir(s.cwd),
279		interp.ExecHandlers(s.execHandlers()...),
280	)
281	if err != nil {
282		return fmt.Errorf("could not run command: %w", err)
283	}
284
285	err = runner.Run(ctx, line)
286	s.cwd = runner.Dir
287	s.env = []string{}
288	for name, vr := range runner.Vars {
289		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
290	}
291	s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
292	return err
293}
294
295func (s *Shell) execHandlers() []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
296	handlers := []func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc{
297		s.blockHandler(),
298	}
299	if useGoCoreUtils {
300		handlers = append(handlers, coreutils.ExecHandler)
301	}
302	return handlers
303}
304
305// IsInterrupt checks if an error is due to interruption
306func IsInterrupt(err error) bool {
307	return errors.Is(err, context.Canceled) ||
308		errors.Is(err, context.DeadlineExceeded)
309}
310
311// ExitCode extracts the exit code from an error
312func ExitCode(err error) int {
313	if err == nil {
314		return 0
315	}
316	var exitErr interp.ExitStatus
317	if errors.As(err, &exitErr) {
318		return int(exitErr)
319	}
320	return 1
321}