shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package offers two main types:
  4// - Shell: A general-purpose shell executor for one-off or managed commands
  5// - PersistentShell: A singleton shell that maintains state across the application
  6//
  7// WINDOWS COMPATIBILITY:
  8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3),
  9// even on Windows. Some caution has to be taken: commands should have forward
 10// slashes (/) as path separators to work, even on Windows.
 11package shell
 12
 13import (
 14	"bytes"
 15	"context"
 16	"errors"
 17	"fmt"
 18	"os"
 19	"strings"
 20	"sync"
 21
 22	"mvdan.cc/sh/v3/expand"
 23	"mvdan.cc/sh/v3/interp"
 24	"mvdan.cc/sh/v3/syntax"
 25)
 26
 27// ShellType represents the type of shell to use
 28type ShellType int
 29
 30const (
 31	ShellTypePOSIX ShellType = iota
 32	ShellTypeCmd
 33	ShellTypePowerShell
 34)
 35
 36// Logger interface for optional logging
 37type Logger interface {
 38	InfoPersist(msg string, keysAndValues ...interface{})
 39}
 40
 41// noopLogger is a logger that does nothing
 42type noopLogger struct{}
 43
 44func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
 45
 46// BlockFunc is a function that determines if a command should be blocked
 47type BlockFunc func(args []string) bool
 48
 49// Shell provides cross-platform shell execution with optional state persistence
 50type Shell struct {
 51	env        []string
 52	cwd        string
 53	mu         sync.Mutex
 54	logger     Logger
 55	blockFuncs []BlockFunc
 56}
 57
 58// Options for creating a new shell
 59type Options struct {
 60	WorkingDir string
 61	Env        []string
 62	Logger     Logger
 63	BlockFuncs []BlockFunc
 64}
 65
 66// NewShell creates a new shell instance with the given options
 67func NewShell(opts *Options) *Shell {
 68	if opts == nil {
 69		opts = &Options{}
 70	}
 71
 72	cwd := opts.WorkingDir
 73	if cwd == "" {
 74		cwd, _ = os.Getwd()
 75	}
 76
 77	env := opts.Env
 78	if env == nil {
 79		env = os.Environ()
 80	}
 81
 82	logger := opts.Logger
 83	if logger == nil {
 84		logger = noopLogger{}
 85	}
 86
 87	return &Shell{
 88		cwd:        cwd,
 89		env:        env,
 90		logger:     logger,
 91		blockFuncs: opts.BlockFuncs,
 92	}
 93}
 94
 95// Exec executes a command in the shell
 96func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
 97	s.mu.Lock()
 98	defer s.mu.Unlock()
 99
100	return s.execPOSIX(ctx, command)
101}
102
103// GetWorkingDir returns the current working directory
104func (s *Shell) GetWorkingDir() string {
105	s.mu.Lock()
106	defer s.mu.Unlock()
107	return s.cwd
108}
109
110// SetWorkingDir sets the working directory
111func (s *Shell) SetWorkingDir(dir string) error {
112	s.mu.Lock()
113	defer s.mu.Unlock()
114
115	// Verify the directory exists
116	if _, err := os.Stat(dir); err != nil {
117		return fmt.Errorf("directory does not exist: %w", err)
118	}
119
120	s.cwd = dir
121	return nil
122}
123
124// GetEnv returns a copy of the environment variables
125func (s *Shell) GetEnv() []string {
126	s.mu.Lock()
127	defer s.mu.Unlock()
128
129	env := make([]string, len(s.env))
130	copy(env, s.env)
131	return env
132}
133
134// SetEnv sets an environment variable
135func (s *Shell) SetEnv(key, value string) {
136	s.mu.Lock()
137	defer s.mu.Unlock()
138
139	// Update or add the environment variable
140	keyPrefix := key + "="
141	for i, env := range s.env {
142		if strings.HasPrefix(env, keyPrefix) {
143			s.env[i] = keyPrefix + value
144			return
145		}
146	}
147	s.env = append(s.env, keyPrefix+value)
148}
149
150// SetBlockFuncs sets the command block functions for the shell
151func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
152	s.mu.Lock()
153	defer s.mu.Unlock()
154	s.blockFuncs = blockFuncs
155}
156
157// CommandsBlocker creates a BlockFunc that blocks exact command matches
158func CommandsBlocker(bannedCommands []string) BlockFunc {
159	bannedSet := make(map[string]bool)
160	for _, cmd := range bannedCommands {
161		bannedSet[cmd] = true
162	}
163
164	return func(args []string) bool {
165		if len(args) == 0 {
166			return false
167		}
168		return bannedSet[args[0]]
169	}
170}
171
172// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
173func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
174	return func(args []string) bool {
175		for _, blocked := range blockedSubCommands {
176			if len(args) >= len(blocked) {
177				match := true
178				for i, part := range blocked {
179					if args[i] != part {
180						match = false
181						break
182					}
183				}
184				if match {
185					return true
186				}
187			}
188		}
189		return false
190	}
191}
192
193func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
194	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
195		return func(ctx context.Context, args []string) error {
196			if len(args) == 0 {
197				return next(ctx, args)
198			}
199
200			for _, blockFunc := range s.blockFuncs {
201				if blockFunc(args) {
202					return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
203				}
204			}
205
206			return next(ctx, args)
207		}
208	}
209}
210
211// execPOSIX executes commands using POSIX shell emulation (cross-platform)
212func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
213	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
214	if err != nil {
215		return "", "", fmt.Errorf("could not parse command: %w", err)
216	}
217
218	var stdout, stderr bytes.Buffer
219	runner, err := interp.New(
220		interp.StdIO(nil, &stdout, &stderr),
221		interp.Interactive(false),
222		interp.Env(expand.ListEnviron(s.env...)),
223		interp.Dir(s.cwd),
224		interp.ExecHandlers(s.blockHandler()),
225	)
226	if err != nil {
227		return "", "", fmt.Errorf("could not run command: %w", err)
228	}
229
230	err = runner.Run(ctx, line)
231	s.cwd = runner.Dir
232	s.env = []string{}
233	for name, vr := range runner.Vars {
234		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
235	}
236	s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
237	return stdout.String(), stderr.String(), err
238}
239
240// IsInterrupt checks if an error is due to interruption
241func IsInterrupt(err error) bool {
242	return errors.Is(err, context.Canceled) ||
243		errors.Is(err, context.DeadlineExceeded)
244}
245
246// ExitCode extracts the exit code from an error
247func ExitCode(err error) int {
248	if err == nil {
249		return 0
250	}
251	status, ok := interp.IsExitStatus(err)
252	if ok {
253		return int(status)
254	}
255	return 1
256}