shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package offers two main types:
  4// - Shell: A general-purpose shell executor for one-off or managed commands
  5// - PersistentShell: A singleton shell that maintains state across the application
  6//
  7// WINDOWS COMPATIBILITY:
  8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
  9// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
 10package shell
 11
 12import (
 13	"bytes"
 14	"context"
 15	"errors"
 16	"fmt"
 17	"os"
 18	"os/exec"
 19	"runtime"
 20	"strings"
 21	"sync"
 22
 23	"mvdan.cc/sh/v3/expand"
 24	"mvdan.cc/sh/v3/interp"
 25	"mvdan.cc/sh/v3/syntax"
 26)
 27
 28// ShellType represents the type of shell to use
 29type ShellType int
 30
 31const (
 32	ShellTypePOSIX ShellType = iota
 33	ShellTypeCmd
 34	ShellTypePowerShell
 35)
 36
 37// Logger interface for optional logging
 38type Logger interface {
 39	InfoPersist(msg string, keysAndValues ...interface{})
 40}
 41
 42// noopLogger is a logger that does nothing
 43type noopLogger struct{}
 44
 45func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
 46
 47// BlockFunc is a function that determines if a command should be blocked
 48type BlockFunc func(args []string) bool
 49
 50// Shell provides cross-platform shell execution with optional state persistence
 51type Shell struct {
 52	env        []string
 53	cwd        string
 54	mu         sync.Mutex
 55	logger     Logger
 56	blockFuncs []BlockFunc
 57}
 58
 59// Options for creating a new shell
 60type Options struct {
 61	WorkingDir string
 62	Env        []string
 63	Logger     Logger
 64	BlockFuncs []BlockFunc
 65}
 66
 67// NewShell creates a new shell instance with the given options
 68func NewShell(opts *Options) *Shell {
 69	if opts == nil {
 70		opts = &Options{}
 71	}
 72
 73	cwd := opts.WorkingDir
 74	if cwd == "" {
 75		cwd, _ = os.Getwd()
 76	}
 77
 78	env := opts.Env
 79	if env == nil {
 80		env = os.Environ()
 81	}
 82
 83	logger := opts.Logger
 84	if logger == nil {
 85		logger = noopLogger{}
 86	}
 87
 88	return &Shell{
 89		cwd:        cwd,
 90		env:        env,
 91		logger:     logger,
 92		blockFuncs: opts.BlockFuncs,
 93	}
 94}
 95
 96// Exec executes a command in the shell
 97func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
 98	s.mu.Lock()
 99	defer s.mu.Unlock()
