1package agent
2
3import (
4 "context"
5 "testing"
6
7 "charm.land/fantasy"
8 "github.com/charmbracelet/crush/internal/message"
9 "github.com/stretchr/testify/require"
10)
11
12func TestRecoverSession(t *testing.T) {
13 t.Run("no messages", func(t *testing.T) {
14 env := testEnv(t)
15
16 sess, err := env.sessions.Create(t.Context(), "Test Session")
17 require.NoError(t, err)
18
19 // Create coordinator with mock services
20 coordinator := &coordinator{
21 sessions: env.sessions,
22 messages: env.messages,
23 }
24
25 err = coordinator.RecoverSession(t.Context(), sess.ID)
26 require.NoError(t, err)
27
28 // Verify no messages were modified
29 msgs, err := env.messages.List(t.Context(), sess.ID)
30 require.NoError(t, err)
31 require.Empty(t, msgs)
32 })
33
34 t.Run("already finished messages", func(t *testing.T) {
35 env := testEnv(t)
36
37 sess, err := env.sessions.Create(t.Context(), "Test Session")
38 require.NoError(t, err)
39
40 // Create a finished assistant message (with Finish part)
41 _, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
42 Role: message.Assistant,
43 Parts: []message.ContentPart{message.TextContent{Text: "Hello!"}, message.Finish{Reason: message.FinishReasonEndTurn}},
44 })
45 require.NoError(t, err)
46
47 // Create coordinator with mock services
48 coordinator := &coordinator{
49 sessions: env.sessions,
50 messages: env.messages,
51 }
52
53 err = coordinator.RecoverSession(t.Context(), sess.ID)
54 require.NoError(t, err)
55
56 // Verify the message was not modified
57 msgs, err := env.messages.List(t.Context(), sess.ID)
58 require.NoError(t, err)
59 require.Len(t, msgs, 1)
60 require.True(t, msgs[0].IsFinished())
61 })
62
63 t.Run("incomplete summary message", func(t *testing.T) {
64 env := testEnv(t)
65
66 sess, err := env.sessions.Create(t.Context(), "Test Session")
67 require.NoError(t, err)
68
69 // Create an incomplete summary message (simulating a crash during summarization)
70 summaryMsg, err := env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
71 Role: message.Assistant,
72 Parts: []message.ContentPart{message.TextContent{Text: "Partial summary..."}},
73 Model: "test-model",
74 Provider: "test-provider",
75 IsSummaryMessage: true,
76 })
77 require.NoError(t, err)
78
79 // Verify the message is not finished
80 require.False(t, summaryMsg.IsFinished())
81
82 // Create coordinator with mock services
83 coordinator := &coordinator{
84 sessions: env.sessions,
85 messages: env.messages,
86 }
87
88 err = coordinator.RecoverSession(t.Context(), sess.ID)
89 require.NoError(t, err)
90
91 // Verify the summary message was recovered
92 recoveredMsg, err := env.messages.Get(t.Context(), summaryMsg.ID)
93 require.NoError(t, err)
94 require.True(t, recoveredMsg.IsFinished())
95 require.Equal(t, message.FinishReasonError, recoveredMsg.FinishReason())
96 require.Contains(t, recoveredMsg.FinishPart().Message, "Session interrupted")
97 })
98
99 t.Run("incomplete assistant message with tool calls", func(t *testing.T) {
100 env := testEnv(t)
101
102 sess, err := env.sessions.Create(t.Context(), "Test Session")
103 require.NoError(t, err)
104
105 // Create an incomplete assistant message with tool calls
106 // (simulating a crash during tool execution)
107 toolCall := message.ToolCall{
108 ID: "tc-1",
109 Name: "bash",
110 Input: `echo "hello"`,
111 ProviderExecuted: false,
112 Finished: false,
113 }
114
115 assistantMsg, err := env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
116 Role: message.Assistant,
117 Parts: []message.ContentPart{message.ToolCall(toolCall)},
118 Model: "test-model",
119 })
120 require.NoError(t, err)
121
122 // Verify the message is not finished
123 require.False(t, assistantMsg.IsFinished())
124
125 // Create coordinator with mock services
126 coordinator := &coordinator{
127 sessions: env.sessions,
128 messages: env.messages,
129 }
130
131 err = coordinator.RecoverSession(t.Context(), sess.ID)
132 require.NoError(t, err)
133
134 // Verify the assistant message was recovered
135 recoveredMsg, err := env.messages.Get(t.Context(), assistantMsg.ID)
136 require.NoError(t, err)
137 require.True(t, recoveredMsg.IsFinished())
138 require.Equal(t, message.FinishReasonError, recoveredMsg.FinishReason())
139 require.Contains(t, recoveredMsg.FinishPart().Message, "Session interrupted")
140
141 // Verify the tool call was marked as finished
142 toolCalls := recoveredMsg.ToolCalls()
143 require.Len(t, toolCalls, 1)
144 require.True(t, toolCalls[0].Finished)
145 })
146
147 t.Run("incomplete assistant message without tool calls", func(t *testing.T) {
148 env := testEnv(t)
149
150 sess, err := env.sessions.Create(t.Context(), "Test Session")
151 require.NoError(t, err)
152
153 // Create an incomplete assistant message with partial content but no tool calls
154 assistantMsg, err := env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
155 Role: message.Assistant,
156 Parts: []message.ContentPart{message.TextContent{Text: "This is a partial response..."}},
157 Model: "test-model",
158 })
159 require.NoError(t, err)
160
161 // Verify the message is not finished
162 require.False(t, assistantMsg.IsFinished())
163
164 // Create coordinator with mock services
165 coordinator := &coordinator{
166 sessions: env.sessions,
167 messages: env.messages,
168 }
169
170 err = coordinator.RecoverSession(t.Context(), sess.ID)
171 require.NoError(t, err)
172
173 // Verify the assistant message was recovered
174 recoveredMsg, err := env.messages.Get(t.Context(), assistantMsg.ID)
175 require.NoError(t, err)
176 require.True(t, recoveredMsg.IsFinished())
177 require.Equal(t, message.FinishReasonError, recoveredMsg.FinishReason())
178 require.Contains(t, recoveredMsg.FinishPart().Message, "Session interrupted")
179 require.Equal(t, "This is a partial response...", recoveredMsg.Content().Text)
180 })
181
182 t.Run("session is busy - skips recovery", func(t *testing.T) {
183 env := testEnv(t)
184
185 sess, err := env.sessions.Create(t.Context(), "Test Session")
186 require.NoError(t, err)
187
188 // Create a dummy agent that reports as busy
189 agent := &dummyAgent{t: t, isBusy: true}
190
191 coordinator := &coordinator{
192 sessions: env.sessions,
193 messages: env.messages,
194 currentAgent: agent,
195 }
196
197 // Create an incomplete assistant message
198 _, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
199 Role: message.Assistant,
200 Parts: []message.ContentPart{message.TextContent{Text: "Partial..."}},
201 Model: "test-model",
202 })
203 require.NoError(t, err)
204
205 err = coordinator.RecoverSession(t.Context(), sess.ID)
206 require.NoError(t, err)
207
208 // Message should NOT be recovered since session is "busy"
209 msgs, err := env.messages.List(t.Context(), sess.ID)
210 require.NoError(t, err)
211 require.Len(t, msgs, 1)
212 require.False(t, msgs[0].IsFinished(), "message should not be finished when session is busy")
213 })
214
215 t.Run("multiple incomplete messages", func(t *testing.T) {
216 env := testEnv(t)
217
218 sess, err := env.sessions.Create(t.Context(), "Test Session")
219 require.NoError(t, err)
220
221 // Create an incomplete summary message
222 _, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
223 Role: message.Assistant,
224 Parts: []message.ContentPart{message.TextContent{Text: "Partial summary..."}},
225 IsSummaryMessage: true,
226 })
227 require.NoError(t, err)
228
229 // Create an incomplete assistant message with tool calls
230 toolCall := message.ToolCall{
231 ID: "tc-1",
232 Name: "bash",
233 Input: `echo "hello"`,
234 ProviderExecuted: false,
235 Finished: false,
236 }
237 _, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
238 Role: message.Assistant,
239 Parts: []message.ContentPart{message.ToolCall(toolCall)},
240 })
241 require.NoError(t, err)
242
243 coordinator := &coordinator{
244 sessions: env.sessions,
245 messages: env.messages,
246 }
247
248 err = coordinator.RecoverSession(t.Context(), sess.ID)
249 require.NoError(t, err)
250
251 // Verify both messages were recovered
252 msgs, err := env.messages.List(t.Context(), sess.ID)
253 require.NoError(t, err)
254 require.Len(t, msgs, 2)
255
256 for _, msg := range msgs {
257 require.True(t, msg.IsFinished(), "message %s should be finished", msg.ID)
258 }
259 })
260
261 t.Run("mixed finished and unfinished messages", func(t *testing.T) {
262 env := testEnv(t)
263
264 sess, err := env.sessions.Create(t.Context(), "Test Session")
265 require.NoError(t, err)
266
267 // Create a finished user message (Finish part is added automatically)
268 _, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
269 Role: message.User,
270 Parts: []message.ContentPart{message.TextContent{Text: "Hello!"}},
271 })
272 require.NoError(t, err)
273
274 // Create a finished assistant message (with Finish part)
275 _, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
276 Role: message.Assistant,
277 Parts: []message.ContentPart{message.TextContent{Text: "Hi there!"}, message.Finish{Reason: message.FinishReasonEndTurn}},
278 })
279 require.NoError(t, err)
280
281 // Create an incomplete assistant message with tool calls
282 toolCall := message.ToolCall{
283 ID: "tc-1",
284 Name: "bash",
285 Input: `echo "hello"`,
286 ProviderExecuted: false,
287 Finished: false,
288 }
289 _, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
290 Role: message.Assistant,
291 Parts: []message.ContentPart{message.ToolCall(toolCall)},
292 })
293 require.NoError(t, err)
294
295 coordinator := &coordinator{
296 sessions: env.sessions,
297 messages: env.messages,
298 }
299
300 err = coordinator.RecoverSession(t.Context(), sess.ID)
301 require.NoError(t, err)
302
303 // Verify all messages are now correct
304 msgs, err := env.messages.List(t.Context(), sess.ID)
305 require.NoError(t, err)
306 require.Len(t, msgs, 3)
307
308 // User message should be finished (was already)
309 require.True(t, msgs[0].IsFinished())
310
311 // First assistant message should be finished (was already)
312 require.True(t, msgs[1].IsFinished())
313
314 // Second assistant message should now be finished
315 require.True(t, msgs[2].IsFinished())
316 })
317}
318
319// dummyAgent implements SessionAgent for testing purposes.
320type dummyAgent struct {
321 t *testing.T
322 isBusy bool
323}
324
325func (a *dummyAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
326 return nil, nil
327}
328
329func (a *dummyAgent) SetModels(large, small Model) {}
330
331func (a *dummyAgent) SetTools(tools []fantasy.AgentTool) {}
332
333func (a *dummyAgent) SetSystemPrompt(systemPrompt string) {}
334
335func (a *dummyAgent) Cancel(sessionID string) {}
336
337func (a *dummyAgent) CancelAll() {}
338
339func (a *dummyAgent) IsSessionBusy(sessionID string) bool {
340 return a.isBusy
341}
342
343func (a *dummyAgent) IsBusy() bool {
344 return a.isBusy
345}
346
347func (a *dummyAgent) QueuedPrompts(sessionID string) int {
348 return 0
349}
350
351func (a *dummyAgent) QueuedPromptsList(sessionID string) []string {
352 return nil
353}
354
355func (a *dummyAgent) ClearQueue(sessionID string) {}
356
357func (a *dummyAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
358 return nil
359}
360
361func (a *dummyAgent) Model() Model {
362 return Model{}
363}