1package backend
2
3import (
4 "context"
5 "sync/atomic"
6 "testing"
7 "time"
8
9 "charm.land/fantasy"
10 "github.com/charmbracelet/crush/internal/agent"
11 "github.com/charmbracelet/crush/internal/app"
12 "github.com/charmbracelet/crush/internal/message"
13 "github.com/charmbracelet/crush/internal/proto"
14 "github.com/google/uuid"
15 "github.com/stretchr/testify/require"
16)
17
18// blockingCoordinator is a minimal agent.Coordinator whose RunAccepted
19// blocks until release is closed. It records that RunAccepted was
20// entered so tests can observe the dispatched goroutine. Every other
21// method returns a zero value.
22type blockingCoordinator struct {
23 entered chan struct{}
24 release chan struct{}
25 runCount atomic.Int32
26}
27
28func newBlockingCoordinator() *blockingCoordinator {
29 return &blockingCoordinator{
30 entered: make(chan struct{}, 1),
31 release: make(chan struct{}),
32 }
33}
34
35func (c *blockingCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
36 return nil, nil
37}
38
39func (c *blockingCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
40 c.runCount.Add(1)
41 select {
42 case c.entered <- struct{}{}:
43 default:
44 }
45 <-c.release
46 return nil, nil
47}
48
49func (c *blockingCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { return nil }
50func (c *blockingCoordinator) Cancel(string) {}
51func (c *blockingCoordinator) CancelAll() {}
52func (c *blockingCoordinator) IsBusy() bool { return false }
53func (c *blockingCoordinator) IsSessionBusy(string) bool { return false }
54func (c *blockingCoordinator) QueuedPrompts(string) int { return 0 }
55func (c *blockingCoordinator) QueuedPromptsList(string) []string { return nil }
56func (c *blockingCoordinator) ClearQueue(string) {}
57func (c *blockingCoordinator) Summarize(context.Context, string) error { return nil }
58func (c *blockingCoordinator) Model() agent.Model { return agent.Model{} }
59func (c *blockingCoordinator) UpdateModels(context.Context) error { return nil }
60
61// insertAgentWorkspace installs a synthetic workspace with the given
62// coordinator (or none) and a workspace run context, mirroring the
63// fields CreateWorkspace initializes.
64func insertAgentWorkspace(t *testing.T, b *Backend, coord agent.Coordinator) *Workspace {
65 t.Helper()
66 ws := &Workspace{
67 ID: uuid.New().String(),
68 Path: t.TempDir(),
69 resolvedPath: t.TempDir(),
70 clients: make(map[string]*clientState),
71 shutdownFn: func() {},
72 }
73 ws.App = &app.App{AgentCoordinator: coord}
74 ws.ctx, ws.cancel = context.WithCancel(b.ctx)
75 b.mu.Lock()
76 b.workspaces.Set(ws.ID, ws)
77 b.pathIndex[ws.resolvedPath] = ws.ID
78 b.mu.Unlock()
79 return ws
80}
81
82func TestSendMessage_WorkspaceNotFound(t *testing.T) {
83 t.Parallel()
84 b, _ := newTestBackend(t)
85 err := b.SendMessage("nope", proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
86 require.ErrorIs(t, err, ErrWorkspaceNotFound)
87}
88
89func TestSendMessage_AgentNotInitialized(t *testing.T) {
90 t.Parallel()
91 b, _ := newTestBackend(t)
92 ws := insertAgentWorkspace(t, b, nil)
93 err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
94 require.ErrorIs(t, err, ErrAgentNotInitialized)
95}
96
97func TestSendMessage_EmptyPrompt(t *testing.T) {
98 t.Parallel()
99 b, _ := newTestBackend(t)
100 ws := insertAgentWorkspace(t, b, newBlockingCoordinator())
101 err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: ""})
102 require.ErrorIs(t, err, agent.ErrEmptyPrompt)
103}
104
105func TestSendMessage_SessionMissing(t *testing.T) {
106 t.Parallel()
107 b, _ := newTestBackend(t)
108 ws := insertAgentWorkspace(t, b, newBlockingCoordinator())
109 err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "", Prompt: "hi"})
110 require.ErrorIs(t, err, agent.ErrSessionMissing)
111}
112
113func TestSendMessage_WorkspaceClosing(t *testing.T) {
114 t.Parallel()
115 b, _ := newTestBackend(t)
116 ws := insertAgentWorkspace(t, b, newBlockingCoordinator())
117 ws.runMu.Lock()
118 ws.closing = true
119 ws.runMu.Unlock()
120 err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
121 require.ErrorIs(t, err, ErrWorkspaceClosing)
122}
123
124// TestSendMessage_SuccessIncrementsRunWG asserts the happy path returns
125// nil synchronously and dispatches a tracked goroutine: while
126// RunAccepted blocks, runWG.Wait must not complete (the ticket is
127// outstanding); after release it drains.
128func TestSendMessage_SuccessIncrementsRunWG(t *testing.T) {
129 t.Parallel()
130 b, _ := newTestBackend(t)
131 coord := newBlockingCoordinator()
132 ws := insertAgentWorkspace(t, b, coord)
133
134 err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
135 require.NoError(t, err)
136
137 select {
138 case <-coord.entered:
139 case <-time.After(2 * time.Second):
140 t.Fatal("dispatched goroutine never entered RunAccepted")
141 }
142 require.Equal(t, int32(1), coord.runCount.Load())
143
144 waited := make(chan struct{})
145 go func() {
146 ws.runWG.Wait()
147 close(waited)
148 }()
149
150 select {
151 case <-waited:
152 t.Fatal("runWG.Wait completed while the run was still in flight; ticket was not added")
153 case <-time.After(100 * time.Millisecond):
154 }
155
156 close(coord.release)
157
158 select {
159 case <-waited:
160 case <-time.After(2 * time.Second):
161 t.Fatal("runWG.Wait did not complete after the run returned")
162 }
163}