1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.LazySlice[tools.BaseTool]
75 // We need this to be able to update it when model changes
76 agentToolFn func() (tools.BaseTool, error)
77
78 provider provider.Provider
79 providerID string
80
81 titleProvider provider.Provider
82 summarizeProvider provider.Provider
83 summarizeProviderID string
84
85 activeRequests *csync.Map[string, context.CancelFunc]
86 promptQueue *csync.Map[string, []string]
87}
88
89var agentPromptMap = map[string]prompt.PromptID{
90 "coder": prompt.PromptCoder,
91 "task": prompt.PromptTask,
92}
93
94func NewAgent(
95 ctx context.Context,
96 agentCfg config.Agent,
97 // These services are needed in the tools
98 permissions permission.Service,
99 sessions session.Service,
100 messages message.Service,
101 history history.Service,
102 lspClients *csync.Map[string, *lsp.Client],
103) (Service, error) {
104 cfg := config.Get()
105
106 var agentToolFn func() (tools.BaseTool, error)
107 if agentCfg.ID == "coder" {
108 agentToolFn = func() (tools.BaseTool, error) {
109 taskAgentCfg := config.Get().Agents["task"]
110 if taskAgentCfg.ID == "" {
111 return nil, fmt.Errorf("task agent not found in config")
112 }
113 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
114 if err != nil {
115 return nil, fmt.Errorf("failed to create task agent: %w", err)
116 }
117 return NewAgentTool(taskAgent, sessions, messages), nil
118 }
119 }
120
121 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
122 if providerCfg == nil {
123 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
124 }
125 model := config.Get().GetModelByType(agentCfg.Model)
126
127 if model == nil {
128 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
129 }
130
131 promptID := agentPromptMap[agentCfg.ID]
132 if promptID == "" {
133 promptID = prompt.PromptDefault
134 }
135 opts := []provider.ProviderClientOption{
136 provider.WithModel(agentCfg.Model),
137 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
138 }
139 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
140 if err != nil {
141 return nil, err
142 }
143
144 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
145 var smallModelProviderCfg *config.ProviderConfig
146 if smallModelCfg.Provider == providerCfg.ID {
147 smallModelProviderCfg = providerCfg
148 } else {
149 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
150
151 if smallModelProviderCfg.ID == "" {
152 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
153 }
154 }
155 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
156 if smallModel.ID == "" {
157 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
158 }
159
160 titleOpts := []provider.ProviderClientOption{
161 provider.WithModel(config.SelectedModelTypeSmall),
162 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
163 }
164 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
165 if err != nil {
166 return nil, err
167 }
168
169 summarizeOpts := []provider.ProviderClientOption{
170 provider.WithModel(config.SelectedModelTypeLarge),
171 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
172 }
173 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
174 if err != nil {
175 return nil, err
176 }
177
178 toolFn := func() []tools.BaseTool {
179 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
180 defer func() {
181 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
182 }()
183
184 cwd := cfg.WorkingDir()
185 allTools := []tools.BaseTool{
186 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
187 tools.NewDownloadTool(permissions, cwd),
188 tools.NewEditTool(lspClients, permissions, history, cwd),
189 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
190 tools.NewFetchTool(permissions, cwd),
191 tools.NewGlobTool(cwd),
192 tools.NewGrepTool(cwd),
193 tools.NewLsTool(permissions, cwd),
194 tools.NewSourcegraphTool(),
195 tools.NewViewTool(lspClients, permissions, cwd),
196 tools.NewWriteTool(lspClients, permissions, history, cwd),
197 }
198
199 mcpToolsOnce.Do(func() {
200 mcpTools = doGetMCPTools(ctx, permissions, cfg)
201 })
202
203 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
204 if agentCfg.ID == "coder" {
205 t = append(t, mcpTools...)
206 if lspClients.Len() > 0 {
207 t = append(t, tools.NewDiagnosticsTool(lspClients))
208 }
209 }
210 return t
211 }
212
213 if agentCfg.AllowedTools == nil {
214 return withCoderTools(allTools)
215 }
216
217 var filteredTools []tools.BaseTool
218 for _, tool := range allTools {
219 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
220 filteredTools = append(filteredTools, tool)
221 }
222 }
223 return withCoderTools(filteredTools)
224 }
225
226 return &agent{
227 Broker: pubsub.NewBroker[AgentEvent](),
228 agentCfg: agentCfg,
229 provider: agentProvider,
230 providerID: string(providerCfg.ID),
231 messages: messages,
232 sessions: sessions,
233 titleProvider: titleProvider,
234 summarizeProvider: summarizeProvider,
235 summarizeProviderID: string(providerCfg.ID),
236 agentToolFn: agentToolFn,
237 activeRequests: csync.NewMap[string, context.CancelFunc](),
238 tools: csync.NewLazySlice(toolFn),
239 promptQueue: csync.NewMap[string, []string](),
240 }, nil
241}
242
243func (a *agent) Model() catwalk.Model {
244 return *config.Get().GetModelByType(a.agentCfg.Model)
245}
246
247func (a *agent) Cancel(sessionID string) {
248 // Cancel regular requests
249 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
250 slog.Info("Request cancellation initiated", "session_id", sessionID)
251 cancel()
252 }
253
254 // Also check for summarize requests
255 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
256 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
257 cancel()
258 }
259
260 if a.QueuedPrompts(sessionID) > 0 {
261 slog.Info("Clearing queued prompts", "session_id", sessionID)
262 a.promptQueue.Del(sessionID)
263 }
264}
265
266func (a *agent) IsBusy() bool {
267 var busy bool
268 for cancelFunc := range a.activeRequests.Seq() {
269 if cancelFunc != nil {
270 busy = true
271 break
272 }
273 }
274 return busy
275}
276
277func (a *agent) IsSessionBusy(sessionID string) bool {
278 _, busy := a.activeRequests.Get(sessionID)
279 return busy
280}
281
282func (a *agent) QueuedPrompts(sessionID string) int {
283 l, ok := a.promptQueue.Get(sessionID)
284 if !ok {
285 return 0
286 }
287 return len(l)
288}
289
290func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
291 if content == "" {
292 return nil
293 }
294 if a.titleProvider == nil {
295 return nil
296 }
297 session, err := a.sessions.Get(ctx, sessionID)
298 if err != nil {
299 return err
300 }
301 parts := []message.ContentPart{message.TextContent{
302 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
303 }}
304
305 // Use streaming approach like summarization
306 response := a.titleProvider.StreamResponse(
307 ctx,
308 []message.Message{
309 {
310 Role: message.User,
311 Parts: parts,
312 },
313 },
314 nil,
315 )
316
317 var finalResponse *provider.ProviderResponse
318 for r := range response {
319 if r.Error != nil {
320 return r.Error
321 }
322 finalResponse = r.Response
323 }
324
325 if finalResponse == nil {
326 return fmt.Errorf("no response received from title provider")
327 }
328
329 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
330
331 if idx := strings.Index(title, "</think>"); idx > 0 {
332 title = title[idx+len("</think>"):]
333 }
334
335 title = strings.TrimSpace(title)
336 if title == "" {
337 return nil
338 }
339
340 session.Title = title
341 _, err = a.sessions.Save(ctx, session)
342 return err
343}
344
345func (a *agent) err(err error) AgentEvent {
346 return AgentEvent{
347 Type: AgentEventTypeError,
348 Error: err,
349 }
350}
351
352func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
353 if !a.Model().SupportsImages && attachments != nil {
354 attachments = nil
355 }
356 events := make(chan AgentEvent, 1)
357 if a.IsSessionBusy(sessionID) {
358 existing, ok := a.promptQueue.Get(sessionID)
359 if !ok {
360 existing = []string{}
361 }
362 existing = append(existing, content)
363 a.promptQueue.Set(sessionID, existing)
364 return nil, nil
365 }
366
367 genCtx, cancel := context.WithCancel(ctx)
368
369 a.activeRequests.Set(sessionID, cancel)
370 go func() {
371 slog.Debug("Request started", "sessionID", sessionID)
372 defer log.RecoverPanic("agent.Run", func() {
373 events <- a.err(fmt.Errorf("panic while running the agent"))
374 })
375 var attachmentParts []message.ContentPart
376 for _, attachment := range attachments {
377 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
378 }
379 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
380 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
381 slog.Error(result.Error.Error())
382 }
383 slog.Debug("Request completed", "sessionID", sessionID)
384 a.activeRequests.Del(sessionID)
385 cancel()
386 a.Publish(pubsub.CreatedEvent, result)
387 events <- result
388 close(events)
389 }()
390 return events, nil
391}
392
393func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
394 cfg := config.Get()
395 // List existing messages; if none, start title generation asynchronously.
396 msgs, err := a.messages.List(ctx, sessionID)
397 if err != nil {
398 return a.err(fmt.Errorf("failed to list messages: %w", err))
399 }
400 if len(msgs) == 0 {
401 go func() {
402 defer log.RecoverPanic("agent.Run", func() {
403 slog.Error("panic while generating title")
404 })
405 titleErr := a.generateTitle(ctx, sessionID, content)
406 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
407 slog.Error("failed to generate title", "error", titleErr)
408 }
409 }()
410 }
411 session, err := a.sessions.Get(ctx, sessionID)
412 if err != nil {
413 return a.err(fmt.Errorf("failed to get session: %w", err))
414 }
415 if session.SummaryMessageID != "" {
416 summaryMsgInex := -1
417 for i, msg := range msgs {
418 if msg.ID == session.SummaryMessageID {
419 summaryMsgInex = i
420 break
421 }
422 }
423 if summaryMsgInex != -1 {
424 msgs = msgs[summaryMsgInex:]
425 msgs[0].Role = message.User
426 }
427 }
428
429 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
430 if err != nil {
431 return a.err(fmt.Errorf("failed to create user message: %w", err))
432 }
433 // Append the new user message to the conversation history.
434 msgHistory := append(msgs, userMsg)
435
436 for {
437 // Check for cancellation before each iteration
438 select {
439 case <-ctx.Done():
440 return a.err(ctx.Err())
441 default:
442 // Continue processing
443 }
444 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
445 if err != nil {
446 if errors.Is(err, context.Canceled) {
447 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
448 a.messages.Update(context.Background(), agentMessage)
449 return a.err(ErrRequestCancelled)
450 }
451 return a.err(fmt.Errorf("failed to process events: %w", err))
452 }
453 if cfg.Options.Debug {
454 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
455 }
456 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
457 // We are not done, we need to respond with the tool response
458 msgHistory = append(msgHistory, agentMessage, *toolResults)
459 // If there are queued prompts, process the next one
460 nextPrompt, ok := a.promptQueue.Take(sessionID)
461 if ok {
462 for _, prompt := range nextPrompt {
463 // Create a new user message for the queued prompt
464 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
465 if err != nil {
466 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
467 }
468 // Append the new user message to the conversation history
469 msgHistory = append(msgHistory, userMsg)
470 }
471 }
472
473 continue
474 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
475 queuePrompts, ok := a.promptQueue.Take(sessionID)
476 if ok {
477 for _, prompt := range queuePrompts {
478 if prompt == "" {
479 continue
480 }
481 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
482 if err != nil {
483 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
484 }
485 msgHistory = append(msgHistory, userMsg)
486 }
487 continue
488 }
489 }
490 if agentMessage.FinishReason() == "" {
491 // Kujtim: could not track down where this is happening but this means its cancelled
492 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
493 _ = a.messages.Update(context.Background(), agentMessage)
494 return a.err(ErrRequestCancelled)
495 }
496 return AgentEvent{
497 Type: AgentEventTypeResponse,
498 Message: agentMessage,
499 Done: true,
500 }
501 }
502}
503
504func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
505 parts := []message.ContentPart{message.TextContent{Text: content}}
506 parts = append(parts, attachmentParts...)
507 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
508 Role: message.User,
509 Parts: parts,
510 })
511}
512
513func (a *agent) getAllTools() ([]tools.BaseTool, error) {
514 allTools := slices.Collect(a.tools.Seq())
515 if a.agentToolFn != nil {
516 agentTool, agentToolErr := a.agentToolFn()
517 if agentToolErr != nil {
518 return nil, agentToolErr
519 }
520 allTools = append(allTools, agentTool)
521 }
522 return allTools, nil
523}
524
525func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
526 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
527
528 // Create the assistant message first so the spinner shows immediately
529 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
530 Role: message.Assistant,
531 Parts: []message.ContentPart{},
532 Model: a.Model().ID,
533 Provider: a.providerID,
534 })
535 if err != nil {
536 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
537 }
538
539 allTools, toolsErr := a.getAllTools()
540 if toolsErr != nil {
541 return assistantMsg, nil, toolsErr
542 }
543 // Now collect tools (which may block on MCP initialization)
544 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
545
546 // Add the session and message ID into the context if needed by tools.
547 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
548
549 // Process each event in the stream.
550 for event := range eventChan {
551 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
552 if errors.Is(processErr, context.Canceled) {
553 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
554 } else {
555 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
556 }
557 return assistantMsg, nil, processErr
558 }
559 if ctx.Err() != nil {
560 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
561 return assistantMsg, nil, ctx.Err()
562 }
563 }
564
565 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
566 toolCalls := assistantMsg.ToolCalls()
567 for i, toolCall := range toolCalls {
568 select {
569 case <-ctx.Done():
570 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
571 // Make all future tool calls cancelled
572 for j := i; j < len(toolCalls); j++ {
573 toolResults[j] = message.ToolResult{
574 ToolCallID: toolCalls[j].ID,
575 Content: "Tool execution canceled by user",
576 IsError: true,
577 }
578 }
579 goto out
580 default:
581 // Continue processing
582 var tool tools.BaseTool
583 allTools, _ := a.getAllTools()
584 for _, availableTool := range allTools {
585 if availableTool.Info().Name == toolCall.Name {
586 tool = availableTool
587 break
588 }
589 }
590
591 // Tool not found
592 if tool == nil {
593 toolResults[i] = message.ToolResult{
594 ToolCallID: toolCall.ID,
595 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
596 IsError: true,
597 }
598 continue
599 }
600
601 // Run tool in goroutine to allow cancellation
602 type toolExecResult struct {
603 response tools.ToolResponse
604 err error
605 }
606 resultChan := make(chan toolExecResult, 1)
607
608 go func() {
609 response, err := tool.Run(ctx, tools.ToolCall{
610 ID: toolCall.ID,
611 Name: toolCall.Name,
612 Input: toolCall.Input,
613 })
614 resultChan <- toolExecResult{response: response, err: err}
615 }()
616
617 var toolResponse tools.ToolResponse
618 var toolErr error
619
620 select {
621 case <-ctx.Done():
622 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
623 // Mark remaining tool calls as cancelled
624 for j := i; j < len(toolCalls); j++ {
625 toolResults[j] = message.ToolResult{
626 ToolCallID: toolCalls[j].ID,
627 Content: "Tool execution canceled by user",
628 IsError: true,
629 }
630 }
631 goto out
632 case result := <-resultChan:
633 toolResponse = result.response
634 toolErr = result.err
635 }
636
637 if toolErr != nil {
638 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
639 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
640 toolResults[i] = message.ToolResult{
641 ToolCallID: toolCall.ID,
642 Content: "Permission denied",
643 IsError: true,
644 }
645 for j := i + 1; j < len(toolCalls); j++ {
646 toolResults[j] = message.ToolResult{
647 ToolCallID: toolCalls[j].ID,
648 Content: "Tool execution canceled by user",
649 IsError: true,
650 }
651 }
652 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
653 break
654 }
655 }
656 toolResults[i] = message.ToolResult{
657 ToolCallID: toolCall.ID,
658 Content: toolResponse.Content,
659 Metadata: toolResponse.Metadata,
660 IsError: toolResponse.IsError,
661 }
662 }
663 }
664out:
665 if len(toolResults) == 0 {
666 return assistantMsg, nil, nil
667 }
668 parts := make([]message.ContentPart, 0)
669 for _, tr := range toolResults {
670 parts = append(parts, tr)
671 }
672 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
673 Role: message.Tool,
674 Parts: parts,
675 Provider: a.providerID,
676 })
677 if err != nil {
678 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
679 }
680
681 return assistantMsg, &msg, err
682}
683
684func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
685 msg.AddFinish(finishReason, message, details)
686 _ = a.messages.Update(ctx, *msg)
687}
688
689func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
690 select {
691 case <-ctx.Done():
692 return ctx.Err()
693 default:
694 // Continue processing.
695 }
696
697 switch event.Type {
698 case provider.EventThinkingDelta:
699 assistantMsg.AppendReasoningContent(event.Thinking)
700 return a.messages.Update(ctx, *assistantMsg)
701 case provider.EventSignatureDelta:
702 assistantMsg.AppendReasoningSignature(event.Signature)
703 return a.messages.Update(ctx, *assistantMsg)
704 case provider.EventContentDelta:
705 assistantMsg.FinishThinking()
706 assistantMsg.AppendContent(event.Content)
707 return a.messages.Update(ctx, *assistantMsg)
708 case provider.EventToolUseStart:
709 assistantMsg.FinishThinking()
710 slog.Info("Tool call started", "toolCall", event.ToolCall)
711 assistantMsg.AddToolCall(*event.ToolCall)
712 return a.messages.Update(ctx, *assistantMsg)
713 case provider.EventToolUseDelta:
714 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
715 return a.messages.Update(ctx, *assistantMsg)
716 case provider.EventToolUseStop:
717 slog.Info("Finished tool call", "toolCall", event.ToolCall)
718 assistantMsg.FinishToolCall(event.ToolCall.ID)
719 return a.messages.Update(ctx, *assistantMsg)
720 case provider.EventError:
721 return event.Error
722 case provider.EventComplete:
723 assistantMsg.FinishThinking()
724 assistantMsg.SetToolCalls(event.Response.ToolCalls)
725 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
726 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
727 return fmt.Errorf("failed to update message: %w", err)
728 }
729 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
730 }
731
732 return nil
733}
734
735func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
736 sess, err := a.sessions.Get(ctx, sessionID)
737 if err != nil {
738 return fmt.Errorf("failed to get session: %w", err)
739 }
740
741 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
742 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
743 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
744 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
745
746 sess.Cost += cost
747 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
748 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
749
750 _, err = a.sessions.Save(ctx, sess)
751 if err != nil {
752 return fmt.Errorf("failed to save session: %w", err)
753 }
754 return nil
755}
756
757func (a *agent) Summarize(ctx context.Context, sessionID string) error {
758 if a.summarizeProvider == nil {
759 return fmt.Errorf("summarize provider not available")
760 }
761
762 // Check if session is busy
763 if a.IsSessionBusy(sessionID) {
764 return ErrSessionBusy
765 }
766
767 // Create a new context with cancellation
768 summarizeCtx, cancel := context.WithCancel(ctx)
769
770 // Store the cancel function in activeRequests to allow cancellation
771 a.activeRequests.Set(sessionID+"-summarize", cancel)
772
773 go func() {
774 defer a.activeRequests.Del(sessionID + "-summarize")
775 defer cancel()
776 event := AgentEvent{
777 Type: AgentEventTypeSummarize,
778 Progress: "Starting summarization...",
779 }
780
781 a.Publish(pubsub.CreatedEvent, event)
782 // Get all messages from the session
783 msgs, err := a.messages.List(summarizeCtx, sessionID)
784 if err != nil {
785 event = AgentEvent{
786 Type: AgentEventTypeError,
787 Error: fmt.Errorf("failed to list messages: %w", err),
788 Done: true,
789 }
790 a.Publish(pubsub.CreatedEvent, event)
791 return
792 }
793 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
794
795 if len(msgs) == 0 {
796 event = AgentEvent{
797 Type: AgentEventTypeError,
798 Error: fmt.Errorf("no messages to summarize"),
799 Done: true,
800 }
801 a.Publish(pubsub.CreatedEvent, event)
802 return
803 }
804
805 event = AgentEvent{
806 Type: AgentEventTypeSummarize,
807 Progress: "Analyzing conversation...",
808 }
809 a.Publish(pubsub.CreatedEvent, event)
810
811 // Add a system message to guide the summarization
812 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
813
814 // Create a new message with the summarize prompt
815 promptMsg := message.Message{
816 Role: message.User,
817 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
818 }
819
820 // Append the prompt to the messages
821 msgsWithPrompt := append(msgs, promptMsg)
822
823 event = AgentEvent{
824 Type: AgentEventTypeSummarize,
825 Progress: "Generating summary...",
826 }
827
828 a.Publish(pubsub.CreatedEvent, event)
829
830 // Send the messages to the summarize provider
831 response := a.summarizeProvider.StreamResponse(
832 summarizeCtx,
833 msgsWithPrompt,
834 nil,
835 )
836 var finalResponse *provider.ProviderResponse
837 for r := range response {
838 if r.Error != nil {
839 event = AgentEvent{
840 Type: AgentEventTypeError,
841 Error: fmt.Errorf("failed to summarize: %w", r.Error),
842 Done: true,
843 }
844 a.Publish(pubsub.CreatedEvent, event)
845 return
846 }
847 finalResponse = r.Response
848 }
849
850 summary := strings.TrimSpace(finalResponse.Content)
851 if summary == "" {
852 event = AgentEvent{
853 Type: AgentEventTypeError,
854 Error: fmt.Errorf("empty summary returned"),
855 Done: true,
856 }
857 a.Publish(pubsub.CreatedEvent, event)
858 return
859 }
860 shell := shell.GetPersistentShell(config.Get().WorkingDir())
861 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
862 event = AgentEvent{
863 Type: AgentEventTypeSummarize,
864 Progress: "Creating new session...",
865 }
866
867 a.Publish(pubsub.CreatedEvent, event)
868 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
869 if err != nil {
870 event = AgentEvent{
871 Type: AgentEventTypeError,
872 Error: fmt.Errorf("failed to get session: %w", err),
873 Done: true,
874 }
875
876 a.Publish(pubsub.CreatedEvent, event)
877 return
878 }
879 // Create a message in the new session with the summary
880 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
881 Role: message.Assistant,
882 Parts: []message.ContentPart{
883 message.TextContent{Text: summary},
884 message.Finish{
885 Reason: message.FinishReasonEndTurn,
886 Time: time.Now().Unix(),
887 },
888 },
889 Model: a.summarizeProvider.Model().ID,
890 Provider: a.summarizeProviderID,
891 })
892 if err != nil {
893 event = AgentEvent{
894 Type: AgentEventTypeError,
895 Error: fmt.Errorf("failed to create summary message: %w", err),
896 Done: true,
897 }
898
899 a.Publish(pubsub.CreatedEvent, event)
900 return
901 }
902 oldSession.SummaryMessageID = msg.ID
903 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
904 oldSession.PromptTokens = 0
905 model := a.summarizeProvider.Model()
906 usage := finalResponse.Usage
907 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
908 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
909 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
910 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
911 oldSession.Cost += cost
912 _, err = a.sessions.Save(summarizeCtx, oldSession)
913 if err != nil {
914 event = AgentEvent{
915 Type: AgentEventTypeError,
916 Error: fmt.Errorf("failed to save session: %w", err),
917 Done: true,
918 }
919 a.Publish(pubsub.CreatedEvent, event)
920 }
921
922 event = AgentEvent{
923 Type: AgentEventTypeSummarize,
924 SessionID: oldSession.ID,
925 Progress: "Summary complete",
926 Done: true,
927 }
928 a.Publish(pubsub.CreatedEvent, event)
929 // Send final success event with the new session ID
930 }()
931
932 return nil
933}
934
935func (a *agent) ClearQueue(sessionID string) {
936 if a.QueuedPrompts(sessionID) > 0 {
937 slog.Info("Clearing queued prompts", "session_id", sessionID)
938 a.promptQueue.Del(sessionID)
939 }
940}
941
942func (a *agent) CancelAll() {
943 if !a.IsBusy() {
944 return
945 }
946 for key := range a.activeRequests.Seq2() {
947 a.Cancel(key) // key is sessionID
948 }
949
950 timeout := time.After(5 * time.Second)
951 for a.IsBusy() {
952 select {
953 case <-timeout:
954 return
955 default:
956 time.Sleep(200 * time.Millisecond)
957 }
958 }
959}
960
961func (a *agent) UpdateModel() error {
962 cfg := config.Get()
963
964 // Get current provider configuration
965 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
966 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
967 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
968 }
969
970 // Check if provider has changed
971 if string(currentProviderCfg.ID) != a.providerID {
972 // Provider changed, need to recreate the main provider
973 model := cfg.GetModelByType(a.agentCfg.Model)
974 if model.ID == "" {
975 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
976 }
977
978 promptID := agentPromptMap[a.agentCfg.ID]
979 if promptID == "" {
980 promptID = prompt.PromptDefault
981 }
982
983 opts := []provider.ProviderClientOption{
984 provider.WithModel(a.agentCfg.Model),
985 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
986 }
987
988 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
989 if err != nil {
990 return fmt.Errorf("failed to create new provider: %w", err)
991 }
992
993 // Update the provider and provider ID
994 a.provider = newProvider
995 a.providerID = string(currentProviderCfg.ID)
996 }
997
998 // Check if providers have changed for title (small) and summarize (large)
999 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1000 var smallModelProviderCfg config.ProviderConfig
1001 for p := range cfg.Providers.Seq() {
1002 if p.ID == smallModelCfg.Provider {
1003 smallModelProviderCfg = p
1004 break
1005 }
1006 }
1007 if smallModelProviderCfg.ID == "" {
1008 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1009 }
1010
1011 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1012 var largeModelProviderCfg config.ProviderConfig
1013 for p := range cfg.Providers.Seq() {
1014 if p.ID == largeModelCfg.Provider {
1015 largeModelProviderCfg = p
1016 break
1017 }
1018 }
1019 if largeModelProviderCfg.ID == "" {
1020 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1021 }
1022
1023 var maxTitleTokens int64 = 40
1024
1025 // if the max output is too low for the gemini provider it won't return anything
1026 if smallModelCfg.Provider == "gemini" {
1027 maxTitleTokens = 1000
1028 }
1029 // Recreate title provider
1030 titleOpts := []provider.ProviderClientOption{
1031 provider.WithModel(config.SelectedModelTypeSmall),
1032 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1033 provider.WithMaxTokens(maxTitleTokens),
1034 }
1035 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1036 if err != nil {
1037 return fmt.Errorf("failed to create new title provider: %w", err)
1038 }
1039 a.titleProvider = newTitleProvider
1040
1041 // Recreate summarize provider if provider changed (now large model)
1042 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1043 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1044 if largeModel == nil {
1045 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1046 }
1047 summarizeOpts := []provider.ProviderClientOption{
1048 provider.WithModel(config.SelectedModelTypeLarge),
1049 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1050 }
1051 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1052 if err != nil {
1053 return fmt.Errorf("failed to create new summarize provider: %w", err)
1054 }
1055 a.summarizeProvider = newSummarizeProvider
1056 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1057 }
1058
1059 return nil
1060}