agent.go

  1package backend
  2
  3import (
  4	"context"
  5	"errors"
  6
  7	"github.com/charmbracelet/crush/internal/agent"
  8	"github.com/charmbracelet/crush/internal/agent/notify"
  9	"github.com/charmbracelet/crush/internal/config"
 10	"github.com/charmbracelet/crush/internal/proto"
 11	"github.com/charmbracelet/crush/internal/pubsub"
 12)
 13
 14// SendMessage validates and accepts a prompt for the workspace's agent,
 15// then dispatches the run on a goroutine bound to the workspace context
 16// and returns immediately. It does not wait for the LLM turn to
 17// complete: the run's lifetime is owned by the workspace, not by the
 18// caller. Errors from the dispatched run reach observers through the
 19// agent event channels (a notify.TypeAgentError notification), not
 20// through this return value.
 21//
 22// SendMessage returns synchronously when the request cannot be accepted:
 23// ErrWorkspaceNotFound if the workspace is missing, ErrAgentNotInitialized
 24// if its coordinator is nil, the structural validation errors from
 25// agent.ValidateCall (ErrEmptyPrompt, ErrSessionMissing) when the prompt
 26// or session is missing, and ErrWorkspaceClosing if the workspace is
 27// being torn down.
 28func (b *Backend) SendMessage(workspaceID string, msg proto.AgentMessage) error {
 29	ws, err := b.GetWorkspace(workspaceID)
 30	if err != nil {
 31		return err
 32	}
 33
 34	if ws.AgentCoordinator == nil {
 35		return ErrAgentNotInitialized
 36	}
 37
 38	if err := agent.ValidateCall(agent.SessionAgentCall{
 39		SessionID:   msg.SessionID,
 40		Prompt:      msg.Prompt,
 41		Attachments: proto.AttachmentsToMessage(msg.Attachments),
 42	}); err != nil {
 43		return err
 44	}
 45
 46	accept := ws.AgentCoordinator.BeginAccepted(msg.SessionID)
 47
 48	ws.runMu.Lock()
 49	if ws.closing {
 50		ws.runMu.Unlock()
 51		accept.Close()
 52		return ErrWorkspaceClosing
 53	}
 54	ws.runWG.Add(1)
 55	ws.runMu.Unlock()
 56
 57	go b.runAgent(ws, msg, accept)
 58	return nil
 59}
 60
 61// runAgent executes an accepted agent run for the workspace. It owns the
 62// accept reservation (releasing it on return) and the runWG ticket added
 63// by SendMessage. The run is bound to the workspace context so its
 64// lifetime is independent of any client's HTTP request. On a non-cancel
 65// error it surfaces the failure to observers via a notify.TypeAgentError
 66// notification; context.Canceled is expected (the FinishReasonCanceled
 67// marker is already published by sessionAgent.Run) and swallowed.
 68//
 69// When msg.RunID is non-empty it is attached to the context via
 70// agent.WithRunID so the coordinator can stamp the terminal
 71// notify.RunComplete event with that correlator.
 72func (b *Backend) runAgent(ws *Workspace, msg proto.AgentMessage, accept *agent.AcceptedRun) {
 73	defer ws.runWG.Done()
 74	defer accept.Close()
 75
 76	ctx := ws.ctx
 77	if msg.RunID != "" {
 78		ctx = agent.WithRunID(ctx, msg.RunID)
 79	}
 80
 81	_, err := ws.AgentCoordinator.RunAccepted(ctx, accept, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...)
 82	if err == nil || errors.Is(err, context.Canceled) {
 83		return
 84	}
 85
 86	ws.AgentNotifications().Publish(pubsub.CreatedEvent, notify.Notification{
 87		SessionID: msg.SessionID,
 88		RunID:     msg.RunID,
 89		Type:      notify.TypeAgentError,
 90		Message:   err.Error(),
 91	})
 92}
 93
 94// GetAgentInfo returns the agent's model and busy status.
 95func (b *Backend) GetAgentInfo(workspaceID string) (proto.AgentInfo, error) {
 96	ws, err := b.GetWorkspace(workspaceID)
 97	if err != nil {
 98		return proto.AgentInfo{}, err
 99	}
100
101	var agentInfo proto.AgentInfo
102	if ws.AgentCoordinator != nil {
103		m := ws.AgentCoordinator.Model()
104		agentInfo = proto.AgentInfo{
105			Model:    m.CatwalkCfg,
106			ModelCfg: m.ModelCfg,
107			IsBusy:   ws.AgentCoordinator.IsBusy(),
108			IsReady:  true,
109		}
110	}
111	return agentInfo, nil
112}
113
114// InitAgent initializes the coder agent for the workspace.
115func (b *Backend) InitAgent(ctx context.Context, workspaceID string) error {
116	ws, err := b.GetWorkspace(workspaceID)
117	if err != nil {
118		return err
119	}
120
121	return ws.InitCoderAgent(ctx)
122}
123
124// UpdateAgent reloads the agent model configuration.
125func (b *Backend) UpdateAgent(ctx context.Context, workspaceID string) error {
126	ws, err := b.GetWorkspace(workspaceID)
127	if err != nil {
128		return err
129	}
130
131	return ws.UpdateAgentModel(ctx)
132}
133
134// CancelSession cancels an ongoing agent operation for the given
135// session.
136func (b *Backend) CancelSession(workspaceID, sessionID string) error {
137	ws, err := b.GetWorkspace(workspaceID)
138	if err != nil {
139		return err
140	}
141
142	if ws.AgentCoordinator != nil {
143		ws.AgentCoordinator.Cancel(sessionID)
144	}
145	return nil
146}
147
148// SummarizeSession triggers a session summarization.
149func (b *Backend) SummarizeSession(ctx context.Context, workspaceID, sessionID string) error {
150	ws, err := b.GetWorkspace(workspaceID)
151	if err != nil {
152		return err
153	}
154
155	if ws.AgentCoordinator == nil {
156		return ErrAgentNotInitialized
157	}
158
159	return ws.AgentCoordinator.Summarize(ctx, sessionID)
160}
161
162// QueuedPrompts returns the number of queued prompts for the session.
163func (b *Backend) QueuedPrompts(workspaceID, sessionID string) (int, error) {
164	ws, err := b.GetWorkspace(workspaceID)
165	if err != nil {
166		return 0, err
167	}
168
169	if ws.AgentCoordinator == nil {
170		return 0, nil
171	}
172
173	return ws.AgentCoordinator.QueuedPrompts(sessionID), nil
174}
175
176// ClearQueue clears the prompt queue for the session.
177func (b *Backend) ClearQueue(workspaceID, sessionID string) error {
178	ws, err := b.GetWorkspace(workspaceID)
179	if err != nil {
180		return err
181	}
182
183	if ws.AgentCoordinator != nil {
184		ws.AgentCoordinator.ClearQueue(sessionID)
185	}
186	return nil
187}
188
189// QueuedPromptsList returns the list of queued prompt strings for a
190// session.
191func (b *Backend) QueuedPromptsList(workspaceID, sessionID string) ([]string, error) {
192	ws, err := b.GetWorkspace(workspaceID)
193	if err != nil {
194		return nil, err
195	}
196
197	if ws.AgentCoordinator == nil {
198		return nil, nil
199	}
200
201	return ws.AgentCoordinator.QueuedPromptsList(sessionID), nil
202}
203
204// GetDefaultSmallModel returns the default small model for a provider.
205func (b *Backend) GetDefaultSmallModel(workspaceID, providerID string) (config.SelectedModel, error) {
206	ws, err := b.GetWorkspace(workspaceID)
207	if err != nil {
208		return config.SelectedModel{}, err
209	}
210
211	return ws.GetDefaultSmallModel(providerID), nil
212}