1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.LazySlice[tools.BaseTool]
75 // We need this to be able to update it when model changes
76 agentToolFn func() (tools.BaseTool, error)
77
78 provider provider.Provider
79 providerID string
80
81 titleProvider provider.Provider
82 summarizeProvider provider.Provider
83 summarizeProviderID string
84
85 activeRequests *csync.Map[string, context.CancelFunc]
86
87 promptQueue *csync.Map[string, []string]
88}
89
90var agentPromptMap = map[string]prompt.PromptID{
91 "coder": prompt.PromptCoder,
92 "task": prompt.PromptTask,
93}
94
95func NewAgent(
96 ctx context.Context,
97 agentCfg config.Agent,
98 // These services are needed in the tools
99 permissions permission.Service,
100 sessions session.Service,
101 messages message.Service,
102 history history.Service,
103 lspClients map[string]*lsp.Client,
104) (Service, error) {
105 cfg := config.Get()
106
107 var agentToolFn func() (tools.BaseTool, error)
108 if agentCfg.ID == "coder" {
109 agentToolFn = func() (tools.BaseTool, error) {
110 taskAgentCfg := config.Get().Agents["task"]
111 if taskAgentCfg.ID == "" {
112 return nil, fmt.Errorf("task agent not found in config")
113 }
114 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
115 if err != nil {
116 return nil, fmt.Errorf("failed to create task agent: %w", err)
117 }
118 return NewAgentTool(taskAgent, sessions, messages), nil
119 }
120 }
121
122 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
123 if providerCfg == nil {
124 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
125 }
126 model := config.Get().GetModelByType(agentCfg.Model)
127
128 if model == nil {
129 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
130 }
131
132 promptID := agentPromptMap[agentCfg.ID]
133 if promptID == "" {
134 promptID = prompt.PromptDefault
135 }
136 opts := []provider.ProviderClientOption{
137 provider.WithModel(agentCfg.Model),
138 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
139 }
140 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
141 if err != nil {
142 return nil, err
143 }
144
145 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
146 var smallModelProviderCfg *config.ProviderConfig
147 if smallModelCfg.Provider == providerCfg.ID {
148 smallModelProviderCfg = providerCfg
149 } else {
150 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
151
152 if smallModelProviderCfg.ID == "" {
153 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
154 }
155 }
156 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
157 if smallModel.ID == "" {
158 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
159 }
160
161 titleOpts := []provider.ProviderClientOption{
162 provider.WithModel(config.SelectedModelTypeSmall),
163 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
164 }
165 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
166 if err != nil {
167 return nil, err
168 }
169
170 summarizeOpts := []provider.ProviderClientOption{
171 provider.WithModel(config.SelectedModelTypeLarge),
172 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
173 }
174 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
175 if err != nil {
176 return nil, err
177 }
178
179 toolFn := func() []tools.BaseTool {
180 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
181 defer func() {
182 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
183 }()
184
185 cwd := cfg.WorkingDir()
186 allTools := []tools.BaseTool{
187 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
188 tools.NewDownloadTool(permissions, cwd),
189 tools.NewEditTool(lspClients, permissions, history, cwd),
190 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
191 tools.NewFetchTool(permissions, cwd),
192 tools.NewGlobTool(cwd),
193 tools.NewGrepTool(cwd),
194 tools.NewLsTool(permissions, cwd),
195 tools.NewSourcegraphTool(),
196 tools.NewViewTool(lspClients, permissions, cwd),
197 tools.NewWriteTool(lspClients, permissions, history, cwd),
198 }
199
200 mcpToolsOnce.Do(func() {
201 mcpTools = doGetMCPTools(ctx, permissions, cfg)
202 })
203
204 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
205 if agentCfg.ID == "coder" {
206 t = append(t, mcpTools...)
207 if len(lspClients) > 0 {
208 t = append(t, tools.NewDiagnosticsTool(lspClients))
209 }
210 }
211 return t
212 }
213
214 if agentCfg.AllowedTools == nil {
215 return withCoderTools(allTools)
216 }
217
218 var filteredTools []tools.BaseTool
219 for _, tool := range allTools {
220 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
221 filteredTools = append(filteredTools, tool)
222 }
223 }
224 return withCoderTools(filteredTools)
225 }
226
227 return &agent{
228 Broker: pubsub.NewBroker[AgentEvent](),
229 agentCfg: agentCfg,
230 provider: agentProvider,
231 providerID: string(providerCfg.ID),
232 messages: messages,
233 sessions: sessions,
234 titleProvider: titleProvider,
235 summarizeProvider: summarizeProvider,
236 summarizeProviderID: string(providerCfg.ID),
237 agentToolFn: agentToolFn,
238 activeRequests: csync.NewMap[string, context.CancelFunc](),
239 tools: csync.NewLazySlice(toolFn),
240 promptQueue: csync.NewMap[string, []string](),
241 }, nil
242}
243
244func (a *agent) Model() catwalk.Model {
245 return *config.Get().GetModelByType(a.agentCfg.Model)
246}
247
248func (a *agent) Cancel(sessionID string) {
249 // Cancel regular requests
250 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
251 slog.Info("Request cancellation initiated", "session_id", sessionID)
252 cancel()
253 }
254
255 // Also check for summarize requests
256 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
257 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
258 cancel()
259 }
260
261 if a.QueuedPrompts(sessionID) > 0 {
262 slog.Info("Clearing queued prompts", "session_id", sessionID)
263 a.promptQueue.Del(sessionID)
264 }
265}
266
267func (a *agent) IsBusy() bool {
268 var busy bool
269 for cancelFunc := range a.activeRequests.Seq() {
270 if cancelFunc != nil {
271 busy = true
272 break
273 }
274 }
275 return busy
276}
277
278func (a *agent) IsSessionBusy(sessionID string) bool {
279 _, busy := a.activeRequests.Get(sessionID)
280 return busy
281}
282
283func (a *agent) QueuedPrompts(sessionID string) int {
284 l, ok := a.promptQueue.Get(sessionID)
285 if !ok {
286 return 0
287 }
288 return len(l)
289}
290
291func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
292 if content == "" {
293 return nil
294 }
295 if a.titleProvider == nil {
296 return nil
297 }
298 session, err := a.sessions.Get(ctx, sessionID)
299 if err != nil {
300 return err
301 }
302 parts := []message.ContentPart{message.TextContent{
303 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
304 }}
305
306 // Use streaming approach like summarization
307 response := a.titleProvider.StreamResponse(
308 ctx,
309 []message.Message{
310 {
311 Role: message.User,
312 Parts: parts,
313 },
314 },
315 nil,
316 )
317
318 var finalResponse *provider.ProviderResponse
319 for r := range response {
320 if r.Error != nil {
321 return r.Error
322 }
323 finalResponse = r.Response
324 }
325
326 if finalResponse == nil {
327 return fmt.Errorf("no response received from title provider")
328 }
329
330 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
331 if title == "" {
332 return nil
333 }
334
335 session.Title = title
336 _, err = a.sessions.Save(ctx, session)
337 return err
338}
339
340func (a *agent) err(err error) AgentEvent {
341 return AgentEvent{
342 Type: AgentEventTypeError,
343 Error: err,
344 }
345}
346
347func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
348 if !a.Model().SupportsImages && attachments != nil {
349 attachments = nil
350 }
351 events := make(chan AgentEvent, 1)
352 if a.IsSessionBusy(sessionID) {
353 existing, ok := a.promptQueue.Get(sessionID)
354 if !ok {
355 existing = []string{}
356 }
357 existing = append(existing, content)
358 a.promptQueue.Set(sessionID, existing)
359 return nil, nil
360 }
361
362 genCtx, cancel := context.WithCancel(ctx)
363
364 a.activeRequests.Set(sessionID, cancel)
365 go func() {
366 slog.Debug("Request started", "sessionID", sessionID)
367 defer log.RecoverPanic("agent.Run", func() {
368 events <- a.err(fmt.Errorf("panic while running the agent"))
369 })
370 var attachmentParts []message.ContentPart
371 for _, attachment := range attachments {
372 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
373 }
374 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
375 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
376 slog.Error(result.Error.Error())
377 }
378 slog.Debug("Request completed", "sessionID", sessionID)
379 a.activeRequests.Del(sessionID)
380 cancel()
381 a.Publish(pubsub.CreatedEvent, result)
382 events <- result
383 close(events)
384 }()
385 return events, nil
386}
387
388func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
389 cfg := config.Get()
390 // List existing messages; if none, start title generation asynchronously.
391 msgs, err := a.messages.List(ctx, sessionID)
392 if err != nil {
393 return a.err(fmt.Errorf("failed to list messages: %w", err))
394 }
395 if len(msgs) == 0 {
396 go func() {
397 defer log.RecoverPanic("agent.Run", func() {
398 slog.Error("panic while generating title")
399 })
400 titleErr := a.generateTitle(ctx, sessionID, content)
401 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
402 slog.Error("failed to generate title", "error", titleErr)
403 }
404 }()
405 }
406 session, err := a.sessions.Get(ctx, sessionID)
407 if err != nil {
408 return a.err(fmt.Errorf("failed to get session: %w", err))
409 }
410 if session.SummaryMessageID != "" {
411 summaryMsgInex := -1
412 for i, msg := range msgs {
413 if msg.ID == session.SummaryMessageID {
414 summaryMsgInex = i
415 break
416 }
417 }
418 if summaryMsgInex != -1 {
419 msgs = msgs[summaryMsgInex:]
420 msgs[0].Role = message.User
421 }
422 }
423
424 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
425 if err != nil {
426 return a.err(fmt.Errorf("failed to create user message: %w", err))
427 }
428 // Append the new user message to the conversation history.
429 msgHistory := append(msgs, userMsg)
430
431 for {
432 // Check for cancellation before each iteration
433 select {
434 case <-ctx.Done():
435 return a.err(ctx.Err())
436 default:
437 // Continue processing
438 }
439 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
440 if err != nil {
441 if errors.Is(err, context.Canceled) {
442 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
443 a.messages.Update(context.Background(), agentMessage)
444 return a.err(ErrRequestCancelled)
445 }
446 return a.err(fmt.Errorf("failed to process events: %w", err))
447 }
448 if cfg.Options.Debug {
449 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
450 }
451 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
452 // We are not done, we need to respond with the tool response
453 msgHistory = append(msgHistory, agentMessage, *toolResults)
454 // If there are queued prompts, process the next one
455 nextPrompt, ok := a.promptQueue.Take(sessionID)
456 if ok {
457 for _, prompt := range nextPrompt {
458 // Create a new user message for the queued prompt
459 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
460 if err != nil {
461 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
462 }
463 // Append the new user message to the conversation history
464 msgHistory = append(msgHistory, userMsg)
465 }
466 }
467
468 continue
469 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
470 queuePrompts, ok := a.promptQueue.Take(sessionID)
471 if ok {
472 for _, prompt := range queuePrompts {
473 if prompt == "" {
474 continue
475 }
476 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
477 if err != nil {
478 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
479 }
480 msgHistory = append(msgHistory, userMsg)
481 }
482 continue
483 }
484 }
485 if agentMessage.FinishReason() == "" {
486 // Kujtim: could not track down where this is happening but this means its cancelled
487 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
488 _ = a.messages.Update(context.Background(), agentMessage)
489 return a.err(ErrRequestCancelled)
490 }
491 return AgentEvent{
492 Type: AgentEventTypeResponse,
493 Message: agentMessage,
494 Done: true,
495 }
496 }
497}
498
499func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
500 parts := []message.ContentPart{message.TextContent{Text: content}}
501 parts = append(parts, attachmentParts...)
502 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
503 Role: message.User,
504 Parts: parts,
505 })
506}
507
508func (a *agent) getAllTools() ([]tools.BaseTool, error) {
509 allTools := slices.Collect(a.tools.Seq())
510 if a.agentToolFn != nil {
511 agentTool, agentToolErr := a.agentToolFn()
512 if agentToolErr != nil {
513 return nil, agentToolErr
514 }
515 allTools = append(allTools, agentTool)
516 }
517 return allTools, nil
518}
519
520func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
521 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
522
523 // Create the assistant message first so the spinner shows immediately
524 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
525 Role: message.Assistant,
526 Parts: []message.ContentPart{},
527 Model: a.Model().ID,
528 Provider: a.providerID,
529 })
530 if err != nil {
531 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
532 }
533
534 allTools, toolsErr := a.getAllTools()
535 if toolsErr != nil {
536 return assistantMsg, nil, toolsErr
537 }
538 // Now collect tools (which may block on MCP initialization)
539 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
540
541 // Add the session and message ID into the context if needed by tools.
542 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
543
544 // Process each event in the stream.
545 for event := range eventChan {
546 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
547 if errors.Is(processErr, context.Canceled) {
548 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
549 } else {
550 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
551 }
552 return assistantMsg, nil, processErr
553 }
554 if ctx.Err() != nil {
555 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
556 return assistantMsg, nil, ctx.Err()
557 }
558 }
559
560 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
561 toolCalls := assistantMsg.ToolCalls()
562 for i, toolCall := range toolCalls {
563 select {
564 case <-ctx.Done():
565 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
566 // Make all future tool calls cancelled
567 for j := i; j < len(toolCalls); j++ {
568 toolResults[j] = message.ToolResult{
569 ToolCallID: toolCalls[j].ID,
570 Content: "Tool execution canceled by user",
571 IsError: true,
572 }
573 }
574 goto out
575 default:
576 // Continue processing
577 var tool tools.BaseTool
578 allTools, _ := a.getAllTools()
579 for _, availableTool := range allTools {
580 if availableTool.Info().Name == toolCall.Name {
581 tool = availableTool
582 break
583 }
584 }
585
586 // Tool not found
587 if tool == nil {
588 toolResults[i] = message.ToolResult{
589 ToolCallID: toolCall.ID,
590 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
591 IsError: true,
592 }
593 continue
594 }
595
596 // Run tool in goroutine to allow cancellation
597 type toolExecResult struct {
598 response tools.ToolResponse
599 err error
600 }
601 resultChan := make(chan toolExecResult, 1)
602
603 go func() {
604 response, err := tool.Run(ctx, tools.ToolCall{
605 ID: toolCall.ID,
606 Name: toolCall.Name,
607 Input: toolCall.Input,
608 })
609 resultChan <- toolExecResult{response: response, err: err}
610 }()
611
612 var toolResponse tools.ToolResponse
613 var toolErr error
614
615 select {
616 case <-ctx.Done():
617 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
618 // Mark remaining tool calls as cancelled
619 for j := i; j < len(toolCalls); j++ {
620 toolResults[j] = message.ToolResult{
621 ToolCallID: toolCalls[j].ID,
622 Content: "Tool execution canceled by user",
623 IsError: true,
624 }
625 }
626 goto out
627 case result := <-resultChan:
628 toolResponse = result.response
629 toolErr = result.err
630 }
631
632 if toolErr != nil {
633 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
634 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
635 toolResults[i] = message.ToolResult{
636 ToolCallID: toolCall.ID,
637 Content: "Permission denied",
638 IsError: true,
639 }
640 for j := i + 1; j < len(toolCalls); j++ {
641 toolResults[j] = message.ToolResult{
642 ToolCallID: toolCalls[j].ID,
643 Content: "Tool execution canceled by user",
644 IsError: true,
645 }
646 }
647 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
648 break
649 }
650 }
651 toolResults[i] = message.ToolResult{
652 ToolCallID: toolCall.ID,
653 Content: toolResponse.Content,
654 Metadata: toolResponse.Metadata,
655 IsError: toolResponse.IsError,
656 }
657 }
658 }
659out:
660 if len(toolResults) == 0 {
661 return assistantMsg, nil, nil
662 }
663 parts := make([]message.ContentPart, 0)
664 for _, tr := range toolResults {
665 parts = append(parts, tr)
666 }
667 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
668 Role: message.Tool,
669 Parts: parts,
670 Provider: a.providerID,
671 })
672 if err != nil {
673 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
674 }
675
676 return assistantMsg, &msg, err
677}
678
679func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
680 msg.AddFinish(finishReason, message, details)
681 _ = a.messages.Update(ctx, *msg)
682}
683
684func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
685 select {
686 case <-ctx.Done():
687 return ctx.Err()
688 default:
689 // Continue processing.
690 }
691
692 switch event.Type {
693 case provider.EventThinkingDelta:
694 assistantMsg.AppendReasoningContent(event.Thinking)
695 return a.messages.Update(ctx, *assistantMsg)
696 case provider.EventSignatureDelta:
697 assistantMsg.AppendReasoningSignature(event.Signature)
698 return a.messages.Update(ctx, *assistantMsg)
699 case provider.EventContentDelta:
700 assistantMsg.FinishThinking()
701 assistantMsg.AppendContent(event.Content)
702 return a.messages.Update(ctx, *assistantMsg)
703 case provider.EventToolUseStart:
704 assistantMsg.FinishThinking()
705 slog.Info("Tool call started", "toolCall", event.ToolCall)
706 assistantMsg.AddToolCall(*event.ToolCall)
707 return a.messages.Update(ctx, *assistantMsg)
708 case provider.EventToolUseDelta:
709 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
710 return a.messages.Update(ctx, *assistantMsg)
711 case provider.EventToolUseStop:
712 slog.Info("Finished tool call", "toolCall", event.ToolCall)
713 assistantMsg.FinishToolCall(event.ToolCall.ID)
714 return a.messages.Update(ctx, *assistantMsg)
715 case provider.EventError:
716 return event.Error
717 case provider.EventComplete:
718 assistantMsg.FinishThinking()
719 assistantMsg.SetToolCalls(event.Response.ToolCalls)
720 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
721 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
722 return fmt.Errorf("failed to update message: %w", err)
723 }
724 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
725 }
726
727 return nil
728}
729
730func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
731 sess, err := a.sessions.Get(ctx, sessionID)
732 if err != nil {
733 return fmt.Errorf("failed to get session: %w", err)
734 }
735
736 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
737 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
738 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
739 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
740
741 sess.Cost += cost
742 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
743 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
744
745 _, err = a.sessions.Save(ctx, sess)
746 if err != nil {
747 return fmt.Errorf("failed to save session: %w", err)
748 }
749 return nil
750}
751
752func (a *agent) Summarize(ctx context.Context, sessionID string) error {
753 if a.summarizeProvider == nil {
754 return fmt.Errorf("summarize provider not available")
755 }
756
757 // Check if session is busy
758 if a.IsSessionBusy(sessionID) {
759 return ErrSessionBusy
760 }
761
762 // Create a new context with cancellation
763 summarizeCtx, cancel := context.WithCancel(ctx)
764
765 // Store the cancel function in activeRequests to allow cancellation
766 a.activeRequests.Set(sessionID+"-summarize", cancel)
767
768 go func() {
769 defer a.activeRequests.Del(sessionID + "-summarize")
770 defer cancel()
771 event := AgentEvent{
772 Type: AgentEventTypeSummarize,
773 Progress: "Starting summarization...",
774 }
775
776 a.Publish(pubsub.CreatedEvent, event)
777 // Get all messages from the session
778 msgs, err := a.messages.List(summarizeCtx, sessionID)
779 if err != nil {
780 event = AgentEvent{
781 Type: AgentEventTypeError,
782 Error: fmt.Errorf("failed to list messages: %w", err),
783 Done: true,
784 }
785 a.Publish(pubsub.CreatedEvent, event)
786 return
787 }
788 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
789
790 if len(msgs) == 0 {
791 event = AgentEvent{
792 Type: AgentEventTypeError,
793 Error: fmt.Errorf("no messages to summarize"),
794 Done: true,
795 }
796 a.Publish(pubsub.CreatedEvent, event)
797 return
798 }
799
800 event = AgentEvent{
801 Type: AgentEventTypeSummarize,
802 Progress: "Analyzing conversation...",
803 }
804 a.Publish(pubsub.CreatedEvent, event)
805
806 // Add a system message to guide the summarization
807 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
808
809 // Create a new message with the summarize prompt
810 promptMsg := message.Message{
811 Role: message.User,
812 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
813 }
814
815 // Append the prompt to the messages
816 msgsWithPrompt := append(msgs, promptMsg)
817
818 event = AgentEvent{
819 Type: AgentEventTypeSummarize,
820 Progress: "Generating summary...",
821 }
822
823 a.Publish(pubsub.CreatedEvent, event)
824
825 // Send the messages to the summarize provider
826 response := a.summarizeProvider.StreamResponse(
827 summarizeCtx,
828 msgsWithPrompt,
829 nil,
830 )
831 var finalResponse *provider.ProviderResponse
832 for r := range response {
833 if r.Error != nil {
834 event = AgentEvent{
835 Type: AgentEventTypeError,
836 Error: fmt.Errorf("failed to summarize: %w", err),
837 Done: true,
838 }
839 a.Publish(pubsub.CreatedEvent, event)
840 return
841 }
842 finalResponse = r.Response
843 }
844
845 summary := strings.TrimSpace(finalResponse.Content)
846 if summary == "" {
847 event = AgentEvent{
848 Type: AgentEventTypeError,
849 Error: fmt.Errorf("empty summary returned"),
850 Done: true,
851 }
852 a.Publish(pubsub.CreatedEvent, event)
853 return
854 }
855 shell := shell.GetPersistentShell(config.Get().WorkingDir())
856 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
857 event = AgentEvent{
858 Type: AgentEventTypeSummarize,
859 Progress: "Creating new session...",
860 }
861
862 a.Publish(pubsub.CreatedEvent, event)
863 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
864 if err != nil {
865 event = AgentEvent{
866 Type: AgentEventTypeError,
867 Error: fmt.Errorf("failed to get session: %w", err),
868 Done: true,
869 }
870
871 a.Publish(pubsub.CreatedEvent, event)
872 return
873 }
874 // Create a message in the new session with the summary
875 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
876 Role: message.Assistant,
877 Parts: []message.ContentPart{
878 message.TextContent{Text: summary},
879 message.Finish{
880 Reason: message.FinishReasonEndTurn,
881 Time: time.Now().Unix(),
882 },
883 },
884 Model: a.summarizeProvider.Model().ID,
885 Provider: a.summarizeProviderID,
886 })
887 if err != nil {
888 event = AgentEvent{
889 Type: AgentEventTypeError,
890 Error: fmt.Errorf("failed to create summary message: %w", err),
891 Done: true,
892 }
893
894 a.Publish(pubsub.CreatedEvent, event)
895 return
896 }
897 oldSession.SummaryMessageID = msg.ID
898 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
899 oldSession.PromptTokens = 0
900 model := a.summarizeProvider.Model()
901 usage := finalResponse.Usage
902 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
903 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
904 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
905 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
906 oldSession.Cost += cost
907 _, err = a.sessions.Save(summarizeCtx, oldSession)
908 if err != nil {
909 event = AgentEvent{
910 Type: AgentEventTypeError,
911 Error: fmt.Errorf("failed to save session: %w", err),
912 Done: true,
913 }
914 a.Publish(pubsub.CreatedEvent, event)
915 }
916
917 event = AgentEvent{
918 Type: AgentEventTypeSummarize,
919 SessionID: oldSession.ID,
920 Progress: "Summary complete",
921 Done: true,
922 }
923 a.Publish(pubsub.CreatedEvent, event)
924 // Send final success event with the new session ID
925 }()
926
927 return nil
928}
929
930func (a *agent) ClearQueue(sessionID string) {
931 if a.QueuedPrompts(sessionID) > 0 {
932 slog.Info("Clearing queued prompts", "session_id", sessionID)
933 a.promptQueue.Del(sessionID)
934 }
935}
936
937func (a *agent) CancelAll() {
938 if !a.IsBusy() {
939 return
940 }
941 for key := range a.activeRequests.Seq2() {
942 a.Cancel(key) // key is sessionID
943 }
944
945 timeout := time.After(5 * time.Second)
946 for a.IsBusy() {
947 select {
948 case <-timeout:
949 return
950 default:
951 time.Sleep(200 * time.Millisecond)
952 }
953 }
954}
955
956func (a *agent) UpdateModel() error {
957 cfg := config.Get()
958
959 // Get current provider configuration
960 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
961 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
962 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
963 }
964
965 // Check if provider has changed
966 if string(currentProviderCfg.ID) != a.providerID {
967 // Provider changed, need to recreate the main provider
968 model := cfg.GetModelByType(a.agentCfg.Model)
969 if model.ID == "" {
970 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
971 }
972
973 promptID := agentPromptMap[a.agentCfg.ID]
974 if promptID == "" {
975 promptID = prompt.PromptDefault
976 }
977
978 opts := []provider.ProviderClientOption{
979 provider.WithModel(a.agentCfg.Model),
980 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
981 }
982
983 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
984 if err != nil {
985 return fmt.Errorf("failed to create new provider: %w", err)
986 }
987
988 // Update the provider and provider ID
989 a.provider = newProvider
990 a.providerID = string(currentProviderCfg.ID)
991 }
992
993 // Check if providers have changed for title (small) and summarize (large)
994 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
995 var smallModelProviderCfg config.ProviderConfig
996 for p := range cfg.Providers.Seq() {
997 if p.ID == smallModelCfg.Provider {
998 smallModelProviderCfg = p
999 break
1000 }
1001 }
1002 if smallModelProviderCfg.ID == "" {
1003 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1004 }
1005
1006 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1007 var largeModelProviderCfg config.ProviderConfig
1008 for p := range cfg.Providers.Seq() {
1009 if p.ID == largeModelCfg.Provider {
1010 largeModelProviderCfg = p
1011 break
1012 }
1013 }
1014 if largeModelProviderCfg.ID == "" {
1015 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1016 }
1017
1018 var maxTitleTokens int64 = 40
1019
1020 // if the max output is too low for the gemini provider it won't return anything
1021 if smallModelCfg.Provider == "gemini" {
1022 maxTitleTokens = 1000
1023 }
1024 // Recreate title provider
1025 titleOpts := []provider.ProviderClientOption{
1026 provider.WithModel(config.SelectedModelTypeSmall),
1027 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1028 provider.WithMaxTokens(maxTitleTokens),
1029 }
1030 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1031 if err != nil {
1032 return fmt.Errorf("failed to create new title provider: %w", err)
1033 }
1034 a.titleProvider = newTitleProvider
1035
1036 // Recreate summarize provider if provider changed (now large model)
1037 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1038 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1039 if largeModel == nil {
1040 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1041 }
1042 summarizeOpts := []provider.ProviderClientOption{
1043 provider.WithModel(config.SelectedModelTypeLarge),
1044 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1045 }
1046 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1047 if err != nil {
1048 return fmt.Errorf("failed to create new summarize provider: %w", err)
1049 }
1050 a.summarizeProvider = newSummarizeProvider
1051 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1052 }
1053
1054 return nil
1055}