1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/event"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26 "github.com/charmbracelet/crush/internal/shell"
27)
28
29const streamChunkTimeout = 2 * time.Minute
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() catwalk.Model
53 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 Summarize(ctx context.Context, sessionID string) error
59 UpdateModel() error
60 QueuedPrompts(sessionID string) int
61 ClearQueue(sessionID string)
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69 permissions permission.Service
70 mcpTools []McpTool
71
72 tools *csync.LazySlice[tools.BaseTool]
73 // We need this to be able to update it when model changes
74 agentToolFn func() (tools.BaseTool, error)
75
76 provider provider.Provider
77 providerID string
78
79 titleProvider provider.Provider
80 summarizeProvider provider.Provider
81 summarizeProviderID string
82
83 activeRequests *csync.Map[string, context.CancelFunc]
84 promptQueue *csync.Map[string, []string]
85}
86
87var agentPromptMap = map[string]prompt.PromptID{
88 "coder": prompt.PromptCoder,
89 "task": prompt.PromptTask,
90}
91
92func NewAgent(
93 ctx context.Context,
94 agentCfg config.Agent,
95 // These services are needed in the tools
96 permissions permission.Service,
97 sessions session.Service,
98 messages message.Service,
99 history history.Service,
100 lspClients *csync.Map[string, *lsp.Client],
101) (Service, error) {
102 cfg := config.Get()
103
104 var agentToolFn func() (tools.BaseTool, error)
105 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
106 agentToolFn = func() (tools.BaseTool, error) {
107 taskAgentCfg := config.Get().Agents["task"]
108 if taskAgentCfg.ID == "" {
109 return nil, fmt.Errorf("task agent not found in config")
110 }
111 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
112 if err != nil {
113 return nil, fmt.Errorf("failed to create task agent: %w", err)
114 }
115 return NewAgentTool(taskAgent, sessions, messages), nil
116 }
117 }
118
119 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
120 if providerCfg == nil {
121 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
122 }
123 model := config.Get().GetModelByType(agentCfg.Model)
124
125 if model == nil {
126 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
127 }
128
129 promptID := agentPromptMap[agentCfg.ID]
130 if promptID == "" {
131 promptID = prompt.PromptDefault
132 }
133 opts := []provider.ProviderClientOption{
134 provider.WithModel(agentCfg.Model),
135 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
136 }
137 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
138 if err != nil {
139 return nil, err
140 }
141
142 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
143 var smallModelProviderCfg *config.ProviderConfig
144 if smallModelCfg.Provider == providerCfg.ID {
145 smallModelProviderCfg = providerCfg
146 } else {
147 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
148
149 if smallModelProviderCfg.ID == "" {
150 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
151 }
152 }
153 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
154 if smallModel.ID == "" {
155 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
156 }
157
158 titleOpts := []provider.ProviderClientOption{
159 provider.WithModel(config.SelectedModelTypeSmall),
160 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
161 }
162 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
163 if err != nil {
164 return nil, err
165 }
166
167 summarizeOpts := []provider.ProviderClientOption{
168 provider.WithModel(config.SelectedModelTypeLarge),
169 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
170 }
171 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
172 if err != nil {
173 return nil, err
174 }
175
176 toolFn := func() []tools.BaseTool {
177 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
178 defer func() {
179 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
180 }()
181
182 cwd := cfg.WorkingDir()
183 allTools := []tools.BaseTool{
184 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
185 tools.NewDownloadTool(permissions, cwd),
186 tools.NewEditTool(lspClients, permissions, history, cwd),
187 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
188 tools.NewFetchTool(permissions, cwd),
189 tools.NewGlobTool(cwd),
190 tools.NewGrepTool(cwd),
191 tools.NewLsTool(permissions, cwd),
192 tools.NewSourcegraphTool(),
193 tools.NewViewTool(lspClients, permissions, cwd),
194 tools.NewWriteTool(lspClients, permissions, history, cwd),
195 }
196
197 mcpToolsOnce.Do(func() {
198 mcpTools = doGetMCPTools(ctx, permissions, cfg)
199 })
200
201 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
202 if agentCfg.ID == "coder" {
203 t = append(t, mcpTools...)
204 if lspClients.Len() > 0 {
205 t = append(t, tools.NewDiagnosticsTool(lspClients))
206 }
207 }
208 return t
209 }
210
211 if agentCfg.AllowedTools == nil {
212 return withCoderTools(allTools)
213 }
214
215 var filteredTools []tools.BaseTool
216 for _, tool := range allTools {
217 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
218 filteredTools = append(filteredTools, tool)
219 }
220 }
221 return withCoderTools(filteredTools)
222 }
223
224 return &agent{
225 Broker: pubsub.NewBroker[AgentEvent](),
226 agentCfg: agentCfg,
227 provider: agentProvider,
228 providerID: string(providerCfg.ID),
229 messages: messages,
230 sessions: sessions,
231 titleProvider: titleProvider,
232 summarizeProvider: summarizeProvider,
233 summarizeProviderID: string(providerCfg.ID),
234 agentToolFn: agentToolFn,
235 activeRequests: csync.NewMap[string, context.CancelFunc](),
236 tools: csync.NewLazySlice(toolFn),
237 promptQueue: csync.NewMap[string, []string](),
238 permissions: permissions,
239 }, nil
240}
241
242func (a *agent) Model() catwalk.Model {
243 return *config.Get().GetModelByType(a.agentCfg.Model)
244}
245
246func (a *agent) Cancel(sessionID string) {
247 // Cancel regular requests
248 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
249 slog.Info("Request cancellation initiated", "session_id", sessionID)
250 cancel()
251 }
252
253 // Also check for summarize requests
254 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
255 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
256 cancel()
257 }
258
259 if a.QueuedPrompts(sessionID) > 0 {
260 slog.Info("Clearing queued prompts", "session_id", sessionID)
261 a.promptQueue.Del(sessionID)
262 }
263}
264
265func (a *agent) IsBusy() bool {
266 var busy bool
267 for cancelFunc := range a.activeRequests.Seq() {
268 if cancelFunc != nil {
269 busy = true
270 break
271 }
272 }
273 return busy
274}
275
276func (a *agent) IsSessionBusy(sessionID string) bool {
277 _, busy := a.activeRequests.Get(sessionID)
278 return busy
279}
280
281func (a *agent) QueuedPrompts(sessionID string) int {
282 l, ok := a.promptQueue.Get(sessionID)
283 if !ok {
284 return 0
285 }
286 return len(l)
287}
288
289func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
290 if content == "" {
291 return nil
292 }
293 if a.titleProvider == nil {
294 return nil
295 }
296 session, err := a.sessions.Get(ctx, sessionID)
297 if err != nil {
298 return err
299 }
300 parts := []message.ContentPart{message.TextContent{
301 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
302 }}
303
304 // Use streaming approach like summarization
305 response := a.titleProvider.StreamResponse(
306 ctx,
307 []message.Message{
308 {
309 Role: message.User,
310 Parts: parts,
311 },
312 },
313 nil,
314 )
315
316 var finalResponse *provider.ProviderResponse
317 for r := range response {
318 if r.Error != nil {
319 return r.Error
320 }
321 finalResponse = r.Response
322 }
323
324 if finalResponse == nil {
325 return fmt.Errorf("no response received from title provider")
326 }
327
328 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
329
330 if idx := strings.Index(title, "</think>"); idx > 0 {
331 title = title[idx+len("</think>"):]
332 }
333
334 title = strings.TrimSpace(title)
335 if title == "" {
336 return nil
337 }
338
339 session.Title = title
340 _, err = a.sessions.Save(ctx, session)
341 return err
342}
343
344func (a *agent) err(err error) AgentEvent {
345 return AgentEvent{
346 Type: AgentEventTypeError,
347 Error: err,
348 }
349}
350
351func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
352 if !a.Model().SupportsImages && attachments != nil {
353 attachments = nil
354 }
355 events := make(chan AgentEvent, 1)
356 if a.IsSessionBusy(sessionID) {
357 existing, ok := a.promptQueue.Get(sessionID)
358 if !ok {
359 existing = []string{}
360 }
361 existing = append(existing, content)
362 a.promptQueue.Set(sessionID, existing)
363 return nil, nil
364 }
365
366 genCtx, cancel := context.WithCancel(ctx)
367 a.activeRequests.Set(sessionID, cancel)
368
369 go func() {
370 slog.Debug("Request started", "sessionID", sessionID)
371 defer log.RecoverPanic("agent.Run", func() {
372 events <- a.err(fmt.Errorf("panic while running the agent"))
373 })
374 var attachmentParts []message.ContentPart
375 for _, attachment := range attachments {
376 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
377 }
378 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
379 if result.Error != nil {
380 if isCancelledErr(result.Error) {
381 slog.Error("Request canceled", "sessionID", sessionID)
382 } else {
383 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
384 event.Error(result.Error)
385 }
386 } else {
387 slog.Debug("Request completed", "sessionID", sessionID)
388 }
389 a.activeRequests.Del(sessionID)
390 cancel()
391 a.Publish(pubsub.CreatedEvent, result)
392 events <- result
393 close(events)
394 }()
395 a.eventPromptSent(sessionID)
396 return events, nil
397}
398
399func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
400 cfg := config.Get()
401 // List existing messages; if none, start title generation asynchronously.
402 msgs, err := a.messages.List(ctx, sessionID)
403 if err != nil {
404 return a.err(fmt.Errorf("failed to list messages: %w", err))
405 }
406 if len(msgs) == 0 {
407 go func() {
408 defer log.RecoverPanic("agent.Run", func() {
409 slog.Error("panic while generating title")
410 })
411 titleErr := a.generateTitle(ctx, sessionID, content)
412 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
413 slog.Error("failed to generate title", "error", titleErr)
414 }
415 }()
416 }
417 session, err := a.sessions.Get(ctx, sessionID)
418 if err != nil {
419 return a.err(fmt.Errorf("failed to get session: %w", err))
420 }
421 if session.SummaryMessageID != "" {
422 summaryMsgInex := -1
423 for i, msg := range msgs {
424 if msg.ID == session.SummaryMessageID {
425 summaryMsgInex = i
426 break
427 }
428 }
429 if summaryMsgInex != -1 {
430 msgs = msgs[summaryMsgInex:]
431 msgs[0].Role = message.User
432 }
433 }
434
435 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
436 if err != nil {
437 return a.err(fmt.Errorf("failed to create user message: %w", err))
438 }
439 // Append the new user message to the conversation history.
440 msgHistory := append(msgs, userMsg)
441
442 for {
443 // Check for cancellation before each iteration
444 select {
445 case <-ctx.Done():
446 return a.err(ctx.Err())
447 default:
448 // Continue processing
449 }
450 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
451 if err != nil {
452 if errors.Is(err, context.Canceled) {
453 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
454 a.messages.Update(context.Background(), agentMessage)
455 return a.err(ErrRequestCancelled)
456 }
457 return a.err(fmt.Errorf("failed to process events: %w", err))
458 }
459 if cfg.Options.Debug {
460 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
461 }
462 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
463 // We are not done, we need to respond with the tool response
464 msgHistory = append(msgHistory, agentMessage, *toolResults)
465 // If there are queued prompts, process the next one
466 nextPrompt, ok := a.promptQueue.Take(sessionID)
467 if ok {
468 for _, prompt := range nextPrompt {
469 // Create a new user message for the queued prompt
470 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
471 if err != nil {
472 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
473 }
474 // Append the new user message to the conversation history
475 msgHistory = append(msgHistory, userMsg)
476 }
477 }
478
479 continue
480 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
481 queuePrompts, ok := a.promptQueue.Take(sessionID)
482 if ok {
483 for _, prompt := range queuePrompts {
484 if prompt == "" {
485 continue
486 }
487 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
488 if err != nil {
489 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
490 }
491 msgHistory = append(msgHistory, userMsg)
492 }
493 continue
494 }
495 }
496 if agentMessage.FinishReason() == "" {
497 // Kujtim: could not track down where this is happening but this means its cancelled
498 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
499 _ = a.messages.Update(context.Background(), agentMessage)
500 return a.err(ErrRequestCancelled)
501 }
502 return AgentEvent{
503 Type: AgentEventTypeResponse,
504 Message: agentMessage,
505 Done: true,
506 }
507 }
508}
509
510func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
511 parts := []message.ContentPart{message.TextContent{Text: content}}
512 parts = append(parts, attachmentParts...)
513 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
514 Role: message.User,
515 Parts: parts,
516 })
517}
518
519func (a *agent) getAllTools() ([]tools.BaseTool, error) {
520 allTools := slices.Collect(a.tools.Seq())
521 if a.agentToolFn != nil {
522 agentTool, agentToolErr := a.agentToolFn()
523 if agentToolErr != nil {
524 return nil, agentToolErr
525 }
526 allTools = append(allTools, agentTool)
527 }
528 return allTools, nil
529}
530
531func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
532 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
533
534 // Create the assistant message first so the spinner shows immediately
535 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
536 Role: message.Assistant,
537 Parts: []message.ContentPart{},
538 Model: a.Model().ID,
539 Provider: a.providerID,
540 })
541 if err != nil {
542 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
543 }
544
545 allTools, toolsErr := a.getAllTools()
546 if toolsErr != nil {
547 return assistantMsg, nil, toolsErr
548 }
549 // Now collect tools (which may block on MCP initialization)
550 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
551
552 // Add the session and message ID into the context if needed by tools.
553 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
554
555 // Process each event in the stream.
556 timer := time.NewTimer(streamChunkTimeout)
557 defer timer.Stop()
558
559loop:
560 for {
561 select {
562 case event, ok := <-eventChan:
563 if !ok {
564 break loop
565 }
566 // Reset the timeout timer since we received a chunk
567 timer.Reset(streamChunkTimeout)
568
569 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
570 if errors.Is(processErr, context.Canceled) {
571 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
572 } else {
573 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
574 }
575 return assistantMsg, nil, processErr
576 }
577 case <-timer.C:
578 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "Stream timeout", "No chunk received within timeout")
579 return assistantMsg, nil, ErrStreamTimeout
580 case <-ctx.Done():
581 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
582 return assistantMsg, nil, ctx.Err()
583 }
584 }
585
586 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
587 toolCalls := assistantMsg.ToolCalls()
588 for i, toolCall := range toolCalls {
589 select {
590 case <-ctx.Done():
591 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
592 // Make all future tool calls cancelled
593 for j := i; j < len(toolCalls); j++ {
594 toolResults[j] = message.ToolResult{
595 ToolCallID: toolCalls[j].ID,
596 Content: "Tool execution canceled by user",
597 IsError: true,
598 }
599 }
600 goto out
601 default:
602 // Continue processing
603 var tool tools.BaseTool
604 allTools, _ := a.getAllTools()
605 for _, availableTool := range allTools {
606 if availableTool.Info().Name == toolCall.Name {
607 tool = availableTool
608 break
609 }
610 }
611
612 // Tool not found
613 if tool == nil {
614 toolResults[i] = message.ToolResult{
615 ToolCallID: toolCall.ID,
616 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
617 IsError: true,
618 }
619 continue
620 }
621
622 // Run tool in goroutine to allow cancellation
623 type toolExecResult struct {
624 response tools.ToolResponse
625 err error
626 }
627 resultChan := make(chan toolExecResult, 1)
628
629 go func() {
630 response, err := tool.Run(ctx, tools.ToolCall{
631 ID: toolCall.ID,
632 Name: toolCall.Name,
633 Input: toolCall.Input,
634 })
635 resultChan <- toolExecResult{response: response, err: err}
636 }()
637
638 var toolResponse tools.ToolResponse
639 var toolErr error
640
641 select {
642 case <-ctx.Done():
643 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
644 // Mark remaining tool calls as cancelled
645 for j := i; j < len(toolCalls); j++ {
646 toolResults[j] = message.ToolResult{
647 ToolCallID: toolCalls[j].ID,
648 Content: "Tool execution canceled by user",
649 IsError: true,
650 }
651 }
652 goto out
653 case result := <-resultChan:
654 toolResponse = result.response
655 toolErr = result.err
656 }
657
658 if toolErr != nil {
659 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
660 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
661 toolResults[i] = message.ToolResult{
662 ToolCallID: toolCall.ID,
663 Content: "Permission denied",
664 IsError: true,
665 }
666 for j := i + 1; j < len(toolCalls); j++ {
667 toolResults[j] = message.ToolResult{
668 ToolCallID: toolCalls[j].ID,
669 Content: "Tool execution canceled by user",
670 IsError: true,
671 }
672 }
673 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
674 break
675 }
676 }
677 toolResults[i] = message.ToolResult{
678 ToolCallID: toolCall.ID,
679 Content: toolResponse.Content,
680 Metadata: toolResponse.Metadata,
681 IsError: toolResponse.IsError,
682 }
683 }
684 }
685out:
686 if len(toolResults) == 0 {
687 return assistantMsg, nil, nil
688 }
689 parts := make([]message.ContentPart, 0)
690 for _, tr := range toolResults {
691 parts = append(parts, tr)
692 }
693 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
694 Role: message.Tool,
695 Parts: parts,
696 Provider: a.providerID,
697 })
698 if err != nil {
699 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
700 }
701
702 return assistantMsg, &msg, err
703}
704
705func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
706 msg.AddFinish(finishReason, message, details)
707 _ = a.messages.Update(ctx, *msg)
708}
709
710func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
711 select {
712 case <-ctx.Done():
713 return ctx.Err()
714 default:
715 // Continue processing.
716 }
717
718 switch event.Type {
719 case provider.EventThinkingDelta:
720 assistantMsg.AppendReasoningContent(event.Thinking)
721 return a.messages.Update(ctx, *assistantMsg)
722 case provider.EventSignatureDelta:
723 assistantMsg.AppendReasoningSignature(event.Signature)
724 return a.messages.Update(ctx, *assistantMsg)
725 case provider.EventContentDelta:
726 assistantMsg.FinishThinking()
727 assistantMsg.AppendContent(event.Content)
728 return a.messages.Update(ctx, *assistantMsg)
729 case provider.EventToolUseStart:
730 assistantMsg.FinishThinking()
731 slog.Info("Tool call started", "toolCall", event.ToolCall)
732 assistantMsg.AddToolCall(*event.ToolCall)
733 return a.messages.Update(ctx, *assistantMsg)
734 case provider.EventToolUseDelta:
735 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
736 return a.messages.Update(ctx, *assistantMsg)
737 case provider.EventToolUseStop:
738 slog.Info("Finished tool call", "toolCall", event.ToolCall)
739 assistantMsg.FinishToolCall(event.ToolCall.ID)
740 return a.messages.Update(ctx, *assistantMsg)
741 case provider.EventError:
742 return event.Error
743 case provider.EventComplete:
744 assistantMsg.FinishThinking()
745 assistantMsg.SetToolCalls(event.Response.ToolCalls)
746 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
747 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
748 return fmt.Errorf("failed to update message: %w", err)
749 }
750 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
751 }
752
753 return nil
754}
755
756func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
757 sess, err := a.sessions.Get(ctx, sessionID)
758 if err != nil {
759 return fmt.Errorf("failed to get session: %w", err)
760 }
761
762 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
763 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
764 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
765 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
766
767 a.eventTokensUsed(sessionID, usage, cost)
768
769 sess.Cost += cost
770 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
771 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
772
773 _, err = a.sessions.Save(ctx, sess)
774 if err != nil {
775 return fmt.Errorf("failed to save session: %w", err)
776 }
777 return nil
778}
779
780func (a *agent) Summarize(ctx context.Context, sessionID string) error {
781 if a.summarizeProvider == nil {
782 return fmt.Errorf("summarize provider not available")
783 }
784
785 // Check if session is busy
786 if a.IsSessionBusy(sessionID) {
787 return ErrSessionBusy
788 }
789
790 // Create a new context with cancellation
791 summarizeCtx, cancel := context.WithCancel(ctx)
792
793 // Store the cancel function in activeRequests to allow cancellation
794 a.activeRequests.Set(sessionID+"-summarize", cancel)
795
796 go func() {
797 defer a.activeRequests.Del(sessionID + "-summarize")
798 defer cancel()
799 event := AgentEvent{
800 Type: AgentEventTypeSummarize,
801 Progress: "Starting summarization...",
802 }
803
804 a.Publish(pubsub.CreatedEvent, event)
805 // Get all messages from the session
806 msgs, err := a.messages.List(summarizeCtx, sessionID)
807 if err != nil {
808 event = AgentEvent{
809 Type: AgentEventTypeError,
810 Error: fmt.Errorf("failed to list messages: %w", err),
811 Done: true,
812 }
813 a.Publish(pubsub.CreatedEvent, event)
814 return
815 }
816 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
817
818 if len(msgs) == 0 {
819 event = AgentEvent{
820 Type: AgentEventTypeError,
821 Error: fmt.Errorf("no messages to summarize"),
822 Done: true,
823 }
824 a.Publish(pubsub.CreatedEvent, event)
825 return
826 }
827
828 event = AgentEvent{
829 Type: AgentEventTypeSummarize,
830 Progress: "Analyzing conversation...",
831 }
832 a.Publish(pubsub.CreatedEvent, event)
833
834 // Add a system message to guide the summarization
835 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
836
837 // Create a new message with the summarize prompt
838 promptMsg := message.Message{
839 Role: message.User,
840 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
841 }
842
843 // Append the prompt to the messages
844 msgsWithPrompt := append(msgs, promptMsg)
845
846 event = AgentEvent{
847 Type: AgentEventTypeSummarize,
848 Progress: "Generating summary...",
849 }
850
851 a.Publish(pubsub.CreatedEvent, event)
852
853 // Send the messages to the summarize provider
854 response := a.summarizeProvider.StreamResponse(
855 summarizeCtx,
856 msgsWithPrompt,
857 nil,
858 )
859 var finalResponse *provider.ProviderResponse
860 for r := range response {
861 if r.Error != nil {
862 event = AgentEvent{
863 Type: AgentEventTypeError,
864 Error: fmt.Errorf("failed to summarize: %w", r.Error),
865 Done: true,
866 }
867 a.Publish(pubsub.CreatedEvent, event)
868 return
869 }
870 finalResponse = r.Response
871 }
872
873 summary := strings.TrimSpace(finalResponse.Content)
874 if summary == "" {
875 event = AgentEvent{
876 Type: AgentEventTypeError,
877 Error: fmt.Errorf("empty summary returned"),
878 Done: true,
879 }
880 a.Publish(pubsub.CreatedEvent, event)
881 return
882 }
883 shell := shell.GetPersistentShell(config.Get().WorkingDir())
884 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
885 event = AgentEvent{
886 Type: AgentEventTypeSummarize,
887 Progress: "Creating new session...",
888 }
889
890 a.Publish(pubsub.CreatedEvent, event)
891 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
892 if err != nil {
893 event = AgentEvent{
894 Type: AgentEventTypeError,
895 Error: fmt.Errorf("failed to get session: %w", err),
896 Done: true,
897 }
898
899 a.Publish(pubsub.CreatedEvent, event)
900 return
901 }
902 // Create a message in the new session with the summary
903 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
904 Role: message.Assistant,
905 Parts: []message.ContentPart{
906 message.TextContent{Text: summary},
907 message.Finish{
908 Reason: message.FinishReasonEndTurn,
909 Time: time.Now().Unix(),
910 },
911 },
912 Model: a.summarizeProvider.Model().ID,
913 Provider: a.summarizeProviderID,
914 })
915 if err != nil {
916 event = AgentEvent{
917 Type: AgentEventTypeError,
918 Error: fmt.Errorf("failed to create summary message: %w", err),
919 Done: true,
920 }
921
922 a.Publish(pubsub.CreatedEvent, event)
923 return
924 }
925 oldSession.SummaryMessageID = msg.ID
926 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
927 oldSession.PromptTokens = 0
928 model := a.summarizeProvider.Model()
929 usage := finalResponse.Usage
930 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
931 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
932 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
933 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
934 oldSession.Cost += cost
935 _, err = a.sessions.Save(summarizeCtx, oldSession)
936 if err != nil {
937 event = AgentEvent{
938 Type: AgentEventTypeError,
939 Error: fmt.Errorf("failed to save session: %w", err),
940 Done: true,
941 }
942 a.Publish(pubsub.CreatedEvent, event)
943 }
944
945 event = AgentEvent{
946 Type: AgentEventTypeSummarize,
947 SessionID: oldSession.ID,
948 Progress: "Summary complete",
949 Done: true,
950 }
951 a.Publish(pubsub.CreatedEvent, event)
952 // Send final success event with the new session ID
953 }()
954
955 return nil
956}
957
958func (a *agent) ClearQueue(sessionID string) {
959 if a.QueuedPrompts(sessionID) > 0 {
960 slog.Info("Clearing queued prompts", "session_id", sessionID)
961 a.promptQueue.Del(sessionID)
962 }
963}
964
965func (a *agent) CancelAll() {
966 if !a.IsBusy() {
967 return
968 }
969 for key := range a.activeRequests.Seq2() {
970 a.Cancel(key) // key is sessionID
971 }
972
973 timeout := time.After(5 * time.Second)
974 for a.IsBusy() {
975 select {
976 case <-timeout:
977 return
978 default:
979 time.Sleep(200 * time.Millisecond)
980 }
981 }
982}
983
984func (a *agent) UpdateModel() error {
985 cfg := config.Get()
986
987 // Get current provider configuration
988 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
989 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
990 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
991 }
992
993 // Check if provider has changed
994 if string(currentProviderCfg.ID) != a.providerID {
995 // Provider changed, need to recreate the main provider
996 model := cfg.GetModelByType(a.agentCfg.Model)
997 if model.ID == "" {
998 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
999 }
1000
1001 promptID := agentPromptMap[a.agentCfg.ID]
1002 if promptID == "" {
1003 promptID = prompt.PromptDefault
1004 }
1005
1006 opts := []provider.ProviderClientOption{
1007 provider.WithModel(a.agentCfg.Model),
1008 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1009 }
1010
1011 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1012 if err != nil {
1013 return fmt.Errorf("failed to create new provider: %w", err)
1014 }
1015
1016 // Update the provider and provider ID
1017 a.provider = newProvider
1018 a.providerID = string(currentProviderCfg.ID)
1019 }
1020
1021 // Check if providers have changed for title (small) and summarize (large)
1022 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1023 var smallModelProviderCfg config.ProviderConfig
1024 for p := range cfg.Providers.Seq() {
1025 if p.ID == smallModelCfg.Provider {
1026 smallModelProviderCfg = p
1027 break
1028 }
1029 }
1030 if smallModelProviderCfg.ID == "" {
1031 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1032 }
1033
1034 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1035 var largeModelProviderCfg config.ProviderConfig
1036 for p := range cfg.Providers.Seq() {
1037 if p.ID == largeModelCfg.Provider {
1038 largeModelProviderCfg = p
1039 break
1040 }
1041 }
1042 if largeModelProviderCfg.ID == "" {
1043 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1044 }
1045
1046 var maxTitleTokens int64 = 40
1047
1048 // if the max output is too low for the gemini provider it won't return anything
1049 if smallModelCfg.Provider == "gemini" {
1050 maxTitleTokens = 1000
1051 }
1052 // Recreate title provider
1053 titleOpts := []provider.ProviderClientOption{
1054 provider.WithModel(config.SelectedModelTypeSmall),
1055 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1056 provider.WithMaxTokens(maxTitleTokens),
1057 }
1058 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1059 if err != nil {
1060 return fmt.Errorf("failed to create new title provider: %w", err)
1061 }
1062 a.titleProvider = newTitleProvider
1063
1064 // Recreate summarize provider if provider changed (now large model)
1065 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1066 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1067 if largeModel == nil {
1068 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1069 }
1070 summarizeOpts := []provider.ProviderClientOption{
1071 provider.WithModel(config.SelectedModelTypeLarge),
1072 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1073 }
1074 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1075 if err != nil {
1076 return fmt.Errorf("failed to create new summarize provider: %w", err)
1077 }
1078 a.summarizeProvider = newSummarizeProvider
1079 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1080 }
1081
1082 return nil
1083}