agent.go

  1package backend
  2
  3import (
  4	"context"
  5
  6	"github.com/charmbracelet/crush/internal/agent"
  7	"github.com/charmbracelet/crush/internal/config"
  8	"github.com/charmbracelet/crush/internal/proto"
  9)
 10
 11// SendMessage sends a prompt to the agent coordinator for the given
 12// workspace and session.
 13//
 14// When msg.RunID is non-empty it is attached to the context via
 15// agent.WithRunID so the coordinator can stamp the resulting
 16// SessionAgentCall (and therefore the terminal notify.RunComplete
 17// event) with that correlator. This is the only way for the
 18// originating client to distinguish its own turn's RunComplete from
 19// any concurrent turn that finishes on the same session.
 20func (b *Backend) SendMessage(ctx context.Context, workspaceID string, msg proto.AgentMessage) error {
 21	ws, err := b.GetWorkspace(workspaceID)
 22	if err != nil {
 23		return err
 24	}
 25
 26	if ws.AgentCoordinator == nil {
 27		return ErrAgentNotInitialized
 28	}
 29
 30	if msg.RunID != "" {
 31		ctx = agent.WithRunID(ctx, msg.RunID)
 32	}
 33	_, err = ws.AgentCoordinator.Run(ctx, msg.SessionID, msg.Prompt, proto.AttachmentsToMessage(msg.Attachments)...)
 34	return err
 35}
 36
 37// GetAgentInfo returns the agent's model and busy status.
 38func (b *Backend) GetAgentInfo(workspaceID string) (proto.AgentInfo, error) {
 39	ws, err := b.GetWorkspace(workspaceID)
 40	if err != nil {
 41		return proto.AgentInfo{}, err
 42	}
 43
 44	var agentInfo proto.AgentInfo
 45	if ws.AgentCoordinator != nil {
 46		m := ws.AgentCoordinator.Model()
 47		agentInfo = proto.AgentInfo{
 48			Model:    m.CatwalkCfg,
 49			ModelCfg: m.ModelCfg,
 50			IsBusy:   ws.AgentCoordinator.IsBusy(),
 51			IsReady:  true,
 52		}
 53	}
 54	return agentInfo, nil
 55}
 56
 57// InitAgent initializes the coder agent for the workspace.
 58func (b *Backend) InitAgent(ctx context.Context, workspaceID string) error {
 59	ws, err := b.GetWorkspace(workspaceID)
 60	if err != nil {
 61		return err
 62	}
 63
 64	return ws.InitCoderAgent(ctx)
 65}
 66
 67// UpdateAgent reloads the agent model configuration.
 68func (b *Backend) UpdateAgent(ctx context.Context, workspaceID string) error {
 69	ws, err := b.GetWorkspace(workspaceID)
 70	if err != nil {
 71		return err
 72	}
 73
 74	return ws.UpdateAgentModel(ctx)
 75}
 76
 77// CancelSession cancels an ongoing agent operation for the given
 78// session.
 79func (b *Backend) CancelSession(workspaceID, sessionID string) error {
 80	ws, err := b.GetWorkspace(workspaceID)
 81	if err != nil {
 82		return err
 83	}
 84
 85	if ws.AgentCoordinator != nil {
 86		ws.AgentCoordinator.Cancel(sessionID)
 87	}
 88	return nil
 89}
 90
 91// SummarizeSession triggers a session summarization.
 92func (b *Backend) SummarizeSession(ctx context.Context, workspaceID, sessionID string) error {
 93	ws, err := b.GetWorkspace(workspaceID)
 94	if err != nil {
 95		return err
 96	}
 97
 98	if ws.AgentCoordinator == nil {
 99		return ErrAgentNotInitialized
100	}
101
102	return ws.AgentCoordinator.Summarize(ctx, sessionID)
103}
104
105// QueuedPrompts returns the number of queued prompts for the session.
106func (b *Backend) QueuedPrompts(workspaceID, sessionID string) (int, error) {
107	ws, err := b.GetWorkspace(workspaceID)
108	if err != nil {
109		return 0, err
110	}
111
112	if ws.AgentCoordinator == nil {
113		return 0, nil
114	}
115
116	return ws.AgentCoordinator.QueuedPrompts(sessionID), nil
117}
118
119// ClearQueue clears the prompt queue for the session.
120func (b *Backend) ClearQueue(workspaceID, sessionID string) error {
121	ws, err := b.GetWorkspace(workspaceID)
122	if err != nil {
123		return err
124	}
125
126	if ws.AgentCoordinator != nil {
127		ws.AgentCoordinator.ClearQueue(sessionID)
128	}
129	return nil
130}
131
132// QueuedPromptsList returns the list of queued prompt strings for a
133// session.
134func (b *Backend) QueuedPromptsList(workspaceID, sessionID string) ([]string, error) {
135	ws, err := b.GetWorkspace(workspaceID)
136	if err != nil {
137		return nil, err
138	}
139
140	if ws.AgentCoordinator == nil {
141		return nil, nil
142	}
143
144	return ws.AgentCoordinator.QueuedPromptsList(sessionID), nil
145}
146
147// GetDefaultSmallModel returns the default small model for a provider.
148func (b *Backend) GetDefaultSmallModel(workspaceID, providerID string) (config.SelectedModel, error) {
149	ws, err := b.GetWorkspace(workspaceID)
150	if err != nil {
151		return config.SelectedModel{}, err
152	}
153
154	return ws.GetDefaultSmallModel(providerID), nil
155}