100
101	// Determine which shell to use based on platform and command
102	shellType := s.determineShellType(command)
103
104	switch shellType {
105	case ShellTypeCmd:
106		return s.execWindows(ctx, command, "cmd")
107	case ShellTypePowerShell:
108		return s.execWindows(ctx, command, "powershell")
109	default:
110		return s.execPOSIX(ctx, command)
111	}
112}
113
114// GetWorkingDir returns the current working directory
115func (s *Shell) GetWorkingDir() string {
116	s.mu.Lock()
117	defer s.mu.Unlock()
118	return s.cwd
119}
120
121// SetWorkingDir sets the working directory
122func (s *Shell) SetWorkingDir(dir string) error {
123	s.mu.Lock()
124	defer s.mu.Unlock()
125
126	// Verify the directory exists
127	if _, err := os.Stat(dir); err != nil {
128		return fmt.Errorf("directory does not exist: %w", err)
129	}
130
131	s.cwd = dir
132	return nil
133}
134
135// GetEnv returns a copy of the environment variables
136func (s *Shell) GetEnv() []string {
137	s.mu.Lock()
138	defer s.mu.Unlock()
139
140	env := make([]string, len(s.env))
141	copy(env, s.env)
142	return env
143}
144
145// SetEnv sets an environment variable
146func (s *Shell) SetEnv(key, value string) {
147	s.mu.Lock()
148	defer s.mu.Unlock()
149
150	// Update or add the environment variable
151	keyPrefix := key + "="
152	for i, env := range s.env {
153		if strings.HasPrefix(env, keyPrefix) {
154			s.env[i] = keyPrefix + value
155			return
156		}
157	}
158	s.env = append(s.env, keyPrefix+value)
159}
160
161// SetBlockFuncs sets the command block functions for the shell
162func (s *Shell) SetBlockFuncs(blockFuncs []BlockFunc) {
163	s.mu.Lock()
164	defer s.mu.Unlock()
165	s.blockFuncs = blockFuncs
166}
167
168// Windows-specific commands that should use native shell
169var windowsNativeCommands = map[string]bool{
170	"dir":      true,
171	"type":     true,
172	"copy":     true,
173	"move":     true,
174	"del":      true,
175	"md":       true,
176	"mkdir":    true,
177	"rd":       true,
178	"rmdir":    true,
179	"cls":      true,
180	"where":    true,
181	"tasklist": true,
182	"taskkill": true,
183	"net":      true,
184	"sc":       true,
185	"reg":      true,
186	"wmic":     true,
187}
188
189// determineShellType decides which shell to use based on platform and command
190func (s *Shell) determineShellType(command string) ShellType {
191	if runtime.GOOS != "windows" {
192		return ShellTypePOSIX
193	}
194
195	// Extract the first command from the command line
196	parts := strings.Fields(command)
197	if len(parts) == 0 {
198		return ShellTypePOSIX
199	}
200
201	firstCmd := strings.ToLower(parts[0])
202
203	// Check if it's a Windows-specific command
204	if windowsNativeCommands[firstCmd] {
205		return ShellTypeCmd
206	}
207
208	// Check for PowerShell-specific syntax
209	if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") ||
210		strings.Contains(command, "New-") || strings.Contains(command, "$_") ||
211		strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") {
212		return ShellTypePowerShell
213	}
214
215	// Default to POSIX emulation for cross-platform compatibility
216	return ShellTypePOSIX
217}
218
219// CommandsBlocker creates a BlockFunc that blocks exact command matches
220func CommandsBlocker(bannedCommands []string) BlockFunc {
221	bannedSet := make(map[string]bool)
222	for _, cmd := range bannedCommands {
223		bannedSet[cmd] = true
224	}
225
226	return func(args []string) bool {
227		if len(args) == 0 {
228			return false
229		}
230		return bannedSet[args[0]]
231	}
232}
233
234// ArgumentsBlocker creates a BlockFunc that blocks specific subcommands
235func ArgumentsBlocker(blockedSubCommands [][]string) BlockFunc {
236	return func(args []string) bool {
237		for _, blocked := range blockedSubCommands {
238			if len(args) >= len(blocked) {
239				match := true
240				for i, part := range blocked {
241					if args[i] != part {
242						match = false
243						break
244					}
245				}
246				if match {
247					return true
248				}
249			}
250		}
251		return false
252	}
253}
254
255func (s *Shell) blockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
256	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
257		return func(ctx context.Context, args []string) error {
258			if len(args) == 0 {
259				return next(ctx, args)
260			}
261
262			for _, blockFunc := range s.blockFuncs {
263				if blockFunc(args) {
264					return fmt.Errorf("command '%s' is not allowed for security reasons", strings.Join(args, " "))
265				}
266			}
267
268			return next(ctx, args)
269		}
270	}
271}
272
273// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
274func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
275	var cmd *exec.Cmd
276
277	// Handle directory changes specially to maintain persistent shell behavior
278	if strings.HasPrefix(strings.TrimSpace(command), "cd ") {
279		return s.handleWindowsCD(command)
280	}
281
282	switch shell {
283	case "cmd":
284		// Use cmd.exe for Windows commands
285		// Add current directory context to maintain state
286		fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command)
287		cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand)
288	case "powershell":
289		// Use PowerShell for PowerShell commands
290		// Add current directory context to maintain state
291		fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command)
292		cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand)
293	default:
294		return "", "", fmt.Errorf("unsupported Windows shell: %s", shell)
295	}
296
297	// Set environment variables
298	cmd.Env = s.env
299
300	var stdout, stderr bytes.Buffer
301	cmd.Stdout = &stdout
302	cmd.Stderr = &stderr
303
304	err := cmd.Run()
305
306	s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
307	return stdout.String(), stderr.String(), err
308}
309
310// handleWindowsCD handles directory changes for Windows shells
311func (s *Shell) handleWindowsCD(command string) (string, string, error) {
312	// Extract the target directory from the cd command
313	parts := strings.Fields(command)
314	if len(parts) < 2 {
315		return "", "cd: missing directory argument", fmt.Errorf("missing directory argument")
316	}
317
318	targetDir := parts[1]
319
320	// Handle relative paths
321	if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") {
322		// Relative path - resolve against current directory
323		if targetDir == ".." {
324			// Go up one directory
325			if len(s.cwd) > 3 { // Don't go above drive root (C:\)
326				lastSlash := strings.LastIndex(s.cwd, "\\")
327				if lastSlash > 2 { // Keep drive letter
328					s.cwd = s.cwd[:lastSlash]
329				}
330			}
331		} else if targetDir != "." {
332			// Go to subdirectory
333			s.cwd = s.cwd + "\\" + targetDir
334		}
335	} else {
336		// Absolute path
337		s.cwd = targetDir
338	}
339
340	// Verify the directory exists
341	if _, err := os.Stat(s.cwd); err != nil {
342		return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err
343	}
344
345	return "", "", nil
346}
347
348// execPOSIX executes commands using POSIX shell emulation (cross-platform)
349func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
350	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
351	if err != nil {
352		return "", "", fmt.Errorf("could not parse command: %w", err)
353	}
354
355	var stdout, stderr bytes.Buffer
356	runner, err := interp.New(
357		interp.StdIO(nil, &stdout, &stderr),
358		interp.Interactive(false),
359		interp.Env(expand.ListEnviron(s.env...)),
360		interp.Dir(s.cwd),
361		interp.ExecHandlers(s.blockHandler()),
362	)
363	if err != nil {
364		return "", "", fmt.Errorf("could not run command: %w", err)
365	}
366
367	err = runner.Run(ctx, line)
368	s.cwd = runner.Dir
369	s.env = []string{}
370	for name, vr := range runner.Vars {
371		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
372	}
373	s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
374	return stdout.String(), stderr.String(), err
375}
376
377// IsInterrupt checks if an error is due to interruption
378func IsInterrupt(err error) bool {
379	return errors.Is(err, context.Canceled) ||
380		errors.Is(err, context.DeadlineExceeded)
381}
382
383// ExitCode extracts the exit code from an error
384func ExitCode(err error) int {
385	if err == nil {
386		return 0
387	}
388	status, ok := interp.IsExitStatus(err)
389	if ok {
390		return int(status)
391	}
392	return 1
393}