1package agent
2
3import (
4 "context"
5 "errors"
6 "testing"
7
8 "charm.land/catwalk/pkg/catwalk"
9 "charm.land/fantasy"
10 "github.com/charmbracelet/crush/internal/config"
11 "github.com/stretchr/testify/assert"
12 "github.com/stretchr/testify/require"
13)
14
15// mockSessionAgent is a minimal mock for the SessionAgent interface.
16type mockSessionAgent struct {
17 model Model
18 runFunc func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error)
19 cancelled []string
20}
21
22func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
23 return m.runFunc(ctx, call)
24}
25
26func (m *mockSessionAgent) Model() Model { return m.model }
27func (m *mockSessionAgent) SetModels(large, small Model) {}
28func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool) {}
29func (m *mockSessionAgent) SetSystemPrompt(systemPrompt string) {}
30func (m *mockSessionAgent) Cancel(sessionID string) {
31 m.cancelled = append(m.cancelled, sessionID)
32}
33func (m *mockSessionAgent) CancelAll() {}
34func (m *mockSessionAgent) IsSessionBusy(sessionID string) bool { return false }
35func (m *mockSessionAgent) IsBusy() bool { return false }
36func (m *mockSessionAgent) QueuedPrompts(sessionID string) int { return 0 }
37func (m *mockSessionAgent) QueuedPromptsList(sessionID string) []string { return nil }
38func (m *mockSessionAgent) ClearQueue(sessionID string) {}
39func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOptions) error {
40 return nil
41}
42
43// newTestCoordinator creates a minimal coordinator for unit testing runSubAgent.
44func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator {
45 cfg, err := config.Init(env.workingDir, "", false)
46 require.NoError(t, err)
47 cfg.Config().Providers.Set(providerID, providerCfg)
48 return &coordinator{
49 cfg: cfg,
50 sessions: env.sessions,
51 }
52}
53
54// newMockAgent creates a mockSessionAgent with the given provider and run function.
55func newMockAgent(providerID string, maxTokens int64, runFunc func(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)) *mockSessionAgent {
56 return &mockSessionAgent{
57 model: Model{
58 CatwalkCfg: catwalk.Model{
59 DefaultMaxTokens: maxTokens,
60 },
61 ModelCfg: config.SelectedModel{
62 Provider: providerID,
63 },
64 },
65 runFunc: runFunc,
66 }
67}
68
69// agentResultWithText creates a minimal AgentResult with the given text response.
70func agentResultWithText(text string) *fantasy.AgentResult {
71 return &fantasy.AgentResult{
72 Response: fantasy.Response{
73 Content: fantasy.ResponseContent{
74 fantasy.TextContent{Text: text},
75 },
76 },
77 }
78}
79
80func TestRunSubAgent(t *testing.T) {
81 const providerID = "test-provider"
82 providerCfg := config.ProviderConfig{ID: providerID}
83
84 t.Run("happy path", func(t *testing.T) {
85 env := testEnv(t)
86 coord := newTestCoordinator(t, env, providerID, providerCfg)
87
88 parentSession, err := env.sessions.Create(t.Context(), "Parent")
89 require.NoError(t, err)
90
91 agent := newMockAgent(providerID, 4096, func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
92 assert.Equal(t, "do something", call.Prompt)
93 assert.Equal(t, int64(4096), call.MaxOutputTokens)
94 return agentResultWithText("done"), nil
95 })
96
97 resp, err := coord.runSubAgent(t.Context(), subAgentParams{
98 Agent: agent,
99 SessionID: parentSession.ID,
100 AgentMessageID: "msg-1",
101 ToolCallID: "call-1",
102 Prompt: "do something",
103 SessionTitle: "Test Session",
104 })
105 require.NoError(t, err)
106 assert.Equal(t, "done", resp.Content)
107 assert.False(t, resp.IsError)
108 })
109
110 t.Run("ModelCfg.MaxTokens overrides default", func(t *testing.T) {
111 env := testEnv(t)
112 coord := newTestCoordinator(t, env, providerID, providerCfg)
113
114 parentSession, err := env.sessions.Create(t.Context(), "Parent")
115 require.NoError(t, err)
116
117 agent := &mockSessionAgent{
118 model: Model{
119 CatwalkCfg: catwalk.Model{
120 DefaultMaxTokens: 4096,
121 },
122 ModelCfg: config.SelectedModel{
123 Provider: providerID,
124 MaxTokens: 8192,
125 },
126 },
127 runFunc: func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
128 assert.Equal(t, int64(8192), call.MaxOutputTokens)
129 return agentResultWithText("ok"), nil
130 },
131 }
132
133 resp, err := coord.runSubAgent(t.Context(), subAgentParams{
134 Agent: agent,
135 SessionID: parentSession.ID,
136 AgentMessageID: "msg-1",
137 ToolCallID: "call-1",
138 Prompt: "test",
139 SessionTitle: "Test",
140 })
141 require.NoError(t, err)
142 assert.Equal(t, "ok", resp.Content)
143 })
144
145 t.Run("session creation failure with canceled context", func(t *testing.T) {
146 env := testEnv(t)
147 coord := newTestCoordinator(t, env, providerID, providerCfg)
148
149 parentSession, err := env.sessions.Create(t.Context(), "Parent")
150 require.NoError(t, err)
151
152 agent := newMockAgent(providerID, 4096, nil)
153
154 // Use a canceled context to trigger CreateTaskSession failure.
155 ctx, cancel := context.WithCancel(t.Context())
156 cancel()
157
158 _, err = coord.runSubAgent(ctx, subAgentParams{
159 Agent: agent,
160 SessionID: parentSession.ID,
161 AgentMessageID: "msg-1",
162 ToolCallID: "call-1",
163 Prompt: "test",
164 SessionTitle: "Test",
165 })
166 require.Error(t, err)
167 })
168
169 t.Run("provider not configured", func(t *testing.T) {
170 env := testEnv(t)
171 coord := newTestCoordinator(t, env, providerID, providerCfg)
172
173 parentSession, err := env.sessions.Create(t.Context(), "Parent")
174 require.NoError(t, err)
175
176 // Agent references a provider that doesn't exist in config.
177 agent := newMockAgent("unknown-provider", 4096, nil)
178
179 _, err = coord.runSubAgent(t.Context(), subAgentParams{
180 Agent: agent,
181 SessionID: parentSession.ID,
182 AgentMessageID: "msg-1",
183 ToolCallID: "call-1",
184 Prompt: "test",
185 SessionTitle: "Test",
186 })
187 require.Error(t, err)
188 assert.Contains(t, err.Error(), "model provider not configured")
189 })
190
191 t.Run("agent run error returns error response", func(t *testing.T) {
192 env := testEnv(t)
193 coord := newTestCoordinator(t, env, providerID, providerCfg)
194
195 parentSession, err := env.sessions.Create(t.Context(), "Parent")
196 require.NoError(t, err)
197
198 agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
199 return nil, errors.New("agent exploded")
200 })
201
202 resp, err := coord.runSubAgent(t.Context(), subAgentParams{
203 Agent: agent,
204 SessionID: parentSession.ID,
205 AgentMessageID: "msg-1",
206 ToolCallID: "call-1",
207 Prompt: "test",
208 SessionTitle: "Test",
209 })
210 // runSubAgent returns (errorResponse, nil) when agent.Run fails — not a Go error.
211 require.NoError(t, err)
212 assert.True(t, resp.IsError)
213 assert.Equal(t, "error generating response", resp.Content)
214 })
215
216 t.Run("session setup callback is invoked", func(t *testing.T) {
217 env := testEnv(t)
218 coord := newTestCoordinator(t, env, providerID, providerCfg)
219
220 parentSession, err := env.sessions.Create(t.Context(), "Parent")
221 require.NoError(t, err)
222
223 var setupCalledWith string
224 agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
225 return agentResultWithText("ok"), nil
226 })
227
228 _, err = coord.runSubAgent(t.Context(), subAgentParams{
229 Agent: agent,
230 SessionID: parentSession.ID,
231 AgentMessageID: "msg-1",
232 ToolCallID: "call-1",
233 Prompt: "test",
234 SessionTitle: "Test",
235 SessionSetup: func(sessionID string) {
236 setupCalledWith = sessionID
237 },
238 })
239 require.NoError(t, err)
240 assert.NotEmpty(t, setupCalledWith, "SessionSetup should have been called")
241 })
242
243 t.Run("cost propagation to parent session", func(t *testing.T) {
244 env := testEnv(t)
245 coord := newTestCoordinator(t, env, providerID, providerCfg)
246
247 parentSession, err := env.sessions.Create(t.Context(), "Parent")
248 require.NoError(t, err)
249
250 agent := newMockAgent(providerID, 4096, func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
251 // Simulate the agent incurring cost by updating the child session.
252 childSession, err := env.sessions.Get(ctx, call.SessionID)
253 if err != nil {
254 return nil, err
255 }
256 childSession.Cost = 0.05
257 _, err = env.sessions.Save(ctx, childSession)
258 if err != nil {
259 return nil, err
260 }
261 return agentResultWithText("ok"), nil
262 })
263
264 _, err = coord.runSubAgent(t.Context(), subAgentParams{
265 Agent: agent,
266 SessionID: parentSession.ID,
267 AgentMessageID: "msg-1",
268 ToolCallID: "call-1",
269 Prompt: "test",
270 SessionTitle: "Test",
271 })
272 require.NoError(t, err)
273
274 updated, err := env.sessions.Get(t.Context(), parentSession.ID)
275 require.NoError(t, err)
276 assert.InDelta(t, 0.05, updated.Cost, 1e-9)
277 })
278}
279
280func TestUpdateParentSessionCost(t *testing.T) {
281 t.Run("accumulates cost correctly", func(t *testing.T) {
282 env := testEnv(t)
283 cfg, err := config.Init(env.workingDir, "", false)
284 require.NoError(t, err)
285 coord := &coordinator{cfg: cfg, sessions: env.sessions}
286
287 parent, err := env.sessions.Create(t.Context(), "Parent")
288 require.NoError(t, err)
289
290 child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
291 require.NoError(t, err)
292
293 // Set child cost.
294 child.Cost = 0.10
295 _, err = env.sessions.Save(t.Context(), child)
296 require.NoError(t, err)
297
298 err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
299 require.NoError(t, err)
300
301 updated, err := env.sessions.Get(t.Context(), parent.ID)
302 require.NoError(t, err)
303 assert.InDelta(t, 0.10, updated.Cost, 1e-9)
304 })
305
306 t.Run("accumulates multiple child costs", func(t *testing.T) {
307 env := testEnv(t)
308 cfg, err := config.Init(env.workingDir, "", false)
309 require.NoError(t, err)
310 coord := &coordinator{cfg: cfg, sessions: env.sessions}
311
312 parent, err := env.sessions.Create(t.Context(), "Parent")
313 require.NoError(t, err)
314
315 child1, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child1")
316 require.NoError(t, err)
317 child1.Cost = 0.05
318 _, err = env.sessions.Save(t.Context(), child1)
319 require.NoError(t, err)
320
321 child2, err := env.sessions.CreateTaskSession(t.Context(), "tool-2", parent.ID, "Child2")
322 require.NoError(t, err)
323 child2.Cost = 0.03
324 _, err = env.sessions.Save(t.Context(), child2)
325 require.NoError(t, err)
326
327 err = coord.updateParentSessionCost(t.Context(), child1.ID, parent.ID)
328 require.NoError(t, err)
329 err = coord.updateParentSessionCost(t.Context(), child2.ID, parent.ID)
330 require.NoError(t, err)
331
332 updated, err := env.sessions.Get(t.Context(), parent.ID)
333 require.NoError(t, err)
334 assert.InDelta(t, 0.08, updated.Cost, 1e-9)
335 })
336
337 t.Run("child session not found", func(t *testing.T) {
338 env := testEnv(t)
339 cfg, err := config.Init(env.workingDir, "", false)
340 require.NoError(t, err)
341 coord := &coordinator{cfg: cfg, sessions: env.sessions}
342
343 parent, err := env.sessions.Create(t.Context(), "Parent")
344 require.NoError(t, err)
345
346 err = coord.updateParentSessionCost(t.Context(), "non-existent", parent.ID)
347 require.Error(t, err)
348 assert.Contains(t, err.Error(), "get child session")
349 })
350
351 t.Run("parent session not found", func(t *testing.T) {
352 env := testEnv(t)
353 cfg, err := config.Init(env.workingDir, "", false)
354 require.NoError(t, err)
355 coord := &coordinator{cfg: cfg, sessions: env.sessions}
356
357 parent, err := env.sessions.Create(t.Context(), "Parent")
358 require.NoError(t, err)
359 child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
360 require.NoError(t, err)
361
362 err = coord.updateParentSessionCost(t.Context(), child.ID, "non-existent")
363 require.Error(t, err)
364 assert.Contains(t, err.Error(), "get parent session")
365 })
366
367 t.Run("zero cost handled correctly", func(t *testing.T) {
368 env := testEnv(t)
369 cfg, err := config.Init(env.workingDir, "", false)
370 require.NoError(t, err)
371 coord := &coordinator{cfg: cfg, sessions: env.sessions}
372
373 parent, err := env.sessions.Create(t.Context(), "Parent")
374 require.NoError(t, err)
375 child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
376 require.NoError(t, err)
377
378 err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
379 require.NoError(t, err)
380
381 updated, err := env.sessions.Get(t.Context(), parent.ID)
382 require.NoError(t, err)
383 assert.InDelta(t, 0.0, updated.Cost, 1e-9)
384 })
385}