shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package offers two main types:
  4// - Shell: A general-purpose shell executor for one-off or managed commands
  5// - PersistentShell: A singleton shell that maintains state across the application
  6//
  7// WINDOWS COMPATIBILITY:
  8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3),
  9// even on Windows. Some caution has to be taken: commands should have forward
 10// slashes (/) as path separators to work, even on Windows.
 11package shell
 12
 13import (
 14	"bytes"
 15	"context"
 16	"errors"
 17	"fmt"
 18	"io"
 19	"os"
 20	"slices"
 21	"strings"
 22	"sync"
 23
 24	"github.com/charmbracelet/x/exp/slice"
 25	"mvdan.cc/sh/moreinterp/coreutils"
 26	"mvdan.cc/sh/v3/expand"
 27	"mvdan.cc/sh/v3/interp"
 28	"mvdan.cc/sh/v3/syntax"
 29)
 30
 31// ShellType represents the type of shell to use
 32type ShellType int
 33
 34const (
 35	ShellTypePOSIX ShellType = iota
 36	ShellTypeCmd
 37	ShellTypePowerShell
 38)
 39
 40// Logger interface for optional logging
 41type Logger interface {
 42	InfoPersist(msg string, keysAndValues ...any)
 43}
 44
 45// noopLogger is a logger that does nothing
 46type noopLogger struct{}
 47
 48func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
 49
 50// BlockFunc is a function that determines if a command should be blocked
 51type BlockFunc func(args []string) bool
 52
 53// Shell provides cross-platform shell execution with optional state persistence
 54type Shell struct {
 55	env        []string
 56	cwd        string
 57	mu         sync.Mutex
 58	logger     Logger
 59	blockFuncs []BlockFunc
 60}
 61
 62// Options for creating a new shell
 63type Options struct {
 64	WorkingDir string
 65	Env        []string
 66	Logger     Logger
 67	BlockFuncs []BlockFunc
 68}
 69
 70// NewShell creates a new shell instance with the given options
 71func NewShell(opts *Options) *Shell {
 72	if opts == nil {
 73		opts = &Options{}
 74	}
 75
 76	cwd := opts.WorkingDir
 77	if cwd == "" {
 78		cwd, _ = os.Getwd()
 79	}
 80
 81	env := opts.Env
 82	if env == nil {
 83		env = os.Environ()
 84	}
 85
 86	logger := opts.Logger
 87	if logger == nil {
 88		logger = noopLogger{}
 89	}
 90
 91	return &Shell{
 92		cwd:        cwd,
 93		env:        env,
 94		logger:     logger,
 95		blockFuncs: opts.BlockFuncs,
 96	}
 97}
 98
 99// Exec executes a command in the shell
