1package agent
2
3import (
4 "context"
5 "errors"
6 "testing"
7
8 "charm.land/catwalk/pkg/catwalk"
9 "charm.land/fantasy"
10 "charm.land/fantasy/providers/anthropic"
11 "charm.land/fantasy/providers/bedrock"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15)
16
17// mockSessionAgent is a minimal mock for the SessionAgent interface.
18type mockSessionAgent struct {
19 model Model
20 runFunc func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error)
21 cancelled []string
22}
23
24func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
25 return m.runFunc(ctx, call)
26}
27
28func (m *mockSessionAgent) BeginAccepted(sessionID string) *AcceptedRun {
29 return &AcceptedRun{sessionID: sessionID}
30}
31
32func (m *mockSessionAgent) Model() Model { return m.model }
33func (m *mockSessionAgent) SetModels(large, small Model) {}
34func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool) {}
35func (m *mockSessionAgent) SetSystemPrompt(systemPrompt string) {}
36func (m *mockSessionAgent) Cancel(sessionID string) {
37 m.cancelled = append(m.cancelled, sessionID)
38}
39func (m *mockSessionAgent) CancelAll() {}
40func (m *mockSessionAgent) IsSessionBusy(sessionID string) bool { return false }
41func (m *mockSessionAgent) IsBusy() bool { return false }
42func (m *mockSessionAgent) QueuedPrompts(sessionID string) int { return 0 }
43func (m *mockSessionAgent) QueuedPromptsList(sessionID string) []string { return nil }
44func (m *mockSessionAgent) ClearQueue(sessionID string) {}
45func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOptions) error {
46 return nil
47}
48
49// newTestCoordinator creates a minimal coordinator for unit testing runSubAgent.
50func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator {
51 cfg, err := config.Init(env.workingDir, "", false)
52 require.NoError(t, err)
53 cfg.Config().Providers.Set(providerID, providerCfg)
54 return &coordinator{
55 cfg: cfg,
56 sessions: env.sessions,
57 }
58}
59
60// newMockAgent creates a mockSessionAgent with the given provider and run function.
61func newMockAgent(providerID string, maxTokens int64, runFunc func(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)) *mockSessionAgent {
62 return &mockSessionAgent{
63 model: Model{
64 CatwalkCfg: catwalk.Model{
65 DefaultMaxTokens: maxTokens,
66 },
67 ModelCfg: config.SelectedModel{
68 Provider: providerID,
69 },
70 },
71 runFunc: runFunc,
72 }
73}
74
75// agentResultWithText creates a minimal AgentResult with the given text response.
76func agentResultWithText(text string) *fantasy.AgentResult {
77 return &fantasy.AgentResult{
78 Response: fantasy.Response{
79 Content: fantasy.ResponseContent{
80 fantasy.TextContent{Text: text},
81 },
82 },
83 }
84}
85
86func TestRunSubAgent(t *testing.T) {
87 const providerID = "test-provider"
88 providerCfg := config.ProviderConfig{ID: providerID}
89
90 t.Run("happy path", func(t *testing.T) {
91 env := testEnv(t)
92 coord := newTestCoordinator(t, env, providerID, providerCfg)
93
94 parentSession, err := env.sessions.Create(t.Context(), "Parent")
95 require.NoError(t, err)
96
97 agent := newMockAgent(providerID, 4096, func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
98 assert.Equal(t, "do something", call.Prompt)
99 assert.Equal(t, int64(4096), call.MaxOutputTokens)
100 return agentResultWithText("done"), nil
101 })
102
103 resp, err := coord.runSubAgent(t.Context(), subAgentParams{
104 Agent: agent,
105 SessionID: parentSession.ID,
106 AgentMessageID: "msg-1",
107 ToolCallID: "call-1",
108 Prompt: "do something",
109 SessionTitle: "Test Session",
110 })
111 require.NoError(t, err)
112 assert.Equal(t, "done", resp.Content)
113 assert.False(t, resp.IsError)
114 })
115
116 t.Run("ModelCfg.MaxTokens overrides default", func(t *testing.T) {
117 env := testEnv(t)
118 coord := newTestCoordinator(t, env, providerID, providerCfg)
119
120 parentSession, err := env.sessions.Create(t.Context(), "Parent")
121 require.NoError(t, err)
122
123 agent := &mockSessionAgent{
124 model: Model{
125 CatwalkCfg: catwalk.Model{
126 DefaultMaxTokens: 4096,
127 },
128 ModelCfg: config.SelectedModel{
129 Provider: providerID,
130 MaxTokens: 8192,
131 },
132 },
133 runFunc: func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
134 assert.Equal(t, int64(8192), call.MaxOutputTokens)
135 return agentResultWithText("ok"), nil
136 },
137 }
138
139 resp, err := coord.runSubAgent(t.Context(), subAgentParams{
140 Agent: agent,
141 SessionID: parentSession.ID,
142 AgentMessageID: "msg-1",
143 ToolCallID: "call-1",
144 Prompt: "test",
145 SessionTitle: "Test",
146 })
147 require.NoError(t, err)
148 assert.Equal(t, "ok", resp.Content)
149 })
150
151 t.Run("session creation failure with canceled context", func(t *testing.T) {
152 env := testEnv(t)
153 coord := newTestCoordinator(t, env, providerID, providerCfg)
154
155 parentSession, err := env.sessions.Create(t.Context(), "Parent")
156 require.NoError(t, err)
157
158 agent := newMockAgent(providerID, 4096, nil)
159
160 // Use a canceled context to trigger CreateTaskSession failure.
161 ctx, cancel := context.WithCancel(t.Context())
162 cancel()
163
164 _, err = coord.runSubAgent(ctx, subAgentParams{
165 Agent: agent,
166 SessionID: parentSession.ID,
167 AgentMessageID: "msg-1",
168 ToolCallID: "call-1",
169 Prompt: "test",
170 SessionTitle: "Test",
171 })
172 require.Error(t, err)
173 })
174
175 t.Run("provider not configured", func(t *testing.T) {
176 env := testEnv(t)
177 coord := newTestCoordinator(t, env, providerID, providerCfg)
178
179 parentSession, err := env.sessions.Create(t.Context(), "Parent")
180 require.NoError(t, err)
181
182 // Agent references a provider that doesn't exist in config.
183 agent := newMockAgent("unknown-provider", 4096, nil)
184
185 _, err = coord.runSubAgent(t.Context(), subAgentParams{
186 Agent: agent,
187 SessionID: parentSession.ID,
188 AgentMessageID: "msg-1",
189 ToolCallID: "call-1",
190 Prompt: "test",
191 SessionTitle: "Test",
192 })
193 require.Error(t, err)
194 assert.Contains(t, err.Error(), "model provider not configured")
195 })
196
197 t.Run("agent run error returns error response", func(t *testing.T) {
198 env := testEnv(t)
199 coord := newTestCoordinator(t, env, providerID, providerCfg)
200
201 parentSession, err := env.sessions.Create(t.Context(), "Parent")
202 require.NoError(t, err)
203
204 agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
205 return nil, errors.New("provider request failed")
206 })
207
208 resp, err := coord.runSubAgent(t.Context(), subAgentParams{
209 Agent: agent,
210 SessionID: parentSession.ID,
211 AgentMessageID: "msg-1",
212 ToolCallID: "call-1",
213 Prompt: "test",
214 SessionTitle: "Test",
215 })
216 // runSubAgent returns (errorResponse, nil) when agent.Run fails — not a Go error.
217 require.NoError(t, err)
218 assert.True(t, resp.IsError)
219 assert.Equal(t, "Failed to generate response: provider request failed", resp.Content)
220 })
221
222 t.Run("session setup callback is invoked", func(t *testing.T) {
223 env := testEnv(t)
224 coord := newTestCoordinator(t, env, providerID, providerCfg)
225
226 parentSession, err := env.sessions.Create(t.Context(), "Parent")
227 require.NoError(t, err)
228
229 var setupCalledWith string
230 agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
231 return agentResultWithText("ok"), nil
232 })
233
234 _, err = coord.runSubAgent(t.Context(), subAgentParams{
235 Agent: agent,
236 SessionID: parentSession.ID,
237 AgentMessageID: "msg-1",
238 ToolCallID: "call-1",
239 Prompt: "test",
240 SessionTitle: "Test",
241 SessionSetup: func(sessionID string) {
242 setupCalledWith = sessionID
243 },
244 })
245 require.NoError(t, err)
246 assert.NotEmpty(t, setupCalledWith, "SessionSetup should have been called")
247 })
248
249 t.Run("cost propagation to parent session", func(t *testing.T) {
250 env := testEnv(t)
251 coord := newTestCoordinator(t, env, providerID, providerCfg)
252
253 parentSession, err := env.sessions.Create(t.Context(), "Parent")
254 require.NoError(t, err)
255
256 agent := newMockAgent(providerID, 4096, func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
257 // Simulate the agent incurring cost by updating the child session.
258 childSession, err := env.sessions.Get(ctx, call.SessionID)
259 if err != nil {
260 return nil, err
261 }
262 childSession.Cost = 0.05
263 _, err = env.sessions.Save(ctx, childSession)
264 if err != nil {
265 return nil, err
266 }
267 return agentResultWithText("ok"), nil
268 })
269
270 _, err = coord.runSubAgent(t.Context(), subAgentParams{
271 Agent: agent,
272 SessionID: parentSession.ID,
273 AgentMessageID: "msg-1",
274 ToolCallID: "call-1",
275 Prompt: "test",
276 SessionTitle: "Test",
277 })
278 require.NoError(t, err)
279
280 updated, err := env.sessions.Get(t.Context(), parentSession.ID)
281 require.NoError(t, err)
282 assert.InDelta(t, 0.05, updated.Cost, 1e-9)
283 })
284}
285
286func TestUpdateParentSessionCost(t *testing.T) {
287 t.Run("accumulates cost correctly", func(t *testing.T) {
288 env := testEnv(t)
289 cfg, err := config.Init(env.workingDir, "", false)
290 require.NoError(t, err)
291 coord := &coordinator{cfg: cfg, sessions: env.sessions}
292
293 parent, err := env.sessions.Create(t.Context(), "Parent")
294 require.NoError(t, err)
295
296 child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
297 require.NoError(t, err)
298
299 // Set child cost.
300 child.Cost = 0.10
301 _, err = env.sessions.Save(t.Context(), child)
302 require.NoError(t, err)
303
304 err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
305 require.NoError(t, err)
306
307 updated, err := env.sessions.Get(t.Context(), parent.ID)
308 require.NoError(t, err)
309 assert.InDelta(t, 0.10, updated.Cost, 1e-9)
310 })
311
312 t.Run("accumulates multiple child costs", func(t *testing.T) {
313 env := testEnv(t)
314 cfg, err := config.Init(env.workingDir, "", false)
315 require.NoError(t, err)
316 coord := &coordinator{cfg: cfg, sessions: env.sessions}
317
318 parent, err := env.sessions.Create(t.Context(), "Parent")
319 require.NoError(t, err)
320
321 child1, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child1")
322 require.NoError(t, err)
323 child1.Cost = 0.05
324 _, err = env.sessions.Save(t.Context(), child1)
325 require.NoError(t, err)
326
327 child2, err := env.sessions.CreateTaskSession(t.Context(), "tool-2", parent.ID, "Child2")
328 require.NoError(t, err)
329 child2.Cost = 0.03
330 _, err = env.sessions.Save(t.Context(), child2)
331 require.NoError(t, err)
332
333 err = coord.updateParentSessionCost(t.Context(), child1.ID, parent.ID)
334 require.NoError(t, err)
335 err = coord.updateParentSessionCost(t.Context(), child2.ID, parent.ID)
336 require.NoError(t, err)
337
338 updated, err := env.sessions.Get(t.Context(), parent.ID)
339 require.NoError(t, err)
340 assert.InDelta(t, 0.08, updated.Cost, 1e-9)
341 })
342
343 t.Run("child session not found", func(t *testing.T) {
344 env := testEnv(t)
345 cfg, err := config.Init(env.workingDir, "", false)
346 require.NoError(t, err)
347 coord := &coordinator{cfg: cfg, sessions: env.sessions}
348
349 parent, err := env.sessions.Create(t.Context(), "Parent")
350 require.NoError(t, err)
351
352 err = coord.updateParentSessionCost(t.Context(), "non-existent", parent.ID)
353 require.Error(t, err)
354 assert.Contains(t, err.Error(), "get child session")
355 })
356
357 t.Run("parent session not found", func(t *testing.T) {
358 env := testEnv(t)
359 cfg, err := config.Init(env.workingDir, "", false)
360 require.NoError(t, err)
361 coord := &coordinator{cfg: cfg, sessions: env.sessions}
362
363 parent, err := env.sessions.Create(t.Context(), "Parent")
364 require.NoError(t, err)
365 child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
366 require.NoError(t, err)
367
368 err = coord.updateParentSessionCost(t.Context(), child.ID, "non-existent")
369 require.Error(t, err)
370 assert.Contains(t, err.Error(), "get parent session")
371 })
372
373 t.Run("zero cost handled correctly", func(t *testing.T) {
374 env := testEnv(t)
375 cfg, err := config.Init(env.workingDir, "", false)
376 require.NoError(t, err)
377 coord := &coordinator{cfg: cfg, sessions: env.sessions}
378
379 parent, err := env.sessions.Create(t.Context(), "Parent")
380 require.NoError(t, err)
381 child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
382 require.NoError(t, err)
383
384 err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
385 require.NoError(t, err)
386
387 updated, err := env.sessions.Get(t.Context(), parent.ID)
388 require.NoError(t, err)
389 assert.InDelta(t, 0.0, updated.Cost, 1e-9)
390 })
391}
392
393func TestGetProviderOptionsReasoningEffort(t *testing.T) {
394 // Bedrock is Fantasy's Anthropic under a different provider name; options
395 // must land under anthropic.Name so the Anthropic language model picks them up.
396 tests := []struct {
397 name string
398 providerType catwalk.Type
399 }{
400 {"anthropic honors reasoning_effort", catwalk.Type(anthropic.Name)},
401 {"bedrock honors reasoning_effort", catwalk.Type(bedrock.Name)},
402 }
403 for _, tc := range tests {
404 t.Run(tc.name, func(t *testing.T) {
405 model := Model{
406 CatwalkCfg: catwalk.Model{
407 ID: "claude-opus-4-7",
408 CanReason: true,
409 ReasoningLevels: []string{"max"},
410 },
411 ModelCfg: config.SelectedModel{
412 Provider: "test",
413 ReasoningEffort: "max",
414 },
415 }
416 providerCfg := config.ProviderConfig{ID: "test", Type: tc.providerType}
417
418 opts := getProviderOptions(model, providerCfg)
419
420 raw, ok := opts[anthropic.Name]
421 require.True(t, ok, "options should be keyed under anthropic.Name for type %q", tc.providerType)
422 parsed, ok := raw.(*anthropic.ProviderOptions)
423 require.True(t, ok)
424 require.NotNil(t, parsed.Effort)
425 assert.Equal(t, anthropic.Effort("max"), *parsed.Effort)
426 })
427 }
428}