shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package provides Shell instances for executing commands with their own
  4// working directory and environment. Each shell execution is independent.
  5//
  6// WINDOWS COMPATIBILITY:
  7// This implementation provides POSIX shell emulation (mvdan.cc/sh/v3) even on
  8// Windows. Commands should use forward slashes (/) as path separators to work
  9// correctly on all platforms.
 10package shell
 11
 12import (
 13	"bytes"
 14	"context"
 15	"errors"
 16	"fmt"
 17	"os"
 18	"slices"
 19	"strings"
 20	"sync"
 21
 22	"github.com/charmbracelet/x/exp/slice"
 23	"mvdan.cc/sh/moreinterp/coreutils"
 24	"mvdan.cc/sh/v3/expand"
 25	"mvdan.cc/sh/v3/interp"
 26	"mvdan.cc/sh/v3/syntax"
 27)
 28
 29// ShellType represents the type of shell to use
 30type ShellType int
 31
 32const (
 33	ShellTypePOSIX ShellType = iota
 34	ShellTypeCmd
 35	ShellTypePowerShell
 36)
 37
 38// Logger interface for optional logging
 39type Logger interface {
 40	InfoPersist(msg string, keysAndValues ...any)
 41}
 42
 43// noopLogger is a logger that does nothing
 44type noopLogger struct{}
 45
 46func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
 47
 48// BlockFunc is a function that determines if a command should be blocked
 49type BlockFunc func(args []string) bool
 50
 51// Shell provides cross-platform shell execution with optional state persistence
 52type Shell struct {
 53	env        []string
 54	cwd        string
 55	mu         sync.Mutex
 56	logger     Logger
 57	blockFuncs []BlockFunc
 58}
 59
 60// Options for creating a new shell
 61type Options struct {
 62	WorkingDir string
 63	Env        []string
 64	Logger     Logger
 65	BlockFuncs []BlockFunc
 66}
 67
 68// NewShell creates a new shell instance with the given options
 69func NewShell(opts *Options) *Shell {
 70	if opts == nil {
 71		opts = &Options{}
 72	}
 73
 74	cwd := opts.WorkingDir
 75	if cwd == "" {
 76		cwd, _ = os.Getwd()
 77	}
 78
 79	env := opts.Env
 80	if env == nil {
 81		env = os.Environ()
 82	}
 83
 84	logger := opts.Logger
 85	if logger == nil {
 86		logger = noopLogger{}
 87	}
 88
 89	return &Shell{
 90		cwd:        cwd,
 91		env:        env,
 92		logger:     logger,
 93		blockFuncs: opts.BlockFuncs,
 94	}
 95}
 96
 97// Exec executes a command in the shell
 98func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
 99	s.mu.Lock()
100	defer s.mu.Unlock()
101
102	return s.execPOSIX(ctx, command)
103}
104
105// GetWorkingDir returns the current working directory
106func (s *Shell) GetWorkingDir() string {
107	s.mu.Lock()
108	defer s.mu.Unlock()
109	return s.cwd
110}
111
112// SetWorkingDir sets the working directory
113func (s *Shell) SetWorkingDir(dir string) error {
114	s.mu.Lock()
115	defer s.mu.Unlock()
116
117	// Verify the directory exists
118	if _, err := os.Stat(dir); err != nil {
119		return fmt.Errorf("directory does not exist: %w", err)
120	}
121
122	s.cwd = dir
123	return nil
124}
125
126// GetEnv returns a copy of the environment variables
127func (s *Shell) GetEnv() []string {
128	s.mu.Lock()
129	defer s.mu.Unlock()
130
131	env := make([]string, len(s.env))
132	copy(env, s.env)
133	return env
134}
135
136// SetEnv sets an environment variable
137func (s *Shell) SetEnv(key, value string) {
138	s.mu.Lock()
139	defer s.mu.Unlock()
140
141	// Update or add the environment variable
142	keyPrefix := key + "="
143	for i, env := range s.env {
144		if strings.HasPrefix(env, keyPrefix) {
145			s.env[i] = keyPrefix + value
146			return
147		}
148	}
149	s.env = append(s.env, keyPrefix+value)
150}
151
152// SetBlockFuncs sets the command block functions for the shell
153func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
154	s.mu.Lock()
155	defer s.mu.Unlock()
156	s.blockFuncs = blockFuncs
157}
158
159// CommandsBlocker creates a BlockFunc that blocks exact command matches
160func CommandsBlocker(cmds []string) BlockFunc {
161	bannedSet := make(map[string]struct{})
162	for _, cmd := range cmds {
163		bannedSet[cmd] = struct{}{}
164	}
165
166	return func(args []string) bool {
167		if len(args) == 0 {
168			return false
169		}
170		_, ok := bannedSet[args[0]]
171		return ok
172	}
173}
174
175// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
176func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
177	return func(parts []string) bool {
178		if len(parts) == 0 || parts[0] != cmd {
179			return false
180		}
181
182		argParts, flagParts := splitArgsFlags(parts[1:])
183		if len(argParts) < len(args) || len(flagParts) < len(flags) {
184			return false
185		}
186
187		argsMatch := slices.Equal(argParts[:len(args)], args)
188		flagsMatch := slice.IsSubset(flags, flagParts)
189
190		return argsMatch && flagsMatch
191	}
192}
193
194func splitArgsFlags(parts []string) (args []string, flags []string) {
195	args = make([]string, 0, len(parts))
196	flags = make([]string, 0, len(parts))
197	for _, part := range parts {
198		if strings.HasPrefix(part, "-") {
199			// Extract flag name before '=' if present
200			flag := part
201			if idx := strings.IndexByte(part, '='); idx != -1 {
202				flag = part[:idx]
203			}
204			flags = append(flags, flag)
205		} else {
206			args = append(args, part)
207		}
208	}
209	return args, flags
210}
211
212func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
213	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
214		return func(ctx context.Context, args []string) error {
215			if len(args) == 0 {
216				return next(ctx, args)
217			}
218
219			for _, blockFunc := range s.blockFuncs {
220				if blockFunc(args) {
221					return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
222				}
223			}
224
225			return next(ctx, args)
226		}
227	}
228}
229
230// execPOSIX executes commands using POSIX shell emulation (cross-platform)
231func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
232	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
233	if err != nil {
234		return "", "", fmt.Errorf("could not parse command: %w", err)
235	}
236
237	var stdout, stderr bytes.Buffer
238	runner, err := interp.New(
239		interp.StdIO(nil, &stdout, &stderr),
240		interp.Interactive(false),
241		interp.Env(expand.ListEnviron(s.env...)),
242		interp.Dir(s.cwd),
243		interp.ExecHandlers(s.blockHandler(), coreutils.ExecHandler),
244	)
245	if err != nil {
246		return "", "", fmt.Errorf("could not run command: %w", err)
247	}
248
249	err = runner.Run(ctx, line)
250	s.cwd = runner.Dir
251	s.env = []string{}
252	for name, vr := range runner.Vars {
253		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
254	}
255	s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
256	return stdout.String(), stderr.String(), err
257}
258
259// IsInterrupt checks if an error is due to interruption
260func IsInterrupt(err error) bool {
261	return errors.Is(err, context.Canceled) ||
262		errors.Is(err, context.DeadlineExceeded)
263}
264
265// ExitCode extracts the exit code from an error
266func ExitCode(err error) int {
267	if err == nil {
268		return 0
269	}
270	var exitErr interp.ExitStatus
271	if errors.As(err, &exitErr) {
272		return int(exitErr)
273	}
274	return 1
275}