run_test.go

  1package shell
  2
  3import (
  4	"bytes"
  5	"context"
  6	"errors"
  7	"fmt"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"testing"
 12	"time"
 13)
 14
 15func TestRun_Echo(t *testing.T) {
 16	var stdout, stderr bytes.Buffer
 17	err := Run(t.Context(), RunOptions{
 18		Command: "echo hi",
 19		Cwd:     t.TempDir(),
 20		Stdout:  &stdout,
 21		Stderr:  &stderr,
 22	})
 23	if err != nil {
 24		t.Fatalf("Run returned error: %v (stderr=%q)", err, stderr.String())
 25	}
 26	if got := stdout.String(); got != "hi\n" {
 27		t.Fatalf("stdout = %q, want %q", got, "hi\n")
 28	}
 29}
 30
 31func TestRun_ExitCode(t *testing.T) {
 32	err := Run(t.Context(), RunOptions{
 33		Command: "exit 7",
 34		Cwd:     t.TempDir(),
 35	})
 36	if err == nil {
 37		t.Fatal("expected error for exit 7, got nil")
 38	}
 39	if code := ExitCode(err); code != 7 {
 40		t.Fatalf("ExitCode = %d, want 7", code)
 41	}
 42}
 43
 44func TestRun_Stdin(t *testing.T) {
 45	// Use the `read` shell builtin so the test doesn't depend on any
 46	// external binary being on PATH (we pass an empty Env here).
 47	var stdout bytes.Buffer
 48	err := Run(t.Context(), RunOptions{
 49		Command: "read line; echo got:$line",
 50		Cwd:     t.TempDir(),
 51		Stdin:   strings.NewReader("hello\n"),
 52		Stdout:  &stdout,
 53	})
 54	if err != nil {
 55		t.Fatalf("Run returned error: %v", err)
 56	}
 57	if got := stdout.String(); got != "got:hello\n" {
 58		t.Fatalf("stdout = %q, want %q", got, "got:hello\n")
 59	}
 60}
 61
 62func TestRun_Env(t *testing.T) {
 63	var stdout bytes.Buffer
 64	err := Run(t.Context(), RunOptions{
 65		Command: `echo "$FOO"`,
 66		Cwd:     t.TempDir(),
 67		Env:     []string{"FOO=bar"},
 68		Stdout:  &stdout,
 69	})
 70	if err != nil {
 71		t.Fatalf("Run returned error: %v", err)
 72	}
 73	if got := stdout.String(); got != "bar\n" {
 74		t.Fatalf("stdout = %q, want %q", got, "bar\n")
 75	}
 76}
 77
 78func TestRun_Cwd(t *testing.T) {
 79	dir := t.TempDir()
 80	var stdout bytes.Buffer
 81	err := Run(t.Context(), RunOptions{
 82		Command: "pwd",
 83		Cwd:     dir,
 84		Stdout:  &stdout,
 85	})
 86	if err != nil {
 87		t.Fatalf("Run returned error: %v", err)
 88	}
 89	// mvdan's pwd builtin resolves symlinks (e.g. /var -> /private/var on
 90	// macOS). Compare against a suffix so we don't get bitten by that.
 91	got := strings.TrimRight(stdout.String(), "\n")
 92	if !strings.HasSuffix(got, dir) && !strings.HasSuffix(dir, got) {
 93		t.Fatalf("pwd = %q, want it to match %q", got, dir)
 94	}
 95}
 96
 97func TestRun_JqBuiltin(t *testing.T) {
 98	var stdout bytes.Buffer
 99	err := Run(t.Context(), RunOptions{
100		Command: `echo '{"a":1}' | jq .a`,
101		Cwd:     t.TempDir(),
102		Stdout:  &stdout,
103	})
104	if err != nil {
105		t.Fatalf("Run returned error: %v", err)
106	}
107	if got := stdout.String(); got != "1\n" {
108		t.Fatalf("stdout = %q, want %q", got, "1\n")
109	}
110}
111
112func TestRun_ParallelIsolation(t *testing.T) {
113	const n = 10
114	var wg sync.WaitGroup
115	wg.Add(n)
116	errs := make([]error, n)
117	outs := make([]string, n)
118	dirs := make([]string, n)
119	for i := range n {
120		dirs[i] = t.TempDir()
121		go func(i int) {
122			defer wg.Done()
123			var stdout bytes.Buffer
124			errs[i] = Run(t.Context(), RunOptions{
125				Command: `echo "$MARKER"`,
126				Cwd:     dirs[i],
127				Env:     []string{fmt.Sprintf("MARKER=id-%d", i)},
128				Stdout:  &stdout,
129			})
130			outs[i] = stdout.String()
131		}(i)
132	}
133	wg.Wait()
134	for i := range n {
135		if errs[i] != nil {
136			t.Errorf("goroutine %d: err = %v", i, errs[i])
137			continue
138		}
139		want := fmt.Sprintf("id-%d\n", i)
140		if outs[i] != want {
141			t.Errorf("goroutine %d: stdout = %q, want %q", i, outs[i], want)
142		}
143	}
144}
145
146// TestRun_CtxCancel_BusyLoop verifies that a pure-shell loop respects ctx
147// cancellation. mvdan's interpreter checks ctx between statements, so this
148// should return quickly even without any external command. The test bounds
149// its own wait via a select so a regression can't hang CI.
150func TestRun_CtxCancel_BusyLoop(t *testing.T) {
151	ctx, cancel := context.WithTimeout(t.Context(), 500*time.Millisecond)
152	t.Cleanup(cancel)
153
154	done := make(chan error, 1)
155	go func() {
156		done <- Run(ctx, RunOptions{
157			Command: "while true; do :; done",
158			Cwd:     t.TempDir(),
159		})
160	}()
161
162	select {
163	case err := <-done:
164		if !IsInterrupt(err) && !errors.Is(err, context.DeadlineExceeded) {
165			t.Fatalf("expected interrupt/deadline error, got: %v", err)
166		}
167	case <-time.After(1500 * time.Millisecond):
168		t.Fatal("Run did not return within 1.5s after ctx cancel")
169	}
170}
171
172// TestRun_CtxCancel_ExternalSleep verifies ctx cancellation reaches an
173// external process via mvdan's default exec. Uses sleep, which lives in
174// coreutils on Windows and /bin on Unix.
175func TestRun_CtxCancel_ExternalSleep(t *testing.T) {
176	ctx, cancel := context.WithTimeout(t.Context(), 200*time.Millisecond)
177	t.Cleanup(cancel)
178
179	done := make(chan error, 1)
180	start := time.Now()
181	go func() {
182		done <- Run(ctx, RunOptions{
183			Command: "sleep 30",
184			Cwd:     t.TempDir(),
185		})
186	}()
187
188	select {
189	case err := <-done:
190		elapsed := time.Since(start)
191		if elapsed > time.Second {
192			t.Fatalf("sleep took too long to cancel: %v", elapsed)
193		}
194		if err == nil {
195			t.Fatal("expected non-nil error from cancelled sleep")
196		}
197	case <-time.After(time.Second):
198		t.Fatal("Run did not return within 1s after ctx cancel")
199	}
200}
201
202func TestRun_ParseError(t *testing.T) {
203	err := Run(t.Context(), RunOptions{
204		Command: "echo 'unterminated",
205		Cwd:     t.TempDir(),
206	})
207	if err == nil {
208		t.Fatal("expected parse error, got nil")
209	}
210	if !strings.Contains(err.Error(), "parse") {
211		t.Fatalf("error should mention parse: %v", err)
212	}
213}
214
215func TestRun_BlockFuncs(t *testing.T) {
216	block := CommandsBlocker([]string{"forbidden"})
217	var stderr bytes.Buffer
218	err := Run(t.Context(), RunOptions{
219		Command:    "forbidden",
220		Cwd:        t.TempDir(),
221		Stderr:     &stderr,
222		BlockFuncs: []BlockFunc{block},
223	})
224	if err == nil {
225		t.Fatal("expected error when running blocked command")
226	}
227	if !strings.Contains(err.Error(), "not allowed") {
228		t.Fatalf("expected 'not allowed' error, got: %v", err)
229	}
230}
231
232func TestRun_RequiresCwd(t *testing.T) {
233	err := Run(t.Context(), RunOptions{
234		Command: "echo hi",
235	})
236	if err == nil {
237		t.Fatal("expected error when Cwd is empty, got nil")
238	}
239	if !strings.Contains(err.Error(), "Cwd is required") {
240		t.Fatalf("error should mention Cwd requirement: %v", err)
241	}
242}
243
244func TestWithNonInteractiveEnv_Empty(t *testing.T) {
245	t.Parallel()
246	result := withNonInteractiveEnv(nil)
247	// All defaults must be present.
248	for _, want := range nonInteractiveEnvVars {
249		if !slices.Contains(result, want) {
250			t.Errorf("missing default %q in result", want)
251		}
252	}
253}
254
255func TestWithNonInteractiveEnv_OverridesExisting(t *testing.T) {
256	t.Parallel()
257	env := []string{"EDITOR=nvim", "PAGER=less", "FOO=bar"}
258	result := withNonInteractiveEnv(env)
259
260	// EDITOR and PAGER must be overridden, not preserved.
261	for _, e := range result {
262		if e == "EDITOR=nvim" {
263			t.Error("EDITOR=nvim should have been overridden")
264		}
265		if e == "PAGER=less" {
266			t.Error("PAGER=less should have been overridden")
267		}
268	}
269	// FOO must survive.
270	if !slices.Contains(result, "FOO=bar") {
271		t.Error("FOO=bar should be preserved")
272	}
273}
274
275func TestWithNonInteractiveEnv_NoPrefixCollision(t *testing.T) {
276	t.Parallel()
277	// EDITORIAL should NOT match EDITOR.
278	env := []string{"EDITORIAL=yes", "GITHUB_TOKEN=secret"}
279	result := withNonInteractiveEnv(env)
280
281	foundEditorial := false
282	foundGithub := false
283	for _, e := range result {
284		if e == "EDITORIAL=yes" {
285			foundEditorial = true
286		}
287		if e == "GITHUB_TOKEN=secret" {
288			foundGithub = true
289		}
290	}
291	if !foundEditorial {
292		t.Error("EDITORIAL=yes should not be removed by EDITOR override")
293	}
294	if !foundGithub {
295		t.Error("GITHUB_TOKEN=secret should not be removed")
296	}
297}
298
299func TestWithNonInteractiveEnv_SliceIndependence(t *testing.T) {
300	t.Parallel()
301	env := []string{"FOO=bar"}
302	result := withNonInteractiveEnv(env)
303	// Mutating the input must not affect the result.
304	env[0] = "FOO=baz"
305	for _, e := range result {
306		if e == "FOO=baz" {
307			t.Error("result shares backing array with input")
308		}
309	}
310}
311
312func TestRun_DiscardsNilWriters(t *testing.T) {
313	// No panic when Stdout/Stderr are nil.
314	err := Run(t.Context(), RunOptions{
315		Command: "echo hi; echo err >&2",
316		Cwd:     t.TempDir(),
317	})
318	if err != nil {
319		t.Fatalf("Run returned error: %v", err)
320	}
321}