shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package offers two main types:
  4// - Shell: A general-purpose shell executor for one-off or managed commands
  5// - PersistentShell: A singleton shell that maintains state across the application
  6//
  7// WINDOWS COMPATIBILITY:
  8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
  9// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
 10package shell
 11
 12import (
 13	"bytes"
 14	"context"
 15	"errors"
 16	"fmt"
 17	"os"
 18	"os/exec"
 19	"runtime"
 20	"strings"
 21	"sync"
 22
 23	"mvdan.cc/sh/v3/expand"
 24	"mvdan.cc/sh/v3/interp"
 25	"mvdan.cc/sh/v3/syntax"
 26)
 27
 28// ShellType represents the type of shell to use
 29type ShellType int
 30
 31const (
 32	ShellTypePOSIX ShellType = iota
 33	ShellTypeCmd
 34	ShellTypePowerShell
 35)
 36
 37// Logger interface for optional logging
 38type Logger interface {
 39	InfoPersist(msg string, keysAndValues ...interface{})
 40}
 41
 42// noopLogger is a logger that does nothing
 43type noopLogger struct{}
 44
 45func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
 46
 47// CommandBlockFunc is a function that determines if a command should be blocked
 48type CommandBlockFunc func(args []string) bool
 49
 50// Shell provides cross-platform shell execution with optional state persistence
 51type Shell struct {
 52	env        []string
 53	cwd        string
 54	mu         sync.Mutex
 55	logger     Logger
 56	blockFuncs []CommandBlockFunc
 57}
 58
 59// Options for creating a new shell
 60type Options struct {
 61	WorkingDir string
 62	Env        []string
 63	Logger     Logger
 64	BlockFuncs []CommandBlockFunc
 65}
 66
 67// NewShell creates a new shell instance with the given options
 68func NewShell(opts *Options) *Shell {
 69	if opts == nil {
 70		opts = &Options{}
 71	}
 72
 73	cwd := opts.WorkingDir
 74	if cwd == "" {
 75		cwd, _ = os.Getwd()
 76	}
 77
 78	env := opts.Env
 79	if env == nil {
 80		env = os.Environ()
 81	}
 82
 83	logger := opts.Logger
 84	if logger == nil {
 85		logger = noopLogger{}
 86	}
 87
 88	return &Shell{
 89		cwd:        cwd,
 90		env:        env,
 91		logger:     logger,
 92		blockFuncs: opts.BlockFuncs,
 93	}
 94}
 95
 96// Exec executes a command in the shell
 97func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
 98	s.mu.Lock()
 99	defer s.mu.Unlock()
