1package shell
2
3import (
4 "bytes"
5 "context"
6 "errors"
7 "fmt"
8 "slices"
9 "strings"
10 "sync"
11 "testing"
12 "time"
13)
14
15func TestRun_Echo(t *testing.T) {
16 var stdout, stderr bytes.Buffer
17 err := Run(t.Context(), RunOptions{
18 Command: "echo hi",
19 Cwd: t.TempDir(),
20 Stdout: &stdout,
21 Stderr: &stderr,
22 })
23 if err != nil {
24 t.Fatalf("Run returned error: %v (stderr=%q)", err, stderr.String())
25 }
26 if got := stdout.String(); got != "hi\n" {
27 t.Fatalf("stdout = %q, want %q", got, "hi\n")
28 }
29}
30
31func TestRun_ExitCode(t *testing.T) {
32 err := Run(t.Context(), RunOptions{
33 Command: "exit 7",
34 Cwd: t.TempDir(),
35 })
36 if err == nil {
37 t.Fatal("expected error for exit 7, got nil")
38 }
39 if code := ExitCode(err); code != 7 {
40 t.Fatalf("ExitCode = %d, want 7", code)
41 }
42}
43
44func TestRun_Stdin(t *testing.T) {
45 // Use the `read` shell builtin so the test doesn't depend on any
46 // external binary being on PATH (we pass an empty Env here).
47 var stdout bytes.Buffer
48 err := Run(t.Context(), RunOptions{
49 Command: "read line; echo got:$line",
50 Cwd: t.TempDir(),
51 Stdin: strings.NewReader("hello\n"),
52 Stdout: &stdout,
53 })
54 if err != nil {
55 t.Fatalf("Run returned error: %v", err)
56 }
57 if got := stdout.String(); got != "got:hello\n" {
58 t.Fatalf("stdout = %q, want %q", got, "got:hello\n")
59 }
60}
61
62func TestRun_Env(t *testing.T) {
63 var stdout bytes.Buffer
64 err := Run(t.Context(), RunOptions{
65 Command: `echo "$FOO"`,
66 Cwd: t.TempDir(),
67 Env: []string{"FOO=bar"},
68 Stdout: &stdout,
69 })
70 if err != nil {
71 t.Fatalf("Run returned error: %v", err)
72 }
73 if got := stdout.String(); got != "bar\n" {
74 t.Fatalf("stdout = %q, want %q", got, "bar\n")
75 }
76}
77
78func TestRun_Cwd(t *testing.T) {
79 dir := t.TempDir()
80 var stdout bytes.Buffer
81 err := Run(t.Context(), RunOptions{
82 Command: "pwd",
83 Cwd: dir,
84 Stdout: &stdout,
85 })
86 if err != nil {
87 t.Fatalf("Run returned error: %v", err)
88 }
89 // mvdan's pwd builtin resolves symlinks (e.g. /var -> /private/var on
90 // macOS). Compare against a suffix so we don't get bitten by that.
91 got := strings.TrimRight(stdout.String(), "\n")
92 if !strings.HasSuffix(got, dir) && !strings.HasSuffix(dir, got) {
93 t.Fatalf("pwd = %q, want it to match %q", got, dir)
94 }
95}
96
97func TestRun_JqBuiltin(t *testing.T) {
98 var stdout bytes.Buffer
99 err := Run(t.Context(), RunOptions{
100 Command: `echo '{"a":1}' | jq .a`,
101 Cwd: t.TempDir(),
102 Stdout: &stdout,
103 })
104 if err != nil {
105 t.Fatalf("Run returned error: %v", err)
106 }
107 if got := stdout.String(); got != "1\n" {
108 t.Fatalf("stdout = %q, want %q", got, "1\n")
109 }
110}
111
112func TestRun_ParallelIsolation(t *testing.T) {
113 const n = 10
114 var wg sync.WaitGroup
115 wg.Add(n)
116 errs := make([]error, n)
117 outs := make([]string, n)
118 dirs := make([]string, n)
119 for i := range n {
120 dirs[i] = t.TempDir()
121 go func(i int) {
122 defer wg.Done()
123 var stdout bytes.Buffer
124 errs[i] = Run(t.Context(), RunOptions{
125 Command: `echo "$MARKER"`,
126 Cwd: dirs[i],
127 Env: []string{fmt.Sprintf("MARKER=id-%d", i)},
128 Stdout: &stdout,
129 })
130 outs[i] = stdout.String()
131 }(i)
132 }
133 wg.Wait()
134 for i := range n {
135 if errs[i] != nil {
136 t.Errorf("goroutine %d: err = %v", i, errs[i])
137 continue
138 }
139 want := fmt.Sprintf("id-%d\n", i)
140 if outs[i] != want {
141 t.Errorf("goroutine %d: stdout = %q, want %q", i, outs[i], want)
142 }
143 }
144}
145
146// TestRun_CtxCancel_BusyLoop verifies that a pure-shell loop respects ctx
147// cancellation. mvdan's interpreter checks ctx between statements, so this
148// should return quickly even without any external command. The test bounds
149// its own wait via a select so a regression can't hang CI.
150func TestRun_CtxCancel_BusyLoop(t *testing.T) {
151 ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond)
152 t.Cleanup(cancel)
153
154 done := make(chan error, 1)
155 go func() {
156 done <- Run(ctx, RunOptions{
157 Command: "while true; do :; done",
158 Cwd: t.TempDir(),
159 })
160 }()
161
162 select {
163 case err := <-done:
164 if !IsInterrupt(err) && !errors.Is(err, context.DeadlineExceeded) {
165 t.Fatalf("expected interrupt/deadline error, got: %v", err)
166 }
167 case <-time.After(1500 * time.Millisecond):
168 t.Fatal("Run did not return within 1.5s after ctx cancel")
169 }
170}
171
172// TestRun_CtxCancel_ExternalSleep verifies ctx cancellation reaches an
173// external process via mvdan's default exec. Uses sleep, which lives in
174// coreutils on Windows and /bin on Unix.
175func TestRun_CtxCancel_ExternalSleep(t *testing.T) {
176 ctx, cancel := context.WithTimeout(t.Context(), 200*time.Millisecond)
177 t.Cleanup(cancel)
178
179 done := make(chan error, 1)
180 start := time.Now()
181 go func() {
182 done <- Run(ctx, RunOptions{
183 Command: "sleep 30",
184 Cwd: t.TempDir(),
185 })
186 }()
187
188 select {
189 case err := <-done:
190 elapsed := time.Since(start)
191 if elapsed > time.Second {
192 t.Fatalf("sleep took too long to cancel: %v", elapsed)
193 }
194 if err == nil {
195 t.Fatal("expected non-nil error from cancelled sleep")
196 }
197 case <-time.After(time.Second):
198 t.Fatal("Run did not return within 1s after ctx cancel")
199 }
200}
201
202func TestRun_ParseError(t *testing.T) {
203 err := Run(t.Context(), RunOptions{
204 Command: "echo 'unterminated",
205 Cwd: t.TempDir(),
206 })
207 if err == nil {
208 t.Fatal("expected parse error, got nil")
209 }
210 if !strings.Contains(err.Error(), "parse") {
211 t.Fatalf("error should mention parse: %v", err)
212 }
213}
214
215func TestRun_BlockFuncs(t *testing.T) {
216 block := CommandsBlocker([]string{"forbidden"})
217 var stderr bytes.Buffer
218 err := Run(t.Context(), RunOptions{
219 Command: "forbidden",
220 Cwd: t.TempDir(),
221 Stderr: &stderr,
222 BlockFuncs: []BlockFunc{block},
223 })
224 if err == nil {
225 t.Fatal("expected error when running blocked command")
226 }
227 if !strings.Contains(err.Error(), "not allowed") {
228 t.Fatalf("expected 'not allowed' error, got: %v", err)
229 }
230}
231
232func TestRun_RequiresCwd(t *testing.T) {
233 err := Run(t.Context(), RunOptions{
234 Command: "echo hi",
235 })
236 if err == nil {
237 t.Fatal("expected error when Cwd is empty, got nil")
238 }
239 if !strings.Contains(err.Error(), "Cwd is required") {
240 t.Fatalf("error should mention Cwd requirement: %v", err)
241 }
242}
243
244func TestWithNonInteractiveEnv_Empty(t *testing.T) {
245 t.Parallel()
246 result := withNonInteractiveEnv(nil)
247 // All defaults must be present.
248 for _, want := range nonInteractiveEnvVars {
249 if !slices.Contains(result, want) {
250 t.Errorf("missing default %q in result", want)
251 }
252 }
253}
254
255func TestWithNonInteractiveEnv_OverridesExisting(t *testing.T) {
256 t.Parallel()
257 env := []string{"EDITOR=nvim", "PAGER=less", "FOO=bar"}
258 result := withNonInteractiveEnv(env)
259
260 // EDITOR and PAGER must be overridden, not preserved.
261 for _, e := range result {
262 if e == "EDITOR=nvim" {
263 t.Error("EDITOR=nvim should have been overridden")
264 }
265 if e == "PAGER=less" {
266 t.Error("PAGER=less should have been overridden")
267 }
268 }
269 // FOO must survive.
270 if !slices.Contains(result, "FOO=bar") {
271 t.Error("FOO=bar should be preserved")
272 }
273}
274
275func TestWithNonInteractiveEnv_NoPrefixCollision(t *testing.T) {
276 t.Parallel()
277 // EDITORIAL should NOT match EDITOR.
278 env := []string{"EDITORIAL=yes", "GITHUB_TOKEN=secret"}
279 result := withNonInteractiveEnv(env)
280
281 foundEditorial := false
282 foundGithub := false
283 for _, e := range result {
284 if e == "EDITORIAL=yes" {
285 foundEditorial = true
286 }
287 if e == "GITHUB_TOKEN=secret" {
288 foundGithub = true
289 }
290 }
291 if !foundEditorial {
292 t.Error("EDITORIAL=yes should not be removed by EDITOR override")
293 }
294 if !foundGithub {
295 t.Error("GITHUB_TOKEN=secret should not be removed")
296 }
297}
298
299func TestWithNonInteractiveEnv_SliceIndependence(t *testing.T) {
300 t.Parallel()
301 env := []string{"FOO=bar"}
302 result := withNonInteractiveEnv(env)
303 // Mutating the input must not affect the result.
304 env[0] = "FOO=baz"
305 for _, e := range result {
306 if e == "FOO=baz" {
307 t.Error("result shares backing array with input")
308 }
309 }
310}
311
312func TestRun_DiscardsNilWriters(t *testing.T) {
313 // No panic when Stdout/Stderr are nil.
314 err := Run(t.Context(), RunOptions{
315 Command: "echo hi; echo err >&2",
316 Cwd: t.TempDir(),
317 })
318 if err != nil {
319 t.Fatalf("Run returned error: %v", err)
320 }
321}