1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.LazySlice[tools.BaseTool]
75 // We need this to be able to update it when model changes
76 agentToolFn func() (tools.BaseTool, error)
77
78 provider provider.Provider
79 providerID string
80
81 titleProvider provider.Provider
82 summarizeProvider provider.Provider
83 summarizeProviderID string
84
85 activeRequests *csync.Map[string, context.CancelFunc]
86 promptQueue *csync.Map[string, []string]
87}
88
89var agentPromptMap = map[string]prompt.PromptID{
90 "coder": prompt.PromptCoder,
91 "task": prompt.PromptTask,
92}
93
94func NewAgent(
95 ctx context.Context,
96 agentCfg config.Agent,
97 // These services are needed in the tools
98 permissions permission.Service,
99 sessions session.Service,
100 messages message.Service,
101 history history.Service,
102 lspClients *csync.Map[string, *lsp.Client],
103) (Service, error) {
104 cfg := config.Get()
105
106 var agentToolFn func() (tools.BaseTool, error)
107 if agentCfg.ID == "coder" {
108 agentToolFn = func() (tools.BaseTool, error) {
109 taskAgentCfg := config.Get().Agents["task"]
110 if taskAgentCfg.ID == "" {
111 return nil, fmt.Errorf("task agent not found in config")
112 }
113 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
114 if err != nil {
115 return nil, fmt.Errorf("failed to create task agent: %w", err)
116 }
117 return NewAgentTool(taskAgent, sessions, messages), nil
118 }
119 }
120
121 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
122 if providerCfg == nil {
123 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
124 }
125 model := config.Get().GetModelByType(agentCfg.Model)
126
127 if model == nil {
128 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
129 }
130
131 promptID := agentPromptMap[agentCfg.ID]
132 if promptID == "" {
133 promptID = prompt.PromptDefault
134 }
135 opts := []provider.ProviderClientOption{
136 provider.WithModel(agentCfg.Model),
137 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
138 }
139 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
140 if err != nil {
141 return nil, err
142 }
143
144 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
145 var smallModelProviderCfg *config.ProviderConfig
146 if smallModelCfg.Provider == providerCfg.ID {
147 smallModelProviderCfg = providerCfg
148 } else {
149 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
150
151 if smallModelProviderCfg.ID == "" {
152 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
153 }
154 }
155 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
156 if smallModel.ID == "" {
157 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
158 }
159
160 titleOpts := []provider.ProviderClientOption{
161 provider.WithModel(config.SelectedModelTypeSmall),
162 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
163 }
164 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
165 if err != nil {
166 return nil, err
167 }
168
169 summarizeOpts := []provider.ProviderClientOption{
170 provider.WithModel(config.SelectedModelTypeLarge),
171 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
172 }
173 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
174 if err != nil {
175 return nil, err
176 }
177
178 toolFn := func() []tools.BaseTool {
179 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
180 defer func() {
181 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
182 }()
183
184 cwd := cfg.WorkingDir()
185 allTools := []tools.BaseTool{
186 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
187 tools.NewDownloadTool(permissions, cwd),
188 tools.NewEditTool(lspClients, permissions, history, cwd),
189 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
190 tools.NewFetchTool(permissions, cwd),
191 tools.NewGlobTool(cwd),
192 tools.NewGrepTool(cwd),
193 tools.NewLsTool(permissions, cwd),
194 tools.NewSourcegraphTool(),
195 tools.NewViewTool(lspClients, permissions, cwd),
196 tools.NewWriteTool(lspClients, permissions, history, cwd),
197 }
198
199 mcpToolsOnce.Do(func() {
200 mcpTools = doGetMCPTools(ctx, permissions, cfg)
201 })
202
203 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
204 if agentCfg.ID == "coder" {
205 t = append(t, mcpTools...)
206 if lspClients.Len() > 0 {
207 t = append(t, tools.NewDiagnosticsTool(lspClients))
208 }
209 }
210 return t
211 }
212
213 if agentCfg.AllowedTools == nil {
214 return withCoderTools(allTools)
215 }
216
217 var filteredTools []tools.BaseTool
218 for _, tool := range allTools {
219 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
220 filteredTools = append(filteredTools, tool)
221 }
222 }
223 return withCoderTools(filteredTools)
224 }
225
226 return &agent{
227 Broker: pubsub.NewBroker[AgentEvent](),
228 agentCfg: agentCfg,
229 provider: agentProvider,
230 providerID: string(providerCfg.ID),
231 messages: messages,
232 sessions: sessions,
233 titleProvider: titleProvider,
234 summarizeProvider: summarizeProvider,
235 summarizeProviderID: string(providerCfg.ID),
236 agentToolFn: agentToolFn,
237 activeRequests: csync.NewMap[string, context.CancelFunc](),
238 tools: csync.NewLazySlice(toolFn),
239 promptQueue: csync.NewMap[string, []string](),
240 }, nil
241}
242
243func (a *agent) Model() catwalk.Model {
244 return *config.Get().GetModelByType(a.agentCfg.Model)
245}
246
247func (a *agent) Cancel(sessionID string) {
248 // Cancel regular requests
249 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
250 slog.Info("Request cancellation initiated", "session_id", sessionID)
251 cancel()
252 }
253
254 // Also check for summarize requests
255 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
256 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
257 cancel()
258 }
259
260 if a.QueuedPrompts(sessionID) > 0 {
261 slog.Info("Clearing queued prompts", "session_id", sessionID)
262 a.promptQueue.Del(sessionID)
263 }
264}
265
266func (a *agent) IsBusy() bool {
267 var busy bool
268 for cancelFunc := range a.activeRequests.Seq() {
269 if cancelFunc != nil {
270 busy = true
271 break
272 }
273 }
274 return busy
275}
276
277func (a *agent) IsSessionBusy(sessionID string) bool {
278 _, busy := a.activeRequests.Get(sessionID)
279 return busy
280}
281
282func (a *agent) QueuedPrompts(sessionID string) int {
283 l, ok := a.promptQueue.Get(sessionID)
284 if !ok {
285 return 0
286 }
287 return len(l)
288}
289
290func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
291 if content == "" {
292 return nil
293 }
294 if a.titleProvider == nil {
295 return nil
296 }
297 session, err := a.sessions.Get(ctx, sessionID)
298 if err != nil {
299 return err
300 }
301 parts := []message.ContentPart{message.TextContent{
302 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
303 }}
304
305 // Use streaming approach like summarization
306 response := a.titleProvider.StreamResponse(
307 ctx,
308 []message.Message{
309 {
310 Role: message.User,
311 Parts: parts,
312 },
313 },
314 nil,
315 )
316
317 var finalResponse *provider.ProviderResponse
318 for r := range response {
319 if r.Error != nil {
320 return r.Error
321 }
322 finalResponse = r.Response
323 }
324
325 if finalResponse == nil {
326 return fmt.Errorf("no response received from title provider")
327 }
328
329 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
330 if title == "" {
331 return nil
332 }
333
334 session.Title = title
335 _, err = a.sessions.Save(ctx, session)
336 return err
337}
338
339func (a *agent) err(err error) AgentEvent {
340 return AgentEvent{
341 Type: AgentEventTypeError,
342 Error: err,
343 }
344}
345
346func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
347 if !a.Model().SupportsImages && attachments != nil {
348 attachments = nil
349 }
350 events := make(chan AgentEvent, 1)
351 if a.IsSessionBusy(sessionID) {
352 existing, ok := a.promptQueue.Get(sessionID)
353 if !ok {
354 existing = []string{}
355 }
356 existing = append(existing, content)
357 a.promptQueue.Set(sessionID, existing)
358 return nil, nil
359 }
360
361 genCtx, cancel := context.WithCancel(ctx)
362
363 a.activeRequests.Set(sessionID, cancel)
364 go func() {
365 slog.Debug("Request started", "sessionID", sessionID)
366 defer log.RecoverPanic("agent.Run", func() {
367 events <- a.err(fmt.Errorf("panic while running the agent"))
368 })
369 var attachmentParts []message.ContentPart
370 for _, attachment := range attachments {
371 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
372 }
373 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
374 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
375 slog.Error(result.Error.Error())
376 }
377 slog.Debug("Request completed", "sessionID", sessionID)
378 a.activeRequests.Del(sessionID)
379 cancel()
380 a.Publish(pubsub.CreatedEvent, result)
381 events <- result
382 close(events)
383 }()
384 return events, nil
385}
386
387func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
388 cfg := config.Get()
389 // List existing messages; if none, start title generation asynchronously.
390 msgs, err := a.messages.List(ctx, sessionID)
391 if err != nil {
392 return a.err(fmt.Errorf("failed to list messages: %w", err))
393 }
394 if len(msgs) == 0 {
395 go func() {
396 defer log.RecoverPanic("agent.Run", func() {
397 slog.Error("panic while generating title")
398 })
399 titleErr := a.generateTitle(ctx, sessionID, content)
400 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
401 slog.Error("failed to generate title", "error", titleErr)
402 }
403 }()
404 }
405 session, err := a.sessions.Get(ctx, sessionID)
406 if err != nil {
407 return a.err(fmt.Errorf("failed to get session: %w", err))
408 }
409 if session.SummaryMessageID != "" {
410 summaryMsgInex := -1
411 for i, msg := range msgs {
412 if msg.ID == session.SummaryMessageID {
413 summaryMsgInex = i
414 break
415 }
416 }
417 if summaryMsgInex != -1 {
418 msgs = msgs[summaryMsgInex:]
419 msgs[0].Role = message.User
420 }
421 }
422
423 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
424 if err != nil {
425 return a.err(fmt.Errorf("failed to create user message: %w", err))
426 }
427 // Append the new user message to the conversation history.
428 msgHistory := append(msgs, userMsg)
429
430 for {
431 // Check for cancellation before each iteration
432 select {
433 case <-ctx.Done():
434 return a.err(ctx.Err())
435 default:
436 // Continue processing
437 }
438 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
439 if err != nil {
440 if errors.Is(err, context.Canceled) {
441 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
442 a.messages.Update(context.Background(), agentMessage)
443 return a.err(ErrRequestCancelled)
444 }
445 return a.err(fmt.Errorf("failed to process events: %w", err))
446 }
447 if cfg.Options.Debug {
448 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
449 }
450 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
451 // We are not done, we need to respond with the tool response
452 msgHistory = append(msgHistory, agentMessage, *toolResults)
453 // If there are queued prompts, process the next one
454 nextPrompt, ok := a.promptQueue.Take(sessionID)
455 if ok {
456 for _, prompt := range nextPrompt {
457 // Create a new user message for the queued prompt
458 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
459 if err != nil {
460 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
461 }
462 // Append the new user message to the conversation history
463 msgHistory = append(msgHistory, userMsg)
464 }
465 }
466
467 continue
468 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
469 queuePrompts, ok := a.promptQueue.Take(sessionID)
470 if ok {
471 for _, prompt := range queuePrompts {
472 if prompt == "" {
473 continue
474 }
475 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
476 if err != nil {
477 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
478 }
479 msgHistory = append(msgHistory, userMsg)
480 }
481 continue
482 }
483 }
484 if agentMessage.FinishReason() == "" {
485 // Kujtim: could not track down where this is happening but this means its cancelled
486 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
487 _ = a.messages.Update(context.Background(), agentMessage)
488 return a.err(ErrRequestCancelled)
489 }
490 return AgentEvent{
491 Type: AgentEventTypeResponse,
492 Message: agentMessage,
493 Done: true,
494 }
495 }
496}
497
498func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
499 parts := []message.ContentPart{message.TextContent{Text: content}}
500 parts = append(parts, attachmentParts...)
501 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
502 Role: message.User,
503 Parts: parts,
504 })
505}
506
507func (a *agent) getAllTools() ([]tools.BaseTool, error) {
508 allTools := slices.Collect(a.tools.Seq())
509 if a.agentToolFn != nil {
510 agentTool, agentToolErr := a.agentToolFn()
511 if agentToolErr != nil {
512 return nil, agentToolErr
513 }
514 allTools = append(allTools, agentTool)
515 }
516 return allTools, nil
517}
518
519func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
520 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
521
522 // Create the assistant message first so the spinner shows immediately
523 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
524 Role: message.Assistant,
525 Parts: []message.ContentPart{},
526 Model: a.Model().ID,
527 Provider: a.providerID,
528 })
529 if err != nil {
530 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
531 }
532
533 allTools, toolsErr := a.getAllTools()
534 if toolsErr != nil {
535 return assistantMsg, nil, toolsErr
536 }
537 // Now collect tools (which may block on MCP initialization)
538 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
539
540 // Add the session and message ID into the context if needed by tools.
541 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
542
543 // Process each event in the stream.
544 for event := range eventChan {
545 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
546 if errors.Is(processErr, context.Canceled) {
547 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
548 } else {
549 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
550 }
551 return assistantMsg, nil, processErr
552 }
553 if ctx.Err() != nil {
554 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
555 return assistantMsg, nil, ctx.Err()
556 }
557 }
558
559 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
560 toolCalls := assistantMsg.ToolCalls()
561 for i, toolCall := range toolCalls {
562 select {
563 case <-ctx.Done():
564 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
565 // Make all future tool calls cancelled
566 for j := i; j < len(toolCalls); j++ {
567 toolResults[j] = message.ToolResult{
568 ToolCallID: toolCalls[j].ID,
569 Content: "Tool execution canceled by user",
570 IsError: true,
571 }
572 }
573 goto out
574 default:
575 // Continue processing
576 var tool tools.BaseTool
577 allTools, _ := a.getAllTools()
578 for _, availableTool := range allTools {
579 if availableTool.Info().Name == toolCall.Name {
580 tool = availableTool
581 break
582 }
583 }
584
585 // Tool not found
586 if tool == nil {
587 toolResults[i] = message.ToolResult{
588 ToolCallID: toolCall.ID,
589 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
590 IsError: true,
591 }
592 continue
593 }
594
595 // Run tool in goroutine to allow cancellation
596 type toolExecResult struct {
597 response tools.ToolResponse
598 err error
599 }
600 resultChan := make(chan toolExecResult, 1)
601
602 go func() {
603 response, err := tool.Run(ctx, tools.ToolCall{
604 ID: toolCall.ID,
605 Name: toolCall.Name,
606 Input: toolCall.Input,
607 })
608 resultChan <- toolExecResult{response: response, err: err}
609 }()
610
611 var toolResponse tools.ToolResponse
612 var toolErr error
613
614 select {
615 case <-ctx.Done():
616 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
617 // Mark remaining tool calls as cancelled
618 for j := i; j < len(toolCalls); j++ {
619 toolResults[j] = message.ToolResult{
620 ToolCallID: toolCalls[j].ID,
621 Content: "Tool execution canceled by user",
622 IsError: true,
623 }
624 }
625 goto out
626 case result := <-resultChan:
627 toolResponse = result.response
628 toolErr = result.err
629 }
630
631 if toolErr != nil {
632 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
633 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
634 toolResults[i] = message.ToolResult{
635 ToolCallID: toolCall.ID,
636 Content: "Permission denied",
637 IsError: true,
638 }
639 for j := i + 1; j < len(toolCalls); j++ {
640 toolResults[j] = message.ToolResult{
641 ToolCallID: toolCalls[j].ID,
642 Content: "Tool execution canceled by user",
643 IsError: true,
644 }
645 }
646 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
647 break
648 }
649 }
650 toolResults[i] = message.ToolResult{
651 ToolCallID: toolCall.ID,
652 Content: toolResponse.Content,
653 Metadata: toolResponse.Metadata,
654 IsError: toolResponse.IsError,
655 }
656 }
657 }
658out:
659 if len(toolResults) == 0 {
660 return assistantMsg, nil, nil
661 }
662 parts := make([]message.ContentPart, 0)
663 for _, tr := range toolResults {
664 parts = append(parts, tr)
665 }
666 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
667 Role: message.Tool,
668 Parts: parts,
669 Provider: a.providerID,
670 })
671 if err != nil {
672 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
673 }
674
675 return assistantMsg, &msg, err
676}
677
678func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
679 msg.AddFinish(finishReason, message, details)
680 _ = a.messages.Update(ctx, *msg)
681}
682
683func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
684 select {
685 case <-ctx.Done():
686 return ctx.Err()
687 default:
688 // Continue processing.
689 }
690
691 switch event.Type {
692 case provider.EventThinkingDelta:
693 assistantMsg.AppendReasoningContent(event.Thinking)
694 return a.messages.Update(ctx, *assistantMsg)
695 case provider.EventSignatureDelta:
696 assistantMsg.AppendReasoningSignature(event.Signature)
697 return a.messages.Update(ctx, *assistantMsg)
698 case provider.EventContentDelta:
699 assistantMsg.FinishThinking()
700 assistantMsg.AppendContent(event.Content)
701 return a.messages.Update(ctx, *assistantMsg)
702 case provider.EventToolUseStart:
703 assistantMsg.FinishThinking()
704 slog.Info("Tool call started", "toolCall", event.ToolCall)
705 assistantMsg.AddToolCall(*event.ToolCall)
706 return a.messages.Update(ctx, *assistantMsg)
707 case provider.EventToolUseDelta:
708 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
709 return a.messages.Update(ctx, *assistantMsg)
710 case provider.EventToolUseStop:
711 slog.Info("Finished tool call", "toolCall", event.ToolCall)
712 assistantMsg.FinishToolCall(event.ToolCall.ID)
713 return a.messages.Update(ctx, *assistantMsg)
714 case provider.EventError:
715 return event.Error
716 case provider.EventComplete:
717 assistantMsg.FinishThinking()
718 assistantMsg.SetToolCalls(event.Response.ToolCalls)
719 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
720 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
721 return fmt.Errorf("failed to update message: %w", err)
722 }
723 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
724 }
725
726 return nil
727}
728
729func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
730 sess, err := a.sessions.Get(ctx, sessionID)
731 if err != nil {
732 return fmt.Errorf("failed to get session: %w", err)
733 }
734
735 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
736 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
737 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
738 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
739
740 sess.Cost += cost
741 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
742 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
743
744 _, err = a.sessions.Save(ctx, sess)
745 if err != nil {
746 return fmt.Errorf("failed to save session: %w", err)
747 }
748 return nil
749}
750
751func (a *agent) Summarize(ctx context.Context, sessionID string) error {
752 if a.summarizeProvider == nil {
753 return fmt.Errorf("summarize provider not available")
754 }
755
756 // Check if session is busy
757 if a.IsSessionBusy(sessionID) {
758 return ErrSessionBusy
759 }
760
761 // Create a new context with cancellation
762 summarizeCtx, cancel := context.WithCancel(ctx)
763
764 // Store the cancel function in activeRequests to allow cancellation
765 a.activeRequests.Set(sessionID+"-summarize", cancel)
766
767 go func() {
768 defer a.activeRequests.Del(sessionID + "-summarize")
769 defer cancel()
770 event := AgentEvent{
771 Type: AgentEventTypeSummarize,
772 Progress: "Starting summarization...",
773 }
774
775 a.Publish(pubsub.CreatedEvent, event)
776 // Get all messages from the session
777 msgs, err := a.messages.List(summarizeCtx, sessionID)
778 if err != nil {
779 event = AgentEvent{
780 Type: AgentEventTypeError,
781 Error: fmt.Errorf("failed to list messages: %w", err),
782 Done: true,
783 }
784 a.Publish(pubsub.CreatedEvent, event)
785 return
786 }
787 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
788
789 if len(msgs) == 0 {
790 event = AgentEvent{
791 Type: AgentEventTypeError,
792 Error: fmt.Errorf("no messages to summarize"),
793 Done: true,
794 }
795 a.Publish(pubsub.CreatedEvent, event)
796 return
797 }
798
799 event = AgentEvent{
800 Type: AgentEventTypeSummarize,
801 Progress: "Analyzing conversation...",
802 }
803 a.Publish(pubsub.CreatedEvent, event)
804
805 // Add a system message to guide the summarization
806 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
807
808 // Create a new message with the summarize prompt
809 promptMsg := message.Message{
810 Role: message.User,
811 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
812 }
813
814 // Append the prompt to the messages
815 msgsWithPrompt := append(msgs, promptMsg)
816
817 event = AgentEvent{
818 Type: AgentEventTypeSummarize,
819 Progress: "Generating summary...",
820 }
821
822 a.Publish(pubsub.CreatedEvent, event)
823
824 // Send the messages to the summarize provider
825 response := a.summarizeProvider.StreamResponse(
826 summarizeCtx,
827 msgsWithPrompt,
828 nil,
829 )
830 var finalResponse *provider.ProviderResponse
831 for r := range response {
832 if r.Error != nil {
833 event = AgentEvent{
834 Type: AgentEventTypeError,
835 Error: fmt.Errorf("failed to summarize: %w", err),
836 Done: true,
837 }
838 a.Publish(pubsub.CreatedEvent, event)
839 return
840 }
841 finalResponse = r.Response
842 }
843
844 summary := strings.TrimSpace(finalResponse.Content)
845 if summary == "" {
846 event = AgentEvent{
847 Type: AgentEventTypeError,
848 Error: fmt.Errorf("empty summary returned"),
849 Done: true,
850 }
851 a.Publish(pubsub.CreatedEvent, event)
852 return
853 }
854 shell := shell.GetPersistentShell(config.Get().WorkingDir())
855 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
856 event = AgentEvent{
857 Type: AgentEventTypeSummarize,
858 Progress: "Creating new session...",
859 }
860
861 a.Publish(pubsub.CreatedEvent, event)
862 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
863 if err != nil {
864 event = AgentEvent{
865 Type: AgentEventTypeError,
866 Error: fmt.Errorf("failed to get session: %w", err),
867 Done: true,
868 }
869
870 a.Publish(pubsub.CreatedEvent, event)
871 return
872 }
873 // Create a message in the new session with the summary
874 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
875 Role: message.Assistant,
876 Parts: []message.ContentPart{
877 message.TextContent{Text: summary},
878 message.Finish{
879 Reason: message.FinishReasonEndTurn,
880 Time: time.Now().Unix(),
881 },
882 },
883 Model: a.summarizeProvider.Model().ID,
884 Provider: a.summarizeProviderID,
885 })
886 if err != nil {
887 event = AgentEvent{
888 Type: AgentEventTypeError,
889 Error: fmt.Errorf("failed to create summary message: %w", err),
890 Done: true,
891 }
892
893 a.Publish(pubsub.CreatedEvent, event)
894 return
895 }
896 oldSession.SummaryMessageID = msg.ID
897 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
898 oldSession.PromptTokens = 0
899 model := a.summarizeProvider.Model()
900 usage := finalResponse.Usage
901 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
902 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
903 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
904 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
905 oldSession.Cost += cost
906 _, err = a.sessions.Save(summarizeCtx, oldSession)
907 if err != nil {
908 event = AgentEvent{
909 Type: AgentEventTypeError,
910 Error: fmt.Errorf("failed to save session: %w", err),
911 Done: true,
912 }
913 a.Publish(pubsub.CreatedEvent, event)
914 }
915
916 event = AgentEvent{
917 Type: AgentEventTypeSummarize,
918 SessionID: oldSession.ID,
919 Progress: "Summary complete",
920 Done: true,
921 }
922 a.Publish(pubsub.CreatedEvent, event)
923 // Send final success event with the new session ID
924 }()
925
926 return nil
927}
928
929func (a *agent) ClearQueue(sessionID string) {
930 if a.QueuedPrompts(sessionID) > 0 {
931 slog.Info("Clearing queued prompts", "session_id", sessionID)
932 a.promptQueue.Del(sessionID)
933 }
934}
935
936func (a *agent) CancelAll() {
937 if !a.IsBusy() {
938 return
939 }
940 for key := range a.activeRequests.Seq2() {
941 a.Cancel(key) // key is sessionID
942 }
943
944 timeout := time.After(5 * time.Second)
945 for a.IsBusy() {
946 select {
947 case <-timeout:
948 return
949 default:
950 time.Sleep(200 * time.Millisecond)
951 }
952 }
953}
954
955func (a *agent) UpdateModel() error {
956 cfg := config.Get()
957
958 // Get current provider configuration
959 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
960 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
961 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
962 }
963
964 // Check if provider has changed
965 if string(currentProviderCfg.ID) != a.providerID {
966 // Provider changed, need to recreate the main provider
967 model := cfg.GetModelByType(a.agentCfg.Model)
968 if model.ID == "" {
969 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
970 }
971
972 promptID := agentPromptMap[a.agentCfg.ID]
973 if promptID == "" {
974 promptID = prompt.PromptDefault
975 }
976
977 opts := []provider.ProviderClientOption{
978 provider.WithModel(a.agentCfg.Model),
979 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
980 }
981
982 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
983 if err != nil {
984 return fmt.Errorf("failed to create new provider: %w", err)
985 }
986
987 // Update the provider and provider ID
988 a.provider = newProvider
989 a.providerID = string(currentProviderCfg.ID)
990 }
991
992 // Check if providers have changed for title (small) and summarize (large)
993 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
994 var smallModelProviderCfg config.ProviderConfig
995 for p := range cfg.Providers.Seq() {
996 if p.ID == smallModelCfg.Provider {
997 smallModelProviderCfg = p
998 break
999 }
1000 }
1001 if smallModelProviderCfg.ID == "" {
1002 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1003 }
1004
1005 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1006 var largeModelProviderCfg config.ProviderConfig
1007 for p := range cfg.Providers.Seq() {
1008 if p.ID == largeModelCfg.Provider {
1009 largeModelProviderCfg = p
1010 break
1011 }
1012 }
1013 if largeModelProviderCfg.ID == "" {
1014 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1015 }
1016
1017 var maxTitleTokens int64 = 40
1018
1019 // if the max output is too low for the gemini provider it won't return anything
1020 if smallModelCfg.Provider == "gemini" {
1021 maxTitleTokens = 1000
1022 }
1023 // Recreate title provider
1024 titleOpts := []provider.ProviderClientOption{
1025 provider.WithModel(config.SelectedModelTypeSmall),
1026 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1027 provider.WithMaxTokens(maxTitleTokens),
1028 }
1029 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1030 if err != nil {
1031 return fmt.Errorf("failed to create new title provider: %w", err)
1032 }
1033 a.titleProvider = newTitleProvider
1034
1035 // Recreate summarize provider if provider changed (now large model)
1036 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1037 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1038 if largeModel == nil {
1039 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1040 }
1041 summarizeOpts := []provider.ProviderClientOption{
1042 provider.WithModel(config.SelectedModelTypeLarge),
1043 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1044 }
1045 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1046 if err != nil {
1047 return fmt.Errorf("failed to create new summarize provider: %w", err)
1048 }
1049 a.summarizeProvider = newSummarizeProvider
1050 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1051 }
1052
1053 return nil
1054}