100
101	// Determine which shell to use based on platform and command
102	shellType := s.determineShellType(command)
103
104	switch shellType {
105	case ShellTypeCmd:
106		return s.execWindows(ctx, command, "cmd")
107	case ShellTypePowerShell:
108		return s.execWindows(ctx, command, "powershell")
109	default:
110		return s.execPOSIX(ctx, command)
111	}
112}
113
114// GetWorkingDir returns the current working directory
115func (s *Shell) GetWorkingDir() string {
116	s.mu.Lock()
117	defer s.mu.Unlock()
118	return s.cwd
119}
120
121// SetWorkingDir sets the working directory
122func (s *Shell) SetWorkingDir(dir string) error {
123	s.mu.Lock()
124	defer s.mu.Unlock()
125
126	// Verify the directory exists
127	if _, err := os.Stat(dir); err != nil {
128		return fmt.Errorf("directory does not exist: %w", err)
129	}
130
131	s.cwd = dir
132	return nil
133}
134
135// GetEnv returns a copy of the environment variables
136func (s *Shell) GetEnv() []string {
137	s.mu.Lock()
138	defer s.mu.Unlock()
139
140	env := make([]string, len(s.env))
141	copy(env, s.env)
142	return env
143}
144
145// SetEnv sets an environment variable
146func (s *Shell) SetEnv(key, value string) {
147	s.mu.Lock()
148	defer s.mu.Unlock()
149
150	// Update or add the environment variable
151	keyPrefix := key + "="
152	for i, env := range s.env {
153		if strings.HasPrefix(env, keyPrefix) {
154			s.env[i] = keyPrefix + value
155			return
156		}
157	}
158	s.env = append(s.env, keyPrefix+value)
159}
160
161// SetBlockFuncs sets the command block functions for the shell
162func (s *Shell) SetBlockFuncs(blockFuncs []CommandBlockFunc) {
163	s.mu.Lock()
164	defer s.mu.Unlock()
165	s.blockFuncs = blockFuncs
166}
167
168// Windows-specific commands that should use native shell
169var windowsNativeCommands = map[string]bool{
170	"dir":      true,
171	"type":     true,
172	"copy":     true,
173	"move":     true,
174	"del":      true,
175	"md":       true,
176	"mkdir":    true,
177	"rd":       true,
178	"rmdir":    true,
179	"cls":      true,
180	"where":    true,
181	"tasklist": true,
182	"taskkill": true,
183	"net":      true,
184	"sc":       true,
185	"reg":      true,
186	"wmic":     true,
187}
188
189// determineShellType decides which shell to use based on platform and command
190func (s *Shell) determineShellType(command string) ShellType {
191	if runtime.GOOS != "windows" {
192		return ShellTypePOSIX
193	}
194
195	// Extract the first command from the command line
196	parts := strings.Fields(command)
197	if len(parts) == 0 {
198		return ShellTypePOSIX
199	}
200
201	firstCmd := strings.ToLower(parts[0])
202
203	// Check if it's a Windows-specific command
204	if windowsNativeCommands[firstCmd] {
205		return ShellTypeCmd
206	}
207
208	// Check for PowerShell-specific syntax
209	if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") ||
210		strings.Contains(command, "New-") || strings.Contains(command, "$_") ||
211		strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") {
212		return ShellTypePowerShell
213	}
214
215	// Default to POSIX emulation for cross-platform compatibility
216	return ShellTypePOSIX
217}
218
219// CreateSimpleCommandBlocker creates a CommandBlockFunc that blocks exact command matches
220func CreateSimpleCommandBlocker(bannedCommands []string) CommandBlockFunc {
221	bannedSet := make(map[string]bool)
222	for _, cmd := range bannedCommands {
223		bannedSet[cmd] = true
224	}
225	
226	return func(args []string) bool {
227		if len(args) == 0 {
228			return false
229		}
230		return bannedSet[args[0]]
231	}
232}
233
234// CreateSubCommandBlocker creates a CommandBlockFunc that blocks specific subcommands
235func CreateSubCommandBlocker(blockedSubCommands [][]string) CommandBlockFunc {
236	return func(args []string) bool {
237		for _, blocked := range blockedSubCommands {
238			if len(args) >= len(blocked) {
239				match := true
240				for i, part := range blocked {
241					if args[i] != part {
242						match = false
243						break
244					}
245				}
246				if match {
247					return true
248				}
249			}
250		}
251		return false
252	}
253}
254func (s *Shell) createCommandBlockHandler() func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
255	return func(next interp.ExecHandlerFunc) interp.ExecHandlerFunc {
256		return func(ctx context.Context, args []string) error {
257			if len(args) == 0 {
258				return next(ctx, args)
259			}
260
261			for _, blockFunc := range s.blockFuncs {
262				if blockFunc(args) {
263					return fmt.Errorf("command '%s' is not allowed for security reasons", strings.Join(args, " "))
264				}
265			}
266
267			return next(ctx, args)
268		}
269	}
270}
271
272// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
273func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
274	var cmd *exec.Cmd
275
276	// Handle directory changes specially to maintain persistent shell behavior
277	if strings.HasPrefix(strings.TrimSpace(command), "cd ") {
278		return s.handleWindowsCD(command)
279	}
280
281	switch shell {
282	case "cmd":
283		// Use cmd.exe for Windows commands
284		// Add current directory context to maintain state
285		fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command)
286		cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand)
287	case "powershell":
288		// Use PowerShell for PowerShell commands
289		// Add current directory context to maintain state
290		fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command)
291		cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand)
292	default:
293		return "", "", fmt.Errorf("unsupported Windows shell: %s", shell)
294	}
295
296	// Set environment variables
297	cmd.Env = s.env
298
299	var stdout, stderr bytes.Buffer
300	cmd.Stdout = &stdout
301	cmd.Stderr = &stderr
302
303	err := cmd.Run()
304
305	s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
306	return stdout.String(), stderr.String(), err
307}
308
309// handleWindowsCD handles directory changes for Windows shells
310func (s *Shell) handleWindowsCD(command string) (string, string, error) {
311	// Extract the target directory from the cd command
312	parts := strings.Fields(command)
313	if len(parts) < 2 {
314		return "", "cd: missing directory argument", fmt.Errorf("missing directory argument")
315	}
316
317	targetDir := parts[1]
318
319	// Handle relative paths
320	if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") {
321		// Relative path - resolve against current directory
322		if targetDir == ".." {
323			// Go up one directory
324			if len(s.cwd) > 3 { // Don't go above drive root (C:\)
325				lastSlash := strings.LastIndex(s.cwd, "\\")
326				if lastSlash > 2 { // Keep drive letter
327					s.cwd = s.cwd[:lastSlash]
328				}
329			}
330		} else if targetDir != "." {
331			// Go to subdirectory
332			s.cwd = s.cwd + "\\" + targetDir
333		}
334	} else {
335		// Absolute path
336		s.cwd = targetDir
337	}
338
339	// Verify the directory exists
340	if _, err := os.Stat(s.cwd); err != nil {
341		return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err
342	}
343
344	return "", "", nil
345}
346
347// execPOSIX executes commands using POSIX shell emulation (cross-platform)
348func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
349	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
350	if err != nil {
351		return "", "", fmt.Errorf("could not parse command: %w", err)
352	}
353
354	var stdout, stderr bytes.Buffer
355	runner, err := interp.New(
356		interp.StdIO(nil, &stdout, &stderr),
357		interp.Interactive(false),
358		interp.Env(expand.ListEnviron(s.env...)),
359		interp.Dir(s.cwd),
360		interp.ExecHandlers(s.createCommandBlockHandler()),
361	)
362	if err != nil {
363		return "", "", fmt.Errorf("could not run command: %w", err)
364	}
365
366	err = runner.Run(ctx, line)
367	s.cwd = runner.Dir
368	s.env = []string{}
369	for name, vr := range runner.Vars {
370		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
371	}
372	s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
373	return stdout.String(), stderr.String(), err
374}
375
376// IsInterrupt checks if an error is due to interruption
377func IsInterrupt(err error) bool {
378	return errors.Is(err, context.Canceled) ||
379		errors.Is(err, context.DeadlineExceeded)
380}
381
382// ExitCode extracts the exit code from an error
383func ExitCode(err error) int {
384	if err == nil {
385		return 0
386	}
387	status, ok := interp.IsExitStatus(err)
388	if ok {
389		return int(status)
390	}
391	return 1
392}