shell.go

  1// Package shell provides cross-platform shell execution capabilities.
  2//
  3// This package offers two main types:
  4// - Shell: A general-purpose shell executor for one-off or managed commands
  5// - PersistentShell: A singleton shell that maintains state across the application
  6//
  7// WINDOWS COMPATIBILITY:
  8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
  9// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
 10package shell
 11
 12import (
 13	"bytes"
 14	"context"
 15	"errors"
 16	"fmt"
 17	"os"
 18	"os/exec"
 19	"runtime"
 20	"strings"
 21	"sync"
 22
 23	"mvdan.cc/sh/v3/expand"
 24	"mvdan.cc/sh/v3/interp"
 25	"mvdan.cc/sh/v3/syntax"
 26)
 27
 28// ShellType represents the type of shell to use
 29type ShellType int
 30
 31const (
 32	ShellTypePOSIX ShellType = iota
 33	ShellTypeCmd
 34	ShellTypePowerShell
 35)
 36
 37// Logger interface for optional logging
 38type Logger interface {
 39	InfoPersist(msg string, keysAndValues ...interface{})
 40}
 41
 42// noopLogger is a logger that does nothing
 43type noopLogger struct{}
 44
 45func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
 46
 47// Shell provides cross-platform shell execution with optional state persistence
 48type Shell struct {
 49	env    []string
 50	cwd    string
 51	mu     sync.Mutex
 52	logger Logger
 53}
 54
 55// Options for creating a new shell
 56type Options struct {
 57	WorkingDir string
 58	Env        []string
 59	Logger     Logger
 60}
 61
 62// NewShell creates a new shell instance with the given options
 63func NewShell(opts *Options) *Shell {
 64	if opts == nil {
 65		opts = &Options{}
 66	}
 67
 68	cwd := opts.WorkingDir
 69	if cwd == "" {
 70		cwd, _ = os.Getwd()
 71	}
 72
 73	env := opts.Env
 74	if env == nil {
 75		env = os.Environ()
 76	}
 77
 78	logger := opts.Logger
 79	if logger == nil {
 80		logger = noopLogger{}
 81	}
 82
 83	return &Shell{
 84		cwd:    cwd,
 85		env:    env,
 86		logger: logger,
 87	}
 88}
 89
 90// Exec executes a command in the shell
 91func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
 92	s.mu.Lock()
 93	defer s.mu.Unlock()
 94
 95	// Determine which shell to use based on platform and command
 96	shellType := s.determineShellType(command)
 97
 98	switch shellType {
 99	case ShellTypeCmd:
100		return s.execWindows(ctx, command, "cmd")
101	case ShellTypePowerShell:
102		return s.execWindows(ctx, command, "powershell")
103	default:
104		return s.execPOSIX(ctx, command)
105	}
106}
107
108// GetWorkingDir returns the current working directory
109func (s *Shell) GetWorkingDir() string {
110	s.mu.Lock()
111	defer s.mu.Unlock()
112	return s.cwd
113}
114
115// SetWorkingDir sets the working directory
116func (s *Shell) SetWorkingDir(dir string) error {
117	s.mu.Lock()
118	defer s.mu.Unlock()
119
120	// Verify the directory exists
121	if _, err := os.Stat(dir); err != nil {
122		return fmt.Errorf("directory does not exist: %w", err)
123	}
124
125	s.cwd = dir
126	return nil
127}
128
129// GetEnv returns a copy of the environment variables
130func (s *Shell) GetEnv() []string {
131	s.mu.Lock()
132	defer s.mu.Unlock()
133
134	env := make([]string, len(s.env))
135	copy(env, s.env)
136	return env
137}
138
139// SetEnv sets an environment variable
140func (s *Shell) SetEnv(key, value string) {
141	s.mu.Lock()
142	defer s.mu.Unlock()
143
144	// Update or add the environment variable
145	keyPrefix := key + "="
146	for i, env := range s.env {
147		if strings.HasPrefix(env, keyPrefix) {
148			s.env[i] = keyPrefix + value
149			return
150		}
151	}
152	s.env = append(s.env, keyPrefix+value)
153}
154
155// Windows-specific commands that should use native shell
156var windowsNativeCommands = map[string]bool{
157	"dir":      true,
158	"type":     true,
159	"copy":     true,
160	"move":     true,
161	"del":      true,
162	"md":       true,
163	"mkdir":    true,
164	"rd":       true,
165	"rmdir":    true,
166	"cls":      true,
167	"where":    true,
168	"tasklist": true,
169	"taskkill": true,
170	"net":      true,
171	"sc":       true,
172	"reg":      true,
173	"wmic":     true,
174}
175
176// determineShellType decides which shell to use based on platform and command
177func (s *Shell) determineShellType(command string) ShellType {
178	if runtime.GOOS != "windows" {
179		return ShellTypePOSIX
180	}
181
182	// Extract the first command from the command line
183	parts := strings.Fields(command)
184	if len(parts) == 0 {
185		return ShellTypePOSIX
186	}
187
188	firstCmd := strings.ToLower(parts[0])
189
190	// Check if it's a Windows-specific command
191	if windowsNativeCommands[firstCmd] {
192		return ShellTypeCmd
193	}
194
195	// Check for PowerShell-specific syntax
196	if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") ||
197		strings.Contains(command, "New-") || strings.Contains(command, "$_") ||
198		strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") {
199		return ShellTypePowerShell
200	}
201
202	// Default to POSIX emulation for cross-platform compatibility
203	return ShellTypePOSIX
204}
205
206// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
207func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
208	var cmd *exec.Cmd
209
210	// Handle directory changes specially to maintain persistent shell behavior
211	if strings.HasPrefix(strings.TrimSpace(command), "cd ") {
212		return s.handleWindowsCD(command)
213	}
214
215	switch shell {
216	case "cmd":
217		// Use cmd.exe for Windows commands
218		// Add current directory context to maintain state
219		fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command)
220		cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand)
221	case "powershell":
222		// Use PowerShell for PowerShell commands
223		// Add current directory context to maintain state
224		fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command)
225		cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand)
226	default:
227		return "", "", fmt.Errorf("unsupported Windows shell: %s", shell)
228	}
229
230	// Set environment variables
231	cmd.Env = s.env
232
233	var stdout, stderr bytes.Buffer
234	cmd.Stdout = &stdout
235	cmd.Stderr = &stderr
236
237	err := cmd.Run()
238
239	s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
240	return stdout.String(), stderr.String(), err
241}
242
243// handleWindowsCD handles directory changes for Windows shells
244func (s *Shell) handleWindowsCD(command string) (string, string, error) {
245	// Extract the target directory from the cd command
246	parts := strings.Fields(command)
247	if len(parts) < 2 {
248		return "", "cd: missing directory argument", fmt.Errorf("missing directory argument")
249	}
250
251	targetDir := parts[1]
252
253	// Handle relative paths
254	if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") {
255		// Relative path - resolve against current directory
256		if targetDir == ".." {
257			// Go up one directory
258			if len(s.cwd) > 3 { // Don't go above drive root (C:\)
259				lastSlash := strings.LastIndex(s.cwd, "\\")
260				if lastSlash > 2 { // Keep drive letter
261					s.cwd = s.cwd[:lastSlash]
262				}
263			}
264		} else if targetDir != "." {
265			// Go to subdirectory
266			s.cwd = s.cwd + "\\" + targetDir
267		}
268	} else {
269		// Absolute path
270		s.cwd = targetDir
271	}
272
273	// Verify the directory exists
274	if _, err := os.Stat(s.cwd); err != nil {
275		return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err
276	}
277
278	return "", "", nil
279}
280
281// execPOSIX executes commands using POSIX shell emulation (cross-platform)
282func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
283	line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
284	if err != nil {
285		return "", "", fmt.Errorf("could not parse command: %w", err)
286	}
287
288	var stdout, stderr bytes.Buffer
289	runner, err := interp.New(
290		interp.StdIO(nil, &stdout, &stderr),
291		interp.Interactive(false),
292		interp.Env(expand.ListEnviron(s.env...)),
293		interp.Dir(s.cwd),
294	)
295	if err != nil {
296		return "", "", fmt.Errorf("could not run command: %w", err)
297	}
298
299	err = runner.Run(ctx, line)
300	s.cwd = runner.Dir
301	s.env = []string{}
302	for name, vr := range runner.Vars {
303		s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
304	}
305	s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
306	return stdout.String(), stderr.String(), err
307}
308
309// IsInterrupt checks if an error is due to interruption
310func IsInterrupt(err error) bool {
311	return errors.Is(err, context.Canceled) ||
312		errors.Is(err, context.DeadlineExceeded)
313}
314
315// ExitCode extracts the exit code from an error
316func ExitCode(err error) int {
317	if err == nil {
318		return 0
319	}
320	status, ok := interp.IsExitStatus(err)
321	if ok {
322		return int(status)
323	}
324	return 1
325}