1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "slices"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/event"
17 "github.com/charmbracelet/crush/internal/history"
18 "github.com/charmbracelet/crush/internal/llm/prompt"
19 "github.com/charmbracelet/crush/internal/llm/provider"
20 "github.com/charmbracelet/crush/internal/llm/tools"
21 "github.com/charmbracelet/crush/internal/log"
22 "github.com/charmbracelet/crush/internal/lsp"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/pubsub"
26 "github.com/charmbracelet/crush/internal/session"
27 "github.com/charmbracelet/crush/internal/shell"
28)
29
30type AgentEventType string
31
32const (
33 AgentEventTypeError AgentEventType = "error"
34 AgentEventTypeResponse AgentEventType = "response"
35 AgentEventTypeSummarize AgentEventType = "summarize"
36)
37
38type AgentEvent struct {
39 Type AgentEventType
40 Message message.Message
41 Error error
42
43 // When summarizing
44 SessionID string
45 Progress string
46 Done bool
47}
48
49type Service interface {
50 pubsub.Suscriber[AgentEvent]
51 Model() catwalk.Model
52 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
53 Cancel(sessionID string)
54 CancelAll()
55 IsSessionBusy(sessionID string) bool
56 IsBusy() bool
57 Summarize(ctx context.Context, sessionID string) error
58 UpdateModel() error
59 QueuedPrompts(sessionID string) int
60 ClearQueue(sessionID string)
61}
62
63type agent struct {
64 cleanupFuncs []func()
65
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70
71 permissions permission.Service
72 baseTools *csync.Map[string, tools.BaseTool]
73 mcpTools *csync.Map[string, tools.BaseTool]
74 lspClients *csync.Map[string, *lsp.Client]
75
76 // We need this to be able to update it when model changes
77 agentToolFn func() (tools.BaseTool, error)
78
79 provider provider.Provider
80 providerID string
81
82 titleProvider provider.Provider
83 summarizeProvider provider.Provider
84 summarizeProviderID string
85
86 activeRequests *csync.Map[string, context.CancelFunc]
87 promptQueue *csync.Map[string, []string]
88}
89
90var agentPromptMap = map[string]prompt.PromptID{
91 "coder": prompt.PromptCoder,
92 "task": prompt.PromptTask,
93}
94
95func NewAgent(
96 ctx context.Context,
97 agentCfg config.Agent,
98 // These services are needed in the tools
99 permissions permission.Service,
100 sessions session.Service,
101 messages message.Service,
102 history history.Service,
103 lspClients *csync.Map[string, *lsp.Client],
104) (Service, error) {
105 cfg := config.Get()
106
107 var agentToolFn func() (tools.BaseTool, error)
108 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
109 agentToolFn = func() (tools.BaseTool, error) {
110 taskAgentCfg := config.Get().Agents["task"]
111 if taskAgentCfg.ID == "" {
112 return nil, fmt.Errorf("task agent not found in config")
113 }
114 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
115 if err != nil {
116 return nil, fmt.Errorf("failed to create task agent: %w", err)
117 }
118 return NewAgentTool(taskAgent, sessions, messages), nil
119 }
120 }
121
122 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
123 if providerCfg == nil {
124 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
125 }
126 model := config.Get().GetModelByType(agentCfg.Model)
127
128 if model == nil {
129 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
130 }
131
132 promptID := agentPromptMap[agentCfg.ID]
133 if promptID == "" {
134 promptID = prompt.PromptDefault
135 }
136 opts := []provider.ProviderClientOption{
137 provider.WithModel(agentCfg.Model),
138 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
139 }
140 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
141 if err != nil {
142 return nil, err
143 }
144
145 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
146 var smallModelProviderCfg *config.ProviderConfig
147 if smallModelCfg.Provider == providerCfg.ID {
148 smallModelProviderCfg = providerCfg
149 } else {
150 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
151
152 if smallModelProviderCfg.ID == "" {
153 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
154 }
155 }
156 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
157 if smallModel.ID == "" {
158 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
159 }
160
161 titleOpts := []provider.ProviderClientOption{
162 provider.WithModel(config.SelectedModelTypeSmall),
163 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
164 }
165 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
166 if err != nil {
167 return nil, err
168 }
169
170 summarizeOpts := []provider.ProviderClientOption{
171 provider.WithModel(config.SelectedModelTypeLarge),
172 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
173 }
174 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
175 if err != nil {
176 return nil, err
177 }
178
179 baseToolsFn := func() map[string]tools.BaseTool {
180 slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
181 defer func() {
182 slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
183 }()
184
185 // Base tools available to all agents
186 cwd := cfg.WorkingDir()
187 result := make(map[string]tools.BaseTool)
188 for _, tool := range []tools.BaseTool{
189 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
190 tools.NewDownloadTool(permissions, cwd),
191 tools.NewEditTool(lspClients, permissions, history, cwd),
192 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
193 tools.NewFetchTool(permissions, cwd),
194 tools.NewGlobTool(cwd),
195 tools.NewGrepTool(cwd),
196 tools.NewLsTool(permissions, cwd),
197 tools.NewSourcegraphTool(),
198 tools.NewViewTool(lspClients, permissions, cwd),
199 tools.NewWriteTool(lspClients, permissions, history, cwd),
200 } {
201 result[tool.Name()] = tool
202 }
203 return result
204 }
205 mcpToolsFn := func() map[string]tools.BaseTool {
206 slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
207 defer func() {
208 slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
209 }()
210
211 mcpToolsOnce.Do(func() {
212 doGetMCPTools(ctx, permissions, cfg)
213 })
214
215 return maps.Collect(mcpTools.Seq2())
216 }
217
218 a := &agent{
219 Broker: pubsub.NewBroker[AgentEvent](),
220 agentCfg: agentCfg,
221 provider: agentProvider,
222 providerID: string(providerCfg.ID),
223 messages: messages,
224 sessions: sessions,
225 titleProvider: titleProvider,
226 summarizeProvider: summarizeProvider,
227 summarizeProviderID: string(providerCfg.ID),
228 agentToolFn: agentToolFn,
229 activeRequests: csync.NewMap[string, context.CancelFunc](),
230 mcpTools: csync.NewLazyMap(mcpToolsFn),
231 baseTools: csync.NewLazyMap(baseToolsFn),
232 promptQueue: csync.NewMap[string, []string](),
233 permissions: permissions,
234 lspClients: lspClients,
235 }
236 a.setupEvents(ctx)
237 return a, nil
238}
239
240func (a *agent) Model() catwalk.Model {
241 return *config.Get().GetModelByType(a.agentCfg.Model)
242}
243
244func (a *agent) Cancel(sessionID string) {
245 // Cancel regular requests
246 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
247 slog.Info("Request cancellation initiated", "session_id", sessionID)
248 cancel()
249 }
250
251 // Also check for summarize requests
252 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
253 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
254 cancel()
255 }
256
257 if a.QueuedPrompts(sessionID) > 0 {
258 slog.Info("Clearing queued prompts", "session_id", sessionID)
259 a.promptQueue.Del(sessionID)
260 }
261}
262
263func (a *agent) IsBusy() bool {
264 var busy bool
265 for cancelFunc := range a.activeRequests.Seq() {
266 if cancelFunc != nil {
267 busy = true
268 break
269 }
270 }
271 return busy
272}
273
274func (a *agent) IsSessionBusy(sessionID string) bool {
275 _, busy := a.activeRequests.Get(sessionID)
276 return busy
277}
278
279func (a *agent) QueuedPrompts(sessionID string) int {
280 l, ok := a.promptQueue.Get(sessionID)
281 if !ok {
282 return 0
283 }
284 return len(l)
285}
286
287func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
288 if content == "" {
289 return nil
290 }
291 if a.titleProvider == nil {
292 return nil
293 }
294 session, err := a.sessions.Get(ctx, sessionID)
295 if err != nil {
296 return err
297 }
298 parts := []message.ContentPart{message.TextContent{
299 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
300 }}
301
302 // Use streaming approach like summarization
303 response := a.titleProvider.StreamResponse(
304 ctx,
305 []message.Message{
306 {
307 Role: message.User,
308 Parts: parts,
309 },
310 },
311 nil,
312 )
313
314 var finalResponse *provider.ProviderResponse
315 for r := range response {
316 if r.Error != nil {
317 return r.Error
318 }
319 finalResponse = r.Response
320 }
321
322 if finalResponse == nil {
323 return fmt.Errorf("no response received from title provider")
324 }
325
326 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
327
328 if idx := strings.Index(title, "</think>"); idx > 0 {
329 title = title[idx+len("</think>"):]
330 }
331
332 title = strings.TrimSpace(title)
333 if title == "" {
334 return nil
335 }
336
337 session.Title = title
338 _, err = a.sessions.Save(ctx, session)
339 return err
340}
341
342func (a *agent) err(err error) AgentEvent {
343 return AgentEvent{
344 Type: AgentEventTypeError,
345 Error: err,
346 }
347}
348
349func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
350 if !a.Model().SupportsImages && attachments != nil {
351 attachments = nil
352 }
353 events := make(chan AgentEvent, 1)
354 if a.IsSessionBusy(sessionID) {
355 existing, ok := a.promptQueue.Get(sessionID)
356 if !ok {
357 existing = []string{}
358 }
359 existing = append(existing, content)
360 a.promptQueue.Set(sessionID, existing)
361 return nil, nil
362 }
363
364 genCtx, cancel := context.WithCancel(ctx)
365 a.activeRequests.Set(sessionID, cancel)
366 startTime := time.Now()
367
368 go func() {
369 slog.Debug("Request started", "sessionID", sessionID)
370 defer log.RecoverPanic("agent.Run", func() {
371 events <- a.err(fmt.Errorf("panic while running the agent"))
372 })
373 var attachmentParts []message.ContentPart
374 for _, attachment := range attachments {
375 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
376 }
377 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
378 if result.Error != nil {
379 if isCancelledErr(result.Error) {
380 slog.Error("Request canceled", "sessionID", sessionID)
381 } else {
382 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
383 event.Error(result.Error)
384 }
385 } else {
386 slog.Debug("Request completed", "sessionID", sessionID)
387 }
388 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
389 a.activeRequests.Del(sessionID)
390 cancel()
391 a.Publish(pubsub.CreatedEvent, result)
392 events <- result
393 close(events)
394 }()
395 a.eventPromptSent(sessionID)
396 return events, nil
397}
398
399func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
400 cfg := config.Get()
401 // List existing messages; if none, start title generation asynchronously.
402 msgs, err := a.messages.List(ctx, sessionID)
403 if err != nil {
404 return a.err(fmt.Errorf("failed to list messages: %w", err))
405 }
406 if len(msgs) == 0 {
407 go func() {
408 defer log.RecoverPanic("agent.Run", func() {
409 slog.Error("panic while generating title")
410 })
411 titleErr := a.generateTitle(ctx, sessionID, content)
412 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
413 slog.Error("failed to generate title", "error", titleErr)
414 }
415 }()
416 }
417 session, err := a.sessions.Get(ctx, sessionID)
418 if err != nil {
419 return a.err(fmt.Errorf("failed to get session: %w", err))
420 }
421 if session.SummaryMessageID != "" {
422 summaryMsgInex := -1
423 for i, msg := range msgs {
424 if msg.ID == session.SummaryMessageID {
425 summaryMsgInex = i
426 break
427 }
428 }
429 if summaryMsgInex != -1 {
430 msgs = msgs[summaryMsgInex:]
431 msgs[0].Role = message.User
432 }
433 }
434
435 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
436 if err != nil {
437 return a.err(fmt.Errorf("failed to create user message: %w", err))
438 }
439 // Append the new user message to the conversation history.
440 msgHistory := append(msgs, userMsg)
441
442 for {
443 // Check for cancellation before each iteration
444 select {
445 case <-ctx.Done():
446 return a.err(ctx.Err())
447 default:
448 // Continue processing
449 }
450 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
451 if err != nil {
452 if errors.Is(err, context.Canceled) {
453 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
454 a.messages.Update(context.Background(), agentMessage)
455 return a.err(ErrRequestCancelled)
456 }
457 return a.err(fmt.Errorf("failed to process events: %w", err))
458 }
459 if cfg.Options.Debug {
460 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
461 }
462 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
463 // We are not done, we need to respond with the tool response
464 msgHistory = append(msgHistory, agentMessage, *toolResults)
465 // If there are queued prompts, process the next one
466 nextPrompt, ok := a.promptQueue.Take(sessionID)
467 if ok {
468 for _, prompt := range nextPrompt {
469 // Create a new user message for the queued prompt
470 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
471 if err != nil {
472 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
473 }
474 // Append the new user message to the conversation history
475 msgHistory = append(msgHistory, userMsg)
476 }
477 }
478
479 continue
480 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
481 queuePrompts, ok := a.promptQueue.Take(sessionID)
482 if ok {
483 for _, prompt := range queuePrompts {
484 if prompt == "" {
485 continue
486 }
487 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
488 if err != nil {
489 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
490 }
491 msgHistory = append(msgHistory, userMsg)
492 }
493 continue
494 }
495 }
496 if agentMessage.FinishReason() == "" {
497 // Kujtim: could not track down where this is happening but this means its cancelled
498 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
499 _ = a.messages.Update(context.Background(), agentMessage)
500 return a.err(ErrRequestCancelled)
501 }
502 return AgentEvent{
503 Type: AgentEventTypeResponse,
504 Message: agentMessage,
505 Done: true,
506 }
507 }
508}
509
510func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
511 parts := []message.ContentPart{message.TextContent{Text: content}}
512 parts = append(parts, attachmentParts...)
513 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
514 Role: message.User,
515 Parts: parts,
516 })
517}
518
519func (a *agent) getAllTools() ([]tools.BaseTool, error) {
520 allTools := slices.Collect(a.baseTools.Seq())
521
522 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
523 if a.agentCfg.ID == "coder" {
524 t = append(t, slices.Collect(a.mcpTools.Seq())...)
525 if a.lspClients.Len() > 0 {
526 t = append(t, tools.NewDiagnosticsTool(a.lspClients))
527 }
528 }
529 return t
530 }
531
532 if a.agentCfg.AllowedTools == nil {
533 allTools = withCoderTools(allTools)
534 } else {
535 var filteredTools []tools.BaseTool
536 for _, tool := range allTools {
537 if slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
538 filteredTools = append(filteredTools, tool)
539 }
540 }
541 allTools = withCoderTools(filteredTools)
542 }
543
544 if a.agentToolFn != nil {
545 agentTool, agentToolErr := a.agentToolFn()
546 if agentToolErr != nil {
547 return nil, agentToolErr
548 }
549 allTools = append(allTools, agentTool)
550 }
551 return allTools, nil
552}
553
554func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
555 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
556
557 // Create the assistant message first so the spinner shows immediately
558 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
559 Role: message.Assistant,
560 Parts: []message.ContentPart{},
561 Model: a.Model().ID,
562 Provider: a.providerID,
563 })
564 if err != nil {
565 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
566 }
567
568 allTools, toolsErr := a.getAllTools()
569 if toolsErr != nil {
570 return assistantMsg, nil, toolsErr
571 }
572 // Now collect tools (which may block on MCP initialization)
573 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
574
575 // Add the session and message ID into the context if needed by tools.
576 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
577
578loop:
579 for {
580 select {
581 case event, ok := <-eventChan:
582 if !ok {
583 break loop
584 }
585 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
586 if errors.Is(processErr, context.Canceled) {
587 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
588 } else {
589 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
590 }
591 return assistantMsg, nil, processErr
592 }
593 case <-ctx.Done():
594 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
595 return assistantMsg, nil, ctx.Err()
596 }
597 }
598
599 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
600 toolCalls := assistantMsg.ToolCalls()
601 for i, toolCall := range toolCalls {
602 select {
603 case <-ctx.Done():
604 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
605 // Make all future tool calls cancelled
606 for j := i; j < len(toolCalls); j++ {
607 toolResults[j] = message.ToolResult{
608 ToolCallID: toolCalls[j].ID,
609 Content: "Tool execution canceled by user",
610 IsError: true,
611 }
612 }
613 goto out
614 default:
615 // Continue processing
616 var tool tools.BaseTool
617 allTools, _ = a.getAllTools()
618 for _, availableTool := range allTools {
619 if availableTool.Info().Name == toolCall.Name {
620 tool = availableTool
621 break
622 }
623 }
624
625 // Tool not found
626 if tool == nil {
627 toolResults[i] = message.ToolResult{
628 ToolCallID: toolCall.ID,
629 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
630 IsError: true,
631 }
632 continue
633 }
634
635 // Run tool in goroutine to allow cancellation
636 type toolExecResult struct {
637 response tools.ToolResponse
638 err error
639 }
640 resultChan := make(chan toolExecResult, 1)
641
642 go func() {
643 response, err := tool.Run(ctx, tools.ToolCall{
644 ID: toolCall.ID,
645 Name: toolCall.Name,
646 Input: toolCall.Input,
647 })
648 resultChan <- toolExecResult{response: response, err: err}
649 }()
650
651 var toolResponse tools.ToolResponse
652 var toolErr error
653
654 select {
655 case <-ctx.Done():
656 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
657 // Mark remaining tool calls as cancelled
658 for j := i; j < len(toolCalls); j++ {
659 toolResults[j] = message.ToolResult{
660 ToolCallID: toolCalls[j].ID,
661 Content: "Tool execution canceled by user",
662 IsError: true,
663 }
664 }
665 goto out
666 case result := <-resultChan:
667 toolResponse = result.response
668 toolErr = result.err
669 }
670
671 if toolErr != nil {
672 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
673 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
674 toolResults[i] = message.ToolResult{
675 ToolCallID: toolCall.ID,
676 Content: "Permission denied",
677 IsError: true,
678 }
679 for j := i + 1; j < len(toolCalls); j++ {
680 toolResults[j] = message.ToolResult{
681 ToolCallID: toolCalls[j].ID,
682 Content: "Tool execution canceled by user",
683 IsError: true,
684 }
685 }
686 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
687 break
688 }
689 }
690 toolResults[i] = message.ToolResult{
691 ToolCallID: toolCall.ID,
692 Content: toolResponse.Content,
693 Metadata: toolResponse.Metadata,
694 IsError: toolResponse.IsError,
695 }
696 }
697 }
698out:
699 if len(toolResults) == 0 {
700 return assistantMsg, nil, nil
701 }
702 parts := make([]message.ContentPart, 0)
703 for _, tr := range toolResults {
704 parts = append(parts, tr)
705 }
706 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
707 Role: message.Tool,
708 Parts: parts,
709 Provider: a.providerID,
710 })
711 if err != nil {
712 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
713 }
714
715 return assistantMsg, &msg, err
716}
717
718func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
719 msg.AddFinish(finishReason, message, details)
720 _ = a.messages.Update(ctx, *msg)
721}
722
723func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
724 select {
725 case <-ctx.Done():
726 return ctx.Err()
727 default:
728 // Continue processing.
729 }
730
731 switch event.Type {
732 case provider.EventThinkingDelta:
733 assistantMsg.AppendReasoningContent(event.Thinking)
734 return a.messages.Update(ctx, *assistantMsg)
735 case provider.EventSignatureDelta:
736 assistantMsg.AppendReasoningSignature(event.Signature)
737 return a.messages.Update(ctx, *assistantMsg)
738 case provider.EventContentDelta:
739 assistantMsg.FinishThinking()
740 assistantMsg.AppendContent(event.Content)
741 return a.messages.Update(ctx, *assistantMsg)
742 case provider.EventToolUseStart:
743 assistantMsg.FinishThinking()
744 slog.Info("Tool call started", "toolCall", event.ToolCall)
745 assistantMsg.AddToolCall(*event.ToolCall)
746 return a.messages.Update(ctx, *assistantMsg)
747 case provider.EventToolUseDelta:
748 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
749 return a.messages.Update(ctx, *assistantMsg)
750 case provider.EventToolUseStop:
751 slog.Info("Finished tool call", "toolCall", event.ToolCall)
752 assistantMsg.FinishToolCall(event.ToolCall.ID)
753 return a.messages.Update(ctx, *assistantMsg)
754 case provider.EventError:
755 return event.Error
756 case provider.EventComplete:
757 assistantMsg.FinishThinking()
758 assistantMsg.SetToolCalls(event.Response.ToolCalls)
759 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
760 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
761 return fmt.Errorf("failed to update message: %w", err)
762 }
763 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
764 }
765
766 return nil
767}
768
769func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
770 sess, err := a.sessions.Get(ctx, sessionID)
771 if err != nil {
772 return fmt.Errorf("failed to get session: %w", err)
773 }
774
775 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
776 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
777 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
778 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
779
780 a.eventTokensUsed(sessionID, usage, cost)
781
782 sess.Cost += cost
783 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
784 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
785
786 _, err = a.sessions.Save(ctx, sess)
787 if err != nil {
788 return fmt.Errorf("failed to save session: %w", err)
789 }
790 return nil
791}
792
793func (a *agent) Summarize(ctx context.Context, sessionID string) error {
794 if a.summarizeProvider == nil {
795 return fmt.Errorf("summarize provider not available")
796 }
797
798 // Check if session is busy
799 if a.IsSessionBusy(sessionID) {
800 return ErrSessionBusy
801 }
802
803 // Create a new context with cancellation
804 summarizeCtx, cancel := context.WithCancel(ctx)
805
806 // Store the cancel function in activeRequests to allow cancellation
807 a.activeRequests.Set(sessionID+"-summarize", cancel)
808
809 go func() {
810 defer a.activeRequests.Del(sessionID + "-summarize")
811 defer cancel()
812 event := AgentEvent{
813 Type: AgentEventTypeSummarize,
814 Progress: "Starting summarization...",
815 }
816
817 a.Publish(pubsub.CreatedEvent, event)
818 // Get all messages from the session
819 msgs, err := a.messages.List(summarizeCtx, sessionID)
820 if err != nil {
821 event = AgentEvent{
822 Type: AgentEventTypeError,
823 Error: fmt.Errorf("failed to list messages: %w", err),
824 Done: true,
825 }
826 a.Publish(pubsub.CreatedEvent, event)
827 return
828 }
829 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
830
831 if len(msgs) == 0 {
832 event = AgentEvent{
833 Type: AgentEventTypeError,
834 Error: fmt.Errorf("no messages to summarize"),
835 Done: true,
836 }
837 a.Publish(pubsub.CreatedEvent, event)
838 return
839 }
840
841 event = AgentEvent{
842 Type: AgentEventTypeSummarize,
843 Progress: "Analyzing conversation...",
844 }
845 a.Publish(pubsub.CreatedEvent, event)
846
847 // Add a system message to guide the summarization
848 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
849
850 // Create a new message with the summarize prompt
851 promptMsg := message.Message{
852 Role: message.User,
853 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
854 }
855
856 // Append the prompt to the messages
857 msgsWithPrompt := append(msgs, promptMsg)
858
859 event = AgentEvent{
860 Type: AgentEventTypeSummarize,
861 Progress: "Generating summary...",
862 }
863
864 a.Publish(pubsub.CreatedEvent, event)
865
866 // Send the messages to the summarize provider
867 response := a.summarizeProvider.StreamResponse(
868 summarizeCtx,
869 msgsWithPrompt,
870 nil,
871 )
872 var finalResponse *provider.ProviderResponse
873 for r := range response {
874 if r.Error != nil {
875 event = AgentEvent{
876 Type: AgentEventTypeError,
877 Error: fmt.Errorf("failed to summarize: %w", r.Error),
878 Done: true,
879 }
880 a.Publish(pubsub.CreatedEvent, event)
881 return
882 }
883 finalResponse = r.Response
884 }
885
886 summary := strings.TrimSpace(finalResponse.Content)
887 if summary == "" {
888 event = AgentEvent{
889 Type: AgentEventTypeError,
890 Error: fmt.Errorf("empty summary returned"),
891 Done: true,
892 }
893 a.Publish(pubsub.CreatedEvent, event)
894 return
895 }
896 shell := shell.GetPersistentShell(config.Get().WorkingDir())
897 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
898 event = AgentEvent{
899 Type: AgentEventTypeSummarize,
900 Progress: "Creating new session...",
901 }
902
903 a.Publish(pubsub.CreatedEvent, event)
904 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
905 if err != nil {
906 event = AgentEvent{
907 Type: AgentEventTypeError,
908 Error: fmt.Errorf("failed to get session: %w", err),
909 Done: true,
910 }
911
912 a.Publish(pubsub.CreatedEvent, event)
913 return
914 }
915 // Create a message in the new session with the summary
916 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
917 Role: message.Assistant,
918 Parts: []message.ContentPart{
919 message.TextContent{Text: summary},
920 message.Finish{
921 Reason: message.FinishReasonEndTurn,
922 Time: time.Now().Unix(),
923 },
924 },
925 Model: a.summarizeProvider.Model().ID,
926 Provider: a.summarizeProviderID,
927 })
928 if err != nil {
929 event = AgentEvent{
930 Type: AgentEventTypeError,
931 Error: fmt.Errorf("failed to create summary message: %w", err),
932 Done: true,
933 }
934
935 a.Publish(pubsub.CreatedEvent, event)
936 return
937 }
938 oldSession.SummaryMessageID = msg.ID
939 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
940 oldSession.PromptTokens = 0
941 model := a.summarizeProvider.Model()
942 usage := finalResponse.Usage
943 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
944 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
945 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
946 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
947 oldSession.Cost += cost
948 _, err = a.sessions.Save(summarizeCtx, oldSession)
949 if err != nil {
950 event = AgentEvent{
951 Type: AgentEventTypeError,
952 Error: fmt.Errorf("failed to save session: %w", err),
953 Done: true,
954 }
955 a.Publish(pubsub.CreatedEvent, event)
956 }
957
958 event = AgentEvent{
959 Type: AgentEventTypeSummarize,
960 SessionID: oldSession.ID,
961 Progress: "Summary complete",
962 Done: true,
963 }
964 a.Publish(pubsub.CreatedEvent, event)
965 // Send final success event with the new session ID
966 }()
967
968 return nil
969}
970
971func (a *agent) ClearQueue(sessionID string) {
972 if a.QueuedPrompts(sessionID) > 0 {
973 slog.Info("Clearing queued prompts", "session_id", sessionID)
974 a.promptQueue.Del(sessionID)
975 }
976}
977
978func (a *agent) CancelAll() {
979 if !a.IsBusy() {
980 return
981 }
982 for key := range a.activeRequests.Seq2() {
983 a.Cancel(key) // key is sessionID
984 }
985
986 for _, cleanup := range a.cleanupFuncs {
987 if cleanup != nil {
988 cleanup()
989 }
990 }
991
992 timeout := time.After(5 * time.Second)
993 for a.IsBusy() {
994 select {
995 case <-timeout:
996 return
997 default:
998 time.Sleep(200 * time.Millisecond)
999 }
1000 }
1001}
1002
1003func (a *agent) UpdateModel() error {
1004 cfg := config.Get()
1005
1006 // Get current provider configuration
1007 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
1008 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
1009 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
1010 }
1011
1012 // Check if provider has changed
1013 if string(currentProviderCfg.ID) != a.providerID {
1014 // Provider changed, need to recreate the main provider
1015 model := cfg.GetModelByType(a.agentCfg.Model)
1016 if model.ID == "" {
1017 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1018 }
1019
1020 promptID := agentPromptMap[a.agentCfg.ID]
1021 if promptID == "" {
1022 promptID = prompt.PromptDefault
1023 }
1024
1025 opts := []provider.ProviderClientOption{
1026 provider.WithModel(a.agentCfg.Model),
1027 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1028 }
1029
1030 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1031 if err != nil {
1032 return fmt.Errorf("failed to create new provider: %w", err)
1033 }
1034
1035 // Update the provider and provider ID
1036 a.provider = newProvider
1037 a.providerID = string(currentProviderCfg.ID)
1038 }
1039
1040 // Check if providers have changed for title (small) and summarize (large)
1041 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1042 var smallModelProviderCfg config.ProviderConfig
1043 for p := range cfg.Providers.Seq() {
1044 if p.ID == smallModelCfg.Provider {
1045 smallModelProviderCfg = p
1046 break
1047 }
1048 }
1049 if smallModelProviderCfg.ID == "" {
1050 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1051 }
1052
1053 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1054 var largeModelProviderCfg config.ProviderConfig
1055 for p := range cfg.Providers.Seq() {
1056 if p.ID == largeModelCfg.Provider {
1057 largeModelProviderCfg = p
1058 break
1059 }
1060 }
1061 if largeModelProviderCfg.ID == "" {
1062 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1063 }
1064
1065 var maxTitleTokens int64 = 40
1066
1067 // if the max output is too low for the gemini provider it won't return anything
1068 if smallModelCfg.Provider == "gemini" {
1069 maxTitleTokens = 1000
1070 }
1071 // Recreate title provider
1072 titleOpts := []provider.ProviderClientOption{
1073 provider.WithModel(config.SelectedModelTypeSmall),
1074 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1075 provider.WithMaxTokens(maxTitleTokens),
1076 }
1077 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1078 if err != nil {
1079 return fmt.Errorf("failed to create new title provider: %w", err)
1080 }
1081 a.titleProvider = newTitleProvider
1082
1083 // Recreate summarize provider if provider changed (now large model)
1084 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1085 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1086 if largeModel == nil {
1087 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1088 }
1089 summarizeOpts := []provider.ProviderClientOption{
1090 provider.WithModel(config.SelectedModelTypeLarge),
1091 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1092 }
1093 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1094 if err != nil {
1095 return fmt.Errorf("failed to create new summarize provider: %w", err)
1096 }
1097 a.summarizeProvider = newSummarizeProvider
1098 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1099 }
1100
1101 return nil
1102}
1103
1104func (a *agent) setupEvents(ctx context.Context) {
1105 ctx, cancel := context.WithCancel(ctx)
1106
1107 go func() {
1108 subCh := SubscribeMCPEvents(ctx)
1109
1110 for {
1111 select {
1112 case event, ok := <-subCh:
1113 if !ok {
1114 slog.Debug("MCPEvents subscription channel closed")
1115 return
1116 }
1117 switch event.Payload.Type {
1118 case MCPEventToolsListChanged:
1119 name := event.Payload.Name
1120 c, ok := mcpClients.Get(name)
1121 if !ok {
1122 slog.Warn("MCP client not found for tools update", "name", name)
1123 continue
1124 }
1125 cfg := config.Get()
1126 tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1127 if err != nil {
1128 slog.Error("error listing tools", "error", err)
1129 updateMCPState(name, MCPStateError, err, nil, 0)
1130 _ = c.Close()
1131 continue
1132 }
1133 updateMcpTools(name, tools)
1134 // Update the lazy map with the new tools
1135 a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2()))
1136 updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1137 default:
1138 continue
1139 }
1140 case <-ctx.Done():
1141 slog.Debug("MCPEvents subscription cancelled")
1142 return
1143 }
1144 }
1145 }()
1146
1147 a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1148}