1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.LazySlice[tools.BaseTool]
75 // We need this to be able to update it when model changes
76 agentToolFn func() (tools.BaseTool, error)
77
78 provider provider.Provider
79 providerID string
80
81 titleProvider provider.Provider
82 summarizeProvider provider.Provider
83 summarizeProviderID string
84
85 activeRequests *csync.Map[string, context.CancelFunc]
86
87 promptQueue *csync.Map[string, []string]
88}
89
90var agentPromptMap = map[string]prompt.PromptID{
91 "coder": prompt.PromptCoder,
92 "task": prompt.PromptTask,
93}
94
95func NewAgent(
96 ctx context.Context,
97 agentCfg config.Agent,
98 // These services are needed in the tools
99 permissions permission.Service,
100 sessions session.Service,
101 messages message.Service,
102 history history.Service,
103 lspClients map[string]*lsp.Client,
104) (Service, error) {
105 cfg := config.Get()
106
107 var agentToolFn func() (tools.BaseTool, error)
108 if agentCfg.ID == "coder" {
109 agentToolFn = func() (tools.BaseTool, error) {
110 taskAgentCfg := config.Get().Agents["task"]
111 if taskAgentCfg.ID == "" {
112 return nil, fmt.Errorf("task agent not found in config")
113 }
114 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
115 if err != nil {
116 return nil, fmt.Errorf("failed to create task agent: %w", err)
117 }
118 return NewAgentTool(taskAgent, sessions, messages), nil
119 }
120 }
121
122 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
123 if providerCfg == nil {
124 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
125 }
126 model := config.Get().GetModelByType(agentCfg.Model)
127
128 if model == nil {
129 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
130 }
131
132 promptID := agentPromptMap[agentCfg.ID]
133 if promptID == "" {
134 promptID = prompt.PromptDefault
135 }
136 opts := []provider.ProviderClientOption{
137 provider.WithModel(agentCfg.Model),
138 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
139 }
140 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
141 if err != nil {
142 return nil, err
143 }
144
145 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
146 var smallModelProviderCfg *config.ProviderConfig
147 if smallModelCfg.Provider == providerCfg.ID {
148 smallModelProviderCfg = providerCfg
149 } else {
150 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
151
152 if smallModelProviderCfg.ID == "" {
153 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
154 }
155 }
156 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
157 if smallModel.ID == "" {
158 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
159 }
160
161 titleOpts := []provider.ProviderClientOption{
162 provider.WithModel(config.SelectedModelTypeSmall),
163 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
164 }
165 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
166 if err != nil {
167 return nil, err
168 }
169
170 summarizeOpts := []provider.ProviderClientOption{
171 provider.WithModel(config.SelectedModelTypeLarge),
172 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
173 }
174 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
175 if err != nil {
176 return nil, err
177 }
178
179 toolFn := func() []tools.BaseTool {
180 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
181 defer func() {
182 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
183 }()
184
185 cwd := cfg.WorkingDir()
186 allTools := []tools.BaseTool{
187 tools.NewBashTool(permissions, cwd),
188 tools.NewDownloadTool(permissions, cwd),
189 tools.NewEditTool(lspClients, permissions, history, cwd),
190 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
191 tools.NewFetchTool(permissions, cwd),
192 tools.NewGlobTool(cwd),
193 tools.NewGrepTool(cwd),
194 tools.NewLsTool(permissions, cwd),
195 tools.NewSourcegraphTool(),
196 tools.NewViewTool(lspClients, permissions, cwd),
197 tools.NewWriteTool(lspClients, permissions, history, cwd),
198 }
199
200 mcpToolsOnce.Do(func() {
201 mcpTools = doGetMCPTools(ctx, permissions, cfg)
202 })
203
204 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
205 if agentCfg.ID == "coder" {
206 t = append(t, mcpTools...)
207 if len(lspClients) > 0 {
208 t = append(t, tools.NewDiagnosticsTool(lspClients))
209 }
210 }
211 return t
212 }
213
214 if agentCfg.AllowedTools == nil {
215 return withCoderTools(allTools)
216 }
217
218 var filteredTools []tools.BaseTool
219 for _, tool := range allTools {
220 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
221 filteredTools = append(filteredTools, tool)
222 }
223 }
224
225 if agentCfg.ID == "coder" {
226 filteredTools = append(filteredTools, mcpTools...)
227 }
228 return withCoderTools(filteredTools)
229 }
230
231 return &agent{
232 Broker: pubsub.NewBroker[AgentEvent](),
233 agentCfg: agentCfg,
234 provider: agentProvider,
235 providerID: string(providerCfg.ID),
236 messages: messages,
237 sessions: sessions,
238 titleProvider: titleProvider,
239 summarizeProvider: summarizeProvider,
240 summarizeProviderID: string(providerCfg.ID),
241 agentToolFn: agentToolFn,
242 activeRequests: csync.NewMap[string, context.CancelFunc](),
243 tools: csync.NewLazySlice(toolFn),
244 promptQueue: csync.NewMap[string, []string](),
245 }, nil
246}
247
248func (a *agent) Model() catwalk.Model {
249 return *config.Get().GetModelByType(a.agentCfg.Model)
250}
251
252func (a *agent) Cancel(sessionID string) {
253 // Cancel regular requests
254 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
255 slog.Info("Request cancellation initiated", "session_id", sessionID)
256 cancel()
257 }
258
259 // Also check for summarize requests
260 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
261 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
262 cancel()
263 }
264
265 if a.QueuedPrompts(sessionID) > 0 {
266 slog.Info("Clearing queued prompts", "session_id", sessionID)
267 a.promptQueue.Del(sessionID)
268 }
269}
270
271func (a *agent) IsBusy() bool {
272 var busy bool
273 for cancelFunc := range a.activeRequests.Seq() {
274 if cancelFunc != nil {
275 busy = true
276 break
277 }
278 }
279 return busy
280}
281
282func (a *agent) IsSessionBusy(sessionID string) bool {
283 _, busy := a.activeRequests.Get(sessionID)
284 return busy
285}
286
287func (a *agent) QueuedPrompts(sessionID string) int {
288 l, ok := a.promptQueue.Get(sessionID)
289 if !ok {
290 return 0
291 }
292 return len(l)
293}
294
295func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
296 if content == "" {
297 return nil
298 }
299 if a.titleProvider == nil {
300 return nil
301 }
302 session, err := a.sessions.Get(ctx, sessionID)
303 if err != nil {
304 return err
305 }
306 parts := []message.ContentPart{message.TextContent{
307 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
308 }}
309
310 // Use streaming approach like summarization
311 response := a.titleProvider.StreamResponse(
312 ctx,
313 []message.Message{
314 {
315 Role: message.User,
316 Parts: parts,
317 },
318 },
319 nil,
320 )
321
322 var finalResponse *provider.ProviderResponse
323 for r := range response {
324 if r.Error != nil {
325 return r.Error
326 }
327 finalResponse = r.Response
328 }
329
330 if finalResponse == nil {
331 return fmt.Errorf("no response received from title provider")
332 }
333
334 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
335 if title == "" {
336 return nil
337 }
338
339 session.Title = title
340 _, err = a.sessions.Save(ctx, session)
341 return err
342}
343
344func (a *agent) err(err error) AgentEvent {
345 return AgentEvent{
346 Type: AgentEventTypeError,
347 Error: err,
348 }
349}
350
351func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
352 if !a.Model().SupportsImages && attachments != nil {
353 attachments = nil
354 }
355 events := make(chan AgentEvent, 1)
356 if a.IsSessionBusy(sessionID) {
357 existing, ok := a.promptQueue.Get(sessionID)
358 if !ok {
359 existing = []string{}
360 }
361 existing = append(existing, content)
362 a.promptQueue.Set(sessionID, existing)
363 return nil, nil
364 }
365
366 genCtx, cancel := context.WithCancel(ctx)
367
368 a.activeRequests.Set(sessionID, cancel)
369 go func() {
370 slog.Debug("Request started", "sessionID", sessionID)
371 defer log.RecoverPanic("agent.Run", func() {
372 events <- a.err(fmt.Errorf("panic while running the agent"))
373 })
374 var attachmentParts []message.ContentPart
375 for _, attachment := range attachments {
376 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
377 }
378 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
379 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
380 slog.Error(result.Error.Error())
381 }
382 slog.Debug("Request completed", "sessionID", sessionID)
383 a.activeRequests.Del(sessionID)
384 cancel()
385 a.Publish(pubsub.CreatedEvent, result)
386 events <- result
387 close(events)
388 }()
389 return events, nil
390}
391
392func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
393 cfg := config.Get()
394 // List existing messages; if none, start title generation asynchronously.
395 msgs, err := a.messages.List(ctx, sessionID)
396 if err != nil {
397 return a.err(fmt.Errorf("failed to list messages: %w", err))
398 }
399 if len(msgs) == 0 {
400 go func() {
401 defer log.RecoverPanic("agent.Run", func() {
402 slog.Error("panic while generating title")
403 })
404 titleErr := a.generateTitle(ctx, sessionID, content)
405 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
406 slog.Error("failed to generate title", "error", titleErr)
407 }
408 }()
409 }
410 session, err := a.sessions.Get(ctx, sessionID)
411 if err != nil {
412 return a.err(fmt.Errorf("failed to get session: %w", err))
413 }
414 if session.SummaryMessageID != "" {
415 summaryMsgInex := -1
416 for i, msg := range msgs {
417 if msg.ID == session.SummaryMessageID {
418 summaryMsgInex = i
419 break
420 }
421 }
422 if summaryMsgInex != -1 {
423 msgs = msgs[summaryMsgInex:]
424 msgs[0].Role = message.User
425 }
426 }
427
428 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
429 if err != nil {
430 return a.err(fmt.Errorf("failed to create user message: %w", err))
431 }
432 // Append the new user message to the conversation history.
433 msgHistory := append(msgs, userMsg)
434
435 for {
436 // Check for cancellation before each iteration
437 select {
438 case <-ctx.Done():
439 return a.err(ctx.Err())
440 default:
441 // Continue processing
442 }
443 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
444 if err != nil {
445 if errors.Is(err, context.Canceled) {
446 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
447 a.messages.Update(context.Background(), agentMessage)
448 return a.err(ErrRequestCancelled)
449 }
450 return a.err(fmt.Errorf("failed to process events: %w", err))
451 }
452 if cfg.Options.Debug {
453 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
454 }
455 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
456 // We are not done, we need to respond with the tool response
457 msgHistory = append(msgHistory, agentMessage, *toolResults)
458 // If there are queued prompts, process the next one
459 nextPrompt, ok := a.promptQueue.Take(sessionID)
460 if ok {
461 for _, prompt := range nextPrompt {
462 // Create a new user message for the queued prompt
463 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
464 if err != nil {
465 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
466 }
467 // Append the new user message to the conversation history
468 msgHistory = append(msgHistory, userMsg)
469 }
470 }
471
472 continue
473 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
474 queuePrompts, ok := a.promptQueue.Take(sessionID)
475 if ok {
476 for _, prompt := range queuePrompts {
477 if prompt == "" {
478 continue
479 }
480 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
481 if err != nil {
482 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
483 }
484 msgHistory = append(msgHistory, userMsg)
485 }
486 continue
487 }
488 }
489 if agentMessage.FinishReason() == "" {
490 // Kujtim: could not track down where this is happening but this means its cancelled
491 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
492 _ = a.messages.Update(context.Background(), agentMessage)
493 return a.err(ErrRequestCancelled)
494 }
495 return AgentEvent{
496 Type: AgentEventTypeResponse,
497 Message: agentMessage,
498 Done: true,
499 }
500 }
501}
502
503func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
504 parts := []message.ContentPart{message.TextContent{Text: content}}
505 parts = append(parts, attachmentParts...)
506 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
507 Role: message.User,
508 Parts: parts,
509 })
510}
511
512func (a *agent) getAllTools() ([]tools.BaseTool, error) {
513 allTools := slices.Collect(a.tools.Seq())
514 if a.agentToolFn != nil {
515 agentTool, agentToolErr := a.agentToolFn()
516 if agentToolErr != nil {
517 return nil, agentToolErr
518 }
519 allTools = append(allTools, agentTool)
520 }
521 return allTools, nil
522}
523
524func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
525 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
526
527 // Create the assistant message first so the spinner shows immediately
528 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
529 Role: message.Assistant,
530 Parts: []message.ContentPart{},
531 Model: a.Model().ID,
532 Provider: a.providerID,
533 })
534 if err != nil {
535 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
536 }
537
538 allTools, toolsErr := a.getAllTools()
539 if toolsErr != nil {
540 return assistantMsg, nil, toolsErr
541 }
542 // Now collect tools (which may block on MCP initialization)
543 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
544
545 // Add the session and message ID into the context if needed by tools.
546 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
547
548 // Process each event in the stream.
549 for event := range eventChan {
550 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
551 if errors.Is(processErr, context.Canceled) {
552 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
553 } else {
554 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
555 }
556 return assistantMsg, nil, processErr
557 }
558 if ctx.Err() != nil {
559 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
560 return assistantMsg, nil, ctx.Err()
561 }
562 }
563
564 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
565 toolCalls := assistantMsg.ToolCalls()
566 for i, toolCall := range toolCalls {
567 select {
568 case <-ctx.Done():
569 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
570 // Make all future tool calls cancelled
571 for j := i; j < len(toolCalls); j++ {
572 toolResults[j] = message.ToolResult{
573 ToolCallID: toolCalls[j].ID,
574 Content: "Tool execution canceled by user",
575 IsError: true,
576 }
577 }
578 goto out
579 default:
580 // Continue processing
581 var tool tools.BaseTool
582 allTools, _ := a.getAllTools()
583 for _, availableTool := range allTools {
584 if availableTool.Info().Name == toolCall.Name {
585 tool = availableTool
586 break
587 }
588 }
589
590 // Tool not found
591 if tool == nil {
592 toolResults[i] = message.ToolResult{
593 ToolCallID: toolCall.ID,
594 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
595 IsError: true,
596 }
597 continue
598 }
599
600 // Run tool in goroutine to allow cancellation
601 type toolExecResult struct {
602 response tools.ToolResponse
603 err error
604 }
605 resultChan := make(chan toolExecResult, 1)
606
607 go func() {
608 response, err := tool.Run(ctx, tools.ToolCall{
609 ID: toolCall.ID,
610 Name: toolCall.Name,
611 Input: toolCall.Input,
612 })
613 resultChan <- toolExecResult{response: response, err: err}
614 }()
615
616 var toolResponse tools.ToolResponse
617 var toolErr error
618
619 select {
620 case <-ctx.Done():
621 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
622 // Mark remaining tool calls as cancelled
623 for j := i; j < len(toolCalls); j++ {
624 toolResults[j] = message.ToolResult{
625 ToolCallID: toolCalls[j].ID,
626 Content: "Tool execution canceled by user",
627 IsError: true,
628 }
629 }
630 goto out
631 case result := <-resultChan:
632 toolResponse = result.response
633 toolErr = result.err
634 }
635
636 if toolErr != nil {
637 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
638 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
639 toolResults[i] = message.ToolResult{
640 ToolCallID: toolCall.ID,
641 Content: "Permission denied",
642 IsError: true,
643 }
644 for j := i + 1; j < len(toolCalls); j++ {
645 toolResults[j] = message.ToolResult{
646 ToolCallID: toolCalls[j].ID,
647 Content: "Tool execution canceled by user",
648 IsError: true,
649 }
650 }
651 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
652 break
653 }
654 }
655 toolResults[i] = message.ToolResult{
656 ToolCallID: toolCall.ID,
657 Content: toolResponse.Content,
658 Metadata: toolResponse.Metadata,
659 IsError: toolResponse.IsError,
660 }
661 }
662 }
663out:
664 if len(toolResults) == 0 {
665 return assistantMsg, nil, nil
666 }
667 parts := make([]message.ContentPart, 0)
668 for _, tr := range toolResults {
669 parts = append(parts, tr)
670 }
671 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
672 Role: message.Tool,
673 Parts: parts,
674 Provider: a.providerID,
675 })
676 if err != nil {
677 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
678 }
679
680 return assistantMsg, &msg, err
681}
682
683func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
684 msg.AddFinish(finishReason, message, details)
685 _ = a.messages.Update(ctx, *msg)
686}
687
688func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
689 select {
690 case <-ctx.Done():
691 return ctx.Err()
692 default:
693 // Continue processing.
694 }
695
696 switch event.Type {
697 case provider.EventThinkingDelta:
698 assistantMsg.AppendReasoningContent(event.Thinking)
699 return a.messages.Update(ctx, *assistantMsg)
700 case provider.EventSignatureDelta:
701 assistantMsg.AppendReasoningSignature(event.Signature)
702 return a.messages.Update(ctx, *assistantMsg)
703 case provider.EventContentDelta:
704 assistantMsg.FinishThinking()
705 assistantMsg.AppendContent(event.Content)
706 return a.messages.Update(ctx, *assistantMsg)
707 case provider.EventToolUseStart:
708 assistantMsg.FinishThinking()
709 slog.Info("Tool call started", "toolCall", event.ToolCall)
710 assistantMsg.AddToolCall(*event.ToolCall)
711 return a.messages.Update(ctx, *assistantMsg)
712 case provider.EventToolUseDelta:
713 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
714 return a.messages.Update(ctx, *assistantMsg)
715 case provider.EventToolUseStop:
716 slog.Info("Finished tool call", "toolCall", event.ToolCall)
717 assistantMsg.FinishToolCall(event.ToolCall.ID)
718 return a.messages.Update(ctx, *assistantMsg)
719 case provider.EventError:
720 return event.Error
721 case provider.EventComplete:
722 assistantMsg.FinishThinking()
723 assistantMsg.SetToolCalls(event.Response.ToolCalls)
724 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
725 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
726 return fmt.Errorf("failed to update message: %w", err)
727 }
728 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
729 }
730
731 return nil
732}
733
734func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
735 sess, err := a.sessions.Get(ctx, sessionID)
736 if err != nil {
737 return fmt.Errorf("failed to get session: %w", err)
738 }
739
740 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
741 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
742 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
743 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
744
745 sess.Cost += cost
746 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
747 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
748
749 _, err = a.sessions.Save(ctx, sess)
750 if err != nil {
751 return fmt.Errorf("failed to save session: %w", err)
752 }
753 return nil
754}
755
756func (a *agent) Summarize(ctx context.Context, sessionID string) error {
757 if a.summarizeProvider == nil {
758 return fmt.Errorf("summarize provider not available")
759 }
760
761 // Check if session is busy
762 if a.IsSessionBusy(sessionID) {
763 return ErrSessionBusy
764 }
765
766 // Create a new context with cancellation
767 summarizeCtx, cancel := context.WithCancel(ctx)
768
769 // Store the cancel function in activeRequests to allow cancellation
770 a.activeRequests.Set(sessionID+"-summarize", cancel)
771
772 go func() {
773 defer a.activeRequests.Del(sessionID + "-summarize")
774 defer cancel()
775 event := AgentEvent{
776 Type: AgentEventTypeSummarize,
777 Progress: "Starting summarization...",
778 }
779
780 a.Publish(pubsub.CreatedEvent, event)
781 // Get all messages from the session
782 msgs, err := a.messages.List(summarizeCtx, sessionID)
783 if err != nil {
784 event = AgentEvent{
785 Type: AgentEventTypeError,
786 Error: fmt.Errorf("failed to list messages: %w", err),
787 Done: true,
788 }
789 a.Publish(pubsub.CreatedEvent, event)
790 return
791 }
792 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
793
794 if len(msgs) == 0 {
795 event = AgentEvent{
796 Type: AgentEventTypeError,
797 Error: fmt.Errorf("no messages to summarize"),
798 Done: true,
799 }
800 a.Publish(pubsub.CreatedEvent, event)
801 return
802 }
803
804 event = AgentEvent{
805 Type: AgentEventTypeSummarize,
806 Progress: "Analyzing conversation...",
807 }
808 a.Publish(pubsub.CreatedEvent, event)
809
810 // Add a system message to guide the summarization
811 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
812
813 // Create a new message with the summarize prompt
814 promptMsg := message.Message{
815 Role: message.User,
816 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
817 }
818
819 // Append the prompt to the messages
820 msgsWithPrompt := append(msgs, promptMsg)
821
822 event = AgentEvent{
823 Type: AgentEventTypeSummarize,
824 Progress: "Generating summary...",
825 }
826
827 a.Publish(pubsub.CreatedEvent, event)
828
829 // Send the messages to the summarize provider
830 response := a.summarizeProvider.StreamResponse(
831 summarizeCtx,
832 msgsWithPrompt,
833 nil,
834 )
835 var finalResponse *provider.ProviderResponse
836 for r := range response {
837 if r.Error != nil {
838 event = AgentEvent{
839 Type: AgentEventTypeError,
840 Error: fmt.Errorf("failed to summarize: %w", err),
841 Done: true,
842 }
843 a.Publish(pubsub.CreatedEvent, event)
844 return
845 }
846 finalResponse = r.Response
847 }
848
849 summary := strings.TrimSpace(finalResponse.Content)
850 if summary == "" {
851 event = AgentEvent{
852 Type: AgentEventTypeError,
853 Error: fmt.Errorf("empty summary returned"),
854 Done: true,
855 }
856 a.Publish(pubsub.CreatedEvent, event)
857 return
858 }
859 shell := shell.GetPersistentShell(config.Get().WorkingDir())
860 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
861 event = AgentEvent{
862 Type: AgentEventTypeSummarize,
863 Progress: "Creating new session...",
864 }
865
866 a.Publish(pubsub.CreatedEvent, event)
867 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
868 if err != nil {
869 event = AgentEvent{
870 Type: AgentEventTypeError,
871 Error: fmt.Errorf("failed to get session: %w", err),
872 Done: true,
873 }
874
875 a.Publish(pubsub.CreatedEvent, event)
876 return
877 }
878 // Create a message in the new session with the summary
879 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
880 Role: message.Assistant,
881 Parts: []message.ContentPart{
882 message.TextContent{Text: summary},
883 message.Finish{
884 Reason: message.FinishReasonEndTurn,
885 Time: time.Now().Unix(),
886 },
887 },
888 Model: a.summarizeProvider.Model().ID,
889 Provider: a.summarizeProviderID,
890 })
891 if err != nil {
892 event = AgentEvent{
893 Type: AgentEventTypeError,
894 Error: fmt.Errorf("failed to create summary message: %w", err),
895 Done: true,
896 }
897
898 a.Publish(pubsub.CreatedEvent, event)
899 return
900 }
901 oldSession.SummaryMessageID = msg.ID
902 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
903 oldSession.PromptTokens = 0
904 model := a.summarizeProvider.Model()
905 usage := finalResponse.Usage
906 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
907 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
908 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
909 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
910 oldSession.Cost += cost
911 _, err = a.sessions.Save(summarizeCtx, oldSession)
912 if err != nil {
913 event = AgentEvent{
914 Type: AgentEventTypeError,
915 Error: fmt.Errorf("failed to save session: %w", err),
916 Done: true,
917 }
918 a.Publish(pubsub.CreatedEvent, event)
919 }
920
921 event = AgentEvent{
922 Type: AgentEventTypeSummarize,
923 SessionID: oldSession.ID,
924 Progress: "Summary complete",
925 Done: true,
926 }
927 a.Publish(pubsub.CreatedEvent, event)
928 // Send final success event with the new session ID
929 }()
930
931 return nil
932}
933
934func (a *agent) ClearQueue(sessionID string) {
935 if a.QueuedPrompts(sessionID) > 0 {
936 slog.Info("Clearing queued prompts", "session_id", sessionID)
937 a.promptQueue.Del(sessionID)
938 }
939}
940
941func (a *agent) CancelAll() {
942 if !a.IsBusy() {
943 return
944 }
945 for key := range a.activeRequests.Seq2() {
946 a.Cancel(key) // key is sessionID
947 }
948
949 timeout := time.After(5 * time.Second)
950 for a.IsBusy() {
951 select {
952 case <-timeout:
953 return
954 default:
955 time.Sleep(200 * time.Millisecond)
956 }
957 }
958}
959
960func (a *agent) UpdateModel() error {
961 cfg := config.Get()
962
963 // Get current provider configuration
964 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
965 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
966 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
967 }
968
969 // Check if provider has changed
970 if string(currentProviderCfg.ID) != a.providerID {
971 // Provider changed, need to recreate the main provider
972 model := cfg.GetModelByType(a.agentCfg.Model)
973 if model.ID == "" {
974 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
975 }
976
977 promptID := agentPromptMap[a.agentCfg.ID]
978 if promptID == "" {
979 promptID = prompt.PromptDefault
980 }
981
982 opts := []provider.ProviderClientOption{
983 provider.WithModel(a.agentCfg.Model),
984 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
985 }
986
987 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
988 if err != nil {
989 return fmt.Errorf("failed to create new provider: %w", err)
990 }
991
992 // Update the provider and provider ID
993 a.provider = newProvider
994 a.providerID = string(currentProviderCfg.ID)
995 }
996
997 // Check if providers have changed for title (small) and summarize (large)
998 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
999 var smallModelProviderCfg config.ProviderConfig
1000 for p := range cfg.Providers.Seq() {
1001 if p.ID == smallModelCfg.Provider {
1002 smallModelProviderCfg = p
1003 break
1004 }
1005 }
1006 if smallModelProviderCfg.ID == "" {
1007 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1008 }
1009
1010 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1011 var largeModelProviderCfg config.ProviderConfig
1012 for p := range cfg.Providers.Seq() {
1013 if p.ID == largeModelCfg.Provider {
1014 largeModelProviderCfg = p
1015 break
1016 }
1017 }
1018 if largeModelProviderCfg.ID == "" {
1019 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1020 }
1021
1022 var maxTitleTokens int64 = 40
1023
1024 // if the max output is too low for the gemini provider it won't return anything
1025 if smallModelCfg.Provider == "gemini" {
1026 maxTitleTokens = 1000
1027 }
1028 // Recreate title provider
1029 titleOpts := []provider.ProviderClientOption{
1030 provider.WithModel(config.SelectedModelTypeSmall),
1031 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1032 provider.WithMaxTokens(maxTitleTokens),
1033 }
1034 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1035 if err != nil {
1036 return fmt.Errorf("failed to create new title provider: %w", err)
1037 }
1038 a.titleProvider = newTitleProvider
1039
1040 // Recreate summarize provider if provider changed (now large model)
1041 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1042 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1043 if largeModel == nil {
1044 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1045 }
1046 summarizeOpts := []provider.ProviderClientOption{
1047 provider.WithModel(config.SelectedModelTypeLarge),
1048 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1049 }
1050 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1051 if err != nil {
1052 return fmt.Errorf("failed to create new summarize provider: %w", err)
1053 }
1054 a.summarizeProvider = newSummarizeProvider
1055 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1056 }
1057
1058 return nil
1059}