tests

Carlos Alexandro Becker created

Signed-off-by: Carlos Alexandro Becker <caarlos0@users.noreply.github.com>

Change summary

internal/agent/coordinator.go      |   2 
internal/agent/coordinator_test.go | 328 ++++++++++++++++++++++++++++++++
2 files changed, 329 insertions(+), 1 deletion(-)

Detailed changes

internal/agent/coordinator.go 🔗

@@ -837,7 +837,7 @@ func (c *coordinator) Summarize(ctx context.Context, sessionID string) error {
 
 func (c *coordinator) RecoverSession(ctx context.Context, sessionID string) error {
 	// Skip recovery if session is currently active
-	if c.currentAgent.IsSessionBusy(sessionID) {
+	if c.currentAgent != nil && c.currentAgent.IsSessionBusy(sessionID) {
 		return nil
 	}
 

internal/agent/coordinator_test.go 🔗

@@ -0,0 +1,328 @@
+package agent
+
+import (
+	"context"
+	"testing"
+
+	"charm.land/fantasy"
+	"github.com/charmbracelet/crush/internal/message"
+	"github.com/stretchr/testify/require"
+)
+
+func TestRecoverSession(t *testing.T) {
+	t.Run("no messages", func(t *testing.T) {
+		env := testEnv(t)
+
+		sess, err := env.sessions.Create(t.Context(), "Test Session")
+		require.NoError(t, err)
+
+		// Create coordinator with mock services
+		coordinator := &coordinator{
+			sessions: env.sessions,
+			messages: env.messages,
+		}
+
+		err = coordinator.RecoverSession(t.Context(), sess.ID)
+		require.NoError(t, err)
+
+		// Verify no messages were modified
+		msgs, err := env.messages.List(t.Context(), sess.ID)
+		require.NoError(t, err)
+		require.Empty(t, msgs)
+	})
+
+	t.Run("already finished messages", func(t *testing.T) {
+		env := testEnv(t)
+
+		sess, err := env.sessions.Create(t.Context(), "Test Session")
+		require.NoError(t, err)
+
+		// Create a finished assistant message (with Finish part)
+		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
+			Role:  message.Assistant,
+			Parts: []message.ContentPart{message.TextContent{Text: "Hello!"}, message.Finish{Reason: message.FinishReasonEndTurn}},
+		})
+		require.NoError(t, err)
+
+		// Create coordinator with mock services
+		coordinator := &coordinator{
+			sessions: env.sessions,
+			messages: env.messages,
+		}
+
+		err = coordinator.RecoverSession(t.Context(), sess.ID)
+		require.NoError(t, err)
+
+		// Verify the message was not modified
+		msgs, err := env.messages.List(t.Context(), sess.ID)
+		require.NoError(t, err)
+		require.Len(t, msgs, 1)
+		require.True(t, msgs[0].IsFinished())
+	})
+
+	t.Run("incomplete summary message", func(t *testing.T) {
+		env := testEnv(t)
+
+		sess, err := env.sessions.Create(t.Context(), "Test Session")
+		require.NoError(t, err)
+
+		// Create an incomplete summary message (simulating a crash during summarization)
+		summaryMsg, err := env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
+			Role:             message.Assistant,
+			Parts:            []message.ContentPart{message.TextContent{Text: "Partial summary..."}},
+			Model:            "test-model",
+			Provider:         "test-provider",
+			IsSummaryMessage: true,
+		})
+		require.NoError(t, err)
+
+		// Verify the message is not finished
+		require.False(t, summaryMsg.IsFinished())
+
+		// Create coordinator with mock services
+		coordinator := &coordinator{
+			sessions: env.sessions,
+			messages: env.messages,
+		}
+
+		err = coordinator.RecoverSession(t.Context(), sess.ID)
+		require.NoError(t, err)
+
+		// Verify the summary message was recovered
+		recoveredMsg, err := env.messages.Get(t.Context(), summaryMsg.ID)
+		require.NoError(t, err)
+		require.True(t, recoveredMsg.IsFinished())
+		require.Equal(t, message.FinishReasonError, recoveredMsg.FinishReason())
+		require.Contains(t, recoveredMsg.FinishPart().Message, "Summarization interrupted")
+	})
+
+	t.Run("incomplete assistant message with tool calls", func(t *testing.T) {
+		env := testEnv(t)
+
+		sess, err := env.sessions.Create(t.Context(), "Test Session")
+		require.NoError(t, err)
+
+		// Create an incomplete assistant message with tool calls
+		// (simulating a crash during tool execution)
+		toolCall := message.ToolCall{
+			ID:               "tc-1",
+			Name:             "bash",
+			Input:            `echo "hello"`,
+			ProviderExecuted: false,
+			Finished:         false,
+		}
+
+		assistantMsg, err := env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
+			Role:  message.Assistant,
+			Parts: []message.ContentPart{message.ToolCall(toolCall)},
+			Model: "test-model",
+		})
+		require.NoError(t, err)
+
+		// Verify the message is not finished
+		require.False(t, assistantMsg.IsFinished())
+
+		// Create coordinator with mock services
+		coordinator := &coordinator{
+			sessions: env.sessions,
+			messages: env.messages,
+		}
+
+		err = coordinator.RecoverSession(t.Context(), sess.ID)
+		require.NoError(t, err)
+
+		// Verify the assistant message was recovered
+		recoveredMsg, err := env.messages.Get(t.Context(), assistantMsg.ID)
+		require.NoError(t, err)
+		require.True(t, recoveredMsg.IsFinished())
+		require.Equal(t, message.FinishReasonError, recoveredMsg.FinishReason())
+		require.Contains(t, recoveredMsg.FinishPart().Message, "Response interrupted")
+
+		// Verify the tool call was marked as finished
+		toolCalls := recoveredMsg.ToolCalls()
+		require.Len(t, toolCalls, 1)
+		require.True(t, toolCalls[0].Finished)
+	})
+
+	t.Run("session is busy - skips recovery", func(t *testing.T) {
+		env := testEnv(t)
+
+		sess, err := env.sessions.Create(t.Context(), "Test Session")
+		require.NoError(t, err)
+
+		// Create a dummy agent that reports as busy
+		agent := &dummyAgent{t: t, isBusy: true}
+
+		coordinator := &coordinator{
+			sessions:     env.sessions,
+			messages:     env.messages,
+			currentAgent: agent,
+		}
+
+		// Create an incomplete assistant message
+		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
+			Role:  message.Assistant,
+			Parts: []message.ContentPart{message.TextContent{Text: "Partial..."}},
+			Model: "test-model",
+		})
+		require.NoError(t, err)
+
+		err = coordinator.RecoverSession(t.Context(), sess.ID)
+		require.NoError(t, err)
+
+		// Message should NOT be recovered since session is "busy"
+		msgs, err := env.messages.List(t.Context(), sess.ID)
+		require.NoError(t, err)
+		require.Len(t, msgs, 1)
+		require.False(t, msgs[0].IsFinished(), "message should not be finished when session is busy")
+	})
+
+	t.Run("multiple incomplete messages", func(t *testing.T) {
+		env := testEnv(t)
+
+		sess, err := env.sessions.Create(t.Context(), "Test Session")
+		require.NoError(t, err)
+
+		// Create an incomplete summary message
+		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
+			Role:             message.Assistant,
+			Parts:            []message.ContentPart{message.TextContent{Text: "Partial summary..."}},
+			IsSummaryMessage: true,
+		})
+		require.NoError(t, err)
+
+		// Create an incomplete assistant message with tool calls
+		toolCall := message.ToolCall{
+			ID:               "tc-1",
+			Name:             "bash",
+			Input:            `echo "hello"`,
+			ProviderExecuted: false,
+			Finished:         false,
+		}
+		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
+			Role:  message.Assistant,
+			Parts: []message.ContentPart{message.ToolCall(toolCall)},
+		})
+		require.NoError(t, err)
+
+		coordinator := &coordinator{
+			sessions: env.sessions,
+			messages: env.messages,
+		}
+
+		err = coordinator.RecoverSession(t.Context(), sess.ID)
+		require.NoError(t, err)
+
+		// Verify both messages were recovered
+		msgs, err := env.messages.List(t.Context(), sess.ID)
+		require.NoError(t, err)
+		require.Len(t, msgs, 2)
+
+		for _, msg := range msgs {
+			require.True(t, msg.IsFinished(), "message %s should be finished", msg.ID)
+		}
+	})
+
+	t.Run("mixed finished and unfinished messages", func(t *testing.T) {
+		env := testEnv(t)
+
+		sess, err := env.sessions.Create(t.Context(), "Test Session")
+		require.NoError(t, err)
+
+		// Create a finished user message (Finish part is added automatically)
+		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
+			Role:  message.User,
+			Parts: []message.ContentPart{message.TextContent{Text: "Hello!"}},
+		})
+		require.NoError(t, err)
+
+		// Create a finished assistant message (with Finish part)
+		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
+			Role:  message.Assistant,
+			Parts: []message.ContentPart{message.TextContent{Text: "Hi there!"}, message.Finish{Reason: message.FinishReasonEndTurn}},
+		})
+		require.NoError(t, err)
+
+		// Create an incomplete assistant message with tool calls
+		toolCall := message.ToolCall{
+			ID:               "tc-1",
+			Name:             "bash",
+			Input:            `echo "hello"`,
+			ProviderExecuted: false,
+			Finished:         false,
+		}
+		_, err = env.messages.Create(t.Context(), sess.ID, message.CreateMessageParams{
+			Role:  message.Assistant,
+			Parts: []message.ContentPart{message.ToolCall(toolCall)},
+		})
+		require.NoError(t, err)
+
+		coordinator := &coordinator{
+			sessions: env.sessions,
+			messages: env.messages,
+		}
+
+		err = coordinator.RecoverSession(t.Context(), sess.ID)
+		require.NoError(t, err)
+
+		// Verify all messages are now correct
+		msgs, err := env.messages.List(t.Context(), sess.ID)
+		require.NoError(t, err)
+		require.Len(t, msgs, 3)
+
+		// User message should be finished (was already)
+		require.True(t, msgs[0].IsFinished())
+
+		// First assistant message should be finished (was already)
+		require.True(t, msgs[1].IsFinished())
+
+		// Second assistant message should now be finished
+		require.True(t, msgs[2].IsFinished())
+	})
+}
+
+// dummyAgent implements SessionAgent for testing purposes.
+type dummyAgent struct {
+	t        *testing.T
+	isBusy   bool
+}
+
+func (a *dummyAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
+	return nil, nil
+}
+
+func (a *dummyAgent) SetModels(large, small Model) {}
+
+func (a *dummyAgent) SetTools(tools []fantasy.AgentTool) {}
+
+func (a *dummyAgent) SetSystemPrompt(systemPrompt string) {}
+
+func (a *dummyAgent) Cancel(sessionID string) {}
+
+func (a *dummyAgent) CancelAll() {}
+
+func (a *dummyAgent) IsSessionBusy(sessionID string) bool {
+	return a.isBusy
+}
+
+func (a *dummyAgent) IsBusy() bool {
+	return a.isBusy
+}
+
+func (a *dummyAgent) QueuedPrompts(sessionID string) int {
+	return 0
+}
+
+func (a *dummyAgent) QueuedPromptsList(sessionID string) []string {
+	return nil
+}
+
+func (a *dummyAgent) ClearQueue(sessionID string) {}
+
+func (a *dummyAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
+	return nil
+}
+
+func (a *dummyAgent) Model() Model {
+	return Model{}
+}