shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package offers two main types:
  4// - Shell: A general-purpose shell executor for one-off or managed commands
  5// - PersistentShell: A singleton shell that maintains state across the application
  6//
  7// WINDOWS COMPATIBILITY:
  8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3),
  9// even on Windows. Some caution has to be taken: commands should have forward
 10// slashes (/) as path separators to work, even on Windows.
 11package shell
 12
 13import (
 14	"bytes"
 15	"context"
 16	"errors"
 17	"fmt"
 18	"os"
 19	"strings"
 20	"sync"
 21
 22	"mvdan.cc/sh/moreinterp/coreutils"
 23	"mvdan.cc/sh/v3/expand"
 24	"mvdan.cc/sh/v3/interp"
 25	"mvdan.cc/sh/v3/syntax"
 26)
 27
 28// ShellType represents the type of shell to use
 29type ShellType int
 30
 31const (
 32	ShellTypePOSIX ShellType = iota
 33	ShellTypeCmd
 34	ShellTypePowerShell
 35)
 36
 37// Logger interface for optional logging
 38type Logger interface {
 39	InfoPersist(msg string, keysAndValues ...any)
 40}
 41
 42// noopLogger is a logger that does nothing
 43type noopLogger struct{}
 44
 45func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
 46
 47// BlockFunc is a function that determines if a command should be blocked
 48type BlockFunc func(args []string) bool
 49
 50// Shell provides cross-platform shell execution with optional state persistence
 51type Shell struct {
 52	env        []string
 53	cwd        string
 54	mu         sync.Mutex
 55	logger     Logger
 56	blockFuncs []BlockFunc
 57}
 58
 59// Options for creating a new shell
 60type Options struct {
 61	WorkingDir string
 62	Env        []string
 63	Logger     Logger
 64	BlockFuncs []BlockFunc
 65}
 66
 67// NewShell creates a new shell instance with the given options
 68func NewShell(opts *Options) *Shell {
 69	if opts == nil {
 70		opts = &Options{}
 71	}
 72
 73	cwd := opts.WorkingDir
 74	if cwd == "" {
 75		cwd, _ = os.Getwd()
 76	}
 77
 78	env := opts.Env
 79	if env == nil {
 80		env = os.Environ()
 81	}
 82
 83	logger := opts.Logger
 84	if logger == nil {
 85		logger = noopLogger{}
 86	}
 87
 88	return &Shell{
 89		cwd:        cwd,
 90		env:        env,
 91		logger:     logger,
 92		blockFuncs: opts.BlockFuncs,
 93	}
 94}
 95
 96// Exec executes a command in the shell
 97func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
 98	s.mu.Lock()
 99	defer s.mu.Unlock()
100
101	return s.execPOSIX(ctx, command)
102}
103
104// GetWorkingDir returns the current working directory
105func (s *Shell) GetWorkingDir() string {
106	s.mu.Lock()
107	defer s.mu.Unlock()
108	return s.cwd
109}
110
111// SetWorkingDir sets the working directory
112func (s *Shell) SetWorkingDir(dir string) error {
113	s.mu.Lock()
114	defer s.mu.Unlock()
115
116	// Verify the directory exists
117	if _, err := os.Stat(dir); err != nil {
118		return fmt.Errorf("directory does not exist: %w", err)
119	}
120
121	s.cwd = dir
122	return nil
123}
124
125// GetEnv returns a copy of the environment variables
126func (s *Shell) GetEnv() []string {
127	s.mu.Lock()
128	defer s.mu.Unlock()
129
130	env := make([]string, len(s.env))
131	copy(env, s.env)
132	return env
133}
134
135// SetEnv sets an environment variable
136func (s *Shell) SetEnv(key, value string) {
137	s.mu.Lock()
138	defer s.mu.Unlock()
139
140	// Update or add the environment variable
141	keyPrefix := key + "="
142	for i, env := range s.env {
143		if strings.HasPrefix(env, keyPrefix) {
144			s.env[i] = keyPrefix + value
145			return
146		}
147	}
148	s.env = append(s.env, keyPrefix+value)
149}
150
151// SetBlockFuncs sets the command block functions for the shell
152func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
153	s.mu.Lock()
154	defer s.mu.Unlock()
155	s.blockFuncs = blockFuncs
156}
157
158// CommandsBlocker creates a BlockFunc that blocks exact command matches
159func CommandsBlocker(bannedCommands []string) BlockFunc {
160	bannedSet := make(map[string]struct{})
161	for _, cmd := range bannedCommands {
162		bannedSet[cmd] = struct{}{}
163	}
164
165	return func(args []string) bool {
166		if len(args) == 0 {
167			return false
168		}
169		_, ok := bannedSet[args[0]]
170		return ok
171	}
172}
173
174// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
175func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
176	return func(args []string) bool {
177		for _, blocked := range blockedSubCommands {
178			if len(args) >= len(blocked) {
179				match := true
180				for i, part := range blocked {
181					if args[i] != part {
182						match = false
183						break
184					}
185				}
186				if match {
187					return true
188				}
189			}
190		}
191		return false
192	}
193}
194
195func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
196	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
197		return func(ctx context.Context, args []string) error {
198			if len(args) == 0 {
199				return next(ctx, args)
200			}
201
202			for _, blockFunc := range s.blockFuncs {
203				if blockFunc(args) {
204					return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
205				}
206			}
207
208			return next(ctx, args)
209		}
210	}
211}
212
213// execPOSIX executes commands using POSIX shell emulation (cross-platform)
214func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
215	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
216	if err != nil {
217		return "", "", fmt.Errorf("could not parse command: %w", err)
218	}
219
220	var stdout, stderr bytes.Buffer
221	runner, err := interp.New(
222		interp.StdIO(nil, &stdout, &stderr),
223		interp.Interactive(false),
224		interp.Env(expand.ListEnviron(s.env...)),
225		interp.Dir(s.cwd),
226		interp.ExecHandlers(s.blockHandler(), coreutils.ExecHandler),
227	)
228	if err != nil {
229		return "", "", fmt.Errorf("could not run command: %w", err)
230	}
231
232	err = runner.Run(ctx, line)
233	s.cwd = runner.Dir
234	s.env = []string{}
235	for name, vr := range runner.Vars {
236		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
237	}
238	s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
239	return stdout.String(), stderr.String(), err
240}
241
242// IsInterrupt checks if an error is due to interruption
243func IsInterrupt(err error) bool {
244	return errors.Is(err, context.Canceled) ||
245		errors.Is(err, context.DeadlineExceeded)
246}
247
248// ExitCode extracts the exit code from an error
249func ExitCode(err error) int {
250	if err == nil {
251		return 0
252	}
253	var exitErr interp.ExitStatus
254	if errors.As(err, &exitErr) {
255		return int(exitErr)
256	}
257	return 1
258}