1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/event"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26 "github.com/charmbracelet/crush/internal/shell"
27)
28
29const streamChunkTimeout = 80 * time.Second
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() catwalk.Model
53 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 Summarize(ctx context.Context, sessionID string) error
59 UpdateModel() error
60 QueuedPrompts(sessionID string) int
61 ClearQueue(sessionID string)
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69 permissions permission.Service
70 mcpTools []McpTool
71
72 tools *csync.LazySlice[tools.BaseTool]
73 // We need this to be able to update it when model changes
74 agentToolFn func() (tools.BaseTool, error)
75
76 provider provider.Provider
77 providerID string
78
79 titleProvider provider.Provider
80 summarizeProvider provider.Provider
81 summarizeProviderID string
82
83 activeRequests *csync.Map[string, context.CancelFunc]
84 promptQueue *csync.Map[string, []string]
85}
86
87var agentPromptMap = map[string]prompt.PromptID{
88 "coder": prompt.PromptCoder,
89 "task": prompt.PromptTask,
90}
91
92func NewAgent(
93 ctx context.Context,
94 agentCfg config.Agent,
95 // These services are needed in the tools
96 permissions permission.Service,
97 sessions session.Service,
98 messages message.Service,
99 history history.Service,
100 lspClients *csync.Map[string, *lsp.Client],
101) (Service, error) {
102 cfg := config.Get()
103
104 var agentToolFn func() (tools.BaseTool, error)
105 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
106 agentToolFn = func() (tools.BaseTool, error) {
107 taskAgentCfg := config.Get().Agents["task"]
108 if taskAgentCfg.ID == "" {
109 return nil, fmt.Errorf("task agent not found in config")
110 }
111 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
112 if err != nil {
113 return nil, fmt.Errorf("failed to create task agent: %w", err)
114 }
115 return NewAgentTool(taskAgent, sessions, messages), nil
116 }
117 }
118
119 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
120 if providerCfg == nil {
121 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
122 }
123 model := config.Get().GetModelByType(agentCfg.Model)
124
125 if model == nil {
126 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
127 }
128
129 promptID := agentPromptMap[agentCfg.ID]
130 if promptID == "" {
131 promptID = prompt.PromptDefault
132 }
133 opts := []provider.ProviderClientOption{
134 provider.WithModel(agentCfg.Model),
135 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
136 }
137 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
138 if err != nil {
139 return nil, err
140 }
141
142 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
143 var smallModelProviderCfg *config.ProviderConfig
144 if smallModelCfg.Provider == providerCfg.ID {
145 smallModelProviderCfg = providerCfg
146 } else {
147 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
148
149 if smallModelProviderCfg.ID == "" {
150 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
151 }
152 }
153 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
154 if smallModel.ID == "" {
155 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
156 }
157
158 titleOpts := []provider.ProviderClientOption{
159 provider.WithModel(config.SelectedModelTypeSmall),
160 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
161 }
162 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
163 if err != nil {
164 return nil, err
165 }
166
167 summarizeOpts := []provider.ProviderClientOption{
168 provider.WithModel(config.SelectedModelTypeLarge),
169 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
170 }
171 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
172 if err != nil {
173 return nil, err
174 }
175
176 toolFn := func() []tools.BaseTool {
177 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
178 defer func() {
179 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
180 }()
181
182 cwd := cfg.WorkingDir()
183 allTools := []tools.BaseTool{
184 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
185 tools.NewDownloadTool(permissions, cwd),
186 tools.NewEditTool(lspClients, permissions, history, cwd),
187 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
188 tools.NewFetchTool(permissions, cwd),
189 tools.NewGlobTool(cwd),
190 tools.NewGrepTool(cwd),
191 tools.NewLsTool(permissions, cwd),
192 tools.NewSourcegraphTool(),
193 tools.NewViewTool(lspClients, permissions, cwd),
194 tools.NewWriteTool(lspClients, permissions, history, cwd),
195 }
196
197 mcpToolsOnce.Do(func() {
198 mcpTools = doGetMCPTools(ctx, permissions, cfg)
199 })
200
201 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
202 if agentCfg.ID == "coder" {
203 t = append(t, mcpTools...)
204 if lspClients.Len() > 0 {
205 t = append(t, tools.NewDiagnosticsTool(lspClients))
206 }
207 }
208 return t
209 }
210
211 if agentCfg.AllowedTools == nil {
212 return withCoderTools(allTools)
213 }
214
215 var filteredTools []tools.BaseTool
216 for _, tool := range allTools {
217 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
218 filteredTools = append(filteredTools, tool)
219 }
220 }
221 return withCoderTools(filteredTools)
222 }
223
224 return &agent{
225 Broker: pubsub.NewBroker[AgentEvent](),
226 agentCfg: agentCfg,
227 provider: agentProvider,
228 providerID: string(providerCfg.ID),
229 messages: messages,
230 sessions: sessions,
231 titleProvider: titleProvider,
232 summarizeProvider: summarizeProvider,
233 summarizeProviderID: string(providerCfg.ID),
234 agentToolFn: agentToolFn,
235 activeRequests: csync.NewMap[string, context.CancelFunc](),
236 tools: csync.NewLazySlice(toolFn),
237 promptQueue: csync.NewMap[string, []string](),
238 permissions: permissions,
239 }, nil
240}
241
242func (a *agent) Model() catwalk.Model {
243 return *config.Get().GetModelByType(a.agentCfg.Model)
244}
245
246func (a *agent) Cancel(sessionID string) {
247 // Cancel regular requests
248 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
249 slog.Info("Request cancellation initiated", "session_id", sessionID)
250 cancel()
251 }
252
253 // Also check for summarize requests
254 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
255 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
256 cancel()
257 }
258
259 if a.QueuedPrompts(sessionID) > 0 {
260 slog.Info("Clearing queued prompts", "session_id", sessionID)
261 a.promptQueue.Del(sessionID)
262 }
263}
264
265func (a *agent) IsBusy() bool {
266 var busy bool
267 for cancelFunc := range a.activeRequests.Seq() {
268 if cancelFunc != nil {
269 busy = true
270 break
271 }
272 }
273 return busy
274}
275
276func (a *agent) IsSessionBusy(sessionID string) bool {
277 _, busy := a.activeRequests.Get(sessionID)
278 return busy
279}
280
281func (a *agent) QueuedPrompts(sessionID string) int {
282 l, ok := a.promptQueue.Get(sessionID)
283 if !ok {
284 return 0
285 }
286 return len(l)
287}
288
289func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
290 if content == "" {
291 return nil
292 }
293 if a.titleProvider == nil {
294 return nil
295 }
296 session, err := a.sessions.Get(ctx, sessionID)
297 if err != nil {
298 return err
299 }
300 parts := []message.ContentPart{message.TextContent{
301 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
302 }}
303
304 // Use streaming approach like summarization
305 response := a.titleProvider.StreamResponse(
306 ctx,
307 []message.Message{
308 {
309 Role: message.User,
310 Parts: parts,
311 },
312 },
313 nil,
314 )
315
316 var finalResponse *provider.ProviderResponse
317 for r := range response {
318 if r.Error != nil {
319 return r.Error
320 }
321 finalResponse = r.Response
322 }
323
324 if finalResponse == nil {
325 return fmt.Errorf("no response received from title provider")
326 }
327
328 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
329
330 if idx := strings.Index(title, "</think>"); idx > 0 {
331 title = title[idx+len("</think>"):]
332 }
333
334 title = strings.TrimSpace(title)
335 if title == "" {
336 return nil
337 }
338
339 session.Title = title
340 _, err = a.sessions.Save(ctx, session)
341 return err
342}
343
344func (a *agent) err(err error) AgentEvent {
345 return AgentEvent{
346 Type: AgentEventTypeError,
347 Error: err,
348 }
349}
350
351func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
352 if !a.Model().SupportsImages && attachments != nil {
353 attachments = nil
354 }
355 events := make(chan AgentEvent, 1)
356 if a.IsSessionBusy(sessionID) {
357 existing, ok := a.promptQueue.Get(sessionID)
358 if !ok {
359 existing = []string{}
360 }
361 existing = append(existing, content)
362 a.promptQueue.Set(sessionID, existing)
363 return nil, nil
364 }
365
366 genCtx, cancel := context.WithCancel(ctx)
367 a.activeRequests.Set(sessionID, cancel)
368 startTime := time.Now()
369
370 go func() {
371 slog.Debug("Request started", "sessionID", sessionID)
372 defer log.RecoverPanic("agent.Run", func() {
373 events <- a.err(fmt.Errorf("panic while running the agent"))
374 })
375 var attachmentParts []message.ContentPart
376 for _, attachment := range attachments {
377 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
378 }
379 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
380 if result.Error != nil {
381 if isCancelledErr(result.Error) {
382 slog.Error("Request canceled", "sessionID", sessionID)
383 } else {
384 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
385 event.Error(result.Error)
386 }
387 } else {
388 slog.Debug("Request completed", "sessionID", sessionID)
389 }
390 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
391 a.activeRequests.Del(sessionID)
392 cancel()
393 a.Publish(pubsub.CreatedEvent, result)
394 events <- result
395 close(events)
396 }()
397 a.eventPromptSent(sessionID)
398 return events, nil
399}
400
401func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
402 cfg := config.Get()
403 // List existing messages; if none, start title generation asynchronously.
404 msgs, err := a.messages.List(ctx, sessionID)
405 if err != nil {
406 return a.err(fmt.Errorf("failed to list messages: %w", err))
407 }
408 if len(msgs) == 0 {
409 go func() {
410 defer log.RecoverPanic("agent.Run", func() {
411 slog.Error("panic while generating title")
412 })
413 titleErr := a.generateTitle(ctx, sessionID, content)
414 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
415 slog.Error("failed to generate title", "error", titleErr)
416 }
417 }()
418 }
419 session, err := a.sessions.Get(ctx, sessionID)
420 if err != nil {
421 return a.err(fmt.Errorf("failed to get session: %w", err))
422 }
423 if session.SummaryMessageID != "" {
424 summaryMsgInex := -1
425 for i, msg := range msgs {
426 if msg.ID == session.SummaryMessageID {
427 summaryMsgInex = i
428 break
429 }
430 }
431 if summaryMsgInex != -1 {
432 msgs = msgs[summaryMsgInex:]
433 msgs[0].Role = message.User
434 }
435 }
436
437 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
438 if err != nil {
439 return a.err(fmt.Errorf("failed to create user message: %w", err))
440 }
441 // Append the new user message to the conversation history.
442 msgHistory := append(msgs, userMsg)
443
444 for {
445 // Check for cancellation before each iteration
446 select {
447 case <-ctx.Done():
448 return a.err(ctx.Err())
449 default:
450 // Continue processing
451 }
452 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
453 if err != nil {
454 if errors.Is(err, context.Canceled) {
455 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
456 a.messages.Update(context.Background(), agentMessage)
457 return a.err(ErrRequestCancelled)
458 }
459 return a.err(fmt.Errorf("failed to process events: %w", err))
460 }
461 if cfg.Options.Debug {
462 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
463 }
464 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
465 // We are not done, we need to respond with the tool response
466 msgHistory = append(msgHistory, agentMessage, *toolResults)
467 // If there are queued prompts, process the next one
468 nextPrompt, ok := a.promptQueue.Take(sessionID)
469 if ok {
470 for _, prompt := range nextPrompt {
471 // Create a new user message for the queued prompt
472 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
473 if err != nil {
474 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
475 }
476 // Append the new user message to the conversation history
477 msgHistory = append(msgHistory, userMsg)
478 }
479 }
480
481 continue
482 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
483 queuePrompts, ok := a.promptQueue.Take(sessionID)
484 if ok {
485 for _, prompt := range queuePrompts {
486 if prompt == "" {
487 continue
488 }
489 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
490 if err != nil {
491 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
492 }
493 msgHistory = append(msgHistory, userMsg)
494 }
495 continue
496 }
497 }
498 if agentMessage.FinishReason() == "" {
499 // Kujtim: could not track down where this is happening but this means its cancelled
500 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
501 _ = a.messages.Update(context.Background(), agentMessage)
502 return a.err(ErrRequestCancelled)
503 }
504 return AgentEvent{
505 Type: AgentEventTypeResponse,
506 Message: agentMessage,
507 Done: true,
508 }
509 }
510}
511
512func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
513 parts := []message.ContentPart{message.TextContent{Text: content}}
514 parts = append(parts, attachmentParts...)
515 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
516 Role: message.User,
517 Parts: parts,
518 })
519}
520
521func (a *agent) getAllTools() ([]tools.BaseTool, error) {
522 allTools := slices.Collect(a.tools.Seq())
523 if a.agentToolFn != nil {
524 agentTool, agentToolErr := a.agentToolFn()
525 if agentToolErr != nil {
526 return nil, agentToolErr
527 }
528 allTools = append(allTools, agentTool)
529 }
530 return allTools, nil
531}
532
533func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
534 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
535
536 // Create the assistant message first so the spinner shows immediately
537 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
538 Role: message.Assistant,
539 Parts: []message.ContentPart{},
540 Model: a.Model().ID,
541 Provider: a.providerID,
542 })
543 if err != nil {
544 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
545 }
546
547 allTools, toolsErr := a.getAllTools()
548 if toolsErr != nil {
549 return assistantMsg, nil, toolsErr
550 }
551 // Now collect tools (which may block on MCP initialization)
552 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
553
554 // Add the session and message ID into the context if needed by tools.
555 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
556
557 // Process each event in the stream.
558loop:
559 for {
560 select {
561 case event, ok := <-eventChan:
562 if !ok {
563 break loop
564 }
565 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
566 if errors.Is(processErr, context.Canceled) {
567 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
568 } else {
569 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
570 }
571 return assistantMsg, nil, processErr
572 }
573 case <-time.After(streamChunkTimeout):
574 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "Stream timeout", "No chunk received within timeout")
575 return assistantMsg, nil, fmt.Errorf("stream chunk timeout")
576 case <-ctx.Done():
577 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
578 return assistantMsg, nil, ctx.Err()
579 }
580 }
581
582 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
583 toolCalls := assistantMsg.ToolCalls()
584 for i, toolCall := range toolCalls {
585 select {
586 case <-ctx.Done():
587 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
588 // Make all future tool calls cancelled
589 for j := i; j < len(toolCalls); j++ {
590 toolResults[j] = message.ToolResult{
591 ToolCallID: toolCalls[j].ID,
592 Content: "Tool execution canceled by user",
593 IsError: true,
594 }
595 }
596 goto out
597 default:
598 // Continue processing
599 var tool tools.BaseTool
600 allTools, _ := a.getAllTools()
601 for _, availableTool := range allTools {
602 if availableTool.Info().Name == toolCall.Name {
603 tool = availableTool
604 break
605 }
606 }
607
608 // Tool not found
609 if tool == nil {
610 toolResults[i] = message.ToolResult{
611 ToolCallID: toolCall.ID,
612 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
613 IsError: true,
614 }
615 continue
616 }
617
618 // Run tool in goroutine to allow cancellation
619 type toolExecResult struct {
620 response tools.ToolResponse
621 err error
622 }
623 resultChan := make(chan toolExecResult, 1)
624
625 go func() {
626 response, err := tool.Run(ctx, tools.ToolCall{
627 ID: toolCall.ID,
628 Name: toolCall.Name,
629 Input: toolCall.Input,
630 })
631 resultChan <- toolExecResult{response: response, err: err}
632 }()
633
634 var toolResponse tools.ToolResponse
635 var toolErr error
636
637 select {
638 case <-ctx.Done():
639 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
640 // Mark remaining tool calls as cancelled
641 for j := i; j < len(toolCalls); j++ {
642 toolResults[j] = message.ToolResult{
643 ToolCallID: toolCalls[j].ID,
644 Content: "Tool execution canceled by user",
645 IsError: true,
646 }
647 }
648 goto out
649 case result := <-resultChan:
650 toolResponse = result.response
651 toolErr = result.err
652 }
653
654 if toolErr != nil {
655 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
656 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
657 toolResults[i] = message.ToolResult{
658 ToolCallID: toolCall.ID,
659 Content: "Permission denied",
660 IsError: true,
661 }
662 for j := i + 1; j < len(toolCalls); j++ {
663 toolResults[j] = message.ToolResult{
664 ToolCallID: toolCalls[j].ID,
665 Content: "Tool execution canceled by user",
666 IsError: true,
667 }
668 }
669 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
670 break
671 }
672 }
673 toolResults[i] = message.ToolResult{
674 ToolCallID: toolCall.ID,
675 Content: toolResponse.Content,
676 Metadata: toolResponse.Metadata,
677 IsError: toolResponse.IsError,
678 }
679 }
680 }
681out:
682 if len(toolResults) == 0 {
683 return assistantMsg, nil, nil
684 }
685 parts := make([]message.ContentPart, 0)
686 for _, tr := range toolResults {
687 parts = append(parts, tr)
688 }
689 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
690 Role: message.Tool,
691 Parts: parts,
692 Provider: a.providerID,
693 })
694 if err != nil {
695 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
696 }
697
698 return assistantMsg, &msg, err
699}
700
701func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
702 msg.AddFinish(finishReason, message, details)
703 _ = a.messages.Update(ctx, *msg)
704}
705
706func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
707 select {
708 case <-ctx.Done():
709 return ctx.Err()
710 default:
711 // Continue processing.
712 }
713
714 switch event.Type {
715 case provider.EventThinkingDelta:
716 assistantMsg.AppendReasoningContent(event.Thinking)
717 return a.messages.Update(ctx, *assistantMsg)
718 case provider.EventSignatureDelta:
719 assistantMsg.AppendReasoningSignature(event.Signature)
720 return a.messages.Update(ctx, *assistantMsg)
721 case provider.EventContentDelta:
722 assistantMsg.FinishThinking()
723 assistantMsg.AppendContent(event.Content)
724 return a.messages.Update(ctx, *assistantMsg)
725 case provider.EventToolUseStart:
726 assistantMsg.FinishThinking()
727 slog.Info("Tool call started", "toolCall", event.ToolCall)
728 assistantMsg.AddToolCall(*event.ToolCall)
729 return a.messages.Update(ctx, *assistantMsg)
730 case provider.EventToolUseDelta:
731 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
732 return a.messages.Update(ctx, *assistantMsg)
733 case provider.EventToolUseStop:
734 slog.Info("Finished tool call", "toolCall", event.ToolCall)
735 assistantMsg.FinishToolCall(event.ToolCall.ID)
736 return a.messages.Update(ctx, *assistantMsg)
737 case provider.EventError:
738 return event.Error
739 case provider.EventComplete:
740 assistantMsg.FinishThinking()
741 assistantMsg.SetToolCalls(event.Response.ToolCalls)
742 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
743 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
744 return fmt.Errorf("failed to update message: %w", err)
745 }
746 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
747 }
748
749 return nil
750}
751
752func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
753 sess, err := a.sessions.Get(ctx, sessionID)
754 if err != nil {
755 return fmt.Errorf("failed to get session: %w", err)
756 }
757
758 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
759 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
760 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
761 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
762
763 a.eventTokensUsed(sessionID, usage, cost)
764
765 sess.Cost += cost
766 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
767 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
768
769 _, err = a.sessions.Save(ctx, sess)
770 if err != nil {
771 return fmt.Errorf("failed to save session: %w", err)
772 }
773 return nil
774}
775
776func (a *agent) Summarize(ctx context.Context, sessionID string) error {
777 if a.summarizeProvider == nil {
778 return fmt.Errorf("summarize provider not available")
779 }
780
781 // Check if session is busy
782 if a.IsSessionBusy(sessionID) {
783 return ErrSessionBusy
784 }
785
786 // Create a new context with cancellation
787 summarizeCtx, cancel := context.WithCancel(ctx)
788
789 // Store the cancel function in activeRequests to allow cancellation
790 a.activeRequests.Set(sessionID+"-summarize", cancel)
791
792 go func() {
793 defer a.activeRequests.Del(sessionID + "-summarize")
794 defer cancel()
795 event := AgentEvent{
796 Type: AgentEventTypeSummarize,
797 Progress: "Starting summarization...",
798 }
799
800 a.Publish(pubsub.CreatedEvent, event)
801 // Get all messages from the session
802 msgs, err := a.messages.List(summarizeCtx, sessionID)
803 if err != nil {
804 event = AgentEvent{
805 Type: AgentEventTypeError,
806 Error: fmt.Errorf("failed to list messages: %w", err),
807 Done: true,
808 }
809 a.Publish(pubsub.CreatedEvent, event)
810 return
811 }
812 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
813
814 if len(msgs) == 0 {
815 event = AgentEvent{
816 Type: AgentEventTypeError,
817 Error: fmt.Errorf("no messages to summarize"),
818 Done: true,
819 }
820 a.Publish(pubsub.CreatedEvent, event)
821 return
822 }
823
824 event = AgentEvent{
825 Type: AgentEventTypeSummarize,
826 Progress: "Analyzing conversation...",
827 }
828 a.Publish(pubsub.CreatedEvent, event)
829
830 // Add a system message to guide the summarization
831 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
832
833 // Create a new message with the summarize prompt
834 promptMsg := message.Message{
835 Role: message.User,
836 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
837 }
838
839 // Append the prompt to the messages
840 msgsWithPrompt := append(msgs, promptMsg)
841
842 event = AgentEvent{
843 Type: AgentEventTypeSummarize,
844 Progress: "Generating summary...",
845 }
846
847 a.Publish(pubsub.CreatedEvent, event)
848
849 // Send the messages to the summarize provider
850 response := a.summarizeProvider.StreamResponse(
851 summarizeCtx,
852 msgsWithPrompt,
853 nil,
854 )
855 var finalResponse *provider.ProviderResponse
856 for r := range response {
857 if r.Error != nil {
858 event = AgentEvent{
859 Type: AgentEventTypeError,
860 Error: fmt.Errorf("failed to summarize: %w", r.Error),
861 Done: true,
862 }
863 a.Publish(pubsub.CreatedEvent, event)
864 return
865 }
866 finalResponse = r.Response
867 }
868
869 summary := strings.TrimSpace(finalResponse.Content)
870 if summary == "" {
871 event = AgentEvent{
872 Type: AgentEventTypeError,
873 Error: fmt.Errorf("empty summary returned"),
874 Done: true,
875 }
876 a.Publish(pubsub.CreatedEvent, event)
877 return
878 }
879 shell := shell.GetPersistentShell(config.Get().WorkingDir())
880 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
881 event = AgentEvent{
882 Type: AgentEventTypeSummarize,
883 Progress: "Creating new session...",
884 }
885
886 a.Publish(pubsub.CreatedEvent, event)
887 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
888 if err != nil {
889 event = AgentEvent{
890 Type: AgentEventTypeError,
891 Error: fmt.Errorf("failed to get session: %w", err),
892 Done: true,
893 }
894
895 a.Publish(pubsub.CreatedEvent, event)
896 return
897 }
898 // Create a message in the new session with the summary
899 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
900 Role: message.Assistant,
901 Parts: []message.ContentPart{
902 message.TextContent{Text: summary},
903 message.Finish{
904 Reason: message.FinishReasonEndTurn,
905 Time: time.Now().Unix(),
906 },
907 },
908 Model: a.summarizeProvider.Model().ID,
909 Provider: a.summarizeProviderID,
910 })
911 if err != nil {
912 event = AgentEvent{
913 Type: AgentEventTypeError,
914 Error: fmt.Errorf("failed to create summary message: %w", err),
915 Done: true,
916 }
917
918 a.Publish(pubsub.CreatedEvent, event)
919 return
920 }
921 oldSession.SummaryMessageID = msg.ID
922 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
923 oldSession.PromptTokens = 0
924 model := a.summarizeProvider.Model()
925 usage := finalResponse.Usage
926 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
927 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
928 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
929 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
930 oldSession.Cost += cost
931 _, err = a.sessions.Save(summarizeCtx, oldSession)
932 if err != nil {
933 event = AgentEvent{
934 Type: AgentEventTypeError,
935 Error: fmt.Errorf("failed to save session: %w", err),
936 Done: true,
937 }
938 a.Publish(pubsub.CreatedEvent, event)
939 }
940
941 event = AgentEvent{
942 Type: AgentEventTypeSummarize,
943 SessionID: oldSession.ID,
944 Progress: "Summary complete",
945 Done: true,
946 }
947 a.Publish(pubsub.CreatedEvent, event)
948 // Send final success event with the new session ID
949 }()
950
951 return nil
952}
953
954func (a *agent) ClearQueue(sessionID string) {
955 if a.QueuedPrompts(sessionID) > 0 {
956 slog.Info("Clearing queued prompts", "session_id", sessionID)
957 a.promptQueue.Del(sessionID)
958 }
959}
960
961func (a *agent) CancelAll() {
962 if !a.IsBusy() {
963 return
964 }
965 for key := range a.activeRequests.Seq2() {
966 a.Cancel(key) // key is sessionID
967 }
968
969 timeout := time.After(5 * time.Second)
970 for a.IsBusy() {
971 select {
972 case <-timeout:
973 return
974 default:
975 time.Sleep(200 * time.Millisecond)
976 }
977 }
978}
979
980func (a *agent) UpdateModel() error {
981 cfg := config.Get()
982
983 // Get current provider configuration
984 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
985 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
986 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
987 }
988
989 // Check if provider has changed
990 if string(currentProviderCfg.ID) != a.providerID {
991 // Provider changed, need to recreate the main provider
992 model := cfg.GetModelByType(a.agentCfg.Model)
993 if model.ID == "" {
994 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
995 }
996
997 promptID := agentPromptMap[a.agentCfg.ID]
998 if promptID == "" {
999 promptID = prompt.PromptDefault
1000 }
1001
1002 opts := []provider.ProviderClientOption{
1003 provider.WithModel(a.agentCfg.Model),
1004 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1005 }
1006
1007 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1008 if err != nil {
1009 return fmt.Errorf("failed to create new provider: %w", err)
1010 }
1011
1012 // Update the provider and provider ID
1013 a.provider = newProvider
1014 a.providerID = string(currentProviderCfg.ID)
1015 }
1016
1017 // Check if providers have changed for title (small) and summarize (large)
1018 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1019 var smallModelProviderCfg config.ProviderConfig
1020 for p := range cfg.Providers.Seq() {
1021 if p.ID == smallModelCfg.Provider {
1022 smallModelProviderCfg = p
1023 break
1024 }
1025 }
1026 if smallModelProviderCfg.ID == "" {
1027 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1028 }
1029
1030 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1031 var largeModelProviderCfg config.ProviderConfig
1032 for p := range cfg.Providers.Seq() {
1033 if p.ID == largeModelCfg.Provider {
1034 largeModelProviderCfg = p
1035 break
1036 }
1037 }
1038 if largeModelProviderCfg.ID == "" {
1039 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1040 }
1041
1042 var maxTitleTokens int64 = 40
1043
1044 // if the max output is too low for the gemini provider it won't return anything
1045 if smallModelCfg.Provider == "gemini" {
1046 maxTitleTokens = 1000
1047 }
1048 // Recreate title provider
1049 titleOpts := []provider.ProviderClientOption{
1050 provider.WithModel(config.SelectedModelTypeSmall),
1051 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1052 provider.WithMaxTokens(maxTitleTokens),
1053 }
1054 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1055 if err != nil {
1056 return fmt.Errorf("failed to create new title provider: %w", err)
1057 }
1058 a.titleProvider = newTitleProvider
1059
1060 // Recreate summarize provider if provider changed (now large model)
1061 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1062 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1063 if largeModel == nil {
1064 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1065 }
1066 summarizeOpts := []provider.ProviderClientOption{
1067 provider.WithModel(config.SelectedModelTypeLarge),
1068 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1069 }
1070 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1071 if err != nil {
1072 return fmt.Errorf("failed to create new summarize provider: %w", err)
1073 }
1074 a.summarizeProvider = newSummarizeProvider
1075 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1076 }
1077
1078 return nil
1079}