shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package offers two main types:
  4// - Shell: A general-purpose shell executor for one-off or managed commands
  5// - PersistentShell: A singleton shell that maintains state across the application
  6//
  7// WINDOWS COMPATIBILITY:
  8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3),
  9// even on Windows. Some caution has to be taken: commands should have forward
 10// slashes (/) as path separators to work, even on Windows.
 11package shell
 12
 13import (
 14	"bytes"
 15	"context"
 16	"errors"
 17	"fmt"
 18	"os"
 19	"slices"
 20	"strings"
 21	"sync"
 22
 23	"github.com/charmbracelet/crush/internal/slicesext"
 24	"mvdan.cc/sh/moreinterp/coreutils"
 25	"mvdan.cc/sh/v3/expand"
 26	"mvdan.cc/sh/v3/interp"
 27	"mvdan.cc/sh/v3/syntax"
 28)
 29
 30// ShellType represents the type of shell to use
 31type ShellType int
 32
 33const (
 34	ShellTypePOSIX ShellType = iota
 35	ShellTypeCmd
 36	ShellTypePowerShell
 37)
 38
 39// Logger interface for optional logging
 40type Logger interface {
 41	InfoPersist(msg string, keysAndValues ...any)
 42}
 43
 44// noopLogger is a logger that does nothing
 45type noopLogger struct{}
 46
 47func (noopLogger) InfoPersist(msg string, keysAndValues ...any) {}
 48
 49// BlockFunc is a function that determines if a command should be blocked
 50type BlockFunc func(args []string) bool
 51
 52// Shell provides cross-platform shell execution with optional state persistence
 53type Shell struct {
 54	env        []string
 55	cwd        string
 56	mu         sync.Mutex
 57	logger     Logger
 58	blockFuncs []BlockFunc
 59}
 60
 61// Options for creating a new shell
 62type Options struct {
 63	WorkingDir string
 64	Env        []string
 65	Logger     Logger
 66	BlockFuncs []BlockFunc
 67}
 68
 69// NewShell creates a new shell instance with the given options
 70func NewShell(opts *Options) *Shell {
 71	if opts == nil {
 72		opts = &Options{}
 73	}
 74
 75	cwd := opts.WorkingDir
 76	if cwd == "" {
 77		cwd, _ = os.Getwd()
 78	}
 79
 80	env := opts.Env
 81	if env == nil {
 82		env = os.Environ()
 83	}
 84
 85	logger := opts.Logger
 86	if logger == nil {
 87		logger = noopLogger{}
 88	}
 89
 90	return &Shell{
 91		cwd:        cwd,
 92		env:        env,
 93		logger:     logger,
 94		blockFuncs: opts.BlockFuncs,
 95	}
 96}
 97
 98// Exec executes a command in the shell
 99func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
