1// Package agent contains the implementation of the AI agent service.
2package agent
3
4import (
5 "context"
6 "errors"
7 "fmt"
8 "log/slog"
9 "maps"
10 "slices"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/event"
18 "github.com/charmbracelet/crush/internal/history"
19 "github.com/charmbracelet/crush/internal/llm/prompt"
20 "github.com/charmbracelet/crush/internal/llm/provider"
21 "github.com/charmbracelet/crush/internal/llm/tools"
22 "github.com/charmbracelet/crush/internal/log"
23 "github.com/charmbracelet/crush/internal/lsp"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/pubsub"
27 "github.com/charmbracelet/crush/internal/session"
28 "github.com/charmbracelet/crush/internal/shell"
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() catwalk.Model
53 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 Summarize(ctx context.Context, sessionID string) error
59 UpdateModel() error
60 QueuedPrompts(sessionID string) int
61 ClearQueue(sessionID string)
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69 permissions permission.Service
70 baseTools *csync.Map[string, tools.BaseTool]
71 mcpTools *csync.Map[string, tools.BaseTool]
72 lspClients *csync.Map[string, *lsp.Client]
73
74 // We need this to be able to update it when model changes
75 agentToolFn func() (tools.BaseTool, error)
76 cleanupFuncs []func()
77
78 provider provider.Provider
79 providerID string
80
81 titleProvider provider.Provider
82 summarizeProvider provider.Provider
83 summarizeProviderID string
84
85 activeRequests *csync.Map[string, context.CancelFunc]
86 promptQueue *csync.Map[string, []string]
87}
88
89var agentPromptMap = map[string]prompt.PromptID{
90 "coder": prompt.PromptCoder,
91 "task": prompt.PromptTask,
92}
93
94func NewAgent(
95 ctx context.Context,
96 agentCfg config.Agent,
97 // These services are needed in the tools
98 permissions permission.Service,
99 sessions session.Service,
100 messages message.Service,
101 history history.Service,
102 lspClients *csync.Map[string, *lsp.Client],
103) (Service, error) {
104 cfg := config.Get()
105
106 var agentToolFn func() (tools.BaseTool, error)
107 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
108 agentToolFn = func() (tools.BaseTool, error) {
109 taskAgentCfg := config.Get().Agents["task"]
110 if taskAgentCfg.ID == "" {
111 return nil, fmt.Errorf("task agent not found in config")
112 }
113 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
114 if err != nil {
115 return nil, fmt.Errorf("failed to create task agent: %w", err)
116 }
117 return NewAgentTool(taskAgent, sessions, messages), nil
118 }
119 }
120
121 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
122 if providerCfg == nil {
123 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
124 }
125 model := config.Get().GetModelByType(agentCfg.Model)
126
127 if model == nil {
128 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
129 }
130
131 promptID := agentPromptMap[agentCfg.ID]
132 if promptID == "" {
133 promptID = prompt.PromptDefault
134 }
135 opts := []provider.ProviderClientOption{
136 provider.WithModel(agentCfg.Model),
137 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
138 }
139 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
140 if err != nil {
141 return nil, err
142 }
143
144 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
145 var smallModelProviderCfg *config.ProviderConfig
146 if smallModelCfg.Provider == providerCfg.ID {
147 smallModelProviderCfg = providerCfg
148 } else {
149 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
150
151 if smallModelProviderCfg.ID == "" {
152 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
153 }
154 }
155 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
156 if smallModel.ID == "" {
157 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
158 }
159
160 titleOpts := []provider.ProviderClientOption{
161 provider.WithModel(config.SelectedModelTypeSmall),
162 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
163 }
164 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
165 if err != nil {
166 return nil, err
167 }
168
169 summarizeOpts := []provider.ProviderClientOption{
170 provider.WithModel(config.SelectedModelTypeLarge),
171 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
172 }
173 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
174 if err != nil {
175 return nil, err
176 }
177
178 baseToolsFn := func() map[string]tools.BaseTool {
179 slog.Debug("Initializing agent base tools", "agent", agentCfg.ID)
180 defer func() {
181 slog.Debug("Initialized agent base tools", "agent", agentCfg.ID)
182 }()
183
184 // Base tools available to all agents
185 cwd := cfg.WorkingDir()
186 result := make(map[string]tools.BaseTool)
187 for _, tool := range []tools.BaseTool{
188 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
189 tools.NewDownloadTool(permissions, cwd),
190 tools.NewEditTool(lspClients, permissions, history, cwd),
191 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
192 tools.NewFetchTool(permissions, cwd),
193 tools.NewGlobTool(cwd),
194 tools.NewGrepTool(cwd),
195 tools.NewLsTool(permissions, cwd),
196 tools.NewSourcegraphTool(),
197 tools.NewViewTool(lspClients, permissions, cwd),
198 tools.NewWriteTool(lspClients, permissions, history, cwd),
199 } {
200 result[tool.Name()] = tool
201 }
202 return result
203 }
204 mcpToolsFn := func() map[string]tools.BaseTool {
205 slog.Debug("Initializing agent mcp tools", "agent", agentCfg.ID)
206 defer func() {
207 slog.Debug("Initialized agent mcp tools", "agent", agentCfg.ID)
208 }()
209
210 mcpToolsOnce.Do(func() {
211 doGetMCPTools(ctx, permissions, cfg)
212 })
213
214 return maps.Collect(mcpTools.Seq2())
215 }
216
217 a := &agent{
218 Broker: pubsub.NewBroker[AgentEvent](),
219 agentCfg: agentCfg,
220 provider: agentProvider,
221 providerID: string(providerCfg.ID),
222 messages: messages,
223 sessions: sessions,
224 titleProvider: titleProvider,
225 summarizeProvider: summarizeProvider,
226 summarizeProviderID: string(providerCfg.ID),
227 agentToolFn: agentToolFn,
228 activeRequests: csync.NewMap[string, context.CancelFunc](),
229 mcpTools: csync.NewLazyMap(mcpToolsFn),
230 baseTools: csync.NewLazyMap(baseToolsFn),
231 promptQueue: csync.NewMap[string, []string](),
232 permissions: permissions,
233 lspClients: lspClients,
234 }
235 a.setupEvents(ctx)
236 return a, nil
237}
238
239func (a *agent) Model() catwalk.Model {
240 return *config.Get().GetModelByType(a.agentCfg.Model)
241}
242
243func (a *agent) Cancel(sessionID string) {
244 // Cancel regular requests
245 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
246 slog.Info("Request cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249
250 // Also check for summarize requests
251 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
252 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
253 cancel()
254 }
255
256 if a.QueuedPrompts(sessionID) > 0 {
257 slog.Info("Clearing queued prompts", "session_id", sessionID)
258 a.promptQueue.Del(sessionID)
259 }
260}
261
262func (a *agent) IsBusy() bool {
263 var busy bool
264 for cancelFunc := range a.activeRequests.Seq() {
265 if cancelFunc != nil {
266 busy = true
267 break
268 }
269 }
270 return busy
271}
272
273func (a *agent) IsSessionBusy(sessionID string) bool {
274 _, busy := a.activeRequests.Get(sessionID)
275 return busy
276}
277
278func (a *agent) QueuedPrompts(sessionID string) int {
279 l, ok := a.promptQueue.Get(sessionID)
280 if !ok {
281 return 0
282 }
283 return len(l)
284}
285
286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
287 if content == "" {
288 return nil
289 }
290 if a.titleProvider == nil {
291 return nil
292 }
293 session, err := a.sessions.Get(ctx, sessionID)
294 if err != nil {
295 return err
296 }
297 parts := []message.ContentPart{message.TextContent{
298 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
299 }}
300
301 // Use streaming approach like summarization
302 response := a.titleProvider.StreamResponse(
303 ctx,
304 []message.Message{
305 {
306 Role: message.User,
307 Parts: parts,
308 },
309 },
310 nil,
311 )
312
313 var finalResponse *provider.ProviderResponse
314 for r := range response {
315 if r.Error != nil {
316 return r.Error
317 }
318 finalResponse = r.Response
319 }
320
321 if finalResponse == nil {
322 return fmt.Errorf("no response received from title provider")
323 }
324
325 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
326
327 if idx := strings.Index(title, "</think>"); idx > 0 {
328 title = title[idx+len("</think>"):]
329 }
330
331 title = strings.TrimSpace(title)
332 if title == "" {
333 return nil
334 }
335
336 session.Title = title
337 _, err = a.sessions.Save(ctx, session)
338 return err
339}
340
341func (a *agent) err(err error) AgentEvent {
342 return AgentEvent{
343 Type: AgentEventTypeError,
344 Error: err,
345 }
346}
347
348func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
349 if !a.Model().SupportsImages && attachments != nil {
350 attachments = nil
351 }
352 events := make(chan AgentEvent, 1)
353 if a.IsSessionBusy(sessionID) {
354 existing, ok := a.promptQueue.Get(sessionID)
355 if !ok {
356 existing = []string{}
357 }
358 existing = append(existing, content)
359 a.promptQueue.Set(sessionID, existing)
360 return nil, nil
361 }
362
363 genCtx, cancel := context.WithCancel(ctx)
364 a.activeRequests.Set(sessionID, cancel)
365 startTime := time.Now()
366
367 go func() {
368 slog.Debug("Request started", "sessionID", sessionID)
369 defer log.RecoverPanic("agent.Run", func() {
370 events <- a.err(fmt.Errorf("panic while running the agent"))
371 })
372 var attachmentParts []message.ContentPart
373 for _, attachment := range attachments {
374 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
375 }
376 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
377 if result.Error != nil {
378 if isCancelledErr(result.Error) {
379 slog.Error("Request canceled", "sessionID", sessionID)
380 } else {
381 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
382 event.Error(result.Error)
383 }
384 } else {
385 slog.Debug("Request completed", "sessionID", sessionID)
386 }
387 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
388 a.activeRequests.Del(sessionID)
389 cancel()
390 a.Publish(pubsub.CreatedEvent, result)
391 events <- result
392 close(events)
393 }()
394 a.eventPromptSent(sessionID)
395 return events, nil
396}
397
398func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
399 cfg := config.Get()
400 // List existing messages; if none, start title generation asynchronously.
401 msgs, err := a.messages.List(ctx, sessionID)
402 if err != nil {
403 return a.err(fmt.Errorf("failed to list messages: %w", err))
404 }
405 if len(msgs) == 0 {
406 go func() {
407 defer log.RecoverPanic("agent.Run", func() {
408 slog.Error("panic while generating title")
409 })
410 titleErr := a.generateTitle(ctx, sessionID, content)
411 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
412 slog.Error("failed to generate title", "error", titleErr)
413 }
414 }()
415 }
416 session, err := a.sessions.Get(ctx, sessionID)
417 if err != nil {
418 return a.err(fmt.Errorf("failed to get session: %w", err))
419 }
420 if session.SummaryMessageID != "" {
421 summaryMsgInex := -1
422 for i, msg := range msgs {
423 if msg.ID == session.SummaryMessageID {
424 summaryMsgInex = i
425 break
426 }
427 }
428 if summaryMsgInex != -1 {
429 msgs = msgs[summaryMsgInex:]
430 msgs[0].Role = message.User
431 }
432 }
433
434 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
435 if err != nil {
436 return a.err(fmt.Errorf("failed to create user message: %w", err))
437 }
438 // Append the new user message to the conversation history.
439 msgHistory := append(msgs, userMsg)
440
441 for {
442 // Check for cancellation before each iteration
443 select {
444 case <-ctx.Done():
445 return a.err(ctx.Err())
446 default:
447 // Continue processing
448 }
449 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
450 if err != nil {
451 if errors.Is(err, context.Canceled) {
452 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
453 a.messages.Update(context.Background(), agentMessage)
454 return a.err(ErrRequestCancelled)
455 }
456 return a.err(fmt.Errorf("failed to process events: %w", err))
457 }
458 if cfg.Options.Debug {
459 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
460 }
461 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
462 // We are not done, we need to respond with the tool response
463 msgHistory = append(msgHistory, agentMessage, *toolResults)
464 // If there are queued prompts, process the next one
465 nextPrompt, ok := a.promptQueue.Take(sessionID)
466 if ok {
467 for _, prompt := range nextPrompt {
468 // Create a new user message for the queued prompt
469 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
470 if err != nil {
471 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
472 }
473 // Append the new user message to the conversation history
474 msgHistory = append(msgHistory, userMsg)
475 }
476 }
477
478 continue
479 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
480 queuePrompts, ok := a.promptQueue.Take(sessionID)
481 if ok {
482 for _, prompt := range queuePrompts {
483 if prompt == "" {
484 continue
485 }
486 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
487 if err != nil {
488 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
489 }
490 msgHistory = append(msgHistory, userMsg)
491 }
492 continue
493 }
494 }
495 if agentMessage.FinishReason() == "" {
496 // Kujtim: could not track down where this is happening but this means its cancelled
497 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
498 _ = a.messages.Update(context.Background(), agentMessage)
499 return a.err(ErrRequestCancelled)
500 }
501 return AgentEvent{
502 Type: AgentEventTypeResponse,
503 Message: agentMessage,
504 Done: true,
505 }
506 }
507}
508
509func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
510 parts := []message.ContentPart{message.TextContent{Text: content}}
511 parts = append(parts, attachmentParts...)
512 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
513 Role: message.User,
514 Parts: parts,
515 })
516}
517
518func (a *agent) getAllTools() ([]tools.BaseTool, error) {
519 var allTools []tools.BaseTool
520 for tool := range a.baseTools.Seq() {
521 if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
522 allTools = append(allTools, tool)
523 }
524 }
525 if a.agentCfg.ID == "coder" {
526 allTools = slices.AppendSeq(allTools, a.mcpTools.Seq())
527 if a.lspClients.Len() > 0 {
528 allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients), tools.NewReferencesTool(a.lspClients))
529 }
530 }
531 if a.agentToolFn != nil {
532 agentTool, agentToolErr := a.agentToolFn()
533 if agentToolErr != nil {
534 return nil, agentToolErr
535 }
536 allTools = append(allTools, agentTool)
537 }
538
539 slices.SortFunc(allTools, func(a, b tools.BaseTool) int {
540 return strings.Compare(a.Name(), b.Name())
541 })
542 return allTools, nil
543}
544
545func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
546 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
547
548 // Create the assistant message first so the spinner shows immediately
549 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
550 Role: message.Assistant,
551 Parts: []message.ContentPart{},
552 Model: a.Model().ID,
553 Provider: a.providerID,
554 })
555 if err != nil {
556 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
557 }
558
559 allTools, toolsErr := a.getAllTools()
560 if toolsErr != nil {
561 return assistantMsg, nil, toolsErr
562 }
563 // Now collect tools (which may block on MCP initialization)
564 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
565
566 // Add the session and message ID into the context if needed by tools.
567 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
568
569loop:
570 for {
571 select {
572 case event, ok := <-eventChan:
573 if !ok {
574 break loop
575 }
576 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
577 if errors.Is(processErr, context.Canceled) {
578 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
579 } else {
580 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
581 }
582 return assistantMsg, nil, processErr
583 }
584 case <-ctx.Done():
585 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
586 return assistantMsg, nil, ctx.Err()
587 }
588 }
589
590 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
591 toolCalls := assistantMsg.ToolCalls()
592 for i, toolCall := range toolCalls {
593 select {
594 case <-ctx.Done():
595 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
596 // Make all future tool calls cancelled
597 for j := i; j < len(toolCalls); j++ {
598 toolResults[j] = message.ToolResult{
599 ToolCallID: toolCalls[j].ID,
600 Content: "Tool execution canceled by user",
601 IsError: true,
602 }
603 }
604 goto out
605 default:
606 // Continue processing
607 var tool tools.BaseTool
608 allTools, _ = a.getAllTools()
609 for _, availableTool := range allTools {
610 if availableTool.Info().Name == toolCall.Name {
611 tool = availableTool
612 break
613 }
614 }
615
616 // Tool not found
617 if tool == nil {
618 toolResults[i] = message.ToolResult{
619 ToolCallID: toolCall.ID,
620 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
621 IsError: true,
622 }
623 continue
624 }
625
626 // Run tool in goroutine to allow cancellation
627 type toolExecResult struct {
628 response tools.ToolResponse
629 err error
630 }
631 resultChan := make(chan toolExecResult, 1)
632
633 go func() {
634 response, err := tool.Run(ctx, tools.ToolCall{
635 ID: toolCall.ID,
636 Name: toolCall.Name,
637 Input: toolCall.Input,
638 })
639 resultChan <- toolExecResult{response: response, err: err}
640 }()
641
642 var toolResponse tools.ToolResponse
643 var toolErr error
644
645 select {
646 case <-ctx.Done():
647 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
648 // Mark remaining tool calls as cancelled
649 for j := i; j < len(toolCalls); j++ {
650 toolResults[j] = message.ToolResult{
651 ToolCallID: toolCalls[j].ID,
652 Content: "Tool execution canceled by user",
653 IsError: true,
654 }
655 }
656 goto out
657 case result := <-resultChan:
658 toolResponse = result.response
659 toolErr = result.err
660 }
661
662 if toolErr != nil {
663 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
664 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
665 toolResults[i] = message.ToolResult{
666 ToolCallID: toolCall.ID,
667 Content: "Permission denied",
668 IsError: true,
669 }
670 for j := i + 1; j < len(toolCalls); j++ {
671 toolResults[j] = message.ToolResult{
672 ToolCallID: toolCalls[j].ID,
673 Content: "Tool execution canceled by user",
674 IsError: true,
675 }
676 }
677 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
678 break
679 }
680 }
681 toolResults[i] = message.ToolResult{
682 ToolCallID: toolCall.ID,
683 Content: toolResponse.Content,
684 Metadata: toolResponse.Metadata,
685 IsError: toolResponse.IsError,
686 }
687 }
688 }
689out:
690 if len(toolResults) == 0 {
691 return assistantMsg, nil, nil
692 }
693 parts := make([]message.ContentPart, 0)
694 for _, tr := range toolResults {
695 parts = append(parts, tr)
696 }
697 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
698 Role: message.Tool,
699 Parts: parts,
700 Provider: a.providerID,
701 })
702 if err != nil {
703 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
704 }
705
706 return assistantMsg, &msg, err
707}
708
709func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
710 msg.AddFinish(finishReason, message, details)
711 _ = a.messages.Update(ctx, *msg)
712}
713
714func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
715 select {
716 case <-ctx.Done():
717 return ctx.Err()
718 default:
719 // Continue processing.
720 }
721
722 switch event.Type {
723 case provider.EventThinkingDelta:
724 assistantMsg.AppendReasoningContent(event.Thinking)
725 return a.messages.Update(ctx, *assistantMsg)
726 case provider.EventSignatureDelta:
727 assistantMsg.AppendReasoningSignature(event.Signature)
728 return a.messages.Update(ctx, *assistantMsg)
729 case provider.EventContentDelta:
730 assistantMsg.FinishThinking()
731 assistantMsg.AppendContent(event.Content)
732 return a.messages.Update(ctx, *assistantMsg)
733 case provider.EventToolUseStart:
734 assistantMsg.FinishThinking()
735 slog.Info("Tool call started", "toolCall", event.ToolCall)
736 assistantMsg.AddToolCall(*event.ToolCall)
737 return a.messages.Update(ctx, *assistantMsg)
738 case provider.EventToolUseDelta:
739 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
740 return a.messages.Update(ctx, *assistantMsg)
741 case provider.EventToolUseStop:
742 slog.Info("Finished tool call", "toolCall", event.ToolCall)
743 assistantMsg.FinishToolCall(event.ToolCall.ID)
744 return a.messages.Update(ctx, *assistantMsg)
745 case provider.EventError:
746 return event.Error
747 case provider.EventComplete:
748 assistantMsg.FinishThinking()
749 assistantMsg.SetToolCalls(event.Response.ToolCalls)
750 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
751 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
752 return fmt.Errorf("failed to update message: %w", err)
753 }
754 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
755 }
756
757 return nil
758}
759
760func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
761 sess, err := a.sessions.Get(ctx, sessionID)
762 if err != nil {
763 return fmt.Errorf("failed to get session: %w", err)
764 }
765
766 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
767 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
768 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
769 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
770
771 a.eventTokensUsed(sessionID, usage, cost)
772
773 sess.Cost += cost
774 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
775 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
776
777 _, err = a.sessions.Save(ctx, sess)
778 if err != nil {
779 return fmt.Errorf("failed to save session: %w", err)
780 }
781 return nil
782}
783
784func (a *agent) Summarize(ctx context.Context, sessionID string) error {
785 if a.summarizeProvider == nil {
786 return fmt.Errorf("summarize provider not available")
787 }
788
789 // Check if session is busy
790 if a.IsSessionBusy(sessionID) {
791 return ErrSessionBusy
792 }
793
794 // Create a new context with cancellation
795 summarizeCtx, cancel := context.WithCancel(ctx)
796
797 // Store the cancel function in activeRequests to allow cancellation
798 a.activeRequests.Set(sessionID+"-summarize", cancel)
799
800 go func() {
801 defer a.activeRequests.Del(sessionID + "-summarize")
802 defer cancel()
803 event := AgentEvent{
804 Type: AgentEventTypeSummarize,
805 Progress: "Starting summarization...",
806 }
807
808 a.Publish(pubsub.CreatedEvent, event)
809 // Get all messages from the session
810 msgs, err := a.messages.List(summarizeCtx, sessionID)
811 if err != nil {
812 event = AgentEvent{
813 Type: AgentEventTypeError,
814 Error: fmt.Errorf("failed to list messages: %w", err),
815 Done: true,
816 }
817 a.Publish(pubsub.CreatedEvent, event)
818 return
819 }
820 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
821
822 if len(msgs) == 0 {
823 event = AgentEvent{
824 Type: AgentEventTypeError,
825 Error: fmt.Errorf("no messages to summarize"),
826 Done: true,
827 }
828 a.Publish(pubsub.CreatedEvent, event)
829 return
830 }
831
832 event = AgentEvent{
833 Type: AgentEventTypeSummarize,
834 Progress: "Analyzing conversation...",
835 }
836 a.Publish(pubsub.CreatedEvent, event)
837
838 // Add a system message to guide the summarization
839 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
840
841 // Create a new message with the summarize prompt
842 promptMsg := message.Message{
843 Role: message.User,
844 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
845 }
846
847 // Append the prompt to the messages
848 msgsWithPrompt := append(msgs, promptMsg)
849
850 event = AgentEvent{
851 Type: AgentEventTypeSummarize,
852 Progress: "Generating summary...",
853 }
854
855 a.Publish(pubsub.CreatedEvent, event)
856
857 // Send the messages to the summarize provider
858 response := a.summarizeProvider.StreamResponse(
859 summarizeCtx,
860 msgsWithPrompt,
861 nil,
862 )
863 var finalResponse *provider.ProviderResponse
864 for r := range response {
865 if r.Error != nil {
866 event = AgentEvent{
867 Type: AgentEventTypeError,
868 Error: fmt.Errorf("failed to summarize: %w", r.Error),
869 Done: true,
870 }
871 a.Publish(pubsub.CreatedEvent, event)
872 return
873 }
874 finalResponse = r.Response
875 }
876
877 summary := strings.TrimSpace(finalResponse.Content)
878 if summary == "" {
879 event = AgentEvent{
880 Type: AgentEventTypeError,
881 Error: fmt.Errorf("empty summary returned"),
882 Done: true,
883 }
884 a.Publish(pubsub.CreatedEvent, event)
885 return
886 }
887 shell := shell.GetPersistentShell(config.Get().WorkingDir())
888 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
889 event = AgentEvent{
890 Type: AgentEventTypeSummarize,
891 Progress: "Creating new session...",
892 }
893
894 a.Publish(pubsub.CreatedEvent, event)
895 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
896 if err != nil {
897 event = AgentEvent{
898 Type: AgentEventTypeError,
899 Error: fmt.Errorf("failed to get session: %w", err),
900 Done: true,
901 }
902
903 a.Publish(pubsub.CreatedEvent, event)
904 return
905 }
906 // Create a message in the new session with the summary
907 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
908 Role: message.Assistant,
909 Parts: []message.ContentPart{
910 message.TextContent{Text: summary},
911 message.Finish{
912 Reason: message.FinishReasonEndTurn,
913 Time: time.Now().Unix(),
914 },
915 },
916 Model: a.summarizeProvider.Model().ID,
917 Provider: a.summarizeProviderID,
918 })
919 if err != nil {
920 event = AgentEvent{
921 Type: AgentEventTypeError,
922 Error: fmt.Errorf("failed to create summary message: %w", err),
923 Done: true,
924 }
925
926 a.Publish(pubsub.CreatedEvent, event)
927 return
928 }
929 oldSession.SummaryMessageID = msg.ID
930 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
931 oldSession.PromptTokens = 0
932 model := a.summarizeProvider.Model()
933 usage := finalResponse.Usage
934 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
935 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
936 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
937 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
938 oldSession.Cost += cost
939 _, err = a.sessions.Save(summarizeCtx, oldSession)
940 if err != nil {
941 event = AgentEvent{
942 Type: AgentEventTypeError,
943 Error: fmt.Errorf("failed to save session: %w", err),
944 Done: true,
945 }
946 a.Publish(pubsub.CreatedEvent, event)
947 }
948
949 event = AgentEvent{
950 Type: AgentEventTypeSummarize,
951 SessionID: oldSession.ID,
952 Progress: "Summary complete",
953 Done: true,
954 }
955 a.Publish(pubsub.CreatedEvent, event)
956 // Send final success event with the new session ID
957 }()
958
959 return nil
960}
961
962func (a *agent) ClearQueue(sessionID string) {
963 if a.QueuedPrompts(sessionID) > 0 {
964 slog.Info("Clearing queued prompts", "session_id", sessionID)
965 a.promptQueue.Del(sessionID)
966 }
967}
968
969func (a *agent) CancelAll() {
970 if !a.IsBusy() {
971 return
972 }
973 for key := range a.activeRequests.Seq2() {
974 a.Cancel(key) // key is sessionID
975 }
976
977 for _, cleanup := range a.cleanupFuncs {
978 if cleanup != nil {
979 cleanup()
980 }
981 }
982
983 timeout := time.After(5 * time.Second)
984 for a.IsBusy() {
985 select {
986 case <-timeout:
987 return
988 default:
989 time.Sleep(200 * time.Millisecond)
990 }
991 }
992}
993
994func (a *agent) UpdateModel() error {
995 cfg := config.Get()
996
997 // Get current provider configuration
998 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
999 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
1000 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
1001 }
1002
1003 // Check if provider has changed
1004 if string(currentProviderCfg.ID) != a.providerID {
1005 // Provider changed, need to recreate the main provider
1006 model := cfg.GetModelByType(a.agentCfg.Model)
1007 if model.ID == "" {
1008 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1009 }
1010
1011 promptID := agentPromptMap[a.agentCfg.ID]
1012 if promptID == "" {
1013 promptID = prompt.PromptDefault
1014 }
1015
1016 opts := []provider.ProviderClientOption{
1017 provider.WithModel(a.agentCfg.Model),
1018 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1019 }
1020
1021 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1022 if err != nil {
1023 return fmt.Errorf("failed to create new provider: %w", err)
1024 }
1025
1026 // Update the provider and provider ID
1027 a.provider = newProvider
1028 a.providerID = string(currentProviderCfg.ID)
1029 }
1030
1031 // Check if providers have changed for title (small) and summarize (large)
1032 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1033 var smallModelProviderCfg config.ProviderConfig
1034 for p := range cfg.Providers.Seq() {
1035 if p.ID == smallModelCfg.Provider {
1036 smallModelProviderCfg = p
1037 break
1038 }
1039 }
1040 if smallModelProviderCfg.ID == "" {
1041 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1042 }
1043
1044 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1045 var largeModelProviderCfg config.ProviderConfig
1046 for p := range cfg.Providers.Seq() {
1047 if p.ID == largeModelCfg.Provider {
1048 largeModelProviderCfg = p
1049 break
1050 }
1051 }
1052 if largeModelProviderCfg.ID == "" {
1053 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1054 }
1055
1056 var maxTitleTokens int64 = 40
1057
1058 // if the max output is too low for the gemini provider it won't return anything
1059 if smallModelCfg.Provider == "gemini" {
1060 maxTitleTokens = 1000
1061 }
1062 // Recreate title provider
1063 titleOpts := []provider.ProviderClientOption{
1064 provider.WithModel(config.SelectedModelTypeSmall),
1065 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1066 provider.WithMaxTokens(maxTitleTokens),
1067 }
1068 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1069 if err != nil {
1070 return fmt.Errorf("failed to create new title provider: %w", err)
1071 }
1072 a.titleProvider = newTitleProvider
1073
1074 // Recreate summarize provider if provider changed (now large model)
1075 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1076 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1077 if largeModel == nil {
1078 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1079 }
1080 summarizeOpts := []provider.ProviderClientOption{
1081 provider.WithModel(config.SelectedModelTypeLarge),
1082 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1083 }
1084 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1085 if err != nil {
1086 return fmt.Errorf("failed to create new summarize provider: %w", err)
1087 }
1088 a.summarizeProvider = newSummarizeProvider
1089 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1090 }
1091
1092 return nil
1093}
1094
1095func (a *agent) setupEvents(ctx context.Context) {
1096 ctx, cancel := context.WithCancel(ctx)
1097
1098 go func() {
1099 subCh := SubscribeMCPEvents(ctx)
1100
1101 for {
1102 select {
1103 case event, ok := <-subCh:
1104 if !ok {
1105 slog.Debug("MCPEvents subscription channel closed")
1106 return
1107 }
1108 switch event.Payload.Type {
1109 case MCPEventToolsListChanged:
1110 name := event.Payload.Name
1111 c, ok := mcpClients.Get(name)
1112 if !ok {
1113 slog.Warn("MCP client not found for tools update", "name", name)
1114 continue
1115 }
1116 cfg := config.Get()
1117 tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1118 if err != nil {
1119 slog.Error("error listing tools", "error", err)
1120 updateMCPState(name, MCPStateError, err, nil, 0)
1121 _ = c.Close()
1122 continue
1123 }
1124 updateMcpTools(name, tools)
1125 a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
1126 updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1127 default:
1128 continue
1129 }
1130 case <-ctx.Done():
1131 slog.Debug("MCPEvents subscription cancelled")
1132 return
1133 }
1134 }
1135 }()
1136
1137 a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1138}