1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42func (t AgentEventType) MarshalText() ([]byte, error) {
43 return []byte(t), nil
44}
45
46func (t *AgentEventType) UnmarshalText(text []byte) error {
47 *t = AgentEventType(text)
48 return nil
49}
50
51type AgentEvent struct {
52 Type AgentEventType `json:"type"`
53 Message message.Message `json:"message"`
54 Error error `json:"error,omitempty"`
55
56 // When summarizing
57 SessionID string `json:"session_id,omitempty"`
58 Progress string `json:"progress,omitempty"`
59 Done bool `json:"done,omitempty"`
60}
61
62type Service interface {
63 pubsub.Suscriber[AgentEvent]
64 Model() catwalk.Model
65 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
66 Cancel(sessionID string)
67 CancelAll()
68 IsSessionBusy(sessionID string) bool
69 IsBusy() bool
70 Summarize(ctx context.Context, sessionID string) error
71 UpdateModel() error
72 QueuedPrompts(sessionID string) int
73 ClearQueue(sessionID string)
74}
75
76type agent struct {
77 *pubsub.Broker[AgentEvent]
78 agentCfg config.Agent
79 sessions session.Service
80 messages message.Service
81 mcpTools []McpTool
82
83 tools *csync.LazySlice[tools.BaseTool]
84 // We need this to be able to update it when model changes
85 agentToolFn func() (tools.BaseTool, error)
86
87 provider provider.Provider
88 providerID string
89
90 titleProvider provider.Provider
91 summarizeProvider provider.Provider
92 summarizeProviderID string
93
94 activeRequests *csync.Map[string, context.CancelFunc]
95 promptQueue *csync.Map[string, []string]
96}
97
98var agentPromptMap = map[string]prompt.PromptID{
99 "coder": prompt.PromptCoder,
100 "task": prompt.PromptTask,
101}
102
103func NewAgent(
104 ctx context.Context,
105 agentCfg config.Agent,
106 // These services are needed in the tools
107 permissions permission.Service,
108 sessions session.Service,
109 messages message.Service,
110 history history.Service,
111 lspClients *csync.Map[string, *lsp.Client],
112) (Service, error) {
113 cfg := config.Get()
114
115 var agentToolFn func() (tools.BaseTool, error)
116 if agentCfg.ID == "coder" {
117 agentToolFn = func() (tools.BaseTool, error) {
118 taskAgentCfg := config.Get().Agents["task"]
119 if taskAgentCfg.ID == "" {
120 return nil, fmt.Errorf("task agent not found in config")
121 }
122 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
123 if err != nil {
124 return nil, fmt.Errorf("failed to create task agent: %w", err)
125 }
126 return NewAgentTool(taskAgent, sessions, messages), nil
127 }
128 }
129
130 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
131 if providerCfg == nil {
132 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
133 }
134 model := config.Get().GetModelByType(agentCfg.Model)
135
136 if model == nil {
137 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
138 }
139
140 promptID := agentPromptMap[agentCfg.ID]
141 if promptID == "" {
142 promptID = prompt.PromptDefault
143 }
144 opts := []provider.ProviderClientOption{
145 provider.WithModel(agentCfg.Model),
146 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
147 }
148 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
149 if err != nil {
150 return nil, err
151 }
152
153 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
154 var smallModelProviderCfg *config.ProviderConfig
155 if smallModelCfg.Provider == providerCfg.ID {
156 smallModelProviderCfg = providerCfg
157 } else {
158 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
159
160 if smallModelProviderCfg.ID == "" {
161 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
162 }
163 }
164 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
165 if smallModel.ID == "" {
166 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
167 }
168
169 titleOpts := []provider.ProviderClientOption{
170 provider.WithModel(config.SelectedModelTypeSmall),
171 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
172 }
173 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
174 if err != nil {
175 return nil, err
176 }
177
178 summarizeOpts := []provider.ProviderClientOption{
179 provider.WithModel(config.SelectedModelTypeLarge),
180 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
181 }
182 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
183 if err != nil {
184 return nil, err
185 }
186
187 toolFn := func() []tools.BaseTool {
188 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
189 defer func() {
190 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
191 }()
192
193 cwd := cfg.WorkingDir()
194 allTools := []tools.BaseTool{
195 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
196 tools.NewDownloadTool(permissions, cwd),
197 tools.NewEditTool(lspClients, permissions, history, cwd),
198 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
199 tools.NewFetchTool(permissions, cwd),
200 tools.NewGlobTool(cwd),
201 tools.NewGrepTool(cwd),
202 tools.NewLsTool(permissions, cwd),
203 tools.NewSourcegraphTool(),
204 tools.NewViewTool(lspClients, permissions, cwd),
205 tools.NewWriteTool(lspClients, permissions, history, cwd),
206 }
207
208 mcpToolsOnce.Do(func() {
209 mcpTools = doGetMCPTools(ctx, permissions, cfg)
210 })
211
212 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
213 if agentCfg.ID == "coder" {
214 t = append(t, mcpTools...)
215 if lspClients.Len() > 0 {
216 t = append(t, tools.NewDiagnosticsTool(lspClients))
217 }
218 }
219 return t
220 }
221
222 if agentCfg.AllowedTools == nil {
223 return withCoderTools(allTools)
224 }
225
226 var filteredTools []tools.BaseTool
227 for _, tool := range allTools {
228 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
229 filteredTools = append(filteredTools, tool)
230 }
231 }
232 return withCoderTools(filteredTools)
233 }
234
235 return &agent{
236 Broker: pubsub.NewBroker[AgentEvent](),
237 agentCfg: agentCfg,
238 provider: agentProvider,
239 providerID: string(providerCfg.ID),
240 messages: messages,
241 sessions: sessions,
242 titleProvider: titleProvider,
243 summarizeProvider: summarizeProvider,
244 summarizeProviderID: string(providerCfg.ID),
245 agentToolFn: agentToolFn,
246 activeRequests: csync.NewMap[string, context.CancelFunc](),
247 tools: csync.NewLazySlice(toolFn),
248 promptQueue: csync.NewMap[string, []string](),
249 }, nil
250}
251
252func (a *agent) Model() catwalk.Model {
253 return *config.Get().GetModelByType(a.agentCfg.Model)
254}
255
256func (a *agent) Cancel(sessionID string) {
257 // Cancel regular requests
258 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
259 slog.Info("Request cancellation initiated", "session_id", sessionID)
260 cancel()
261 }
262
263 // Also check for summarize requests
264 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
265 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
266 cancel()
267 }
268
269 if a.QueuedPrompts(sessionID) > 0 {
270 slog.Info("Clearing queued prompts", "session_id", sessionID)
271 a.promptQueue.Del(sessionID)
272 }
273}
274
275func (a *agent) IsBusy() bool {
276 var busy bool
277 for cancelFunc := range a.activeRequests.Seq() {
278 if cancelFunc != nil {
279 busy = true
280 break
281 }
282 }
283 return busy
284}
285
286func (a *agent) IsSessionBusy(sessionID string) bool {
287 _, busy := a.activeRequests.Get(sessionID)
288 return busy
289}
290
291func (a *agent) QueuedPrompts(sessionID string) int {
292 l, ok := a.promptQueue.Get(sessionID)
293 if !ok {
294 return 0
295 }
296 return len(l)
297}
298
299func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
300 if content == "" {
301 return nil
302 }
303 if a.titleProvider == nil {
304 return nil
305 }
306 session, err := a.sessions.Get(ctx, sessionID)
307 if err != nil {
308 return err
309 }
310 parts := []message.ContentPart{message.TextContent{
311 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
312 }}
313
314 // Use streaming approach like summarization
315 response := a.titleProvider.StreamResponse(
316 ctx,
317 []message.Message{
318 {
319 Role: message.User,
320 Parts: parts,
321 },
322 },
323 nil,
324 )
325
326 var finalResponse *provider.ProviderResponse
327 for r := range response {
328 if r.Error != nil {
329 return r.Error
330 }
331 finalResponse = r.Response
332 }
333
334 if finalResponse == nil {
335 return fmt.Errorf("no response received from title provider")
336 }
337
338 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
339
340 if idx := strings.Index(title, "</think>"); idx > 0 {
341 title = title[idx+len("</think>"):]
342 }
343
344 title = strings.TrimSpace(title)
345 if title == "" {
346 return nil
347 }
348
349 session.Title = title
350 _, err = a.sessions.Save(ctx, session)
351 return err
352}
353
354func (a *agent) err(err error) AgentEvent {
355 return AgentEvent{
356 Type: AgentEventTypeError,
357 Error: err,
358 }
359}
360
361func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
362 if !a.Model().SupportsImages && attachments != nil {
363 attachments = nil
364 }
365 events := make(chan AgentEvent, 1)
366 if a.IsSessionBusy(sessionID) {
367 existing, ok := a.promptQueue.Get(sessionID)
368 if !ok {
369 existing = []string{}
370 }
371 existing = append(existing, content)
372 a.promptQueue.Set(sessionID, existing)
373 return nil, nil
374 }
375
376 genCtx, cancel := context.WithCancel(ctx)
377
378 a.activeRequests.Set(sessionID, cancel)
379 go func() {
380 slog.Debug("Request started", "sessionID", sessionID)
381 defer log.RecoverPanic("agent.Run", func() {
382 events <- a.err(fmt.Errorf("panic while running the agent"))
383 })
384 var attachmentParts []message.ContentPart
385 for _, attachment := range attachments {
386 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
387 }
388 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
389 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
390 slog.Error(result.Error.Error())
391 }
392 slog.Debug("Request completed", "sessionID", sessionID)
393 a.activeRequests.Del(sessionID)
394 cancel()
395 a.Publish(pubsub.CreatedEvent, result)
396 events <- result
397 close(events)
398 }()
399 return events, nil
400}
401
402func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
403 cfg := config.Get()
404 // List existing messages; if none, start title generation asynchronously.
405 msgs, err := a.messages.List(ctx, sessionID)
406 if err != nil {
407 return a.err(fmt.Errorf("failed to list messages: %w", err))
408 }
409 if len(msgs) == 0 {
410 go func() {
411 defer log.RecoverPanic("agent.Run", func() {
412 slog.Error("panic while generating title")
413 })
414 titleErr := a.generateTitle(ctx, sessionID, content)
415 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
416 slog.Error("failed to generate title", "error", titleErr)
417 }
418 }()
419 }
420 session, err := a.sessions.Get(ctx, sessionID)
421 if err != nil {
422 return a.err(fmt.Errorf("failed to get session: %w", err))
423 }
424 if session.SummaryMessageID != "" {
425 summaryMsgInex := -1
426 for i, msg := range msgs {
427 if msg.ID == session.SummaryMessageID {
428 summaryMsgInex = i
429 break
430 }
431 }
432 if summaryMsgInex != -1 {
433 msgs = msgs[summaryMsgInex:]
434 msgs[0].Role = message.User
435 }
436 }
437
438 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
439 if err != nil {
440 return a.err(fmt.Errorf("failed to create user message: %w", err))
441 }
442 // Append the new user message to the conversation history.
443 msgHistory := append(msgs, userMsg)
444
445 for {
446 // Check for cancellation before each iteration
447 select {
448 case <-ctx.Done():
449 return a.err(ctx.Err())
450 default:
451 // Continue processing
452 }
453 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
454 if err != nil {
455 if errors.Is(err, context.Canceled) {
456 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
457 a.messages.Update(context.Background(), agentMessage)
458 return a.err(ErrRequestCancelled)
459 }
460 return a.err(fmt.Errorf("failed to process events: %w", err))
461 }
462 if cfg.Options.Debug {
463 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
464 }
465 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
466 // We are not done, we need to respond with the tool response
467 msgHistory = append(msgHistory, agentMessage, *toolResults)
468 // If there are queued prompts, process the next one
469 nextPrompt, ok := a.promptQueue.Take(sessionID)
470 if ok {
471 for _, prompt := range nextPrompt {
472 // Create a new user message for the queued prompt
473 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
474 if err != nil {
475 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
476 }
477 // Append the new user message to the conversation history
478 msgHistory = append(msgHistory, userMsg)
479 }
480 }
481
482 continue
483 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
484 queuePrompts, ok := a.promptQueue.Take(sessionID)
485 if ok {
486 for _, prompt := range queuePrompts {
487 if prompt == "" {
488 continue
489 }
490 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
491 if err != nil {
492 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
493 }
494 msgHistory = append(msgHistory, userMsg)
495 }
496 continue
497 }
498 }
499 if agentMessage.FinishReason() == "" {
500 // Kujtim: could not track down where this is happening but this means its cancelled
501 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
502 _ = a.messages.Update(context.Background(), agentMessage)
503 return a.err(ErrRequestCancelled)
504 }
505 return AgentEvent{
506 Type: AgentEventTypeResponse,
507 Message: agentMessage,
508 Done: true,
509 }
510 }
511}
512
513func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
514 parts := []message.ContentPart{message.TextContent{Text: content}}
515 parts = append(parts, attachmentParts...)
516 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
517 Role: message.User,
518 Parts: parts,
519 })
520}
521
522func (a *agent) getAllTools() ([]tools.BaseTool, error) {
523 allTools := slices.Collect(a.tools.Seq())
524 if a.agentToolFn != nil {
525 agentTool, agentToolErr := a.agentToolFn()
526 if agentToolErr != nil {
527 return nil, agentToolErr
528 }
529 allTools = append(allTools, agentTool)
530 }
531 return allTools, nil
532}
533
534func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
535 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
536
537 // Create the assistant message first so the spinner shows immediately
538 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
539 Role: message.Assistant,
540 Parts: []message.ContentPart{},
541 Model: a.Model().ID,
542 Provider: a.providerID,
543 })
544 if err != nil {
545 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
546 }
547
548 allTools, toolsErr := a.getAllTools()
549 if toolsErr != nil {
550 return assistantMsg, nil, toolsErr
551 }
552 // Now collect tools (which may block on MCP initialization)
553 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
554
555 // Add the session and message ID into the context if needed by tools.
556 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
557
558 // Process each event in the stream.
559 for event := range eventChan {
560 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
561 if errors.Is(processErr, context.Canceled) {
562 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
563 } else {
564 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
565 }
566 return assistantMsg, nil, processErr
567 }
568 if ctx.Err() != nil {
569 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
570 return assistantMsg, nil, ctx.Err()
571 }
572 }
573
574 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
575 toolCalls := assistantMsg.ToolCalls()
576 for i, toolCall := range toolCalls {
577 select {
578 case <-ctx.Done():
579 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
580 // Make all future tool calls cancelled
581 for j := i; j < len(toolCalls); j++ {
582 toolResults[j] = message.ToolResult{
583 ToolCallID: toolCalls[j].ID,
584 Content: "Tool execution canceled by user",
585 IsError: true,
586 }
587 }
588 goto out
589 default:
590 // Continue processing
591 var tool tools.BaseTool
592 allTools, _ := a.getAllTools()
593 for _, availableTool := range allTools {
594 if availableTool.Info().Name == toolCall.Name {
595 tool = availableTool
596 break
597 }
598 }
599
600 // Tool not found
601 if tool == nil {
602 toolResults[i] = message.ToolResult{
603 ToolCallID: toolCall.ID,
604 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
605 IsError: true,
606 }
607 continue
608 }
609
610 // Run tool in goroutine to allow cancellation
611 type toolExecResult struct {
612 response tools.ToolResponse
613 err error
614 }
615 resultChan := make(chan toolExecResult, 1)
616
617 go func() {
618 response, err := tool.Run(ctx, tools.ToolCall{
619 ID: toolCall.ID,
620 Name: toolCall.Name,
621 Input: toolCall.Input,
622 })
623 resultChan <- toolExecResult{response: response, err: err}
624 }()
625
626 var toolResponse tools.ToolResponse
627 var toolErr error
628
629 select {
630 case <-ctx.Done():
631 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
632 // Mark remaining tool calls as cancelled
633 for j := i; j < len(toolCalls); j++ {
634 toolResults[j] = message.ToolResult{
635 ToolCallID: toolCalls[j].ID,
636 Content: "Tool execution canceled by user",
637 IsError: true,
638 }
639 }
640 goto out
641 case result := <-resultChan:
642 toolResponse = result.response
643 toolErr = result.err
644 }
645
646 if toolErr != nil {
647 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
648 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
649 toolResults[i] = message.ToolResult{
650 ToolCallID: toolCall.ID,
651 Content: "Permission denied",
652 IsError: true,
653 }
654 for j := i + 1; j < len(toolCalls); j++ {
655 toolResults[j] = message.ToolResult{
656 ToolCallID: toolCalls[j].ID,
657 Content: "Tool execution canceled by user",
658 IsError: true,
659 }
660 }
661 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
662 break
663 }
664 }
665 toolResults[i] = message.ToolResult{
666 ToolCallID: toolCall.ID,
667 Content: toolResponse.Content,
668 Metadata: toolResponse.Metadata,
669 IsError: toolResponse.IsError,
670 }
671 }
672 }
673out:
674 if len(toolResults) == 0 {
675 return assistantMsg, nil, nil
676 }
677 parts := make([]message.ContentPart, 0)
678 for _, tr := range toolResults {
679 parts = append(parts, tr)
680 }
681 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
682 Role: message.Tool,
683 Parts: parts,
684 Provider: a.providerID,
685 })
686 if err != nil {
687 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
688 }
689
690 return assistantMsg, &msg, err
691}
692
693func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
694 msg.AddFinish(finishReason, message, details)
695 _ = a.messages.Update(ctx, *msg)
696}
697
698func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
699 select {
700 case <-ctx.Done():
701 return ctx.Err()
702 default:
703 // Continue processing.
704 }
705
706 switch event.Type {
707 case provider.EventThinkingDelta:
708 assistantMsg.AppendReasoningContent(event.Thinking)
709 return a.messages.Update(ctx, *assistantMsg)
710 case provider.EventSignatureDelta:
711 assistantMsg.AppendReasoningSignature(event.Signature)
712 return a.messages.Update(ctx, *assistantMsg)
713 case provider.EventContentDelta:
714 assistantMsg.FinishThinking()
715 assistantMsg.AppendContent(event.Content)
716 return a.messages.Update(ctx, *assistantMsg)
717 case provider.EventToolUseStart:
718 assistantMsg.FinishThinking()
719 slog.Info("Tool call started", "toolCall", event.ToolCall)
720 assistantMsg.AddToolCall(*event.ToolCall)
721 return a.messages.Update(ctx, *assistantMsg)
722 case provider.EventToolUseDelta:
723 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
724 return a.messages.Update(ctx, *assistantMsg)
725 case provider.EventToolUseStop:
726 slog.Info("Finished tool call", "toolCall", event.ToolCall)
727 assistantMsg.FinishToolCall(event.ToolCall.ID)
728 return a.messages.Update(ctx, *assistantMsg)
729 case provider.EventError:
730 return event.Error
731 case provider.EventComplete:
732 assistantMsg.FinishThinking()
733 assistantMsg.SetToolCalls(event.Response.ToolCalls)
734 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
735 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
736 return fmt.Errorf("failed to update message: %w", err)
737 }
738 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
739 }
740
741 return nil
742}
743
744func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
745 sess, err := a.sessions.Get(ctx, sessionID)
746 if err != nil {
747 return fmt.Errorf("failed to get session: %w", err)
748 }
749
750 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
751 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
752 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
753 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
754
755 sess.Cost += cost
756 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
757 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
758
759 _, err = a.sessions.Save(ctx, sess)
760 if err != nil {
761 return fmt.Errorf("failed to save session: %w", err)
762 }
763 return nil
764}
765
766func (a *agent) Summarize(ctx context.Context, sessionID string) error {
767 if a.summarizeProvider == nil {
768 return fmt.Errorf("summarize provider not available")
769 }
770
771 // Check if session is busy
772 if a.IsSessionBusy(sessionID) {
773 return ErrSessionBusy
774 }
775
776 // Create a new context with cancellation
777 summarizeCtx, cancel := context.WithCancel(ctx)
778
779 // Store the cancel function in activeRequests to allow cancellation
780 a.activeRequests.Set(sessionID+"-summarize", cancel)
781
782 go func() {
783 defer a.activeRequests.Del(sessionID + "-summarize")
784 defer cancel()
785 event := AgentEvent{
786 Type: AgentEventTypeSummarize,
787 Progress: "Starting summarization...",
788 }
789
790 a.Publish(pubsub.CreatedEvent, event)
791 // Get all messages from the session
792 msgs, err := a.messages.List(summarizeCtx, sessionID)
793 if err != nil {
794 event = AgentEvent{
795 Type: AgentEventTypeError,
796 Error: fmt.Errorf("failed to list messages: %w", err),
797 Done: true,
798 }
799 a.Publish(pubsub.CreatedEvent, event)
800 return
801 }
802 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
803
804 if len(msgs) == 0 {
805 event = AgentEvent{
806 Type: AgentEventTypeError,
807 Error: fmt.Errorf("no messages to summarize"),
808 Done: true,
809 }
810 a.Publish(pubsub.CreatedEvent, event)
811 return
812 }
813
814 event = AgentEvent{
815 Type: AgentEventTypeSummarize,
816 Progress: "Analyzing conversation...",
817 }
818 a.Publish(pubsub.CreatedEvent, event)
819
820 // Add a system message to guide the summarization
821 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
822
823 // Create a new message with the summarize prompt
824 promptMsg := message.Message{
825 Role: message.User,
826 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
827 }
828
829 // Append the prompt to the messages
830 msgsWithPrompt := append(msgs, promptMsg)
831
832 event = AgentEvent{
833 Type: AgentEventTypeSummarize,
834 Progress: "Generating summary...",
835 }
836
837 a.Publish(pubsub.CreatedEvent, event)
838
839 // Send the messages to the summarize provider
840 response := a.summarizeProvider.StreamResponse(
841 summarizeCtx,
842 msgsWithPrompt,
843 nil,
844 )
845 var finalResponse *provider.ProviderResponse
846 for r := range response {
847 if r.Error != nil {
848 event = AgentEvent{
849 Type: AgentEventTypeError,
850 Error: fmt.Errorf("failed to summarize: %w", r.Error),
851 Done: true,
852 }
853 a.Publish(pubsub.CreatedEvent, event)
854 return
855 }
856 finalResponse = r.Response
857 }
858
859 summary := strings.TrimSpace(finalResponse.Content)
860 if summary == "" {
861 event = AgentEvent{
862 Type: AgentEventTypeError,
863 Error: fmt.Errorf("empty summary returned"),
864 Done: true,
865 }
866 a.Publish(pubsub.CreatedEvent, event)
867 return
868 }
869 shell := shell.GetPersistentShell(config.Get().WorkingDir())
870 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
871 event = AgentEvent{
872 Type: AgentEventTypeSummarize,
873 Progress: "Creating new session...",
874 }
875
876 a.Publish(pubsub.CreatedEvent, event)
877 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
878 if err != nil {
879 event = AgentEvent{
880 Type: AgentEventTypeError,
881 Error: fmt.Errorf("failed to get session: %w", err),
882 Done: true,
883 }
884
885 a.Publish(pubsub.CreatedEvent, event)
886 return
887 }
888 // Create a message in the new session with the summary
889 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
890 Role: message.Assistant,
891 Parts: []message.ContentPart{
892 message.TextContent{Text: summary},
893 message.Finish{
894 Reason: message.FinishReasonEndTurn,
895 Time: time.Now().Unix(),
896 },
897 },
898 Model: a.summarizeProvider.Model().ID,
899 Provider: a.summarizeProviderID,
900 })
901 if err != nil {
902 event = AgentEvent{
903 Type: AgentEventTypeError,
904 Error: fmt.Errorf("failed to create summary message: %w", err),
905 Done: true,
906 }
907
908 a.Publish(pubsub.CreatedEvent, event)
909 return
910 }
911 oldSession.SummaryMessageID = msg.ID
912 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
913 oldSession.PromptTokens = 0
914 model := a.summarizeProvider.Model()
915 usage := finalResponse.Usage
916 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
917 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
918 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
919 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
920 oldSession.Cost += cost
921 _, err = a.sessions.Save(summarizeCtx, oldSession)
922 if err != nil {
923 event = AgentEvent{
924 Type: AgentEventTypeError,
925 Error: fmt.Errorf("failed to save session: %w", err),
926 Done: true,
927 }
928 a.Publish(pubsub.CreatedEvent, event)
929 }
930
931 event = AgentEvent{
932 Type: AgentEventTypeSummarize,
933 SessionID: oldSession.ID,
934 Progress: "Summary complete",
935 Done: true,
936 }
937 a.Publish(pubsub.CreatedEvent, event)
938 // Send final success event with the new session ID
939 }()
940
941 return nil
942}
943
944func (a *agent) ClearQueue(sessionID string) {
945 if a.QueuedPrompts(sessionID) > 0 {
946 slog.Info("Clearing queued prompts", "session_id", sessionID)
947 a.promptQueue.Del(sessionID)
948 }
949}
950
951func (a *agent) CancelAll() {
952 if !a.IsBusy() {
953 return
954 }
955 for key := range a.activeRequests.Seq2() {
956 a.Cancel(key) // key is sessionID
957 }
958
959 timeout := time.After(5 * time.Second)
960 for a.IsBusy() {
961 select {
962 case <-timeout:
963 return
964 default:
965 time.Sleep(200 * time.Millisecond)
966 }
967 }
968}
969
970func (a *agent) UpdateModel() error {
971 cfg := config.Get()
972
973 // Get current provider configuration
974 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
975 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
976 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
977 }
978
979 // Check if provider has changed
980 if string(currentProviderCfg.ID) != a.providerID {
981 // Provider changed, need to recreate the main provider
982 model := cfg.GetModelByType(a.agentCfg.Model)
983 if model.ID == "" {
984 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
985 }
986
987 promptID := agentPromptMap[a.agentCfg.ID]
988 if promptID == "" {
989 promptID = prompt.PromptDefault
990 }
991
992 opts := []provider.ProviderClientOption{
993 provider.WithModel(a.agentCfg.Model),
994 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
995 }
996
997 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
998 if err != nil {
999 return fmt.Errorf("failed to create new provider: %w", err)
1000 }
1001
1002 // Update the provider and provider ID
1003 a.provider = newProvider
1004 a.providerID = string(currentProviderCfg.ID)
1005 }
1006
1007 // Check if providers have changed for title (small) and summarize (large)
1008 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1009 var smallModelProviderCfg config.ProviderConfig
1010 for p := range cfg.Providers.Seq() {
1011 if p.ID == smallModelCfg.Provider {
1012 smallModelProviderCfg = p
1013 break
1014 }
1015 }
1016 if smallModelProviderCfg.ID == "" {
1017 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1018 }
1019
1020 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1021 var largeModelProviderCfg config.ProviderConfig
1022 for p := range cfg.Providers.Seq() {
1023 if p.ID == largeModelCfg.Provider {
1024 largeModelProviderCfg = p
1025 break
1026 }
1027 }
1028 if largeModelProviderCfg.ID == "" {
1029 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1030 }
1031
1032 var maxTitleTokens int64 = 40
1033
1034 // if the max output is too low for the gemini provider it won't return anything
1035 if smallModelCfg.Provider == "gemini" {
1036 maxTitleTokens = 1000
1037 }
1038 // Recreate title provider
1039 titleOpts := []provider.ProviderClientOption{
1040 provider.WithModel(config.SelectedModelTypeSmall),
1041 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1042 provider.WithMaxTokens(maxTitleTokens),
1043 }
1044 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1045 if err != nil {
1046 return fmt.Errorf("failed to create new title provider: %w", err)
1047 }
1048 a.titleProvider = newTitleProvider
1049
1050 // Recreate summarize provider if provider changed (now large model)
1051 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1052 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1053 if largeModel == nil {
1054 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1055 }
1056 summarizeOpts := []provider.ProviderClientOption{
1057 provider.WithModel(config.SelectedModelTypeLarge),
1058 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1059 }
1060 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1061 if err != nil {
1062 return fmt.Errorf("failed to create new summarize provider: %w", err)
1063 }
1064 a.summarizeProvider = newSummarizeProvider
1065 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1066 }
1067
1068 return nil
1069}