agent_test.go

  1package backend
  2
  3import (
  4	"context"
  5	"sync/atomic"
  6	"testing"
  7	"time"
  8
  9	"charm.land/fantasy"
 10	"github.com/charmbracelet/crush/internal/agent"
 11	"github.com/charmbracelet/crush/internal/app"
 12	"github.com/charmbracelet/crush/internal/message"
 13	"github.com/charmbracelet/crush/internal/proto"
 14	"github.com/google/uuid"
 15	"github.com/stretchr/testify/require"
 16)
 17
 18// blockingCoordinator is a minimal agent.Coordinator whose RunAccepted
 19// blocks until release is closed. It records that RunAccepted was
 20// entered so tests can observe the dispatched goroutine. Every other
 21// method returns a zero value.
 22type blockingCoordinator struct {
 23	entered  chan struct{}
 24	release  chan struct{}
 25	runCount atomic.Int32
 26}
 27
 28func newBlockingCoordinator() *blockingCoordinator {
 29	return &blockingCoordinator{
 30		entered: make(chan struct{}, 1),
 31		release: make(chan struct{}),
 32	}
 33}
 34
 35func (c *blockingCoordinator) Run(ctx context.Context, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 36	return nil, nil
 37}
 38
 39func (c *blockingCoordinator) RunAccepted(ctx context.Context, accept *agent.AcceptedRun, sessionID, prompt string, attachments ...message.Attachment) (*fantasy.AgentResult, error) {
 40	c.runCount.Add(1)
 41	select {
 42	case c.entered <- struct{}{}:
 43	default:
 44	}
 45	<-c.release
 46	return nil, nil
 47}
 48
 49func (c *blockingCoordinator) BeginAccepted(sessionID string) *agent.AcceptedRun { return nil }
 50func (c *blockingCoordinator) Cancel(string)                                     {}
 51func (c *blockingCoordinator) CancelAll()                                        {}
 52func (c *blockingCoordinator) IsBusy() bool                                      { return false }
 53func (c *blockingCoordinator) IsSessionBusy(string) bool                         { return false }
 54func (c *blockingCoordinator) QueuedPrompts(string) int                          { return 0 }
 55func (c *blockingCoordinator) QueuedPromptsList(string) []string                 { return nil }
 56func (c *blockingCoordinator) ClearQueue(string)                                 {}
 57func (c *blockingCoordinator) Summarize(context.Context, string) error           { return nil }
 58func (c *blockingCoordinator) Model() agent.Model                                { return agent.Model{} }
 59func (c *blockingCoordinator) UpdateModels(context.Context) error                { return nil }
 60
 61// insertAgentWorkspace installs a synthetic workspace with the given
 62// coordinator (or none) and a workspace run context, mirroring the
 63// fields CreateWorkspace initializes.
 64func insertAgentWorkspace(t *testing.T, b *Backend, coord agent.Coordinator) *Workspace {
 65	t.Helper()
 66	ws := &Workspace{
 67		ID:           uuid.New().String(),
 68		Path:         t.TempDir(),
 69		resolvedPath: t.TempDir(),
 70		clients:      make(map[string]*clientState),
 71		shutdownFn:   func() {},
 72	}
 73	ws.App = &app.App{AgentCoordinator: coord}
 74	ws.ctx, ws.cancel = context.WithCancel(b.ctx)
 75	b.mu.Lock()
 76	b.workspaces.Set(ws.ID, ws)
 77	b.pathIndex[ws.resolvedPath] = ws.ID
 78	b.mu.Unlock()
 79	return ws
 80}
 81
 82func TestSendMessage_WorkspaceNotFound(t *testing.T) {
 83	t.Parallel()
 84	b, _ := newTestBackend(t)
 85	err := b.SendMessage("nope", proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
 86	require.ErrorIs(t, err, ErrWorkspaceNotFound)
 87}
 88
 89func TestSendMessage_AgentNotInitialized(t *testing.T) {
 90	t.Parallel()
 91	b, _ := newTestBackend(t)
 92	ws := insertAgentWorkspace(t, b, nil)
 93	err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
 94	require.ErrorIs(t, err, ErrAgentNotInitialized)
 95}
 96
 97func TestSendMessage_EmptyPrompt(t *testing.T) {
 98	t.Parallel()
 99	b, _ := newTestBackend(t)
100	ws := insertAgentWorkspace(t, b, newBlockingCoordinator())
101	err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: ""})
102	require.ErrorIs(t, err, agent.ErrEmptyPrompt)
103}
104
105func TestSendMessage_SessionMissing(t *testing.T) {
106	t.Parallel()
107	b, _ := newTestBackend(t)
108	ws := insertAgentWorkspace(t, b, newBlockingCoordinator())
109	err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "", Prompt: "hi"})
110	require.ErrorIs(t, err, agent.ErrSessionMissing)
111}
112
113func TestSendMessage_WorkspaceClosing(t *testing.T) {
114	t.Parallel()
115	b, _ := newTestBackend(t)
116	ws := insertAgentWorkspace(t, b, newBlockingCoordinator())
117	ws.runMu.Lock()
118	ws.closing = true
119	ws.runMu.Unlock()
120	err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
121	require.ErrorIs(t, err, ErrWorkspaceClosing)
122}
123
124// TestSendMessage_SuccessIncrementsRunWG asserts the happy path returns
125// nil synchronously and dispatches a tracked goroutine: while
126// RunAccepted blocks, runWG.Wait must not complete (the ticket is
127// outstanding); after release it drains.
128func TestSendMessage_SuccessIncrementsRunWG(t *testing.T) {
129	t.Parallel()
130	b, _ := newTestBackend(t)
131	coord := newBlockingCoordinator()
132	ws := insertAgentWorkspace(t, b, coord)
133
134	err := b.SendMessage(ws.ID, proto.AgentMessage{SessionID: "S1", Prompt: "hi"})
135	require.NoError(t, err)
136
137	select {
138	case <-coord.entered:
139	case <-time.After(2 * time.Second):
140		t.Fatal("dispatched goroutine never entered RunAccepted")
141	}
142	require.Equal(t, int32(1), coord.runCount.Load())
143
144	waited := make(chan struct{})
145	go func() {
146		ws.runWG.Wait()
147		close(waited)
148	}()
149
150	select {
151	case <-waited:
152		t.Fatal("runWG.Wait completed while the run was still in flight; ticket was not added")
153	case <-time.After(100 * time.Millisecond):
154	}
155
156	close(coord.release)
157
158	select {
159	case <-waited:
160	case <-time.After(2 * time.Second):
161		t.Fatal("runWG.Wait did not complete after the run returned")
162	}
163}