coordinator_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"testing"
  7
  8	"charm.land/catwalk/pkg/catwalk"
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/anthropic"
 11	"charm.land/fantasy/providers/bedrock"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/stretchr/testify/assert"
 14	"github.com/stretchr/testify/require"
 15)
 16
 17// mockSessionAgent is a minimal mock for the SessionAgent interface.
 18type mockSessionAgent struct {
 19	model     Model
 20	runFunc   func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error)
 21	cancelled []string
 22}
 23
 24func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 25	return m.runFunc(ctx, call)
 26}
 27
 28func (m *mockSessionAgent) Model() Model                        { return m.model }
 29func (m *mockSessionAgent) SetModels(large, small Model)        {}
 30func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool)  {}
 31func (m *mockSessionAgent) SetSystemPrompt(systemPrompt string) {}
 32func (m *mockSessionAgent) Cancel(sessionID string) {
 33	m.cancelled = append(m.cancelled, sessionID)
 34}
 35func (m *mockSessionAgent) CancelAll()                                  {}
 36func (m *mockSessionAgent) IsSessionBusy(sessionID string) bool         { return false }
 37func (m *mockSessionAgent) IsBusy() bool                                { return false }
 38func (m *mockSessionAgent) QueuedPrompts(sessionID string) int          { return 0 }
 39func (m *mockSessionAgent) QueuedPromptsList(sessionID string) []string { return nil }
 40func (m *mockSessionAgent) ClearQueue(sessionID string)                 {}
 41func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOptions) error {
 42	return nil
 43}
 44
 45// newTestCoordinator creates a minimal coordinator for unit testing runSubAgent.
 46func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator {
 47	cfg, err := config.Init(env.workingDir, "", false)
 48	require.NoError(t, err)
 49	cfg.Config().Providers.Set(providerID, providerCfg)
 50	return &coordinator{
 51		cfg:      cfg,
 52		sessions: env.sessions,
 53	}
 54}
 55
 56// newMockAgent creates a mockSessionAgent with the given provider and run function.
 57func newMockAgent(providerID string, maxTokens int64, runFunc func(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)) *mockSessionAgent {
 58	return &mockSessionAgent{
 59		model: Model{
 60			CatwalkCfg: catwalk.Model{
 61				DefaultMaxTokens: maxTokens,
 62			},
 63			ModelCfg: config.SelectedModel{
 64				Provider: providerID,
 65			},
 66		},
 67		runFunc: runFunc,
 68	}
 69}
 70
 71// agentResultWithText creates a minimal AgentResult with the given text response.
 72func agentResultWithText(text string) *fantasy.AgentResult {
 73	return &fantasy.AgentResult{
 74		Response: fantasy.Response{
 75			Content: fantasy.ResponseContent{
 76				fantasy.TextContent{Text: text},
 77			},
 78		},
 79	}
 80}
 81
 82func TestRunSubAgent(t *testing.T) {
 83	const providerID = "test-provider"
 84	providerCfg := config.ProviderConfig{ID: providerID}
 85
 86	t.Run("happy path", func(t *testing.T) {
 87		env := testEnv(t)
 88		coord := newTestCoordinator(t, env, providerID, providerCfg)
 89
 90		parentSession, err := env.sessions.Create(t.Context(), "Parent")
 91		require.NoError(t, err)
 92
 93		agent := newMockAgent(providerID, 4096, func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 94			assert.Equal(t, "do something", call.Prompt)
 95			assert.Equal(t, int64(4096), call.MaxOutputTokens)
 96			return agentResultWithText("done"), nil
 97		})
 98
 99		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
100			Agent:          agent,
101			SessionID:      parentSession.ID,
102			AgentMessageID: "msg-1",
103			ToolCallID:     "call-1",
104			Prompt:         "do something",
105			SessionTitle:   "Test Session",
106		})
107		require.NoError(t, err)
108		assert.Equal(t, "done", resp.Content)
109		assert.False(t, resp.IsError)
110	})
111
112	t.Run("ModelCfg.MaxTokens overrides default", func(t *testing.T) {
113		env := testEnv(t)
114		coord := newTestCoordinator(t, env, providerID, providerCfg)
115
116		parentSession, err := env.sessions.Create(t.Context(), "Parent")
117		require.NoError(t, err)
118
119		agent := &mockSessionAgent{
120			model: Model{
121				CatwalkCfg: catwalk.Model{
122					DefaultMaxTokens: 4096,
123				},
124				ModelCfg: config.SelectedModel{
125					Provider:  providerID,
126					MaxTokens: 8192,
127				},
128			},
129			runFunc: func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
130				assert.Equal(t, int64(8192), call.MaxOutputTokens)
131				return agentResultWithText("ok"), nil
132			},
133		}
134
135		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
136			Agent:          agent,
137			SessionID:      parentSession.ID,
138			AgentMessageID: "msg-1",
139			ToolCallID:     "call-1",
140			Prompt:         "test",
141			SessionTitle:   "Test",
142		})
143		require.NoError(t, err)
144		assert.Equal(t, "ok", resp.Content)
145	})
146
147	t.Run("session creation failure with canceled context", func(t *testing.T) {
148		env := testEnv(t)
149		coord := newTestCoordinator(t, env, providerID, providerCfg)
150
151		parentSession, err := env.sessions.Create(t.Context(), "Parent")
152		require.NoError(t, err)
153
154		agent := newMockAgent(providerID, 4096, nil)
155
156		// Use a canceled context to trigger CreateTaskSession failure.
157		ctx, cancel := context.WithCancel(t.Context())
158		cancel()
159
160		_, err = coord.runSubAgent(ctx, subAgentParams{
161			Agent:          agent,
162			SessionID:      parentSession.ID,
163			AgentMessageID: "msg-1",
164			ToolCallID:     "call-1",
165			Prompt:         "test",
166			SessionTitle:   "Test",
167		})
168		require.Error(t, err)
169	})
170
171	t.Run("provider not configured", func(t *testing.T) {
172		env := testEnv(t)
173		coord := newTestCoordinator(t, env, providerID, providerCfg)
174
175		parentSession, err := env.sessions.Create(t.Context(), "Parent")
176		require.NoError(t, err)
177
178		// Agent references a provider that doesn't exist in config.
179		agent := newMockAgent("unknown-provider", 4096, nil)
180
181		_, err = coord.runSubAgent(t.Context(), subAgentParams{
182			Agent:          agent,
183			SessionID:      parentSession.ID,
184			AgentMessageID: "msg-1",
185			ToolCallID:     "call-1",
186			Prompt:         "test",
187			SessionTitle:   "Test",
188		})
189		require.Error(t, err)
190		assert.Contains(t, err.Error(), "model provider not configured")
191	})
192
193	t.Run("agent run error returns error response", func(t *testing.T) {
194		env := testEnv(t)
195		coord := newTestCoordinator(t, env, providerID, providerCfg)
196
197		parentSession, err := env.sessions.Create(t.Context(), "Parent")
198		require.NoError(t, err)
199
200		agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
201			return nil, errors.New("provider request failed")
202		})
203
204		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
205			Agent:          agent,
206			SessionID:      parentSession.ID,
207			AgentMessageID: "msg-1",
208			ToolCallID:     "call-1",
209			Prompt:         "test",
210			SessionTitle:   "Test",
211		})
212		// runSubAgent returns (errorResponse, nil) when agent.Run fails — not a Go error.
213		require.NoError(t, err)
214		assert.True(t, resp.IsError)
215		assert.Equal(t, "Failed to generate response: provider request failed", resp.Content)
216	})
217
218	t.Run("session setup callback is invoked", func(t *testing.T) {
219		env := testEnv(t)
220		coord := newTestCoordinator(t, env, providerID, providerCfg)
221
222		parentSession, err := env.sessions.Create(t.Context(), "Parent")
223		require.NoError(t, err)
224
225		var setupCalledWith string
226		agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
227			return agentResultWithText("ok"), nil
228		})
229
230		_, err = coord.runSubAgent(t.Context(), subAgentParams{
231			Agent:          agent,
232			SessionID:      parentSession.ID,
233			AgentMessageID: "msg-1",
234			ToolCallID:     "call-1",
235			Prompt:         "test",
236			SessionTitle:   "Test",
237			SessionSetup: func(sessionID string) {
238				setupCalledWith = sessionID
239			},
240		})
241		require.NoError(t, err)
242		assert.NotEmpty(t, setupCalledWith, "SessionSetup should have been called")
243	})
244
245	t.Run("cost propagation to parent session", func(t *testing.T) {
246		env := testEnv(t)
247		coord := newTestCoordinator(t, env, providerID, providerCfg)
248
249		parentSession, err := env.sessions.Create(t.Context(), "Parent")
250		require.NoError(t, err)
251
252		agent := newMockAgent(providerID, 4096, func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
253			// Simulate the agent incurring cost by updating the child session.
254			childSession, err := env.sessions.Get(ctx, call.SessionID)
255			if err != nil {
256				return nil, err
257			}
258			childSession.Cost = 0.05
259			_, err = env.sessions.Save(ctx, childSession)
260			if err != nil {
261				return nil, err
262			}
263			return agentResultWithText("ok"), nil
264		})
265
266		_, err = coord.runSubAgent(t.Context(), subAgentParams{
267			Agent:          agent,
268			SessionID:      parentSession.ID,
269			AgentMessageID: "msg-1",
270			ToolCallID:     "call-1",
271			Prompt:         "test",
272			SessionTitle:   "Test",
273		})
274		require.NoError(t, err)
275
276		updated, err := env.sessions.Get(t.Context(), parentSession.ID)
277		require.NoError(t, err)
278		assert.InDelta(t, 0.05, updated.Cost, 1e-9)
279	})
280}
281
282func TestUpdateParentSessionCost(t *testing.T) {
283	t.Run("accumulates cost correctly", func(t *testing.T) {
284		env := testEnv(t)
285		cfg, err := config.Init(env.workingDir, "", false)
286		require.NoError(t, err)
287		coord := &coordinator{cfg: cfg, sessions: env.sessions}
288
289		parent, err := env.sessions.Create(t.Context(), "Parent")
290		require.NoError(t, err)
291
292		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
293		require.NoError(t, err)
294
295		// Set child cost.
296		child.Cost = 0.10
297		_, err = env.sessions.Save(t.Context(), child)
298		require.NoError(t, err)
299
300		err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
301		require.NoError(t, err)
302
303		updated, err := env.sessions.Get(t.Context(), parent.ID)
304		require.NoError(t, err)
305		assert.InDelta(t, 0.10, updated.Cost, 1e-9)
306	})
307
308	t.Run("accumulates multiple child costs", func(t *testing.T) {
309		env := testEnv(t)
310		cfg, err := config.Init(env.workingDir, "", false)
311		require.NoError(t, err)
312		coord := &coordinator{cfg: cfg, sessions: env.sessions}
313
314		parent, err := env.sessions.Create(t.Context(), "Parent")
315		require.NoError(t, err)
316
317		child1, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child1")
318		require.NoError(t, err)
319		child1.Cost = 0.05
320		_, err = env.sessions.Save(t.Context(), child1)
321		require.NoError(t, err)
322
323		child2, err := env.sessions.CreateTaskSession(t.Context(), "tool-2", parent.ID, "Child2")
324		require.NoError(t, err)
325		child2.Cost = 0.03
326		_, err = env.sessions.Save(t.Context(), child2)
327		require.NoError(t, err)
328
329		err = coord.updateParentSessionCost(t.Context(), child1.ID, parent.ID)
330		require.NoError(t, err)
331		err = coord.updateParentSessionCost(t.Context(), child2.ID, parent.ID)
332		require.NoError(t, err)
333
334		updated, err := env.sessions.Get(t.Context(), parent.ID)
335		require.NoError(t, err)
336		assert.InDelta(t, 0.08, updated.Cost, 1e-9)
337	})
338
339	t.Run("child session not found", func(t *testing.T) {
340		env := testEnv(t)
341		cfg, err := config.Init(env.workingDir, "", false)
342		require.NoError(t, err)
343		coord := &coordinator{cfg: cfg, sessions: env.sessions}
344
345		parent, err := env.sessions.Create(t.Context(), "Parent")
346		require.NoError(t, err)
347
348		err = coord.updateParentSessionCost(t.Context(), "non-existent", parent.ID)
349		require.Error(t, err)
350		assert.Contains(t, err.Error(), "get child session")
351	})
352
353	t.Run("parent session not found", func(t *testing.T) {
354		env := testEnv(t)
355		cfg, err := config.Init(env.workingDir, "", false)
356		require.NoError(t, err)
357		coord := &coordinator{cfg: cfg, sessions: env.sessions}
358
359		parent, err := env.sessions.Create(t.Context(), "Parent")
360		require.NoError(t, err)
361		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
362		require.NoError(t, err)
363
364		err = coord.updateParentSessionCost(t.Context(), child.ID, "non-existent")
365		require.Error(t, err)
366		assert.Contains(t, err.Error(), "get parent session")
367	})
368
369	t.Run("zero cost handled correctly", func(t *testing.T) {
370		env := testEnv(t)
371		cfg, err := config.Init(env.workingDir, "", false)
372		require.NoError(t, err)
373		coord := &coordinator{cfg: cfg, sessions: env.sessions}
374
375		parent, err := env.sessions.Create(t.Context(), "Parent")
376		require.NoError(t, err)
377		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
378		require.NoError(t, err)
379
380		err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
381		require.NoError(t, err)
382
383		updated, err := env.sessions.Get(t.Context(), parent.ID)
384		require.NoError(t, err)
385		assert.InDelta(t, 0.0, updated.Cost, 1e-9)
386	})
387}
388
389func TestGetProviderOptionsReasoningEffort(t *testing.T) {
390	// Bedrock is Fantasy's Anthropic under a different provider name; options
391	// must land under anthropic.Name so the Anthropic language model picks them up.
392	tests := []struct {
393		name         string
394		providerType catwalk.Type
395	}{
396		{"anthropic honors reasoning_effort", catwalk.Type(anthropic.Name)},
397		{"bedrock honors reasoning_effort", catwalk.Type(bedrock.Name)},
398	}
399	for _, tc := range tests {
400		t.Run(tc.name, func(t *testing.T) {
401			model := Model{
402				CatwalkCfg: catwalk.Model{
403					ID:              "claude-opus-4-7",
404					CanReason:       true,
405					ReasoningLevels: []string{"max"},
406				},
407				ModelCfg: config.SelectedModel{
408					Provider:        "test",
409					ReasoningEffort: "max",
410				},
411			}
412			providerCfg := config.ProviderConfig{ID: "test", Type: tc.providerType}
413
414			opts := getProviderOptions(model, providerCfg)
415
416			raw, ok := opts[anthropic.Name]
417			require.True(t, ok, "options should be keyed under anthropic.Name for type %q", tc.providerType)
418			parsed, ok := raw.(*anthropic.ProviderOptions)
419			require.True(t, ok)
420			require.NotNil(t, parsed.Effort)
421			assert.Equal(t, anthropic.Effort("max"), *parsed.Effort)
422		})
423	}
424}