coordinator_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"testing"
  7
  8	"charm.land/catwalk/pkg/catwalk"
  9	"charm.land/fantasy"
 10	"charm.land/fantasy/providers/anthropic"
 11	"charm.land/fantasy/providers/bedrock"
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/stretchr/testify/assert"
 14	"github.com/stretchr/testify/require"
 15)
 16
 17// mockSessionAgent is a minimal mock for the SessionAgent interface.
 18type mockSessionAgent struct {
 19	model     Model
 20	runFunc   func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error)
 21	cancelled []string
 22}
 23
 24func (m *mockSessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 25	return m.runFunc(ctx, call)
 26}
 27
 28func (m *mockSessionAgent) BeginAccepted(sessionID string) *AcceptedRun {
 29	return &AcceptedRun{sessionID: sessionID}
 30}
 31
 32func (m *mockSessionAgent) Model() Model                        { return m.model }
 33func (m *mockSessionAgent) SetModels(large, small Model)        {}
 34func (m *mockSessionAgent) SetTools(tools []fantasy.AgentTool)  {}
 35func (m *mockSessionAgent) SetSystemPrompt(systemPrompt string) {}
 36func (m *mockSessionAgent) Cancel(sessionID string) {
 37	m.cancelled = append(m.cancelled, sessionID)
 38}
 39func (m *mockSessionAgent) CancelAll()                                  {}
 40func (m *mockSessionAgent) IsSessionBusy(sessionID string) bool         { return false }
 41func (m *mockSessionAgent) IsBusy() bool                                { return false }
 42func (m *mockSessionAgent) QueuedPrompts(sessionID string) int          { return 0 }
 43func (m *mockSessionAgent) QueuedPromptsList(sessionID string) []string { return nil }
 44func (m *mockSessionAgent) ClearQueue(sessionID string)                 {}
 45func (m *mockSessionAgent) Summarize(context.Context, string, fantasy.ProviderOptions) error {
 46	return nil
 47}
 48
 49// newTestCoordinator creates a minimal coordinator for unit testing runSubAgent.
 50func newTestCoordinator(t *testing.T, env fakeEnv, providerID string, providerCfg config.ProviderConfig) *coordinator {
 51	cfg, err := config.Init(env.workingDir, "", false)
 52	require.NoError(t, err)
 53	cfg.Config().Providers.Set(providerID, providerCfg)
 54	return &coordinator{
 55		cfg:      cfg,
 56		sessions: env.sessions,
 57	}
 58}
 59
 60// newMockAgent creates a mockSessionAgent with the given provider and run function.
 61func newMockAgent(providerID string, maxTokens int64, runFunc func(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)) *mockSessionAgent {
 62	return &mockSessionAgent{
 63		model: Model{
 64			CatwalkCfg: catwalk.Model{
 65				DefaultMaxTokens: maxTokens,
 66			},
 67			ModelCfg: config.SelectedModel{
 68				Provider: providerID,
 69			},
 70		},
 71		runFunc: runFunc,
 72	}
 73}
 74
 75// agentResultWithText creates a minimal AgentResult with the given text response.
 76func agentResultWithText(text string) *fantasy.AgentResult {
 77	return &fantasy.AgentResult{
 78		Response: fantasy.Response{
 79			Content: fantasy.ResponseContent{
 80				fantasy.TextContent{Text: text},
 81			},
 82		},
 83	}
 84}
 85
 86func TestRunSubAgent(t *testing.T) {
 87	const providerID = "test-provider"
 88	providerCfg := config.ProviderConfig{ID: providerID}
 89
 90	t.Run("happy path", func(t *testing.T) {
 91		env := testEnv(t)
 92		coord := newTestCoordinator(t, env, providerID, providerCfg)
 93
 94		parentSession, err := env.sessions.Create(t.Context(), "Parent")
 95		require.NoError(t, err)
 96
 97		agent := newMockAgent(providerID, 4096, func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
 98			assert.Equal(t, "do something", call.Prompt)
 99			assert.Equal(t, int64(4096), call.MaxOutputTokens)
100			return agentResultWithText("done"), nil
101		})
102
103		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
104			Agent:          agent,
105			SessionID:      parentSession.ID,
106			AgentMessageID: "msg-1",
107			ToolCallID:     "call-1",
108			Prompt:         "do something",
109			SessionTitle:   "Test Session",
110		})
111		require.NoError(t, err)
112		assert.Equal(t, "done", resp.Content)
113		assert.False(t, resp.IsError)
114	})
115
116	t.Run("ModelCfg.MaxTokens overrides default", func(t *testing.T) {
117		env := testEnv(t)
118		coord := newTestCoordinator(t, env, providerID, providerCfg)
119
120		parentSession, err := env.sessions.Create(t.Context(), "Parent")
121		require.NoError(t, err)
122
123		agent := &mockSessionAgent{
124			model: Model{
125				CatwalkCfg: catwalk.Model{
126					DefaultMaxTokens: 4096,
127				},
128				ModelCfg: config.SelectedModel{
129					Provider:  providerID,
130					MaxTokens: 8192,
131				},
132			},
133			runFunc: func(_ context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
134				assert.Equal(t, int64(8192), call.MaxOutputTokens)
135				return agentResultWithText("ok"), nil
136			},
137		}
138
139		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
140			Agent:          agent,
141			SessionID:      parentSession.ID,
142			AgentMessageID: "msg-1",
143			ToolCallID:     "call-1",
144			Prompt:         "test",
145			SessionTitle:   "Test",
146		})
147		require.NoError(t, err)
148		assert.Equal(t, "ok", resp.Content)
149	})
150
151	t.Run("session creation failure with canceled context", func(t *testing.T) {
152		env := testEnv(t)
153		coord := newTestCoordinator(t, env, providerID, providerCfg)
154
155		parentSession, err := env.sessions.Create(t.Context(), "Parent")
156		require.NoError(t, err)
157
158		agent := newMockAgent(providerID, 4096, nil)
159
160		// Use a canceled context to trigger CreateTaskSession failure.
161		ctx, cancel := context.WithCancel(t.Context())
162		cancel()
163
164		_, err = coord.runSubAgent(ctx, subAgentParams{
165			Agent:          agent,
166			SessionID:      parentSession.ID,
167			AgentMessageID: "msg-1",
168			ToolCallID:     "call-1",
169			Prompt:         "test",
170			SessionTitle:   "Test",
171		})
172		require.Error(t, err)
173	})
174
175	t.Run("provider not configured", func(t *testing.T) {
176		env := testEnv(t)
177		coord := newTestCoordinator(t, env, providerID, providerCfg)
178
179		parentSession, err := env.sessions.Create(t.Context(), "Parent")
180		require.NoError(t, err)
181
182		// Agent references a provider that doesn't exist in config.
183		agent := newMockAgent("unknown-provider", 4096, nil)
184
185		_, err = coord.runSubAgent(t.Context(), subAgentParams{
186			Agent:          agent,
187			SessionID:      parentSession.ID,
188			AgentMessageID: "msg-1",
189			ToolCallID:     "call-1",
190			Prompt:         "test",
191			SessionTitle:   "Test",
192		})
193		require.Error(t, err)
194		assert.Contains(t, err.Error(), "model provider not configured")
195	})
196
197	t.Run("agent run error returns error response", func(t *testing.T) {
198		env := testEnv(t)
199		coord := newTestCoordinator(t, env, providerID, providerCfg)
200
201		parentSession, err := env.sessions.Create(t.Context(), "Parent")
202		require.NoError(t, err)
203
204		agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
205			return nil, errors.New("provider request failed")
206		})
207
208		resp, err := coord.runSubAgent(t.Context(), subAgentParams{
209			Agent:          agent,
210			SessionID:      parentSession.ID,
211			AgentMessageID: "msg-1",
212			ToolCallID:     "call-1",
213			Prompt:         "test",
214			SessionTitle:   "Test",
215		})
216		// runSubAgent returns (errorResponse, nil) when agent.Run fails — not a Go error.
217		require.NoError(t, err)
218		assert.True(t, resp.IsError)
219		assert.Equal(t, "Failed to generate response: provider request failed", resp.Content)
220	})
221
222	t.Run("session setup callback is invoked", func(t *testing.T) {
223		env := testEnv(t)
224		coord := newTestCoordinator(t, env, providerID, providerCfg)
225
226		parentSession, err := env.sessions.Create(t.Context(), "Parent")
227		require.NoError(t, err)
228
229		var setupCalledWith string
230		agent := newMockAgent(providerID, 4096, func(_ context.Context, _ SessionAgentCall) (*fantasy.AgentResult, error) {
231			return agentResultWithText("ok"), nil
232		})
233
234		_, err = coord.runSubAgent(t.Context(), subAgentParams{
235			Agent:          agent,
236			SessionID:      parentSession.ID,
237			AgentMessageID: "msg-1",
238			ToolCallID:     "call-1",
239			Prompt:         "test",
240			SessionTitle:   "Test",
241			SessionSetup: func(sessionID string) {
242				setupCalledWith = sessionID
243			},
244		})
245		require.NoError(t, err)
246		assert.NotEmpty(t, setupCalledWith, "SessionSetup should have been called")
247	})
248
249	t.Run("cost propagation to parent session", func(t *testing.T) {
250		env := testEnv(t)
251		coord := newTestCoordinator(t, env, providerID, providerCfg)
252
253		parentSession, err := env.sessions.Create(t.Context(), "Parent")
254		require.NoError(t, err)
255
256		agent := newMockAgent(providerID, 4096, func(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
257			// Simulate the agent incurring cost by updating the child session.
258			childSession, err := env.sessions.Get(ctx, call.SessionID)
259			if err != nil {
260				return nil, err
261			}
262			childSession.Cost = 0.05
263			_, err = env.sessions.Save(ctx, childSession)
264			if err != nil {
265				return nil, err
266			}
267			return agentResultWithText("ok"), nil
268		})
269
270		_, err = coord.runSubAgent(t.Context(), subAgentParams{
271			Agent:          agent,
272			SessionID:      parentSession.ID,
273			AgentMessageID: "msg-1",
274			ToolCallID:     "call-1",
275			Prompt:         "test",
276			SessionTitle:   "Test",
277		})
278		require.NoError(t, err)
279
280		updated, err := env.sessions.Get(t.Context(), parentSession.ID)
281		require.NoError(t, err)
282		assert.InDelta(t, 0.05, updated.Cost, 1e-9)
283	})
284}
285
286func TestUpdateParentSessionCost(t *testing.T) {
287	t.Run("accumulates cost correctly", func(t *testing.T) {
288		env := testEnv(t)
289		cfg, err := config.Init(env.workingDir, "", false)
290		require.NoError(t, err)
291		coord := &coordinator{cfg: cfg, sessions: env.sessions}
292
293		parent, err := env.sessions.Create(t.Context(), "Parent")
294		require.NoError(t, err)
295
296		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
297		require.NoError(t, err)
298
299		// Set child cost.
300		child.Cost = 0.10
301		_, err = env.sessions.Save(t.Context(), child)
302		require.NoError(t, err)
303
304		err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
305		require.NoError(t, err)
306
307		updated, err := env.sessions.Get(t.Context(), parent.ID)
308		require.NoError(t, err)
309		assert.InDelta(t, 0.10, updated.Cost, 1e-9)
310	})
311
312	t.Run("accumulates multiple child costs", func(t *testing.T) {
313		env := testEnv(t)
314		cfg, err := config.Init(env.workingDir, "", false)
315		require.NoError(t, err)
316		coord := &coordinator{cfg: cfg, sessions: env.sessions}
317
318		parent, err := env.sessions.Create(t.Context(), "Parent")
319		require.NoError(t, err)
320
321		child1, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child1")
322		require.NoError(t, err)
323		child1.Cost = 0.05
324		_, err = env.sessions.Save(t.Context(), child1)
325		require.NoError(t, err)
326
327		child2, err := env.sessions.CreateTaskSession(t.Context(), "tool-2", parent.ID, "Child2")
328		require.NoError(t, err)
329		child2.Cost = 0.03
330		_, err = env.sessions.Save(t.Context(), child2)
331		require.NoError(t, err)
332
333		err = coord.updateParentSessionCost(t.Context(), child1.ID, parent.ID)
334		require.NoError(t, err)
335		err = coord.updateParentSessionCost(t.Context(), child2.ID, parent.ID)
336		require.NoError(t, err)
337
338		updated, err := env.sessions.Get(t.Context(), parent.ID)
339		require.NoError(t, err)
340		assert.InDelta(t, 0.08, updated.Cost, 1e-9)
341	})
342
343	t.Run("child session not found", func(t *testing.T) {
344		env := testEnv(t)
345		cfg, err := config.Init(env.workingDir, "", false)
346		require.NoError(t, err)
347		coord := &coordinator{cfg: cfg, sessions: env.sessions}
348
349		parent, err := env.sessions.Create(t.Context(), "Parent")
350		require.NoError(t, err)
351
352		err = coord.updateParentSessionCost(t.Context(), "non-existent", parent.ID)
353		require.Error(t, err)
354		assert.Contains(t, err.Error(), "get child session")
355	})
356
357	t.Run("parent session not found", func(t *testing.T) {
358		env := testEnv(t)
359		cfg, err := config.Init(env.workingDir, "", false)
360		require.NoError(t, err)
361		coord := &coordinator{cfg: cfg, sessions: env.sessions}
362
363		parent, err := env.sessions.Create(t.Context(), "Parent")
364		require.NoError(t, err)
365		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
366		require.NoError(t, err)
367
368		err = coord.updateParentSessionCost(t.Context(), child.ID, "non-existent")
369		require.Error(t, err)
370		assert.Contains(t, err.Error(), "get parent session")
371	})
372
373	t.Run("zero cost handled correctly", func(t *testing.T) {
374		env := testEnv(t)
375		cfg, err := config.Init(env.workingDir, "", false)
376		require.NoError(t, err)
377		coord := &coordinator{cfg: cfg, sessions: env.sessions}
378
379		parent, err := env.sessions.Create(t.Context(), "Parent")
380		require.NoError(t, err)
381		child, err := env.sessions.CreateTaskSession(t.Context(), "tool-1", parent.ID, "Child")
382		require.NoError(t, err)
383
384		err = coord.updateParentSessionCost(t.Context(), child.ID, parent.ID)
385		require.NoError(t, err)
386
387		updated, err := env.sessions.Get(t.Context(), parent.ID)
388		require.NoError(t, err)
389		assert.InDelta(t, 0.0, updated.Cost, 1e-9)
390	})
391}
392
393func TestGetProviderOptionsReasoningEffort(t *testing.T) {
394	// Bedrock is Fantasy's Anthropic under a different provider name; options
395	// must land under anthropic.Name so the Anthropic language model picks them up.
396	tests := []struct {
397		name         string
398		providerType catwalk.Type
399	}{
400		{"anthropic honors reasoning_effort", catwalk.Type(anthropic.Name)},
401		{"bedrock honors reasoning_effort", catwalk.Type(bedrock.Name)},
402	}
403	for _, tc := range tests {
404		t.Run(tc.name, func(t *testing.T) {
405			model := Model{
406				CatwalkCfg: catwalk.Model{
407					ID:              "claude-opus-4-7",
408					CanReason:       true,
409					ReasoningLevels: []string{"max"},
410				},
411				ModelCfg: config.SelectedModel{
412					Provider:        "test",
413					ReasoningEffort: "max",
414				},
415			}
416			providerCfg := config.ProviderConfig{ID: "test", Type: tc.providerType}
417
418			opts := getProviderOptions(model, providerCfg)
419
420			raw, ok := opts[anthropic.Name]
421			require.True(t, ok, "options should be keyed under anthropic.Name for type %q", tc.providerType)
422			parsed, ok := raw.(*anthropic.ProviderOptions)
423			require.True(t, ok)
424			require.NotNil(t, parsed.Effort)
425			assert.Equal(t, anthropic.Effort("max"), *parsed.Effort)
426		})
427	}
428}