bash_test.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"os"
  7	"strings"
  8	"testing"
  9	"time"
 10
 11	"github.com/kujtimiihoxha/termai/internal/permission"
 12	"github.com/kujtimiihoxha/termai/internal/pubsub"
 13	"github.com/stretchr/testify/assert"
 14	"github.com/stretchr/testify/require"
 15)
 16
 17func TestBashTool_Info(t *testing.T) {
 18	tool := NewBashTool(newMockPermissionService(true))
 19	info := tool.Info()
 20
 21	assert.Equal(t, BashToolName, info.Name)
 22	assert.NotEmpty(t, info.Description)
 23	assert.Contains(t, info.Parameters, "command")
 24	assert.Contains(t, info.Parameters, "timeout")
 25	assert.Contains(t, info.Required, "command")
 26}
 27
 28func TestBashTool_Run(t *testing.T) {
 29	// Save original working directory
 30	origWd, err := os.Getwd()
 31	require.NoError(t, err)
 32	defer func() {
 33		os.Chdir(origWd)
 34	}()
 35
 36	t.Run("executes command successfully", func(t *testing.T) {
 37		tool := NewBashTool(newMockPermissionService(true))
 38		params := BashParams{
 39			Command: "echo 'Hello World'",
 40		}
 41
 42		paramsJSON, err := json.Marshal(params)
 43		require.NoError(t, err)
 44
 45		call := ToolCall{
 46			Name:  BashToolName,
 47			Input: string(paramsJSON),
 48		}
 49
 50		response, err := tool.Run(context.Background(), call)
 51		require.NoError(t, err)
 52		assert.Equal(t, "Hello World\n", response.Content)
 53	})
 54
 55	t.Run("handles invalid parameters", func(t *testing.T) {
 56		tool := NewBashTool(newMockPermissionService(true))
 57		call := ToolCall{
 58			Name:  BashToolName,
 59			Input: "invalid json",
 60		}
 61
 62		response, err := tool.Run(context.Background(), call)
 63		require.NoError(t, err)
 64		assert.Contains(t, response.Content, "invalid parameters")
 65	})
 66
 67	t.Run("handles missing command", func(t *testing.T) {
 68		tool := NewBashTool(newMockPermissionService(true))
 69		params := BashParams{
 70			Command: "",
 71		}
 72
 73		paramsJSON, err := json.Marshal(params)
 74		require.NoError(t, err)
 75
 76		call := ToolCall{
 77			Name:  BashToolName,
 78			Input: string(paramsJSON),
 79		}
 80
 81		response, err := tool.Run(context.Background(), call)
 82		require.NoError(t, err)
 83		assert.Contains(t, response.Content, "missing command")
 84	})
 85
 86	t.Run("handles banned commands", func(t *testing.T) {
 87		tool := NewBashTool(newMockPermissionService(true))
 88
 89		for _, bannedCmd := range bannedCommands {
 90			params := BashParams{
 91				Command: bannedCmd + " arg1 arg2",
 92			}
 93
 94			paramsJSON, err := json.Marshal(params)
 95			require.NoError(t, err)
 96
 97			call := ToolCall{
 98				Name:  BashToolName,
 99				Input: string(paramsJSON),
100			}
101
102			response, err := tool.Run(context.Background(), call)
103			require.NoError(t, err)
104			assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
105		}
106	})
107
108	t.Run("handles multi-word safe commands without permission check", func(t *testing.T) {
109		tool := NewBashTool(newMockPermissionService(false))
110
111		// Test with multi-word safe commands
112		multiWordCommands := []string{
113			"go env",
114		}
115
116		for _, cmd := range multiWordCommands {
117			params := BashParams{
118				Command: cmd,
119			}
120
121			paramsJSON, err := json.Marshal(params)
122			require.NoError(t, err)
123
124			call := ToolCall{
125				Name:  BashToolName,
126				Input: string(paramsJSON),
127			}
128
129			response, err := tool.Run(context.Background(), call)
130			require.NoError(t, err)
131			assert.NotContains(t, response.Content, "permission denied",
132				"Command %s should be allowed without permission", cmd)
133		}
134	})
135
136	t.Run("handles permission denied", func(t *testing.T) {
137		tool := NewBashTool(newMockPermissionService(false))
138
139		// Test with a command that requires permission
140		params := BashParams{
141			Command: "mkdir test_dir",
142		}
143
144		paramsJSON, err := json.Marshal(params)
145		require.NoError(t, err)
146
147		call := ToolCall{
148			Name:  BashToolName,
149			Input: string(paramsJSON),
150		}
151
152		response, err := tool.Run(context.Background(), call)
153		require.NoError(t, err)
154		assert.Contains(t, response.Content, "permission denied")
155	})
156
157	t.Run("handles command timeout", func(t *testing.T) {
158		tool := NewBashTool(newMockPermissionService(true))
159		params := BashParams{
160			Command: "sleep 2",
161			Timeout: 100, // 100ms timeout
162		}
163
164		paramsJSON, err := json.Marshal(params)
165		require.NoError(t, err)
166
167		call := ToolCall{
168			Name:  BashToolName,
169			Input: string(paramsJSON),
170		}
171
172		response, err := tool.Run(context.Background(), call)
173		require.NoError(t, err)
174		assert.Contains(t, response.Content, "aborted")
175	})
176
177	t.Run("handles command with stderr output", func(t *testing.T) {
178		tool := NewBashTool(newMockPermissionService(true))
179		params := BashParams{
180			Command: "echo 'error message' >&2",
181		}
182
183		paramsJSON, err := json.Marshal(params)
184		require.NoError(t, err)
185
186		call := ToolCall{
187			Name:  BashToolName,
188			Input: string(paramsJSON),
189		}
190
191		response, err := tool.Run(context.Background(), call)
192		require.NoError(t, err)
193		assert.Contains(t, response.Content, "error message")
194	})
195
196	t.Run("handles command with both stdout and stderr", func(t *testing.T) {
197		tool := NewBashTool(newMockPermissionService(true))
198		params := BashParams{
199			Command: "echo 'stdout message' && echo 'stderr message' >&2",
200		}
201
202		paramsJSON, err := json.Marshal(params)
203		require.NoError(t, err)
204
205		call := ToolCall{
206			Name:  BashToolName,
207			Input: string(paramsJSON),
208		}
209
210		response, err := tool.Run(context.Background(), call)
211		require.NoError(t, err)
212		assert.Contains(t, response.Content, "stdout message")
213		assert.Contains(t, response.Content, "stderr message")
214	})
215
216	t.Run("handles context cancellation", func(t *testing.T) {
217		tool := NewBashTool(newMockPermissionService(true))
218		params := BashParams{
219			Command: "sleep 5",
220		}
221
222		paramsJSON, err := json.Marshal(params)
223		require.NoError(t, err)
224
225		call := ToolCall{
226			Name:  BashToolName,
227			Input: string(paramsJSON),
228		}
229
230		ctx, cancel := context.WithCancel(context.Background())
231
232		// Cancel the context after a short delay
233		go func() {
234			time.Sleep(100 * time.Millisecond)
235			cancel()
236		}()
237
238		response, err := tool.Run(ctx, call)
239		require.NoError(t, err)
240		assert.Contains(t, response.Content, "aborted")
241	})
242
243	t.Run("respects max timeout", func(t *testing.T) {
244		tool := NewBashTool(newMockPermissionService(true))
245		params := BashParams{
246			Command: "echo 'test'",
247			Timeout: MaxTimeout + 1000, // Exceeds max timeout
248		}
249
250		paramsJSON, err := json.Marshal(params)
251		require.NoError(t, err)
252
253		call := ToolCall{
254			Name:  BashToolName,
255			Input: string(paramsJSON),
256		}
257
258		response, err := tool.Run(context.Background(), call)
259		require.NoError(t, err)
260		assert.Equal(t, "test\n", response.Content)
261	})
262
263	t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
264		tool := NewBashTool(newMockPermissionService(true))
265		params := BashParams{
266			Command: "echo 'test'",
267			Timeout: -100, // Negative timeout
268		}
269
270		paramsJSON, err := json.Marshal(params)
271		require.NoError(t, err)
272
273		call := ToolCall{
274			Name:  BashToolName,
275			Input: string(paramsJSON),
276		}
277
278		response, err := tool.Run(context.Background(), call)
279		require.NoError(t, err)
280		assert.Equal(t, "test\n", response.Content)
281	})
282}
283
284func TestTruncateOutput(t *testing.T) {
285	t.Run("does not truncate short output", func(t *testing.T) {
286		output := "short output"
287		result := truncateOutput(output)
288		assert.Equal(t, output, result)
289	})
290
291	t.Run("truncates long output", func(t *testing.T) {
292		// Create a string longer than MaxOutputLength
293		longOutput := strings.Repeat("a\n", MaxOutputLength)
294		result := truncateOutput(longOutput)
295
296		// Check that the result is shorter than the original
297		assert.Less(t, len(result), len(longOutput))
298
299		// Check that the truncation message is included
300		assert.Contains(t, result, "lines truncated")
301
302		// Check that we have the beginning and end of the original string
303		assert.True(t, strings.HasPrefix(result, "a\n"))
304		assert.True(t, strings.HasSuffix(result, "a\n"))
305	})
306}
307
308func TestCountLines(t *testing.T) {
309	testCases := []struct {
310		name     string
311		input    string
312		expected int
313	}{
314		{
315			name:     "empty string",
316			input:    "",
317			expected: 0,
318		},
319		{
320			name:     "single line",
321			input:    "line1",
322			expected: 1,
323		},
324		{
325			name:     "multiple lines",
326			input:    "line1\nline2\nline3",
327			expected: 3,
328		},
329		{
330			name:     "trailing newline",
331			input:    "line1\nline2\n",
332			expected: 3, // Empty string after last newline counts as a line
333		},
334	}
335
336	for _, tc := range testCases {
337		t.Run(tc.name, func(t *testing.T) {
338			result := countLines(tc.input)
339			assert.Equal(t, tc.expected, result)
340		})
341	}
342}
343
344// Mock permission service for testing
345type mockPermissionService struct {
346	*pubsub.Broker[permission.PermissionRequest]
347	allow bool
348}
349
350func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
351	// Not needed for tests
352}
353
354func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
355	// Not needed for tests
356}
357
358func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
359	// Not needed for tests
360}
361
362func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
363	return m.allow
364}
365
366func newMockPermissionService(allow bool) permission.Service {
367	return &mockPermissionService{
368		Broker: pubsub.NewBroker[permission.PermissionRequest](),
369		allow:  allow,
370	}
371}