100	s.mu.Lock()
101	defer s.mu.Unlock()
102
103	return s.execPOSIX(ctx, command)
104}
105
106// GetWorkingDir returns the current working directory
107func (s *Shell) GetWorkingDir() string {
108	s.mu.Lock()
109	defer s.mu.Unlock()
110	return s.cwd
111}
112
113// SetWorkingDir sets the working directory
114func (s *Shell) SetWorkingDir(dir string) error {
115	s.mu.Lock()
116	defer s.mu.Unlock()
117
118	// Verify the directory exists
119	if _, err := os.Stat(dir); err != nil {
120		return fmt.Errorf("directory does not exist: %w", err)
121	}
122
123	s.cwd = dir
124	return nil
125}
126
127// GetEnv returns a copy of the environment variables
128func (s *Shell) GetEnv() []string {
129	s.mu.Lock()
130	defer s.mu.Unlock()
131
132	env := make([]string, len(s.env))
133	copy(env, s.env)
134	return env
135}
136
137// SetEnv sets an environment variable
138func (s *Shell) SetEnv(key, value string) {
139	s.mu.Lock()
140	defer s.mu.Unlock()
141
142	// Update or add the environment variable
143	keyPrefix := key + "="
144	for i, env := range s.env {
145		if strings.HasPrefix(env, keyPrefix) {
146			s.env[i] = keyPrefix + value
147			return
148		}
149	}
150	s.env = append(s.env, keyPrefix+value)
151}
152
153// SetBlockFuncs sets the command block functions for the shell
154func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
155	s.mu.Lock()
156	defer s.mu.Unlock()
157	s.blockFuncs = blockFuncs
158}
159
160// CommandsBlocker creates a BlockFunc that blocks exact command matches
161func CommandsBlocker(cmds []string) BlockFunc {
162	bannedSet := make(map[string]struct{})
163	for _, cmd := range cmds {
164		bannedSet[cmd] = struct{}{}
165	}
166
167	return func(args []string) bool {
168		if len(args) == 0 {
169			return false
170		}
171		_, ok := bannedSet[args[0]]
172		return ok
173	}
174}
175
176// ArgumentsBlocker creates a BlockFunc that blocks specific subcommand
177func ArgumentsBlocker(cmd string, args []string, flags []string) BlockFunc {
178	return func(parts []string) bool {
179		if len(parts) == 0 || parts[0] != cmd {
180			return false
181		}
182
183		argParts, flagParts := splitArgsFlags(parts[1:])
184		if len(argParts) < len(args) || len(flagParts) < len(flags) {
185			return false
186		}
187
188		argsMatch := slices.Equal(argParts[:len(args)], args)
189		flagsMatch := slicesext.IsSubset(flags, flagParts)
190
191		return argsMatch && flagsMatch
192	}
193}
194
195func splitArgsFlags(parts []string) (args []string, flags []string) {
196	args = make([]string, 0, len(parts))
197	flags = make([]string, 0, len(parts))
198	for _, part := range parts {
199		if strings.HasPrefix(part, "-") {
200			flags = append(flags, part)
201		} else {
202			args = append(args, part)
203		}
204	}
205	return
206}
207
208func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
209	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
210		return func(ctx context.Context, args []string) error {
211			if len(args) == 0 {
212				return next(ctx, args)
213			}
214
215			for _, blockFunc := range s.blockFuncs {
216				if blockFunc(args) {
217					return fmt.Errorf("command is not allowed for security reasons: %s", strings.Join(args, " "))
218				}
219			}
220
221			return next(ctx, args)
222		}
223	}
224}
225
226// execPOSIX executes commands using POSIX shell emulation (cross-platform)
227func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
228	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
229	if err != nil {
230		return "", "", fmt.Errorf("could not parse command: %w", err)
231	}
232
233	var stdout, stderr bytes.Buffer
234	runner, err := interp.New(
235		interp.StdIO(nil, &stdout, &stderr),
236		interp.Interactive(false),
237		interp.Env(expand.ListEnviron(s.env...)),
238		interp.Dir(s.cwd),
239		interp.ExecHandlers(s.blockHandler(), coreutils.ExecHandler),
240	)
241	if err != nil {
242		return "", "", fmt.Errorf("could not run command: %w", err)
243	}
244
245	err = runner.Run(ctx, line)
246	s.cwd = runner.Dir
247	s.env = []string{}
248	for name, vr := range runner.Vars {
249		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
250	}
251	s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
252	return stdout.String(), stderr.String(), err
253}
254
255// IsInterrupt checks if an error is due to interruption
256func IsInterrupt(err error) bool {
257	return errors.Is(err, context.Canceled) ||
258		errors.Is(err, context.DeadlineExceeded)
259}
260
261// ExitCode extracts the exit code from an error
262func ExitCode(err error) int {
263	if err == nil {
264		return 0
265	}
266	var exitErr interp.ExitStatus
267	if errors.As(err, &exitErr) {
268		return int(exitErr)
269	}
270	return 1
271}