1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "slices"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/event"
17 "github.com/charmbracelet/crush/internal/history"
18 "github.com/charmbracelet/crush/internal/llm/prompt"
19 "github.com/charmbracelet/crush/internal/llm/provider"
20 "github.com/charmbracelet/crush/internal/llm/tools"
21 "github.com/charmbracelet/crush/internal/log"
22 "github.com/charmbracelet/crush/internal/lsp"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/pubsub"
26 "github.com/charmbracelet/crush/internal/session"
27 "github.com/charmbracelet/crush/internal/shell"
28)
29
30const streamChunkTimeout = 80 * time.Second
31
32type AgentEventType string
33
34const (
35 AgentEventTypeError AgentEventType = "error"
36 AgentEventTypeResponse AgentEventType = "response"
37 AgentEventTypeSummarize AgentEventType = "summarize"
38)
39
40type AgentEvent struct {
41 Type AgentEventType
42 Message message.Message
43 Error error
44
45 // When summarizing
46 SessionID string
47 Progress string
48 Done bool
49}
50
51type Service interface {
52 pubsub.Suscriber[AgentEvent]
53 Model() catwalk.Model
54 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
55 Cancel(sessionID string)
56 CancelAll()
57 IsSessionBusy(sessionID string) bool
58 IsBusy() bool
59 Summarize(ctx context.Context, sessionID string) error
60 UpdateModel() error
61 QueuedPrompts(sessionID string) int
62 ClearQueue(sessionID string)
63}
64
65type agent struct {
66 cleanupFuncs []func()
67
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72
73 permissions permission.Service
74 baseTools *csync.Map[string, tools.BaseTool]
75 mcpTools *csync.Map[string, tools.BaseTool]
76 lspClients *csync.Map[string, *lsp.Client]
77
78 // We need this to be able to update it when model changes
79 agentToolFn func() (tools.BaseTool, error)
80
81 provider provider.Provider
82 providerID string
83
84 titleProvider provider.Provider
85 summarizeProvider provider.Provider
86 summarizeProviderID string
87
88 activeRequests *csync.Map[string, context.CancelFunc]
89 promptQueue *csync.Map[string, []string]
90}
91
92var agentPromptMap = map[string]prompt.PromptID{
93 "coder": prompt.PromptCoder,
94 "task": prompt.PromptTask,
95}
96
97func NewAgent(
98 ctx context.Context,
99 agentCfg config.Agent,
100 // These services are needed in the tools
101 permissions permission.Service,
102 sessions session.Service,
103 messages message.Service,
104 history history.Service,
105 lspClients *csync.Map[string, *lsp.Client],
106) (Service, error) {
107 cfg := config.Get()
108
109 var agentToolFn func() (tools.BaseTool, error)
110 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
111 agentToolFn = func() (tools.BaseTool, error) {
112 taskAgentCfg := config.Get().Agents["task"]
113 if taskAgentCfg.ID == "" {
114 return nil, fmt.Errorf("task agent not found in config")
115 }
116 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
117 if err != nil {
118 return nil, fmt.Errorf("failed to create task agent: %w", err)
119 }
120 return NewAgentTool(taskAgent, sessions, messages), nil
121 }
122 }
123
124 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
125 if providerCfg == nil {
126 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
127 }
128 model := config.Get().GetModelByType(agentCfg.Model)
129
130 if model == nil {
131 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
132 }
133
134 promptID := agentPromptMap[agentCfg.ID]
135 if promptID == "" {
136 promptID = prompt.PromptDefault
137 }
138 opts := []provider.ProviderClientOption{
139 provider.WithModel(agentCfg.Model),
140 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
141 }
142 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
143 if err != nil {
144 return nil, err
145 }
146
147 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
148 var smallModelProviderCfg *config.ProviderConfig
149 if smallModelCfg.Provider == providerCfg.ID {
150 smallModelProviderCfg = providerCfg
151 } else {
152 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
153
154 if smallModelProviderCfg.ID == "" {
155 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
156 }
157 }
158 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
159 if smallModel.ID == "" {
160 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
161 }
162
163 titleOpts := []provider.ProviderClientOption{
164 provider.WithModel(config.SelectedModelTypeSmall),
165 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
166 }
167 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
168 if err != nil {
169 return nil, err
170 }
171
172 summarizeOpts := []provider.ProviderClientOption{
173 provider.WithModel(config.SelectedModelTypeLarge),
174 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
175 }
176 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
177 if err != nil {
178 return nil, err
179 }
180
181 baseToolsFn := func() map[string]tools.BaseTool {
182 slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
183 defer func() {
184 slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
185 }()
186
187 // Base tools available to all agents
188 cwd := cfg.WorkingDir()
189 result := make(map[string]tools.BaseTool)
190 for _, tool := range []tools.BaseTool{
191 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
192 tools.NewDownloadTool(permissions, cwd),
193 tools.NewEditTool(lspClients, permissions, history, cwd),
194 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
195 tools.NewFetchTool(permissions, cwd),
196 tools.NewGlobTool(cwd),
197 tools.NewGrepTool(cwd),
198 tools.NewLsTool(permissions, cwd),
199 tools.NewSourcegraphTool(),
200 tools.NewViewTool(lspClients, permissions, cwd),
201 tools.NewWriteTool(lspClients, permissions, history, cwd),
202 } {
203 result[tool.Name()] = tool
204 }
205 return result
206 }
207 mcpToolsFn := func() map[string]tools.BaseTool {
208 slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
209 defer func() {
210 slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
211 }()
212
213 mcpToolsOnce.Do(func() {
214 doGetMCPTools(ctx, permissions, cfg)
215 })
216 result := make(map[string]tools.BaseTool)
217 for _, mcpTool := range mcpTools.Seq2() {
218 result[mcpTool.Name()] = mcpTool
219 }
220 return result
221 }
222
223 a := &agent{
224 Broker: pubsub.NewBroker[AgentEvent](),
225 agentCfg: agentCfg,
226 provider: agentProvider,
227 providerID: string(providerCfg.ID),
228 messages: messages,
229 sessions: sessions,
230 titleProvider: titleProvider,
231 summarizeProvider: summarizeProvider,
232 summarizeProviderID: string(providerCfg.ID),
233 agentToolFn: agentToolFn,
234 activeRequests: csync.NewMap[string, context.CancelFunc](),
235 mcpTools: csync.NewLazyMap(mcpToolsFn),
236 baseTools: csync.NewLazyMap(baseToolsFn),
237 promptQueue: csync.NewMap[string, []string](),
238 permissions: permissions,
239 lspClients: lspClients,
240 }
241 a.setupEvents(ctx)
242 return a, nil
243}
244
245func (a *agent) Model() catwalk.Model {
246 return *config.Get().GetModelByType(a.agentCfg.Model)
247}
248
249func (a *agent) Cancel(sessionID string) {
250 // Cancel regular requests
251 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
252 slog.Info("Request cancellation initiated", "session_id", sessionID)
253 cancel()
254 }
255
256 // Also check for summarize requests
257 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
258 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
259 cancel()
260 }
261
262 if a.QueuedPrompts(sessionID) > 0 {
263 slog.Info("Clearing queued prompts", "session_id", sessionID)
264 a.promptQueue.Del(sessionID)
265 }
266}
267
268func (a *agent) IsBusy() bool {
269 var busy bool
270 for cancelFunc := range a.activeRequests.Seq() {
271 if cancelFunc != nil {
272 busy = true
273 break
274 }
275 }
276 return busy
277}
278
279func (a *agent) IsSessionBusy(sessionID string) bool {
280 _, busy := a.activeRequests.Get(sessionID)
281 return busy
282}
283
284func (a *agent) QueuedPrompts(sessionID string) int {
285 l, ok := a.promptQueue.Get(sessionID)
286 if !ok {
287 return 0
288 }
289 return len(l)
290}
291
292func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
293 if content == "" {
294 return nil
295 }
296 if a.titleProvider == nil {
297 return nil
298 }
299 session, err := a.sessions.Get(ctx, sessionID)
300 if err != nil {
301 return err
302 }
303 parts := []message.ContentPart{message.TextContent{
304 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
305 }}
306
307 // Use streaming approach like summarization
308 response := a.titleProvider.StreamResponse(
309 ctx,
310 []message.Message{
311 {
312 Role: message.User,
313 Parts: parts,
314 },
315 },
316 nil,
317 )
318
319 var finalResponse *provider.ProviderResponse
320 for r := range response {
321 if r.Error != nil {
322 return r.Error
323 }
324 finalResponse = r.Response
325 }
326
327 if finalResponse == nil {
328 return fmt.Errorf("no response received from title provider")
329 }
330
331 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
332
333 if idx := strings.Index(title, "</think>"); idx > 0 {
334 title = title[idx+len("</think>"):]
335 }
336
337 title = strings.TrimSpace(title)
338 if title == "" {
339 return nil
340 }
341
342 session.Title = title
343 _, err = a.sessions.Save(ctx, session)
344 return err
345}
346
347func (a *agent) err(err error) AgentEvent {
348 return AgentEvent{
349 Type: AgentEventTypeError,
350 Error: err,
351 }
352}
353
354func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
355 if !a.Model().SupportsImages && attachments != nil {
356 attachments = nil
357 }
358 events := make(chan AgentEvent, 1)
359 if a.IsSessionBusy(sessionID) {
360 existing, ok := a.promptQueue.Get(sessionID)
361 if !ok {
362 existing = []string{}
363 }
364 existing = append(existing, content)
365 a.promptQueue.Set(sessionID, existing)
366 return nil, nil
367 }
368
369 genCtx, cancel := context.WithCancel(ctx)
370 a.activeRequests.Set(sessionID, cancel)
371 startTime := time.Now()
372
373 go func() {
374 slog.Debug("Request started", "sessionID", sessionID)
375 defer log.RecoverPanic("agent.Run", func() {
376 events <- a.err(fmt.Errorf("panic while running the agent"))
377 })
378 var attachmentParts []message.ContentPart
379 for _, attachment := range attachments {
380 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
381 }
382 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
383 if result.Error != nil {
384 if isCancelledErr(result.Error) {
385 slog.Error("Request canceled", "sessionID", sessionID)
386 } else {
387 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
388 event.Error(result.Error)
389 }
390 } else {
391 slog.Debug("Request completed", "sessionID", sessionID)
392 }
393 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
394 a.activeRequests.Del(sessionID)
395 cancel()
396 a.Publish(pubsub.CreatedEvent, result)
397 events <- result
398 close(events)
399 }()
400 a.eventPromptSent(sessionID)
401 return events, nil
402}
403
404func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
405 cfg := config.Get()
406 // List existing messages; if none, start title generation asynchronously.
407 msgs, err := a.messages.List(ctx, sessionID)
408 if err != nil {
409 return a.err(fmt.Errorf("failed to list messages: %w", err))
410 }
411 if len(msgs) == 0 {
412 go func() {
413 defer log.RecoverPanic("agent.Run", func() {
414 slog.Error("panic while generating title")
415 })
416 titleErr := a.generateTitle(ctx, sessionID, content)
417 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
418 slog.Error("failed to generate title", "error", titleErr)
419 }
420 }()
421 }
422 session, err := a.sessions.Get(ctx, sessionID)
423 if err != nil {
424 return a.err(fmt.Errorf("failed to get session: %w", err))
425 }
426 if session.SummaryMessageID != "" {
427 summaryMsgInex := -1
428 for i, msg := range msgs {
429 if msg.ID == session.SummaryMessageID {
430 summaryMsgInex = i
431 break
432 }
433 }
434 if summaryMsgInex != -1 {
435 msgs = msgs[summaryMsgInex:]
436 msgs[0].Role = message.User
437 }
438 }
439
440 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
441 if err != nil {
442 return a.err(fmt.Errorf("failed to create user message: %w", err))
443 }
444 // Append the new user message to the conversation history.
445 msgHistory := append(msgs, userMsg)
446
447 for {
448 // Check for cancellation before each iteration
449 select {
450 case <-ctx.Done():
451 return a.err(ctx.Err())
452 default:
453 // Continue processing
454 }
455 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
456 if err != nil {
457 if errors.Is(err, context.Canceled) {
458 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
459 a.messages.Update(context.Background(), agentMessage)
460 return a.err(ErrRequestCancelled)
461 }
462 return a.err(fmt.Errorf("failed to process events: %w", err))
463 }
464 if cfg.Options.Debug {
465 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
466 }
467 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
468 // We are not done, we need to respond with the tool response
469 msgHistory = append(msgHistory, agentMessage, *toolResults)
470 // If there are queued prompts, process the next one
471 nextPrompt, ok := a.promptQueue.Take(sessionID)
472 if ok {
473 for _, prompt := range nextPrompt {
474 // Create a new user message for the queued prompt
475 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
476 if err != nil {
477 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
478 }
479 // Append the new user message to the conversation history
480 msgHistory = append(msgHistory, userMsg)
481 }
482 }
483
484 continue
485 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
486 queuePrompts, ok := a.promptQueue.Take(sessionID)
487 if ok {
488 for _, prompt := range queuePrompts {
489 if prompt == "" {
490 continue
491 }
492 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
493 if err != nil {
494 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
495 }
496 msgHistory = append(msgHistory, userMsg)
497 }
498 continue
499 }
500 }
501 if agentMessage.FinishReason() == "" {
502 // Kujtim: could not track down where this is happening but this means its cancelled
503 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
504 _ = a.messages.Update(context.Background(), agentMessage)
505 return a.err(ErrRequestCancelled)
506 }
507 return AgentEvent{
508 Type: AgentEventTypeResponse,
509 Message: agentMessage,
510 Done: true,
511 }
512 }
513}
514
515func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
516 parts := []message.ContentPart{message.TextContent{Text: content}}
517 parts = append(parts, attachmentParts...)
518 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
519 Role: message.User,
520 Parts: parts,
521 })
522}
523
524func (a *agent) getAllTools() ([]tools.BaseTool, error) {
525 allTools := slices.Collect(a.baseTools.Seq())
526
527 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
528 if a.agentCfg.ID == "coder" {
529 t = append(t, slices.Collect(a.mcpTools.Seq())...)
530 if a.lspClients.Len() > 0 {
531 t = append(t, tools.NewDiagnosticsTool(a.lspClients))
532 }
533 }
534 return t
535 }
536
537 if a.agentCfg.AllowedTools == nil {
538 allTools = withCoderTools(allTools)
539 } else {
540 var filteredTools []tools.BaseTool
541 for _, tool := range allTools {
542 if slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
543 filteredTools = append(filteredTools, tool)
544 }
545 }
546 allTools = withCoderTools(filteredTools)
547 }
548
549 if a.agentToolFn != nil {
550 agentTool, agentToolErr := a.agentToolFn()
551 if agentToolErr != nil {
552 return nil, agentToolErr
553 }
554 allTools = append(allTools, agentTool)
555 }
556 return allTools, nil
557}
558
559func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
560 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
561
562 // Create the assistant message first so the spinner shows immediately
563 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
564 Role: message.Assistant,
565 Parts: []message.ContentPart{},
566 Model: a.Model().ID,
567 Provider: a.providerID,
568 })
569 if err != nil {
570 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
571 }
572
573 allTools, toolsErr := a.getAllTools()
574 if toolsErr != nil {
575 return assistantMsg, nil, toolsErr
576 }
577 // Now collect tools (which may block on MCP initialization)
578 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
579
580 // Add the session and message ID into the context if needed by tools.
581 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
582
583 // Process each event in the stream.
584loop:
585 for {
586 select {
587 case event, ok := <-eventChan:
588 if !ok {
589 break loop
590 }
591 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
592 if errors.Is(processErr, context.Canceled) {
593 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
594 } else {
595 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
596 }
597 return assistantMsg, nil, processErr
598 }
599 case <-time.After(streamChunkTimeout):
600 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "Stream timeout", "No chunk received within timeout")
601 return assistantMsg, nil, fmt.Errorf("stream chunk timeout")
602 case <-ctx.Done():
603 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
604 return assistantMsg, nil, ctx.Err()
605 }
606 }
607
608 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
609 toolCalls := assistantMsg.ToolCalls()
610 for i, toolCall := range toolCalls {
611 select {
612 case <-ctx.Done():
613 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
614 // Make all future tool calls cancelled
615 for j := i; j < len(toolCalls); j++ {
616 toolResults[j] = message.ToolResult{
617 ToolCallID: toolCalls[j].ID,
618 Content: "Tool execution canceled by user",
619 IsError: true,
620 }
621 }
622 goto out
623 default:
624 // Continue processing
625 var tool tools.BaseTool
626 allTools, _ = a.getAllTools()
627 for _, availableTool := range allTools {
628 if availableTool.Info().Name == toolCall.Name {
629 tool = availableTool
630 break
631 }
632 }
633
634 // Tool not found
635 if tool == nil {
636 toolResults[i] = message.ToolResult{
637 ToolCallID: toolCall.ID,
638 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
639 IsError: true,
640 }
641 continue
642 }
643
644 // Run tool in goroutine to allow cancellation
645 type toolExecResult struct {
646 response tools.ToolResponse
647 err error
648 }
649 resultChan := make(chan toolExecResult, 1)
650
651 go func() {
652 response, err := tool.Run(ctx, tools.ToolCall{
653 ID: toolCall.ID,
654 Name: toolCall.Name,
655 Input: toolCall.Input,
656 })
657 resultChan <- toolExecResult{response: response, err: err}
658 }()
659
660 var toolResponse tools.ToolResponse
661 var toolErr error
662
663 select {
664 case <-ctx.Done():
665 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
666 // Mark remaining tool calls as cancelled
667 for j := i; j < len(toolCalls); j++ {
668 toolResults[j] = message.ToolResult{
669 ToolCallID: toolCalls[j].ID,
670 Content: "Tool execution canceled by user",
671 IsError: true,
672 }
673 }
674 goto out
675 case result := <-resultChan:
676 toolResponse = result.response
677 toolErr = result.err
678 }
679
680 if toolErr != nil {
681 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
682 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
683 toolResults[i] = message.ToolResult{
684 ToolCallID: toolCall.ID,
685 Content: "Permission denied",
686 IsError: true,
687 }
688 for j := i + 1; j < len(toolCalls); j++ {
689 toolResults[j] = message.ToolResult{
690 ToolCallID: toolCalls[j].ID,
691 Content: "Tool execution canceled by user",
692 IsError: true,
693 }
694 }
695 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
696 break
697 }
698 }
699 toolResults[i] = message.ToolResult{
700 ToolCallID: toolCall.ID,
701 Content: toolResponse.Content,
702 Metadata: toolResponse.Metadata,
703 IsError: toolResponse.IsError,
704 }
705 }
706 }
707out:
708 if len(toolResults) == 0 {
709 return assistantMsg, nil, nil
710 }
711 parts := make([]message.ContentPart, 0)
712 for _, tr := range toolResults {
713 parts = append(parts, tr)
714 }
715 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
716 Role: message.Tool,
717 Parts: parts,
718 Provider: a.providerID,
719 })
720 if err != nil {
721 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
722 }
723
724 return assistantMsg, &msg, err
725}
726
727func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
728 msg.AddFinish(finishReason, message, details)
729 _ = a.messages.Update(ctx, *msg)
730}
731
732func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
733 select {
734 case <-ctx.Done():
735 return ctx.Err()
736 default:
737 // Continue processing.
738 }
739
740 switch event.Type {
741 case provider.EventThinkingDelta:
742 assistantMsg.AppendReasoningContent(event.Thinking)
743 return a.messages.Update(ctx, *assistantMsg)
744 case provider.EventSignatureDelta:
745 assistantMsg.AppendReasoningSignature(event.Signature)
746 return a.messages.Update(ctx, *assistantMsg)
747 case provider.EventContentDelta:
748 assistantMsg.FinishThinking()
749 assistantMsg.AppendContent(event.Content)
750 return a.messages.Update(ctx, *assistantMsg)
751 case provider.EventToolUseStart:
752 assistantMsg.FinishThinking()
753 slog.Info("Tool call started", "toolCall", event.ToolCall)
754 assistantMsg.AddToolCall(*event.ToolCall)
755 return a.messages.Update(ctx, *assistantMsg)
756 case provider.EventToolUseDelta:
757 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
758 return a.messages.Update(ctx, *assistantMsg)
759 case provider.EventToolUseStop:
760 slog.Info("Finished tool call", "toolCall", event.ToolCall)
761 assistantMsg.FinishToolCall(event.ToolCall.ID)
762 return a.messages.Update(ctx, *assistantMsg)
763 case provider.EventError:
764 return event.Error
765 case provider.EventComplete:
766 assistantMsg.FinishThinking()
767 assistantMsg.SetToolCalls(event.Response.ToolCalls)
768 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
769 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
770 return fmt.Errorf("failed to update message: %w", err)
771 }
772 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
773 }
774
775 return nil
776}
777
778func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
779 sess, err := a.sessions.Get(ctx, sessionID)
780 if err != nil {
781 return fmt.Errorf("failed to get session: %w", err)
782 }
783
784 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
785 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
786 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
787 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
788
789 a.eventTokensUsed(sessionID, usage, cost)
790
791 sess.Cost += cost
792 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
793 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
794
795 _, err = a.sessions.Save(ctx, sess)
796 if err != nil {
797 return fmt.Errorf("failed to save session: %w", err)
798 }
799 return nil
800}
801
802func (a *agent) Summarize(ctx context.Context, sessionID string) error {
803 if a.summarizeProvider == nil {
804 return fmt.Errorf("summarize provider not available")
805 }
806
807 // Check if session is busy
808 if a.IsSessionBusy(sessionID) {
809 return ErrSessionBusy
810 }
811
812 // Create a new context with cancellation
813 summarizeCtx, cancel := context.WithCancel(ctx)
814
815 // Store the cancel function in activeRequests to allow cancellation
816 a.activeRequests.Set(sessionID+"-summarize", cancel)
817
818 go func() {
819 defer a.activeRequests.Del(sessionID + "-summarize")
820 defer cancel()
821 event := AgentEvent{
822 Type: AgentEventTypeSummarize,
823 Progress: "Starting summarization...",
824 }
825
826 a.Publish(pubsub.CreatedEvent, event)
827 // Get all messages from the session
828 msgs, err := a.messages.List(summarizeCtx, sessionID)
829 if err != nil {
830 event = AgentEvent{
831 Type: AgentEventTypeError,
832 Error: fmt.Errorf("failed to list messages: %w", err),
833 Done: true,
834 }
835 a.Publish(pubsub.CreatedEvent, event)
836 return
837 }
838 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
839
840 if len(msgs) == 0 {
841 event = AgentEvent{
842 Type: AgentEventTypeError,
843 Error: fmt.Errorf("no messages to summarize"),
844 Done: true,
845 }
846 a.Publish(pubsub.CreatedEvent, event)
847 return
848 }
849
850 event = AgentEvent{
851 Type: AgentEventTypeSummarize,
852 Progress: "Analyzing conversation...",
853 }
854 a.Publish(pubsub.CreatedEvent, event)
855
856 // Add a system message to guide the summarization
857 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
858
859 // Create a new message with the summarize prompt
860 promptMsg := message.Message{
861 Role: message.User,
862 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
863 }
864
865 // Append the prompt to the messages
866 msgsWithPrompt := append(msgs, promptMsg)
867
868 event = AgentEvent{
869 Type: AgentEventTypeSummarize,
870 Progress: "Generating summary...",
871 }
872
873 a.Publish(pubsub.CreatedEvent, event)
874
875 // Send the messages to the summarize provider
876 response := a.summarizeProvider.StreamResponse(
877 summarizeCtx,
878 msgsWithPrompt,
879 nil,
880 )
881 var finalResponse *provider.ProviderResponse
882 for r := range response {
883 if r.Error != nil {
884 event = AgentEvent{
885 Type: AgentEventTypeError,
886 Error: fmt.Errorf("failed to summarize: %w", r.Error),
887 Done: true,
888 }
889 a.Publish(pubsub.CreatedEvent, event)
890 return
891 }
892 finalResponse = r.Response
893 }
894
895 summary := strings.TrimSpace(finalResponse.Content)
896 if summary == "" {
897 event = AgentEvent{
898 Type: AgentEventTypeError,
899 Error: fmt.Errorf("empty summary returned"),
900 Done: true,
901 }
902 a.Publish(pubsub.CreatedEvent, event)
903 return
904 }
905 shell := shell.GetPersistentShell(config.Get().WorkingDir())
906 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
907 event = AgentEvent{
908 Type: AgentEventTypeSummarize,
909 Progress: "Creating new session...",
910 }
911
912 a.Publish(pubsub.CreatedEvent, event)
913 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
914 if err != nil {
915 event = AgentEvent{
916 Type: AgentEventTypeError,
917 Error: fmt.Errorf("failed to get session: %w", err),
918 Done: true,
919 }
920
921 a.Publish(pubsub.CreatedEvent, event)
922 return
923 }
924 // Create a message in the new session with the summary
925 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
926 Role: message.Assistant,
927 Parts: []message.ContentPart{
928 message.TextContent{Text: summary},
929 message.Finish{
930 Reason: message.FinishReasonEndTurn,
931 Time: time.Now().Unix(),
932 },
933 },
934 Model: a.summarizeProvider.Model().ID,
935 Provider: a.summarizeProviderID,
936 })
937 if err != nil {
938 event = AgentEvent{
939 Type: AgentEventTypeError,
940 Error: fmt.Errorf("failed to create summary message: %w", err),
941 Done: true,
942 }
943
944 a.Publish(pubsub.CreatedEvent, event)
945 return
946 }
947 oldSession.SummaryMessageID = msg.ID
948 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
949 oldSession.PromptTokens = 0
950 model := a.summarizeProvider.Model()
951 usage := finalResponse.Usage
952 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
953 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
954 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
955 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
956 oldSession.Cost += cost
957 _, err = a.sessions.Save(summarizeCtx, oldSession)
958 if err != nil {
959 event = AgentEvent{
960 Type: AgentEventTypeError,
961 Error: fmt.Errorf("failed to save session: %w", err),
962 Done: true,
963 }
964 a.Publish(pubsub.CreatedEvent, event)
965 }
966
967 event = AgentEvent{
968 Type: AgentEventTypeSummarize,
969 SessionID: oldSession.ID,
970 Progress: "Summary complete",
971 Done: true,
972 }
973 a.Publish(pubsub.CreatedEvent, event)
974 // Send final success event with the new session ID
975 }()
976
977 return nil
978}
979
980func (a *agent) ClearQueue(sessionID string) {
981 if a.QueuedPrompts(sessionID) > 0 {
982 slog.Info("Clearing queued prompts", "session_id", sessionID)
983 a.promptQueue.Del(sessionID)
984 }
985}
986
987func (a *agent) CancelAll() {
988 if !a.IsBusy() {
989 return
990 }
991 for key := range a.activeRequests.Seq2() {
992 a.Cancel(key) // key is sessionID
993 }
994
995 for _, cleanup := range a.cleanupFuncs {
996 if cleanup != nil {
997 cleanup()
998 }
999 }
1000
1001 timeout := time.After(5 * time.Second)
1002 for a.IsBusy() {
1003 select {
1004 case <-timeout:
1005 return
1006 default:
1007 time.Sleep(200 * time.Millisecond)
1008 }
1009 }
1010}
1011
1012func (a *agent) UpdateModel() error {
1013 cfg := config.Get()
1014
1015 // Get current provider configuration
1016 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
1017 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
1018 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
1019 }
1020
1021 // Check if provider has changed
1022 if string(currentProviderCfg.ID) != a.providerID {
1023 // Provider changed, need to recreate the main provider
1024 model := cfg.GetModelByType(a.agentCfg.Model)
1025 if model.ID == "" {
1026 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1027 }
1028
1029 promptID := agentPromptMap[a.agentCfg.ID]
1030 if promptID == "" {
1031 promptID = prompt.PromptDefault
1032 }
1033
1034 opts := []provider.ProviderClientOption{
1035 provider.WithModel(a.agentCfg.Model),
1036 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1037 }
1038
1039 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1040 if err != nil {
1041 return fmt.Errorf("failed to create new provider: %w", err)
1042 }
1043
1044 // Update the provider and provider ID
1045 a.provider = newProvider
1046 a.providerID = string(currentProviderCfg.ID)
1047 }
1048
1049 // Check if providers have changed for title (small) and summarize (large)
1050 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1051 var smallModelProviderCfg config.ProviderConfig
1052 for p := range cfg.Providers.Seq() {
1053 if p.ID == smallModelCfg.Provider {
1054 smallModelProviderCfg = p
1055 break
1056 }
1057 }
1058 if smallModelProviderCfg.ID == "" {
1059 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1060 }
1061
1062 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1063 var largeModelProviderCfg config.ProviderConfig
1064 for p := range cfg.Providers.Seq() {
1065 if p.ID == largeModelCfg.Provider {
1066 largeModelProviderCfg = p
1067 break
1068 }
1069 }
1070 if largeModelProviderCfg.ID == "" {
1071 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1072 }
1073
1074 var maxTitleTokens int64 = 40
1075
1076 // if the max output is too low for the gemini provider it won't return anything
1077 if smallModelCfg.Provider == "gemini" {
1078 maxTitleTokens = 1000
1079 }
1080 // Recreate title provider
1081 titleOpts := []provider.ProviderClientOption{
1082 provider.WithModel(config.SelectedModelTypeSmall),
1083 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1084 provider.WithMaxTokens(maxTitleTokens),
1085 }
1086 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1087 if err != nil {
1088 return fmt.Errorf("failed to create new title provider: %w", err)
1089 }
1090 a.titleProvider = newTitleProvider
1091
1092 // Recreate summarize provider if provider changed (now large model)
1093 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1094 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1095 if largeModel == nil {
1096 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1097 }
1098 summarizeOpts := []provider.ProviderClientOption{
1099 provider.WithModel(config.SelectedModelTypeLarge),
1100 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1101 }
1102 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1103 if err != nil {
1104 return fmt.Errorf("failed to create new summarize provider: %w", err)
1105 }
1106 a.summarizeProvider = newSummarizeProvider
1107 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1108 }
1109
1110 return nil
1111}
1112
1113func (a *agent) setupEvents(ctx context.Context) {
1114 ctx, cancel := context.WithCancel(ctx)
1115
1116 go func() {
1117 subCh := SubscribeMCPEvents(ctx)
1118
1119 for {
1120 select {
1121 case event, ok := <-subCh:
1122 if !ok {
1123 slog.Debug("MCPEvents subscription channel closed")
1124 return
1125 }
1126 switch event.Payload.Type {
1127 case MCPEventToolsListChanged:
1128 name := event.Payload.Name
1129 c, ok := mcpClients.Get(name)
1130 if !ok {
1131 slog.Warn("MCP client not found for tools update", "name", name)
1132 continue
1133 }
1134 cfg := config.Get()
1135 tools := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1136 updateMcpTools(name, tools)
1137 // Update the lazy map with the new tools
1138 a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2()))
1139 updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1140 default:
1141 continue
1142 }
1143 case <-ctx.Done():
1144 slog.Debug("MCPEvents subscription cancelled")
1145 return
1146 }
1147 }
1148 }()
1149
1150 a.cleanupFuncs = append(a.cleanupFuncs, func() {
1151 cancel()
1152 })
1153}