bash_test.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"os"
  7	"strings"
  8	"testing"
  9	"time"
 10
 11	"github.com/stretchr/testify/assert"
 12	"github.com/stretchr/testify/require"
 13)
 14
 15func TestBashTool_Info(t *testing.T) {
 16	tool := NewBashTool(newMockPermissionService(true))
 17	info := tool.Info()
 18
 19	assert.Equal(t, BashToolName, info.Name)
 20	assert.NotEmpty(t, info.Description)
 21	assert.Contains(t, info.Parameters, "command")
 22	assert.Contains(t, info.Parameters, "timeout")
 23	assert.Contains(t, info.Required, "command")
 24}
 25
 26func TestBashTool_Run(t *testing.T) {
 27	// Save original working directory
 28	origWd, err := os.Getwd()
 29	require.NoError(t, err)
 30	defer func() {
 31		os.Chdir(origWd)
 32	}()
 33
 34	t.Run("executes command successfully", func(t *testing.T) {
 35		tool := NewBashTool(newMockPermissionService(true))
 36		params := BashParams{
 37			Command: "echo 'Hello World'",
 38		}
 39
 40		paramsJSON, err := json.Marshal(params)
 41		require.NoError(t, err)
 42
 43		call := ToolCall{
 44			Name:  BashToolName,
 45			Input: string(paramsJSON),
 46		}
 47
 48		response, err := tool.Run(context.Background(), call)
 49		require.NoError(t, err)
 50		assert.Equal(t, "Hello World\n", response.Content)
 51	})
 52
 53	t.Run("handles invalid parameters", func(t *testing.T) {
 54		tool := NewBashTool(newMockPermissionService(true))
 55		call := ToolCall{
 56			Name:  BashToolName,
 57			Input: "invalid json",
 58		}
 59
 60		response, err := tool.Run(context.Background(), call)
 61		require.NoError(t, err)
 62		assert.Contains(t, response.Content, "invalid parameters")
 63	})
 64
 65	t.Run("handles missing command", func(t *testing.T) {
 66		tool := NewBashTool(newMockPermissionService(true))
 67		params := BashParams{
 68			Command: "",
 69		}
 70
 71		paramsJSON, err := json.Marshal(params)
 72		require.NoError(t, err)
 73
 74		call := ToolCall{
 75			Name:  BashToolName,
 76			Input: string(paramsJSON),
 77		}
 78
 79		response, err := tool.Run(context.Background(), call)
 80		require.NoError(t, err)
 81		assert.Contains(t, response.Content, "missing command")
 82	})
 83
 84	t.Run("handles banned commands", func(t *testing.T) {
 85		tool := NewBashTool(newMockPermissionService(true))
 86
 87		for _, bannedCmd := range bannedCommands {
 88			params := BashParams{
 89				Command: bannedCmd + " arg1 arg2",
 90			}
 91
 92			paramsJSON, err := json.Marshal(params)
 93			require.NoError(t, err)
 94
 95			call := ToolCall{
 96				Name:  BashToolName,
 97				Input: string(paramsJSON),
 98			}
 99
100			response, err := tool.Run(context.Background(), call)
101			require.NoError(t, err)
102			assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
103		}
104	})
105
106	t.Run("handles multi-word safe commands without permission check", func(t *testing.T) {
107		tool := NewBashTool(newMockPermissionService(false))
108
109		// Test with multi-word safe commands
110		multiWordCommands := []string{
111			"go env",
112		}
113
114		for _, cmd := range multiWordCommands {
115			params := BashParams{
116				Command: cmd,
117			}
118
119			paramsJSON, err := json.Marshal(params)
120			require.NoError(t, err)
121
122			call := ToolCall{
123				Name:  BashToolName,
124				Input: string(paramsJSON),
125			}
126
127			response, err := tool.Run(context.Background(), call)
128			require.NoError(t, err)
129			assert.NotContains(t, response.Content, "permission denied",
130				"Command %s should be allowed without permission", cmd)
131		}
132	})
133
134	t.Run("handles permission denied", func(t *testing.T) {
135		tool := NewBashTool(newMockPermissionService(false))
136
137		// Test with a command that requires permission
138		params := BashParams{
139			Command: "mkdir test_dir",
140		}
141
142		paramsJSON, err := json.Marshal(params)
143		require.NoError(t, err)
144
145		call := ToolCall{
146			Name:  BashToolName,
147			Input: string(paramsJSON),
148		}
149
150		response, err := tool.Run(context.Background(), call)
151		require.NoError(t, err)
152		assert.Contains(t, response.Content, "permission denied")
153	})
154
155	t.Run("handles command timeout", func(t *testing.T) {
156		tool := NewBashTool(newMockPermissionService(true))
157		params := BashParams{
158			Command: "sleep 2",
159			Timeout: 100, // 100ms timeout
160		}
161
162		paramsJSON, err := json.Marshal(params)
163		require.NoError(t, err)
164
165		call := ToolCall{
166			Name:  BashToolName,
167			Input: string(paramsJSON),
168		}
169
170		response, err := tool.Run(context.Background(), call)
171		require.NoError(t, err)
172		assert.Contains(t, response.Content, "aborted")
173	})
174
175	t.Run("handles command with stderr output", func(t *testing.T) {
176		tool := NewBashTool(newMockPermissionService(true))
177		params := BashParams{
178			Command: "echo 'error message' >&2",
179		}
180
181		paramsJSON, err := json.Marshal(params)
182		require.NoError(t, err)
183
184		call := ToolCall{
185			Name:  BashToolName,
186			Input: string(paramsJSON),
187		}
188
189		response, err := tool.Run(context.Background(), call)
190		require.NoError(t, err)
191		assert.Contains(t, response.Content, "error message")
192	})
193
194	t.Run("handles command with both stdout and stderr", func(t *testing.T) {
195		tool := NewBashTool(newMockPermissionService(true))
196		params := BashParams{
197			Command: "echo 'stdout message' && echo 'stderr message' >&2",
198		}
199
200		paramsJSON, err := json.Marshal(params)
201		require.NoError(t, err)
202
203		call := ToolCall{
204			Name:  BashToolName,
205			Input: string(paramsJSON),
206		}
207
208		response, err := tool.Run(context.Background(), call)
209		require.NoError(t, err)
210		assert.Contains(t, response.Content, "stdout message")
211		assert.Contains(t, response.Content, "stderr message")
212	})
213
214	t.Run("handles context cancellation", func(t *testing.T) {
215		tool := NewBashTool(newMockPermissionService(true))
216		params := BashParams{
217			Command: "sleep 5",
218		}
219
220		paramsJSON, err := json.Marshal(params)
221		require.NoError(t, err)
222
223		call := ToolCall{
224			Name:  BashToolName,
225			Input: string(paramsJSON),
226		}
227
228		ctx, cancel := context.WithCancel(context.Background())
229
230		// Cancel the context after a short delay
231		go func() {
232			time.Sleep(100 * time.Millisecond)
233			cancel()
234		}()
235
236		response, err := tool.Run(ctx, call)
237		require.NoError(t, err)
238		assert.Contains(t, response.Content, "aborted")
239	})
240
241	t.Run("respects max timeout", func(t *testing.T) {
242		tool := NewBashTool(newMockPermissionService(true))
243		params := BashParams{
244			Command: "echo 'test'",
245			Timeout: MaxTimeout + 1000, // Exceeds max timeout
246		}
247
248		paramsJSON, err := json.Marshal(params)
249		require.NoError(t, err)
250
251		call := ToolCall{
252			Name:  BashToolName,
253			Input: string(paramsJSON),
254		}
255
256		response, err := tool.Run(context.Background(), call)
257		require.NoError(t, err)
258		assert.Equal(t, "test\n", response.Content)
259	})
260
261	t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
262		tool := NewBashTool(newMockPermissionService(true))
263		params := BashParams{
264			Command: "echo 'test'",
265			Timeout: -100, // Negative timeout
266		}
267
268		paramsJSON, err := json.Marshal(params)
269		require.NoError(t, err)
270
271		call := ToolCall{
272			Name:  BashToolName,
273			Input: string(paramsJSON),
274		}
275
276		response, err := tool.Run(context.Background(), call)
277		require.NoError(t, err)
278		assert.Equal(t, "test\n", response.Content)
279	})
280}
281
282func TestTruncateOutput(t *testing.T) {
283	t.Run("does not truncate short output", func(t *testing.T) {
284		output := "short output"
285		result := truncateOutput(output)
286		assert.Equal(t, output, result)
287	})
288
289	t.Run("truncates long output", func(t *testing.T) {
290		// Create a string longer than MaxOutputLength
291		longOutput := strings.Repeat("a\n", MaxOutputLength)
292		result := truncateOutput(longOutput)
293
294		// Check that the result is shorter than the original
295		assert.Less(t, len(result), len(longOutput))
296
297		// Check that the truncation message is included
298		assert.Contains(t, result, "lines truncated")
299
300		// Check that we have the beginning and end of the original string
301		assert.True(t, strings.HasPrefix(result, "a\n"))
302		assert.True(t, strings.HasSuffix(result, "a\n"))
303	})
304}
305
306func TestCountLines(t *testing.T) {
307	testCases := []struct {
308		name     string
309		input    string
310		expected int
311	}{
312		{
313			name:     "empty string",
314			input:    "",
315			expected: 0,
316		},
317		{
318			name:     "single line",
319			input:    "line1",
320			expected: 1,
321		},
322		{
323			name:     "multiple lines",
324			input:    "line1\nline2\nline3",
325			expected: 3,
326		},
327		{
328			name:     "trailing newline",
329			input:    "line1\nline2\n",
330			expected: 3, // Empty string after last newline counts as a line
331		},
332	}
333
334	for _, tc := range testCases {
335		t.Run(tc.name, func(t *testing.T) {
336			result := countLines(tc.input)
337			assert.Equal(t, tc.expected, result)
338		})
339	}
340}