bash_test.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"os"
  7	"strings"
  8	"testing"
  9	"time"
 10
 11	"github.com/kujtimiihoxha/termai/internal/permission"
 12	"github.com/kujtimiihoxha/termai/internal/pubsub"
 13	"github.com/stretchr/testify/assert"
 14	"github.com/stretchr/testify/require"
 15)
 16
 17func TestBashTool_Info(t *testing.T) {
 18	tool := NewBashTool()
 19	info := tool.Info()
 20
 21	assert.Equal(t, BashToolName, info.Name)
 22	assert.NotEmpty(t, info.Description)
 23	assert.Contains(t, info.Parameters, "command")
 24	assert.Contains(t, info.Parameters, "timeout")
 25	assert.Contains(t, info.Required, "command")
 26}
 27
 28func TestBashTool_Run(t *testing.T) {
 29	// Setup a mock permission handler that always allows
 30	origPermission := permission.Default
 31	defer func() {
 32		permission.Default = origPermission
 33	}()
 34	permission.Default = newMockPermissionService(true)
 35
 36	// Save original working directory
 37	origWd, err := os.Getwd()
 38	require.NoError(t, err)
 39	defer func() {
 40		os.Chdir(origWd)
 41	}()
 42
 43	t.Run("executes command successfully", func(t *testing.T) {
 44		permission.Default = newMockPermissionService(true)
 45		tool := NewBashTool()
 46		params := BashParams{
 47			Command: "echo 'Hello World'",
 48		}
 49
 50		paramsJSON, err := json.Marshal(params)
 51		require.NoError(t, err)
 52
 53		call := ToolCall{
 54			Name:  BashToolName,
 55			Input: string(paramsJSON),
 56		}
 57
 58		response, err := tool.Run(context.Background(), call)
 59		require.NoError(t, err)
 60		assert.Equal(t, "Hello World\n", response.Content)
 61	})
 62
 63	t.Run("handles invalid parameters", func(t *testing.T) {
 64		permission.Default = newMockPermissionService(true)
 65
 66		tool := NewBashTool()
 67		call := ToolCall{
 68			Name:  BashToolName,
 69			Input: "invalid json",
 70		}
 71
 72		response, err := tool.Run(context.Background(), call)
 73		require.NoError(t, err)
 74		assert.Contains(t, response.Content, "invalid parameters")
 75	})
 76
 77	t.Run("handles missing command", func(t *testing.T) {
 78		permission.Default = newMockPermissionService(true)
 79
 80		tool := NewBashTool()
 81		params := BashParams{
 82			Command: "",
 83		}
 84
 85		paramsJSON, err := json.Marshal(params)
 86		require.NoError(t, err)
 87
 88		call := ToolCall{
 89			Name:  BashToolName,
 90			Input: string(paramsJSON),
 91		}
 92
 93		response, err := tool.Run(context.Background(), call)
 94		require.NoError(t, err)
 95		assert.Contains(t, response.Content, "missing command")
 96	})
 97
 98	t.Run("handles banned commands", func(t *testing.T) {
 99		permission.Default = newMockPermissionService(true)
100
101		tool := NewBashTool()
102
103		for _, bannedCmd := range BannedCommands {
104			params := BashParams{
105				Command: bannedCmd + " arg1 arg2",
106			}
107
108			paramsJSON, err := json.Marshal(params)
109			require.NoError(t, err)
110
111			call := ToolCall{
112				Name:  BashToolName,
113				Input: string(paramsJSON),
114			}
115
116			response, err := tool.Run(context.Background(), call)
117			require.NoError(t, err)
118			assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
119		}
120	})
121
122	t.Run("handles safe read-only commands without permission check", func(t *testing.T) {
123		permission.Default = newMockPermissionService(false)
124
125		tool := NewBashTool()
126
127		// Test with a safe read-only command
128		params := BashParams{
129			Command: "echo 'test'",
130		}
131
132		paramsJSON, err := json.Marshal(params)
133		require.NoError(t, err)
134
135		call := ToolCall{
136			Name:  BashToolName,
137			Input: string(paramsJSON),
138		}
139
140		response, err := tool.Run(context.Background(), call)
141		require.NoError(t, err)
142		assert.Equal(t, "test\n", response.Content)
143	})
144
145	t.Run("handles permission denied", func(t *testing.T) {
146		permission.Default = newMockPermissionService(false)
147
148		tool := NewBashTool()
149
150		// Test with a command that requires permission
151		params := BashParams{
152			Command: "mkdir test_dir",
153		}
154
155		paramsJSON, err := json.Marshal(params)
156		require.NoError(t, err)
157
158		call := ToolCall{
159			Name:  BashToolName,
160			Input: string(paramsJSON),
161		}
162
163		response, err := tool.Run(context.Background(), call)
164		require.NoError(t, err)
165		assert.Contains(t, response.Content, "permission denied")
166	})
167
168	t.Run("handles command timeout", func(t *testing.T) {
169		permission.Default = newMockPermissionService(true)
170		tool := NewBashTool()
171		params := BashParams{
172			Command: "sleep 2",
173			Timeout: 100, // 100ms timeout
174		}
175
176		paramsJSON, err := json.Marshal(params)
177		require.NoError(t, err)
178
179		call := ToolCall{
180			Name:  BashToolName,
181			Input: string(paramsJSON),
182		}
183
184		response, err := tool.Run(context.Background(), call)
185		require.NoError(t, err)
186		assert.Contains(t, response.Content, "aborted")
187	})
188
189	t.Run("handles command with stderr output", func(t *testing.T) {
190		permission.Default = newMockPermissionService(true)
191		tool := NewBashTool()
192		params := BashParams{
193			Command: "echo 'error message' >&2",
194		}
195
196		paramsJSON, err := json.Marshal(params)
197		require.NoError(t, err)
198
199		call := ToolCall{
200			Name:  BashToolName,
201			Input: string(paramsJSON),
202		}
203
204		response, err := tool.Run(context.Background(), call)
205		require.NoError(t, err)
206		assert.Contains(t, response.Content, "error message")
207	})
208
209	t.Run("handles command with both stdout and stderr", func(t *testing.T) {
210		permission.Default = newMockPermissionService(true)
211		tool := NewBashTool()
212		params := BashParams{
213			Command: "echo 'stdout message' && echo 'stderr message' >&2",
214		}
215
216		paramsJSON, err := json.Marshal(params)
217		require.NoError(t, err)
218
219		call := ToolCall{
220			Name:  BashToolName,
221			Input: string(paramsJSON),
222		}
223
224		response, err := tool.Run(context.Background(), call)
225		require.NoError(t, err)
226		assert.Contains(t, response.Content, "stdout message")
227		assert.Contains(t, response.Content, "stderr message")
228	})
229
230	t.Run("handles context cancellation", func(t *testing.T) {
231		permission.Default = newMockPermissionService(true)
232		tool := NewBashTool()
233		params := BashParams{
234			Command: "sleep 5",
235		}
236
237		paramsJSON, err := json.Marshal(params)
238		require.NoError(t, err)
239
240		call := ToolCall{
241			Name:  BashToolName,
242			Input: string(paramsJSON),
243		}
244
245		ctx, cancel := context.WithCancel(context.Background())
246
247		// Cancel the context after a short delay
248		go func() {
249			time.Sleep(100 * time.Millisecond)
250			cancel()
251		}()
252
253		response, err := tool.Run(ctx, call)
254		require.NoError(t, err)
255		assert.Contains(t, response.Content, "aborted")
256	})
257
258	t.Run("respects max timeout", func(t *testing.T) {
259		permission.Default = newMockPermissionService(true)
260		tool := NewBashTool()
261		params := BashParams{
262			Command: "echo 'test'",
263			Timeout: MaxTimeout + 1000, // Exceeds max timeout
264		}
265
266		paramsJSON, err := json.Marshal(params)
267		require.NoError(t, err)
268
269		call := ToolCall{
270			Name:  BashToolName,
271			Input: string(paramsJSON),
272		}
273
274		response, err := tool.Run(context.Background(), call)
275		require.NoError(t, err)
276		assert.Equal(t, "test\n", response.Content)
277	})
278
279	t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
280		permission.Default = newMockPermissionService(true)
281		tool := NewBashTool()
282		params := BashParams{
283			Command: "echo 'test'",
284			Timeout: -100, // Negative timeout
285		}
286
287		paramsJSON, err := json.Marshal(params)
288		require.NoError(t, err)
289
290		call := ToolCall{
291			Name:  BashToolName,
292			Input: string(paramsJSON),
293		}
294
295		response, err := tool.Run(context.Background(), call)
296		require.NoError(t, err)
297		assert.Equal(t, "test\n", response.Content)
298	})
299}
300
301func TestTruncateOutput(t *testing.T) {
302	t.Run("does not truncate short output", func(t *testing.T) {
303		output := "short output"
304		result := truncateOutput(output)
305		assert.Equal(t, output, result)
306	})
307
308	t.Run("truncates long output", func(t *testing.T) {
309		// Create a string longer than MaxOutputLength
310		longOutput := strings.Repeat("a\n", MaxOutputLength)
311		result := truncateOutput(longOutput)
312
313		// Check that the result is shorter than the original
314		assert.Less(t, len(result), len(longOutput))
315
316		// Check that the truncation message is included
317		assert.Contains(t, result, "lines truncated")
318
319		// Check that we have the beginning and end of the original string
320		assert.True(t, strings.HasPrefix(result, "a\n"))
321		assert.True(t, strings.HasSuffix(result, "a\n"))
322	})
323}
324
325func TestCountLines(t *testing.T) {
326	testCases := []struct {
327		name     string
328		input    string
329		expected int
330	}{
331		{
332			name:     "empty string",
333			input:    "",
334			expected: 0,
335		},
336		{
337			name:     "single line",
338			input:    "line1",
339			expected: 1,
340		},
341		{
342			name:     "multiple lines",
343			input:    "line1\nline2\nline3",
344			expected: 3,
345		},
346		{
347			name:     "trailing newline",
348			input:    "line1\nline2\n",
349			expected: 3, // Empty string after last newline counts as a line
350		},
351	}
352
353	for _, tc := range testCases {
354		t.Run(tc.name, func(t *testing.T) {
355			result := countLines(tc.input)
356			assert.Equal(t, tc.expected, result)
357		})
358	}
359}
360
361// Mock permission service for testing
362type mockPermissionService struct {
363	*pubsub.Broker[permission.PermissionRequest]
364	allow bool
365}
366
367func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
368	// Not needed for tests
369}
370
371func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
372	// Not needed for tests
373}
374
375func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
376	// Not needed for tests
377}
378
379func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
380	return m.allow
381}
382
383func newMockPermissionService(allow bool) permission.Service {
384	return &mockPermissionService{
385		Broker: pubsub.NewBroker[permission.PermissionRequest](),
386		allow:  allow,
387	}
388}
389