1// Package shell provides cross-platform shell execution capabilities.
2//
3// WINDOWS COMPATIBILITY:
4// This implementation provides both POSIX shell emulation (mvdan.cc/sh/v3) and
5// native Windows shell support (cmd.exe/PowerShell) for optimal compatibility:
6// - On Windows: Uses native cmd.exe or PowerShell for Windows-specific commands
7// - Cross-platform: Falls back to POSIX emulation for Unix-style commands
8// - Automatic detection: Chooses the best shell based on command and platform
9package shell
10
11import (
12 "bytes"
13 "context"
14 "errors"
15 "fmt"
16 "os"
17 "os/exec"
18 "runtime"
19 "strings"
20 "sync"
21
22 "github.com/charmbracelet/crush/internal/logging"
23 "mvdan.cc/sh/v3/expand"
24 "mvdan.cc/sh/v3/interp"
25 "mvdan.cc/sh/v3/syntax"
26)
27
28// ShellType represents the type of shell to use
29type ShellType int
30
31const (
32 ShellTypePOSIX ShellType = iota
33 ShellTypeCmd
34 ShellTypePowerShell
35)
36
37type PersistentShell struct {
38 env []string
39 cwd string
40 mu sync.Mutex
41}
42
43var (
44 once sync.Once
45 shellInstance *PersistentShell
46)
47
48// Windows-specific commands that should use native shell
49var windowsNativeCommands = map[string]bool{
50 "dir": true,
51 "type": true,
52 "copy": true,
53 "move": true,
54 "del": true,
55 "md": true,
56 "mkdir": true,
57 "rd": true,
58 "rmdir": true,
59 "cls": true,
60 "where": true,
61 "tasklist": true,
62 "taskkill": true,
63 "net": true,
64 "sc": true,
65 "reg": true,
66 "wmic": true,
67}
68
69func GetPersistentShell(cwd string) *PersistentShell {
70 once.Do(func() {
71 shellInstance = newPersistentShell(cwd)
72 })
73 return shellInstance
74}
75
76func newPersistentShell(cwd string) *PersistentShell {
77 return &PersistentShell{
78 cwd: cwd,
79 env: os.Environ(),
80 }
81}
82
83func (s *PersistentShell) Exec(ctx context.Context, command string) (string, string, error) {
84 s.mu.Lock()
85 defer s.mu.Unlock()
86
87 // Determine which shell to use based on platform and command
88 shellType := s.determineShellType(command)
89
90 switch shellType {
91 case ShellTypeCmd:
92 return s.execWindows(ctx, command, "cmd")
93 case ShellTypePowerShell:
94 return s.execWindows(ctx, command, "powershell")
95 default:
96 return s.execPOSIX(ctx, command)
97 }
98}
99
100// determineShellType decides which shell to use based on platform and command
101func (s *PersistentShell) determineShellType(command string) ShellType {
102 if runtime.GOOS != "windows" {
103 return ShellTypePOSIX
104 }
105
106 // Extract the first command from the command line
107 parts := strings.Fields(command)
108 if len(parts) == 0 {
109 return ShellTypePOSIX
110 }
111
112 firstCmd := strings.ToLower(parts[0])
113
114 // Check if it's a Windows-specific command
115 if windowsNativeCommands[firstCmd] {
116 return ShellTypeCmd
117 }
118
119 // Check for PowerShell-specific syntax
120 if strings.Contains(command, "Get-") || strings.Contains(command, "Set-") ||
121 strings.Contains(command, "New-") || strings.Contains(command, "$_") ||
122 strings.Contains(command, "| Where-Object") || strings.Contains(command, "| ForEach-Object") {
123 return ShellTypePowerShell
124 }
125
126 // Default to POSIX emulation for cross-platform compatibility
127 return ShellTypePOSIX
128}
129
130// execWindows executes commands using native Windows shells (cmd.exe or PowerShell)
131func (s *PersistentShell) execWindows(ctx context.Context, command string, shell string) (string, string, error) {
132 var cmd *exec.Cmd
133
134 // Handle directory changes specially to maintain persistent shell behavior
135 if strings.HasPrefix(strings.TrimSpace(command), "cd ") {
136 return s.handleWindowsCD(command)
137 }
138
139 switch shell {
140 case "cmd":
141 // Use cmd.exe for Windows commands
142 // Add current directory context to maintain state
143 fullCommand := fmt.Sprintf("cd /d \"%s\" && %s", s.cwd, command)
144 cmd = exec.CommandContext(ctx, "cmd", "/C", fullCommand)
145 case "powershell":
146 // Use PowerShell for PowerShell commands
147 // Add current directory context to maintain state
148 fullCommand := fmt.Sprintf("Set-Location '%s'; %s", s.cwd, command)
149 cmd = exec.CommandContext(ctx, "powershell", "-Command", fullCommand)
150 default:
151 return "", "", fmt.Errorf("unsupported Windows shell: %s", shell)
152 }
153
154 // Set environment variables
155 cmd.Env = s.env
156
157 var stdout, stderr bytes.Buffer
158 cmd.Stdout = &stdout
159 cmd.Stderr = &stderr
160
161 err := cmd.Run()
162
163 logging.InfoPersist("Windows command finished", "shell", shell, "command", command, "err", err)
164 return stdout.String(), stderr.String(), err
165}
166
167// handleWindowsCD handles directory changes for Windows shells
168func (s *PersistentShell) handleWindowsCD(command string) (string, string, error) {
169 // Extract the target directory from the cd command
170 parts := strings.Fields(command)
171 if len(parts) < 2 {
172 return "", "cd: missing directory argument", fmt.Errorf("missing directory argument")
173 }
174
175 targetDir := parts[1]
176
177 // Handle relative paths
178 if !strings.Contains(targetDir, ":") && !strings.HasPrefix(targetDir, "\\") {
179 // Relative path - resolve against current directory
180 if targetDir == ".." {
181 // Go up one directory
182 if len(s.cwd) > 3 { // Don't go above drive root (C:\)
183 lastSlash := strings.LastIndex(s.cwd, "\\")
184 if lastSlash > 2 { // Keep drive letter
185 s.cwd = s.cwd[:lastSlash]
186 }
187 }
188 } else if targetDir != "." {
189 // Go to subdirectory
190 s.cwd = s.cwd + "\\" + targetDir
191 }
192 } else {
193 // Absolute path
194 s.cwd = targetDir
195 }
196
197 // Verify the directory exists
198 if _, err := os.Stat(s.cwd); err != nil {
199 return "", fmt.Sprintf("cd: %s: No such file or directory", targetDir), err
200 }
201
202 return "", "", nil
203}
204
205// execPOSIX executes commands using POSIX shell emulation (cross-platform)
206func (s *PersistentShell) execPOSIX(ctx context.Context, command string) (string, string, error) {
207 line, err := syntax.NewParser().Parse(strings.NewReader(command), "")
208 if err != nil {
209 return "", "", fmt.Errorf("could not parse command: %w", err)
210 }
211
212 var stdout, stderr bytes.Buffer
213 runner, err := interp.New(
214 interp.StdIO(nil, &stdout, &stderr),
215 interp.Interactive(false),
216 interp.Env(expand.ListEnviron(s.env...)),
217 interp.Dir(s.cwd),
218 )
219 if err != nil {
220 return "", "", fmt.Errorf("could not run command: %w", err)
221 }
222
223 err = runner.Run(ctx, line)
224 s.cwd = runner.Dir
225 s.env = []string{}
226 for name, vr := range runner.Vars {
227 s.env = append(s.env, fmt.Sprintf("%s=%s", name, vr.Str))
228 }
229 logging.InfoPersist("POSIX command finished", "command", command, "err", err)
230 return stdout.String(), stderr.String(), err
231}
232
233func IsInterrupt(err error) bool {
234 return errors.Is(err, context.Canceled) ||
235 errors.Is(err, context.DeadlineExceeded)
236}
237
238func ExitCode(err error) int {
239 if err == nil {
240 return 0
241 }
242 status, ok := interp.IsExitStatus(err)
243 if ok {
244 return int(status)
245 }
246 return 1
247}