1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/event"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26 "github.com/charmbracelet/crush/internal/shell"
27)
28
29type AgentEventType string
30
31const (
32 AgentEventTypeError AgentEventType = "error"
33 AgentEventTypeResponse AgentEventType = "response"
34 AgentEventTypeSummarize AgentEventType = "summarize"
35)
36
37type AgentEvent struct {
38 Type AgentEventType
39 Message message.Message
40 Error error
41
42 // When summarizing
43 SessionID string
44 Progress string
45 Done bool
46}
47
48type Service interface {
49 pubsub.Suscriber[AgentEvent]
50 Model() catwalk.Model
51 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
52 Cancel(sessionID string)
53 CancelAll()
54 IsSessionBusy(sessionID string) bool
55 IsBusy() bool
56 Summarize(ctx context.Context, sessionID string) error
57 UpdateModel() error
58 QueuedPrompts(sessionID string) int
59 ClearQueue(sessionID string)
60}
61
62type agent struct {
63 *pubsub.Broker[AgentEvent]
64 agentCfg config.Agent
65 sessions session.Service
66 messages message.Service
67 permissions permission.Service
68 mcpTools []McpTool
69
70 tools *csync.LazySlice[tools.BaseTool]
71 // We need this to be able to update it when model changes
72 agentToolFn func() (tools.BaseTool, error)
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests *csync.Map[string, context.CancelFunc]
82 promptQueue *csync.Map[string, []string]
83}
84
85var agentPromptMap = map[string]prompt.PromptID{
86 "coder": prompt.PromptCoder,
87 "task": prompt.PromptTask,
88}
89
90func NewAgent(
91 ctx context.Context,
92 agentCfg config.Agent,
93 // These services are needed in the tools
94 permissions permission.Service,
95 sessions session.Service,
96 messages message.Service,
97 history history.Service,
98 lspClients *csync.Map[string, *lsp.Client],
99) (Service, error) {
100 cfg := config.Get()
101
102 var agentToolFn func() (tools.BaseTool, error)
103 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
104 agentToolFn = func() (tools.BaseTool, error) {
105 taskAgentCfg := config.Get().Agents["task"]
106 if taskAgentCfg.ID == "" {
107 return nil, fmt.Errorf("task agent not found in config")
108 }
109 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
110 if err != nil {
111 return nil, fmt.Errorf("failed to create task agent: %w", err)
112 }
113 return NewAgentTool(taskAgent, sessions, messages), nil
114 }
115 }
116
117 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
118 if providerCfg == nil {
119 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
120 }
121 model := config.Get().GetModelByType(agentCfg.Model)
122
123 if model == nil {
124 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
125 }
126
127 promptID := agentPromptMap[agentCfg.ID]
128 if promptID == "" {
129 promptID = prompt.PromptDefault
130 }
131 opts := []provider.ProviderClientOption{
132 provider.WithModel(agentCfg.Model),
133 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
134 }
135 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
136 if err != nil {
137 return nil, err
138 }
139
140 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
141 var smallModelProviderCfg *config.ProviderConfig
142 if smallModelCfg.Provider == providerCfg.ID {
143 smallModelProviderCfg = providerCfg
144 } else {
145 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
146
147 if smallModelProviderCfg.ID == "" {
148 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
149 }
150 }
151 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
152 if smallModel.ID == "" {
153 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
154 }
155
156 titleOpts := []provider.ProviderClientOption{
157 provider.WithModel(config.SelectedModelTypeSmall),
158 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
159 }
160 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
161 if err != nil {
162 return nil, err
163 }
164
165 summarizeOpts := []provider.ProviderClientOption{
166 provider.WithModel(config.SelectedModelTypeLarge),
167 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
168 }
169 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
170 if err != nil {
171 return nil, err
172 }
173
174 toolFn := func() []tools.BaseTool {
175 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
176 defer func() {
177 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
178 }()
179
180 cwd := cfg.WorkingDir()
181 allTools := []tools.BaseTool{
182 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
183 tools.NewDownloadTool(permissions, cwd),
184 tools.NewEditTool(lspClients, permissions, history, cwd),
185 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
186 tools.NewFetchTool(permissions, cwd),
187 tools.NewGlobTool(cwd),
188 tools.NewGrepTool(cwd),
189 tools.NewLsTool(permissions, cwd),
190 tools.NewSourcegraphTool(),
191 tools.NewViewTool(lspClients, permissions, cwd),
192 tools.NewWriteTool(lspClients, permissions, history, cwd),
193 }
194
195 mcpToolsOnce.Do(func() {
196 mcpTools = doGetMCPTools(ctx, permissions, cfg)
197 })
198
199 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
200 if agentCfg.ID == "coder" {
201 t = append(t, mcpTools...)
202 if lspClients.Len() > 0 {
203 t = append(t, tools.NewDiagnosticsTool(lspClients))
204 }
205 }
206 return t
207 }
208
209 if agentCfg.AllowedTools == nil {
210 return withCoderTools(allTools)
211 }
212
213 var filteredTools []tools.BaseTool
214 for _, tool := range allTools {
215 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
216 filteredTools = append(filteredTools, tool)
217 }
218 }
219 return withCoderTools(filteredTools)
220 }
221
222 return &agent{
223 Broker: pubsub.NewBroker[AgentEvent](),
224 agentCfg: agentCfg,
225 provider: agentProvider,
226 providerID: string(providerCfg.ID),
227 messages: messages,
228 sessions: sessions,
229 titleProvider: titleProvider,
230 summarizeProvider: summarizeProvider,
231 summarizeProviderID: string(providerCfg.ID),
232 agentToolFn: agentToolFn,
233 activeRequests: csync.NewMap[string, context.CancelFunc](),
234 tools: csync.NewLazySlice(toolFn),
235 promptQueue: csync.NewMap[string, []string](),
236 permissions: permissions,
237 }, nil
238}
239
240func (a *agent) Model() catwalk.Model {
241 return *config.Get().GetModelByType(a.agentCfg.Model)
242}
243
244func (a *agent) Cancel(sessionID string) {
245 // Cancel regular requests
246 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
247 slog.Info("Request cancellation initiated", "session_id", sessionID)
248 cancel()
249 }
250
251 // Also check for summarize requests
252 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
253 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
254 cancel()
255 }
256
257 if a.QueuedPrompts(sessionID) > 0 {
258 slog.Info("Clearing queued prompts", "session_id", sessionID)
259 a.promptQueue.Del(sessionID)
260 }
261}
262
263func (a *agent) IsBusy() bool {
264 var busy bool
265 for cancelFunc := range a.activeRequests.Seq() {
266 if cancelFunc != nil {
267 busy = true
268 break
269 }
270 }
271 return busy
272}
273
274func (a *agent) IsSessionBusy(sessionID string) bool {
275 _, busy := a.activeRequests.Get(sessionID)
276 return busy
277}
278
279func (a *agent) QueuedPrompts(sessionID string) int {
280 l, ok := a.promptQueue.Get(sessionID)
281 if !ok {
282 return 0
283 }
284 return len(l)
285}
286
287func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
288 if content == "" {
289 return nil
290 }
291 if a.titleProvider == nil {
292 return nil
293 }
294 session, err := a.sessions.Get(ctx, sessionID)
295 if err != nil {
296 return err
297 }
298 parts := []message.ContentPart{message.TextContent{
299 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
300 }}
301
302 // Use streaming approach like summarization
303 response := a.titleProvider.StreamResponse(
304 ctx,
305 []message.Message{
306 {
307 Role: message.User,
308 Parts: parts,
309 },
310 },
311 nil,
312 )
313
314 var finalResponse *provider.ProviderResponse
315 for r := range response {
316 if r.Error != nil {
317 return r.Error
318 }
319 finalResponse = r.Response
320 }
321
322 if finalResponse == nil {
323 return fmt.Errorf("no response received from title provider")
324 }
325
326 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
327
328 if idx := strings.Index(title, "</think>"); idx > 0 {
329 title = title[idx+len("</think>"):]
330 }
331
332 title = strings.TrimSpace(title)
333 if title == "" {
334 return nil
335 }
336
337 session.Title = title
338 _, err = a.sessions.Save(ctx, session)
339 return err
340}
341
342func (a *agent) err(err error) AgentEvent {
343 return AgentEvent{
344 Type: AgentEventTypeError,
345 Error: err,
346 }
347}
348
349func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
350 if !a.Model().SupportsImages && attachments != nil {
351 attachments = nil
352 }
353 events := make(chan AgentEvent, 1)
354 if a.IsSessionBusy(sessionID) {
355 existing, ok := a.promptQueue.Get(sessionID)
356 if !ok {
357 existing = []string{}
358 }
359 existing = append(existing, content)
360 a.promptQueue.Set(sessionID, existing)
361 return nil, nil
362 }
363
364 genCtx, cancel := context.WithCancel(ctx)
365 a.activeRequests.Set(sessionID, cancel)
366 startTime := time.Now()
367
368 go func() {
369 slog.Debug("Request started", "sessionID", sessionID)
370 defer log.RecoverPanic("agent.Run", func() {
371 events <- a.err(fmt.Errorf("panic while running the agent"))
372 })
373 var attachmentParts []message.ContentPart
374 for _, attachment := range attachments {
375 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
376 }
377 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
378 if result.Error != nil {
379 if isCancelledErr(result.Error) {
380 slog.Error("Request canceled", "sessionID", sessionID)
381 } else {
382 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
383 event.Error(result.Error)
384 }
385 } else {
386 slog.Debug("Request completed", "sessionID", sessionID)
387 }
388 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
389 a.activeRequests.Del(sessionID)
390 cancel()
391 a.Publish(pubsub.CreatedEvent, result)
392 events <- result
393 close(events)
394 }()
395 a.eventPromptSent(sessionID)
396 return events, nil
397}
398
399func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
400 cfg := config.Get()
401 // List existing messages; if none, start title generation asynchronously.
402 msgs, err := a.messages.List(ctx, sessionID)
403 if err != nil {
404 return a.err(fmt.Errorf("failed to list messages: %w", err))
405 }
406 if len(msgs) == 0 {
407 go func() {
408 defer log.RecoverPanic("agent.Run", func() {
409 slog.Error("panic while generating title")
410 })
411 titleErr := a.generateTitle(ctx, sessionID, content)
412 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
413 slog.Error("failed to generate title", "error", titleErr)
414 }
415 }()
416 }
417 session, err := a.sessions.Get(ctx, sessionID)
418 if err != nil {
419 return a.err(fmt.Errorf("failed to get session: %w", err))
420 }
421 if session.SummaryMessageID != "" {
422 summaryMsgInex := -1
423 for i, msg := range msgs {
424 if msg.ID == session.SummaryMessageID {
425 summaryMsgInex = i
426 break
427 }
428 }
429 if summaryMsgInex != -1 {
430 msgs = msgs[summaryMsgInex:]
431 msgs[0].Role = message.User
432 }
433 }
434
435 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
436 if err != nil {
437 return a.err(fmt.Errorf("failed to create user message: %w", err))
438 }
439 // Append the new user message to the conversation history.
440 msgHistory := append(msgs, userMsg)
441
442 for {
443 // Check for cancellation before each iteration
444 select {
445 case <-ctx.Done():
446 return a.err(ctx.Err())
447 default:
448 // Continue processing
449 }
450 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
451 if err != nil {
452 if errors.Is(err, context.Canceled) {
453 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
454 a.messages.Update(context.Background(), agentMessage)
455 return a.err(ErrRequestCancelled)
456 }
457 return a.err(fmt.Errorf("failed to process events: %w", err))
458 }
459 if cfg.Options.Debug {
460 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
461 }
462 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
463 // We are not done, we need to respond with the tool response
464 msgHistory = append(msgHistory, agentMessage, *toolResults)
465 // If there are queued prompts, process the next one
466 nextPrompt, ok := a.promptQueue.Take(sessionID)
467 if ok {
468 for _, prompt := range nextPrompt {
469 // Create a new user message for the queued prompt
470 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
471 if err != nil {
472 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
473 }
474 // Append the new user message to the conversation history
475 msgHistory = append(msgHistory, userMsg)
476 }
477 }
478
479 continue
480 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
481 queuePrompts, ok := a.promptQueue.Take(sessionID)
482 if ok {
483 for _, prompt := range queuePrompts {
484 if prompt == "" {
485 continue
486 }
487 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
488 if err != nil {
489 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
490 }
491 msgHistory = append(msgHistory, userMsg)
492 }
493 continue
494 }
495 }
496 if agentMessage.FinishReason() == "" {
497 // Kujtim: could not track down where this is happening but this means its cancelled
498 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
499 _ = a.messages.Update(context.Background(), agentMessage)
500 return a.err(ErrRequestCancelled)
501 }
502 return AgentEvent{
503 Type: AgentEventTypeResponse,
504 Message: agentMessage,
505 Done: true,
506 }
507 }
508}
509
510func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
511 parts := []message.ContentPart{message.TextContent{Text: content}}
512 parts = append(parts, attachmentParts...)
513 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
514 Role: message.User,
515 Parts: parts,
516 })
517}
518
519func (a *agent) getAllTools() ([]tools.BaseTool, error) {
520 allTools := slices.Collect(a.tools.Seq())
521 if a.agentToolFn != nil {
522 agentTool, agentToolErr := a.agentToolFn()
523 if agentToolErr != nil {
524 return nil, agentToolErr
525 }
526 allTools = append(allTools, agentTool)
527 }
528 return allTools, nil
529}
530
531func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
532 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
533
534 // Create the assistant message first so the spinner shows immediately
535 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
536 Role: message.Assistant,
537 Parts: []message.ContentPart{},
538 Model: a.Model().ID,
539 Provider: a.providerID,
540 })
541 if err != nil {
542 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
543 }
544
545 allTools, toolsErr := a.getAllTools()
546 if toolsErr != nil {
547 return assistantMsg, nil, toolsErr
548 }
549 // Now collect tools (which may block on MCP initialization)
550 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
551
552 // Add the session and message ID into the context if needed by tools.
553 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
554
555loop:
556 for {
557 select {
558 case event, ok := <-eventChan:
559 if !ok {
560 break loop
561 }
562 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
563 if errors.Is(processErr, context.Canceled) {
564 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
565 } else {
566 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
567 }
568 return assistantMsg, nil, processErr
569 }
570 case <-ctx.Done():
571 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
572 return assistantMsg, nil, ctx.Err()
573 }
574 }
575
576 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
577 toolCalls := assistantMsg.ToolCalls()
578 for i, toolCall := range toolCalls {
579 select {
580 case <-ctx.Done():
581 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
582 // Make all future tool calls cancelled
583 for j := i; j < len(toolCalls); j++ {
584 toolResults[j] = message.ToolResult{
585 ToolCallID: toolCalls[j].ID,
586 Content: "Tool execution canceled by user",
587 IsError: true,
588 }
589 }
590 goto out
591 default:
592 // Continue processing
593 var tool tools.BaseTool
594 allTools, _ := a.getAllTools()
595 for _, availableTool := range allTools {
596 if availableTool.Info().Name == toolCall.Name {
597 tool = availableTool
598 break
599 }
600 }
601
602 // Tool not found
603 if tool == nil {
604 toolResults[i] = message.ToolResult{
605 ToolCallID: toolCall.ID,
606 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
607 IsError: true,
608 }
609 continue
610 }
611
612 // Run tool in goroutine to allow cancellation
613 type toolExecResult struct {
614 response tools.ToolResponse
615 err error
616 }
617 resultChan := make(chan toolExecResult, 1)
618
619 go func() {
620 response, err := tool.Run(ctx, tools.ToolCall{
621 ID: toolCall.ID,
622 Name: toolCall.Name,
623 Input: toolCall.Input,
624 })
625 resultChan <- toolExecResult{response: response, err: err}
626 }()
627
628 var toolResponse tools.ToolResponse
629 var toolErr error
630
631 select {
632 case <-ctx.Done():
633 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
634 // Mark remaining tool calls as cancelled
635 for j := i; j < len(toolCalls); j++ {
636 toolResults[j] = message.ToolResult{
637 ToolCallID: toolCalls[j].ID,
638 Content: "Tool execution canceled by user",
639 IsError: true,
640 }
641 }
642 goto out
643 case result := <-resultChan:
644 toolResponse = result.response
645 toolErr = result.err
646 }
647
648 if toolErr != nil {
649 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
650 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
651 toolResults[i] = message.ToolResult{
652 ToolCallID: toolCall.ID,
653 Content: "Permission denied",
654 IsError: true,
655 }
656 for j := i + 1; j < len(toolCalls); j++ {
657 toolResults[j] = message.ToolResult{
658 ToolCallID: toolCalls[j].ID,
659 Content: "Tool execution canceled by user",
660 IsError: true,
661 }
662 }
663 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
664 break
665 }
666 }
667 toolResults[i] = message.ToolResult{
668 ToolCallID: toolCall.ID,
669 Content: toolResponse.Content,
670 Metadata: toolResponse.Metadata,
671 IsError: toolResponse.IsError,
672 }
673 }
674 }
675out:
676 if len(toolResults) == 0 {
677 return assistantMsg, nil, nil
678 }
679 parts := make([]message.ContentPart, 0)
680 for _, tr := range toolResults {
681 parts = append(parts, tr)
682 }
683 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
684 Role: message.Tool,
685 Parts: parts,
686 Provider: a.providerID,
687 })
688 if err != nil {
689 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
690 }
691
692 return assistantMsg, &msg, err
693}
694
695func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
696 msg.AddFinish(finishReason, message, details)
697 _ = a.messages.Update(ctx, *msg)
698}
699
700func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
701 select {
702 case <-ctx.Done():
703 return ctx.Err()
704 default:
705 // Continue processing.
706 }
707
708 switch event.Type {
709 case provider.EventThinkingDelta:
710 assistantMsg.AppendReasoningContent(event.Thinking)
711 return a.messages.Update(ctx, *assistantMsg)
712 case provider.EventSignatureDelta:
713 assistantMsg.AppendReasoningSignature(event.Signature)
714 return a.messages.Update(ctx, *assistantMsg)
715 case provider.EventContentDelta:
716 assistantMsg.FinishThinking()
717 assistantMsg.AppendContent(event.Content)
718 return a.messages.Update(ctx, *assistantMsg)
719 case provider.EventToolUseStart:
720 assistantMsg.FinishThinking()
721 slog.Info("Tool call started", "toolCall", event.ToolCall)
722 assistantMsg.AddToolCall(*event.ToolCall)
723 return a.messages.Update(ctx, *assistantMsg)
724 case provider.EventToolUseDelta:
725 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
726 return a.messages.Update(ctx, *assistantMsg)
727 case provider.EventToolUseStop:
728 slog.Info("Finished tool call", "toolCall", event.ToolCall)
729 assistantMsg.FinishToolCall(event.ToolCall.ID)
730 return a.messages.Update(ctx, *assistantMsg)
731 case provider.EventError:
732 return event.Error
733 case provider.EventComplete:
734 assistantMsg.FinishThinking()
735 assistantMsg.SetToolCalls(event.Response.ToolCalls)
736 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
737 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
738 return fmt.Errorf("failed to update message: %w", err)
739 }
740 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
741 }
742
743 return nil
744}
745
746func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
747 sess, err := a.sessions.Get(ctx, sessionID)
748 if err != nil {
749 return fmt.Errorf("failed to get session: %w", err)
750 }
751
752 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
753 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
754 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
755 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
756
757 a.eventTokensUsed(sessionID, usage, cost)
758
759 sess.Cost += cost
760 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
761 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
762
763 _, err = a.sessions.Save(ctx, sess)
764 if err != nil {
765 return fmt.Errorf("failed to save session: %w", err)
766 }
767 return nil
768}
769
770func (a *agent) Summarize(ctx context.Context, sessionID string) error {
771 if a.summarizeProvider == nil {
772 return fmt.Errorf("summarize provider not available")
773 }
774
775 // Check if session is busy
776 if a.IsSessionBusy(sessionID) {
777 return ErrSessionBusy
778 }
779
780 // Create a new context with cancellation
781 summarizeCtx, cancel := context.WithCancel(ctx)
782
783 // Store the cancel function in activeRequests to allow cancellation
784 a.activeRequests.Set(sessionID+"-summarize", cancel)
785
786 go func() {
787 defer a.activeRequests.Del(sessionID + "-summarize")
788 defer cancel()
789 event := AgentEvent{
790 Type: AgentEventTypeSummarize,
791 Progress: "Starting summarization...",
792 }
793
794 a.Publish(pubsub.CreatedEvent, event)
795 // Get all messages from the session
796 msgs, err := a.messages.List(summarizeCtx, sessionID)
797 if err != nil {
798 event = AgentEvent{
799 Type: AgentEventTypeError,
800 Error: fmt.Errorf("failed to list messages: %w", err),
801 Done: true,
802 }
803 a.Publish(pubsub.CreatedEvent, event)
804 return
805 }
806 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
807
808 if len(msgs) == 0 {
809 event = AgentEvent{
810 Type: AgentEventTypeError,
811 Error: fmt.Errorf("no messages to summarize"),
812 Done: true,
813 }
814 a.Publish(pubsub.CreatedEvent, event)
815 return
816 }
817
818 event = AgentEvent{
819 Type: AgentEventTypeSummarize,
820 Progress: "Analyzing conversation...",
821 }
822 a.Publish(pubsub.CreatedEvent, event)
823
824 // Add a system message to guide the summarization
825 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
826
827 // Create a new message with the summarize prompt
828 promptMsg := message.Message{
829 Role: message.User,
830 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
831 }
832
833 // Append the prompt to the messages
834 msgsWithPrompt := append(msgs, promptMsg)
835
836 event = AgentEvent{
837 Type: AgentEventTypeSummarize,
838 Progress: "Generating summary...",
839 }
840
841 a.Publish(pubsub.CreatedEvent, event)
842
843 // Send the messages to the summarize provider
844 response := a.summarizeProvider.StreamResponse(
845 summarizeCtx,
846 msgsWithPrompt,
847 nil,
848 )
849 var finalResponse *provider.ProviderResponse
850 for r := range response {
851 if r.Error != nil {
852 event = AgentEvent{
853 Type: AgentEventTypeError,
854 Error: fmt.Errorf("failed to summarize: %w", r.Error),
855 Done: true,
856 }
857 a.Publish(pubsub.CreatedEvent, event)
858 return
859 }
860 finalResponse = r.Response
861 }
862
863 summary := strings.TrimSpace(finalResponse.Content)
864 if summary == "" {
865 event = AgentEvent{
866 Type: AgentEventTypeError,
867 Error: fmt.Errorf("empty summary returned"),
868 Done: true,
869 }
870 a.Publish(pubsub.CreatedEvent, event)
871 return
872 }
873 shell := shell.GetPersistentShell(config.Get().WorkingDir())
874 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
875 event = AgentEvent{
876 Type: AgentEventTypeSummarize,
877 Progress: "Creating new session...",
878 }
879
880 a.Publish(pubsub.CreatedEvent, event)
881 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
882 if err != nil {
883 event = AgentEvent{
884 Type: AgentEventTypeError,
885 Error: fmt.Errorf("failed to get session: %w", err),
886 Done: true,
887 }
888
889 a.Publish(pubsub.CreatedEvent, event)
890 return
891 }
892 // Create a message in the new session with the summary
893 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
894 Role: message.Assistant,
895 Parts: []message.ContentPart{
896 message.TextContent{Text: summary},
897 message.Finish{
898 Reason: message.FinishReasonEndTurn,
899 Time: time.Now().Unix(),
900 },
901 },
902 Model: a.summarizeProvider.Model().ID,
903 Provider: a.summarizeProviderID,
904 })
905 if err != nil {
906 event = AgentEvent{
907 Type: AgentEventTypeError,
908 Error: fmt.Errorf("failed to create summary message: %w", err),
909 Done: true,
910 }
911
912 a.Publish(pubsub.CreatedEvent, event)
913 return
914 }
915 oldSession.SummaryMessageID = msg.ID
916 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
917 oldSession.PromptTokens = 0
918 model := a.summarizeProvider.Model()
919 usage := finalResponse.Usage
920 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
921 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
922 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
923 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
924 oldSession.Cost += cost
925 _, err = a.sessions.Save(summarizeCtx, oldSession)
926 if err != nil {
927 event = AgentEvent{
928 Type: AgentEventTypeError,
929 Error: fmt.Errorf("failed to save session: %w", err),
930 Done: true,
931 }
932 a.Publish(pubsub.CreatedEvent, event)
933 }
934
935 event = AgentEvent{
936 Type: AgentEventTypeSummarize,
937 SessionID: oldSession.ID,
938 Progress: "Summary complete",
939 Done: true,
940 }
941 a.Publish(pubsub.CreatedEvent, event)
942 // Send final success event with the new session ID
943 }()
944
945 return nil
946}
947
948func (a *agent) ClearQueue(sessionID string) {
949 if a.QueuedPrompts(sessionID) > 0 {
950 slog.Info("Clearing queued prompts", "session_id", sessionID)
951 a.promptQueue.Del(sessionID)
952 }
953}
954
955func (a *agent) CancelAll() {
956 if !a.IsBusy() {
957 return
958 }
959 for key := range a.activeRequests.Seq2() {
960 a.Cancel(key) // key is sessionID
961 }
962
963 timeout := time.After(5 * time.Second)
964 for a.IsBusy() {
965 select {
966 case <-timeout:
967 return
968 default:
969 time.Sleep(200 * time.Millisecond)
970 }
971 }
972}
973
974func (a *agent) UpdateModel() error {
975 cfg := config.Get()
976
977 // Get current provider configuration
978 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
979 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
980 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
981 }
982
983 // Check if provider has changed
984 if string(currentProviderCfg.ID) != a.providerID {
985 // Provider changed, need to recreate the main provider
986 model := cfg.GetModelByType(a.agentCfg.Model)
987 if model.ID == "" {
988 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
989 }
990
991 promptID := agentPromptMap[a.agentCfg.ID]
992 if promptID == "" {
993 promptID = prompt.PromptDefault
994 }
995
996 opts := []provider.ProviderClientOption{
997 provider.WithModel(a.agentCfg.Model),
998 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
999 }
1000
1001 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1002 if err != nil {
1003 return fmt.Errorf("failed to create new provider: %w", err)
1004 }
1005
1006 // Update the provider and provider ID
1007 a.provider = newProvider
1008 a.providerID = string(currentProviderCfg.ID)
1009 }
1010
1011 // Check if providers have changed for title (small) and summarize (large)
1012 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1013 var smallModelProviderCfg config.ProviderConfig
1014 for p := range cfg.Providers.Seq() {
1015 if p.ID == smallModelCfg.Provider {
1016 smallModelProviderCfg = p
1017 break
1018 }
1019 }
1020 if smallModelProviderCfg.ID == "" {
1021 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1022 }
1023
1024 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1025 var largeModelProviderCfg config.ProviderConfig
1026 for p := range cfg.Providers.Seq() {
1027 if p.ID == largeModelCfg.Provider {
1028 largeModelProviderCfg = p
1029 break
1030 }
1031 }
1032 if largeModelProviderCfg.ID == "" {
1033 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1034 }
1035
1036 var maxTitleTokens int64 = 40
1037
1038 // if the max output is too low for the gemini provider it won't return anything
1039 if smallModelCfg.Provider == "gemini" {
1040 maxTitleTokens = 1000
1041 }
1042 // Recreate title provider
1043 titleOpts := []provider.ProviderClientOption{
1044 provider.WithModel(config.SelectedModelTypeSmall),
1045 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1046 provider.WithMaxTokens(maxTitleTokens),
1047 }
1048 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1049 if err != nil {
1050 return fmt.Errorf("failed to create new title provider: %w", err)
1051 }
1052 a.titleProvider = newTitleProvider
1053
1054 // Recreate summarize provider if provider changed (now large model)
1055 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1056 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1057 if largeModel == nil {
1058 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1059 }
1060 summarizeOpts := []provider.ProviderClientOption{
1061 provider.WithModel(config.SelectedModelTypeLarge),
1062 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1063 }
1064 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1065 if err != nil {
1066 return fmt.Errorf("failed to create new summarize provider: %w", err)
1067 }
1068 a.summarizeProvider = newSummarizeProvider
1069 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1070 }
1071
1072 return nil
1073}