1package agent
2
3import (
4 "context"
5 "encoding/json"
6 "errors"
7 "fmt"
8 "log/slog"
9 "time"
10
11 "github.com/charmbracelet/catwalk/pkg/catwalk"
12 "github.com/charmbracelet/crush/internal/agent/prompt"
13 "github.com/charmbracelet/crush/internal/agent/tools"
14 "github.com/charmbracelet/crush/internal/ai"
15 "github.com/charmbracelet/crush/internal/ai/providers"
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/history"
19 "github.com/charmbracelet/crush/internal/lsp"
20 "github.com/charmbracelet/crush/internal/message"
21 "github.com/charmbracelet/crush/internal/permission"
22 "github.com/charmbracelet/crush/internal/pubsub"
23 "github.com/charmbracelet/crush/internal/session"
24)
25
26// Common errors
27var (
28 ErrRequestCancelled = errors.New("request canceled by user")
29 ErrSessionBusy = errors.New("session is currently processing another request")
30)
31
32type AgentEventType string
33
34const (
35 AgentEventTypeError AgentEventType = "error"
36 AgentEventTypeResponse AgentEventType = "response"
37 AgentEventTypeSummarize AgentEventType = "summarize"
38)
39
40type AgentEvent struct {
41 Type AgentEventType
42 Result ai.AgentResult
43 Error error
44
45 // When summarizing
46 SessionID string
47 Progress string
48 Done bool
49}
50
51type Service interface {
52 pubsub.Suscriber[AgentEvent]
53 Model() catwalk.Model
54 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
55 Cancel(sessionID string)
56 CancelAll()
57 IsSessionBusy(sessionID string) bool
58 IsBusy() bool
59 Summarize(ctx context.Context, sessionID string) error
60 UpdateModel() error
61 QueuedPrompts(sessionID string) int
62 ClearQueue(sessionID string)
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 cfg *config.Config
68 permissions permission.Service
69 sessions session.Service
70 messages message.Service
71 history history.Service
72 lspClients map[string]*lsp.Client
73 activeRequests *csync.Map[string, context.CancelFunc]
74
75 promptQueue *csync.Map[string, []string]
76}
77
78type AgentOption = func(*agent)
79
80// WIP this is a work in progress
81func NewAgent(
82 cfg *config.Config,
83 permissions permission.Service,
84 sessions session.Service,
85 messages message.Service,
86 history history.Service,
87 lspClients map[string]*lsp.Client,
88) Service {
89 return &agent{
90 cfg: cfg,
91 Broker: pubsub.NewBroker[AgentEvent](),
92 permissions: permissions,
93 sessions: sessions,
94 messages: messages,
95 history: history,
96 lspClients: lspClients,
97 activeRequests: csync.NewMap[string, context.CancelFunc](),
98 promptQueue: csync.NewMap[string, []string](),
99 }
100}
101
102func (a *agent) getLanguageModel(providerName, modelID string) (ai.LanguageModel, error) {
103 var provider ai.Provider
104 providerCfg, ok := a.cfg.Providers.Get(providerName)
105 if !ok {
106 return nil, errors.New("provider not found")
107 }
108
109 models := providerCfg.Models
110 foundModel := false
111 for _, providerModel := range models {
112 if providerModel.ID == modelID {
113 foundModel = true
114 break
115 }
116 }
117 if !foundModel {
118 return nil, fmt.Errorf("model `%s` not found in provider `%s`", modelID, providerName)
119 }
120 switch providerName {
121 case "openai":
122 apiKey, err := a.cfg.Resolve(providerCfg.APIKey)
123 if err != nil {
124 return nil, err
125 }
126 baseURL, err := a.cfg.Resolve(providerCfg.BaseURL)
127 if err != nil {
128 return nil, err
129 }
130 opts := []providers.OpenAiOption{
131 providers.WithOpenAiAPIKey(apiKey),
132 }
133 if baseURL != "" {
134 opts = append(opts, providers.WithOpenAiBaseURL(baseURL))
135 }
136 provider = providers.NewOpenAiProvider(opts...)
137 default:
138 return nil, errors.New("provider not found")
139 }
140 if provider == nil {
141 return nil, errors.New("provider not found")
142 }
143 return provider.LanguageModel(modelID)
144}
145
146func (a *agent) tools(ctx context.Context) []ai.AgentTool {
147 cwd := a.cfg.WorkingDir()
148 allTools := []ai.AgentTool{
149 tools.NewBashTool(a.permissions, cwd),
150 tools.NewDownloadTool(a.permissions, cwd),
151 tools.NewEditTool(a.lspClients, a.permissions, a.history, cwd),
152 tools.NewMultiEditTool(a.lspClients, a.permissions, a.history, cwd),
153 tools.NewFetchTool(a.permissions, cwd),
154 tools.NewGlobTool(cwd),
155 tools.NewGrepTool(cwd),
156 tools.NewLSTool(a.permissions, cwd),
157 tools.NewSourcegraphTool(),
158 tools.NewViewTool(a.lspClients, a.permissions, cwd),
159 tools.NewWriteTool(a.lspClients, a.permissions, a.history, cwd),
160 }
161 mcpTools := tools.GetMCPTools(ctx, a.permissions, a.cfg)
162
163 allTools = append(allTools, mcpTools...)
164
165 if len(a.lspClients) > 0 {
166 allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients))
167 }
168 // TODO: add agent tool
169 return allTools
170}
171
172// Run implements Service.
173func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
174 // INFO: for now we assume that the agent uses the large model
175 configModel := a.cfg.Models[config.SelectedModelTypeLarge]
176 model, err := a.getLanguageModel(configModel.Provider, configModel.Model)
177 if err != nil {
178 return nil, err
179 }
180
181 modelCfg := a.Model()
182 maxTokens := configModel.MaxTokens
183 if maxTokens == 0 {
184 maxTokens = modelCfg.DefaultMaxTokens
185 }
186
187 if !modelCfg.SupportsImages && attachments != nil {
188 attachments = nil
189 }
190
191 agent := ai.NewAgent(
192 model,
193 ai.WithSystemPrompt(
194 prompt.CoderPrompt(configModel.Provider, a.cfg.Options.ContextPaths...),
195 ),
196 ai.WithTools(a.tools(ctx)...),
197 ai.WithMaxOutputTokens(maxTokens),
198 )
199
200 events := make(chan AgentEvent, 1)
201 if a.IsSessionBusy(sessionID) {
202 existing, ok := a.promptQueue.Get(sessionID)
203 if !ok {
204 existing = []string{}
205 }
206 existing = append(existing, content)
207 a.promptQueue.Set(sessionID, existing)
208 return nil, nil
209 }
210
211 genCtx, cancel := context.WithCancel(ctx)
212 a.activeRequests.Set(sessionID, cancel)
213
214 go func() {
215 slog.Debug("Request started", "sessionID", sessionID)
216
217 result, err := a.makeCall(genCtx, agent, sessionID, content, attachments)
218 a.activeRequests.Del(sessionID)
219 cancel()
220 if err != nil {
221 slog.Error(err.Error())
222 events <- AgentEvent{
223 Type: AgentEventTypeError,
224 Error: err,
225 }
226 } else {
227 result := AgentEvent{
228 Type: AgentEventTypeResponse,
229 Result: *result,
230 }
231 a.Publish(pubsub.CreatedEvent, result)
232 events <- result
233 }
234 slog.Debug("Request completed", "sessionID", sessionID)
235 // TODO: implement this
236 close(events)
237 }()
238 return events, nil
239}
240
241func (a *agent) makeCall(ctx context.Context, agent ai.Agent, sessionID, prompt string, attachments []message.Attachment) (*ai.AgentResult, error) {
242 msgs, err := a.messages.List(ctx, sessionID)
243 if err != nil {
244 return nil, fmt.Errorf("failed to list messages: %w", err)
245 }
246 if len(msgs) == 0 {
247 go func() {
248 // TODO: generate title
249 // titleErr := a.generateTitle(context.Background(), sessionID, content)
250 // if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
251 // slog.Error("failed to generate title", "error", titleErr)
252 // }
253 }()
254 }
255 session, err := a.sessions.Get(ctx, sessionID)
256 if err != nil {
257 return nil, fmt.Errorf("failed to get session: %w", err)
258 }
259 if session.SummaryMessageID != "" {
260 summaryMsgInex := -1
261 for i, msg := range msgs {
262 if msg.ID == session.SummaryMessageID {
263 summaryMsgInex = i
264 break
265 }
266 }
267 if summaryMsgInex != -1 {
268 msgs = msgs[summaryMsgInex:]
269 msgs[0].Role = message.User
270 }
271 }
272
273 // Create the user message
274 var attachmentParts []message.ContentPart
275 for _, attachment := range attachments {
276 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
277 }
278 parts := []message.ContentPart{message.TextContent{Text: prompt}}
279 parts = append(parts, attachmentParts...)
280 _, err = a.messages.Create(ctx, sessionID, message.CreateMessageParams{
281 Role: message.User,
282 Parts: parts,
283 })
284 if err != nil {
285 return nil, fmt.Errorf("failed to create user message: %w", err)
286 }
287
288 var history []ai.Message
289 for _, m := range msgs {
290 if len(m.Parts) == 0 {
291 continue
292 }
293 // Assistant message without content or tool calls (cancelled before it returned anything)
294 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
295 continue
296 }
297 history = append(history, m.ToAIMessage()...)
298 }
299
300 var files []ai.FilePart
301 for _, attachment := range attachments {
302 files = append(files, ai.FilePart{
303 Filename: attachment.FileName,
304 Data: attachment.Content,
305 MediaType: attachment.MimeType,
306 })
307 }
308 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
309 // TODO: see if this is even needed
310 ctx = context.WithValue(ctx, tools.MessageIDContextKey, "mock")
311
312 var currentAssistant *message.Message
313 result, err := agent.Stream(ctx, ai.AgentStreamCall{
314 Prompt: prompt,
315 Files: files,
316 Messages: history,
317 // Get's called before each step
318 PrepareStep: func(options ai.PrepareStepFunctionOptions) (ai.PrepareStepResult, error) {
319 prepared := ai.PrepareStepResult{}
320 modelCfg := a.cfg.Models[config.SelectedModelTypeLarge]
321 // Before each step create the new assistant message
322 assistantMsg, createErr := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
323 Role: message.Assistant,
324 Parts: []message.ContentPart{},
325 Model: modelCfg.Model,
326 Provider: modelCfg.Provider,
327 })
328 if createErr != nil {
329 return prepared, createErr
330 }
331 currentAssistant = &assistantMsg
332 return prepared, nil
333 },
334 OnChunk: func(chunk ai.StreamPart) error {
335 data, _ := json.Marshal(chunk)
336 slog.Info("\n" + string(data) + "\n")
337 return nil
338 },
339 // TODO: see how to not swallow the errors on these handlers
340 OnReasoningDelta: func(id string, text string) error {
341 currentAssistant.AppendReasoningContent(text)
342 return a.messages.Update(ctx, *currentAssistant)
343 },
344 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
345 // handle anthropic signature
346 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
347 if signature, ok := anthropicData["signature"]; ok {
348 if str, ok := signature.(string); ok {
349 currentAssistant.AppendReasoningSignature(str)
350 }
351 }
352 }
353 currentAssistant.FinishThinking()
354 return a.messages.Update(ctx, *currentAssistant)
355 },
356 OnTextDelta: func(id string, text string) error {
357 currentAssistant.AppendContent(text)
358 return a.messages.Update(ctx, *currentAssistant)
359 },
360 OnToolInputStart: func(id string, toolName string) error {
361 toolCall := message.ToolCall{
362 ID: id,
363 Name: toolName,
364 ProviderExecuted: false,
365 Finished: false,
366 }
367 slog.Info("Tool call started", "toolCall", toolName)
368 currentAssistant.AddToolCall(toolCall)
369 return a.messages.Update(ctx, *currentAssistant)
370 },
371 OnToolCall: func(tc ai.ToolCallContent) error {
372 toolCall := message.ToolCall{
373 ID: tc.ToolCallID,
374 Name: tc.ToolName,
375 Input: tc.Input,
376 ProviderExecuted: false,
377 Finished: true,
378 }
379 currentAssistant.AddToolCall(toolCall)
380 return a.messages.Update(ctx, *currentAssistant)
381 },
382 OnToolResult: func(result ai.ToolResultContent) error {
383 var resultContent string
384 isError := false
385 switch result.Result.GetType() {
386 case ai.ToolResultContentTypeText:
387 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
388 if ok {
389 resultContent = r.Text
390 }
391 case ai.ToolResultContentTypeError:
392 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
393 if ok {
394 isError = true
395 resultContent = r.Error.Error()
396 }
397 case ai.ToolResultContentTypeMedia:
398 // TODO: handle this message type
399 }
400 toolResult := message.ToolResult{
401 ToolCallID: result.ToolCallID,
402 Name: result.ToolName,
403 Content: resultContent,
404 IsError: isError,
405 Metadata: result.ClientMetadata,
406 }
407 currentAssistant.AddToolResult(toolResult)
408 return a.messages.Update(ctx, *currentAssistant)
409 },
410 OnStepFinish: func(stepResult ai.StepResult) error {
411 slog.Info("Step Finished", "result", stepResult)
412 finishReason := message.FinishReasonUnknown
413 switch stepResult.FinishReason {
414 case ai.FinishReasonLength:
415 finishReason = message.FinishReasonMaxTokens
416 case ai.FinishReasonStop:
417 finishReason = message.FinishReasonEndTurn
418 case ai.FinishReasonToolCalls:
419 finishReason = message.FinishReasonToolUse
420 }
421 currentAssistant.AddFinish(finishReason, "", "")
422 return a.messages.Update(ctx, *currentAssistant)
423 },
424 })
425 if err != nil {
426 isCancelErr := errors.Is(err, context.Canceled)
427 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
428 if currentAssistant == nil {
429 return result, err
430 }
431 toolCalls := currentAssistant.ToolCalls()
432 toolResults := currentAssistant.ToolResults()
433 for _, tc := range toolCalls {
434 if !tc.Finished {
435 tc.Finished = true
436 tc.Input = "{}"
437 }
438 currentAssistant.AddToolCall(tc)
439 found := false
440 for _, tr := range toolResults {
441 if tr.ToolCallID == tc.ID {
442 found = true
443 break
444 }
445 }
446 if !found {
447 content := "There was an error while executing the tool"
448 if isCancelErr {
449 content = "Tool execution canceled by user"
450 } else if isPermissionErr {
451 content = "Permission denied"
452 }
453 currentAssistant.AddToolResult(message.ToolResult{
454 ToolCallID: tc.ID,
455 Name: tc.Name,
456 Content: content,
457 IsError: true,
458 })
459 }
460 }
461 if isCancelErr {
462 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
463 } else if isPermissionErr {
464 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
465 } else {
466 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
467 }
468 // TODO: handle error?
469 _ = a.messages.Update(context.Background(), *currentAssistant)
470 }
471 return result, err
472}
473
474// Summarize implements Service.
475func (a *agent) Summarize(ctx context.Context, sessionID string) error {
476 // TODO: implement
477 return nil
478}
479
480// UpdateModel implements Service.
481func (a *agent) UpdateModel() error {
482 return nil
483}
484
485func (a *agent) Cancel(sessionID string) {
486 // Cancel regular requests
487 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
488 slog.Info("Request cancellation initiated", "session_id", sessionID)
489 cancel()
490 }
491
492 // Also check for summarize requests
493 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
494 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
495 cancel()
496 }
497
498 if a.QueuedPrompts(sessionID) > 0 {
499 slog.Info("Clearing queued prompts", "session_id", sessionID)
500 a.promptQueue.Del(sessionID)
501 }
502}
503
504func (a *agent) ClearQueue(sessionID string) {
505 if a.QueuedPrompts(sessionID) > 0 {
506 slog.Info("Clearing queued prompts", "session_id", sessionID)
507 a.promptQueue.Del(sessionID)
508 }
509}
510
511func (a *agent) CancelAll() {
512 if !a.IsBusy() {
513 return
514 }
515 for key := range a.activeRequests.Seq2() {
516 a.Cancel(key) // key is sessionID
517 }
518
519 timeout := time.After(5 * time.Second)
520 for a.IsBusy() {
521 select {
522 case <-timeout:
523 return
524 default:
525 time.Sleep(200 * time.Millisecond)
526 }
527 }
528}
529
530func (a *agent) IsBusy() bool {
531 var busy bool
532 for cancelFunc := range a.activeRequests.Seq() {
533 if cancelFunc != nil {
534 busy = true
535 break
536 }
537 }
538 return busy
539}
540
541func (a *agent) IsSessionBusy(sessionID string) bool {
542 _, busy := a.activeRequests.Get(sessionID)
543 return busy
544}
545
546func (a *agent) Model() catwalk.Model {
547 return *a.cfg.GetModelByType(config.SelectedModelTypeLarge)
548}
549
550func (a *agent) QueuedPrompts(sessionID string) int {
551 l, ok := a.promptQueue.Get(sessionID)
552 if !ok {
553 return 0
554 }
555 return len(l)
556}