1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "slices"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/event"
17 "github.com/charmbracelet/crush/internal/history"
18 "github.com/charmbracelet/crush/internal/llm/prompt"
19 "github.com/charmbracelet/crush/internal/llm/provider"
20 "github.com/charmbracelet/crush/internal/llm/tools"
21 "github.com/charmbracelet/crush/internal/log"
22 "github.com/charmbracelet/crush/internal/lsp"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/pubsub"
26 "github.com/charmbracelet/crush/internal/session"
27 "github.com/charmbracelet/crush/internal/shell"
28)
29
30type AgentEventType string
31
32const (
33 AgentEventTypeError AgentEventType = "error"
34 AgentEventTypeResponse AgentEventType = "response"
35 AgentEventTypeSummarize AgentEventType = "summarize"
36)
37
38type AgentEvent struct {
39 Type AgentEventType
40 Message message.Message
41 Error error
42
43 // When summarizing
44 SessionID string
45 Progress string
46 Done bool
47}
48
49type Service interface {
50 pubsub.Suscriber[AgentEvent]
51 Model() catwalk.Model
52 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
53 Cancel(sessionID string)
54 CancelAll()
55 IsSessionBusy(sessionID string) bool
56 IsBusy() bool
57 Summarize(ctx context.Context, sessionID string) error
58 UpdateModel() error
59 QueuedPrompts(sessionID string) int
60 ClearQueue(sessionID string)
61}
62
63type agent struct {
64 *pubsub.Broker[AgentEvent]
65 agentCfg config.Agent
66 sessions session.Service
67 messages message.Service
68 permissions permission.Service
69 baseTools *csync.Map[string, tools.BaseTool]
70 mcpTools *csync.Map[string, tools.BaseTool]
71 lspClients *csync.Map[string, *lsp.Client]
72
73 // We need this to be able to update it when model changes
74 agentToolFn func() (tools.BaseTool, error)
75 cleanupFuncs []func()
76
77 provider provider.Provider
78 providerID string
79
80 titleProvider provider.Provider
81 summarizeProvider provider.Provider
82 summarizeProviderID string
83
84 activeRequests *csync.Map[string, context.CancelFunc]
85 promptQueue *csync.Map[string, []string]
86}
87
88var agentPromptMap = map[string]prompt.PromptID{
89 "coder": prompt.PromptCoder,
90 "task": prompt.PromptTask,
91}
92
93func NewAgent(
94 ctx context.Context,
95 agentCfg config.Agent,
96 // These services are needed in the tools
97 permissions permission.Service,
98 sessions session.Service,
99 messages message.Service,
100 history history.Service,
101 lspClients *csync.Map[string, *lsp.Client],
102) (Service, error) {
103 cfg := config.Get()
104
105 var agentToolFn func() (tools.BaseTool, error)
106 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
107 agentToolFn = func() (tools.BaseTool, error) {
108 taskAgentCfg := config.Get().Agents["task"]
109 if taskAgentCfg.ID == "" {
110 return nil, fmt.Errorf("task agent not found in config")
111 }
112 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
113 if err != nil {
114 return nil, fmt.Errorf("failed to create task agent: %w", err)
115 }
116 return NewAgentTool(taskAgent, sessions, messages), nil
117 }
118 }
119
120 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
121 if providerCfg == nil {
122 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
123 }
124 model := config.Get().GetModelByType(agentCfg.Model)
125
126 if model == nil {
127 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
128 }
129
130 promptID := agentPromptMap[agentCfg.ID]
131 if promptID == "" {
132 promptID = prompt.PromptDefault
133 }
134 opts := []provider.ProviderClientOption{
135 provider.WithModel(agentCfg.Model),
136 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
137 }
138 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
139 if err != nil {
140 return nil, err
141 }
142
143 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
144 var smallModelProviderCfg *config.ProviderConfig
145 if smallModelCfg.Provider == providerCfg.ID {
146 smallModelProviderCfg = providerCfg
147 } else {
148 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
149
150 if smallModelProviderCfg.ID == "" {
151 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
152 }
153 }
154 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
155 if smallModel.ID == "" {
156 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
157 }
158
159 titleOpts := []provider.ProviderClientOption{
160 provider.WithModel(config.SelectedModelTypeSmall),
161 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
162 }
163 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
164 if err != nil {
165 return nil, err
166 }
167
168 summarizeOpts := []provider.ProviderClientOption{
169 provider.WithModel(config.SelectedModelTypeLarge),
170 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
171 }
172 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
173 if err != nil {
174 return nil, err
175 }
176
177 baseToolsFn := func() map[string]tools.BaseTool {
178 slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
179 defer func() {
180 slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
181 }()
182
183 // Base tools available to all agents
184 cwd := cfg.WorkingDir()
185 result := make(map[string]tools.BaseTool)
186 for _, tool := range []tools.BaseTool{
187 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
188 tools.NewDownloadTool(permissions, cwd),
189 tools.NewEditTool(lspClients, permissions, history, cwd),
190 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
191 tools.NewFetchTool(permissions, cwd),
192 tools.NewGlobTool(cwd),
193 tools.NewGrepTool(cwd),
194 tools.NewLsTool(permissions, cwd),
195 tools.NewSourcegraphTool(),
196 tools.NewViewTool(lspClients, permissions, cwd),
197 tools.NewWriteTool(lspClients, permissions, history, cwd),
198 } {
199 result[tool.Name()] = tool
200 }
201 return result
202 }
203 mcpToolsFn := func() map[string]tools.BaseTool {
204 slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
205 defer func() {
206 slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
207 }()
208
209 mcpToolsOnce.Do(func() {
210 doGetMCPTools(ctx, permissions, cfg)
211 })
212
213 return maps.Collect(mcpTools.Seq2())
214 }
215
216 a := &agent{
217 Broker: pubsub.NewBroker[AgentEvent](),
218 agentCfg: agentCfg,
219 provider: agentProvider,
220 providerID: string(providerCfg.ID),
221 messages: messages,
222 sessions: sessions,
223 titleProvider: titleProvider,
224 summarizeProvider: summarizeProvider,
225 summarizeProviderID: string(providerCfg.ID),
226 agentToolFn: agentToolFn,
227 activeRequests: csync.NewMap[string, context.CancelFunc](),
228 mcpTools: csync.NewLazyMap(mcpToolsFn),
229 baseTools: csync.NewLazyMap(baseToolsFn),
230 promptQueue: csync.NewMap[string, []string](),
231 permissions: permissions,
232 lspClients: lspClients,
233 }
234 a.setupEvents(ctx)
235 return a, nil
236}
237
238func (a *agent) Model() catwalk.Model {
239 return *config.Get().GetModelByType(a.agentCfg.Model)
240}
241
242func (a *agent) Cancel(sessionID string) {
243 // Cancel regular requests
244 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
245 slog.Info("Request cancellation initiated", "session_id", sessionID)
246 cancel()
247 }
248
249 // Also check for summarize requests
250 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
251 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
252 cancel()
253 }
254
255 if a.QueuedPrompts(sessionID) > 0 {
256 slog.Info("Clearing queued prompts", "session_id", sessionID)
257 a.promptQueue.Del(sessionID)
258 }
259}
260
261func (a *agent) IsBusy() bool {
262 var busy bool
263 for cancelFunc := range a.activeRequests.Seq() {
264 if cancelFunc != nil {
265 busy = true
266 break
267 }
268 }
269 return busy
270}
271
272func (a *agent) IsSessionBusy(sessionID string) bool {
273 _, busy := a.activeRequests.Get(sessionID)
274 return busy
275}
276
277func (a *agent) QueuedPrompts(sessionID string) int {
278 l, ok := a.promptQueue.Get(sessionID)
279 if !ok {
280 return 0
281 }
282 return len(l)
283}
284
285func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
286 if content == "" {
287 return nil
288 }
289 if a.titleProvider == nil {
290 return nil
291 }
292 session, err := a.sessions.Get(ctx, sessionID)
293 if err != nil {
294 return err
295 }
296 parts := []message.ContentPart{message.TextContent{
297 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
298 }}
299
300 // Use streaming approach like summarization
301 response := a.titleProvider.StreamResponse(
302 ctx,
303 []message.Message{
304 {
305 Role: message.User,
306 Parts: parts,
307 },
308 },
309 nil,
310 )
311
312 var finalResponse *provider.ProviderResponse
313 for r := range response {
314 if r.Error != nil {
315 return r.Error
316 }
317 finalResponse = r.Response
318 }
319
320 if finalResponse == nil {
321 return fmt.Errorf("no response received from title provider")
322 }
323
324 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
325
326 if idx := strings.Index(title, "</think>"); idx > 0 {
327 title = title[idx+len("</think>"):]
328 }
329
330 title = strings.TrimSpace(title)
331 if title == "" {
332 return nil
333 }
334
335 session.Title = title
336 _, err = a.sessions.Save(ctx, session)
337 return err
338}
339
340func (a *agent) err(err error) AgentEvent {
341 return AgentEvent{
342 Type: AgentEventTypeError,
343 Error: err,
344 }
345}
346
347func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
348 if !a.Model().SupportsImages && attachments != nil {
349 attachments = nil
350 }
351 events := make(chan AgentEvent, 1)
352 if a.IsSessionBusy(sessionID) {
353 existing, ok := a.promptQueue.Get(sessionID)
354 if !ok {
355 existing = []string{}
356 }
357 existing = append(existing, content)
358 a.promptQueue.Set(sessionID, existing)
359 return nil, nil
360 }
361
362 genCtx, cancel := context.WithCancel(ctx)
363 a.activeRequests.Set(sessionID, cancel)
364 startTime := time.Now()
365
366 go func() {
367 slog.Debug("Request started", "sessionID", sessionID)
368 defer log.RecoverPanic("agent.Run", func() {
369 events <- a.err(fmt.Errorf("panic while running the agent"))
370 })
371 var attachmentParts []message.ContentPart
372 for _, attachment := range attachments {
373 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
374 }
375 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
376 if result.Error != nil {
377 if isCancelledErr(result.Error) {
378 slog.Error("Request canceled", "sessionID", sessionID)
379 } else {
380 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
381 event.Error(result.Error)
382 }
383 } else {
384 slog.Debug("Request completed", "sessionID", sessionID)
385 }
386 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
387 a.activeRequests.Del(sessionID)
388 cancel()
389 a.Publish(pubsub.CreatedEvent, result)
390 events <- result
391 close(events)
392 }()
393 a.eventPromptSent(sessionID)
394 return events, nil
395}
396
397func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
398 cfg := config.Get()
399 // List existing messages; if none, start title generation asynchronously.
400 msgs, err := a.messages.List(ctx, sessionID)
401 if err != nil {
402 return a.err(fmt.Errorf("failed to list messages: %w", err))
403 }
404 if len(msgs) == 0 {
405 go func() {
406 defer log.RecoverPanic("agent.Run", func() {
407 slog.Error("panic while generating title")
408 })
409 titleErr := a.generateTitle(ctx, sessionID, content)
410 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
411 slog.Error("failed to generate title", "error", titleErr)
412 }
413 }()
414 }
415 session, err := a.sessions.Get(ctx, sessionID)
416 if err != nil {
417 return a.err(fmt.Errorf("failed to get session: %w", err))
418 }
419 if session.SummaryMessageID != "" {
420 summaryMsgInex := -1
421 for i, msg := range msgs {
422 if msg.ID == session.SummaryMessageID {
423 summaryMsgInex = i
424 break
425 }
426 }
427 if summaryMsgInex != -1 {
428 msgs = msgs[summaryMsgInex:]
429 msgs[0].Role = message.User
430 }
431 }
432
433 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
434 if err != nil {
435 return a.err(fmt.Errorf("failed to create user message: %w", err))
436 }
437 // Append the new user message to the conversation history.
438 msgHistory := append(msgs, userMsg)
439
440 for {
441 // Check for cancellation before each iteration
442 select {
443 case <-ctx.Done():
444 return a.err(ctx.Err())
445 default:
446 // Continue processing
447 }
448 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
449 if err != nil {
450 if errors.Is(err, context.Canceled) {
451 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
452 a.messages.Update(context.Background(), agentMessage)
453 return a.err(ErrRequestCancelled)
454 }
455 return a.err(fmt.Errorf("failed to process events: %w", err))
456 }
457 if cfg.Options.Debug {
458 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
459 }
460 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
461 // We are not done, we need to respond with the tool response
462 msgHistory = append(msgHistory, agentMessage, *toolResults)
463 // If there are queued prompts, process the next one
464 nextPrompt, ok := a.promptQueue.Take(sessionID)
465 if ok {
466 for _, prompt := range nextPrompt {
467 // Create a new user message for the queued prompt
468 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
469 if err != nil {
470 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
471 }
472 // Append the new user message to the conversation history
473 msgHistory = append(msgHistory, userMsg)
474 }
475 }
476
477 continue
478 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
479 queuePrompts, ok := a.promptQueue.Take(sessionID)
480 if ok {
481 for _, prompt := range queuePrompts {
482 if prompt == "" {
483 continue
484 }
485 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
486 if err != nil {
487 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
488 }
489 msgHistory = append(msgHistory, userMsg)
490 }
491 continue
492 }
493 }
494 if agentMessage.FinishReason() == "" {
495 // Kujtim: could not track down where this is happening but this means its cancelled
496 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
497 _ = a.messages.Update(context.Background(), agentMessage)
498 return a.err(ErrRequestCancelled)
499 }
500 return AgentEvent{
501 Type: AgentEventTypeResponse,
502 Message: agentMessage,
503 Done: true,
504 }
505 }
506}
507
508func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
509 parts := []message.ContentPart{message.TextContent{Text: content}}
510 parts = append(parts, attachmentParts...)
511 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
512 Role: message.User,
513 Parts: parts,
514 })
515}
516
517func (a *agent) getAllTools() ([]tools.BaseTool, error) {
518 allTools := slices.Collect(a.baseTools.Seq())
519
520 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
521 if a.agentCfg.ID == "coder" {
522 t = append(t, slices.Collect(a.mcpTools.Seq())...)
523 if a.lspClients.Len() > 0 {
524 t = append(t, tools.NewDiagnosticsTool(a.lspClients))
525 }
526 }
527 return t
528 }
529
530 if a.agentCfg.AllowedTools == nil {
531 allTools = withCoderTools(allTools)
532 } else {
533 var filteredTools []tools.BaseTool
534 for _, tool := range allTools {
535 if slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
536 filteredTools = append(filteredTools, tool)
537 }
538 }
539 allTools = withCoderTools(filteredTools)
540 }
541
542 if a.agentToolFn != nil {
543 agentTool, agentToolErr := a.agentToolFn()
544 if agentToolErr != nil {
545 return nil, agentToolErr
546 }
547 allTools = append(allTools, agentTool)
548 }
549 return allTools, nil
550}
551
552func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
553 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
554
555 // Create the assistant message first so the spinner shows immediately
556 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
557 Role: message.Assistant,
558 Parts: []message.ContentPart{},
559 Model: a.Model().ID,
560 Provider: a.providerID,
561 })
562 if err != nil {
563 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
564 }
565
566 allTools, toolsErr := a.getAllTools()
567 if toolsErr != nil {
568 return assistantMsg, nil, toolsErr
569 }
570 // Now collect tools (which may block on MCP initialization)
571 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
572
573 // Add the session and message ID into the context if needed by tools.
574 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
575
576loop:
577 for {
578 select {
579 case event, ok := <-eventChan:
580 if !ok {
581 break loop
582 }
583 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
584 if errors.Is(processErr, context.Canceled) {
585 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
586 } else {
587 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
588 }
589 return assistantMsg, nil, processErr
590 }
591 case <-ctx.Done():
592 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
593 return assistantMsg, nil, ctx.Err()
594 }
595 }
596
597 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
598 toolCalls := assistantMsg.ToolCalls()
599 for i, toolCall := range toolCalls {
600 select {
601 case <-ctx.Done():
602 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
603 // Make all future tool calls cancelled
604 for j := i; j < len(toolCalls); j++ {
605 toolResults[j] = message.ToolResult{
606 ToolCallID: toolCalls[j].ID,
607 Content: "Tool execution canceled by user",
608 IsError: true,
609 }
610 }
611 goto out
612 default:
613 // Continue processing
614 var tool tools.BaseTool
615 allTools, _ = a.getAllTools()
616 for _, availableTool := range allTools {
617 if availableTool.Info().Name == toolCall.Name {
618 tool = availableTool
619 break
620 }
621 }
622
623 // Tool not found
624 if tool == nil {
625 toolResults[i] = message.ToolResult{
626 ToolCallID: toolCall.ID,
627 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
628 IsError: true,
629 }
630 continue
631 }
632
633 // Run tool in goroutine to allow cancellation
634 type toolExecResult struct {
635 response tools.ToolResponse
636 err error
637 }
638 resultChan := make(chan toolExecResult, 1)
639
640 go func() {
641 response, err := tool.Run(ctx, tools.ToolCall{
642 ID: toolCall.ID,
643 Name: toolCall.Name,
644 Input: toolCall.Input,
645 })
646 resultChan <- toolExecResult{response: response, err: err}
647 }()
648
649 var toolResponse tools.ToolResponse
650 var toolErr error
651
652 select {
653 case <-ctx.Done():
654 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
655 // Mark remaining tool calls as cancelled
656 for j := i; j < len(toolCalls); j++ {
657 toolResults[j] = message.ToolResult{
658 ToolCallID: toolCalls[j].ID,
659 Content: "Tool execution canceled by user",
660 IsError: true,
661 }
662 }
663 goto out
664 case result := <-resultChan:
665 toolResponse = result.response
666 toolErr = result.err
667 }
668
669 if toolErr != nil {
670 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
671 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
672 toolResults[i] = message.ToolResult{
673 ToolCallID: toolCall.ID,
674 Content: "Permission denied",
675 IsError: true,
676 }
677 for j := i + 1; j < len(toolCalls); j++ {
678 toolResults[j] = message.ToolResult{
679 ToolCallID: toolCalls[j].ID,
680 Content: "Tool execution canceled by user",
681 IsError: true,
682 }
683 }
684 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
685 break
686 }
687 }
688 toolResults[i] = message.ToolResult{
689 ToolCallID: toolCall.ID,
690 Content: toolResponse.Content,
691 Metadata: toolResponse.Metadata,
692 IsError: toolResponse.IsError,
693 }
694 }
695 }
696out:
697 if len(toolResults) == 0 {
698 return assistantMsg, nil, nil
699 }
700 parts := make([]message.ContentPart, 0)
701 for _, tr := range toolResults {
702 parts = append(parts, tr)
703 }
704 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
705 Role: message.Tool,
706 Parts: parts,
707 Provider: a.providerID,
708 })
709 if err != nil {
710 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
711 }
712
713 return assistantMsg, &msg, err
714}
715
716func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
717 msg.AddFinish(finishReason, message, details)
718 _ = a.messages.Update(ctx, *msg)
719}
720
721func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
722 select {
723 case <-ctx.Done():
724 return ctx.Err()
725 default:
726 // Continue processing.
727 }
728
729 switch event.Type {
730 case provider.EventThinkingDelta:
731 assistantMsg.AppendReasoningContent(event.Thinking)
732 return a.messages.Update(ctx, *assistantMsg)
733 case provider.EventSignatureDelta:
734 assistantMsg.AppendReasoningSignature(event.Signature)
735 return a.messages.Update(ctx, *assistantMsg)
736 case provider.EventContentDelta:
737 assistantMsg.FinishThinking()
738 assistantMsg.AppendContent(event.Content)
739 return a.messages.Update(ctx, *assistantMsg)
740 case provider.EventToolUseStart:
741 assistantMsg.FinishThinking()
742 slog.Info("Tool call started", "toolCall", event.ToolCall)
743 assistantMsg.AddToolCall(*event.ToolCall)
744 return a.messages.Update(ctx, *assistantMsg)
745 case provider.EventToolUseDelta:
746 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
747 return a.messages.Update(ctx, *assistantMsg)
748 case provider.EventToolUseStop:
749 slog.Info("Finished tool call", "toolCall", event.ToolCall)
750 assistantMsg.FinishToolCall(event.ToolCall.ID)
751 return a.messages.Update(ctx, *assistantMsg)
752 case provider.EventError:
753 return event.Error
754 case provider.EventComplete:
755 assistantMsg.FinishThinking()
756 assistantMsg.SetToolCalls(event.Response.ToolCalls)
757 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
758 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
759 return fmt.Errorf("failed to update message: %w", err)
760 }
761 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
762 }
763
764 return nil
765}
766
767func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
768 sess, err := a.sessions.Get(ctx, sessionID)
769 if err != nil {
770 return fmt.Errorf("failed to get session: %w", err)
771 }
772
773 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
774 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
775 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
776 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
777
778 a.eventTokensUsed(sessionID, usage, cost)
779
780 sess.Cost += cost
781 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
782 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
783
784 _, err = a.sessions.Save(ctx, sess)
785 if err != nil {
786 return fmt.Errorf("failed to save session: %w", err)
787 }
788 return nil
789}
790
791func (a *agent) Summarize(ctx context.Context, sessionID string) error {
792 if a.summarizeProvider == nil {
793 return fmt.Errorf("summarize provider not available")
794 }
795
796 // Check if session is busy
797 if a.IsSessionBusy(sessionID) {
798 return ErrSessionBusy
799 }
800
801 // Create a new context with cancellation
802 summarizeCtx, cancel := context.WithCancel(ctx)
803
804 // Store the cancel function in activeRequests to allow cancellation
805 a.activeRequests.Set(sessionID+"-summarize", cancel)
806
807 go func() {
808 defer a.activeRequests.Del(sessionID + "-summarize")
809 defer cancel()
810 event := AgentEvent{
811 Type: AgentEventTypeSummarize,
812 Progress: "Starting summarization...",
813 }
814
815 a.Publish(pubsub.CreatedEvent, event)
816 // Get all messages from the session
817 msgs, err := a.messages.List(summarizeCtx, sessionID)
818 if err != nil {
819 event = AgentEvent{
820 Type: AgentEventTypeError,
821 Error: fmt.Errorf("failed to list messages: %w", err),
822 Done: true,
823 }
824 a.Publish(pubsub.CreatedEvent, event)
825 return
826 }
827 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
828
829 if len(msgs) == 0 {
830 event = AgentEvent{
831 Type: AgentEventTypeError,
832 Error: fmt.Errorf("no messages to summarize"),
833 Done: true,
834 }
835 a.Publish(pubsub.CreatedEvent, event)
836 return
837 }
838
839 event = AgentEvent{
840 Type: AgentEventTypeSummarize,
841 Progress: "Analyzing conversation...",
842 }
843 a.Publish(pubsub.CreatedEvent, event)
844
845 // Add a system message to guide the summarization
846 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
847
848 // Create a new message with the summarize prompt
849 promptMsg := message.Message{
850 Role: message.User,
851 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
852 }
853
854 // Append the prompt to the messages
855 msgsWithPrompt := append(msgs, promptMsg)
856
857 event = AgentEvent{
858 Type: AgentEventTypeSummarize,
859 Progress: "Generating summary...",
860 }
861
862 a.Publish(pubsub.CreatedEvent, event)
863
864 // Send the messages to the summarize provider
865 response := a.summarizeProvider.StreamResponse(
866 summarizeCtx,
867 msgsWithPrompt,
868 nil,
869 )
870 var finalResponse *provider.ProviderResponse
871 for r := range response {
872 if r.Error != nil {
873 event = AgentEvent{
874 Type: AgentEventTypeError,
875 Error: fmt.Errorf("failed to summarize: %w", r.Error),
876 Done: true,
877 }
878 a.Publish(pubsub.CreatedEvent, event)
879 return
880 }
881 finalResponse = r.Response
882 }
883
884 summary := strings.TrimSpace(finalResponse.Content)
885 if summary == "" {
886 event = AgentEvent{
887 Type: AgentEventTypeError,
888 Error: fmt.Errorf("empty summary returned"),
889 Done: true,
890 }
891 a.Publish(pubsub.CreatedEvent, event)
892 return
893 }
894 shell := shell.GetPersistentShell(config.Get().WorkingDir())
895 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
896 event = AgentEvent{
897 Type: AgentEventTypeSummarize,
898 Progress: "Creating new session...",
899 }
900
901 a.Publish(pubsub.CreatedEvent, event)
902 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
903 if err != nil {
904 event = AgentEvent{
905 Type: AgentEventTypeError,
906 Error: fmt.Errorf("failed to get session: %w", err),
907 Done: true,
908 }
909
910 a.Publish(pubsub.CreatedEvent, event)
911 return
912 }
913 // Create a message in the new session with the summary
914 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
915 Role: message.Assistant,
916 Parts: []message.ContentPart{
917 message.TextContent{Text: summary},
918 message.Finish{
919 Reason: message.FinishReasonEndTurn,
920 Time: time.Now().Unix(),
921 },
922 },
923 Model: a.summarizeProvider.Model().ID,
924 Provider: a.summarizeProviderID,
925 })
926 if err != nil {
927 event = AgentEvent{
928 Type: AgentEventTypeError,
929 Error: fmt.Errorf("failed to create summary message: %w", err),
930 Done: true,
931 }
932
933 a.Publish(pubsub.CreatedEvent, event)
934 return
935 }
936 oldSession.SummaryMessageID = msg.ID
937 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
938 oldSession.PromptTokens = 0
939 model := a.summarizeProvider.Model()
940 usage := finalResponse.Usage
941 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
942 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
943 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
944 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
945 oldSession.Cost += cost
946 _, err = a.sessions.Save(summarizeCtx, oldSession)
947 if err != nil {
948 event = AgentEvent{
949 Type: AgentEventTypeError,
950 Error: fmt.Errorf("failed to save session: %w", err),
951 Done: true,
952 }
953 a.Publish(pubsub.CreatedEvent, event)
954 }
955
956 event = AgentEvent{
957 Type: AgentEventTypeSummarize,
958 SessionID: oldSession.ID,
959 Progress: "Summary complete",
960 Done: true,
961 }
962 a.Publish(pubsub.CreatedEvent, event)
963 // Send final success event with the new session ID
964 }()
965
966 return nil
967}
968
969func (a *agent) ClearQueue(sessionID string) {
970 if a.QueuedPrompts(sessionID) > 0 {
971 slog.Info("Clearing queued prompts", "session_id", sessionID)
972 a.promptQueue.Del(sessionID)
973 }
974}
975
976func (a *agent) CancelAll() {
977 if !a.IsBusy() {
978 return
979 }
980 for key := range a.activeRequests.Seq2() {
981 a.Cancel(key) // key is sessionID
982 }
983
984 for _, cleanup := range a.cleanupFuncs {
985 if cleanup != nil {
986 cleanup()
987 }
988 }
989
990 timeout := time.After(5 * time.Second)
991 for a.IsBusy() {
992 select {
993 case <-timeout:
994 return
995 default:
996 time.Sleep(200 * time.Millisecond)
997 }
998 }
999}
1000
1001func (a *agent) UpdateModel() error {
1002 cfg := config.Get()
1003
1004 // Get current provider configuration
1005 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
1006 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
1007 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
1008 }
1009
1010 // Check if provider has changed
1011 if string(currentProviderCfg.ID) != a.providerID {
1012 // Provider changed, need to recreate the main provider
1013 model := cfg.GetModelByType(a.agentCfg.Model)
1014 if model.ID == "" {
1015 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1016 }
1017
1018 promptID := agentPromptMap[a.agentCfg.ID]
1019 if promptID == "" {
1020 promptID = prompt.PromptDefault
1021 }
1022
1023 opts := []provider.ProviderClientOption{
1024 provider.WithModel(a.agentCfg.Model),
1025 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1026 }
1027
1028 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1029 if err != nil {
1030 return fmt.Errorf("failed to create new provider: %w", err)
1031 }
1032
1033 // Update the provider and provider ID
1034 a.provider = newProvider
1035 a.providerID = string(currentProviderCfg.ID)
1036 }
1037
1038 // Check if providers have changed for title (small) and summarize (large)
1039 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1040 var smallModelProviderCfg config.ProviderConfig
1041 for p := range cfg.Providers.Seq() {
1042 if p.ID == smallModelCfg.Provider {
1043 smallModelProviderCfg = p
1044 break
1045 }
1046 }
1047 if smallModelProviderCfg.ID == "" {
1048 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1049 }
1050
1051 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1052 var largeModelProviderCfg config.ProviderConfig
1053 for p := range cfg.Providers.Seq() {
1054 if p.ID == largeModelCfg.Provider {
1055 largeModelProviderCfg = p
1056 break
1057 }
1058 }
1059 if largeModelProviderCfg.ID == "" {
1060 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1061 }
1062
1063 var maxTitleTokens int64 = 40
1064
1065 // if the max output is too low for the gemini provider it won't return anything
1066 if smallModelCfg.Provider == "gemini" {
1067 maxTitleTokens = 1000
1068 }
1069 // Recreate title provider
1070 titleOpts := []provider.ProviderClientOption{
1071 provider.WithModel(config.SelectedModelTypeSmall),
1072 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1073 provider.WithMaxTokens(maxTitleTokens),
1074 }
1075 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1076 if err != nil {
1077 return fmt.Errorf("failed to create new title provider: %w", err)
1078 }
1079 a.titleProvider = newTitleProvider
1080
1081 // Recreate summarize provider if provider changed (now large model)
1082 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1083 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1084 if largeModel == nil {
1085 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1086 }
1087 summarizeOpts := []provider.ProviderClientOption{
1088 provider.WithModel(config.SelectedModelTypeLarge),
1089 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1090 }
1091 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1092 if err != nil {
1093 return fmt.Errorf("failed to create new summarize provider: %w", err)
1094 }
1095 a.summarizeProvider = newSummarizeProvider
1096 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1097 }
1098
1099 return nil
1100}
1101
1102func (a *agent) setupEvents(ctx context.Context) {
1103 ctx, cancel := context.WithCancel(ctx)
1104
1105 go func() {
1106 subCh := SubscribeMCPEvents(ctx)
1107
1108 for {
1109 select {
1110 case event, ok := <-subCh:
1111 if !ok {
1112 slog.Debug("MCPEvents subscription channel closed")
1113 return
1114 }
1115 switch event.Payload.Type {
1116 case MCPEventToolsListChanged:
1117 name := event.Payload.Name
1118 c, ok := mcpClients.Get(name)
1119 if !ok {
1120 slog.Warn("MCP client not found for tools update", "name", name)
1121 continue
1122 }
1123 cfg := config.Get()
1124 tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1125 if err != nil {
1126 slog.Error("error listing tools", "error", err)
1127 updateMCPState(name, MCPStateError, err, nil, 0)
1128 _ = c.Close()
1129 continue
1130 }
1131 updateMcpTools(name, tools)
1132 a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
1133 updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1134 default:
1135 continue
1136 }
1137 case <-ctx.Done():
1138 slog.Debug("MCPEvents subscription cancelled")
1139 return
1140 }
1141 }
1142 }()
1143
1144 a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1145}