1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/event"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26 "github.com/charmbracelet/crush/internal/shell"
27)
28
29type AgentEventType string
30
31const (
32 AgentEventTypeError AgentEventType = "error"
33 AgentEventTypeResponse AgentEventType = "response"
34 AgentEventTypeSummarize AgentEventType = "summarize"
35)
36
37type AgentEvent struct {
38 Type AgentEventType
39 Message message.Message
40 Error error
41
42 // When summarizing
43 SessionID string
44 Progress string
45 Done bool
46}
47
48type Service interface {
49 pubsub.Suscriber[AgentEvent]
50 Model() catwalk.Model
51 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
52 Cancel(sessionID string)
53 CancelAll()
54 IsSessionBusy(sessionID string) bool
55 IsBusy() bool
56 Summarize(ctx context.Context, sessionID string) error
57 UpdateModel() error
58 QueuedPrompts(sessionID string) int
59 ClearQueue(sessionID string)
60}
61
62type agent struct {
63 *pubsub.Broker[AgentEvent]
64 agentCfg config.Agent
65 sessions session.Service
66 messages message.Service
67 permissions permission.Service
68 mcpTools []McpTool
69
70 tools *csync.LazySlice[tools.BaseTool]
71 // We need this to be able to update it when model changes
72 agentToolFn func() (tools.BaseTool, error)
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests *csync.Map[string, context.CancelFunc]
82 promptQueue *csync.Map[string, []string]
83}
84
85var agentPromptMap = map[string]prompt.PromptID{
86 "coder": prompt.PromptCoder,
87 "task": prompt.PromptTask,
88}
89
90func NewAgent(
91 ctx context.Context,
92 agentCfg config.Agent,
93 // These services are needed in the tools
94 permissions permission.Service,
95 sessions session.Service,
96 messages message.Service,
97 history history.Service,
98 lspClients *csync.Map[string, *lsp.Client],
99) (Service, error) {
100 cfg := config.Get()
101
102 var agentToolFn func() (tools.BaseTool, error)
103 if agentCfg.ID == "coder" {
104 agentToolFn = func() (tools.BaseTool, error) {
105 taskAgentCfg := config.Get().Agents["task"]
106 if taskAgentCfg.ID == "" {
107 return nil, fmt.Errorf("task agent not found in config")
108 }
109 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
110 if err != nil {
111 return nil, fmt.Errorf("failed to create task agent: %w", err)
112 }
113 return NewAgentTool(taskAgent, sessions, messages), nil
114 }
115 }
116
117 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
118 if providerCfg == nil {
119 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
120 }
121 model := config.Get().GetModelByType(agentCfg.Model)
122
123 if model == nil {
124 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
125 }
126
127 promptID := agentPromptMap[agentCfg.ID]
128 if promptID == "" {
129 promptID = prompt.PromptDefault
130 }
131 opts := []provider.ProviderClientOption{
132 provider.WithModel(agentCfg.Model),
133 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
134 }
135 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
136 if err != nil {
137 return nil, err
138 }
139
140 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
141 var smallModelProviderCfg *config.ProviderConfig
142 if smallModelCfg.Provider == providerCfg.ID {
143 smallModelProviderCfg = providerCfg
144 } else {
145 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
146
147 if smallModelProviderCfg.ID == "" {
148 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
149 }
150 }
151 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
152 if smallModel.ID == "" {
153 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
154 }
155
156 titleOpts := []provider.ProviderClientOption{
157 provider.WithModel(config.SelectedModelTypeSmall),
158 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
159 }
160 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
161 if err != nil {
162 return nil, err
163 }
164
165 summarizeOpts := []provider.ProviderClientOption{
166 provider.WithModel(config.SelectedModelTypeLarge),
167 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
168 }
169 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
170 if err != nil {
171 return nil, err
172 }
173
174 toolFn := func() []tools.BaseTool {
175 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
176 defer func() {
177 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
178 }()
179
180 cwd := cfg.WorkingDir()
181 allTools := []tools.BaseTool{
182 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
183 tools.NewDownloadTool(permissions, cwd),
184 tools.NewEditTool(lspClients, permissions, history, cwd),
185 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
186 tools.NewFetchTool(permissions, cwd),
187 tools.NewGlobTool(cwd),
188 tools.NewGrepTool(cwd),
189 tools.NewLsTool(permissions, cwd),
190 tools.NewSourcegraphTool(),
191 tools.NewViewTool(lspClients, permissions, cwd),
192 tools.NewWriteTool(lspClients, permissions, history, cwd),
193 }
194
195 mcpToolsOnce.Do(func() {
196 mcpTools = doGetMCPTools(ctx, permissions, cfg)
197 })
198
199 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
200 if agentCfg.ID == "coder" {
201 t = append(t, mcpTools...)
202 if lspClients.Len() > 0 {
203 t = append(t, tools.NewDiagnosticsTool(lspClients))
204 }
205 }
206 return t
207 }
208
209 if agentCfg.AllowedTools == nil {
210 return withCoderTools(allTools)
211 }
212
213 var filteredTools []tools.BaseTool
214 for _, tool := range allTools {
215 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
216 filteredTools = append(filteredTools, tool)
217 }
218 }
219 return withCoderTools(filteredTools)
220 }
221
222 return &agent{
223 Broker: pubsub.NewBroker[AgentEvent](),
224 agentCfg: agentCfg,
225 provider: agentProvider,
226 providerID: string(providerCfg.ID),
227 messages: messages,
228 sessions: sessions,
229 titleProvider: titleProvider,
230 summarizeProvider: summarizeProvider,
231 summarizeProviderID: string(providerCfg.ID),
232 agentToolFn: agentToolFn,
233 activeRequests: csync.NewMap[string, context.CancelFunc](),
234 tools: csync.NewLazySlice(toolFn),
235 promptQueue: csync.NewMap[string, []string](),
236 permissions: permissions,
237 }, nil
238}
239
240func (a *agent) Model() catwalk.Model {
241 return *config.Get().GetModelByType(a.agentCfg.Model)
242}
243
244func (a *agent) Cancel(sessionID string) {
245 // Cancel regular requests
246 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
247 slog.Info("Request cancellation initiated", "session_id", sessionID)
248 cancel()
249 }
250
251 // Also check for summarize requests
252 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
253 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
254 cancel()
255 }
256
257 if a.QueuedPrompts(sessionID) > 0 {
258 slog.Info("Clearing queued prompts", "session_id", sessionID)
259 a.promptQueue.Del(sessionID)
260 }
261}
262
263func (a *agent) IsBusy() bool {
264 var busy bool
265 for cancelFunc := range a.activeRequests.Seq() {
266 if cancelFunc != nil {
267 busy = true
268 break
269 }
270 }
271 return busy
272}
273
274func (a *agent) IsSessionBusy(sessionID string) bool {
275 _, busy := a.activeRequests.Get(sessionID)
276 return busy
277}
278
279func (a *agent) QueuedPrompts(sessionID string) int {
280 l, ok := a.promptQueue.Get(sessionID)
281 if !ok {
282 return 0
283 }
284 return len(l)
285}
286
287func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
288 if content == "" {
289 return nil
290 }
291 if a.titleProvider == nil {
292 return nil
293 }
294 session, err := a.sessions.Get(ctx, sessionID)
295 if err != nil {
296 return err
297 }
298 parts := []message.ContentPart{message.TextContent{
299 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
300 }}
301
302 // Use streaming approach like summarization
303 response := a.titleProvider.StreamResponse(
304 ctx,
305 []message.Message{
306 {
307 Role: message.User,
308 Parts: parts,
309 },
310 },
311 nil,
312 )
313
314 var finalResponse *provider.ProviderResponse
315 for r := range response {
316 if r.Error != nil {
317 return r.Error
318 }
319 finalResponse = r.Response
320 }
321
322 if finalResponse == nil {
323 return fmt.Errorf("no response received from title provider")
324 }
325
326 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
327
328 if idx := strings.Index(title, "</think>"); idx > 0 {
329 title = title[idx+len("</think>"):]
330 }
331
332 title = strings.TrimSpace(title)
333 if title == "" {
334 return nil
335 }
336
337 session.Title = title
338 _, err = a.sessions.Save(ctx, session)
339 return err
340}
341
342func (a *agent) err(err error) AgentEvent {
343 return AgentEvent{
344 Type: AgentEventTypeError,
345 Error: err,
346 }
347}
348
349func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
350 if !a.Model().SupportsImages && attachments != nil {
351 attachments = nil
352 }
353 events := make(chan AgentEvent, 1)
354 if a.IsSessionBusy(sessionID) {
355 existing, ok := a.promptQueue.Get(sessionID)
356 if !ok {
357 existing = []string{}
358 }
359 existing = append(existing, content)
360 a.promptQueue.Set(sessionID, existing)
361 return nil, nil
362 }
363
364 genCtx, cancel := context.WithCancel(ctx)
365 a.activeRequests.Set(sessionID, cancel)
366 startTime := time.Now()
367
368 go func() {
369 slog.Debug("Request started", "sessionID", sessionID)
370 defer log.RecoverPanic("agent.Run", func() {
371 events <- a.err(fmt.Errorf("panic while running the agent"))
372 })
373 var attachmentParts []message.ContentPart
374 for _, attachment := range attachments {
375 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
376 }
377 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
378 if result.Error != nil {
379 if isCancelledErr(result.Error) {
380 slog.Error("Request canceled", "sessionID", sessionID)
381 } else {
382 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
383 event.Error(result.Error)
384 }
385 } else {
386 slog.Debug("Request completed", "sessionID", sessionID)
387 }
388 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
389 a.activeRequests.Del(sessionID)
390 cancel()
391 a.Publish(pubsub.CreatedEvent, result)
392 events <- result
393 close(events)
394 }()
395 a.eventPromptSent(sessionID)
396 return events, nil
397}
398
399func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
400 cfg := config.Get()
401 // List existing messages; if none, start title generation asynchronously.
402 msgs, err := a.messages.List(ctx, sessionID)
403 if err != nil {
404 return a.err(fmt.Errorf("failed to list messages: %w", err))
405 }
406 if len(msgs) == 0 {
407 go func() {
408 defer log.RecoverPanic("agent.Run", func() {
409 slog.Error("panic while generating title")
410 })
411 titleErr := a.generateTitle(ctx, sessionID, content)
412 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
413 slog.Error("failed to generate title", "error", titleErr)
414 }
415 }()
416 }
417 session, err := a.sessions.Get(ctx, sessionID)
418 if err != nil {
419 return a.err(fmt.Errorf("failed to get session: %w", err))
420 }
421 if session.SummaryMessageID != "" {
422 summaryMsgInex := -1
423 for i, msg := range msgs {
424 if msg.ID == session.SummaryMessageID {
425 summaryMsgInex = i
426 break
427 }
428 }
429 if summaryMsgInex != -1 {
430 msgs = msgs[summaryMsgInex:]
431 msgs[0].Role = message.User
432 }
433 }
434
435 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
436 if err != nil {
437 return a.err(fmt.Errorf("failed to create user message: %w", err))
438 }
439 // Append the new user message to the conversation history.
440 msgHistory := append(msgs, userMsg)
441
442 for {
443 // Check for cancellation before each iteration
444 select {
445 case <-ctx.Done():
446 return a.err(ctx.Err())
447 default:
448 // Continue processing
449 }
450 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
451 if err != nil {
452 if errors.Is(err, context.Canceled) {
453 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
454 a.messages.Update(context.Background(), agentMessage)
455 return a.err(ErrRequestCancelled)
456 }
457 return a.err(fmt.Errorf("failed to process events: %w", err))
458 }
459 if cfg.Options.Debug {
460 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
461 }
462 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
463 // We are not done, we need to respond with the tool response
464 msgHistory = append(msgHistory, agentMessage, *toolResults)
465 // If there are queued prompts, process the next one
466 nextPrompt, ok := a.promptQueue.Take(sessionID)
467 if ok {
468 for _, prompt := range nextPrompt {
469 // Create a new user message for the queued prompt
470 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
471 if err != nil {
472 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
473 }
474 // Append the new user message to the conversation history
475 msgHistory = append(msgHistory, userMsg)
476 }
477 }
478
479 continue
480 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
481 queuePrompts, ok := a.promptQueue.Take(sessionID)
482 if ok {
483 for _, prompt := range queuePrompts {
484 if prompt == "" {
485 continue
486 }
487 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
488 if err != nil {
489 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
490 }
491 msgHistory = append(msgHistory, userMsg)
492 }
493 continue
494 }
495 }
496 if agentMessage.FinishReason() == "" {
497 // Kujtim: could not track down where this is happening but this means its cancelled
498 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
499 _ = a.messages.Update(context.Background(), agentMessage)
500 return a.err(ErrRequestCancelled)
501 }
502 return AgentEvent{
503 Type: AgentEventTypeResponse,
504 Message: agentMessage,
505 Done: true,
506 }
507 }
508}
509
510func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
511 parts := []message.ContentPart{message.TextContent{Text: content}}
512 parts = append(parts, attachmentParts...)
513 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
514 Role: message.User,
515 Parts: parts,
516 })
517}
518
519func (a *agent) getAllTools() ([]tools.BaseTool, error) {
520 allTools := slices.Collect(a.tools.Seq())
521 if a.agentToolFn != nil {
522 agentTool, agentToolErr := a.agentToolFn()
523 if agentToolErr != nil {
524 return nil, agentToolErr
525 }
526 allTools = append(allTools, agentTool)
527 }
528 return allTools, nil
529}
530
531func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
532 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
533
534 // Create the assistant message first so the spinner shows immediately
535 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
536 Role: message.Assistant,
537 Parts: []message.ContentPart{},
538 Model: a.Model().ID,
539 Provider: a.providerID,
540 })
541 if err != nil {
542 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
543 }
544
545 allTools, toolsErr := a.getAllTools()
546 if toolsErr != nil {
547 return assistantMsg, nil, toolsErr
548 }
549 // Now collect tools (which may block on MCP initialization)
550 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
551
552 // Add the session and message ID into the context if needed by tools.
553 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
554
555 // Process each event in the stream.
556 for event := range eventChan {
557 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
558 if errors.Is(processErr, context.Canceled) {
559 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
560 } else {
561 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
562 }
563 return assistantMsg, nil, processErr
564 }
565 if ctx.Err() != nil {
566 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
567 return assistantMsg, nil, ctx.Err()
568 }
569 }
570
571 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
572 toolCalls := assistantMsg.ToolCalls()
573 for i, toolCall := range toolCalls {
574 select {
575 case <-ctx.Done():
576 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
577 // Make all future tool calls cancelled
578 for j := i; j < len(toolCalls); j++ {
579 toolResults[j] = message.ToolResult{
580 ToolCallID: toolCalls[j].ID,
581 Content: "Tool execution canceled by user",
582 IsError: true,
583 }
584 }
585 goto out
586 default:
587 // Continue processing
588 var tool tools.BaseTool
589 allTools, _ := a.getAllTools()
590 for _, availableTool := range allTools {
591 if availableTool.Info().Name == toolCall.Name {
592 tool = availableTool
593 break
594 }
595 }
596
597 // Tool not found
598 if tool == nil {
599 toolResults[i] = message.ToolResult{
600 ToolCallID: toolCall.ID,
601 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
602 IsError: true,
603 }
604 continue
605 }
606
607 // Run tool in goroutine to allow cancellation
608 type toolExecResult struct {
609 response tools.ToolResponse
610 err error
611 }
612 resultChan := make(chan toolExecResult, 1)
613
614 go func() {
615 response, err := tool.Run(ctx, tools.ToolCall{
616 ID: toolCall.ID,
617 Name: toolCall.Name,
618 Input: toolCall.Input,
619 })
620 resultChan <- toolExecResult{response: response, err: err}
621 }()
622
623 var toolResponse tools.ToolResponse
624 var toolErr error
625
626 select {
627 case <-ctx.Done():
628 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
629 // Mark remaining tool calls as cancelled
630 for j := i; j < len(toolCalls); j++ {
631 toolResults[j] = message.ToolResult{
632 ToolCallID: toolCalls[j].ID,
633 Content: "Tool execution canceled by user",
634 IsError: true,
635 }
636 }
637 goto out
638 case result := <-resultChan:
639 toolResponse = result.response
640 toolErr = result.err
641 }
642
643 if toolErr != nil {
644 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
645 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
646 toolResults[i] = message.ToolResult{
647 ToolCallID: toolCall.ID,
648 Content: "Permission denied",
649 IsError: true,
650 }
651 for j := i + 1; j < len(toolCalls); j++ {
652 toolResults[j] = message.ToolResult{
653 ToolCallID: toolCalls[j].ID,
654 Content: "Tool execution canceled by user",
655 IsError: true,
656 }
657 }
658 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
659 break
660 }
661 }
662 toolResults[i] = message.ToolResult{
663 ToolCallID: toolCall.ID,
664 Content: toolResponse.Content,
665 Metadata: toolResponse.Metadata,
666 IsError: toolResponse.IsError,
667 }
668 }
669 }
670out:
671 if len(toolResults) == 0 {
672 return assistantMsg, nil, nil
673 }
674 parts := make([]message.ContentPart, 0)
675 for _, tr := range toolResults {
676 parts = append(parts, tr)
677 }
678 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
679 Role: message.Tool,
680 Parts: parts,
681 Provider: a.providerID,
682 })
683 if err != nil {
684 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
685 }
686
687 return assistantMsg, &msg, err
688}
689
690func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
691 msg.AddFinish(finishReason, message, details)
692 _ = a.messages.Update(ctx, *msg)
693}
694
695func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
696 select {
697 case <-ctx.Done():
698 return ctx.Err()
699 default:
700 // Continue processing.
701 }
702
703 switch event.Type {
704 case provider.EventThinkingDelta:
705 assistantMsg.AppendReasoningContent(event.Thinking)
706 return a.messages.Update(ctx, *assistantMsg)
707 case provider.EventSignatureDelta:
708 assistantMsg.AppendReasoningSignature(event.Signature)
709 return a.messages.Update(ctx, *assistantMsg)
710 case provider.EventContentDelta:
711 assistantMsg.FinishThinking()
712 assistantMsg.AppendContent(event.Content)
713 return a.messages.Update(ctx, *assistantMsg)
714 case provider.EventToolUseStart:
715 assistantMsg.FinishThinking()
716 slog.Info("Tool call started", "toolCall", event.ToolCall)
717 assistantMsg.AddToolCall(*event.ToolCall)
718 return a.messages.Update(ctx, *assistantMsg)
719 case provider.EventToolUseDelta:
720 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
721 return a.messages.Update(ctx, *assistantMsg)
722 case provider.EventToolUseStop:
723 slog.Info("Finished tool call", "toolCall", event.ToolCall)
724 assistantMsg.FinishToolCall(event.ToolCall.ID)
725 return a.messages.Update(ctx, *assistantMsg)
726 case provider.EventError:
727 return event.Error
728 case provider.EventComplete:
729 assistantMsg.FinishThinking()
730 assistantMsg.SetToolCalls(event.Response.ToolCalls)
731 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
732 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
733 return fmt.Errorf("failed to update message: %w", err)
734 }
735 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
736 }
737
738 return nil
739}
740
741func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
742 sess, err := a.sessions.Get(ctx, sessionID)
743 if err != nil {
744 return fmt.Errorf("failed to get session: %w", err)
745 }
746
747 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
748 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
749 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
750 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
751
752 a.eventTokensUsed(sessionID, usage, cost)
753
754 sess.Cost += cost
755 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
756 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
757
758 _, err = a.sessions.Save(ctx, sess)
759 if err != nil {
760 return fmt.Errorf("failed to save session: %w", err)
761 }
762 return nil
763}
764
765func (a *agent) Summarize(ctx context.Context, sessionID string) error {
766 if a.summarizeProvider == nil {
767 return fmt.Errorf("summarize provider not available")
768 }
769
770 // Check if session is busy
771 if a.IsSessionBusy(sessionID) {
772 return ErrSessionBusy
773 }
774
775 // Create a new context with cancellation
776 summarizeCtx, cancel := context.WithCancel(ctx)
777
778 // Store the cancel function in activeRequests to allow cancellation
779 a.activeRequests.Set(sessionID+"-summarize", cancel)
780
781 go func() {
782 defer a.activeRequests.Del(sessionID + "-summarize")
783 defer cancel()
784 event := AgentEvent{
785 Type: AgentEventTypeSummarize,
786 Progress: "Starting summarization...",
787 }
788
789 a.Publish(pubsub.CreatedEvent, event)
790 // Get all messages from the session
791 msgs, err := a.messages.List(summarizeCtx, sessionID)
792 if err != nil {
793 event = AgentEvent{
794 Type: AgentEventTypeError,
795 Error: fmt.Errorf("failed to list messages: %w", err),
796 Done: true,
797 }
798 a.Publish(pubsub.CreatedEvent, event)
799 return
800 }
801 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
802
803 if len(msgs) == 0 {
804 event = AgentEvent{
805 Type: AgentEventTypeError,
806 Error: fmt.Errorf("no messages to summarize"),
807 Done: true,
808 }
809 a.Publish(pubsub.CreatedEvent, event)
810 return
811 }
812
813 event = AgentEvent{
814 Type: AgentEventTypeSummarize,
815 Progress: "Analyzing conversation...",
816 }
817 a.Publish(pubsub.CreatedEvent, event)
818
819 // Add a system message to guide the summarization
820 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
821
822 // Create a new message with the summarize prompt
823 promptMsg := message.Message{
824 Role: message.User,
825 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
826 }
827
828 // Append the prompt to the messages
829 msgsWithPrompt := append(msgs, promptMsg)
830
831 event = AgentEvent{
832 Type: AgentEventTypeSummarize,
833 Progress: "Generating summary...",
834 }
835
836 a.Publish(pubsub.CreatedEvent, event)
837
838 // Send the messages to the summarize provider
839 response := a.summarizeProvider.StreamResponse(
840 summarizeCtx,
841 msgsWithPrompt,
842 nil,
843 )
844 var finalResponse *provider.ProviderResponse
845 for r := range response {
846 if r.Error != nil {
847 event = AgentEvent{
848 Type: AgentEventTypeError,
849 Error: fmt.Errorf("failed to summarize: %w", r.Error),
850 Done: true,
851 }
852 a.Publish(pubsub.CreatedEvent, event)
853 return
854 }
855 finalResponse = r.Response
856 }
857
858 summary := strings.TrimSpace(finalResponse.Content)
859 if summary == "" {
860 event = AgentEvent{
861 Type: AgentEventTypeError,
862 Error: fmt.Errorf("empty summary returned"),
863 Done: true,
864 }
865 a.Publish(pubsub.CreatedEvent, event)
866 return
867 }
868 shell := shell.GetPersistentShell(config.Get().WorkingDir())
869 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
870 event = AgentEvent{
871 Type: AgentEventTypeSummarize,
872 Progress: "Creating new session...",
873 }
874
875 a.Publish(pubsub.CreatedEvent, event)
876 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
877 if err != nil {
878 event = AgentEvent{
879 Type: AgentEventTypeError,
880 Error: fmt.Errorf("failed to get session: %w", err),
881 Done: true,
882 }
883
884 a.Publish(pubsub.CreatedEvent, event)
885 return
886 }
887 // Create a message in the new session with the summary
888 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
889 Role: message.Assistant,
890 Parts: []message.ContentPart{
891 message.TextContent{Text: summary},
892 message.Finish{
893 Reason: message.FinishReasonEndTurn,
894 Time: time.Now().Unix(),
895 },
896 },
897 Model: a.summarizeProvider.Model().ID,
898 Provider: a.summarizeProviderID,
899 })
900 if err != nil {
901 event = AgentEvent{
902 Type: AgentEventTypeError,
903 Error: fmt.Errorf("failed to create summary message: %w", err),
904 Done: true,
905 }
906
907 a.Publish(pubsub.CreatedEvent, event)
908 return
909 }
910 oldSession.SummaryMessageID = msg.ID
911 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
912 oldSession.PromptTokens = 0
913 model := a.summarizeProvider.Model()
914 usage := finalResponse.Usage
915 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
916 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
917 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
918 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
919 oldSession.Cost += cost
920 _, err = a.sessions.Save(summarizeCtx, oldSession)
921 if err != nil {
922 event = AgentEvent{
923 Type: AgentEventTypeError,
924 Error: fmt.Errorf("failed to save session: %w", err),
925 Done: true,
926 }
927 a.Publish(pubsub.CreatedEvent, event)
928 }
929
930 event = AgentEvent{
931 Type: AgentEventTypeSummarize,
932 SessionID: oldSession.ID,
933 Progress: "Summary complete",
934 Done: true,
935 }
936 a.Publish(pubsub.CreatedEvent, event)
937 // Send final success event with the new session ID
938 }()
939
940 return nil
941}
942
943func (a *agent) ClearQueue(sessionID string) {
944 if a.QueuedPrompts(sessionID) > 0 {
945 slog.Info("Clearing queued prompts", "session_id", sessionID)
946 a.promptQueue.Del(sessionID)
947 }
948}
949
950func (a *agent) CancelAll() {
951 if !a.IsBusy() {
952 return
953 }
954 for key := range a.activeRequests.Seq2() {
955 a.Cancel(key) // key is sessionID
956 }
957
958 timeout := time.After(5 * time.Second)
959 for a.IsBusy() {
960 select {
961 case <-timeout:
962 return
963 default:
964 time.Sleep(200 * time.Millisecond)
965 }
966 }
967}
968
969func (a *agent) UpdateModel() error {
970 cfg := config.Get()
971
972 // Get current provider configuration
973 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
974 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
975 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
976 }
977
978 // Check if provider has changed
979 if string(currentProviderCfg.ID) != a.providerID {
980 // Provider changed, need to recreate the main provider
981 model := cfg.GetModelByType(a.agentCfg.Model)
982 if model.ID == "" {
983 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
984 }
985
986 promptID := agentPromptMap[a.agentCfg.ID]
987 if promptID == "" {
988 promptID = prompt.PromptDefault
989 }
990
991 opts := []provider.ProviderClientOption{
992 provider.WithModel(a.agentCfg.Model),
993 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
994 }
995
996 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
997 if err != nil {
998 return fmt.Errorf("failed to create new provider: %w", err)
999 }
1000
1001 // Update the provider and provider ID
1002 a.provider = newProvider
1003 a.providerID = string(currentProviderCfg.ID)
1004 }
1005
1006 // Check if providers have changed for title (small) and summarize (large)
1007 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1008 var smallModelProviderCfg config.ProviderConfig
1009 for p := range cfg.Providers.Seq() {
1010 if p.ID == smallModelCfg.Provider {
1011 smallModelProviderCfg = p
1012 break
1013 }
1014 }
1015 if smallModelProviderCfg.ID == "" {
1016 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1017 }
1018
1019 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1020 var largeModelProviderCfg config.ProviderConfig
1021 for p := range cfg.Providers.Seq() {
1022 if p.ID == largeModelCfg.Provider {
1023 largeModelProviderCfg = p
1024 break
1025 }
1026 }
1027 if largeModelProviderCfg.ID == "" {
1028 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1029 }
1030
1031 var maxTitleTokens int64 = 40
1032
1033 // if the max output is too low for the gemini provider it won't return anything
1034 if smallModelCfg.Provider == "gemini" {
1035 maxTitleTokens = 1000
1036 }
1037 // Recreate title provider
1038 titleOpts := []provider.ProviderClientOption{
1039 provider.WithModel(config.SelectedModelTypeSmall),
1040 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1041 provider.WithMaxTokens(maxTitleTokens),
1042 }
1043 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1044 if err != nil {
1045 return fmt.Errorf("failed to create new title provider: %w", err)
1046 }
1047 a.titleProvider = newTitleProvider
1048
1049 // Recreate summarize provider if provider changed (now large model)
1050 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1051 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1052 if largeModel == nil {
1053 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1054 }
1055 summarizeOpts := []provider.ProviderClientOption{
1056 provider.WithModel(config.SelectedModelTypeLarge),
1057 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1058 }
1059 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1060 if err != nil {
1061 return fmt.Errorf("failed to create new summarize provider: %w", err)
1062 }
1063 a.summarizeProvider = newSummarizeProvider
1064 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1065 }
1066
1067 return nil
1068}