hooks_test.go

  1package hooks
  2
  3import (
  4	"context"
  5	"os"
  6	"path/filepath"
  7	"strings"
  8	"testing"
  9	"time"
 10
 11	"github.com/charmbracelet/crush/internal/config"
 12	"github.com/stretchr/testify/require"
 13)
 14
 15func TestHookExecutor_Execute(t *testing.T) {
 16	t.Parallel()
 17
 18	tempDir := t.TempDir()
 19
 20	tests := []struct {
 21		name    string
 22		config  config.HookConfig
 23		hookCtx HookContext
 24		wantErr bool
 25	}{
 26		{
 27			name: "simple command hook",
 28			config: config.HookConfig{
 29				config.PreToolUse: []config.HookMatcher{
 30					{
 31						Matcher: "bash",
 32						Hooks: []config.Hook{
 33							{
 34								Type:    "command",
 35								Command: "echo 'hook executed'",
 36							},
 37						},
 38					},
 39				},
 40			},
 41			hookCtx: HookContext{
 42				EventType: config.PreToolUse,
 43				ToolName:  "bash",
 44			},
 45		},
 46		{
 47			name: "hook with jq processing",
 48			config: config.HookConfig{
 49				config.PreToolUse: []config.HookMatcher{
 50					{
 51						Matcher: "bash",
 52						Hooks: []config.Hook{
 53							{
 54								Type:    "command",
 55								Command: `jq -r '.tool_name'`,
 56							},
 57						},
 58					},
 59				},
 60			},
 61			hookCtx: HookContext{
 62				EventType: config.PreToolUse,
 63				ToolName:  "bash",
 64			},
 65		},
 66		{
 67			name: "hook that writes to file",
 68			config: config.HookConfig{
 69				config.PostToolUse: []config.HookMatcher{
 70					{
 71						Matcher: "*",
 72						Hooks: []config.Hook{
 73							{
 74								Type:    "command",
 75								Command: `jq -r '"\(.tool_name): \(.tool_result)"' >> ` + filepath.Join(tempDir, "hook-log.txt"),
 76							},
 77						},
 78					},
 79				},
 80			},
 81			hookCtx: HookContext{
 82				EventType:  config.PostToolUse,
 83				ToolName:   "edit",
 84				ToolResult: "file edited successfully",
 85			},
 86		},
 87		{
 88			name: "hook with timeout",
 89			config: config.HookConfig{
 90				config.Stop: []config.HookMatcher{
 91					{
 92						Hooks: []config.Hook{
 93							{
 94								Type:    "command",
 95								Command: "sleep 0.1 && echo 'done'",
 96								Timeout: ptrInt(1),
 97							},
 98						},
 99					},
100				},
101			},
102			hookCtx: HookContext{
103				EventType: config.Stop,
104			},
105		},
106		{
107			name: "failed hook command",
108			config: config.HookConfig{
109				config.PreToolUse: []config.HookMatcher{
110					{
111						Matcher: "bash",
112						Hooks: []config.Hook{
113							{
114								Type:    "command",
115								Command: "exit 1",
116							},
117						},
118					},
119				},
120			},
121			hookCtx: HookContext{
122				EventType: config.PreToolUse,
123				ToolName:  "bash",
124			},
125			wantErr: true,
126		},
127		{
128			name: "hook with single quote in JSON",
129			config: config.HookConfig{
130				config.PostToolUse: []config.HookMatcher{
131					{
132						Matcher: "edit",
133						Hooks: []config.Hook{
134							{
135								Type:    "command",
136								Command: `jq -r '.tool_result'`,
137							},
138						},
139					},
140				},
141			},
142			hookCtx: HookContext{
143				EventType:  config.PostToolUse,
144				ToolName:   "edit",
145				ToolResult: "it's a test with 'quotes'",
146			},
147		},
148	}
149
150	for _, tt := range tests {
151		t.Run(tt.name, func(t *testing.T) {
152			t.Parallel()
153
154			executor := NewExecutor(tt.config, tempDir)
155			require.NotNil(t, executor)
156
157			ctx := context.Background()
158			err := executor.Execute(ctx, tt.hookCtx)
159
160			if tt.wantErr {
161				require.Error(t, err)
162			} else {
163				require.NoError(t, err)
164			}
165		})
166	}
167}
168
169func TestHookExecutor_MatcherApplies(t *testing.T) {
170	t.Parallel()
171
172	tempDir := t.TempDir()
173	executor := NewExecutor(config.HookConfig{}, tempDir)
174
175	tests := []struct {
176		name    string
177		matcher config.HookMatcher
178		ctx     HookContext
179		want    bool
180	}{
181		{
182			name: "empty matcher matches all",
183			matcher: config.HookMatcher{
184				Matcher: "",
185			},
186			ctx: HookContext{
187				EventType: config.PreToolUse,
188				ToolName:  "bash",
189			},
190			want: true,
191		},
192		{
193			name: "wildcard matcher matches all",
194			matcher: config.HookMatcher{
195				Matcher: "*",
196			},
197			ctx: HookContext{
198				EventType: config.PreToolUse,
199				ToolName:  "edit",
200			},
201			want: true,
202		},
203		{
204			name: "specific tool matcher matches",
205			matcher: config.HookMatcher{
206				Matcher: "bash",
207			},
208			ctx: HookContext{
209				EventType: config.PreToolUse,
210				ToolName:  "bash",
211			},
212			want: true,
213		},
214		{
215			name: "specific tool matcher doesn't match different tool",
216			matcher: config.HookMatcher{
217				Matcher: "bash",
218			},
219			ctx: HookContext{
220				EventType: config.PreToolUse,
221				ToolName:  "edit",
222			},
223			want: false,
224		},
225		{
226			name: "non-tool event matches empty matcher",
227			matcher: config.HookMatcher{
228				Matcher: "",
229			},
230			ctx: HookContext{
231				EventType: config.Stop,
232			},
233			want: true,
234		},
235	}
236
237	for _, tt := range tests {
238		t.Run(tt.name, func(t *testing.T) {
239			t.Parallel()
240
241			got := executor.matcherApplies(tt.matcher, tt.ctx)
242			require.Equal(t, tt.want, got)
243		})
244	}
245}
246
247func TestHookExecutor_Timeout(t *testing.T) {
248	t.Parallel()
249
250	tempDir := t.TempDir()
251	shortTimeout := 1
252
253	hookConfig := config.HookConfig{
254		config.Stop: []config.HookMatcher{
255			{
256				Hooks: []config.Hook{
257					{
258						Type:    "command",
259						Command: "sleep 10",
260						Timeout: &shortTimeout,
261					},
262				},
263			},
264		},
265	}
266
267	executor := NewExecutor(hookConfig, tempDir)
268	ctx := context.Background()
269
270	start := time.Now()
271	err := executor.Execute(ctx, HookContext{
272		EventType: config.Stop,
273	})
274	duration := time.Since(start)
275
276	require.Error(t, err)
277	require.Less(t, duration, 2*time.Second)
278}
279
280func TestHookExecutor_MultipleHooks(t *testing.T) {
281	t.Parallel()
282
283	tempDir := t.TempDir()
284	logFile := filepath.Join(tempDir, "multi-hook-log.txt")
285
286	hookConfig := config.HookConfig{
287		config.PreToolUse: []config.HookMatcher{
288			{
289				Matcher: "bash",
290				Hooks: []config.Hook{
291					{
292						Type:    "command",
293						Command: "echo 'hook1' >> " + logFile,
294					},
295					{
296						Type:    "command",
297						Command: "echo 'hook2' >> " + logFile,
298					},
299					{
300						Type:    "command",
301						Command: "echo 'hook3' >> " + logFile,
302					},
303				},
304			},
305		},
306	}
307
308	executor := NewExecutor(hookConfig, tempDir)
309	ctx := context.Background()
310
311	err := executor.Execute(ctx, HookContext{
312		EventType: config.PreToolUse,
313		ToolName:  "bash",
314	})
315
316	require.NoError(t, err)
317
318	content, err := os.ReadFile(logFile)
319	require.NoError(t, err)
320
321	lines := strings.Split(strings.TrimSpace(string(content)), "\n")
322	require.Len(t, lines, 3)
323	require.Equal(t, "hook1", lines[0])
324	require.Equal(t, "hook2", lines[1])
325	require.Equal(t, "hook3", lines[2])
326}
327
328func TestHookExecutor_ContextCancellation(t *testing.T) {
329	t.Parallel()
330
331	tempDir := t.TempDir()
332	logFile := filepath.Join(tempDir, "cancel-log.txt")
333
334	hookConfig := config.HookConfig{
335		config.PreToolUse: []config.HookMatcher{
336			{
337				Matcher: "bash",
338				Hooks: []config.Hook{
339					{
340						Type:    "command",
341						Command: "echo 'hook1' >> " + logFile,
342					},
343					{
344						Type:    "command",
345						Command: "sleep 10 && echo 'hook2' >> " + logFile,
346					},
347				},
348			},
349		},
350	}
351
352	executor := NewExecutor(hookConfig, tempDir)
353	ctx, cancel := context.WithCancel(context.Background())
354
355	go func() {
356		time.Sleep(100 * time.Millisecond)
357		cancel()
358	}()
359
360	err := executor.Execute(ctx, HookContext{
361		EventType: config.PreToolUse,
362		ToolName:  "bash",
363	})
364
365	require.Error(t, err)
366	require.ErrorIs(t, err, context.Canceled)
367}
368
369func ptrInt(i int) *int {
370	return &i
371}