1package agent
2
3import (
4 "context"
5 "errors"
6 "testing"
7
8 "charm.land/catwalk/pkg/catwalk"
9 "charm.land/fantasy"
10 "charm.land/fantasy/providers/anthropic"
11 "charm.land/fantasy/providers/bedrock"
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/stretchr/testify/assert"
14 "github.com/stretchr/testify/require"
15)
16
17// mockSessionAgent is a minimal mock for the SessionAgent interface.
18type mockSessionAgent struct {
19 model Model
20 runFunc func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error)
21 cancelled []string
22}
23
24func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
25 return m.runFunc(ctx, call)
26}
27
28func (m *mockSessionAgent) Model() Model { return m.model }
29func (m *mockSessionAgent) SetModels(large, small Model) {}
30func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool) {}
31func (m *mockSessionAgent) SetSystemPrompt(systemPrompt string) {}
32func (m *mockSessionAgent) Cancel(sessionID string) {
33 m.cancelled = append(m.cancelled, sessionID)
34}
35func (m *mockSessionAgent) CancelAll() {}
36func (m *mockSessionAgent) IsSessionBusy(sessionID string) bool { return false }
37func (m *mockSessionAgent) IsBusy() bool { return false }
38func (m *mockSessionAgent) QueuedPrompts(sessionID string) int { return 0 }
39func (m *mockSessionAgent) QueuedPromptsList(sessionID string) []string { return nil }
40func (m *mockSessionAgent) ClearQueue(sessionID string) {}
41func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOptions) error {
42 return nil
43}
44
45// newTestCoordinator creates a minimal coordinator for unit testing runSubAgent.
46func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator {
47 cfg, err := config.Init(env.workingDir, "", false)
48 require.NoError(t, err)
49 cfg.Config().Providers.Set(providerID, providerCfg)
50 return &coordinator{
51 cfg: cfg,
52 sessions: env.sessions,
53 }
54}
55
56// newMockAgent creates a mockSessionAgent with the given provider and run function.
57func newMockAgent(providerID string, maxTokens int64, runFunc func(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)) *mockSessionAgent {
58 return &mockSessionAgent{
59 model: Model{
60 CatwalkCfg: catwalk.Model{
61 DefaultMaxTokens: maxTokens,
62 },
63 ModelCfg: config.SelectedModel{
64 Provider: providerID,
65 },
66 },
67 runFunc: runFunc,
68 }
69}
70
71// agentResultWithText creates a minimal AgentResult with the given text response.
72func agentResultWithText(text string) *fantasy.AgentResult {
73 return &fantasy.AgentResult{
74 Response: fantasy.Response{
75 Content: fantasy.ResponseContent{
76 fantasy.TextContent{Text: text},
77 },
78 },
79 }
80}
81
82func TestRunSubAgent(t *testing.T) {
83 const providerID = "test-provider"
84 providerCfg := config.ProviderConfig{ID: providerID}
85
86 t.Run("happy path", func(t *testing.T) {
87 env := testEnv(t)
88 coord := newTestCoordinator(t, env, providerID, providerCfg)
89
90 parentSession, err := env.sessions.Create(t.Context(), "Parent")
91 require.NoError(t, err)
92
93 agent := newMockAgent(providerID, 4096, func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
94 assert.Equal(t, "do something", call.Prompt)
95 assert.Equal(t, int64(4096), call.MaxOutputTokens)
96 return agentResultWithText("done"), nil
97 })
98
99 resp, err := coord.runSubAgent(t.Context(), subAgentParams{
100 Agent: agent,
101 SessionID: parentSession.ID,
102 AgentMessageID: "msg-1",
103 ToolCallID: "call-1",
104 Prompt: "do something",
105 SessionTitle: "Test Session",
106 })
107 require.NoError(t, err)
108 assert.Equal(t, "done", resp.Content)
109 assert.False(t, resp.IsError)
110 })
111
112 t.Run("ModelCfg.MaxTokens overrides default", func(t *testing.T) {
113 env := testEnv(t)
114 coord := newTestCoordinator(t, env, providerID, providerCfg)
115
116 parentSession, err := env.sessions.Create(t.Context(), "Parent")
117 require.NoError(t, err)
118
119 agent := &mockSessionAgent{
120 model: Model{
121 CatwalkCfg: catwalk.Model{
122 DefaultMaxTokens: 4096,
123 },
124 ModelCfg: config.SelectedModel{
125 Provider: providerID,
126 MaxTokens: 8192,
127 },
128 },
129 runFunc: func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
130 assert.Equal(t, int64(8192), call.MaxOutputTokens)
131 return agentResultWithText("ok"), nil
132 },
133 }
134
135 resp, err := coord.runSubAgent(t.Context(), subAgentParams{
136 Agent: agent,
137 SessionID: parentSession.ID,
138 AgentMessageID: "msg-1",
139 ToolCallID: "call-1",
140 Prompt: "test",
141 SessionTitle: "Test",
142 })
143 require.NoError(t, err)
144 assert.Equal(t, "ok", resp.Content)
145 })
146
147 t.Run("session creation failure with canceled context", func(t *testing.T) {
148 env := testEnv(t)
149 coord := newTestCoordinator(t, env, providerID, providerCfg)
150
151 parentSession, err := env.sessions.Create(t.Context(), "Parent")
152 require.NoError(t, err)
153
154 agent := newMockAgent(providerID, 4096, nil)
155
156 // Use a canceled context to trigger CreateTaskSession failure.
157 ctx, cancel := context.WithCancel(t.Context())
158 cancel()
159
160 _, err = coord.runSubAgent(ctx, subAgentParams{
161 Agent: agent,
162 SessionID: parentSession.ID,
163 AgentMessageID: "msg-1",
164 ToolCallID: "call-1",
165 Prompt: "test",
166 SessionTitle: "Test",
167 })
168 require.Error(t, err)
169 })
170
171 t.Run("provider not configured", func(t *testing.T) {
172 env := testEnv(t)
173 coord := newTestCoordinator(t, env, providerID, providerCfg)
174
175 parentSession, err := env.sessions.Create(t.Context(), "Parent")
176 require.NoError(t, err)
177
178 // Agent references a provider that doesn't exist in config.
179 agent := newMockAgent("unknown-provider", 4096, nil)
180
181 _, err = coord.runSubAgent(t.Context(), subAgentParams{
182 Agent: agent,
183 SessionID: parentSession.ID,
184 AgentMessageID: "msg-1",
185 ToolCallID: "call-1",
186 Prompt: "test",
187 SessionTitle: "Test",
188 })
189 require.Error(t, err)
190 assert.Contains(t, err.Error(), "model provider not configured")
191 })
192
193 t.Run("agent run error returns error response", func(t *testing.T) {
194 env := testEnv(t)
195 coord := newTestCoordinator(t, env, providerID, providerCfg)
196
197 parentSession, err := env.sessions.Create(t.Context(), "Parent")
198 require.NoError(t, err)
199
200 agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
201 return nil, errors.New("provider request failed")
202 })
203
204 resp, err := coord.runSubAgent(t.Context(), subAgentParams{
205 Agent: agent,
206 SessionID: parentSession.ID,
207 AgentMessageID: "msg-1",
208 ToolCallID: "call-1",
209 Prompt: "test",
210 SessionTitle: "Test",
211 })
212 // runSubAgent returns (errorResponse, nil) when agent.Run fails — not a Go error.
213 require.NoError(t, err)
214 assert.True(t, resp.IsError)
215 assert.Equal(t, "Failed to generate response: provider request failed", resp.Content)
216 })
217
218 t.Run("session setup callback is invoked", func(t *testing.T) {
219 env := testEnv(t)
220 coord := newTestCoordinator(t, env, providerID, providerCfg)
221
222 parentSession, err := env.sessions.Create(t.Context(), "Parent")
223 require.NoError(t, err)
224
225 var setupCalledWith string
226 agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
227 return agentResultWithText("ok"), nil
228 })
229
230 _, err = coord.runSubAgent(t.Context(), subAgentParams{
231 Agent: agent,
232 SessionID: parentSession.ID,
233 AgentMessageID: "msg-1",
234 ToolCallID: "call-1",
235 Prompt: "test",
236 SessionTitle: "Test",
237 SessionSetup: func(sessionID string) {
238 setupCalledWith = sessionID
239 },
240 })
241 require.NoError(t, err)
242 assert.NotEmpty(t, setupCalledWith, "SessionSetup should have been called")
243 })
244
245 t.Run("cost propagation to parent session", func(t *testing.T) {
246 env := testEnv(t)
247 coord := newTestCoordinator(t, env, providerID, providerCfg)
248
249 parentSession, err := env.sessions.Create(t.Context(), "Parent")
250 require.NoError(t, err)
251
252 agent := newMockAgent(providerID, 4096, func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
253 // Simulate the agent incurring cost by updating the child session.
254 childSession, err := env.sessions.Get(ctx, call.SessionID)
255 if err != nil {
256 return nil, err
257 }
258 childSession.Cost = 0.05
259 _, err = env.sessions.Save(ctx, childSession)
260 if err != nil {
261 return nil, err
262 }
263 return agentResultWithText("ok"), nil
264 })
265
266 _, err = coord.runSubAgent(t.Context(), subAgentParams{
267 Agent: agent,
268 SessionID: parentSession.ID,
269 AgentMessageID: "msg-1",
270 ToolCallID: "call-1",
271 Prompt: "test",
272 SessionTitle: "Test",
273 })
274 require.NoError(t, err)
275
276 updated, err := env.sessions.Get(t.Context(), parentSession.ID)
277 require.NoError(t, err)
278 assert.InDelta(t, 0.05, updated.Cost, 1e-9)
279 })
280}
281
282func TestUpdateParentSessionCost(t *testing.T) {
283 t.Run("accumulates cost correctly", func(t *testing.T) {
284 env := testEnv(t)
285 cfg, err := config.Init(env.workingDir, "", false)
286 require.NoError(t, err)
287 coord := &coordinator{cfg: cfg, sessions: env.sessions}
288
289 parent, err := env.sessions.Create(t.Context(), "Parent")
290 require.NoError(t, err)
291
292 child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
293 require.NoError(t, err)
294
295 // Set child cost.
296 child.Cost = 0.10
297 _, err = env.sessions.Save(t.Context(), child)
298 require.NoError(t, err)
299
300 err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
301 require.NoError(t, err)
302
303 updated, err := env.sessions.Get(t.Context(), parent.ID)
304 require.NoError(t, err)
305 assert.InDelta(t, 0.10, updated.Cost, 1e-9)
306 })
307
308 t.Run("accumulates multiple child costs", func(t *testing.T) {
309 env := testEnv(t)
310 cfg, err := config.Init(env.workingDir, "", false)
311 require.NoError(t, err)
312 coord := &coordinator{cfg: cfg, sessions: env.sessions}
313
314 parent, err := env.sessions.Create(t.Context(), "Parent")
315 require.NoError(t, err)
316
317 child1, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child1")
318 require.NoError(t, err)
319 child1.Cost = 0.05
320 _, err = env.sessions.Save(t.Context(), child1)
321 require.NoError(t, err)
322
323 child2, err := env.sessions.CreateTaskSession(t.Context(), "tool-2", parent.ID, "Child2")
324 require.NoError(t, err)
325 child2.Cost = 0.03
326 _, err = env.sessions.Save(t.Context(), child2)
327 require.NoError(t, err)
328
329 err = coord.updateParentSessionCost(t.Context(), child1.ID, parent.ID)
330 require.NoError(t, err)
331 err = coord.updateParentSessionCost(t.Context(), child2.ID, parent.ID)
332 require.NoError(t, err)
333
334 updated, err := env.sessions.Get(t.Context(), parent.ID)
335 require.NoError(t, err)
336 assert.InDelta(t, 0.08, updated.Cost, 1e-9)
337 })
338
339 t.Run("child session not found", func(t *testing.T) {
340 env := testEnv(t)
341 cfg, err := config.Init(env.workingDir, "", false)
342 require.NoError(t, err)
343 coord := &coordinator{cfg: cfg, sessions: env.sessions}
344
345 parent, err := env.sessions.Create(t.Context(), "Parent")
346 require.NoError(t, err)
347
348 err = coord.updateParentSessionCost(t.Context(), "non-existent", parent.ID)
349 require.Error(t, err)
350 assert.Contains(t, err.Error(), "get child session")
351 })
352
353 t.Run("parent session not found", func(t *testing.T) {
354 env := testEnv(t)
355 cfg, err := config.Init(env.workingDir, "", false)
356 require.NoError(t, err)
357 coord := &coordinator{cfg: cfg, sessions: env.sessions}
358
359 parent, err := env.sessions.Create(t.Context(), "Parent")
360 require.NoError(t, err)
361 child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
362 require.NoError(t, err)
363
364 err = coord.updateParentSessionCost(t.Context(), child.ID, "non-existent")
365 require.Error(t, err)
366 assert.Contains(t, err.Error(), "get parent session")
367 })
368
369 t.Run("zero cost handled correctly", func(t *testing.T) {
370 env := testEnv(t)
371 cfg, err := config.Init(env.workingDir, "", false)
372 require.NoError(t, err)
373 coord := &coordinator{cfg: cfg, sessions: env.sessions}
374
375 parent, err := env.sessions.Create(t.Context(), "Parent")
376 require.NoError(t, err)
377 child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
378 require.NoError(t, err)
379
380 err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
381 require.NoError(t, err)
382
383 updated, err := env.sessions.Get(t.Context(), parent.ID)
384 require.NoError(t, err)
385 assert.InDelta(t, 0.0, updated.Cost, 1e-9)
386 })
387}
388
389func TestGetProviderOptionsReasoningEffort(t *testing.T) {
390 // Bedrock is Fantasy's Anthropic under a different provider name; options
391 // must land under anthropic.Name so the Anthropic language model picks them up.
392 tests := []struct {
393 name string
394 providerType catwalk.Type
395 }{
396 {"anthropic honors reasoning_effort", catwalk.Type(anthropic.Name)},
397 {"bedrock honors reasoning_effort", catwalk.Type(bedrock.Name)},
398 }
399 for _, tc := range tests {
400 t.Run(tc.name, func(t *testing.T) {
401 model := Model{
402 CatwalkCfg: catwalk.Model{
403 ID: "claude-opus-4-7",
404 CanReason: true,
405 ReasoningLevels: []string{"max"},
406 },
407 ModelCfg: config.SelectedModel{
408 Provider: "test",
409 ReasoningEffort: "max",
410 },
411 }
412 providerCfg := config.ProviderConfig{ID: "test", Type: tc.providerType}
413
414 opts := getProviderOptions(model, providerCfg)
415
416 raw, ok := opts[anthropic.Name]
417 require.True(t, ok, "options should be keyed under anthropic.Name for type %q", tc.providerType)
418 parsed, ok := raw.(*anthropic.ProviderOptions)
419 require.True(t, ok)
420 require.NotNil(t, parsed.Effort)
421 assert.Equal(t, anthropic.Effort("max"), *parsed.Effort)
422 })
423 }
424}