1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/event"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26 "github.com/charmbracelet/crush/internal/shell"
27)
28
29const streamChunkTimeout = 80 * time.Second
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() catwalk.Model
53 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 Summarize(ctx context.Context, sessionID string) error
59 UpdateModel() error
60 QueuedPrompts(sessionID string) int
61 ClearQueue(sessionID string)
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69 permissions permission.Service
70 mcpTools []McpTool
71
72 tools *csync.LazySlice[tools.BaseTool]
73 // We need this to be able to update it when model changes
74 agentToolFn func() (tools.BaseTool, error)
75
76 provider provider.Provider
77 providerID string
78
79 titleProvider provider.Provider
80 summarizeProvider provider.Provider
81 summarizeProviderID string
82
83 activeRequests *csync.Map[string, context.CancelFunc]
84 promptQueue *csync.Map[string, []string]
85}
86
87var agentPromptMap = map[string]prompt.PromptID{
88 "coder": prompt.PromptCoder,
89 "task": prompt.PromptTask,
90}
91
92func NewAgent(
93 ctx context.Context,
94 agentCfg config.Agent,
95 // These services are needed in the tools
96 permissions permission.Service,
97 sessions session.Service,
98 messages message.Service,
99 history history.Service,
100 lspClients *csync.Map[string, *lsp.Client],
101) (Service, error) {
102 cfg := config.Get()
103
104 var agentToolFn func() (tools.BaseTool, error)
105 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
106 agentToolFn = func() (tools.BaseTool, error) {
107 taskAgentCfg := config.Get().Agents["task"]
108 if taskAgentCfg.ID == "" {
109 return nil, fmt.Errorf("task agent not found in config")
110 }
111 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
112 if err != nil {
113 return nil, fmt.Errorf("failed to create task agent: %w", err)
114 }
115 return NewAgentTool(taskAgent, sessions, messages), nil
116 }
117 }
118
119 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
120 if providerCfg == nil {
121 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
122 }
123 model := config.Get().GetModelByType(agentCfg.Model)
124
125 if model == nil {
126 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
127 }
128
129 promptID := agentPromptMap[agentCfg.ID]
130 if promptID == "" {
131 promptID = prompt.PromptDefault
132 }
133 opts := []provider.ProviderClientOption{
134 provider.WithModel(agentCfg.Model),
135 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
136 }
137 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
138 if err != nil {
139 return nil, err
140 }
141
142 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
143 var smallModelProviderCfg *config.ProviderConfig
144 if smallModelCfg.Provider == providerCfg.ID {
145 smallModelProviderCfg = providerCfg
146 } else {
147 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
148
149 if smallModelProviderCfg.ID == "" {
150 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
151 }
152 }
153 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
154 if smallModel.ID == "" {
155 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
156 }
157
158 titleOpts := []provider.ProviderClientOption{
159 provider.WithModel(config.SelectedModelTypeSmall),
160 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
161 }
162 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
163 if err != nil {
164 return nil, err
165 }
166
167 summarizeOpts := []provider.ProviderClientOption{
168 provider.WithModel(config.SelectedModelTypeLarge),
169 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
170 }
171 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
172 if err != nil {
173 return nil, err
174 }
175
176 toolFn := func() []tools.BaseTool {
177 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
178 defer func() {
179 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
180 }()
181
182 cwd := cfg.WorkingDir()
183 allTools := []tools.BaseTool{
184 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
185 tools.NewDownloadTool(permissions, cwd),
186 tools.NewEditTool(lspClients, permissions, history, cwd),
187 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
188 tools.NewFetchTool(permissions, cwd),
189 tools.NewGlobTool(cwd),
190 tools.NewGrepTool(cwd),
191 tools.NewLsTool(permissions, cwd),
192 tools.NewSourcegraphTool(),
193 tools.NewViewTool(lspClients, permissions, cwd),
194 tools.NewWriteTool(lspClients, permissions, history, cwd),
195 }
196
197 mcpToolsOnce.Do(func() {
198 mcpTools = doGetMCPTools(ctx, permissions, cfg)
199 })
200
201 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
202 if agentCfg.ID == "coder" {
203 t = append(t, mcpTools...)
204 if lspClients.Len() > 0 {
205 t = append(t, tools.NewDiagnosticsTool(lspClients))
206 }
207 }
208 return t
209 }
210
211 if agentCfg.AllowedTools == nil {
212 return withCoderTools(allTools)
213 }
214
215 var filteredTools []tools.BaseTool
216 for _, tool := range allTools {
217 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
218 filteredTools = append(filteredTools, tool)
219 }
220 }
221 return withCoderTools(filteredTools)
222 }
223
224 return &agent{
225 Broker: pubsub.NewBroker[AgentEvent](),
226 agentCfg: agentCfg,
227 provider: agentProvider,
228 providerID: string(providerCfg.ID),
229 messages: messages,
230 sessions: sessions,
231 titleProvider: titleProvider,
232 summarizeProvider: summarizeProvider,
233 summarizeProviderID: string(providerCfg.ID),
234 agentToolFn: agentToolFn,
235 activeRequests: csync.NewMap[string, context.CancelFunc](),
236 tools: csync.NewLazySlice(toolFn),
237 promptQueue: csync.NewMap[string, []string](),
238 permissions: permissions,
239 }, nil
240}
241
242func (a *agent) Model() catwalk.Model {
243 return *config.Get().GetModelByType(a.agentCfg.Model)
244}
245
246func (a *agent) Cancel(sessionID string) {
247 // Cancel regular requests
248 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
249 slog.Info("Request cancellation initiated", "session_id", sessionID)
250 cancel()
251 }
252
253 // Also check for summarize requests
254 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
255 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
256 cancel()
257 }
258
259 if a.QueuedPrompts(sessionID) > 0 {
260 slog.Info("Clearing queued prompts", "session_id", sessionID)
261 a.promptQueue.Del(sessionID)
262 }
263}
264
265func (a *agent) IsBusy() bool {
266 var busy bool
267 for cancelFunc := range a.activeRequests.Seq() {
268 if cancelFunc != nil {
269 busy = true
270 break
271 }
272 }
273 return busy
274}
275
276func (a *agent) IsSessionBusy(sessionID string) bool {
277 _, busy := a.activeRequests.Get(sessionID)
278 return busy
279}
280
281func (a *agent) QueuedPrompts(sessionID string) int {
282 l, ok := a.promptQueue.Get(sessionID)
283 if !ok {
284 return 0
285 }
286 return len(l)
287}
288
289func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
290 if content == "" {
291 return nil
292 }
293 if a.titleProvider == nil {
294 return nil
295 }
296 session, err := a.sessions.Get(ctx, sessionID)
297 if err != nil {
298 return err
299 }
300 parts := []message.ContentPart{message.TextContent{
301 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
302 }}
303
304 // Use streaming approach like summarization
305 response := a.titleProvider.StreamResponse(
306 ctx,
307 []message.Message{
308 {
309 Role: message.User,
310 Parts: parts,
311 },
312 },
313 nil,
314 )
315
316 var finalResponse *provider.ProviderResponse
317 for r := range response {
318 if r.Error != nil {
319 return r.Error
320 }
321 finalResponse = r.Response
322 }
323
324 if finalResponse == nil {
325 return fmt.Errorf("no response received from title provider")
326 }
327
328 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
329
330 if idx := strings.Index(title, "</think>"); idx > 0 {
331 title = title[idx+len("</think>"):]
332 }
333
334 title = strings.TrimSpace(title)
335 if title == "" {
336 return nil
337 }
338
339 session.Title = title
340 _, err = a.sessions.Save(ctx, session)
341 return err
342}
343
344func (a *agent) err(err error) AgentEvent {
345 return AgentEvent{
346 Type: AgentEventTypeError,
347 Error: err,
348 }
349}
350
351func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
352 if !a.Model().SupportsImages && attachments != nil {
353 attachments = nil
354 }
355 events := make(chan AgentEvent, 1)
356 if a.IsSessionBusy(sessionID) {
357 existing, ok := a.promptQueue.Get(sessionID)
358 if !ok {
359 existing = []string{}
360 }
361 existing = append(existing, content)
362 a.promptQueue.Set(sessionID, existing)
363 return nil, nil
364 }
365
366 genCtx, cancel := context.WithCancel(ctx)
367 a.activeRequests.Set(sessionID, cancel)
368 startTime := time.Now()
369
370 go func() {
371 slog.Debug("Request started", "sessionID", sessionID)
372 defer log.RecoverPanic("agent.Run", func() {
373 events <- a.err(fmt.Errorf("panic while running the agent"))
374 })
375 var attachmentParts []message.ContentPart
376 for _, attachment := range attachments {
377 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
378 }
379 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
380 if result.Error != nil {
381 if isCancelledErr(result.Error) {
382 slog.Error("Request canceled", "sessionID", sessionID)
383 } else {
384 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
385 event.Error(result.Error)
386 }
387 } else {
388 slog.Debug("Request completed", "sessionID", sessionID)
389 }
390 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
391 a.activeRequests.Del(sessionID)
392 cancel()
393 a.Publish(pubsub.CreatedEvent, result)
394 events <- result
395 close(events)
396 }()
397 a.eventPromptSent(sessionID)
398 return events, nil
399}
400
401func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
402 cfg := config.Get()
403 // List existing messages; if none, start title generation asynchronously.
404 msgs, err := a.messages.List(ctx, sessionID)
405 if err != nil {
406 return a.err(fmt.Errorf("failed to list messages: %w", err))
407 }
408 if len(msgs) == 0 {
409 go func() {
410 defer log.RecoverPanic("agent.Run", func() {
411 slog.Error("panic while generating title")
412 })
413 titleErr := a.generateTitle(ctx, sessionID, content)
414 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
415 slog.Error("failed to generate title", "error", titleErr)
416 }
417 }()
418 }
419 session, err := a.sessions.Get(ctx, sessionID)
420 if err != nil {
421 return a.err(fmt.Errorf("failed to get session: %w", err))
422 }
423 if session.SummaryMessageID != "" {
424 summaryMsgInex := -1
425 for i, msg := range msgs {
426 if msg.ID == session.SummaryMessageID {
427 summaryMsgInex = i
428 break
429 }
430 }
431 if summaryMsgInex != -1 {
432 msgs = msgs[summaryMsgInex:]
433 msgs[0].Role = message.User
434 }
435 }
436
437 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
438 if err != nil {
439 return a.err(fmt.Errorf("failed to create user message: %w", err))
440 }
441 // Append the new user message to the conversation history.
442 msgHistory := append(msgs, userMsg)
443
444 for {
445 // Check for cancellation before each iteration
446 select {
447 case <-ctx.Done():
448 return a.err(ctx.Err())
449 default:
450 // Continue processing
451 }
452 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
453 if err != nil {
454 if errors.Is(err, context.Canceled) {
455 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
456 a.messages.Update(context.Background(), agentMessage)
457 return a.err(ErrRequestCancelled)
458 }
459 return a.err(fmt.Errorf("failed to process events: %w", err))
460 }
461 if cfg.Options.Debug {
462 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
463 }
464 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
465 // We are not done, we need to respond with the tool response
466 msgHistory = append(msgHistory, agentMessage, *toolResults)
467 // If there are queued prompts, process the next one
468 nextPrompt, ok := a.promptQueue.Take(sessionID)
469 if ok {
470 for _, prompt := range nextPrompt {
471 // Create a new user message for the queued prompt
472 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
473 if err != nil {
474 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
475 }
476 // Append the new user message to the conversation history
477 msgHistory = append(msgHistory, userMsg)
478 }
479 }
480
481 continue
482 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
483 queuePrompts, ok := a.promptQueue.Take(sessionID)
484 if ok {
485 for _, prompt := range queuePrompts {
486 if prompt == "" {
487 continue
488 }
489 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
490 if err != nil {
491 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
492 }
493 msgHistory = append(msgHistory, userMsg)
494 }
495 continue
496 }
497 }
498 if agentMessage.FinishReason() == "" {
499 // Kujtim: could not track down where this is happening but this means its cancelled
500 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
501 _ = a.messages.Update(context.Background(), agentMessage)
502 return a.err(ErrRequestCancelled)
503 }
504 return AgentEvent{
505 Type: AgentEventTypeResponse,
506 Message: agentMessage,
507 Done: true,
508 }
509 }
510}
511
512func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
513 parts := []message.ContentPart{message.TextContent{Text: content}}
514 parts = append(parts, attachmentParts...)
515 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
516 Role: message.User,
517 Parts: parts,
518 })
519}
520
521func (a *agent) getAllTools() ([]tools.BaseTool, error) {
522 allTools := slices.Collect(a.tools.Seq())
523 if a.agentToolFn != nil {
524 agentTool, agentToolErr := a.agentToolFn()
525 if agentToolErr != nil {
526 return nil, agentToolErr
527 }
528 allTools = append(allTools, agentTool)
529 }
530 return allTools, nil
531}
532
533func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
534 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
535
536 // Create the assistant message first so the spinner shows immediately
537 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
538 Role: message.Assistant,
539 Parts: []message.ContentPart{},
540 Model: a.Model().ID,
541 Provider: a.providerID,
542 })
543 if err != nil {
544 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
545 }
546
547 allTools, toolsErr := a.getAllTools()
548 if toolsErr != nil {
549 return assistantMsg, nil, toolsErr
550 }
551 // Now collect tools (which may block on MCP initialization)
552 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
553
554 // Add the session and message ID into the context if needed by tools.
555 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
556
557 // Process each event in the stream.
558 timer := time.NewTimer(streamChunkTimeout)
559 defer timer.Stop()
560
561loop:
562 for {
563 select {
564 case event, ok := <-eventChan:
565 if !ok {
566 break loop
567 }
568 // Reset the timeout timer since we received a chunk
569 timer.Reset(streamChunkTimeout)
570
571 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
572 if errors.Is(processErr, context.Canceled) {
573 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
574 } else {
575 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
576 }
577 return assistantMsg, nil, processErr
578 }
579 case <-timer.C:
580 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "Stream timeout", "No chunk received within timeout")
581 return assistantMsg, nil, ErrStreamTimeout
582 case <-ctx.Done():
583 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
584 return assistantMsg, nil, ctx.Err()
585 }
586 }
587
588 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
589 toolCalls := assistantMsg.ToolCalls()
590 for i, toolCall := range toolCalls {
591 select {
592 case <-ctx.Done():
593 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
594 // Make all future tool calls cancelled
595 for j := i; j < len(toolCalls); j++ {
596 toolResults[j] = message.ToolResult{
597 ToolCallID: toolCalls[j].ID,
598 Content: "Tool execution canceled by user",
599 IsError: true,
600 }
601 }
602 goto out
603 default:
604 // Continue processing
605 var tool tools.BaseTool
606 allTools, _ := a.getAllTools()
607 for _, availableTool := range allTools {
608 if availableTool.Info().Name == toolCall.Name {
609 tool = availableTool
610 break
611 }
612 }
613
614 // Tool not found
615 if tool == nil {
616 toolResults[i] = message.ToolResult{
617 ToolCallID: toolCall.ID,
618 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
619 IsError: true,
620 }
621 continue
622 }
623
624 // Run tool in goroutine to allow cancellation
625 type toolExecResult struct {
626 response tools.ToolResponse
627 err error
628 }
629 resultChan := make(chan toolExecResult, 1)
630
631 go func() {
632 response, err := tool.Run(ctx, tools.ToolCall{
633 ID: toolCall.ID,
634 Name: toolCall.Name,
635 Input: toolCall.Input,
636 })
637 resultChan <- toolExecResult{response: response, err: err}
638 }()
639
640 var toolResponse tools.ToolResponse
641 var toolErr error
642
643 select {
644 case <-ctx.Done():
645 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
646 // Mark remaining tool calls as cancelled
647 for j := i; j < len(toolCalls); j++ {
648 toolResults[j] = message.ToolResult{
649 ToolCallID: toolCalls[j].ID,
650 Content: "Tool execution canceled by user",
651 IsError: true,
652 }
653 }
654 goto out
655 case result := <-resultChan:
656 toolResponse = result.response
657 toolErr = result.err
658 }
659
660 if toolErr != nil {
661 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
662 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
663 toolResults[i] = message.ToolResult{
664 ToolCallID: toolCall.ID,
665 Content: "Permission denied",
666 IsError: true,
667 }
668 for j := i + 1; j < len(toolCalls); j++ {
669 toolResults[j] = message.ToolResult{
670 ToolCallID: toolCalls[j].ID,
671 Content: "Tool execution canceled by user",
672 IsError: true,
673 }
674 }
675 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
676 break
677 }
678 }
679 toolResults[i] = message.ToolResult{
680 ToolCallID: toolCall.ID,
681 Content: toolResponse.Content,
682 Metadata: toolResponse.Metadata,
683 IsError: toolResponse.IsError,
684 }
685 }
686 }
687out:
688 if len(toolResults) == 0 {
689 return assistantMsg, nil, nil
690 }
691 parts := make([]message.ContentPart, 0)
692 for _, tr := range toolResults {
693 parts = append(parts, tr)
694 }
695 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
696 Role: message.Tool,
697 Parts: parts,
698 Provider: a.providerID,
699 })
700 if err != nil {
701 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
702 }
703
704 return assistantMsg, &msg, err
705}
706
707func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
708 msg.AddFinish(finishReason, message, details)
709 _ = a.messages.Update(ctx, *msg)
710}
711
712func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
713 select {
714 case <-ctx.Done():
715 return ctx.Err()
716 default:
717 // Continue processing.
718 }
719
720 switch event.Type {
721 case provider.EventThinkingDelta:
722 assistantMsg.AppendReasoningContent(event.Thinking)
723 return a.messages.Update(ctx, *assistantMsg)
724 case provider.EventSignatureDelta:
725 assistantMsg.AppendReasoningSignature(event.Signature)
726 return a.messages.Update(ctx, *assistantMsg)
727 case provider.EventContentDelta:
728 assistantMsg.FinishThinking()
729 assistantMsg.AppendContent(event.Content)
730 return a.messages.Update(ctx, *assistantMsg)
731 case provider.EventToolUseStart:
732 assistantMsg.FinishThinking()
733 slog.Info("Tool call started", "toolCall", event.ToolCall)
734 assistantMsg.AddToolCall(*event.ToolCall)
735 return a.messages.Update(ctx, *assistantMsg)
736 case provider.EventToolUseDelta:
737 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
738 return a.messages.Update(ctx, *assistantMsg)
739 case provider.EventToolUseStop:
740 slog.Info("Finished tool call", "toolCall", event.ToolCall)
741 assistantMsg.FinishToolCall(event.ToolCall.ID)
742 return a.messages.Update(ctx, *assistantMsg)
743 case provider.EventError:
744 return event.Error
745 case provider.EventComplete:
746 assistantMsg.FinishThinking()
747 assistantMsg.SetToolCalls(event.Response.ToolCalls)
748 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
749 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
750 return fmt.Errorf("failed to update message: %w", err)
751 }
752 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
753 }
754
755 return nil
756}
757
758func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
759 sess, err := a.sessions.Get(ctx, sessionID)
760 if err != nil {
761 return fmt.Errorf("failed to get session: %w", err)
762 }
763
764 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
765 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
766 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
767 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
768
769 a.eventTokensUsed(sessionID, usage, cost)
770
771 sess.Cost += cost
772 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
773 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
774
775 _, err = a.sessions.Save(ctx, sess)
776 if err != nil {
777 return fmt.Errorf("failed to save session: %w", err)
778 }
779 return nil
780}
781
782func (a *agent) Summarize(ctx context.Context, sessionID string) error {
783 if a.summarizeProvider == nil {
784 return fmt.Errorf("summarize provider not available")
785 }
786
787 // Check if session is busy
788 if a.IsSessionBusy(sessionID) {
789 return ErrSessionBusy
790 }
791
792 // Create a new context with cancellation
793 summarizeCtx, cancel := context.WithCancel(ctx)
794
795 // Store the cancel function in activeRequests to allow cancellation
796 a.activeRequests.Set(sessionID+"-summarize", cancel)
797
798 go func() {
799 defer a.activeRequests.Del(sessionID + "-summarize")
800 defer cancel()
801 event := AgentEvent{
802 Type: AgentEventTypeSummarize,
803 Progress: "Starting summarization...",
804 }
805
806 a.Publish(pubsub.CreatedEvent, event)
807 // Get all messages from the session
808 msgs, err := a.messages.List(summarizeCtx, sessionID)
809 if err != nil {
810 event = AgentEvent{
811 Type: AgentEventTypeError,
812 Error: fmt.Errorf("failed to list messages: %w", err),
813 Done: true,
814 }
815 a.Publish(pubsub.CreatedEvent, event)
816 return
817 }
818 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
819
820 if len(msgs) == 0 {
821 event = AgentEvent{
822 Type: AgentEventTypeError,
823 Error: fmt.Errorf("no messages to summarize"),
824 Done: true,
825 }
826 a.Publish(pubsub.CreatedEvent, event)
827 return
828 }
829
830 event = AgentEvent{
831 Type: AgentEventTypeSummarize,
832 Progress: "Analyzing conversation...",
833 }
834 a.Publish(pubsub.CreatedEvent, event)
835
836 // Add a system message to guide the summarization
837 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
838
839 // Create a new message with the summarize prompt
840 promptMsg := message.Message{
841 Role: message.User,
842 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
843 }
844
845 // Append the prompt to the messages
846 msgsWithPrompt := append(msgs, promptMsg)
847
848 event = AgentEvent{
849 Type: AgentEventTypeSummarize,
850 Progress: "Generating summary...",
851 }
852
853 a.Publish(pubsub.CreatedEvent, event)
854
855 // Send the messages to the summarize provider
856 response := a.summarizeProvider.StreamResponse(
857 summarizeCtx,
858 msgsWithPrompt,
859 nil,
860 )
861 var finalResponse *provider.ProviderResponse
862 for r := range response {
863 if r.Error != nil {
864 event = AgentEvent{
865 Type: AgentEventTypeError,
866 Error: fmt.Errorf("failed to summarize: %w", r.Error),
867 Done: true,
868 }
869 a.Publish(pubsub.CreatedEvent, event)
870 return
871 }
872 finalResponse = r.Response
873 }
874
875 summary := strings.TrimSpace(finalResponse.Content)
876 if summary == "" {
877 event = AgentEvent{
878 Type: AgentEventTypeError,
879 Error: fmt.Errorf("empty summary returned"),
880 Done: true,
881 }
882 a.Publish(pubsub.CreatedEvent, event)
883 return
884 }
885 shell := shell.GetPersistentShell(config.Get().WorkingDir())
886 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
887 event = AgentEvent{
888 Type: AgentEventTypeSummarize,
889 Progress: "Creating new session...",
890 }
891
892 a.Publish(pubsub.CreatedEvent, event)
893 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
894 if err != nil {
895 event = AgentEvent{
896 Type: AgentEventTypeError,
897 Error: fmt.Errorf("failed to get session: %w", err),
898 Done: true,
899 }
900
901 a.Publish(pubsub.CreatedEvent, event)
902 return
903 }
904 // Create a message in the new session with the summary
905 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
906 Role: message.Assistant,
907 Parts: []message.ContentPart{
908 message.TextContent{Text: summary},
909 message.Finish{
910 Reason: message.FinishReasonEndTurn,
911 Time: time.Now().Unix(),
912 },
913 },
914 Model: a.summarizeProvider.Model().ID,
915 Provider: a.summarizeProviderID,
916 })
917 if err != nil {
918 event = AgentEvent{
919 Type: AgentEventTypeError,
920 Error: fmt.Errorf("failed to create summary message: %w", err),
921 Done: true,
922 }
923
924 a.Publish(pubsub.CreatedEvent, event)
925 return
926 }
927 oldSession.SummaryMessageID = msg.ID
928 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
929 oldSession.PromptTokens = 0
930 model := a.summarizeProvider.Model()
931 usage := finalResponse.Usage
932 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
933 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
934 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
935 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
936 oldSession.Cost += cost
937 _, err = a.sessions.Save(summarizeCtx, oldSession)
938 if err != nil {
939 event = AgentEvent{
940 Type: AgentEventTypeError,
941 Error: fmt.Errorf("failed to save session: %w", err),
942 Done: true,
943 }
944 a.Publish(pubsub.CreatedEvent, event)
945 }
946
947 event = AgentEvent{
948 Type: AgentEventTypeSummarize,
949 SessionID: oldSession.ID,
950 Progress: "Summary complete",
951 Done: true,
952 }
953 a.Publish(pubsub.CreatedEvent, event)
954 // Send final success event with the new session ID
955 }()
956
957 return nil
958}
959
960func (a *agent) ClearQueue(sessionID string) {
961 if a.QueuedPrompts(sessionID) > 0 {
962 slog.Info("Clearing queued prompts", "session_id", sessionID)
963 a.promptQueue.Del(sessionID)
964 }
965}
966
967func (a *agent) CancelAll() {
968 if !a.IsBusy() {
969 return
970 }
971 for key := range a.activeRequests.Seq2() {
972 a.Cancel(key) // key is sessionID
973 }
974
975 timeout := time.After(5 * time.Second)
976 for a.IsBusy() {
977 select {
978 case <-timeout:
979 return
980 default:
981 time.Sleep(200 * time.Millisecond)
982 }
983 }
984}
985
986func (a *agent) UpdateModel() error {
987 cfg := config.Get()
988
989 // Get current provider configuration
990 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
991 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
992 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
993 }
994
995 // Check if provider has changed
996 if string(currentProviderCfg.ID) != a.providerID {
997 // Provider changed, need to recreate the main provider
998 model := cfg.GetModelByType(a.agentCfg.Model)
999 if model.ID == "" {
1000 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1001 }
1002
1003 promptID := agentPromptMap[a.agentCfg.ID]
1004 if promptID == "" {
1005 promptID = prompt.PromptDefault
1006 }
1007
1008 opts := []provider.ProviderClientOption{
1009 provider.WithModel(a.agentCfg.Model),
1010 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1011 }
1012
1013 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1014 if err != nil {
1015 return fmt.Errorf("failed to create new provider: %w", err)
1016 }
1017
1018 // Update the provider and provider ID
1019 a.provider = newProvider
1020 a.providerID = string(currentProviderCfg.ID)
1021 }
1022
1023 // Check if providers have changed for title (small) and summarize (large)
1024 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1025 var smallModelProviderCfg config.ProviderConfig
1026 for p := range cfg.Providers.Seq() {
1027 if p.ID == smallModelCfg.Provider {
1028 smallModelProviderCfg = p
1029 break
1030 }
1031 }
1032 if smallModelProviderCfg.ID == "" {
1033 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1034 }
1035
1036 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1037 var largeModelProviderCfg config.ProviderConfig
1038 for p := range cfg.Providers.Seq() {
1039 if p.ID == largeModelCfg.Provider {
1040 largeModelProviderCfg = p
1041 break
1042 }
1043 }
1044 if largeModelProviderCfg.ID == "" {
1045 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1046 }
1047
1048 var maxTitleTokens int64 = 40
1049
1050 // if the max output is too low for the gemini provider it won't return anything
1051 if smallModelCfg.Provider == "gemini" {
1052 maxTitleTokens = 1000
1053 }
1054 // Recreate title provider
1055 titleOpts := []provider.ProviderClientOption{
1056 provider.WithModel(config.SelectedModelTypeSmall),
1057 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1058 provider.WithMaxTokens(maxTitleTokens),
1059 }
1060 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1061 if err != nil {
1062 return fmt.Errorf("failed to create new title provider: %w", err)
1063 }
1064 a.titleProvider = newTitleProvider
1065
1066 // Recreate summarize provider if provider changed (now large model)
1067 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1068 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1069 if largeModel == nil {
1070 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1071 }
1072 summarizeOpts := []provider.ProviderClientOption{
1073 provider.WithModel(config.SelectedModelTypeLarge),
1074 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1075 }
1076 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1077 if err != nil {
1078 return fmt.Errorf("failed to create new summarize provider: %w", err)
1079 }
1080 a.summarizeProvider = newSummarizeProvider
1081 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1082 }
1083
1084 return nil
1085}