1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.LazySlice[tools.BaseTool]
75 // We need this to be able to update it when model changes
76 agentToolFn func() (tools.BaseTool, error)
77
78 provider provider.Provider
79 providerID string
80
81 titleProvider provider.Provider
82 summarizeProvider provider.Provider
83 summarizeProviderID string
84
85 activeRequests *csync.Map[string, context.CancelFunc]
86
87 promptQueue *csync.Map[string, []string]
88}
89
90var agentPromptMap = map[string]prompt.PromptID{
91 "coder": prompt.PromptCoder,
92 "task": prompt.PromptTask,
93}
94
95func NewAgent(
96 ctx context.Context,
97 agentCfg config.Agent,
98 // These services are needed in the tools
99 permissions permission.Service,
100 sessions session.Service,
101 messages message.Service,
102 history history.Service,
103 lspClients map[string]*lsp.Client,
104) (Service, error) {
105 cfg := config.Get()
106
107 var agentToolFn func() (tools.BaseTool, error)
108 if agentCfg.ID == "coder" {
109 agentToolFn = func() (tools.BaseTool, error) {
110 taskAgentCfg := config.Get().Agents["task"]
111 if taskAgentCfg.ID == "" {
112 return nil, fmt.Errorf("task agent not found in config")
113 }
114 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
115 if err != nil {
116 return nil, fmt.Errorf("failed to create task agent: %w", err)
117 }
118 return NewAgentTool(taskAgent, sessions, messages), nil
119 }
120 }
121
122 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
123 if providerCfg == nil {
124 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
125 }
126 model := config.Get().GetModelByType(agentCfg.Model)
127
128 if model == nil {
129 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
130 }
131
132 promptID := agentPromptMap[agentCfg.ID]
133 if promptID == "" {
134 promptID = prompt.PromptDefault
135 }
136 opts := []provider.ProviderClientOption{
137 provider.WithModel(agentCfg.Model),
138 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
139 }
140 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
141 if err != nil {
142 return nil, err
143 }
144
145 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
146 var smallModelProviderCfg *config.ProviderConfig
147 if smallModelCfg.Provider == providerCfg.ID {
148 smallModelProviderCfg = providerCfg
149 } else {
150 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
151
152 if smallModelProviderCfg.ID == "" {
153 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
154 }
155 }
156 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
157 if smallModel.ID == "" {
158 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
159 }
160
161 titleOpts := []provider.ProviderClientOption{
162 provider.WithModel(config.SelectedModelTypeSmall),
163 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
164 }
165 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
166 if err != nil {
167 return nil, err
168 }
169
170 summarizeOpts := []provider.ProviderClientOption{
171 provider.WithModel(config.SelectedModelTypeLarge),
172 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
173 }
174 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
175 if err != nil {
176 return nil, err
177 }
178
179 toolFn := func() []tools.BaseTool {
180 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
181 defer func() {
182 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
183 }()
184
185 cwd := cfg.WorkingDir()
186 allTools := []tools.BaseTool{
187 tools.NewBashTool(permissions, cwd),
188 tools.NewDownloadTool(permissions, cwd),
189 tools.NewEditTool(lspClients, permissions, history, cwd),
190 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
191 tools.NewFetchTool(permissions, cwd),
192 tools.NewGlobTool(cwd),
193 tools.NewGrepTool(cwd),
194 tools.NewLsTool(permissions, cwd),
195 tools.NewSourcegraphTool(),
196 tools.NewViewTool(lspClients, permissions, cwd),
197 tools.NewWriteTool(lspClients, permissions, history, cwd),
198 }
199
200 if agentCfg.AllowedTools == nil {
201 return allTools
202 }
203
204 var filteredTools []tools.BaseTool
205 for _, tool := range allTools {
206 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
207 filteredTools = append(filteredTools, tool)
208 }
209 }
210
211 if agentCfg.ID == "coder" {
212 mcpToolsOnce.Do(func() {
213 mcpTools = doGetMCPTools(ctx, permissions, cfg)
214 })
215 filteredTools = append(filteredTools, mcpTools...)
216 if len(lspClients) > 0 {
217 filteredTools = append(filteredTools, tools.NewDiagnosticsTool(lspClients))
218 }
219
220 }
221 return filteredTools
222 }
223
224 return &agent{
225 Broker: pubsub.NewBroker[AgentEvent](),
226 agentCfg: agentCfg,
227 provider: agentProvider,
228 providerID: string(providerCfg.ID),
229 messages: messages,
230 sessions: sessions,
231 titleProvider: titleProvider,
232 summarizeProvider: summarizeProvider,
233 summarizeProviderID: string(providerCfg.ID),
234 agentToolFn: agentToolFn,
235 activeRequests: csync.NewMap[string, context.CancelFunc](),
236 tools: csync.NewLazySlice(toolFn),
237 promptQueue: csync.NewMap[string, []string](),
238 }, nil
239}
240
241func (a *agent) Model() catwalk.Model {
242 return *config.Get().GetModelByType(a.agentCfg.Model)
243}
244
245func (a *agent) Cancel(sessionID string) {
246 // Cancel regular requests
247 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
248 slog.Info("Request cancellation initiated", "session_id", sessionID)
249 cancel()
250 }
251
252 // Also check for summarize requests
253 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
254 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
255 cancel()
256 }
257
258 if a.QueuedPrompts(sessionID) > 0 {
259 slog.Info("Clearing queued prompts", "session_id", sessionID)
260 a.promptQueue.Del(sessionID)
261 }
262}
263
264func (a *agent) IsBusy() bool {
265 var busy bool
266 for cancelFunc := range a.activeRequests.Seq() {
267 if cancelFunc != nil {
268 busy = true
269 break
270 }
271 }
272 return busy
273}
274
275func (a *agent) IsSessionBusy(sessionID string) bool {
276 _, busy := a.activeRequests.Get(sessionID)
277 return busy
278}
279
280func (a *agent) QueuedPrompts(sessionID string) int {
281 l, ok := a.promptQueue.Get(sessionID)
282 if !ok {
283 return 0
284 }
285 return len(l)
286}
287
288func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
289 if content == "" {
290 return nil
291 }
292 if a.titleProvider == nil {
293 return nil
294 }
295 session, err := a.sessions.Get(ctx, sessionID)
296 if err != nil {
297 return err
298 }
299 parts := []message.ContentPart{message.TextContent{
300 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
301 }}
302
303 // Use streaming approach like summarization
304 response := a.titleProvider.StreamResponse(
305 ctx,
306 []message.Message{
307 {
308 Role: message.User,
309 Parts: parts,
310 },
311 },
312 nil,
313 )
314
315 var finalResponse *provider.ProviderResponse
316 for r := range response {
317 if r.Error != nil {
318 return r.Error
319 }
320 finalResponse = r.Response
321 }
322
323 if finalResponse == nil {
324 return fmt.Errorf("no response received from title provider")
325 }
326
327 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
328 if title == "" {
329 return nil
330 }
331
332 session.Title = title
333 _, err = a.sessions.Save(ctx, session)
334 return err
335}
336
337func (a *agent) err(err error) AgentEvent {
338 return AgentEvent{
339 Type: AgentEventTypeError,
340 Error: err,
341 }
342}
343
344func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
345 if !a.Model().SupportsImages && attachments != nil {
346 attachments = nil
347 }
348 events := make(chan AgentEvent, 1)
349 if a.IsSessionBusy(sessionID) {
350 existing, ok := a.promptQueue.Get(sessionID)
351 if !ok {
352 existing = []string{}
353 }
354 existing = append(existing, content)
355 a.promptQueue.Set(sessionID, existing)
356 return nil, nil
357 }
358
359 genCtx, cancel := context.WithCancel(ctx)
360
361 a.activeRequests.Set(sessionID, cancel)
362 go func() {
363 slog.Debug("Request started", "sessionID", sessionID)
364 defer log.RecoverPanic("agent.Run", func() {
365 events <- a.err(fmt.Errorf("panic while running the agent"))
366 })
367 var attachmentParts []message.ContentPart
368 for _, attachment := range attachments {
369 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
370 }
371 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
372 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
373 slog.Error(result.Error.Error())
374 }
375 slog.Debug("Request completed", "sessionID", sessionID)
376 a.activeRequests.Del(sessionID)
377 cancel()
378 a.Publish(pubsub.CreatedEvent, result)
379 events <- result
380 close(events)
381 }()
382 return events, nil
383}
384
385func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
386 cfg := config.Get()
387 // List existing messages; if none, start title generation asynchronously.
388 msgs, err := a.messages.List(ctx, sessionID)
389 if err != nil {
390 return a.err(fmt.Errorf("failed to list messages: %w", err))
391 }
392 if len(msgs) == 0 {
393 go func() {
394 defer log.RecoverPanic("agent.Run", func() {
395 slog.Error("panic while generating title")
396 })
397 titleErr := a.generateTitle(ctx, sessionID, content)
398 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
399 slog.Error("failed to generate title", "error", titleErr)
400 }
401 }()
402 }
403 session, err := a.sessions.Get(ctx, sessionID)
404 if err != nil {
405 return a.err(fmt.Errorf("failed to get session: %w", err))
406 }
407 if session.SummaryMessageID != "" {
408 summaryMsgInex := -1
409 for i, msg := range msgs {
410 if msg.ID == session.SummaryMessageID {
411 summaryMsgInex = i
412 break
413 }
414 }
415 if summaryMsgInex != -1 {
416 msgs = msgs[summaryMsgInex:]
417 msgs[0].Role = message.User
418 }
419 }
420
421 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
422 if err != nil {
423 return a.err(fmt.Errorf("failed to create user message: %w", err))
424 }
425 // Append the new user message to the conversation history.
426 msgHistory := append(msgs, userMsg)
427
428 for {
429 // Check for cancellation before each iteration
430 select {
431 case <-ctx.Done():
432 return a.err(ctx.Err())
433 default:
434 // Continue processing
435 }
436 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
437 if err != nil {
438 if errors.Is(err, context.Canceled) {
439 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
440 a.messages.Update(context.Background(), agentMessage)
441 return a.err(ErrRequestCancelled)
442 }
443 return a.err(fmt.Errorf("failed to process events: %w", err))
444 }
445 if cfg.Options.Debug {
446 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
447 }
448 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
449 // We are not done, we need to respond with the tool response
450 msgHistory = append(msgHistory, agentMessage, *toolResults)
451 // If there are queued prompts, process the next one
452 nextPrompt, ok := a.promptQueue.Take(sessionID)
453 if ok {
454 for _, prompt := range nextPrompt {
455 // Create a new user message for the queued prompt
456 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
457 if err != nil {
458 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
459 }
460 // Append the new user message to the conversation history
461 msgHistory = append(msgHistory, userMsg)
462 }
463 }
464
465 continue
466 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
467 queuePrompts, ok := a.promptQueue.Take(sessionID)
468 if ok {
469 for _, prompt := range queuePrompts {
470 if prompt == "" {
471 continue
472 }
473 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
474 if err != nil {
475 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
476 }
477 msgHistory = append(msgHistory, userMsg)
478 }
479 continue
480 }
481 }
482 if agentMessage.FinishReason() == "" {
483 // Kujtim: could not track down where this is happening but this means its cancelled
484 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
485 _ = a.messages.Update(context.Background(), agentMessage)
486 return a.err(ErrRequestCancelled)
487 }
488 return AgentEvent{
489 Type: AgentEventTypeResponse,
490 Message: agentMessage,
491 Done: true,
492 }
493 }
494}
495
496func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
497 parts := []message.ContentPart{message.TextContent{Text: content}}
498 parts = append(parts, attachmentParts...)
499 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
500 Role: message.User,
501 Parts: parts,
502 })
503}
504
505func (a *agent) getAllTools() ([]tools.BaseTool, error) {
506 allTools := slices.Collect(a.tools.Seq())
507 if a.agentToolFn != nil {
508 agentTool, agentToolErr := a.agentToolFn()
509 if agentToolErr != nil {
510 return nil, agentToolErr
511 }
512 allTools = append(allTools, agentTool)
513 }
514 return allTools, nil
515}
516
517func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
518 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
519
520 // Create the assistant message first so the spinner shows immediately
521 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
522 Role: message.Assistant,
523 Parts: []message.ContentPart{},
524 Model: a.Model().ID,
525 Provider: a.providerID,
526 })
527 if err != nil {
528 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
529 }
530
531 allTools, toolsErr := a.getAllTools()
532 if toolsErr != nil {
533 return assistantMsg, nil, toolsErr
534 }
535 // Now collect tools (which may block on MCP initialization)
536 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
537
538 // Add the session and message ID into the context if needed by tools.
539 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
540
541 // Process each event in the stream.
542 for event := range eventChan {
543 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
544 if errors.Is(processErr, context.Canceled) {
545 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
546 } else {
547 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
548 }
549 return assistantMsg, nil, processErr
550 }
551 if ctx.Err() != nil {
552 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
553 return assistantMsg, nil, ctx.Err()
554 }
555 }
556
557 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
558 toolCalls := assistantMsg.ToolCalls()
559 for i, toolCall := range toolCalls {
560 select {
561 case <-ctx.Done():
562 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
563 // Make all future tool calls cancelled
564 for j := i; j < len(toolCalls); j++ {
565 toolResults[j] = message.ToolResult{
566 ToolCallID: toolCalls[j].ID,
567 Content: "Tool execution canceled by user",
568 IsError: true,
569 }
570 }
571 goto out
572 default:
573 // Continue processing
574 var tool tools.BaseTool
575 allTools, _ := a.getAllTools()
576 for _, availableTool := range allTools {
577 if availableTool.Info().Name == toolCall.Name {
578 tool = availableTool
579 break
580 }
581 }
582
583 // Tool not found
584 if tool == nil {
585 toolResults[i] = message.ToolResult{
586 ToolCallID: toolCall.ID,
587 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
588 IsError: true,
589 }
590 continue
591 }
592
593 // Run tool in goroutine to allow cancellation
594 type toolExecResult struct {
595 response tools.ToolResponse
596 err error
597 }
598 resultChan := make(chan toolExecResult, 1)
599
600 go func() {
601 response, err := tool.Run(ctx, tools.ToolCall{
602 ID: toolCall.ID,
603 Name: toolCall.Name,
604 Input: toolCall.Input,
605 })
606 resultChan <- toolExecResult{response: response, err: err}
607 }()
608
609 var toolResponse tools.ToolResponse
610 var toolErr error
611
612 select {
613 case <-ctx.Done():
614 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
615 // Mark remaining tool calls as cancelled
616 for j := i; j < len(toolCalls); j++ {
617 toolResults[j] = message.ToolResult{
618 ToolCallID: toolCalls[j].ID,
619 Content: "Tool execution canceled by user",
620 IsError: true,
621 }
622 }
623 goto out
624 case result := <-resultChan:
625 toolResponse = result.response
626 toolErr = result.err
627 }
628
629 if toolErr != nil {
630 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
631 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
632 toolResults[i] = message.ToolResult{
633 ToolCallID: toolCall.ID,
634 Content: "Permission denied",
635 IsError: true,
636 }
637 for j := i + 1; j < len(toolCalls); j++ {
638 toolResults[j] = message.ToolResult{
639 ToolCallID: toolCalls[j].ID,
640 Content: "Tool execution canceled by user",
641 IsError: true,
642 }
643 }
644 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
645 break
646 }
647 }
648 toolResults[i] = message.ToolResult{
649 ToolCallID: toolCall.ID,
650 Content: toolResponse.Content,
651 Metadata: toolResponse.Metadata,
652 IsError: toolResponse.IsError,
653 }
654 }
655 }
656out:
657 if len(toolResults) == 0 {
658 return assistantMsg, nil, nil
659 }
660 parts := make([]message.ContentPart, 0)
661 for _, tr := range toolResults {
662 parts = append(parts, tr)
663 }
664 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
665 Role: message.Tool,
666 Parts: parts,
667 Provider: a.providerID,
668 })
669 if err != nil {
670 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
671 }
672
673 return assistantMsg, &msg, err
674}
675
676func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
677 msg.AddFinish(finishReason, message, details)
678 _ = a.messages.Update(ctx, *msg)
679}
680
681func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
682 select {
683 case <-ctx.Done():
684 return ctx.Err()
685 default:
686 // Continue processing.
687 }
688
689 switch event.Type {
690 case provider.EventThinkingDelta:
691 assistantMsg.AppendReasoningContent(event.Thinking)
692 return a.messages.Update(ctx, *assistantMsg)
693 case provider.EventSignatureDelta:
694 assistantMsg.AppendReasoningSignature(event.Signature)
695 return a.messages.Update(ctx, *assistantMsg)
696 case provider.EventContentDelta:
697 assistantMsg.FinishThinking()
698 assistantMsg.AppendContent(event.Content)
699 return a.messages.Update(ctx, *assistantMsg)
700 case provider.EventToolUseStart:
701 assistantMsg.FinishThinking()
702 slog.Info("Tool call started", "toolCall", event.ToolCall)
703 assistantMsg.AddToolCall(*event.ToolCall)
704 return a.messages.Update(ctx, *assistantMsg)
705 case provider.EventToolUseDelta:
706 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
707 return a.messages.Update(ctx, *assistantMsg)
708 case provider.EventToolUseStop:
709 slog.Info("Finished tool call", "toolCall", event.ToolCall)
710 assistantMsg.FinishToolCall(event.ToolCall.ID)
711 return a.messages.Update(ctx, *assistantMsg)
712 case provider.EventError:
713 return event.Error
714 case provider.EventComplete:
715 assistantMsg.FinishThinking()
716 assistantMsg.SetToolCalls(event.Response.ToolCalls)
717 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
718 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
719 return fmt.Errorf("failed to update message: %w", err)
720 }
721 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
722 }
723
724 return nil
725}
726
727func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
728 sess, err := a.sessions.Get(ctx, sessionID)
729 if err != nil {
730 return fmt.Errorf("failed to get session: %w", err)
731 }
732
733 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
734 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
735 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
736 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
737
738 sess.Cost += cost
739 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
740 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
741
742 _, err = a.sessions.Save(ctx, sess)
743 if err != nil {
744 return fmt.Errorf("failed to save session: %w", err)
745 }
746 return nil
747}
748
749func (a *agent) Summarize(ctx context.Context, sessionID string) error {
750 if a.summarizeProvider == nil {
751 return fmt.Errorf("summarize provider not available")
752 }
753
754 // Check if session is busy
755 if a.IsSessionBusy(sessionID) {
756 return ErrSessionBusy
757 }
758
759 // Create a new context with cancellation
760 summarizeCtx, cancel := context.WithCancel(ctx)
761
762 // Store the cancel function in activeRequests to allow cancellation
763 a.activeRequests.Set(sessionID+"-summarize", cancel)
764
765 go func() {
766 defer a.activeRequests.Del(sessionID + "-summarize")
767 defer cancel()
768 event := AgentEvent{
769 Type: AgentEventTypeSummarize,
770 Progress: "Starting summarization...",
771 }
772
773 a.Publish(pubsub.CreatedEvent, event)
774 // Get all messages from the session
775 msgs, err := a.messages.List(summarizeCtx, sessionID)
776 if err != nil {
777 event = AgentEvent{
778 Type: AgentEventTypeError,
779 Error: fmt.Errorf("failed to list messages: %w", err),
780 Done: true,
781 }
782 a.Publish(pubsub.CreatedEvent, event)
783 return
784 }
785 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
786
787 if len(msgs) == 0 {
788 event = AgentEvent{
789 Type: AgentEventTypeError,
790 Error: fmt.Errorf("no messages to summarize"),
791 Done: true,
792 }
793 a.Publish(pubsub.CreatedEvent, event)
794 return
795 }
796
797 event = AgentEvent{
798 Type: AgentEventTypeSummarize,
799 Progress: "Analyzing conversation...",
800 }
801 a.Publish(pubsub.CreatedEvent, event)
802
803 // Add a system message to guide the summarization
804 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
805
806 // Create a new message with the summarize prompt
807 promptMsg := message.Message{
808 Role: message.User,
809 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
810 }
811
812 // Append the prompt to the messages
813 msgsWithPrompt := append(msgs, promptMsg)
814
815 event = AgentEvent{
816 Type: AgentEventTypeSummarize,
817 Progress: "Generating summary...",
818 }
819
820 a.Publish(pubsub.CreatedEvent, event)
821
822 // Send the messages to the summarize provider
823 response := a.summarizeProvider.StreamResponse(
824 summarizeCtx,
825 msgsWithPrompt,
826 nil,
827 )
828 var finalResponse *provider.ProviderResponse
829 for r := range response {
830 if r.Error != nil {
831 event = AgentEvent{
832 Type: AgentEventTypeError,
833 Error: fmt.Errorf("failed to summarize: %w", err),
834 Done: true,
835 }
836 a.Publish(pubsub.CreatedEvent, event)
837 return
838 }
839 finalResponse = r.Response
840 }
841
842 summary := strings.TrimSpace(finalResponse.Content)
843 if summary == "" {
844 event = AgentEvent{
845 Type: AgentEventTypeError,
846 Error: fmt.Errorf("empty summary returned"),
847 Done: true,
848 }
849 a.Publish(pubsub.CreatedEvent, event)
850 return
851 }
852 shell := shell.GetPersistentShell(config.Get().WorkingDir())
853 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
854 event = AgentEvent{
855 Type: AgentEventTypeSummarize,
856 Progress: "Creating new session...",
857 }
858
859 a.Publish(pubsub.CreatedEvent, event)
860 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
861 if err != nil {
862 event = AgentEvent{
863 Type: AgentEventTypeError,
864 Error: fmt.Errorf("failed to get session: %w", err),
865 Done: true,
866 }
867
868 a.Publish(pubsub.CreatedEvent, event)
869 return
870 }
871 // Create a message in the new session with the summary
872 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
873 Role: message.Assistant,
874 Parts: []message.ContentPart{
875 message.TextContent{Text: summary},
876 message.Finish{
877 Reason: message.FinishReasonEndTurn,
878 Time: time.Now().Unix(),
879 },
880 },
881 Model: a.summarizeProvider.Model().ID,
882 Provider: a.summarizeProviderID,
883 })
884 if err != nil {
885 event = AgentEvent{
886 Type: AgentEventTypeError,
887 Error: fmt.Errorf("failed to create summary message: %w", err),
888 Done: true,
889 }
890
891 a.Publish(pubsub.CreatedEvent, event)
892 return
893 }
894 oldSession.SummaryMessageID = msg.ID
895 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
896 oldSession.PromptTokens = 0
897 model := a.summarizeProvider.Model()
898 usage := finalResponse.Usage
899 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
900 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
901 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
902 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
903 oldSession.Cost += cost
904 _, err = a.sessions.Save(summarizeCtx, oldSession)
905 if err != nil {
906 event = AgentEvent{
907 Type: AgentEventTypeError,
908 Error: fmt.Errorf("failed to save session: %w", err),
909 Done: true,
910 }
911 a.Publish(pubsub.CreatedEvent, event)
912 }
913
914 event = AgentEvent{
915 Type: AgentEventTypeSummarize,
916 SessionID: oldSession.ID,
917 Progress: "Summary complete",
918 Done: true,
919 }
920 a.Publish(pubsub.CreatedEvent, event)
921 // Send final success event with the new session ID
922 }()
923
924 return nil
925}
926
927func (a *agent) ClearQueue(sessionID string) {
928 if a.QueuedPrompts(sessionID) > 0 {
929 slog.Info("Clearing queued prompts", "session_id", sessionID)
930 a.promptQueue.Del(sessionID)
931 }
932}
933
934func (a *agent) CancelAll() {
935 if !a.IsBusy() {
936 return
937 }
938 for key := range a.activeRequests.Seq2() {
939 a.Cancel(key) // key is sessionID
940 }
941
942 timeout := time.After(5 * time.Second)
943 for a.IsBusy() {
944 select {
945 case <-timeout:
946 return
947 default:
948 time.Sleep(200 * time.Millisecond)
949 }
950 }
951}
952
953func (a *agent) UpdateModel() error {
954 cfg := config.Get()
955
956 // Get current provider configuration
957 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
958 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
959 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
960 }
961
962 // Check if provider has changed
963 if string(currentProviderCfg.ID) != a.providerID {
964 // Provider changed, need to recreate the main provider
965 model := cfg.GetModelByType(a.agentCfg.Model)
966 if model.ID == "" {
967 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
968 }
969
970 promptID := agentPromptMap[a.agentCfg.ID]
971 if promptID == "" {
972 promptID = prompt.PromptDefault
973 }
974
975 opts := []provider.ProviderClientOption{
976 provider.WithModel(a.agentCfg.Model),
977 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
978 }
979
980 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
981 if err != nil {
982 return fmt.Errorf("failed to create new provider: %w", err)
983 }
984
985 // Update the provider and provider ID
986 a.provider = newProvider
987 a.providerID = string(currentProviderCfg.ID)
988 }
989
990 // Check if providers have changed for title (small) and summarize (large)
991 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
992 var smallModelProviderCfg config.ProviderConfig
993 for p := range cfg.Providers.Seq() {
994 if p.ID == smallModelCfg.Provider {
995 smallModelProviderCfg = p
996 break
997 }
998 }
999 if smallModelProviderCfg.ID == "" {
1000 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1001 }
1002
1003 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1004 var largeModelProviderCfg config.ProviderConfig
1005 for p := range cfg.Providers.Seq() {
1006 if p.ID == largeModelCfg.Provider {
1007 largeModelProviderCfg = p
1008 break
1009 }
1010 }
1011 if largeModelProviderCfg.ID == "" {
1012 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1013 }
1014
1015 var maxTitleTokens int64 = 40
1016
1017 // if the max output is too low for the gemini provider it won't return anything
1018 if smallModelCfg.Provider == "gemini" {
1019 maxTitleTokens = 1000
1020 }
1021 // Recreate title provider
1022 titleOpts := []provider.ProviderClientOption{
1023 provider.WithModel(config.SelectedModelTypeSmall),
1024 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1025 provider.WithMaxTokens(maxTitleTokens),
1026 }
1027 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1028 if err != nil {
1029 return fmt.Errorf("failed to create new title provider: %w", err)
1030 }
1031 a.titleProvider = newTitleProvider
1032
1033 // Recreate summarize provider if provider changed (now large model)
1034 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1035 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1036 if largeModel == nil {
1037 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1038 }
1039 summarizeOpts := []provider.ProviderClientOption{
1040 provider.WithModel(config.SelectedModelTypeLarge),
1041 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1042 }
1043 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1044 if err != nil {
1045 return fmt.Errorf("failed to create new summarize provider: %w", err)
1046 }
1047 a.summarizeProvider = newSummarizeProvider
1048 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1049 }
1050
1051 return nil
1052}