expand_test.go

  1package shell
  2
  3import (
  4	"strings"
  5	"testing"
  6
  7	"github.com/stretchr/testify/require"
  8)
  9
 10func TestExpandValue_Success(t *testing.T) {
 11	t.Parallel()
 12
 13	tests := []struct {
 14		name  string
 15		value string
 16		env   []string
 17		want  string
 18	}{
 19		{
 20			name:  "plain string round trip",
 21			value: "hello world",
 22			want:  "hello world",
 23		},
 24		{
 25			name:  "plain var from env",
 26			value: "$FOO",
 27			env:   []string{"FOO=bar"},
 28			want:  "bar",
 29		},
 30		{
 31			name:  "braced var from env",
 32			value: "pre-${FOO}-post",
 33			env:   []string{"FOO=bar"},
 34			want:  "pre-bar-post",
 35		},
 36		{
 37			name:  "default syntax on unset",
 38			value: "${MISSING:-fallback}",
 39			want:  "fallback",
 40		},
 41		{
 42			name:  "default syntax on set preserves value",
 43			value: "${SET:-fallback}",
 44			env:   []string{"SET=used"},
 45			want:  "used",
 46		},
 47		{
 48			name:  "command substitution",
 49			value: "$(echo hi)",
 50			want:  "hi",
 51		},
 52		{
 53			name:  "command substitution preserves internal spaces",
 54			value: `$(echo "a b")`,
 55			want:  "a b",
 56		},
 57		{
 58			name:  "command substitution strips only trailing newline",
 59			value: "$(printf 'a\\nb\\n')",
 60			want:  "a\nb",
 61		},
 62		{
 63			name:  "literal spaces around cmdsubst are preserved",
 64			value: "  $(echo v)  ",
 65			want:  "  v  ",
 66		},
 67		{
 68			name:  "paren inside quoted arg to echo",
 69			value: `$(echo ")")`,
 70			want:  ")",
 71		},
 72		{
 73			name:  "nested command substitution",
 74			value: "$(echo $(echo hi))",
 75			want:  "hi",
 76		},
 77		{
 78			name:  "glob-like input round trips unchanged",
 79			value: "*.go",
 80			want:  "*.go",
 81		},
 82		{
 83			name:  "tilde round trips unchanged",
 84			value: "~",
 85			want:  "~",
 86		},
 87		{
 88			name:  "env var inside cmdsubst",
 89			value: `$(printf '%s' "$FOO")`,
 90			env:   []string{"FOO=bar"},
 91			want:  "bar",
 92		},
 93	}
 94
 95	for _, tc := range tests {
 96		t.Run(tc.name, func(t *testing.T) {
 97			t.Parallel()
 98			got, err := ExpandValue(t.Context(), tc.value, tc.env)
 99			require.NoError(t, err)
100			require.Equal(t, tc.want, got)
101		})
102	}
103}
104
105func TestExpandValue_Errors(t *testing.T) {
106	t.Parallel()
107
108	t.Run("unset var is an error, not empty", func(t *testing.T) {
109		t.Parallel()
110		_, err := ExpandValue(t.Context(), "$MISSING", nil)
111		require.Error(t, err)
112	})
113
114	t.Run("unset var inside braces is an error", func(t *testing.T) {
115		t.Parallel()
116		_, err := ExpandValue(t.Context(), "${MISSING}", nil)
117		require.Error(t, err)
118	})
119
120	t.Run("unset var inside cmdsubst is an error", func(t *testing.T) {
121		t.Parallel()
122		_, err := ExpandValue(t.Context(), `$(printf '%s' "$MISSING")`, nil)
123		require.Error(t, err)
124	})
125
126	t.Run("bad syntax returns error", func(t *testing.T) {
127		t.Parallel()
128		_, err := ExpandValue(t.Context(), "$(", nil)
129		require.Error(t, err)
130	})
131
132	t.Run("inner non-zero exit returns error with exit code", func(t *testing.T) {
133		t.Parallel()
134		_, err := ExpandValue(t.Context(), "$(false)", nil)
135		require.Error(t, err)
136		require.Contains(t, err.Error(), "exit status 1")
137	})
138
139	t.Run("inner explicit exit code is surfaced", func(t *testing.T) {
140		t.Parallel()
141		_, err := ExpandValue(t.Context(), "$(exit 7)", nil)
142		require.Error(t, err)
143		require.Contains(t, err.Error(), "exit status 7")
144	})
145
146	t.Run("inner stderr is surfaced", func(t *testing.T) {
147		t.Parallel()
148		_, err := ExpandValue(
149			t.Context(),
150			`$(printf 'boom\n' 1>&2; exit 1)`,
151			nil,
152		)
153		require.Error(t, err)
154		require.Contains(t, err.Error(), "boom")
155	})
156
157	t.Run("inner stderr is truncated to byte budget", func(t *testing.T) {
158		t.Parallel()
159		// Emit more than maxInnerStderrBytes bytes of 'X' on stderr.
160		long := strings.Repeat("X", maxInnerStderrBytes*2)
161		_, err := ExpandValue(
162			t.Context(),
163			`$(printf '`+long+`' 1>&2; exit 1)`,
164			nil,
165		)
166		require.Error(t, err)
167		require.NotContains(
168			t,
169			err.Error(),
170			strings.Repeat("X", maxInnerStderrBytes+1),
171			"stderr should be bounded",
172		)
173	})
174}
175
176func TestSanitizeStderr(t *testing.T) {
177	t.Parallel()
178
179	t.Run("trims trailing newlines", func(t *testing.T) {
180		t.Parallel()
181		require.Equal(t, "hi", sanitizeStderr([]byte("hi\n\n")))
182	})
183
184	t.Run("preserves tabs and embedded newlines", func(t *testing.T) {
185		t.Parallel()
186		require.Equal(t, "a\tb\nc", sanitizeStderr([]byte("a\tb\nc")))
187	})
188
189	t.Run("replaces control characters", func(t *testing.T) {
190		t.Parallel()
191		require.Equal(t, "a?b", sanitizeStderr([]byte{'a', 0x01, 'b'}))
192	})
193
194	t.Run("bounds output", func(t *testing.T) {
195		t.Parallel()
196		got := sanitizeStderr([]byte(strings.Repeat("x", maxInnerStderrBytes*2)))
197		require.Len(t, got, maxInnerStderrBytes)
198	})
199}