1// Package agent contains the implementation of the AI agent service.
2package agent
3
4import (
5 "context"
6 "errors"
7 "fmt"
8 "log/slog"
9 "maps"
10 "slices"
11 "strings"
12 "time"
13
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/event"
18 "github.com/charmbracelet/crush/internal/history"
19 "github.com/charmbracelet/crush/internal/llm/prompt"
20 "github.com/charmbracelet/crush/internal/llm/provider"
21 "github.com/charmbracelet/crush/internal/llm/tools"
22 "github.com/charmbracelet/crush/internal/log"
23 "github.com/charmbracelet/crush/internal/lsp"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/pubsub"
27 "github.com/charmbracelet/crush/internal/session"
28 "github.com/charmbracelet/crush/internal/shell"
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() catwalk.Model
53 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 Summarize(ctx context.Context, sessionID string) error
59 UpdateModel() error
60 QueuedPrompts(sessionID string) int
61 ClearQueue(sessionID string)
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69 permissions permission.Service
70 baseTools *csync.Map[string, tools.BaseTool]
71 mcpTools *csync.Map[string, tools.BaseTool]
72 lspClients *csync.Map[string, *lsp.Client]
73
74 // We need this to be able to update it when model changes
75 agentToolFn func() (tools.BaseTool, error)
76 cleanupFuncs []func()
77
78 provider provider.Provider
79 providerID string
80
81 titleProvider provider.Provider
82 summarizeProvider provider.Provider
83 summarizeProviderID string
84
85 activeRequests *csync.Map[string, context.CancelFunc]
86 promptQueue *csync.Map[string, []string]
87}
88
89var agentPromptMap = map[string]prompt.PromptID{
90 "coder": prompt.PromptCoder,
91 "task": prompt.PromptTask,
92}
93
94func NewAgent(
95 ctx context.Context,
96 agentCfg config.Agent,
97 // These services are needed in the tools
98 permissions permission.Service,
99 sessions session.Service,
100 messages message.Service,
101 history history.Service,
102 lspClients *csync.Map[string, *lsp.Client],
103) (Service, error) {
104 cfg := config.Get()
105
106 var agentToolFn func() (tools.BaseTool, error)
107 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
108 agentToolFn = func() (tools.BaseTool, error) {
109 taskAgentCfg := config.Get().Agents["task"]
110 if taskAgentCfg.ID == "" {
111 return nil, fmt.Errorf("task agent not found in config")
112 }
113 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
114 if err != nil {
115 return nil, fmt.Errorf("failed to create task agent: %w", err)
116 }
117 return NewAgentTool(taskAgent, sessions, messages), nil
118 }
119 }
120
121 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
122 if providerCfg == nil {
123 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
124 }
125 model := config.Get().GetModelByType(agentCfg.Model)
126
127 if model == nil {
128 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
129 }
130
131 promptID := agentPromptMap[agentCfg.ID]
132 if promptID == "" {
133 promptID = prompt.PromptDefault
134 }
135 opts := []provider.ProviderClientOption{
136 provider.WithModel(agentCfg.Model),
137 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
138 }
139 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
140 if err != nil {
141 return nil, err
142 }
143
144 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
145 var smallModelProviderCfg *config.ProviderConfig
146 if smallModelCfg.Provider == providerCfg.ID {
147 smallModelProviderCfg = providerCfg
148 } else {
149 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
150
151 if smallModelProviderCfg.ID == "" {
152 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
153 }
154 }
155 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
156 if smallModel.ID == "" {
157 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
158 }
159
160 titleOpts := []provider.ProviderClientOption{
161 provider.WithModel(config.SelectedModelTypeSmall),
162 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
163 }
164 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
165 if err != nil {
166 return nil, err
167 }
168
169 summarizeOpts := []provider.ProviderClientOption{
170 provider.WithModel(config.SelectedModelTypeLarge),
171 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
172 }
173 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
174 if err != nil {
175 return nil, err
176 }
177
178 baseToolsFn := func() map[string]tools.BaseTool {
179 slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
180 defer func() {
181 slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
182 }()
183
184 // Base tools available to all agents
185 cwd := cfg.WorkingDir()
186 result := make(map[string]tools.BaseTool)
187 for _, tool := range []tools.BaseTool{
188 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
189 tools.NewDownloadTool(permissions, cwd),
190 tools.NewEditTool(lspClients, permissions, history, cwd),
191 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
192 tools.NewFetchTool(permissions, cwd),
193 tools.NewGlobTool(cwd),
194 tools.NewGrepTool(cwd),
195 tools.NewLsTool(permissions, cwd),
196 tools.NewSourcegraphTool(),
197 tools.NewViewTool(lspClients, permissions, cwd),
198 tools.NewWriteTool(lspClients, permissions, history, cwd),
199 } {
200 result[tool.Name()] = tool
201 }
202 return result
203 }
204 mcpToolsFn := func() map[string]tools.BaseTool {
205 slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
206 defer func() {
207 slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
208 }()
209
210 mcpToolsOnce.Do(func() {
211 doGetMCPTools(ctx, permissions, cfg)
212 })
213
214 return maps.Collect(mcpTools.Seq2())
215 }
216
217 a := &agent{
218 Broker: pubsub.NewBroker[AgentEvent](),
219 agentCfg: agentCfg,
220 provider: agentProvider,
221 providerID: string(providerCfg.ID),
222 messages: messages,
223 sessions: sessions,
224 titleProvider: titleProvider,
225 summarizeProvider: summarizeProvider,
226 summarizeProviderID: string(providerCfg.ID),
227 agentToolFn: agentToolFn,
228 activeRequests: csync.NewMap[string, context.CancelFunc](),
229 mcpTools: csync.NewLazyMap(mcpToolsFn),
230 baseTools: csync.NewLazyMap(baseToolsFn),
231 promptQueue: csync.NewMap[string, []string](),
232 permissions: permissions,
233 lspClients: lspClients,
234 }
235 a.setupEvents(ctx)
236 return a, nil
237}
238
239func (a *agent) Model() catwalk.Model {
240 return *config.Get().GetModelByType(a.agentCfg.Model)
241}
242
243func (a *agent) Cancel(sessionID string) {
244 // Cancel regular requests
245 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
246 slog.Info("Request cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249
250 // Also check for summarize requests
251 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
252 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
253 cancel()
254 }
255
256 if a.QueuedPrompts(sessionID) > 0 {
257 slog.Info("Clearing queued prompts", "session_id", sessionID)
258 a.promptQueue.Del(sessionID)
259 }
260}
261
262func (a *agent) IsBusy() bool {
263 var busy bool
264 for cancelFunc := range a.activeRequests.Seq() {
265 if cancelFunc != nil {
266 busy = true
267 break
268 }
269 }
270 return busy
271}
272
273func (a *agent) IsSessionBusy(sessionID string) bool {
274 _, busy := a.activeRequests.Get(sessionID)
275 return busy
276}
277
278func (a *agent) QueuedPrompts(sessionID string) int {
279 l, ok := a.promptQueue.Get(sessionID)
280 if !ok {
281 return 0
282 }
283 return len(l)
284}
285
286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
287 if content == "" {
288 return nil
289 }
290 if a.titleProvider == nil {
291 return nil
292 }
293 session, err := a.sessions.Get(ctx, sessionID)
294 if err != nil {
295 return err
296 }
297 parts := []message.ContentPart{message.TextContent{
298 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
299 }}
300
301 // Use streaming approach like summarization
302 response := a.titleProvider.StreamResponse(
303 ctx,
304 []message.Message{
305 {
306 Role: message.User,
307 Parts: parts,
308 },
309 },
310 nil,
311 )
312
313 var finalResponse *provider.ProviderResponse
314 for r := range response {
315 if r.Error != nil {
316 return r.Error
317 }
318 finalResponse = r.Response
319 }
320
321 if finalResponse == nil {
322 return fmt.Errorf("no response received from title provider")
323 }
324
325 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
326
327 if idx := strings.Index(title, "</think>"); idx > 0 {
328 title = title[idx+len("</think>"):]
329 }
330
331 title = strings.TrimSpace(title)
332 if title == "" {
333 return nil
334 }
335
336 session.Title = title
337 _, err = a.sessions.Save(ctx, session)
338 return err
339}
340
341func (a *agent) err(err error) AgentEvent {
342 return AgentEvent{
343 Type: AgentEventTypeError,
344 Error: err,
345 }
346}
347
348func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
349 if !a.Model().SupportsImages && attachments != nil {
350 attachments = nil
351 }
352 events := make(chan AgentEvent, 1)
353 if a.IsSessionBusy(sessionID) {
354 existing, ok := a.promptQueue.Get(sessionID)
355 if !ok {
356 existing = []string{}
357 }
358 existing = append(existing, content)
359 a.promptQueue.Set(sessionID, existing)
360 return nil, nil
361 }
362
363 genCtx, cancel := context.WithCancel(ctx)
364 a.activeRequests.Set(sessionID, cancel)
365 startTime := time.Now()
366
367 go func() {
368 slog.Debug("Request started", "sessionID", sessionID)
369 defer log.RecoverPanic("agent.Run", func() {
370 events <- a.err(fmt.Errorf("panic while running the agent"))
371 })
372 var attachmentParts []message.ContentPart
373 for _, attachment := range attachments {
374 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
375 }
376 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
377 if result.Error != nil {
378 if isCancelledErr(result.Error) {
379 slog.Error("Request canceled", "sessionID", sessionID)
380 } else {
381 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
382 event.Error(result.Error)
383 }
384 } else {
385 slog.Debug("Request completed", "sessionID", sessionID)
386 }
387 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
388 a.activeRequests.Del(sessionID)
389 cancel()
390 a.Publish(pubsub.CreatedEvent, result)
391 events <- result
392 close(events)
393 }()
394 a.eventPromptSent(sessionID)
395 return events, nil
396}
397
398func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
399 cfg := config.Get()
400 // List existing messages; if none, start title generation asynchronously.
401 msgs, err := a.messages.List(ctx, sessionID)
402 if err != nil {
403 return a.err(fmt.Errorf("failed to list messages: %w", err))
404 }
405 if len(msgs) == 0 {
406 go func() {
407 defer log.RecoverPanic("agent.Run", func() {
408 slog.Error("panic while generating title")
409 })
410 titleErr := a.generateTitle(ctx, sessionID, content)
411 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
412 slog.Error("failed to generate title", "error", titleErr)
413 }
414 }()
415 }
416 session, err := a.sessions.Get(ctx, sessionID)
417 if err != nil {
418 return a.err(fmt.Errorf("failed to get session: %w", err))
419 }
420 if session.SummaryMessageID != "" {
421 summaryMsgInex := -1
422 for i, msg := range msgs {
423 if msg.ID == session.SummaryMessageID {
424 summaryMsgInex = i
425 break
426 }
427 }
428 if summaryMsgInex != -1 {
429 msgs = msgs[summaryMsgInex:]
430 msgs[0].Role = message.User
431 }
432 }
433
434 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
435 if err != nil {
436 return a.err(fmt.Errorf("failed to create user message: %w", err))
437 }
438 // Append the new user message to the conversation history.
439 msgHistory := append(msgs, userMsg)
440
441 for {
442 // Check for cancellation before each iteration
443 select {
444 case <-ctx.Done():
445 return a.err(ctx.Err())
446 default:
447 // Continue processing
448 }
449 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
450 if err != nil {
451 if errors.Is(err, context.Canceled) {
452 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
453 a.messages.Update(context.Background(), agentMessage)
454 return a.err(ErrRequestCancelled)
455 }
456 return a.err(fmt.Errorf("failed to process events: %w", err))
457 }
458 if cfg.Options.Debug {
459 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
460 }
461 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
462 // We are not done, we need to respond with the tool response
463 msgHistory = append(msgHistory, agentMessage, *toolResults)
464 // If there are queued prompts, process the next one
465 nextPrompt, ok := a.promptQueue.Take(sessionID)
466 if ok {
467 for _, prompt := range nextPrompt {
468 // Create a new user message for the queued prompt
469 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
470 if err != nil {
471 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
472 }
473 // Append the new user message to the conversation history
474 msgHistory = append(msgHistory, userMsg)
475 }
476 }
477
478 continue
479 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
480 queuePrompts, ok := a.promptQueue.Take(sessionID)
481 if ok {
482 for _, prompt := range queuePrompts {
483 if prompt == "" {
484 continue
485 }
486 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
487 if err != nil {
488 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
489 }
490 msgHistory = append(msgHistory, userMsg)
491 }
492 continue
493 }
494 }
495 if agentMessage.FinishReason() == "" {
496 // Kujtim: could not track down where this is happening but this means its cancelled
497 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
498 _ = a.messages.Update(context.Background(), agentMessage)
499 return a.err(ErrRequestCancelled)
500 }
501 return AgentEvent{
502 Type: AgentEventTypeResponse,
503 Message: agentMessage,
504 Done: true,
505 }
506 }
507}
508
509func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
510 parts := []message.ContentPart{message.TextContent{Text: content}}
511 parts = append(parts, attachmentParts...)
512 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
513 Role: message.User,
514 Parts: parts,
515 })
516}
517
518func (a *agent) getAllTools() ([]tools.BaseTool, error) {
519 var allTools []tools.BaseTool
520 for tool := range a.baseTools.Seq() {
521 if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
522 allTools = append(allTools, tool)
523 }
524 }
525 if a.agentCfg.ID == "coder" {
526 allTools = slices.AppendSeq(allTools, a.mcpTools.Seq())
527 if a.lspClients.Len() > 0 {
528 allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients))
529 }
530 }
531 if a.agentToolFn != nil {
532 agentTool, agentToolErr := a.agentToolFn()
533 if agentToolErr != nil {
534 return nil, agentToolErr
535 }
536 allTools = append(allTools, agentTool)
537 }
538 return allTools, nil
539}
540
541func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
542 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
543
544 // Create the assistant message first so the spinner shows immediately
545 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
546 Role: message.Assistant,
547 Parts: []message.ContentPart{},
548 Model: a.Model().ID,
549 Provider: a.providerID,
550 })
551 if err != nil {
552 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
553 }
554
555 allTools, toolsErr := a.getAllTools()
556 if toolsErr != nil {
557 return assistantMsg, nil, toolsErr
558 }
559 // Now collect tools (which may block on MCP initialization)
560 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
561
562 // Add the session and message ID into the context if needed by tools.
563 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
564
565loop:
566 for {
567 select {
568 case event, ok := <-eventChan:
569 if !ok {
570 break loop
571 }
572 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
573 if errors.Is(processErr, context.Canceled) {
574 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
575 } else {
576 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
577 }
578 return assistantMsg, nil, processErr
579 }
580 case <-ctx.Done():
581 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
582 return assistantMsg, nil, ctx.Err()
583 }
584 }
585
586 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
587 toolCalls := assistantMsg.ToolCalls()
588 for i, toolCall := range toolCalls {
589 select {
590 case <-ctx.Done():
591 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
592 // Make all future tool calls cancelled
593 for j := i; j < len(toolCalls); j++ {
594 toolResults[j] = message.ToolResult{
595 ToolCallID: toolCalls[j].ID,
596 Content: "Tool execution canceled by user",
597 IsError: true,
598 }
599 }
600 goto out
601 default:
602 // Continue processing
603 var tool tools.BaseTool
604 allTools, _ = a.getAllTools()
605 for _, availableTool := range allTools {
606 if availableTool.Info().Name == toolCall.Name {
607 tool = availableTool
608 break
609 }
610 }
611
612 // Tool not found
613 if tool == nil {
614 toolResults[i] = message.ToolResult{
615 ToolCallID: toolCall.ID,
616 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
617 IsError: true,
618 }
619 continue
620 }
621
622 // Run tool in goroutine to allow cancellation
623 type toolExecResult struct {
624 response tools.ToolResponse
625 err error
626 }
627 resultChan := make(chan toolExecResult, 1)
628
629 go func() {
630 response, err := tool.Run(ctx, tools.ToolCall{
631 ID: toolCall.ID,
632 Name: toolCall.Name,
633 Input: toolCall.Input,
634 })
635 resultChan <- toolExecResult{response: response, err: err}
636 }()
637
638 var toolResponse tools.ToolResponse
639 var toolErr error
640
641 select {
642 case <-ctx.Done():
643 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
644 // Mark remaining tool calls as cancelled
645 for j := i; j < len(toolCalls); j++ {
646 toolResults[j] = message.ToolResult{
647 ToolCallID: toolCalls[j].ID,
648 Content: "Tool execution canceled by user",
649 IsError: true,
650 }
651 }
652 goto out
653 case result := <-resultChan:
654 toolResponse = result.response
655 toolErr = result.err
656 }
657
658 if toolErr != nil {
659 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
660 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
661 toolResults[i] = message.ToolResult{
662 ToolCallID: toolCall.ID,
663 Content: "Permission denied",
664 IsError: true,
665 }
666 for j := i + 1; j < len(toolCalls); j++ {
667 toolResults[j] = message.ToolResult{
668 ToolCallID: toolCalls[j].ID,
669 Content: "Tool execution canceled by user",
670 IsError: true,
671 }
672 }
673 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
674 break
675 }
676 }
677 toolResults[i] = message.ToolResult{
678 ToolCallID: toolCall.ID,
679 Content: toolResponse.Content,
680 Metadata: toolResponse.Metadata,
681 IsError: toolResponse.IsError,
682 }
683 }
684 }
685out:
686 if len(toolResults) == 0 {
687 return assistantMsg, nil, nil
688 }
689 parts := make([]message.ContentPart, 0)
690 for _, tr := range toolResults {
691 parts = append(parts, tr)
692 }
693 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
694 Role: message.Tool,
695 Parts: parts,
696 Provider: a.providerID,
697 })
698 if err != nil {
699 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
700 }
701
702 return assistantMsg, &msg, err
703}
704
705func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
706 msg.AddFinish(finishReason, message, details)
707 _ = a.messages.Update(ctx, *msg)
708}
709
710func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
711 select {
712 case <-ctx.Done():
713 return ctx.Err()
714 default:
715 // Continue processing.
716 }
717
718 switch event.Type {
719 case provider.EventThinkingDelta:
720 assistantMsg.AppendReasoningContent(event.Thinking)
721 return a.messages.Update(ctx, *assistantMsg)
722 case provider.EventSignatureDelta:
723 assistantMsg.AppendReasoningSignature(event.Signature)
724 return a.messages.Update(ctx, *assistantMsg)
725 case provider.EventContentDelta:
726 assistantMsg.FinishThinking()
727 assistantMsg.AppendContent(event.Content)
728 return a.messages.Update(ctx, *assistantMsg)
729 case provider.EventToolUseStart:
730 assistantMsg.FinishThinking()
731 slog.Info("Tool call started", "toolCall", event.ToolCall)
732 assistantMsg.AddToolCall(*event.ToolCall)
733 return a.messages.Update(ctx, *assistantMsg)
734 case provider.EventToolUseDelta:
735 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
736 return a.messages.Update(ctx, *assistantMsg)
737 case provider.EventToolUseStop:
738 slog.Info("Finished tool call", "toolCall", event.ToolCall)
739 assistantMsg.FinishToolCall(event.ToolCall.ID)
740 return a.messages.Update(ctx, *assistantMsg)
741 case provider.EventError:
742 return event.Error
743 case provider.EventComplete:
744 assistantMsg.FinishThinking()
745 assistantMsg.SetToolCalls(event.Response.ToolCalls)
746 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
747 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
748 return fmt.Errorf("failed to update message: %w", err)
749 }
750 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
751 }
752
753 return nil
754}
755
756func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
757 sess, err := a.sessions.Get(ctx, sessionID)
758 if err != nil {
759 return fmt.Errorf("failed to get session: %w", err)
760 }
761
762 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
763 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
764 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
765 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
766
767 a.eventTokensUsed(sessionID, usage, cost)
768
769 sess.Cost += cost
770 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
771 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
772
773 _, err = a.sessions.Save(ctx, sess)
774 if err != nil {
775 return fmt.Errorf("failed to save session: %w", err)
776 }
777 return nil
778}
779
780func (a *agent) Summarize(ctx context.Context, sessionID string) error {
781 if a.summarizeProvider == nil {
782 return fmt.Errorf("summarize provider not available")
783 }
784
785 // Check if session is busy
786 if a.IsSessionBusy(sessionID) {
787 return ErrSessionBusy
788 }
789
790 // Create a new context with cancellation
791 summarizeCtx, cancel := context.WithCancel(ctx)
792
793 // Store the cancel function in activeRequests to allow cancellation
794 a.activeRequests.Set(sessionID+"-summarize", cancel)
795
796 go func() {
797 defer a.activeRequests.Del(sessionID + "-summarize")
798 defer cancel()
799 event := AgentEvent{
800 Type: AgentEventTypeSummarize,
801 Progress: "Starting summarization...",
802 }
803
804 a.Publish(pubsub.CreatedEvent, event)
805 // Get all messages from the session
806 msgs, err := a.messages.List(summarizeCtx, sessionID)
807 if err != nil {
808 event = AgentEvent{
809 Type: AgentEventTypeError,
810 Error: fmt.Errorf("failed to list messages: %w", err),
811 Done: true,
812 }
813 a.Publish(pubsub.CreatedEvent, event)
814 return
815 }
816 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
817
818 if len(msgs) == 0 {
819 event = AgentEvent{
820 Type: AgentEventTypeError,
821 Error: fmt.Errorf("no messages to summarize"),
822 Done: true,
823 }
824 a.Publish(pubsub.CreatedEvent, event)
825 return
826 }
827
828 event = AgentEvent{
829 Type: AgentEventTypeSummarize,
830 Progress: "Analyzing conversation...",
831 }
832 a.Publish(pubsub.CreatedEvent, event)
833
834 // Add a system message to guide the summarization
835 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
836
837 // Create a new message with the summarize prompt
838 promptMsg := message.Message{
839 Role: message.User,
840 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
841 }
842
843 // Append the prompt to the messages
844 msgsWithPrompt := append(msgs, promptMsg)
845
846 event = AgentEvent{
847 Type: AgentEventTypeSummarize,
848 Progress: "Generating summary...",
849 }
850
851 a.Publish(pubsub.CreatedEvent, event)
852
853 // Send the messages to the summarize provider
854 response := a.summarizeProvider.StreamResponse(
855 summarizeCtx,
856 msgsWithPrompt,
857 nil,
858 )
859 var finalResponse *provider.ProviderResponse
860 for r := range response {
861 if r.Error != nil {
862 event = AgentEvent{
863 Type: AgentEventTypeError,
864 Error: fmt.Errorf("failed to summarize: %w", r.Error),
865 Done: true,
866 }
867 a.Publish(pubsub.CreatedEvent, event)
868 return
869 }
870 finalResponse = r.Response
871 }
872
873 summary := strings.TrimSpace(finalResponse.Content)
874 if summary == "" {
875 event = AgentEvent{
876 Type: AgentEventTypeError,
877 Error: fmt.Errorf("empty summary returned"),
878 Done: true,
879 }
880 a.Publish(pubsub.CreatedEvent, event)
881 return
882 }
883 shell := shell.GetPersistentShell(config.Get().WorkingDir())
884 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
885 event = AgentEvent{
886 Type: AgentEventTypeSummarize,
887 Progress: "Creating new session...",
888 }
889
890 a.Publish(pubsub.CreatedEvent, event)
891 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
892 if err != nil {
893 event = AgentEvent{
894 Type: AgentEventTypeError,
895 Error: fmt.Errorf("failed to get session: %w", err),
896 Done: true,
897 }
898
899 a.Publish(pubsub.CreatedEvent, event)
900 return
901 }
902 // Create a message in the new session with the summary
903 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
904 Role: message.Assistant,
905 Parts: []message.ContentPart{
906 message.TextContent{Text: summary},
907 message.Finish{
908 Reason: message.FinishReasonEndTurn,
909 Time: time.Now().Unix(),
910 },
911 },
912 Model: a.summarizeProvider.Model().ID,
913 Provider: a.summarizeProviderID,
914 })
915 if err != nil {
916 event = AgentEvent{
917 Type: AgentEventTypeError,
918 Error: fmt.Errorf("failed to create summary message: %w", err),
919 Done: true,
920 }
921
922 a.Publish(pubsub.CreatedEvent, event)
923 return
924 }
925 oldSession.SummaryMessageID = msg.ID
926 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
927 oldSession.PromptTokens = 0
928 model := a.summarizeProvider.Model()
929 usage := finalResponse.Usage
930 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
931 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
932 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
933 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
934 oldSession.Cost += cost
935 _, err = a.sessions.Save(summarizeCtx, oldSession)
936 if err != nil {
937 event = AgentEvent{
938 Type: AgentEventTypeError,
939 Error: fmt.Errorf("failed to save session: %w", err),
940 Done: true,
941 }
942 a.Publish(pubsub.CreatedEvent, event)
943 }
944
945 event = AgentEvent{
946 Type: AgentEventTypeSummarize,
947 SessionID: oldSession.ID,
948 Progress: "Summary complete",
949 Done: true,
950 }
951 a.Publish(pubsub.CreatedEvent, event)
952 // Send final success event with the new session ID
953 }()
954
955 return nil
956}
957
958func (a *agent) ClearQueue(sessionID string) {
959 if a.QueuedPrompts(sessionID) > 0 {
960 slog.Info("Clearing queued prompts", "session_id", sessionID)
961 a.promptQueue.Del(sessionID)
962 }
963}
964
965func (a *agent) CancelAll() {
966 if !a.IsBusy() {
967 return
968 }
969 for key := range a.activeRequests.Seq2() {
970 a.Cancel(key) // key is sessionID
971 }
972
973 for _, cleanup := range a.cleanupFuncs {
974 if cleanup != nil {
975 cleanup()
976 }
977 }
978
979 timeout := time.After(5 * time.Second)
980 for a.IsBusy() {
981 select {
982 case <-timeout:
983 return
984 default:
985 time.Sleep(200 * time.Millisecond)
986 }
987 }
988}
989
990func (a *agent) UpdateModel() error {
991 cfg := config.Get()
992
993 // Get current provider configuration
994 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
995 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
996 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
997 }
998
999 // Check if provider has changed
1000 if string(currentProviderCfg.ID) != a.providerID {
1001 // Provider changed, need to recreate the main provider
1002 model := cfg.GetModelByType(a.agentCfg.Model)
1003 if model.ID == "" {
1004 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1005 }
1006
1007 promptID := agentPromptMap[a.agentCfg.ID]
1008 if promptID == "" {
1009 promptID = prompt.PromptDefault
1010 }
1011
1012 opts := []provider.ProviderClientOption{
1013 provider.WithModel(a.agentCfg.Model),
1014 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1015 }
1016
1017 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1018 if err != nil {
1019 return fmt.Errorf("failed to create new provider: %w", err)
1020 }
1021
1022 // Update the provider and provider ID
1023 a.provider = newProvider
1024 a.providerID = string(currentProviderCfg.ID)
1025 }
1026
1027 // Check if providers have changed for title (small) and summarize (large)
1028 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1029 var smallModelProviderCfg config.ProviderConfig
1030 for p := range cfg.Providers.Seq() {
1031 if p.ID == smallModelCfg.Provider {
1032 smallModelProviderCfg = p
1033 break
1034 }
1035 }
1036 if smallModelProviderCfg.ID == "" {
1037 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1038 }
1039
1040 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1041 var largeModelProviderCfg config.ProviderConfig
1042 for p := range cfg.Providers.Seq() {
1043 if p.ID == largeModelCfg.Provider {
1044 largeModelProviderCfg = p
1045 break
1046 }
1047 }
1048 if largeModelProviderCfg.ID == "" {
1049 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1050 }
1051
1052 var maxTitleTokens int64 = 40
1053
1054 // if the max output is too low for the gemini provider it won't return anything
1055 if smallModelCfg.Provider == "gemini" {
1056 maxTitleTokens = 1000
1057 }
1058 // Recreate title provider
1059 titleOpts := []provider.ProviderClientOption{
1060 provider.WithModel(config.SelectedModelTypeSmall),
1061 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1062 provider.WithMaxTokens(maxTitleTokens),
1063 }
1064 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1065 if err != nil {
1066 return fmt.Errorf("failed to create new title provider: %w", err)
1067 }
1068 a.titleProvider = newTitleProvider
1069
1070 // Recreate summarize provider if provider changed (now large model)
1071 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1072 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1073 if largeModel == nil {
1074 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1075 }
1076 summarizeOpts := []provider.ProviderClientOption{
1077 provider.WithModel(config.SelectedModelTypeLarge),
1078 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1079 }
1080 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1081 if err != nil {
1082 return fmt.Errorf("failed to create new summarize provider: %w", err)
1083 }
1084 a.summarizeProvider = newSummarizeProvider
1085 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1086 }
1087
1088 return nil
1089}
1090
1091func (a *agent) setupEvents(ctx context.Context) {
1092 ctx, cancel := context.WithCancel(ctx)
1093
1094 go func() {
1095 subCh := SubscribeMCPEvents(ctx)
1096
1097 for {
1098 select {
1099 case event, ok := <-subCh:
1100 if !ok {
1101 slog.Debug("MCPEvents subscription channel closed")
1102 return
1103 }
1104 switch event.Payload.Type {
1105 case MCPEventToolsListChanged:
1106 name := event.Payload.Name
1107 c, ok := mcpClients.Get(name)
1108 if !ok {
1109 slog.Warn("MCP client not found for tools update", "name", name)
1110 continue
1111 }
1112 cfg := config.Get()
1113 tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1114 if err != nil {
1115 slog.Error("error listing tools", "error", err)
1116 updateMCPState(name, MCPStateError, err, nil, 0)
1117 _ = c.Close()
1118 continue
1119 }
1120 updateMcpTools(name, tools)
1121 a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
1122 updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1123 default:
1124 continue
1125 }
1126 case <-ctx.Done():
1127 slog.Debug("MCPEvents subscription cancelled")
1128 return
1129 }
1130 }
1131 }()
1132
1133 a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1134}