bash_test.go

  1package tools
  2
  3import (
  4	"context"
  5	"encoding/json"
  6	"os"
  7	"strings"
  8	"testing"
  9	"time"
 10
 11	"github.com/kujtimiihoxha/termai/internal/permission"
 12	"github.com/kujtimiihoxha/termai/internal/pubsub"
 13	"github.com/stretchr/testify/assert"
 14	"github.com/stretchr/testify/require"
 15)
 16
 17func TestBashTool_Info(t *testing.T) {
 18	tool := NewBashTool()
 19	info := tool.Info()
 20
 21	assert.Equal(t, BashToolName, info.Name)
 22	assert.NotEmpty(t, info.Description)
 23	assert.Contains(t, info.Parameters, "command")
 24	assert.Contains(t, info.Parameters, "timeout")
 25	assert.Contains(t, info.Required, "command")
 26}
 27
 28func TestBashTool_Run(t *testing.T) {
 29	// Setup a mock permission handler that always allows
 30	origPermission := permission.Default
 31	defer func() {
 32		permission.Default = origPermission
 33	}()
 34	permission.Default = newMockPermissionService(true)
 35
 36	// Save original working directory
 37	origWd, err := os.Getwd()
 38	require.NoError(t, err)
 39	defer func() {
 40		os.Chdir(origWd)
 41	}()
 42
 43	t.Run("executes command successfully", func(t *testing.T) {
 44		permission.Default = newMockPermissionService(true)
 45		tool := NewBashTool()
 46		params := BashParams{
 47			Command: "echo 'Hello World'",
 48		}
 49
 50		paramsJSON, err := json.Marshal(params)
 51		require.NoError(t, err)
 52
 53		call := ToolCall{
 54			Name:  BashToolName,
 55			Input: string(paramsJSON),
 56		}
 57
 58		response, err := tool.Run(context.Background(), call)
 59		require.NoError(t, err)
 60		assert.Equal(t, "Hello World\n", response.Content)
 61	})
 62
 63	t.Run("handles invalid parameters", func(t *testing.T) {
 64		permission.Default = newMockPermissionService(true)
 65
 66		tool := NewBashTool()
 67		call := ToolCall{
 68			Name:  BashToolName,
 69			Input: "invalid json",
 70		}
 71
 72		response, err := tool.Run(context.Background(), call)
 73		require.NoError(t, err)
 74		assert.Contains(t, response.Content, "invalid parameters")
 75	})
 76
 77	t.Run("handles missing command", func(t *testing.T) {
 78		permission.Default = newMockPermissionService(true)
 79
 80		tool := NewBashTool()
 81		params := BashParams{
 82			Command: "",
 83		}
 84
 85		paramsJSON, err := json.Marshal(params)
 86		require.NoError(t, err)
 87
 88		call := ToolCall{
 89			Name:  BashToolName,
 90			Input: string(paramsJSON),
 91		}
 92
 93		response, err := tool.Run(context.Background(), call)
 94		require.NoError(t, err)
 95		assert.Contains(t, response.Content, "missing command")
 96	})
 97
 98	t.Run("handles banned commands", func(t *testing.T) {
 99		permission.Default = newMockPermissionService(true)
100
101		tool := NewBashTool()
102
103		for _, bannedCmd := range BannedCommands {
104			params := BashParams{
105				Command: bannedCmd + " arg1 arg2",
106			}
107
108			paramsJSON, err := json.Marshal(params)
109			require.NoError(t, err)
110
111			call := ToolCall{
112				Name:  BashToolName,
113				Input: string(paramsJSON),
114			}
115
116			response, err := tool.Run(context.Background(), call)
117			require.NoError(t, err)
118			assert.Contains(t, response.Content, "not allowed", "Command %s should be blocked", bannedCmd)
119		}
120	})
121
122	t.Run("handles multi-word safe commands without permission check", func(t *testing.T) {
123		permission.Default = newMockPermissionService(false)
124
125		tool := NewBashTool()
126
127		// Test with multi-word safe commands
128		multiWordCommands := []string{
129			"git status",
130			"git log -n 5",
131			"docker ps",
132			"go test ./...",
133			"kubectl get pods",
134		}
135
136		for _, cmd := range multiWordCommands {
137			params := BashParams{
138				Command: cmd,
139			}
140
141			paramsJSON, err := json.Marshal(params)
142			require.NoError(t, err)
143
144			call := ToolCall{
145				Name:  BashToolName,
146				Input: string(paramsJSON),
147			}
148
149			response, err := tool.Run(context.Background(), call)
150			require.NoError(t, err)
151			assert.NotContains(t, response.Content, "permission denied", 
152				"Command %s should be allowed without permission", cmd)
153		}
154	})
155
156	t.Run("handles permission denied", func(t *testing.T) {
157		permission.Default = newMockPermissionService(false)
158
159		tool := NewBashTool()
160
161		// Test with a command that requires permission
162		params := BashParams{
163			Command: "mkdir test_dir",
164		}
165
166		paramsJSON, err := json.Marshal(params)
167		require.NoError(t, err)
168
169		call := ToolCall{
170			Name:  BashToolName,
171			Input: string(paramsJSON),
172		}
173
174		response, err := tool.Run(context.Background(), call)
175		require.NoError(t, err)
176		assert.Contains(t, response.Content, "permission denied")
177	})
178
179	t.Run("handles command timeout", func(t *testing.T) {
180		permission.Default = newMockPermissionService(true)
181		tool := NewBashTool()
182		params := BashParams{
183			Command: "sleep 2",
184			Timeout: 100, // 100ms timeout
185		}
186
187		paramsJSON, err := json.Marshal(params)
188		require.NoError(t, err)
189
190		call := ToolCall{
191			Name:  BashToolName,
192			Input: string(paramsJSON),
193		}
194
195		response, err := tool.Run(context.Background(), call)
196		require.NoError(t, err)
197		assert.Contains(t, response.Content, "aborted")
198	})
199
200	t.Run("handles command with stderr output", func(t *testing.T) {
201		permission.Default = newMockPermissionService(true)
202		tool := NewBashTool()
203		params := BashParams{
204			Command: "echo 'error message' >&2",
205		}
206
207		paramsJSON, err := json.Marshal(params)
208		require.NoError(t, err)
209
210		call := ToolCall{
211			Name:  BashToolName,
212			Input: string(paramsJSON),
213		}
214
215		response, err := tool.Run(context.Background(), call)
216		require.NoError(t, err)
217		assert.Contains(t, response.Content, "error message")
218	})
219
220	t.Run("handles command with both stdout and stderr", func(t *testing.T) {
221		permission.Default = newMockPermissionService(true)
222		tool := NewBashTool()
223		params := BashParams{
224			Command: "echo 'stdout message' && echo 'stderr message' >&2",
225		}
226
227		paramsJSON, err := json.Marshal(params)
228		require.NoError(t, err)
229
230		call := ToolCall{
231			Name:  BashToolName,
232			Input: string(paramsJSON),
233		}
234
235		response, err := tool.Run(context.Background(), call)
236		require.NoError(t, err)
237		assert.Contains(t, response.Content, "stdout message")
238		assert.Contains(t, response.Content, "stderr message")
239	})
240
241	t.Run("handles context cancellation", func(t *testing.T) {
242		permission.Default = newMockPermissionService(true)
243		tool := NewBashTool()
244		params := BashParams{
245			Command: "sleep 5",
246		}
247
248		paramsJSON, err := json.Marshal(params)
249		require.NoError(t, err)
250
251		call := ToolCall{
252			Name:  BashToolName,
253			Input: string(paramsJSON),
254		}
255
256		ctx, cancel := context.WithCancel(context.Background())
257
258		// Cancel the context after a short delay
259		go func() {
260			time.Sleep(100 * time.Millisecond)
261			cancel()
262		}()
263
264		response, err := tool.Run(ctx, call)
265		require.NoError(t, err)
266		assert.Contains(t, response.Content, "aborted")
267	})
268
269	t.Run("respects max timeout", func(t *testing.T) {
270		permission.Default = newMockPermissionService(true)
271		tool := NewBashTool()
272		params := BashParams{
273			Command: "echo 'test'",
274			Timeout: MaxTimeout + 1000, // Exceeds max timeout
275		}
276
277		paramsJSON, err := json.Marshal(params)
278		require.NoError(t, err)
279
280		call := ToolCall{
281			Name:  BashToolName,
282			Input: string(paramsJSON),
283		}
284
285		response, err := tool.Run(context.Background(), call)
286		require.NoError(t, err)
287		assert.Equal(t, "test\n", response.Content)
288	})
289
290	t.Run("uses default timeout for zero or negative timeout", func(t *testing.T) {
291		permission.Default = newMockPermissionService(true)
292		tool := NewBashTool()
293		params := BashParams{
294			Command: "echo 'test'",
295			Timeout: -100, // Negative timeout
296		}
297
298		paramsJSON, err := json.Marshal(params)
299		require.NoError(t, err)
300
301		call := ToolCall{
302			Name:  BashToolName,
303			Input: string(paramsJSON),
304		}
305
306		response, err := tool.Run(context.Background(), call)
307		require.NoError(t, err)
308		assert.Equal(t, "test\n", response.Content)
309	})
310}
311
312func TestTruncateOutput(t *testing.T) {
313	t.Run("does not truncate short output", func(t *testing.T) {
314		output := "short output"
315		result := truncateOutput(output)
316		assert.Equal(t, output, result)
317	})
318
319	t.Run("truncates long output", func(t *testing.T) {
320		// Create a string longer than MaxOutputLength
321		longOutput := strings.Repeat("a\n", MaxOutputLength)
322		result := truncateOutput(longOutput)
323
324		// Check that the result is shorter than the original
325		assert.Less(t, len(result), len(longOutput))
326
327		// Check that the truncation message is included
328		assert.Contains(t, result, "lines truncated")
329
330		// Check that we have the beginning and end of the original string
331		assert.True(t, strings.HasPrefix(result, "a\n"))
332		assert.True(t, strings.HasSuffix(result, "a\n"))
333	})
334}
335
336func TestCountLines(t *testing.T) {
337	testCases := []struct {
338		name     string
339		input    string
340		expected int
341	}{
342		{
343			name:     "empty string",
344			input:    "",
345			expected: 0,
346		},
347		{
348			name:     "single line",
349			input:    "line1",
350			expected: 1,
351		},
352		{
353			name:     "multiple lines",
354			input:    "line1\nline2\nline3",
355			expected: 3,
356		},
357		{
358			name:     "trailing newline",
359			input:    "line1\nline2\n",
360			expected: 3, // Empty string after last newline counts as a line
361		},
362	}
363
364	for _, tc := range testCases {
365		t.Run(tc.name, func(t *testing.T) {
366			result := countLines(tc.input)
367			assert.Equal(t, tc.expected, result)
368		})
369	}
370}
371
372// Mock permission service for testing
373type mockPermissionService struct {
374	*pubsub.Broker[permission.PermissionRequest]
375	allow bool
376}
377
378func (m *mockPermissionService) GrantPersistant(permission permission.PermissionRequest) {
379	// Not needed for tests
380}
381
382func (m *mockPermissionService) Grant(permission permission.PermissionRequest) {
383	// Not needed for tests
384}
385
386func (m *mockPermissionService) Deny(permission permission.PermissionRequest) {
387	// Not needed for tests
388}
389
390func (m *mockPermissionService) Request(opts permission.CreatePermissionRequest) bool {
391	return m.allow
392}
393
394func newMockPermissionService(allow bool) permission.Service {
395	return &mockPermissionService{
396		Broker: pubsub.NewBroker[permission.PermissionRequest](),
397		allow:  allow,
398	}
399}
400