100func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
101	s.mu.Lock()
102	defer s.mu.Unlock()
103
104	return s.execPOSIX(ctx, command, nil)
105}
106
107// ExecWithStdin executes a command in the shell with the given stdin
108func (s *Shell) ExecWithStdin(ctx context.Context, command string, stdin string) (string, string, error) {
109	s.mu.Lock()
110	defer s.mu.Unlock()
111
112	return s.execPOSIX(ctx, command, strings.NewReader(stdin))
113}
114
115// GetWorkingDir returns the current working directory
116func (s *Shell) GetWorkingDir() string {
117	s.mu.Lock()
118	defer s.mu.Unlock()
119	return s.cwd
120}
121
122// SetWorkingDir sets the working directory
123func (s *Shell) SetWorkingDir(dir string) error {
124	s.mu.Lock()
125	defer s.mu.Unlock()
126
127	// Verify the directory exists
128	if _, err := os.Stat(dir); err != nil {
129		return fmt.Errorf("directory does not exist: %w", err)
130	}
131
132	s.cwd = dir
133	return nil
134}
135
136// GetEnv returns a copy of the environment variables
137func (s *Shell) GetEnv() []string {
138	s.mu.Lock()
139	defer s.mu.Unlock()
140
141	env := make([]string, len(s.env))
142	copy(env, s.env)
143	return env
144}
145
146// SetEnv sets an environment variable
147func (s *Shell) SetEnv(key, value string) {
148	s.mu.Lock()
149	defer s.mu.Unlock()
150
151	// Update or add the environment variable
152	keyPrefix := key + "="
153	for i, env := range s.env {
154		if strings.HasPrefix(env, keyPrefix) {
155			s.env[i] = keyPrefix + value
156			return
157		}
158	}
159	s.env = append(s.env, keyPrefix+value)
160}
161
162// SetBlockFuncs sets the command block functions for the shell
163func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
164	s.mu.Lock()
165	defer s.mu.Unlock()
166	s.blockFuncs = blockFuncs
167}
168
169// CommandsBlocker creates a BlockFunc that blocks exact command matches
170func CommandsBlocker(cmds []string) BlockFunc {
171	bannedSet := make(map[string]struct{})
172	for _, cmd := range cmds {
173		bannedSet[cmd] = struct{}{}
174	}
175
176	return func(args []string) bool {
177		if len(args) == 0 {
178			return false
179		}
180		_, ok := bannedSet[args[0]]
181		return ok
182	}
183}
184
185// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
186func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
187	return func(parts []string) bool {
188		if len(parts) == 0 || parts[0] != cmd {
189			return false
190		}
191
192		argParts, flagParts := splitArgsFlags(parts[1:])
193		if len(argParts) < len(args) || len(flagParts) < len(flags) {
194			return false
195		}
196
197		argsMatch := slices.Equal(argParts[:len(args)], args)
198		flagsMatch := slice.IsSubset(flags, flagParts)
199
200		return argsMatch && flagsMatch
201	}
202}
203
204func splitArgsFlags(parts []string) (args []string, flags []string) {
205	args = make([]string, 0, len(parts))
206	flags = make([]string, 0, len(parts))
207	for _, part := range parts {
208		if strings.HasPrefix(part, "-") {
209			// Extract flag name before '=' if present
210			flag := part
211			if idx := strings.IndexByte(part, '='); idx != -1 {
212				flag = part[:idx]
213			}
214			flags = append(flags, flag)
215		} else {
216			args = append(args, part)
217		}
218	}
219	return args, flags
220}
221
222func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
223	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
224		return func(ctx context.Context, args []string) error {
225			if len(args) == 0 {
226				return next(ctx, args)
227			}
228
229			for _, blockFunc := range s.blockFuncs {
230				if blockFunc(args) {
231					return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
232				}
233			}
234
235			return next(ctx, args)
236		}
237	}
238}
239
240// execPOSIX executes commands using POSIX shell emulation (cross-platform)
241func (s *Shell) execPOSIX(ctx context.Context, command string, stdin *strings.Reader) (string, string, error) {
242	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
243	if err != nil {
244		return "", "", fmt.Errorf("could not parse command: %w", err)
245	}
246
247	var stdout, stderr bytes.Buffer
248	var stdinReader io.Reader
249	if stdin != nil {
250		stdinReader = stdin
251	}
252	runner, err := interp.New(
253		interp.StdIO(stdinReader, &stdout, &stderr),
254		interp.Interactive(false),
255		interp.Env(expand.ListEnviron(s.env...)),
256		interp.Dir(s.cwd),
257		interp.ExecHandlers(s.blockHandler(), coreutils.ExecHandler),
258	)
259	if err != nil {
260		return "", "", fmt.Errorf("could not run command: %w", err)
261	}
262
263	err = runner.Run(ctx, line)
264	s.cwd = runner.Dir
265	s.env = []string{}
266	for name, vr := range runner.Vars {
267		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
268	}
269	s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
270	return stdout.String(), stderr.String(), err
271}
272
273// IsInterrupt checks if an error is due to interruption
274func IsInterrupt(err error) bool {
275	return errors.Is(err, context.Canceled) ||
276		errors.Is(err, context.DeadlineExceeded)
277}
278
279// ExitCode extracts the exit code from an error
280func ExitCode(err error) int {
281	if err == nil {
282		return 0
283	}
284	var exitErr interp.ExitStatus
285	if errors.As(err, &exitErr) {
286		return int(exitErr)
287	}
288	return 1
289}