coordinator_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"testing"
  7
  8	"charm.land/catwalk/pkg/catwalk"
  9	"charm.land/fantasy"
 10	"github.com/charmbracelet/crush/internal/config"
 11	"github.com/stretchr/testify/assert"
 12	"github.com/stretchr/testify/require"
 13)
 14
 15// mockSessionAgent is a minimal mock for the SessionAgent interface.
 16type mockSessionAgent struct {
 17	model     Model
 18	runFunc   func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error)
 19	cancelled []string
 20}
 21
 22func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 23	return m.runFunc(ctx, call)
 24}
 25
 26func (m *mockSessionAgent) Model() Model                        { return m.model }
 27func (m *mockSessionAgent) SetModels(large, small Model)        {}
 28func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool)  {}
 29func (m *mockSessionAgent) SetSystemPrompt(systemPrompt string) {}
 30func (m *mockSessionAgent) Cancel(sessionID string) {
 31	m.cancelled = append(m.cancelled, sessionID)
 32}
 33func (m *mockSessionAgent) CancelAll()                                  {}
 34func (m *mockSessionAgent) IsSessionBusy(sessionID string) bool         { return false }
 35func (m *mockSessionAgent) IsBusy() bool                                { return false }
 36func (m *mockSessionAgent) QueuedPrompts(sessionID string) int          { return 0 }
 37func (m *mockSessionAgent) QueuedPromptsList(sessionID string) []string { return nil }
 38func (m *mockSessionAgent) ClearQueue(sessionID string)                 {}
 39func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOptions) error {
 40	return nil
 41}
 42
 43// newTestCoordinator creates a minimal coordinator for unit testing runSubAgent.
 44func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator {
 45	cfg, err := config.Init(env.workingDir, "", false)
 46	require.NoError(t, err)
 47	cfg.Config().Providers.Set(providerID, providerCfg)
 48	return &coordinator{
 49		cfg:      cfg,
 50		sessions: env.sessions,
 51	}
 52}
 53
 54// newMockAgent creates a mockSessionAgent with the given provider and run function.
 55func newMockAgent(providerID string, maxTokens int64, runFunc func(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)) *mockSessionAgent {
 56	return &mockSessionAgent{
 57		model: Model{
 58			CatwalkCfg: catwalk.Model{
 59				DefaultMaxTokens: maxTokens,
 60			},
 61			ModelCfg: config.SelectedModel{
 62				Provider: providerID,
 63			},
 64		},
 65		runFunc: runFunc,
 66	}
 67}
 68
 69// agentResultWithText creates a minimal AgentResult with the given text response.
 70func agentResultWithText(text string) *fantasy.AgentResult {
 71	return &fantasy.AgentResult{
 72		Response: fantasy.Response{
 73			Content: fantasy.ResponseContent{
 74				fantasy.TextContent{Text: text},
 75			},
 76		},
 77	}
 78}
 79
 80func TestRunSubAgent(t *testing.T) {
 81	const providerID = "test-provider"
 82	providerCfg := config.ProviderConfig{ID: providerID}
 83
 84	t.Run("happy path", func(t *testing.T) {
 85		env := testEnv(t)
 86		coord := newTestCoordinator(t, env, providerID, providerCfg)
 87
 88		parentSession, err := env.sessions.Create(t.Context(), "Parent")
 89		require.NoError(t, err)
 90
 91		agent := newMockAgent(providerID, 4096, func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 92			assert.Equal(t, "do something", call.Prompt)
 93			assert.Equal(t, int64(4096), call.MaxOutputTokens)
 94			return agentResultWithText("done"), nil
 95		})
 96
 97		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
 98			Agent:          agent,
 99			SessionID:      parentSession.ID,
100			AgentMessageID: "msg-1",
101			ToolCallID:     "call-1",
102			Prompt:         "do something",
103			SessionTitle:   "Test Session",
104		})
105		require.NoError(t, err)
106		assert.Equal(t, "done", resp.Content)
107		assert.False(t, resp.IsError)
108	})
109
110	t.Run("ModelCfg.MaxTokens overrides default", func(t *testing.T) {
111		env := testEnv(t)
112		coord := newTestCoordinator(t, env, providerID, providerCfg)
113
114		parentSession, err := env.sessions.Create(t.Context(), "Parent")
115		require.NoError(t, err)
116
117		agent := &mockSessionAgent{
118			model: Model{
119				CatwalkCfg: catwalk.Model{
120					DefaultMaxTokens: 4096,
121				},
122				ModelCfg: config.SelectedModel{
123					Provider:  providerID,
124					MaxTokens: 8192,
125				},
126			},
127			runFunc: func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
128				assert.Equal(t, int64(8192), call.MaxOutputTokens)
129				return agentResultWithText("ok"), nil
130			},
131		}
132
133		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
134			Agent:          agent,
135			SessionID:      parentSession.ID,
136			AgentMessageID: "msg-1",
137			ToolCallID:     "call-1",
138			Prompt:         "test",
139			SessionTitle:   "Test",
140		})
141		require.NoError(t, err)
142		assert.Equal(t, "ok", resp.Content)
143	})
144
145	t.Run("session creation failure with canceled context", func(t *testing.T) {
146		env := testEnv(t)
147		coord := newTestCoordinator(t, env, providerID, providerCfg)
148
149		parentSession, err := env.sessions.Create(t.Context(), "Parent")
150		require.NoError(t, err)
151
152		agent := newMockAgent(providerID, 4096, nil)
153
154		// Use a canceled context to trigger CreateTaskSession failure.
155		ctx, cancel := context.WithCancel(t.Context())
156		cancel()
157
158		_, err = coord.runSubAgent(ctx, subAgentParams{
159			Agent:          agent,
160			SessionID:      parentSession.ID,
161			AgentMessageID: "msg-1",
162			ToolCallID:     "call-1",
163			Prompt:         "test",
164			SessionTitle:   "Test",
165		})
166		require.Error(t, err)
167	})
168
169	t.Run("provider not configured", func(t *testing.T) {
170		env := testEnv(t)
171		coord := newTestCoordinator(t, env, providerID, providerCfg)
172
173		parentSession, err := env.sessions.Create(t.Context(), "Parent")
174		require.NoError(t, err)
175
176		// Agent references a provider that doesn't exist in config.
177		agent := newMockAgent("unknown-provider", 4096, nil)
178
179		_, err = coord.runSubAgent(t.Context(), subAgentParams{
180			Agent:          agent,
181			SessionID:      parentSession.ID,
182			AgentMessageID: "msg-1",
183			ToolCallID:     "call-1",
184			Prompt:         "test",
185			SessionTitle:   "Test",
186		})
187		require.Error(t, err)
188		assert.Contains(t, err.Error(), "model provider not configured")
189	})
190
191	t.Run("agent run error returns error response", func(t *testing.T) {
192		env := testEnv(t)
193		coord := newTestCoordinator(t, env, providerID, providerCfg)
194
195		parentSession, err := env.sessions.Create(t.Context(), "Parent")
196		require.NoError(t, err)
197
198		agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
199			return nil, errors.New("agent exploded")
200		})
201
202		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
203			Agent:          agent,
204			SessionID:      parentSession.ID,
205			AgentMessageID: "msg-1",
206			ToolCallID:     "call-1",
207			Prompt:         "test",
208			SessionTitle:   "Test",
209		})
210		// runSubAgent returns (errorResponse, nil) when agent.Run fails — not a Go error.
211		require.NoError(t, err)
212		assert.True(t, resp.IsError)
213		assert.Equal(t, "error generating response", resp.Content)
214	})
215
216	t.Run("session setup callback is invoked", func(t *testing.T) {
217		env := testEnv(t)
218		coord := newTestCoordinator(t, env, providerID, providerCfg)
219
220		parentSession, err := env.sessions.Create(t.Context(), "Parent")
221		require.NoError(t, err)
222
223		var setupCalledWith string
224		agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
225			return agentResultWithText("ok"), nil
226		})
227
228		_, err = coord.runSubAgent(t.Context(), subAgentParams{
229			Agent:          agent,
230			SessionID:      parentSession.ID,
231			AgentMessageID: "msg-1",
232			ToolCallID:     "call-1",
233			Prompt:         "test",
234			SessionTitle:   "Test",
235			SessionSetup: func(sessionID string) {
236				setupCalledWith = sessionID
237			},
238		})
239		require.NoError(t, err)
240		assert.NotEmpty(t, setupCalledWith, "SessionSetup should have been called")
241	})
242
243	t.Run("cost propagation to parent session", func(t *testing.T) {
244		env := testEnv(t)
245		coord := newTestCoordinator(t, env, providerID, providerCfg)
246
247		parentSession, err := env.sessions.Create(t.Context(), "Parent")
248		require.NoError(t, err)
249
250		agent := newMockAgent(providerID, 4096, func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
251			// Simulate the agent incurring cost by updating the child session.
252			childSession, err := env.sessions.Get(ctx, call.SessionID)
253			if err != nil {
254				return nil, err
255			}
256			childSession.Cost = 0.05
257			_, err = env.sessions.Save(ctx, childSession)
258			if err != nil {
259				return nil, err
260			}
261			return agentResultWithText("ok"), nil
262		})
263
264		_, err = coord.runSubAgent(t.Context(), subAgentParams{
265			Agent:          agent,
266			SessionID:      parentSession.ID,
267			AgentMessageID: "msg-1",
268			ToolCallID:     "call-1",
269			Prompt:         "test",
270			SessionTitle:   "Test",
271		})
272		require.NoError(t, err)
273
274		updated, err := env.sessions.Get(t.Context(), parentSession.ID)
275		require.NoError(t, err)
276		assert.InDelta(t, 0.05, updated.Cost, 1e-9)
277	})
278}
279
280func TestUpdateParentSessionCost(t *testing.T) {
281	t.Run("accumulates cost correctly", func(t *testing.T) {
282		env := testEnv(t)
283		cfg, err := config.Init(env.workingDir, "", false)
284		require.NoError(t, err)
285		coord := &coordinator{cfg: cfg, sessions: env.sessions}
286
287		parent, err := env.sessions.Create(t.Context(), "Parent")
288		require.NoError(t, err)
289
290		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
291		require.NoError(t, err)
292
293		// Set child cost.
294		child.Cost = 0.10
295		_, err = env.sessions.Save(t.Context(), child)
296		require.NoError(t, err)
297
298		err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
299		require.NoError(t, err)
300
301		updated, err := env.sessions.Get(t.Context(), parent.ID)
302		require.NoError(t, err)
303		assert.InDelta(t, 0.10, updated.Cost, 1e-9)
304	})
305
306	t.Run("accumulates multiple child costs", func(t *testing.T) {
307		env := testEnv(t)
308		cfg, err := config.Init(env.workingDir, "", false)
309		require.NoError(t, err)
310		coord := &coordinator{cfg: cfg, sessions: env.sessions}
311
312		parent, err := env.sessions.Create(t.Context(), "Parent")
313		require.NoError(t, err)
314
315		child1, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child1")
316		require.NoError(t, err)
317		child1.Cost = 0.05
318		_, err = env.sessions.Save(t.Context(), child1)
319		require.NoError(t, err)
320
321		child2, err := env.sessions.CreateTaskSession(t.Context(), "tool-2", parent.ID, "Child2")
322		require.NoError(t, err)
323		child2.Cost = 0.03
324		_, err = env.sessions.Save(t.Context(), child2)
325		require.NoError(t, err)
326
327		err = coord.updateParentSessionCost(t.Context(), child1.ID, parent.ID)
328		require.NoError(t, err)
329		err = coord.updateParentSessionCost(t.Context(), child2.ID, parent.ID)
330		require.NoError(t, err)
331
332		updated, err := env.sessions.Get(t.Context(), parent.ID)
333		require.NoError(t, err)
334		assert.InDelta(t, 0.08, updated.Cost, 1e-9)
335	})
336
337	t.Run("child session not found", func(t *testing.T) {
338		env := testEnv(t)
339		cfg, err := config.Init(env.workingDir, "", false)
340		require.NoError(t, err)
341		coord := &coordinator{cfg: cfg, sessions: env.sessions}
342
343		parent, err := env.sessions.Create(t.Context(), "Parent")
344		require.NoError(t, err)
345
346		err = coord.updateParentSessionCost(t.Context(), "non-existent", parent.ID)
347		require.Error(t, err)
348		assert.Contains(t, err.Error(), "get child session")
349	})
350
351	t.Run("parent session not found", func(t *testing.T) {
352		env := testEnv(t)
353		cfg, err := config.Init(env.workingDir, "", false)
354		require.NoError(t, err)
355		coord := &coordinator{cfg: cfg, sessions: env.sessions}
356
357		parent, err := env.sessions.Create(t.Context(), "Parent")
358		require.NoError(t, err)
359		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
360		require.NoError(t, err)
361
362		err = coord.updateParentSessionCost(t.Context(), child.ID, "non-existent")
363		require.Error(t, err)
364		assert.Contains(t, err.Error(), "get parent session")
365	})
366
367	t.Run("zero cost handled correctly", func(t *testing.T) {
368		env := testEnv(t)
369		cfg, err := config.Init(env.workingDir, "", false)
370		require.NoError(t, err)
371		coord := &coordinator{cfg: cfg, sessions: env.sessions}
372
373		parent, err := env.sessions.Create(t.Context(), "Parent")
374		require.NoError(t, err)
375		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
376		require.NoError(t, err)
377
378		err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
379		require.NoError(t, err)
380
381		updated, err := env.sessions.Get(t.Context(), parent.ID)
382		require.NoError(t, err)
383		assert.InDelta(t, 0.0, updated.Cost, 1e-9)
384	})
385}