coordinator_test.go

  1package agent
  2
  3import (
  4	"context"
  5	"testing"
  6
  7	"charm.land/fantasy"
  8	"github.com/charmbracelet/crush/internal/message"
  9	"github.com/stretchr/testify/require"
 10)
 11
 12func TestRecoverSession(t *testing.T) {
 13	t.Run("no messages", func(t *testing.T) {
 14		env := testEnv(t)
 15
 16		sess, err := env.sessions.Create(t.Context(), "Test Session")
 17		require.NoError(t, err)
 18
 19		// Create coordinator with mock services
 20		coordinator := &coordinator{
 21			sessions: env.sessions,
 22			messages: env.messages,
 23		}
 24
 25		err = coordinator.RecoverSession(t.Context(), sess.ID)
 26		require.NoError(t, err)
 27
 28		// Verify no messages were modified
 29		msgs, err := env.messages.List(t.Context(), sess.ID)
 30		require.NoError(t, err)
 31		require.Empty(t, msgs)
 32	})
 33
 34	t.Run("already finished messages", func(t *testing.T) {
 35		env := testEnv(t)
 36
 37		sess, err := env.sessions.Create(t.Context(), "Test Session")
 38		require.NoError(t, err)
 39
 40		// Create a finished assistant message (with Finish part)
 41		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
 42			Role:  message.Assistant,
 43			Parts: []message.ContentPart{message.TextContent{Text: "Hello!"}, message.Finish{Reason: message.FinishReasonEndTurn}},
 44		})
 45		require.NoError(t, err)
 46
 47		// Create coordinator with mock services
 48		coordinator := &coordinator{
 49			sessions: env.sessions,
 50			messages: env.messages,
 51		}
 52
 53		err = coordinator.RecoverSession(t.Context(), sess.ID)
 54		require.NoError(t, err)
 55
 56		// Verify the message was not modified
 57		msgs, err := env.messages.List(t.Context(), sess.ID)
 58		require.NoError(t, err)
 59		require.Len(t, msgs, 1)
 60		require.True(t, msgs[0].IsFinished())
 61	})
 62
 63	t.Run("incomplete summary message", func(t *testing.T) {
 64		env := testEnv(t)
 65
 66		sess, err := env.sessions.Create(t.Context(), "Test Session")
 67		require.NoError(t, err)
 68
 69		// Create an incomplete summary message (simulating a crash during summarization)
 70		summaryMsg, err := env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
 71			Role:             message.Assistant,
 72			Parts:            []message.ContentPart{message.TextContent{Text: "Partial summary..."}},
 73			Model:            "test-model",
 74			Provider:         "test-provider",
 75			IsSummaryMessage: true,
 76		})
 77		require.NoError(t, err)
 78
 79		// Verify the message is not finished
 80		require.False(t, summaryMsg.IsFinished())
 81
 82		// Create coordinator with mock services
 83		coordinator := &coordinator{
 84			sessions: env.sessions,
 85			messages: env.messages,
 86		}
 87
 88		err = coordinator.RecoverSession(t.Context(), sess.ID)
 89		require.NoError(t, err)
 90
 91		// Verify the summary message was recovered
 92		recoveredMsg, err := env.messages.Get(t.Context(), summaryMsg.ID)
 93		require.NoError(t, err)
 94		require.True(t, recoveredMsg.IsFinished())
 95		require.Equal(t, message.FinishReasonError, recoveredMsg.FinishReason())
 96		require.Contains(t, recoveredMsg.FinishPart().Message, "Session interrupted")
 97	})
 98
 99	t.Run("incomplete assistant message with tool calls", func(t *testing.T) {
100		env := testEnv(t)
101
102		sess, err := env.sessions.Create(t.Context(), "Test Session")
103		require.NoError(t, err)
104
105		// Create an incomplete assistant message with tool calls
106		// (simulating a crash during tool execution)
107		toolCall := message.ToolCall{
108			ID:               "tc-1",
109			Name:             "bash",
110			Input:            `echo "hello"`,
111			ProviderExecuted: false,
112			Finished:         false,
113		}
114
115		assistantMsg, err := env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
116			Role:  message.Assistant,
117			Parts: []message.ContentPart{message.ToolCall(toolCall)},
118			Model: "test-model",
119		})
120		require.NoError(t, err)
121
122		// Verify the message is not finished
123		require.False(t, assistantMsg.IsFinished())
124
125		// Create coordinator with mock services
126		coordinator := &coordinator{
127			sessions: env.sessions,
128			messages: env.messages,
129		}
130
131		err = coordinator.RecoverSession(t.Context(), sess.ID)
132		require.NoError(t, err)
133
134		// Verify the assistant message was recovered
135		recoveredMsg, err := env.messages.Get(t.Context(), assistantMsg.ID)
136		require.NoError(t, err)
137		require.True(t, recoveredMsg.IsFinished())
138		require.Equal(t, message.FinishReasonError, recoveredMsg.FinishReason())
139		require.Contains(t, recoveredMsg.FinishPart().Message, "Session interrupted")
140
141		// Verify the tool call was marked as finished
142		toolCalls := recoveredMsg.ToolCalls()
143		require.Len(t, toolCalls, 1)
144		require.True(t, toolCalls[0].Finished)
145	})
146
147	t.Run("incomplete assistant message without tool calls", func(t *testing.T) {
148		env := testEnv(t)
149
150		sess, err := env.sessions.Create(t.Context(), "Test Session")
151		require.NoError(t, err)
152
153		// Create an incomplete assistant message with partial content but no tool calls
154		assistantMsg, err := env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
155			Role:  message.Assistant,
156			Parts: []message.ContentPart{message.TextContent{Text: "This is a partial response..."}},
157			Model: "test-model",
158		})
159		require.NoError(t, err)
160
161		// Verify the message is not finished
162		require.False(t, assistantMsg.IsFinished())
163
164		// Create coordinator with mock services
165		coordinator := &coordinator{
166			sessions: env.sessions,
167			messages: env.messages,
168		}
169
170		err = coordinator.RecoverSession(t.Context(), sess.ID)
171		require.NoError(t, err)
172
173		// Verify the assistant message was recovered
174		recoveredMsg, err := env.messages.Get(t.Context(), assistantMsg.ID)
175		require.NoError(t, err)
176		require.True(t, recoveredMsg.IsFinished())
177		require.Equal(t, message.FinishReasonError, recoveredMsg.FinishReason())
178		require.Contains(t, recoveredMsg.FinishPart().Message, "Session interrupted")
179		require.Equal(t, "This is a partial response...", recoveredMsg.Content().Text)
180	})
181
182	t.Run("session is busy - skips recovery", func(t *testing.T) {
183		env := testEnv(t)
184
185		sess, err := env.sessions.Create(t.Context(), "Test Session")
186		require.NoError(t, err)
187
188		// Create a dummy agent that reports as busy
189		agent := &dummyAgent{t: t, isBusy: true}
190
191		coordinator := &coordinator{
192			sessions:     env.sessions,
193			messages:     env.messages,
194			currentAgent: agent,
195		}
196
197		// Create an incomplete assistant message
198		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
199			Role:  message.Assistant,
200			Parts: []message.ContentPart{message.TextContent{Text: "Partial..."}},
201			Model: "test-model",
202		})
203		require.NoError(t, err)
204
205		err = coordinator.RecoverSession(t.Context(), sess.ID)
206		require.NoError(t, err)
207
208		// Message should NOT be recovered since session is "busy"
209		msgs, err := env.messages.List(t.Context(), sess.ID)
210		require.NoError(t, err)
211		require.Len(t, msgs, 1)
212		require.False(t, msgs[0].IsFinished(), "message should not be finished when session is busy")
213	})
214
215	t.Run("multiple incomplete messages", func(t *testing.T) {
216		env := testEnv(t)
217
218		sess, err := env.sessions.Create(t.Context(), "Test Session")
219		require.NoError(t, err)
220
221		// Create an incomplete summary message
222		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
223			Role:             message.Assistant,
224			Parts:            []message.ContentPart{message.TextContent{Text: "Partial summary..."}},
225			IsSummaryMessage: true,
226		})
227		require.NoError(t, err)
228
229		// Create an incomplete assistant message with tool calls
230		toolCall := message.ToolCall{
231			ID:               "tc-1",
232			Name:             "bash",
233			Input:            `echo "hello"`,
234			ProviderExecuted: false,
235			Finished:         false,
236		}
237		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
238			Role:  message.Assistant,
239			Parts: []message.ContentPart{message.ToolCall(toolCall)},
240		})
241		require.NoError(t, err)
242
243		coordinator := &coordinator{
244			sessions: env.sessions,
245			messages: env.messages,
246		}
247
248		err = coordinator.RecoverSession(t.Context(), sess.ID)
249		require.NoError(t, err)
250
251		// Verify both messages were recovered
252		msgs, err := env.messages.List(t.Context(), sess.ID)
253		require.NoError(t, err)
254		require.Len(t, msgs, 2)
255
256		for _, msg := range msgs {
257			require.True(t, msg.IsFinished(), "message %s should be finished", msg.ID)
258		}
259	})
260
261	t.Run("mixed finished and unfinished messages", func(t *testing.T) {
262		env := testEnv(t)
263
264		sess, err := env.sessions.Create(t.Context(), "Test Session")
265		require.NoError(t, err)
266
267		// Create a finished user message (Finish part is added automatically)
268		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
269			Role:  message.User,
270			Parts: []message.ContentPart{message.TextContent{Text: "Hello!"}},
271		})
272		require.NoError(t, err)
273
274		// Create a finished assistant message (with Finish part)
275		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
276			Role:  message.Assistant,
277			Parts: []message.ContentPart{message.TextContent{Text: "Hi there!"}, message.Finish{Reason: message.FinishReasonEndTurn}},
278		})
279		require.NoError(t, err)
280
281		// Create an incomplete assistant message with tool calls
282		toolCall := message.ToolCall{
283			ID:               "tc-1",
284			Name:             "bash",
285			Input:            `echo "hello"`,
286			ProviderExecuted: false,
287			Finished:         false,
288		}
289		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
290			Role:  message.Assistant,
291			Parts: []message.ContentPart{message.ToolCall(toolCall)},
292		})
293		require.NoError(t, err)
294
295		coordinator := &coordinator{
296			sessions: env.sessions,
297			messages: env.messages,
298		}
299
300		err = coordinator.RecoverSession(t.Context(), sess.ID)
301		require.NoError(t, err)
302
303		// Verify all messages are now correct
304		msgs, err := env.messages.List(t.Context(), sess.ID)
305		require.NoError(t, err)
306		require.Len(t, msgs, 3)
307
308		// User message should be finished (was already)
309		require.True(t, msgs[0].IsFinished())
310
311		// First assistant message should be finished (was already)
312		require.True(t, msgs[1].IsFinished())
313
314		// Second assistant message should now be finished
315		require.True(t, msgs[2].IsFinished())
316	})
317}
318
319// dummyAgent implements SessionAgent for testing purposes.
320type dummyAgent struct {
321	t      *testing.T
322	isBusy bool
323}
324
325func (a *dummyAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
326	return nil, nil
327}
328
329func (a *dummyAgent) SetModels(large, small Model) {}
330
331func (a *dummyAgent) SetTools(tools []fantasy.AgentTool) {}
332
333func (a *dummyAgent) SetSystemPrompt(systemPrompt string) {}
334
335func (a *dummyAgent) Cancel(sessionID string) {}
336
337func (a *dummyAgent) CancelAll() {}
338
339func (a *dummyAgent) IsSessionBusy(sessionID string) bool {
340	return a.isBusy
341}
342
343func (a *dummyAgent) IsBusy() bool {
344	return a.isBusy
345}
346
347func (a *dummyAgent) QueuedPrompts(sessionID string) int {
348	return 0
349}
350
351func (a *dummyAgent) QueuedPromptsList(sessionID string) []string {
352	return nil
353}
354
355func (a *dummyAgent) ClearQueue(sessionID string) {}
356
357func (a *dummyAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
358	return nil
359}
360
361func (a *dummyAgent) Model() Model {
362	return Model{}
363}