1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "slices"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/event"
17 "github.com/charmbracelet/crush/internal/history"
18 "github.com/charmbracelet/crush/internal/llm/prompt"
19 "github.com/charmbracelet/crush/internal/llm/provider"
20 "github.com/charmbracelet/crush/internal/llm/tools"
21 "github.com/charmbracelet/crush/internal/log"
22 "github.com/charmbracelet/crush/internal/lsp"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/pubsub"
26 "github.com/charmbracelet/crush/internal/session"
27 "github.com/charmbracelet/crush/internal/shell"
28)
29
30type AgentEventType string
31
32const (
33 AgentEventTypeError AgentEventType = "error"
34 AgentEventTypeResponse AgentEventType = "response"
35 AgentEventTypeSummarize AgentEventType = "summarize"
36)
37
38type AgentEvent struct {
39 Type AgentEventType
40 Message message.Message
41 Error error
42
43 // When summarizing
44 SessionID string
45 Progress string
46 Done bool
47}
48
49type Service interface {
50 pubsub.Suscriber[AgentEvent]
51 Model() catwalk.Model
52 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
53 Cancel(sessionID string)
54 CancelAll()
55 IsSessionBusy(sessionID string) bool
56 IsBusy() bool
57 Summarize(ctx context.Context, sessionID string) error
58 UpdateModel() error
59 QueuedPrompts(sessionID string) int
60 ClearQueue(sessionID string)
61}
62
63type agent struct {
64 *pubsub.Broker[AgentEvent]
65 agentCfg config.Agent
66 sessions session.Service
67 messages message.Service
68 permissions permission.Service
69 baseTools *csync.Map[string, tools.BaseTool]
70 mcpTools *csync.Map[string, tools.BaseTool]
71 lspClients *csync.Map[string, *lsp.Client]
72
73 // We need this to be able to update it when model changes
74 agentToolFn func() (tools.BaseTool, error)
75 cleanupFuncs []func()
76
77 provider provider.Provider
78 providerID string
79
80 titleProvider provider.Provider
81 summarizeProvider provider.Provider
82 summarizeProviderID string
83
84 activeRequests *csync.Map[string, context.CancelFunc]
85 promptQueue *csync.Map[string, []string]
86}
87
88var agentPromptMap = map[string]prompt.PromptID{
89 "coder": prompt.PromptCoder,
90 "task": prompt.PromptTask,
91}
92
93func NewAgent(
94 ctx context.Context,
95 agentCfg config.Agent,
96 // These services are needed in the tools
97 permissions permission.Service,
98 sessions session.Service,
99 messages message.Service,
100 history history.Service,
101 lspClients *csync.Map[string, *lsp.Client],
102) (Service, error) {
103 cfg := config.Get()
104
105 var agentToolFn func() (tools.BaseTool, error)
106 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
107 agentToolFn = func() (tools.BaseTool, error) {
108 taskAgentCfg := config.Get().Agents["task"]
109 if taskAgentCfg.ID == "" {
110 return nil, fmt.Errorf("task agent not found in config")
111 }
112 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
113 if err != nil {
114 return nil, fmt.Errorf("failed to create task agent: %w", err)
115 }
116 return NewAgentTool(taskAgent, sessions, messages), nil
117 }
118 }
119
120 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
121 if providerCfg == nil {
122 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
123 }
124 model := config.Get().GetModelByType(agentCfg.Model)
125
126 if model == nil {
127 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
128 }
129
130 promptID := agentPromptMap[agentCfg.ID]
131 if promptID == "" {
132 promptID = prompt.PromptDefault
133 }
134 opts := []provider.ProviderClientOption{
135 provider.WithModel(agentCfg.Model),
136 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
137 }
138 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
139 if err != nil {
140 return nil, err
141 }
142
143 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
144 var smallModelProviderCfg *config.ProviderConfig
145 if smallModelCfg.Provider == providerCfg.ID {
146 smallModelProviderCfg = providerCfg
147 } else {
148 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
149
150 if smallModelProviderCfg.ID == "" {
151 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
152 }
153 }
154 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
155 if smallModel.ID == "" {
156 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
157 }
158
159 titleOpts := []provider.ProviderClientOption{
160 provider.WithModel(config.SelectedModelTypeSmall),
161 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
162 }
163 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
164 if err != nil {
165 return nil, err
166 }
167
168 summarizeOpts := []provider.ProviderClientOption{
169 provider.WithModel(config.SelectedModelTypeLarge),
170 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
171 }
172 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
173 if err != nil {
174 return nil, err
175 }
176
177 baseToolsFn := func() map[string]tools.BaseTool {
178 slog.Debug("Initializing agent base tools", "agent", agentCfg.ID)
179 defer func() {
180 slog.Debug("Initialized agent base tools", "agent", agentCfg.ID)
181 }()
182
183 // Base tools available to all agents
184 cwd := cfg.WorkingDir()
185 result := make(map[string]tools.BaseTool)
186 for _, tool := range []tools.BaseTool{
187 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
188 tools.NewDownloadTool(permissions, cwd),
189 tools.NewEditTool(lspClients, permissions, history, cwd),
190 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
191 tools.NewFetchTool(permissions, cwd),
192 tools.NewGlobTool(cwd),
193 tools.NewGrepTool(cwd),
194 tools.NewLsTool(permissions, cwd),
195 tools.NewSourcegraphTool(),
196 tools.NewViewTool(lspClients, permissions, cwd),
197 tools.NewWriteTool(lspClients, permissions, history, cwd),
198 } {
199 result[tool.Name()] = tool
200 }
201 return result
202 }
203 mcpToolsFn := func() map[string]tools.BaseTool {
204 slog.Debug("Initializing agent mcp tools", "agent", agentCfg.ID)
205 defer func() {
206 slog.Debug("Initialized agent mcp tools", "agent", agentCfg.ID)
207 }()
208
209 mcpToolsOnce.Do(func() {
210 doGetMCPTools(ctx, permissions, cfg)
211 })
212
213 return maps.Collect(mcpTools.Seq2())
214 }
215
216 a := &agent{
217 Broker: pubsub.NewBroker[AgentEvent](),
218 agentCfg: agentCfg,
219 provider: agentProvider,
220 providerID: string(providerCfg.ID),
221 messages: messages,
222 sessions: sessions,
223 titleProvider: titleProvider,
224 summarizeProvider: summarizeProvider,
225 summarizeProviderID: string(providerCfg.ID),
226 agentToolFn: agentToolFn,
227 activeRequests: csync.NewMap[string, context.CancelFunc](),
228 mcpTools: csync.NewLazyMap(mcpToolsFn),
229 baseTools: csync.NewLazyMap(baseToolsFn),
230 promptQueue: csync.NewMap[string, []string](),
231 permissions: permissions,
232 lspClients: lspClients,
233 }
234 a.setupEvents(ctx)
235 return a, nil
236}
237
238func (a *agent) Model() catwalk.Model {
239 return *config.Get().GetModelByType(a.agentCfg.Model)
240}
241
242func (a *agent) Cancel(sessionID string) {
243 // Cancel regular requests
244 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
245 slog.Info("Request cancellation initiated", "session_id", sessionID)
246 cancel()
247 }
248
249 // Also check for summarize requests
250 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
251 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
252 cancel()
253 }
254
255 if a.QueuedPrompts(sessionID) > 0 {
256 slog.Info("Clearing queued prompts", "session_id", sessionID)
257 a.promptQueue.Del(sessionID)
258 }
259}
260
261func (a *agent) IsBusy() bool {
262 var busy bool
263 for cancelFunc := range a.activeRequests.Seq() {
264 if cancelFunc != nil {
265 busy = true
266 break
267 }
268 }
269 return busy
270}
271
272func (a *agent) IsSessionBusy(sessionID string) bool {
273 _, busy := a.activeRequests.Get(sessionID)
274 return busy
275}
276
277func (a *agent) QueuedPrompts(sessionID string) int {
278 l, ok := a.promptQueue.Get(sessionID)
279 if !ok {
280 return 0
281 }
282 return len(l)
283}
284
285func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
286 if content == "" {
287 return nil
288 }
289 if a.titleProvider == nil {
290 return nil
291 }
292 session, err := a.sessions.Get(ctx, sessionID)
293 if err != nil {
294 return err
295 }
296 parts := []message.ContentPart{message.TextContent{
297 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
298 }}
299
300 // Use streaming approach like summarization
301 response := a.titleProvider.StreamResponse(
302 ctx,
303 []message.Message{
304 {
305 Role: message.User,
306 Parts: parts,
307 },
308 },
309 nil,
310 )
311
312 var finalResponse *provider.ProviderResponse
313 for r := range response {
314 if r.Error != nil {
315 return r.Error
316 }
317 finalResponse = r.Response
318 }
319
320 if finalResponse == nil {
321 return fmt.Errorf("no response received from title provider")
322 }
323
324 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
325
326 if idx := strings.Index(title, "</think>"); idx > 0 {
327 title = title[idx+len("</think>"):]
328 }
329
330 title = strings.TrimSpace(title)
331 if title == "" {
332 return nil
333 }
334
335 session.Title = title
336 _, err = a.sessions.Save(ctx, session)
337 return err
338}
339
340func (a *agent) err(err error) AgentEvent {
341 return AgentEvent{
342 Type: AgentEventTypeError,
343 Error: err,
344 }
345}
346
347func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
348 if !a.Model().SupportsImages && attachments != nil {
349 attachments = nil
350 }
351 events := make(chan AgentEvent, 1)
352 if a.IsSessionBusy(sessionID) {
353 existing, ok := a.promptQueue.Get(sessionID)
354 if !ok {
355 existing = []string{}
356 }
357 existing = append(existing, content)
358 a.promptQueue.Set(sessionID, existing)
359 return nil, nil
360 }
361
362 genCtx, cancel := context.WithCancel(ctx)
363 a.activeRequests.Set(sessionID, cancel)
364 startTime := time.Now()
365
366 go func() {
367 slog.Debug("Request started", "sessionID", sessionID)
368 defer log.RecoverPanic("agent.Run", func() {
369 events <- a.err(fmt.Errorf("panic while running the agent"))
370 })
371 var attachmentParts []message.ContentPart
372 for _, attachment := range attachments {
373 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
374 }
375 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
376 if result.Error != nil {
377 if isCancelledErr(result.Error) {
378 slog.Error("Request canceled", "sessionID", sessionID)
379 } else {
380 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
381 event.Error(result.Error)
382 }
383 } else {
384 slog.Debug("Request completed", "sessionID", sessionID)
385 }
386 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
387 a.activeRequests.Del(sessionID)
388 cancel()
389 a.Publish(pubsub.CreatedEvent, result)
390 events <- result
391 close(events)
392 }()
393 a.eventPromptSent(sessionID)
394 return events, nil
395}
396
397func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
398 cfg := config.Get()
399 // List existing messages; if none, start title generation asynchronously.
400 msgs, err := a.messages.List(ctx, sessionID)
401 if err != nil {
402 return a.err(fmt.Errorf("failed to list messages: %w", err))
403 }
404 if len(msgs) == 0 {
405 go func() {
406 defer log.RecoverPanic("agent.Run", func() {
407 slog.Error("panic while generating title")
408 })
409 titleErr := a.generateTitle(ctx, sessionID, content)
410 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
411 slog.Error("failed to generate title", "error", titleErr)
412 }
413 }()
414 }
415 session, err := a.sessions.Get(ctx, sessionID)
416 if err != nil {
417 return a.err(fmt.Errorf("failed to get session: %w", err))
418 }
419 if session.SummaryMessageID != "" {
420 summaryMsgInex := -1
421 for i, msg := range msgs {
422 if msg.ID == session.SummaryMessageID {
423 summaryMsgInex = i
424 break
425 }
426 }
427 if summaryMsgInex != -1 {
428 msgs = msgs[summaryMsgInex:]
429 msgs[0].Role = message.User
430 }
431 }
432
433 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
434 if err != nil {
435 return a.err(fmt.Errorf("failed to create user message: %w", err))
436 }
437 // Append the new user message to the conversation history.
438 msgHistory := append(msgs, userMsg)
439
440 for {
441 // Check for cancellation before each iteration
442 select {
443 case <-ctx.Done():
444 return a.err(ctx.Err())
445 default:
446 // Continue processing
447 }
448 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
449 if err != nil {
450 if errors.Is(err, context.Canceled) {
451 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
452 a.messages.Update(context.Background(), agentMessage)
453 return a.err(ErrRequestCancelled)
454 }
455 return a.err(fmt.Errorf("failed to process events: %w", err))
456 }
457 if cfg.Options.Debug {
458 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
459 }
460 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
461 // We are not done, we need to respond with the tool response
462 msgHistory = append(msgHistory, agentMessage, *toolResults)
463 // If there are queued prompts, process the next one
464 nextPrompt, ok := a.promptQueue.Take(sessionID)
465 if ok {
466 for _, prompt := range nextPrompt {
467 // Create a new user message for the queued prompt
468 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
469 if err != nil {
470 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
471 }
472 // Append the new user message to the conversation history
473 msgHistory = append(msgHistory, userMsg)
474 }
475 }
476
477 continue
478 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
479 queuePrompts, ok := a.promptQueue.Take(sessionID)
480 if ok {
481 for _, prompt := range queuePrompts {
482 if prompt == "" {
483 continue
484 }
485 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
486 if err != nil {
487 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
488 }
489 msgHistory = append(msgHistory, userMsg)
490 }
491 continue
492 }
493 }
494 if agentMessage.FinishReason() == "" {
495 // Kujtim: could not track down where this is happening but this means its cancelled
496 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
497 _ = a.messages.Update(context.Background(), agentMessage)
498 return a.err(ErrRequestCancelled)
499 }
500 return AgentEvent{
501 Type: AgentEventTypeResponse,
502 Message: agentMessage,
503 Done: true,
504 }
505 }
506}
507
508func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
509 parts := []message.ContentPart{message.TextContent{Text: content}}
510 parts = append(parts, attachmentParts...)
511 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
512 Role: message.User,
513 Parts: parts,
514 })
515}
516
517func (a *agent) getAllTools() ([]tools.BaseTool, error) {
518 var allTools []tools.BaseTool
519 for tool := range a.baseTools.Seq() {
520 if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
521 allTools = append(allTools, tool)
522 }
523 }
524 if a.agentCfg.ID == "coder" {
525 allTools = slices.AppendSeq(allTools, a.mcpTools.Seq())
526 if a.lspClients.Len() > 0 {
527 allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients))
528 }
529 }
530 if a.agentToolFn != nil {
531 agentTool, agentToolErr := a.agentToolFn()
532 if agentToolErr != nil {
533 return nil, agentToolErr
534 }
535 allTools = append(allTools, agentTool)
536 }
537 return allTools, nil
538}
539
540func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
541 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
542
543 // Create the assistant message first so the spinner shows immediately
544 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
545 Role: message.Assistant,
546 Parts: []message.ContentPart{},
547 Model: a.Model().ID,
548 Provider: a.providerID,
549 })
550 if err != nil {
551 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
552 }
553
554 allTools, toolsErr := a.getAllTools()
555 if toolsErr != nil {
556 return assistantMsg, nil, toolsErr
557 }
558 // Now collect tools (which may block on MCP initialization)
559 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
560
561 // Add the session and message ID into the context if needed by tools.
562 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
563
564loop:
565 for {
566 select {
567 case event, ok := <-eventChan:
568 if !ok {
569 break loop
570 }
571 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
572 if errors.Is(processErr, context.Canceled) {
573 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
574 } else {
575 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
576 }
577 return assistantMsg, nil, processErr
578 }
579 case <-ctx.Done():
580 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
581 return assistantMsg, nil, ctx.Err()
582 }
583 }
584
585 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
586 toolCalls := assistantMsg.ToolCalls()
587 for i, toolCall := range toolCalls {
588 select {
589 case <-ctx.Done():
590 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
591 // Make all future tool calls cancelled
592 for j := i; j < len(toolCalls); j++ {
593 toolResults[j] = message.ToolResult{
594 ToolCallID: toolCalls[j].ID,
595 Content: "Tool execution canceled by user",
596 IsError: true,
597 }
598 }
599 goto out
600 default:
601 // Continue processing
602 var tool tools.BaseTool
603 allTools, _ = a.getAllTools()
604 for _, availableTool := range allTools {
605 if availableTool.Info().Name == toolCall.Name {
606 tool = availableTool
607 break
608 }
609 }
610
611 // Tool not found
612 if tool == nil {
613 toolResults[i] = message.ToolResult{
614 ToolCallID: toolCall.ID,
615 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
616 IsError: true,
617 }
618 continue
619 }
620
621 // Run tool in goroutine to allow cancellation
622 type toolExecResult struct {
623 response tools.ToolResponse
624 err error
625 }
626 resultChan := make(chan toolExecResult, 1)
627
628 go func() {
629 response, err := tool.Run(ctx, tools.ToolCall{
630 ID: toolCall.ID,
631 Name: toolCall.Name,
632 Input: toolCall.Input,
633 })
634 resultChan <- toolExecResult{response: response, err: err}
635 }()
636
637 var toolResponse tools.ToolResponse
638 var toolErr error
639
640 select {
641 case <-ctx.Done():
642 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
643 // Mark remaining tool calls as cancelled
644 for j := i; j < len(toolCalls); j++ {
645 toolResults[j] = message.ToolResult{
646 ToolCallID: toolCalls[j].ID,
647 Content: "Tool execution canceled by user",
648 IsError: true,
649 }
650 }
651 goto out
652 case result := <-resultChan:
653 toolResponse = result.response
654 toolErr = result.err
655 }
656
657 if toolErr != nil {
658 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
659 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
660 toolResults[i] = message.ToolResult{
661 ToolCallID: toolCall.ID,
662 Content: "Permission denied",
663 IsError: true,
664 }
665 for j := i + 1; j < len(toolCalls); j++ {
666 toolResults[j] = message.ToolResult{
667 ToolCallID: toolCalls[j].ID,
668 Content: "Tool execution canceled by user",
669 IsError: true,
670 }
671 }
672 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
673 break
674 }
675 }
676 toolResults[i] = message.ToolResult{
677 ToolCallID: toolCall.ID,
678 Content: toolResponse.Content,
679 Metadata: toolResponse.Metadata,
680 IsError: toolResponse.IsError,
681 }
682 }
683 }
684out:
685 if len(toolResults) == 0 {
686 return assistantMsg, nil, nil
687 }
688 parts := make([]message.ContentPart, 0)
689 for _, tr := range toolResults {
690 parts = append(parts, tr)
691 }
692 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
693 Role: message.Tool,
694 Parts: parts,
695 Provider: a.providerID,
696 })
697 if err != nil {
698 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
699 }
700
701 return assistantMsg, &msg, err
702}
703
704func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
705 msg.AddFinish(finishReason, message, details)
706 _ = a.messages.Update(ctx, *msg)
707}
708
709func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
710 select {
711 case <-ctx.Done():
712 return ctx.Err()
713 default:
714 // Continue processing.
715 }
716
717 switch event.Type {
718 case provider.EventThinkingDelta:
719 assistantMsg.AppendReasoningContent(event.Thinking)
720 return a.messages.Update(ctx, *assistantMsg)
721 case provider.EventSignatureDelta:
722 assistantMsg.AppendReasoningSignature(event.Signature)
723 return a.messages.Update(ctx, *assistantMsg)
724 case provider.EventContentDelta:
725 assistantMsg.FinishThinking()
726 assistantMsg.AppendContent(event.Content)
727 return a.messages.Update(ctx, *assistantMsg)
728 case provider.EventToolUseStart:
729 assistantMsg.FinishThinking()
730 slog.Info("Tool call started", "toolCall", event.ToolCall)
731 assistantMsg.AddToolCall(*event.ToolCall)
732 return a.messages.Update(ctx, *assistantMsg)
733 case provider.EventToolUseDelta:
734 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
735 return a.messages.Update(ctx, *assistantMsg)
736 case provider.EventToolUseStop:
737 slog.Info("Finished tool call", "toolCall", event.ToolCall)
738 assistantMsg.FinishToolCall(event.ToolCall.ID)
739 return a.messages.Update(ctx, *assistantMsg)
740 case provider.EventError:
741 return event.Error
742 case provider.EventComplete:
743 assistantMsg.FinishThinking()
744 assistantMsg.SetToolCalls(event.Response.ToolCalls)
745 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
746 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
747 return fmt.Errorf("failed to update message: %w", err)
748 }
749 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
750 }
751
752 return nil
753}
754
755func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
756 sess, err := a.sessions.Get(ctx, sessionID)
757 if err != nil {
758 return fmt.Errorf("failed to get session: %w", err)
759 }
760
761 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
762 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
763 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
764 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
765
766 a.eventTokensUsed(sessionID, usage, cost)
767
768 sess.Cost += cost
769 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
770 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
771
772 _, err = a.sessions.Save(ctx, sess)
773 if err != nil {
774 return fmt.Errorf("failed to save session: %w", err)
775 }
776 return nil
777}
778
779func (a *agent) Summarize(ctx context.Context, sessionID string) error {
780 if a.summarizeProvider == nil {
781 return fmt.Errorf("summarize provider not available")
782 }
783
784 // Check if session is busy
785 if a.IsSessionBusy(sessionID) {
786 return ErrSessionBusy
787 }
788
789 // Create a new context with cancellation
790 summarizeCtx, cancel := context.WithCancel(ctx)
791
792 // Store the cancel function in activeRequests to allow cancellation
793 a.activeRequests.Set(sessionID+"-summarize", cancel)
794
795 go func() {
796 defer a.activeRequests.Del(sessionID + "-summarize")
797 defer cancel()
798 event := AgentEvent{
799 Type: AgentEventTypeSummarize,
800 Progress: "Starting summarization...",
801 }
802
803 a.Publish(pubsub.CreatedEvent, event)
804 // Get all messages from the session
805 msgs, err := a.messages.List(summarizeCtx, sessionID)
806 if err != nil {
807 event = AgentEvent{
808 Type: AgentEventTypeError,
809 Error: fmt.Errorf("failed to list messages: %w", err),
810 Done: true,
811 }
812 a.Publish(pubsub.CreatedEvent, event)
813 return
814 }
815 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
816
817 if len(msgs) == 0 {
818 event = AgentEvent{
819 Type: AgentEventTypeError,
820 Error: fmt.Errorf("no messages to summarize"),
821 Done: true,
822 }
823 a.Publish(pubsub.CreatedEvent, event)
824 return
825 }
826
827 event = AgentEvent{
828 Type: AgentEventTypeSummarize,
829 Progress: "Analyzing conversation...",
830 }
831 a.Publish(pubsub.CreatedEvent, event)
832
833 // Add a system message to guide the summarization
834 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
835
836 // Create a new message with the summarize prompt
837 promptMsg := message.Message{
838 Role: message.User,
839 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
840 }
841
842 // Append the prompt to the messages
843 msgsWithPrompt := append(msgs, promptMsg)
844
845 event = AgentEvent{
846 Type: AgentEventTypeSummarize,
847 Progress: "Generating summary...",
848 }
849
850 a.Publish(pubsub.CreatedEvent, event)
851
852 // Send the messages to the summarize provider
853 response := a.summarizeProvider.StreamResponse(
854 summarizeCtx,
855 msgsWithPrompt,
856 nil,
857 )
858 var finalResponse *provider.ProviderResponse
859 for r := range response {
860 if r.Error != nil {
861 event = AgentEvent{
862 Type: AgentEventTypeError,
863 Error: fmt.Errorf("failed to summarize: %w", r.Error),
864 Done: true,
865 }
866 a.Publish(pubsub.CreatedEvent, event)
867 return
868 }
869 finalResponse = r.Response
870 }
871
872 summary := strings.TrimSpace(finalResponse.Content)
873 if summary == "" {
874 event = AgentEvent{
875 Type: AgentEventTypeError,
876 Error: fmt.Errorf("empty summary returned"),
877 Done: true,
878 }
879 a.Publish(pubsub.CreatedEvent, event)
880 return
881 }
882 shell := shell.GetPersistentShell(config.Get().WorkingDir())
883 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
884 event = AgentEvent{
885 Type: AgentEventTypeSummarize,
886 Progress: "Creating new session...",
887 }
888
889 a.Publish(pubsub.CreatedEvent, event)
890 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
891 if err != nil {
892 event = AgentEvent{
893 Type: AgentEventTypeError,
894 Error: fmt.Errorf("failed to get session: %w", err),
895 Done: true,
896 }
897
898 a.Publish(pubsub.CreatedEvent, event)
899 return
900 }
901 // Create a message in the new session with the summary
902 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
903 Role: message.Assistant,
904 Parts: []message.ContentPart{
905 message.TextContent{Text: summary},
906 message.Finish{
907 Reason: message.FinishReasonEndTurn,
908 Time: time.Now().Unix(),
909 },
910 },
911 Model: a.summarizeProvider.Model().ID,
912 Provider: a.summarizeProviderID,
913 })
914 if err != nil {
915 event = AgentEvent{
916 Type: AgentEventTypeError,
917 Error: fmt.Errorf("failed to create summary message: %w", err),
918 Done: true,
919 }
920
921 a.Publish(pubsub.CreatedEvent, event)
922 return
923 }
924 oldSession.SummaryMessageID = msg.ID
925 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
926 oldSession.PromptTokens = 0
927 model := a.summarizeProvider.Model()
928 usage := finalResponse.Usage
929 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
930 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
931 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
932 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
933 oldSession.Cost += cost
934 _, err = a.sessions.Save(summarizeCtx, oldSession)
935 if err != nil {
936 event = AgentEvent{
937 Type: AgentEventTypeError,
938 Error: fmt.Errorf("failed to save session: %w", err),
939 Done: true,
940 }
941 a.Publish(pubsub.CreatedEvent, event)
942 }
943
944 event = AgentEvent{
945 Type: AgentEventTypeSummarize,
946 SessionID: oldSession.ID,
947 Progress: "Summary complete",
948 Done: true,
949 }
950 a.Publish(pubsub.CreatedEvent, event)
951 // Send final success event with the new session ID
952 }()
953
954 return nil
955}
956
957func (a *agent) ClearQueue(sessionID string) {
958 if a.QueuedPrompts(sessionID) > 0 {
959 slog.Info("Clearing queued prompts", "session_id", sessionID)
960 a.promptQueue.Del(sessionID)
961 }
962}
963
964func (a *agent) CancelAll() {
965 if !a.IsBusy() {
966 return
967 }
968 for key := range a.activeRequests.Seq2() {
969 a.Cancel(key) // key is sessionID
970 }
971
972 for _, cleanup := range a.cleanupFuncs {
973 if cleanup != nil {
974 cleanup()
975 }
976 }
977
978 timeout := time.After(5 * time.Second)
979 for a.IsBusy() {
980 select {
981 case <-timeout:
982 return
983 default:
984 time.Sleep(200 * time.Millisecond)
985 }
986 }
987}
988
989func (a *agent) UpdateModel() error {
990 cfg := config.Get()
991
992 // Get current provider configuration
993 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
994 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
995 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
996 }
997
998 // Check if provider has changed
999 if string(currentProviderCfg.ID) != a.providerID {
1000 // Provider changed, need to recreate the main provider
1001 model := cfg.GetModelByType(a.agentCfg.Model)
1002 if model.ID == "" {
1003 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1004 }
1005
1006 promptID := agentPromptMap[a.agentCfg.ID]
1007 if promptID == "" {
1008 promptID = prompt.PromptDefault
1009 }
1010
1011 opts := []provider.ProviderClientOption{
1012 provider.WithModel(a.agentCfg.Model),
1013 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1014 }
1015
1016 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1017 if err != nil {
1018 return fmt.Errorf("failed to create new provider: %w", err)
1019 }
1020
1021 // Update the provider and provider ID
1022 a.provider = newProvider
1023 a.providerID = string(currentProviderCfg.ID)
1024 }
1025
1026 // Check if providers have changed for title (small) and summarize (large)
1027 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1028 var smallModelProviderCfg config.ProviderConfig
1029 for p := range cfg.Providers.Seq() {
1030 if p.ID == smallModelCfg.Provider {
1031 smallModelProviderCfg = p
1032 break
1033 }
1034 }
1035 if smallModelProviderCfg.ID == "" {
1036 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1037 }
1038
1039 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1040 var largeModelProviderCfg config.ProviderConfig
1041 for p := range cfg.Providers.Seq() {
1042 if p.ID == largeModelCfg.Provider {
1043 largeModelProviderCfg = p
1044 break
1045 }
1046 }
1047 if largeModelProviderCfg.ID == "" {
1048 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1049 }
1050
1051 var maxTitleTokens int64 = 40
1052
1053 // if the max output is too low for the gemini provider it won't return anything
1054 if smallModelCfg.Provider == "gemini" {
1055 maxTitleTokens = 1000
1056 }
1057 // Recreate title provider
1058 titleOpts := []provider.ProviderClientOption{
1059 provider.WithModel(config.SelectedModelTypeSmall),
1060 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1061 provider.WithMaxTokens(maxTitleTokens),
1062 }
1063 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1064 if err != nil {
1065 return fmt.Errorf("failed to create new title provider: %w", err)
1066 }
1067 a.titleProvider = newTitleProvider
1068
1069 // Recreate summarize provider if provider changed (now large model)
1070 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1071 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1072 if largeModel == nil {
1073 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1074 }
1075 summarizeOpts := []provider.ProviderClientOption{
1076 provider.WithModel(config.SelectedModelTypeLarge),
1077 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1078 }
1079 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1080 if err != nil {
1081 return fmt.Errorf("failed to create new summarize provider: %w", err)
1082 }
1083 a.summarizeProvider = newSummarizeProvider
1084 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1085 }
1086
1087 return nil
1088}
1089
1090func (a *agent) setupEvents(ctx context.Context) {
1091 ctx, cancel := context.WithCancel(ctx)
1092
1093 go func() {
1094 subCh := SubscribeMCPEvents(ctx)
1095
1096 for {
1097 select {
1098 case event, ok := <-subCh:
1099 if !ok {
1100 slog.Debug("MCPEvents subscription channel closed")
1101 return
1102 }
1103 switch event.Payload.Type {
1104 case MCPEventToolsListChanged:
1105 name := event.Payload.Name
1106 c, ok := mcpClients.Get(name)
1107 if !ok {
1108 slog.Warn("MCP client not found for tools update", "name", name)
1109 continue
1110 }
1111 cfg := config.Get()
1112 tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1113 if err != nil {
1114 slog.Error("error listing tools", "error", err)
1115 updateMCPState(name, MCPStateError, err, nil, 0)
1116 _ = c.Close()
1117 continue
1118 }
1119 updateMcpTools(name, tools)
1120 a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
1121 updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1122 default:
1123 continue
1124 }
1125 case <-ctx.Done():
1126 slog.Debug("MCPEvents subscription cancelled")
1127 return
1128 }
1129 }
1130 }()
1131
1132 a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1133}