1// Package shell provides cross-platform shell execution capabilities.
2//
3// This package offers two main types:
4// - Shell: A general-purpose shell executor for one-off or managed commands
5// - PersistentShell: A singleton shell that maintains state across the application
6//
7// WINDOWS COMPATIBILITY:
8// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
9// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility.
10package shell
11
12import (
13 "bytes"
14 "context"
15 "errors"
16 "fmt"
17 "os"
18 "os/exec"
19 "runtime"
20 "strings"
21 "sync"
22
23 "mvdan.cc/sh/v3/expand"
24 "mvdan.cc/sh/v3/interp"
25 "mvdan.cc/sh/v3/syntax"
26)
27
28// ShellType represents the type of shell to use
29type ShellType int
30
31const (
32 ShellTypePOSIX ShellType = iota
33 ShellTypeCmd
34 ShellTypePowerShell
35)
36
37// Logger interface for optional logging
38type Logger interface {
39 InfoPersist(msg string, keysAndValues ...interface{})
40}
41
42// noopLogger is a logger that does nothing
43type noopLogger struct{}
44
45func (noopLogger) InfoPersist(msg string, keysAndValues ...interface{}) {}
46
47// Shell provides cross-platform shell execution with optional state persistence
48type Shell struct {
49 env []string
50 cwd string
51 mu sync.Mutex
52 logger Logger
53}
54
55// Options for creating a new shell
56type Options struct {
57 WorkingDir string
58 Env []string
59 Logger Logger
60}
61
62// NewShell creates a new shell instance with the given options
63func NewShell(opts *Options) *Shell {
64 if opts == nil {
65 opts = &Options{}
66 }
67
68 cwd := opts.WorkingDir
69 if cwd == "" {
70 cwd, _ = os.Getwd()
71 }
72
73 env := opts.Env
74 if env == nil {
75 env = os.Environ()
76 }
77
78 logger := opts.Logger
79 if logger == nil {
80 logger = noopLogger{}
81 }
82
83 return &Shell{
84 cwd: cwd,
85 env: env,
86 logger: logger,
87 }
88}
89
90// Exec executes a command in the shell
91func (s *Shell) Exec(ctx context.Context, command string) (string, string, error) {
92 s.mu.Lock()
93 defer s.mu.Unlock()
94
95 // Determine which shell to use based on platform and command
96 shellType := s.determineShellType(command)
97
98 switch shellType {
99 case ShellTypeCmd:
100 return s.execWindows(ctx, command, "cmd")
101 case ShellTypePowerShell:
102 return s.execWindows(ctx, command, "powershell")
103 default:
104 return s.execPOSIX(ctx, command)
105 }
106}
107
108// GetWorkingDir returns the current working directory
109func (s *Shell) GetWorkingDir() string {
110 s.mu.Lock()
111 defer s.mu.Unlock()
112 return s.cwd
113}
114
115// SetWorkingDir sets the working directory
116func (s *Shell) SetWorkingDir(dir string) error {
117 s.mu.Lock()
118 defer s.mu.Unlock()
119
120 // Verify the directory exists
121 if _, err := os.Stat(dir); err != nil {
122 return fmt.Errorf("directory does not exist: %w", err)
123 }
124
125 s.cwd = dir
126 return nil
127}
128
129// GetEnv returns a copy of the environment variables
130func (s *Shell) GetEnv() []string {
131 s.mu.Lock()
132 defer s.mu.Unlock()
133
134 env := make([]string, len(s.env))
135 copy(env, s.env)
136 return env
137}
138
139// SetEnv sets an environment variable
140func (s *Shell) SetEnv(key, value string) {
141 s.mu.Lock()
142 defer s.mu.Unlock()
143
144 // Update or add the environment variable
145 keyPrefix := key + "="
146 for i, env := range s.env {
147 if strings.HasPrefix(env, keyPrefix) {
148 s.env[i] = keyPrefix + value
149 return
150 }
151 }
152 s.env = append(s.env, keyPrefix+value)
153}
154
155// Windows-specific commands that should use native shell
156var windowsNativeCommands = map[string]bool{
157 "dir": true,
158 "type": true,
159 "copy": true,
160 "move": true,
161 "del": true,
162 "md": true,
163 "mkdir": true,
164 "rd": true,
165 "rmdir": true,
166 "cls": true,
167 "where": true,
168 "tasklist": true,
169 "taskkill": true,
170 "net": true,
171 "sc": true,
172 "reg": true,
173 "wmic": true,
174}
175
176// determineShellType decides which shell to use based on platform and command
177func (s *Shell) determineShellType(command string) ShellType {
178 if runtime.GOOS != "windows" {
179 return ShellTypePOSIX
180 }
181
182 // Extract the first command from the command line
183 parts := strings.Fields(command)
184 if len(parts) == 0 {
185 return ShellTypePOSIX
186 }
187
188 firstCmd := strings.ToLower(parts[0])
189
190 // Check if it's a Windows-specific command
191 if windowsNativeCommands[firstCmd] {
192 return ShellTypeCmd
193 }
194
195 // Check for PowerShell-specific syntax
196 if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") ||
197 strings.Contains(command, "New-") || strings.Contains(command, "$_") ||
198 strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") {
199 return ShellTypePowerShell
200 }
201
202 // Default to POSIX emulation for cross-platform compatibility
203 return ShellTypePOSIX
204}
205
206// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
207func (s *Shell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
208 var cmd *exec.Cmd
209
210 // Handle directory changes specially to maintain persistent shell behavior
211 if strings.HasPrefix(strings.TrimSpace(command), "cd ") {
212 return s.handleWindowsCD(command)
213 }
214
215 switch shell {
216 case "cmd":
217 // Use cmd.exe for Windows commands
218 // Add current directory context to maintain state
219 fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command)
220 cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand)
221 case "powershell":
222 // Use PowerShell for PowerShell commands
223 // Add current directory context to maintain state
224 fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command)
225 cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand)
226 default:
227 return "", "", fmt.Errorf("unsupported Windows shell: %s", shell)
228 }
229
230 // Set environment variables
231 cmd.Env = s.env
232
233 var stdout, stderr bytes.Buffer
234 cmd.Stdout = &stdout
235 cmd.Stderr = &stderr
236
237 err := cmd.Run()
238
239 s.logger.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
240 return stdout.String(), stderr.String(), err
241}
242
243// handleWindowsCD handles directory changes for Windows shells
244func (s *Shell) handleWindowsCD(command string) (string, string, error) {
245 // Extract the target directory from the cd command
246 parts := strings.Fields(command)
247 if len(parts) < 2 {
248 return "", "cd: missing directory argument", fmt.Errorf("missing directory argument")
249 }
250
251 targetDir := parts[1]
252
253 // Handle relative paths
254 if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") {
255 // Relative path - resolve against current directory
256 if targetDir == ".." {
257 // Go up one directory
258 if len(s.cwd) > 3 { // Don't go above drive root (C:\)
259 lastSlash := strings.LastIndex(s.cwd, "\\")
260 if lastSlash > 2 { // Keep drive letter
261 s.cwd = s.cwd[:lastSlash]
262 }
263 }
264 } else if targetDir != "." {
265 // Go to subdirectory
266 s.cwd = s.cwd + "\\" + targetDir
267 }
268 } else {
269 // Absolute path
270 s.cwd = targetDir
271 }
272
273 // Verify the directory exists
274 if _, err := os.Stat(s.cwd); err != nil {
275 return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err
276 }
277
278 return "", "", nil
279}
280
281// execPOSIX executes commands using POSIX shell emulation (cross-platform)
282func (s *Shell) execPOSIX(ctx context.Context, command string) (string, string, error) {
283 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
284 if err != nil {
285 return "", "", fmt.Errorf("could not parse command: %w", err)
286 }
287
288 var stdout, stderr bytes.Buffer
289 runner, err := interp.New(
290 interp.StdIO(nil, &stdout, &stderr),
291 interp.Interactive(false),
292 interp.Env(expand.ListEnviron(s.env...)),
293 interp.Dir(s.cwd),
294 )
295 if err != nil {
296 return "", "", fmt.Errorf("could not run command: %w", err)
297 }
298
299 err = runner.Run(ctx, line)
300 s.cwd = runner.Dir
301 s.env = []string{}
302 for name, vr := range runner.Vars {
303 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
304 }
305 s.logger.InfoPersist("POSIX command finished", "command", command, "err", err)
306 return stdout.String(), stderr.String(), err
307}
308
309// IsInterrupt checks if an error is due to interruption
310func IsInterrupt(err error) bool {
311 return errors.Is(err, context.Canceled) ||
312 errors.Is(err, context.DeadlineExceeded)
313}
314
315// ExitCode extracts the exit code from an error
316func ExitCode(err error) int {
317 if err == nil {
318 return 0
319 }
320 status, ok := interp.IsExitStatus(err)
321 if ok {
322 return int(status)
323 }
324 return 1
325}