1// Package agent contains the implementation of the AI agent service.
2package agent
3
4import (
5 "context"
6 "errors"
7 "fmt"
8 "log/slog"
9 "maps"
10 "slices"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/event"
18 "github.com/charmbracelet/crush/internal/history"
19 "github.com/charmbracelet/crush/internal/llm/prompt"
20 "github.com/charmbracelet/crush/internal/llm/provider"
21 "github.com/charmbracelet/crush/internal/llm/tools"
22 "github.com/charmbracelet/crush/internal/log"
23 "github.com/charmbracelet/crush/internal/lsp"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/pubsub"
27 "github.com/charmbracelet/crush/internal/session"
28 "github.com/charmbracelet/crush/internal/shell"
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() catwalk.Model
53 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 Summarize(ctx context.Context, sessionID string) error
59 UpdateModel() error
60 QueuedPrompts(sessionID string) int
61 ClearQueue(sessionID string)
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69 permissions permission.Service
70 baseTools *csync.Map[string, tools.BaseTool]
71 mcpTools *csync.Map[string, tools.BaseTool]
72 lspClients *csync.Map[string, *lsp.Client]
73
74 // We need this to be able to update it when model changes
75 agentToolFn func() (tools.BaseTool, error)
76 cleanupFuncs []func()
77
78 provider provider.Provider
79 providerID string
80
81 titleProvider provider.Provider
82 summarizeProvider provider.Provider
83 summarizeProviderID string
84
85 activeRequests *csync.Map[string, context.CancelFunc]
86 promptQueue *csync.Map[string, []string]
87}
88
89var agentPromptMap = map[string]prompt.PromptID{
90 "coder": prompt.PromptCoder,
91 "task": prompt.PromptTask,
92}
93
94func NewAgent(
95 ctx context.Context,
96 agentCfg config.Agent,
97 // These services are needed in the tools
98 permissions permission.Service,
99 sessions session.Service,
100 messages message.Service,
101 history history.Service,
102 lspClients *csync.Map[string, *lsp.Client],
103) (Service, error) {
104 cfg := config.Get()
105
106 var agentToolFn func() (tools.BaseTool, error)
107 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
108 agentToolFn = func() (tools.BaseTool, error) {
109 taskAgentCfg := config.Get().Agents["task"]
110 if taskAgentCfg.ID == "" {
111 return nil, fmt.Errorf("task agent not found in config")
112 }
113 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
114 if err != nil {
115 return nil, fmt.Errorf("failed to create task agent: %w", err)
116 }
117 return NewAgentTool(taskAgent, sessions, messages), nil
118 }
119 }
120
121 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
122 if providerCfg == nil {
123 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
124 }
125 model := config.Get().GetModelByType(agentCfg.Model)
126
127 if model == nil {
128 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
129 }
130
131 promptID := agentPromptMap[agentCfg.ID]
132 if promptID == "" {
133 promptID = prompt.PromptDefault
134 }
135 opts := []provider.ProviderClientOption{
136 provider.WithModel(agentCfg.Model),
137 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
138 }
139 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
140 if err != nil {
141 return nil, err
142 }
143
144 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
145 var smallModelProviderCfg *config.ProviderConfig
146 if smallModelCfg.Provider == providerCfg.ID {
147 smallModelProviderCfg = providerCfg
148 } else {
149 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
150
151 if smallModelProviderCfg.ID == "" {
152 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
153 }
154 }
155 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
156 if smallModel.ID == "" {
157 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
158 }
159
160 titleOpts := []provider.ProviderClientOption{
161 provider.WithModel(config.SelectedModelTypeSmall),
162 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
163 }
164 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
165 if err != nil {
166 return nil, err
167 }
168
169 summarizeOpts := []provider.ProviderClientOption{
170 provider.WithModel(config.SelectedModelTypeLarge),
171 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
172 }
173 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
174 if err != nil {
175 return nil, err
176 }
177
178 baseToolsFn := func() map[string]tools.BaseTool {
179 slog.Debug("Initializing agent base tools", "agent", agentCfg.ID)
180 defer func() {
181 slog.Debug("Initialized agent base tools", "agent", agentCfg.ID)
182 }()
183
184 // Base tools available to all agents
185 cwd := cfg.WorkingDir()
186 result := make(map[string]tools.BaseTool)
187 for _, tool := range []tools.BaseTool{
188 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
189 tools.NewDownloadTool(permissions, cwd),
190 tools.NewEditTool(lspClients, permissions, history, cwd),
191 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
192 tools.NewFetchTool(permissions, cwd),
193 tools.NewGlobTool(cwd),
194 tools.NewGrepTool(cwd),
195 tools.NewLsTool(permissions, cwd),
196 tools.NewSourcegraphTool(),
197 tools.NewViewTool(lspClients, permissions, cwd),
198 tools.NewWriteTool(lspClients, permissions, history, cwd),
199 } {
200 result[tool.Name()] = tool
201 }
202 return result
203 }
204 mcpToolsFn := func() map[string]tools.BaseTool {
205 slog.Debug("Initializing agent mcp tools", "agent", agentCfg.ID)
206 defer func() {
207 slog.Debug("Initialized agent mcp tools", "agent", agentCfg.ID)
208 }()
209
210 mcpToolsOnce.Do(func() {
211 doGetMCPTools(ctx, permissions, cfg)
212 })
213
214 return maps.Collect(mcpTools.Seq2())
215 }
216
217 a := &agent{
218 Broker: pubsub.NewBroker[AgentEvent](),
219 agentCfg: agentCfg,
220 provider: agentProvider,
221 providerID: string(providerCfg.ID),
222 messages: messages,
223 sessions: sessions,
224 titleProvider: titleProvider,
225 summarizeProvider: summarizeProvider,
226 summarizeProviderID: string(providerCfg.ID),
227 agentToolFn: agentToolFn,
228 activeRequests: csync.NewMap[string, context.CancelFunc](),
229 mcpTools: csync.NewLazyMap(mcpToolsFn),
230 baseTools: csync.NewLazyMap(baseToolsFn),
231 promptQueue: csync.NewMap[string, []string](),
232 permissions: permissions,
233 lspClients: lspClients,
234 }
235 a.setupEvents(ctx)
236 return a, nil
237}
238
239func (a *agent) Model() catwalk.Model {
240 return *config.Get().GetModelByType(a.agentCfg.Model)
241}
242
243func (a *agent) Cancel(sessionID string) {
244 // Cancel regular requests
245 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
246 slog.Info("Request cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249
250 // Also check for summarize requests
251 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
252 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
253 cancel()
254 }
255
256 if a.QueuedPrompts(sessionID) > 0 {
257 slog.Info("Clearing queued prompts", "session_id", sessionID)
258 a.promptQueue.Del(sessionID)
259 }
260}
261
262func (a *agent) IsBusy() bool {
263 var busy bool
264 for cancelFunc := range a.activeRequests.Seq() {
265 if cancelFunc != nil {
266 busy = true
267 break
268 }
269 }
270 return busy
271}
272
273func (a *agent) IsSessionBusy(sessionID string) bool {
274 _, busy := a.activeRequests.Get(sessionID)
275 return busy
276}
277
278func (a *agent) QueuedPrompts(sessionID string) int {
279 l, ok := a.promptQueue.Get(sessionID)
280 if !ok {
281 return 0
282 }
283 return len(l)
284}
285
286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
287 if content == "" {
288 return nil
289 }
290 if a.titleProvider == nil {
291 return nil
292 }
293 session, err := a.sessions.Get(ctx, sessionID)
294 if err != nil {
295 return err
296 }
297 parts := []message.ContentPart{message.TextContent{
298 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
299 }}
300
301 // Use streaming approach like summarization
302 response := a.titleProvider.StreamResponse(
303 ctx,
304 []message.Message{
305 {
306 Role: message.User,
307 Parts: parts,
308 },
309 },
310 nil,
311 )
312
313 var finalResponse *provider.ProviderResponse
314 for r := range response {
315 if r.Error != nil {
316 return r.Error
317 }
318 finalResponse = r.Response
319 }
320
321 if finalResponse == nil {
322 return fmt.Errorf("no response received from title provider")
323 }
324
325 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
326
327 if idx := strings.Index(title, "</think>"); idx > 0 {
328 title = title[idx+len("</think>"):]
329 }
330
331 title = strings.TrimSpace(title)
332 if title == "" {
333 return nil
334 }
335
336 session.Title = title
337 _, err = a.sessions.Save(ctx, session)
338 return err
339}
340
341func (a *agent) err(err error) AgentEvent {
342 return AgentEvent{
343 Type: AgentEventTypeError,
344 Error: err,
345 }
346}
347
348func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
349 events := make(chan AgentEvent, 1)
350 if a.IsSessionBusy(sessionID) {
351 existing, ok := a.promptQueue.Get(sessionID)
352 if !ok {
353 existing = []string{}
354 }
355 existing = append(existing, content)
356 a.promptQueue.Set(sessionID, existing)
357 return nil, nil
358 }
359
360 genCtx, cancel := context.WithCancel(ctx)
361 a.activeRequests.Set(sessionID, cancel)
362 startTime := time.Now()
363
364 go func() {
365 slog.Debug("Request started", "sessionID", sessionID)
366 defer log.RecoverPanic("agent.Run", func() {
367 events <- a.err(fmt.Errorf("panic while running the agent"))
368 })
369 var attachmentParts []message.ContentPart
370 for _, attachment := range attachments {
371 if !a.Model().SupportsImages && strings.HasPrefix(attachment.MimeType, "image/") {
372 slog.Warn("Model does not support images, skipping attachment", "mimeType", attachment.MimeType, "fileName", attachment.FileName)
373 continue
374 }
375 attachmentParts = append(attachmentParts, message.BinaryContent{
376 Path: attachment.FilePath,
377 MIMEType: attachment.MimeType,
378 Data: attachment.Content,
379 })
380 }
381 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
382 if result.Error != nil {
383 if isCancelledErr(result.Error) {
384 slog.Error("Request canceled", "sessionID", sessionID)
385 } else {
386 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
387 event.Error(result.Error)
388 }
389 } else {
390 slog.Debug("Request completed", "sessionID", sessionID)
391 }
392 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
393 a.activeRequests.Del(sessionID)
394 cancel()
395 a.Publish(pubsub.CreatedEvent, result)
396 events <- result
397 close(events)
398 }()
399 a.eventPromptSent(sessionID)
400 return events, nil
401}
402
403func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
404 cfg := config.Get()
405 // List existing messages; if none, start title generation asynchronously.
406 msgs, err := a.messages.List(ctx, sessionID)
407 if err != nil {
408 return a.err(fmt.Errorf("failed to list messages: %w", err))
409 }
410 if len(msgs) == 0 {
411 go func() {
412 defer log.RecoverPanic("agent.Run", func() {
413 slog.Error("panic while generating title")
414 })
415 titleErr := a.generateTitle(ctx, sessionID, content)
416 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
417 slog.Error("failed to generate title", "error", titleErr)
418 }
419 }()
420 }
421 session, err := a.sessions.Get(ctx, sessionID)
422 if err != nil {
423 return a.err(fmt.Errorf("failed to get session: %w", err))
424 }
425 if session.SummaryMessageID != "" {
426 summaryMsgInex := -1
427 for i, msg := range msgs {
428 if msg.ID == session.SummaryMessageID {
429 summaryMsgInex = i
430 break
431 }
432 }
433 if summaryMsgInex != -1 {
434 msgs = msgs[summaryMsgInex:]
435 msgs[0].Role = message.User
436 }
437 }
438
439 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
440 if err != nil {
441 return a.err(fmt.Errorf("failed to create user message: %w", err))
442 }
443 // Append the new user message to the conversation history.
444 msgHistory := append(msgs, userMsg)
445
446 for {
447 // Check for cancellation before each iteration
448 select {
449 case <-ctx.Done():
450 return a.err(ctx.Err())
451 default:
452 // Continue processing
453 }
454 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
455 if err != nil {
456 if errors.Is(err, context.Canceled) {
457 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
458 a.messages.Update(context.Background(), agentMessage)
459 return a.err(ErrRequestCancelled)
460 }
461 return a.err(fmt.Errorf("failed to process events: %w", err))
462 }
463 if cfg.Options.Debug {
464 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
465 }
466 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
467 // We are not done, we need to respond with the tool response
468 msgHistory = append(msgHistory, agentMessage, *toolResults)
469 // If there are queued prompts, process the next one
470 nextPrompt, ok := a.promptQueue.Take(sessionID)
471 if ok {
472 for _, prompt := range nextPrompt {
473 // Create a new user message for the queued prompt
474 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
475 if err != nil {
476 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
477 }
478 // Append the new user message to the conversation history
479 msgHistory = append(msgHistory, userMsg)
480 }
481 }
482
483 continue
484 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
485 queuePrompts, ok := a.promptQueue.Take(sessionID)
486 if ok {
487 for _, prompt := range queuePrompts {
488 if prompt == "" {
489 continue
490 }
491 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
492 if err != nil {
493 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
494 }
495 msgHistory = append(msgHistory, userMsg)
496 }
497 continue
498 }
499 }
500 if agentMessage.FinishReason() == "" {
501 // Kujtim: could not track down where this is happening but this means its cancelled
502 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
503 _ = a.messages.Update(context.Background(), agentMessage)
504 return a.err(ErrRequestCancelled)
505 }
506 return AgentEvent{
507 Type: AgentEventTypeResponse,
508 Message: agentMessage,
509 Done: true,
510 }
511 }
512}
513
514func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
515 parts := []message.ContentPart{message.TextContent{Text: content}}
516 parts = append(parts, attachmentParts...)
517 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
518 Role: message.User,
519 Parts: parts,
520 })
521}
522
523func (a *agent) getAllTools() ([]tools.BaseTool, error) {
524 var allTools []tools.BaseTool
525 for tool := range a.baseTools.Seq() {
526 if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
527 allTools = append(allTools, tool)
528 }
529 }
530 if a.agentCfg.ID == "coder" {
531 allTools = slices.AppendSeq(allTools, a.mcpTools.Seq())
532 if a.lspClients.Len() > 0 {
533 allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients))
534 }
535 }
536 if a.agentToolFn != nil {
537 agentTool, agentToolErr := a.agentToolFn()
538 if agentToolErr != nil {
539 return nil, agentToolErr
540 }
541 allTools = append(allTools, agentTool)
542 }
543 return allTools, nil
544}
545
546func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
547 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
548
549 // Create the assistant message first so the spinner shows immediately
550 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
551 Role: message.Assistant,
552 Parts: []message.ContentPart{},
553 Model: a.Model().ID,
554 Provider: a.providerID,
555 })
556 if err != nil {
557 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
558 }
559
560 allTools, toolsErr := a.getAllTools()
561 if toolsErr != nil {
562 return assistantMsg, nil, toolsErr
563 }
564 // Now collect tools (which may block on MCP initialization)
565 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
566
567 // Add the session and message ID into the context if needed by tools.
568 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
569
570loop:
571 for {
572 select {
573 case event, ok := <-eventChan:
574 if !ok {
575 break loop
576 }
577 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
578 if errors.Is(processErr, context.Canceled) {
579 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
580 } else {
581 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
582 }
583 return assistantMsg, nil, processErr
584 }
585 case <-ctx.Done():
586 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
587 return assistantMsg, nil, ctx.Err()
588 }
589 }
590
591 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
592 toolCalls := assistantMsg.ToolCalls()
593 for i, toolCall := range toolCalls {
594 select {
595 case <-ctx.Done():
596 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
597 // Make all future tool calls cancelled
598 for j := i; j < len(toolCalls); j++ {
599 toolResults[j] = message.ToolResult{
600 ToolCallID: toolCalls[j].ID,
601 Content: "Tool execution canceled by user",
602 IsError: true,
603 }
604 }
605 goto out
606 default:
607 // Continue processing
608 var tool tools.BaseTool
609 allTools, _ = a.getAllTools()
610 for _, availableTool := range allTools {
611 if availableTool.Info().Name == toolCall.Name {
612 tool = availableTool
613 break
614 }
615 }
616
617 // Tool not found
618 if tool == nil {
619 toolResults[i] = message.ToolResult{
620 ToolCallID: toolCall.ID,
621 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
622 IsError: true,
623 }
624 continue
625 }
626
627 // Run tool in goroutine to allow cancellation
628 type toolExecResult struct {
629 response tools.ToolResponse
630 err error
631 }
632 resultChan := make(chan toolExecResult, 1)
633
634 go func() {
635 response, err := tool.Run(ctx, tools.ToolCall{
636 ID: toolCall.ID,
637 Name: toolCall.Name,
638 Input: toolCall.Input,
639 })
640 resultChan <- toolExecResult{response: response, err: err}
641 }()
642
643 var toolResponse tools.ToolResponse
644 var toolErr error
645
646 select {
647 case <-ctx.Done():
648 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
649 // Mark remaining tool calls as cancelled
650 for j := i; j < len(toolCalls); j++ {
651 toolResults[j] = message.ToolResult{
652 ToolCallID: toolCalls[j].ID,
653 Content: "Tool execution canceled by user",
654 IsError: true,
655 }
656 }
657 goto out
658 case result := <-resultChan:
659 toolResponse = result.response
660 toolErr = result.err
661 }
662
663 if toolErr != nil {
664 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
665 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
666 toolResults[i] = message.ToolResult{
667 ToolCallID: toolCall.ID,
668 Content: "Permission denied",
669 IsError: true,
670 }
671 for j := i + 1; j < len(toolCalls); j++ {
672 toolResults[j] = message.ToolResult{
673 ToolCallID: toolCalls[j].ID,
674 Content: "Tool execution canceled by user",
675 IsError: true,
676 }
677 }
678 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
679 break
680 }
681 }
682 toolResults[i] = message.ToolResult{
683 ToolCallID: toolCall.ID,
684 Content: toolResponse.Content,
685 Metadata: toolResponse.Metadata,
686 IsError: toolResponse.IsError,
687 }
688 }
689 }
690out:
691 if len(toolResults) == 0 {
692 return assistantMsg, nil, nil
693 }
694 parts := make([]message.ContentPart, 0)
695 for _, tr := range toolResults {
696 parts = append(parts, tr)
697 }
698 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
699 Role: message.Tool,
700 Parts: parts,
701 Provider: a.providerID,
702 })
703 if err != nil {
704 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
705 }
706
707 return assistantMsg, &msg, err
708}
709
710func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
711 msg.AddFinish(finishReason, message, details)
712 _ = a.messages.Update(ctx, *msg)
713}
714
715func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
716 select {
717 case <-ctx.Done():
718 return ctx.Err()
719 default:
720 // Continue processing.
721 }
722
723 switch event.Type {
724 case provider.EventThinkingDelta:
725 assistantMsg.AppendReasoningContent(event.Thinking)
726 return a.messages.Update(ctx, *assistantMsg)
727 case provider.EventSignatureDelta:
728 assistantMsg.AppendReasoningSignature(event.Signature)
729 return a.messages.Update(ctx, *assistantMsg)
730 case provider.EventContentDelta:
731 assistantMsg.FinishThinking()
732 assistantMsg.AppendContent(event.Content)
733 return a.messages.Update(ctx, *assistantMsg)
734 case provider.EventToolUseStart:
735 assistantMsg.FinishThinking()
736 slog.Info("Tool call started", "toolCall", event.ToolCall)
737 assistantMsg.AddToolCall(*event.ToolCall)
738 return a.messages.Update(ctx, *assistantMsg)
739 case provider.EventToolUseDelta:
740 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
741 return a.messages.Update(ctx, *assistantMsg)
742 case provider.EventToolUseStop:
743 slog.Info("Finished tool call", "toolCall", event.ToolCall)
744 assistantMsg.FinishToolCall(event.ToolCall.ID)
745 return a.messages.Update(ctx, *assistantMsg)
746 case provider.EventError:
747 return event.Error
748 case provider.EventComplete:
749 assistantMsg.FinishThinking()
750 assistantMsg.SetToolCalls(event.Response.ToolCalls)
751 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
752 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
753 return fmt.Errorf("failed to update message: %w", err)
754 }
755 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
756 }
757
758 return nil
759}
760
761func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
762 sess, err := a.sessions.Get(ctx, sessionID)
763 if err != nil {
764 return fmt.Errorf("failed to get session: %w", err)
765 }
766
767 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
768 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
769 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
770 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
771
772 a.eventTokensUsed(sessionID, usage, cost)
773
774 sess.Cost += cost
775 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
776 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
777
778 _, err = a.sessions.Save(ctx, sess)
779 if err != nil {
780 return fmt.Errorf("failed to save session: %w", err)
781 }
782 return nil
783}
784
785func (a *agent) Summarize(ctx context.Context, sessionID string) error {
786 if a.summarizeProvider == nil {
787 return fmt.Errorf("summarize provider not available")
788 }
789
790 // Check if session is busy
791 if a.IsSessionBusy(sessionID) {
792 return ErrSessionBusy
793 }
794
795 // Create a new context with cancellation
796 summarizeCtx, cancel := context.WithCancel(ctx)
797
798 // Store the cancel function in activeRequests to allow cancellation
799 a.activeRequests.Set(sessionID+"-summarize", cancel)
800
801 go func() {
802 defer a.activeRequests.Del(sessionID + "-summarize")
803 defer cancel()
804 event := AgentEvent{
805 Type: AgentEventTypeSummarize,
806 Progress: "Starting summarization...",
807 }
808
809 a.Publish(pubsub.CreatedEvent, event)
810 // Get all messages from the session
811 msgs, err := a.messages.List(summarizeCtx, sessionID)
812 if err != nil {
813 event = AgentEvent{
814 Type: AgentEventTypeError,
815 Error: fmt.Errorf("failed to list messages: %w", err),
816 Done: true,
817 }
818 a.Publish(pubsub.CreatedEvent, event)
819 return
820 }
821 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
822
823 if len(msgs) == 0 {
824 event = AgentEvent{
825 Type: AgentEventTypeError,
826 Error: fmt.Errorf("no messages to summarize"),
827 Done: true,
828 }
829 a.Publish(pubsub.CreatedEvent, event)
830 return
831 }
832
833 event = AgentEvent{
834 Type: AgentEventTypeSummarize,
835 Progress: "Analyzing conversation...",
836 }
837 a.Publish(pubsub.CreatedEvent, event)
838
839 // Add a system message to guide the summarization
840 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
841
842 // Create a new message with the summarize prompt
843 promptMsg := message.Message{
844 Role: message.User,
845 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
846 }
847
848 // Append the prompt to the messages
849 msgsWithPrompt := append(msgs, promptMsg)
850
851 event = AgentEvent{
852 Type: AgentEventTypeSummarize,
853 Progress: "Generating summary...",
854 }
855
856 a.Publish(pubsub.CreatedEvent, event)
857
858 // Send the messages to the summarize provider
859 response := a.summarizeProvider.StreamResponse(
860 summarizeCtx,
861 msgsWithPrompt,
862 nil,
863 )
864 var finalResponse *provider.ProviderResponse
865 for r := range response {
866 if r.Error != nil {
867 event = AgentEvent{
868 Type: AgentEventTypeError,
869 Error: fmt.Errorf("failed to summarize: %w", r.Error),
870 Done: true,
871 }
872 a.Publish(pubsub.CreatedEvent, event)
873 return
874 }
875 finalResponse = r.Response
876 }
877
878 summary := strings.TrimSpace(finalResponse.Content)
879 if summary == "" {
880 event = AgentEvent{
881 Type: AgentEventTypeError,
882 Error: fmt.Errorf("empty summary returned"),
883 Done: true,
884 }
885 a.Publish(pubsub.CreatedEvent, event)
886 return
887 }
888 shell := shell.GetPersistentShell(config.Get().WorkingDir())
889 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
890 event = AgentEvent{
891 Type: AgentEventTypeSummarize,
892 Progress: "Creating new session...",
893 }
894
895 a.Publish(pubsub.CreatedEvent, event)
896 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
897 if err != nil {
898 event = AgentEvent{
899 Type: AgentEventTypeError,
900 Error: fmt.Errorf("failed to get session: %w", err),
901 Done: true,
902 }
903
904 a.Publish(pubsub.CreatedEvent, event)
905 return
906 }
907 // Create a message in the new session with the summary
908 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
909 Role: message.Assistant,
910 Parts: []message.ContentPart{
911 message.TextContent{Text: summary},
912 message.Finish{
913 Reason: message.FinishReasonEndTurn,
914 Time: time.Now().Unix(),
915 },
916 },
917 Model: a.summarizeProvider.Model().ID,
918 Provider: a.summarizeProviderID,
919 })
920 if err != nil {
921 event = AgentEvent{
922 Type: AgentEventTypeError,
923 Error: fmt.Errorf("failed to create summary message: %w", err),
924 Done: true,
925 }
926
927 a.Publish(pubsub.CreatedEvent, event)
928 return
929 }
930 oldSession.SummaryMessageID = msg.ID
931 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
932 oldSession.PromptTokens = 0
933 model := a.summarizeProvider.Model()
934 usage := finalResponse.Usage
935 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
936 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
937 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
938 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
939 oldSession.Cost += cost
940 _, err = a.sessions.Save(summarizeCtx, oldSession)
941 if err != nil {
942 event = AgentEvent{
943 Type: AgentEventTypeError,
944 Error: fmt.Errorf("failed to save session: %w", err),
945 Done: true,
946 }
947 a.Publish(pubsub.CreatedEvent, event)
948 }
949
950 event = AgentEvent{
951 Type: AgentEventTypeSummarize,
952 SessionID: oldSession.ID,
953 Progress: "Summary complete",
954 Done: true,
955 }
956 a.Publish(pubsub.CreatedEvent, event)
957 // Send final success event with the new session ID
958 }()
959
960 return nil
961}
962
963func (a *agent) ClearQueue(sessionID string) {
964 if a.QueuedPrompts(sessionID) > 0 {
965 slog.Info("Clearing queued prompts", "session_id", sessionID)
966 a.promptQueue.Del(sessionID)
967 }
968}
969
970func (a *agent) CancelAll() {
971 if !a.IsBusy() {
972 return
973 }
974 for key := range a.activeRequests.Seq2() {
975 a.Cancel(key) // key is sessionID
976 }
977
978 for _, cleanup := range a.cleanupFuncs {
979 if cleanup != nil {
980 cleanup()
981 }
982 }
983
984 timeout := time.After(5 * time.Second)
985 for a.IsBusy() {
986 select {
987 case <-timeout:
988 return
989 default:
990 time.Sleep(200 * time.Millisecond)
991 }
992 }
993}
994
995func (a *agent) UpdateModel() error {
996 cfg := config.Get()
997
998 // Get current provider configuration
999 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
1000 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
1001 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
1002 }
1003
1004 // Check if provider has changed
1005 if string(currentProviderCfg.ID) != a.providerID {
1006 // Provider changed, need to recreate the main provider
1007 model := cfg.GetModelByType(a.agentCfg.Model)
1008 if model.ID == "" {
1009 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1010 }
1011
1012 promptID := agentPromptMap[a.agentCfg.ID]
1013 if promptID == "" {
1014 promptID = prompt.PromptDefault
1015 }
1016
1017 opts := []provider.ProviderClientOption{
1018 provider.WithModel(a.agentCfg.Model),
1019 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1020 }
1021
1022 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1023 if err != nil {
1024 return fmt.Errorf("failed to create new provider: %w", err)
1025 }
1026
1027 // Update the provider and provider ID
1028 a.provider = newProvider
1029 a.providerID = string(currentProviderCfg.ID)
1030 }
1031
1032 // Check if providers have changed for title (small) and summarize (large)
1033 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1034 var smallModelProviderCfg config.ProviderConfig
1035 for p := range cfg.Providers.Seq() {
1036 if p.ID == smallModelCfg.Provider {
1037 smallModelProviderCfg = p
1038 break
1039 }
1040 }
1041 if smallModelProviderCfg.ID == "" {
1042 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1043 }
1044
1045 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1046 var largeModelProviderCfg config.ProviderConfig
1047 for p := range cfg.Providers.Seq() {
1048 if p.ID == largeModelCfg.Provider {
1049 largeModelProviderCfg = p
1050 break
1051 }
1052 }
1053 if largeModelProviderCfg.ID == "" {
1054 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1055 }
1056
1057 var maxTitleTokens int64 = 40
1058
1059 // if the max output is too low for the gemini provider it won't return anything
1060 if smallModelCfg.Provider == "gemini" {
1061 maxTitleTokens = 1000
1062 }
1063 // Recreate title provider
1064 titleOpts := []provider.ProviderClientOption{
1065 provider.WithModel(config.SelectedModelTypeSmall),
1066 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1067 provider.WithMaxTokens(maxTitleTokens),
1068 }
1069 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1070 if err != nil {
1071 return fmt.Errorf("failed to create new title provider: %w", err)
1072 }
1073 a.titleProvider = newTitleProvider
1074
1075 // Recreate summarize provider if provider changed (now large model)
1076 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1077 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1078 if largeModel == nil {
1079 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1080 }
1081 summarizeOpts := []provider.ProviderClientOption{
1082 provider.WithModel(config.SelectedModelTypeLarge),
1083 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1084 }
1085 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1086 if err != nil {
1087 return fmt.Errorf("failed to create new summarize provider: %w", err)
1088 }
1089 a.summarizeProvider = newSummarizeProvider
1090 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1091 }
1092
1093 return nil
1094}
1095
1096func (a *agent) setupEvents(ctx context.Context) {
1097 ctx, cancel := context.WithCancel(ctx)
1098
1099 go func() {
1100 subCh := SubscribeMCPEvents(ctx)
1101
1102 for {
1103 select {
1104 case event, ok := <-subCh:
1105 if !ok {
1106 slog.Debug("MCPEvents subscription channel closed")
1107 return
1108 }
1109 name := event.Payload.Name
1110 c, ok := mcpClients.Get(name)
1111 if !ok {
1112 slog.Warn("MCP client not found for tools/prompts update", "name", name)
1113 continue
1114 }
1115 switch event.Payload.Type {
1116 case MCPEventToolsListChanged:
1117 cfg := config.Get()
1118 tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1119 if err != nil {
1120 slog.Error("error listing tools", "error", err)
1121 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
1122 _ = c.Close()
1123 continue
1124 }
1125 updateMcpTools(name, tools)
1126 a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
1127 case MCPEventPromptsListChanged:
1128 prompts, err := getPrompts(ctx, c)
1129 if err != nil {
1130 slog.Error("error listing prompts", "error", err)
1131 updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
1132 _ = c.Close()
1133 continue
1134 }
1135 updateMcpPrompts(name, prompts)
1136 default:
1137 continue
1138 }
1139 updateMCPState(name, MCPStateConnected, nil, c, MCPCounts{
1140 Tools: mcpTools.Len(),
1141 Prompts: mcpPrompts.Len(),
1142 })
1143 case <-ctx.Done():
1144 slog.Debug("MCPEvents subscription cancelled")
1145 return
1146 }
1147 }
1148 }()
1149
1150 a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1151}