1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.LazySlice[tools.BaseTool]
75 // We need this to be able to update it when model changes
76 agentToolFn func() (tools.BaseTool, error)
77
78 provider provider.Provider
79 providerID string
80
81 titleProvider provider.Provider
82 summarizeProvider provider.Provider
83 summarizeProviderID string
84
85 activeRequests *csync.Map[string, context.CancelFunc]
86
87 promptQueue *csync.Map[string, []string]
88}
89
90var agentPromptMap = map[string]prompt.PromptID{
91 "coder": prompt.PromptCoder,
92 "task": prompt.PromptTask,
93}
94
95func NewAgent(
96 agentCfg config.Agent,
97 // These services are needed in the tools
98 permissions permission.Service,
99 sessions session.Service,
100 messages message.Service,
101 history history.Service,
102 lspClients map[string]*lsp.Client,
103) (Service, error) {
104 cfg := config.Get()
105
106 var agentToolFn func() (tools.BaseTool, error)
107 if agentCfg.ID == "coder" {
108 agentToolFn = func() (tools.BaseTool, error) {
109 taskAgentCfg := config.Get().Agents["task"]
110 if taskAgentCfg.ID == "" {
111 return nil, fmt.Errorf("task agent not found in config")
112 }
113 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
114 if err != nil {
115 return nil, fmt.Errorf("failed to create task agent: %w", err)
116 }
117 return NewAgentTool(taskAgent, sessions, messages), nil
118 }
119 }
120
121 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
122 if providerCfg == nil {
123 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
124 }
125 model := config.Get().GetModelByType(agentCfg.Model)
126
127 if model == nil {
128 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
129 }
130
131 promptID := agentPromptMap[agentCfg.ID]
132 if promptID == "" {
133 promptID = prompt.PromptDefault
134 }
135 opts := []provider.ProviderClientOption{
136 provider.WithModel(agentCfg.Model),
137 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
138 }
139 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
140 if err != nil {
141 return nil, err
142 }
143
144 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
145 var smallModelProviderCfg *config.ProviderConfig
146 if smallModelCfg.Provider == providerCfg.ID {
147 smallModelProviderCfg = providerCfg
148 } else {
149 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
150
151 if smallModelProviderCfg.ID == "" {
152 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
153 }
154 }
155 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
156 if smallModel.ID == "" {
157 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
158 }
159
160 titleOpts := []provider.ProviderClientOption{
161 provider.WithModel(config.SelectedModelTypeSmall),
162 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
163 }
164 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
165 if err != nil {
166 return nil, err
167 }
168
169 summarizeOpts := []provider.ProviderClientOption{
170 provider.WithModel(config.SelectedModelTypeLarge),
171 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
172 }
173 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
174 if err != nil {
175 return nil, err
176 }
177
178 toolFn := func() []tools.BaseTool {
179 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
180 defer func() {
181 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
182 }()
183
184 cwd := cfg.WorkingDir()
185 allTools := []tools.BaseTool{
186 tools.NewBashTool(permissions, cwd),
187 tools.NewDownloadTool(permissions, cwd),
188 tools.NewEditTool(lspClients, permissions, history, cwd),
189 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
190 tools.NewFetchTool(permissions, cwd),
191 tools.NewGlobTool(cwd),
192 tools.NewGrepTool(cwd),
193 tools.NewLsTool(permissions, cwd),
194 tools.NewSourcegraphTool(),
195 tools.NewViewTool(lspClients, permissions, cwd),
196 tools.NewWriteTool(lspClients, permissions, history, cwd),
197 }
198
199 mcpToolsOnce.Do(func() {
200 mcpTools = doGetMCPTools(permissions, cfg)
201 })
202 allTools = append(allTools, mcpTools...)
203
204 if len(lspClients) > 0 {
205 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
206 }
207
208 if agentCfg.AllowedTools == nil {
209 return allTools
210 }
211
212 var filteredTools []tools.BaseTool
213 for _, tool := range allTools {
214 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
215 filteredTools = append(filteredTools, tool)
216 }
217 }
218 return filteredTools
219 }
220
221 return &agent{
222 Broker: pubsub.NewBroker[AgentEvent](),
223 agentCfg: agentCfg,
224 provider: agentProvider,
225 providerID: string(providerCfg.ID),
226 messages: messages,
227 sessions: sessions,
228 titleProvider: titleProvider,
229 summarizeProvider: summarizeProvider,
230 summarizeProviderID: string(providerCfg.ID),
231 agentToolFn: agentToolFn,
232 activeRequests: csync.NewMap[string, context.CancelFunc](),
233 tools: csync.NewLazySlice(toolFn),
234 promptQueue: csync.NewMap[string, []string](),
235 }, nil
236}
237
238func (a *agent) Model() catwalk.Model {
239 return *config.Get().GetModelByType(a.agentCfg.Model)
240}
241
242func (a *agent) Cancel(sessionID string) {
243 // Cancel regular requests
244 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
245 slog.Info("Request cancellation initiated", "session_id", sessionID)
246 cancel()
247 }
248
249 // Also check for summarize requests
250 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
251 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
252 cancel()
253 }
254
255 if a.QueuedPrompts(sessionID) > 0 {
256 slog.Info("Clearing queued prompts", "session_id", sessionID)
257 a.promptQueue.Del(sessionID)
258 }
259}
260
261func (a *agent) IsBusy() bool {
262 var busy bool
263 for cancelFunc := range a.activeRequests.Seq() {
264 if cancelFunc != nil {
265 busy = true
266 break
267 }
268 }
269 return busy
270}
271
272func (a *agent) IsSessionBusy(sessionID string) bool {
273 _, busy := a.activeRequests.Get(sessionID)
274 return busy
275}
276
277func (a *agent) QueuedPrompts(sessionID string) int {
278 l, ok := a.promptQueue.Get(sessionID)
279 if !ok {
280 return 0
281 }
282 return len(l)
283}
284
285func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
286 if content == "" {
287 return nil
288 }
289 if a.titleProvider == nil {
290 return nil
291 }
292 session, err := a.sessions.Get(ctx, sessionID)
293 if err != nil {
294 return err
295 }
296 parts := []message.ContentPart{message.TextContent{
297 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
298 }}
299
300 // Use streaming approach like summarization
301 response := a.titleProvider.StreamResponse(
302 ctx,
303 []message.Message{
304 {
305 Role: message.User,
306 Parts: parts,
307 },
308 },
309 nil,
310 )
311
312 var finalResponse *provider.ProviderResponse
313 for r := range response {
314 if r.Error != nil {
315 return r.Error
316 }
317 finalResponse = r.Response
318 }
319
320 if finalResponse == nil {
321 return fmt.Errorf("no response received from title provider")
322 }
323
324 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
325 if title == "" {
326 return nil
327 }
328
329 session.Title = title
330 _, err = a.sessions.Save(ctx, session)
331 return err
332}
333
334func (a *agent) err(err error) AgentEvent {
335 return AgentEvent{
336 Type: AgentEventTypeError,
337 Error: err,
338 }
339}
340
341func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
342 if !a.Model().SupportsImages && attachments != nil {
343 attachments = nil
344 }
345 events := make(chan AgentEvent, 1)
346 if a.IsSessionBusy(sessionID) {
347 existing, ok := a.promptQueue.Get(sessionID)
348 if !ok {
349 existing = []string{}
350 }
351 existing = append(existing, content)
352 a.promptQueue.Set(sessionID, existing)
353 return nil, nil
354 }
355
356 genCtx, cancel := context.WithCancel(ctx)
357
358 a.activeRequests.Set(sessionID, cancel)
359 go func() {
360 slog.Debug("Request started", "sessionID", sessionID)
361 defer log.RecoverPanic("agent.Run", func() {
362 events <- a.err(fmt.Errorf("panic while running the agent"))
363 })
364 var attachmentParts []message.ContentPart
365 for _, attachment := range attachments {
366 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
367 }
368 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
369 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
370 slog.Error(result.Error.Error())
371 }
372 slog.Debug("Request completed", "sessionID", sessionID)
373 a.activeRequests.Del(sessionID)
374 cancel()
375 a.Publish(pubsub.CreatedEvent, result)
376 events <- result
377 close(events)
378 }()
379 return events, nil
380}
381
382func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
383 cfg := config.Get()
384 // List existing messages; if none, start title generation asynchronously.
385 msgs, err := a.messages.List(ctx, sessionID)
386 if err != nil {
387 return a.err(fmt.Errorf("failed to list messages: %w", err))
388 }
389 if len(msgs) == 0 {
390 go func() {
391 defer log.RecoverPanic("agent.Run", func() {
392 slog.Error("panic while generating title")
393 })
394 titleErr := a.generateTitle(ctx, sessionID, content)
395 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
396 slog.Error("failed to generate title", "error", titleErr)
397 }
398 }()
399 }
400 session, err := a.sessions.Get(ctx, sessionID)
401 if err != nil {
402 return a.err(fmt.Errorf("failed to get session: %w", err))
403 }
404 if session.SummaryMessageID != "" {
405 summaryMsgInex := -1
406 for i, msg := range msgs {
407 if msg.ID == session.SummaryMessageID {
408 summaryMsgInex = i
409 break
410 }
411 }
412 if summaryMsgInex != -1 {
413 msgs = msgs[summaryMsgInex:]
414 msgs[0].Role = message.User
415 }
416 }
417
418 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
419 if err != nil {
420 return a.err(fmt.Errorf("failed to create user message: %w", err))
421 }
422 // Append the new user message to the conversation history.
423 msgHistory := append(msgs, userMsg)
424
425 for {
426 // Check for cancellation before each iteration
427 select {
428 case <-ctx.Done():
429 return a.err(ctx.Err())
430 default:
431 // Continue processing
432 }
433 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
434 if err != nil {
435 if errors.Is(err, context.Canceled) {
436 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
437 a.messages.Update(context.Background(), agentMessage)
438 return a.err(ErrRequestCancelled)
439 }
440 return a.err(fmt.Errorf("failed to process events: %w", err))
441 }
442 if cfg.Options.Debug {
443 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
444 }
445 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
446 // We are not done, we need to respond with the tool response
447 msgHistory = append(msgHistory, agentMessage, *toolResults)
448 // If there are queued prompts, process the next one
449 nextPrompt, ok := a.promptQueue.Take(sessionID)
450 if ok {
451 for _, prompt := range nextPrompt {
452 // Create a new user message for the queued prompt
453 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
454 if err != nil {
455 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
456 }
457 // Append the new user message to the conversation history
458 msgHistory = append(msgHistory, userMsg)
459 }
460 }
461
462 continue
463 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
464 queuePrompts, ok := a.promptQueue.Take(sessionID)
465 if ok {
466 for _, prompt := range queuePrompts {
467 if prompt == "" {
468 continue
469 }
470 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
471 if err != nil {
472 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
473 }
474 msgHistory = append(msgHistory, userMsg)
475 }
476 continue
477 }
478 }
479 if agentMessage.FinishReason() == "" {
480 // Kujtim: could not track down where this is happening but this means its cancelled
481 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
482 _ = a.messages.Update(context.Background(), agentMessage)
483 return a.err(ErrRequestCancelled)
484 }
485 return AgentEvent{
486 Type: AgentEventTypeResponse,
487 Message: agentMessage,
488 Done: true,
489 }
490 }
491}
492
493func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
494 parts := []message.ContentPart{message.TextContent{Text: content}}
495 parts = append(parts, attachmentParts...)
496 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
497 Role: message.User,
498 Parts: parts,
499 })
500}
501
502func (a *agent) getAllTools() ([]tools.BaseTool, error) {
503 allTools := slices.Collect(a.tools.Seq())
504 if a.agentToolFn != nil {
505 agentTool, agentToolErr := a.agentToolFn()
506 if agentToolErr != nil {
507 return nil, agentToolErr
508 }
509 allTools = append(allTools, agentTool)
510 }
511 return allTools, nil
512}
513
514func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
515 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
516
517 // Create the assistant message first so the spinner shows immediately
518 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
519 Role: message.Assistant,
520 Parts: []message.ContentPart{},
521 Model: a.Model().ID,
522 Provider: a.providerID,
523 })
524 if err != nil {
525 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
526 }
527
528 allTools, toolsErr := a.getAllTools()
529 if toolsErr != nil {
530 return assistantMsg, nil, toolsErr
531 }
532 // Now collect tools (which may block on MCP initialization)
533 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
534
535 // Add the session and message ID into the context if needed by tools.
536 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
537
538 // Process each event in the stream.
539 for event := range eventChan {
540 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
541 if errors.Is(processErr, context.Canceled) {
542 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
543 } else {
544 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
545 }
546 return assistantMsg, nil, processErr
547 }
548 if ctx.Err() != nil {
549 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
550 return assistantMsg, nil, ctx.Err()
551 }
552 }
553
554 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
555 toolCalls := assistantMsg.ToolCalls()
556 for i, toolCall := range toolCalls {
557 select {
558 case <-ctx.Done():
559 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
560 // Make all future tool calls cancelled
561 for j := i; j < len(toolCalls); j++ {
562 toolResults[j] = message.ToolResult{
563 ToolCallID: toolCalls[j].ID,
564 Content: "Tool execution canceled by user",
565 IsError: true,
566 }
567 }
568 goto out
569 default:
570 // Continue processing
571 var tool tools.BaseTool
572 allTools, _ := a.getAllTools()
573 for _, availableTool := range allTools {
574 if availableTool.Info().Name == toolCall.Name {
575 tool = availableTool
576 break
577 }
578 }
579
580 // Tool not found
581 if tool == nil {
582 toolResults[i] = message.ToolResult{
583 ToolCallID: toolCall.ID,
584 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
585 IsError: true,
586 }
587 continue
588 }
589
590 // Run tool in goroutine to allow cancellation
591 type toolExecResult struct {
592 response tools.ToolResponse
593 err error
594 }
595 resultChan := make(chan toolExecResult, 1)
596
597 go func() {
598 response, err := tool.Run(ctx, tools.ToolCall{
599 ID: toolCall.ID,
600 Name: toolCall.Name,
601 Input: toolCall.Input,
602 })
603 resultChan <- toolExecResult{response: response, err: err}
604 }()
605
606 var toolResponse tools.ToolResponse
607 var toolErr error
608
609 select {
610 case <-ctx.Done():
611 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
612 // Mark remaining tool calls as cancelled
613 for j := i; j < len(toolCalls); j++ {
614 toolResults[j] = message.ToolResult{
615 ToolCallID: toolCalls[j].ID,
616 Content: "Tool execution canceled by user",
617 IsError: true,
618 }
619 }
620 goto out
621 case result := <-resultChan:
622 toolResponse = result.response
623 toolErr = result.err
624 }
625
626 if toolErr != nil {
627 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
628 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
629 toolResults[i] = message.ToolResult{
630 ToolCallID: toolCall.ID,
631 Content: "Permission denied",
632 IsError: true,
633 }
634 for j := i + 1; j < len(toolCalls); j++ {
635 toolResults[j] = message.ToolResult{
636 ToolCallID: toolCalls[j].ID,
637 Content: "Tool execution canceled by user",
638 IsError: true,
639 }
640 }
641 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
642 break
643 }
644 }
645 toolResults[i] = message.ToolResult{
646 ToolCallID: toolCall.ID,
647 Content: toolResponse.Content,
648 Metadata: toolResponse.Metadata,
649 IsError: toolResponse.IsError,
650 }
651 }
652 }
653out:
654 if len(toolResults) == 0 {
655 return assistantMsg, nil, nil
656 }
657 parts := make([]message.ContentPart, 0)
658 for _, tr := range toolResults {
659 parts = append(parts, tr)
660 }
661 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
662 Role: message.Tool,
663 Parts: parts,
664 Provider: a.providerID,
665 })
666 if err != nil {
667 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
668 }
669
670 return assistantMsg, &msg, err
671}
672
673func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
674 msg.AddFinish(finishReason, message, details)
675 _ = a.messages.Update(ctx, *msg)
676}
677
678func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
679 select {
680 case <-ctx.Done():
681 return ctx.Err()
682 default:
683 // Continue processing.
684 }
685
686 switch event.Type {
687 case provider.EventThinkingDelta:
688 assistantMsg.AppendReasoningContent(event.Thinking)
689 return a.messages.Update(ctx, *assistantMsg)
690 case provider.EventSignatureDelta:
691 assistantMsg.AppendReasoningSignature(event.Signature)
692 return a.messages.Update(ctx, *assistantMsg)
693 case provider.EventContentDelta:
694 assistantMsg.FinishThinking()
695 assistantMsg.AppendContent(event.Content)
696 return a.messages.Update(ctx, *assistantMsg)
697 case provider.EventToolUseStart:
698 assistantMsg.FinishThinking()
699 slog.Info("Tool call started", "toolCall", event.ToolCall)
700 assistantMsg.AddToolCall(*event.ToolCall)
701 return a.messages.Update(ctx, *assistantMsg)
702 case provider.EventToolUseDelta:
703 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
704 return a.messages.Update(ctx, *assistantMsg)
705 case provider.EventToolUseStop:
706 slog.Info("Finished tool call", "toolCall", event.ToolCall)
707 assistantMsg.FinishToolCall(event.ToolCall.ID)
708 return a.messages.Update(ctx, *assistantMsg)
709 case provider.EventError:
710 return event.Error
711 case provider.EventComplete:
712 assistantMsg.FinishThinking()
713 assistantMsg.SetToolCalls(event.Response.ToolCalls)
714 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
715 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
716 return fmt.Errorf("failed to update message: %w", err)
717 }
718 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
719 }
720
721 return nil
722}
723
724func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
725 sess, err := a.sessions.Get(ctx, sessionID)
726 if err != nil {
727 return fmt.Errorf("failed to get session: %w", err)
728 }
729
730 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
731 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
732 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
733 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
734
735 sess.Cost += cost
736 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
737 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
738
739 _, err = a.sessions.Save(ctx, sess)
740 if err != nil {
741 return fmt.Errorf("failed to save session: %w", err)
742 }
743 return nil
744}
745
746func (a *agent) Summarize(ctx context.Context, sessionID string) error {
747 if a.summarizeProvider == nil {
748 return fmt.Errorf("summarize provider not available")
749 }
750
751 // Check if session is busy
752 if a.IsSessionBusy(sessionID) {
753 return ErrSessionBusy
754 }
755
756 // Create a new context with cancellation
757 summarizeCtx, cancel := context.WithCancel(ctx)
758
759 // Store the cancel function in activeRequests to allow cancellation
760 a.activeRequests.Set(sessionID+"-summarize", cancel)
761
762 go func() {
763 defer a.activeRequests.Del(sessionID + "-summarize")
764 defer cancel()
765 event := AgentEvent{
766 Type: AgentEventTypeSummarize,
767 Progress: "Starting summarization...",
768 }
769
770 a.Publish(pubsub.CreatedEvent, event)
771 // Get all messages from the session
772 msgs, err := a.messages.List(summarizeCtx, sessionID)
773 if err != nil {
774 event = AgentEvent{
775 Type: AgentEventTypeError,
776 Error: fmt.Errorf("failed to list messages: %w", err),
777 Done: true,
778 }
779 a.Publish(pubsub.CreatedEvent, event)
780 return
781 }
782 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
783
784 if len(msgs) == 0 {
785 event = AgentEvent{
786 Type: AgentEventTypeError,
787 Error: fmt.Errorf("no messages to summarize"),
788 Done: true,
789 }
790 a.Publish(pubsub.CreatedEvent, event)
791 return
792 }
793
794 event = AgentEvent{
795 Type: AgentEventTypeSummarize,
796 Progress: "Analyzing conversation...",
797 }
798 a.Publish(pubsub.CreatedEvent, event)
799
800 // Add a system message to guide the summarization
801 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
802
803 // Create a new message with the summarize prompt
804 promptMsg := message.Message{
805 Role: message.User,
806 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
807 }
808
809 // Append the prompt to the messages
810 msgsWithPrompt := append(msgs, promptMsg)
811
812 event = AgentEvent{
813 Type: AgentEventTypeSummarize,
814 Progress: "Generating summary...",
815 }
816
817 a.Publish(pubsub.CreatedEvent, event)
818
819 // Send the messages to the summarize provider
820 response := a.summarizeProvider.StreamResponse(
821 summarizeCtx,
822 msgsWithPrompt,
823 nil,
824 )
825 var finalResponse *provider.ProviderResponse
826 for r := range response {
827 if r.Error != nil {
828 event = AgentEvent{
829 Type: AgentEventTypeError,
830 Error: fmt.Errorf("failed to summarize: %w", err),
831 Done: true,
832 }
833 a.Publish(pubsub.CreatedEvent, event)
834 return
835 }
836 finalResponse = r.Response
837 }
838
839 summary := strings.TrimSpace(finalResponse.Content)
840 if summary == "" {
841 event = AgentEvent{
842 Type: AgentEventTypeError,
843 Error: fmt.Errorf("empty summary returned"),
844 Done: true,
845 }
846 a.Publish(pubsub.CreatedEvent, event)
847 return
848 }
849 shell := shell.GetPersistentShell(config.Get().WorkingDir())
850 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
851 event = AgentEvent{
852 Type: AgentEventTypeSummarize,
853 Progress: "Creating new session...",
854 }
855
856 a.Publish(pubsub.CreatedEvent, event)
857 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
858 if err != nil {
859 event = AgentEvent{
860 Type: AgentEventTypeError,
861 Error: fmt.Errorf("failed to get session: %w", err),
862 Done: true,
863 }
864
865 a.Publish(pubsub.CreatedEvent, event)
866 return
867 }
868 // Create a message in the new session with the summary
869 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
870 Role: message.Assistant,
871 Parts: []message.ContentPart{
872 message.TextContent{Text: summary},
873 message.Finish{
874 Reason: message.FinishReasonEndTurn,
875 Time: time.Now().Unix(),
876 },
877 },
878 Model: a.summarizeProvider.Model().ID,
879 Provider: a.summarizeProviderID,
880 })
881 if err != nil {
882 event = AgentEvent{
883 Type: AgentEventTypeError,
884 Error: fmt.Errorf("failed to create summary message: %w", err),
885 Done: true,
886 }
887
888 a.Publish(pubsub.CreatedEvent, event)
889 return
890 }
891 oldSession.SummaryMessageID = msg.ID
892 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
893 oldSession.PromptTokens = 0
894 model := a.summarizeProvider.Model()
895 usage := finalResponse.Usage
896 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
897 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
898 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
899 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
900 oldSession.Cost += cost
901 _, err = a.sessions.Save(summarizeCtx, oldSession)
902 if err != nil {
903 event = AgentEvent{
904 Type: AgentEventTypeError,
905 Error: fmt.Errorf("failed to save session: %w", err),
906 Done: true,
907 }
908 a.Publish(pubsub.CreatedEvent, event)
909 }
910
911 event = AgentEvent{
912 Type: AgentEventTypeSummarize,
913 SessionID: oldSession.ID,
914 Progress: "Summary complete",
915 Done: true,
916 }
917 a.Publish(pubsub.CreatedEvent, event)
918 // Send final success event with the new session ID
919 }()
920
921 return nil
922}
923
924func (a *agent) ClearQueue(sessionID string) {
925 if a.QueuedPrompts(sessionID) > 0 {
926 slog.Info("Clearing queued prompts", "session_id", sessionID)
927 a.promptQueue.Del(sessionID)
928 }
929}
930
931func (a *agent) CancelAll() {
932 if !a.IsBusy() {
933 return
934 }
935 for key := range a.activeRequests.Seq2() {
936 a.Cancel(key) // key is sessionID
937 }
938
939 timeout := time.After(5 * time.Second)
940 for a.IsBusy() {
941 select {
942 case <-timeout:
943 return
944 default:
945 time.Sleep(200 * time.Millisecond)
946 }
947 }
948}
949
950func (a *agent) UpdateModel() error {
951 cfg := config.Get()
952
953 // Get current provider configuration
954 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
955 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
956 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
957 }
958
959 // Check if provider has changed
960 if string(currentProviderCfg.ID) != a.providerID {
961 // Provider changed, need to recreate the main provider
962 model := cfg.GetModelByType(a.agentCfg.Model)
963 if model.ID == "" {
964 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
965 }
966
967 promptID := agentPromptMap[a.agentCfg.ID]
968 if promptID == "" {
969 promptID = prompt.PromptDefault
970 }
971
972 opts := []provider.ProviderClientOption{
973 provider.WithModel(a.agentCfg.Model),
974 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
975 }
976
977 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
978 if err != nil {
979 return fmt.Errorf("failed to create new provider: %w", err)
980 }
981
982 // Update the provider and provider ID
983 a.provider = newProvider
984 a.providerID = string(currentProviderCfg.ID)
985 }
986
987 // Check if providers have changed for title (small) and summarize (large)
988 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
989 var smallModelProviderCfg config.ProviderConfig
990 for p := range cfg.Providers.Seq() {
991 if p.ID == smallModelCfg.Provider {
992 smallModelProviderCfg = p
993 break
994 }
995 }
996 if smallModelProviderCfg.ID == "" {
997 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
998 }
999
1000 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1001 var largeModelProviderCfg config.ProviderConfig
1002 for p := range cfg.Providers.Seq() {
1003 if p.ID == largeModelCfg.Provider {
1004 largeModelProviderCfg = p
1005 break
1006 }
1007 }
1008 if largeModelProviderCfg.ID == "" {
1009 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1010 }
1011
1012 var maxTitleTokens int64 = 40
1013
1014 // if the max output is too low for the gemini provider it won't return anything
1015 if smallModelCfg.Provider == "gemini" {
1016 maxTitleTokens = 1000
1017 }
1018 // Recreate title provider
1019 titleOpts := []provider.ProviderClientOption{
1020 provider.WithModel(config.SelectedModelTypeSmall),
1021 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1022 provider.WithMaxTokens(maxTitleTokens),
1023 }
1024 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1025 if err != nil {
1026 return fmt.Errorf("failed to create new title provider: %w", err)
1027 }
1028 a.titleProvider = newTitleProvider
1029
1030 // Recreate summarize provider if provider changed (now large model)
1031 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1032 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1033 if largeModel == nil {
1034 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1035 }
1036 summarizeOpts := []provider.ProviderClientOption{
1037 provider.WithModel(config.SelectedModelTypeLarge),
1038 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1039 }
1040 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1041 if err != nil {
1042 return fmt.Errorf("failed to create new summarize provider: %w", err)
1043 }
1044 a.summarizeProvider = newSummarizeProvider
1045 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1046 }
1047
1048 return nil
1049}