1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.LazySlice[tools.BaseTool]
75
76 provider provider.Provider
77 providerID string
78
79 titleProvider provider.Provider
80 summarizeProvider provider.Provider
81 summarizeProviderID string
82
83 activeRequests *csync.Map[string, context.CancelFunc]
84
85 promptQueue *csync.Map[string, []string]
86}
87
88var agentPromptMap = map[string]prompt.PromptID{
89 "coder": prompt.PromptCoder,
90 "task": prompt.PromptTask,
91}
92
93func NewAgent(
94 ctx context.Context,
95 agentCfg config.Agent,
96 // These services are needed in the tools
97 permissions permission.Service,
98 sessions session.Service,
99 messages message.Service,
100 history history.Service,
101 lspClients map[string]*lsp.Client,
102) (Service, error) {
103 cfg := config.Get()
104
105 var agentTool tools.BaseTool
106 if agentCfg.ID == "coder" {
107 taskAgentCfg := config.Get().Agents["task"]
108 if taskAgentCfg.ID == "" {
109 return nil, fmt.Errorf("task agent not found in config")
110 }
111 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
112 if err != nil {
113 return nil, fmt.Errorf("failed to create task agent: %w", err)
114 }
115
116 agentTool = NewAgentTool(taskAgent, sessions, messages)
117 }
118
119 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
120 if providerCfg == nil {
121 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
122 }
123 model := config.Get().GetModelByType(agentCfg.Model)
124
125 if model == nil {
126 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
127 }
128
129 promptID := agentPromptMap[agentCfg.ID]
130 if promptID == "" {
131 promptID = prompt.PromptDefault
132 }
133 opts := []provider.ProviderClientOption{
134 provider.WithModel(agentCfg.Model),
135 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
136 }
137 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
138 if err != nil {
139 return nil, err
140 }
141
142 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
143 var smallModelProviderCfg *config.ProviderConfig
144 if smallModelCfg.Provider == providerCfg.ID {
145 smallModelProviderCfg = providerCfg
146 } else {
147 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
148
149 if smallModelProviderCfg.ID == "" {
150 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
151 }
152 }
153 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
154 if smallModel.ID == "" {
155 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
156 }
157
158 titleOpts := []provider.ProviderClientOption{
159 provider.WithModel(config.SelectedModelTypeSmall),
160 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
161 }
162 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
163 if err != nil {
164 return nil, err
165 }
166
167 summarizeOpts := []provider.ProviderClientOption{
168 provider.WithModel(config.SelectedModelTypeLarge),
169 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
170 }
171 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
172 if err != nil {
173 return nil, err
174 }
175
176 toolFn := func() []tools.BaseTool {
177 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
178 defer func() {
179 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
180 }()
181
182 cwd := cfg.WorkingDir()
183 allTools := []tools.BaseTool{
184 tools.NewBashTool(permissions, cwd),
185 tools.NewDownloadTool(permissions, cwd),
186 tools.NewEditTool(lspClients, permissions, history, cwd),
187 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
188 tools.NewFetchTool(permissions, cwd),
189 tools.NewGlobTool(cwd),
190 tools.NewGrepTool(cwd),
191 tools.NewLsTool(permissions, cwd),
192 tools.NewSourcegraphTool(),
193 tools.NewViewTool(lspClients, permissions, cwd),
194 tools.NewWriteTool(lspClients, permissions, history, cwd),
195 }
196
197 mcpToolsOnce.Do(func() {
198 mcpTools = doGetMCPTools(ctx, permissions, cfg)
199 })
200 allTools = append(allTools, mcpTools...)
201
202 if len(lspClients) > 0 {
203 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
204 }
205
206 if agentTool != nil {
207 allTools = append(allTools, agentTool)
208 }
209
210 if agentCfg.AllowedTools == nil {
211 return allTools
212 }
213
214 var filteredTools []tools.BaseTool
215 for _, tool := range allTools {
216 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
217 filteredTools = append(filteredTools, tool)
218 }
219 }
220 return filteredTools
221 }
222
223 return &agent{
224 Broker: pubsub.NewBroker[AgentEvent](),
225 agentCfg: agentCfg,
226 provider: agentProvider,
227 providerID: string(providerCfg.ID),
228 messages: messages,
229 sessions: sessions,
230 titleProvider: titleProvider,
231 summarizeProvider: summarizeProvider,
232 summarizeProviderID: string(providerCfg.ID),
233 activeRequests: csync.NewMap[string, context.CancelFunc](),
234 tools: csync.NewLazySlice(toolFn),
235 promptQueue: csync.NewMap[string, []string](),
236 }, nil
237}
238
239func (a *agent) Model() catwalk.Model {
240 return *config.Get().GetModelByType(a.agentCfg.Model)
241}
242
243func (a *agent) Cancel(sessionID string) {
244 // Cancel regular requests
245 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
246 slog.Info("Request cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249
250 // Also check for summarize requests
251 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
252 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
253 cancel()
254 }
255
256 if a.QueuedPrompts(sessionID) > 0 {
257 slog.Info("Clearing queued prompts", "session_id", sessionID)
258 a.promptQueue.Del(sessionID)
259 }
260}
261
262func (a *agent) IsBusy() bool {
263 var busy bool
264 for cancelFunc := range a.activeRequests.Seq() {
265 if cancelFunc != nil {
266 busy = true
267 break
268 }
269 }
270 return busy
271}
272
273func (a *agent) IsSessionBusy(sessionID string) bool {
274 _, busy := a.activeRequests.Get(sessionID)
275 return busy
276}
277
278func (a *agent) QueuedPrompts(sessionID string) int {
279 l, ok := a.promptQueue.Get(sessionID)
280 if !ok {
281 return 0
282 }
283 return len(l)
284}
285
286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
287 if content == "" {
288 return nil
289 }
290 if a.titleProvider == nil {
291 return nil
292 }
293 session, err := a.sessions.Get(ctx, sessionID)
294 if err != nil {
295 return err
296 }
297 parts := []message.ContentPart{message.TextContent{
298 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
299 }}
300
301 // Use streaming approach like summarization
302 response := a.titleProvider.StreamResponse(
303 ctx,
304 []message.Message{
305 {
306 Role: message.User,
307 Parts: parts,
308 },
309 },
310 nil,
311 )
312
313 var finalResponse *provider.ProviderResponse
314 for r := range response {
315 if r.Error != nil {
316 return r.Error
317 }
318 finalResponse = r.Response
319 }
320
321 if finalResponse == nil {
322 return fmt.Errorf("no response received from title provider")
323 }
324
325 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
326 if title == "" {
327 return nil
328 }
329
330 session.Title = title
331 _, err = a.sessions.Save(ctx, session)
332 return err
333}
334
335func (a *agent) err(err error) AgentEvent {
336 return AgentEvent{
337 Type: AgentEventTypeError,
338 Error: err,
339 }
340}
341
342func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
343 if !a.Model().SupportsImages && attachments != nil {
344 attachments = nil
345 }
346 events := make(chan AgentEvent)
347 if a.IsSessionBusy(sessionID) {
348 existing, ok := a.promptQueue.Get(sessionID)
349 if !ok {
350 existing = []string{}
351 }
352 existing = append(existing, content)
353 a.promptQueue.Set(sessionID, existing)
354 return nil, nil
355 }
356
357 genCtx, cancel := context.WithCancel(ctx)
358
359 a.activeRequests.Set(sessionID, cancel)
360 go func() {
361 slog.Debug("Request started", "sessionID", sessionID)
362 defer log.RecoverPanic("agent.Run", func() {
363 events <- a.err(fmt.Errorf("panic while running the agent"))
364 })
365 var attachmentParts []message.ContentPart
366 for _, attachment := range attachments {
367 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
368 }
369 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
370 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
371 slog.Error(result.Error.Error())
372 }
373 slog.Debug("Request completed", "sessionID", sessionID)
374 a.activeRequests.Del(sessionID)
375 cancel()
376 a.Publish(pubsub.CreatedEvent, result)
377 select {
378 case events <- result:
379 case <-genCtx.Done():
380 }
381 close(events)
382 }()
383 return events, nil
384}
385
386func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
387 cfg := config.Get()
388 // List existing messages; if none, start title generation asynchronously.
389 msgs, err := a.messages.List(ctx, sessionID)
390 if err != nil {
391 return a.err(fmt.Errorf("failed to list messages: %w", err))
392 }
393 if len(msgs) == 0 {
394 go func() {
395 defer log.RecoverPanic("agent.Run", func() {
396 slog.Error("panic while generating title")
397 })
398 titleErr := a.generateTitle(ctx, sessionID, content)
399 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
400 slog.Error("failed to generate title", "error", titleErr)
401 }
402 }()
403 }
404 session, err := a.sessions.Get(ctx, sessionID)
405 if err != nil {
406 return a.err(fmt.Errorf("failed to get session: %w", err))
407 }
408 if session.SummaryMessageID != "" {
409 summaryMsgInex := -1
410 for i, msg := range msgs {
411 if msg.ID == session.SummaryMessageID {
412 summaryMsgInex = i
413 break
414 }
415 }
416 if summaryMsgInex != -1 {
417 msgs = msgs[summaryMsgInex:]
418 msgs[0].Role = message.User
419 }
420 }
421
422 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
423 if err != nil {
424 return a.err(fmt.Errorf("failed to create user message: %w", err))
425 }
426 // Append the new user message to the conversation history.
427 msgHistory := append(msgs, userMsg)
428
429 for {
430 // Check for cancellation before each iteration
431 select {
432 case <-ctx.Done():
433 return a.err(ctx.Err())
434 default:
435 // Continue processing
436 }
437 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
438 if err != nil {
439 if errors.Is(err, context.Canceled) {
440 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
441 a.messages.Update(context.Background(), agentMessage)
442 return a.err(ErrRequestCancelled)
443 }
444 return a.err(fmt.Errorf("failed to process events: %w", err))
445 }
446 if cfg.Options.Debug {
447 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
448 }
449 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
450 // We are not done, we need to respond with the tool response
451 msgHistory = append(msgHistory, agentMessage, *toolResults)
452 // If there are queued prompts, process the next one
453 nextPrompt, ok := a.promptQueue.Take(sessionID)
454 if ok {
455 for _, prompt := range nextPrompt {
456 // Create a new user message for the queued prompt
457 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
458 if err != nil {
459 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
460 }
461 // Append the new user message to the conversation history
462 msgHistory = append(msgHistory, userMsg)
463 }
464 }
465
466 continue
467 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
468 queuePrompts, ok := a.promptQueue.Take(sessionID)
469 if ok {
470 for _, prompt := range queuePrompts {
471 if prompt == "" {
472 continue
473 }
474 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
475 if err != nil {
476 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
477 }
478 msgHistory = append(msgHistory, userMsg)
479 }
480 continue
481 }
482 }
483 if agentMessage.FinishReason() == "" {
484 // Kujtim: could not track down where this is happening but this means its cancelled
485 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
486 _ = a.messages.Update(context.Background(), agentMessage)
487 return a.err(ErrRequestCancelled)
488 }
489 return AgentEvent{
490 Type: AgentEventTypeResponse,
491 Message: agentMessage,
492 Done: true,
493 }
494 }
495}
496
497func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
498 parts := []message.ContentPart{message.TextContent{Text: content}}
499 parts = append(parts, attachmentParts...)
500 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
501 Role: message.User,
502 Parts: parts,
503 })
504}
505
506func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
507 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
508
509 // Create the assistant message first so the spinner shows immediately
510 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
511 Role: message.Assistant,
512 Parts: []message.ContentPart{},
513 Model: a.Model().ID,
514 Provider: a.providerID,
515 })
516 if err != nil {
517 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
518 }
519
520 // Now collect tools (which may block on MCP initialization)
521 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
522
523 // Add the session and message ID into the context if needed by tools.
524 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
525
526 // Process each event in the stream.
527 for event := range eventChan {
528 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
529 if errors.Is(processErr, context.Canceled) {
530 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
531 } else {
532 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
533 }
534 return assistantMsg, nil, processErr
535 }
536 if ctx.Err() != nil {
537 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
538 return assistantMsg, nil, ctx.Err()
539 }
540 }
541
542 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
543 toolCalls := assistantMsg.ToolCalls()
544 for i, toolCall := range toolCalls {
545 select {
546 case <-ctx.Done():
547 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
548 // Make all future tool calls cancelled
549 for j := i; j < len(toolCalls); j++ {
550 toolResults[j] = message.ToolResult{
551 ToolCallID: toolCalls[j].ID,
552 Content: "Tool execution canceled by user",
553 IsError: true,
554 }
555 }
556 goto out
557 default:
558 // Continue processing
559 var tool tools.BaseTool
560 for availableTool := range a.tools.Seq() {
561 if availableTool.Info().Name == toolCall.Name {
562 tool = availableTool
563 break
564 }
565 }
566
567 // Tool not found
568 if tool == nil {
569 toolResults[i] = message.ToolResult{
570 ToolCallID: toolCall.ID,
571 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
572 IsError: true,
573 }
574 continue
575 }
576
577 // Run tool in goroutine to allow cancellation
578 type toolExecResult struct {
579 response tools.ToolResponse
580 err error
581 }
582 resultChan := make(chan toolExecResult, 1)
583
584 go func() {
585 response, err := tool.Run(ctx, tools.ToolCall{
586 ID: toolCall.ID,
587 Name: toolCall.Name,
588 Input: toolCall.Input,
589 })
590 resultChan <- toolExecResult{response: response, err: err}
591 }()
592
593 var toolResponse tools.ToolResponse
594 var toolErr error
595
596 select {
597 case <-ctx.Done():
598 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
599 // Mark remaining tool calls as cancelled
600 for j := i; j < len(toolCalls); j++ {
601 toolResults[j] = message.ToolResult{
602 ToolCallID: toolCalls[j].ID,
603 Content: "Tool execution canceled by user",
604 IsError: true,
605 }
606 }
607 goto out
608 case result := <-resultChan:
609 toolResponse = result.response
610 toolErr = result.err
611 }
612
613 if toolErr != nil {
614 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
615 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
616 toolResults[i] = message.ToolResult{
617 ToolCallID: toolCall.ID,
618 Content: "Permission denied",
619 IsError: true,
620 }
621 for j := i + 1; j < len(toolCalls); j++ {
622 toolResults[j] = message.ToolResult{
623 ToolCallID: toolCalls[j].ID,
624 Content: "Tool execution canceled by user",
625 IsError: true,
626 }
627 }
628 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
629 break
630 }
631 }
632 toolResults[i] = message.ToolResult{
633 ToolCallID: toolCall.ID,
634 Content: toolResponse.Content,
635 Metadata: toolResponse.Metadata,
636 IsError: toolResponse.IsError,
637 }
638 }
639 }
640out:
641 if len(toolResults) == 0 {
642 return assistantMsg, nil, nil
643 }
644 parts := make([]message.ContentPart, 0)
645 for _, tr := range toolResults {
646 parts = append(parts, tr)
647 }
648 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
649 Role: message.Tool,
650 Parts: parts,
651 Provider: a.providerID,
652 })
653 if err != nil {
654 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
655 }
656
657 return assistantMsg, &msg, err
658}
659
660func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
661 msg.AddFinish(finishReason, message, details)
662 _ = a.messages.Update(ctx, *msg)
663}
664
665func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
666 select {
667 case <-ctx.Done():
668 return ctx.Err()
669 default:
670 // Continue processing.
671 }
672
673 switch event.Type {
674 case provider.EventThinkingDelta:
675 assistantMsg.AppendReasoningContent(event.Thinking)
676 return a.messages.Update(ctx, *assistantMsg)
677 case provider.EventSignatureDelta:
678 assistantMsg.AppendReasoningSignature(event.Signature)
679 return a.messages.Update(ctx, *assistantMsg)
680 case provider.EventContentDelta:
681 assistantMsg.FinishThinking()
682 assistantMsg.AppendContent(event.Content)
683 return a.messages.Update(ctx, *assistantMsg)
684 case provider.EventToolUseStart:
685 assistantMsg.FinishThinking()
686 slog.Info("Tool call started", "toolCall", event.ToolCall)
687 assistantMsg.AddToolCall(*event.ToolCall)
688 return a.messages.Update(ctx, *assistantMsg)
689 case provider.EventToolUseDelta:
690 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
691 return a.messages.Update(ctx, *assistantMsg)
692 case provider.EventToolUseStop:
693 slog.Info("Finished tool call", "toolCall", event.ToolCall)
694 assistantMsg.FinishToolCall(event.ToolCall.ID)
695 return a.messages.Update(ctx, *assistantMsg)
696 case provider.EventError:
697 return event.Error
698 case provider.EventComplete:
699 assistantMsg.FinishThinking()
700 assistantMsg.SetToolCalls(event.Response.ToolCalls)
701 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
702 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
703 return fmt.Errorf("failed to update message: %w", err)
704 }
705 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
706 }
707
708 return nil
709}
710
711func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
712 sess, err := a.sessions.Get(ctx, sessionID)
713 if err != nil {
714 return fmt.Errorf("failed to get session: %w", err)
715 }
716
717 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
718 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
719 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
720 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
721
722 sess.Cost += cost
723 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
724 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
725
726 _, err = a.sessions.Save(ctx, sess)
727 if err != nil {
728 return fmt.Errorf("failed to save session: %w", err)
729 }
730 return nil
731}
732
733func (a *agent) Summarize(ctx context.Context, sessionID string) error {
734 if a.summarizeProvider == nil {
735 return fmt.Errorf("summarize provider not available")
736 }
737
738 // Check if session is busy
739 if a.IsSessionBusy(sessionID) {
740 return ErrSessionBusy
741 }
742
743 // Create a new context with cancellation
744 summarizeCtx, cancel := context.WithCancel(ctx)
745
746 // Store the cancel function in activeRequests to allow cancellation
747 a.activeRequests.Set(sessionID+"-summarize", cancel)
748
749 go func() {
750 defer a.activeRequests.Del(sessionID + "-summarize")
751 defer cancel()
752 event := AgentEvent{
753 Type: AgentEventTypeSummarize,
754 Progress: "Starting summarization...",
755 }
756
757 a.Publish(pubsub.CreatedEvent, event)
758 // Get all messages from the session
759 msgs, err := a.messages.List(summarizeCtx, sessionID)
760 if err != nil {
761 event = AgentEvent{
762 Type: AgentEventTypeError,
763 Error: fmt.Errorf("failed to list messages: %w", err),
764 Done: true,
765 }
766 a.Publish(pubsub.CreatedEvent, event)
767 return
768 }
769 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
770
771 if len(msgs) == 0 {
772 event = AgentEvent{
773 Type: AgentEventTypeError,
774 Error: fmt.Errorf("no messages to summarize"),
775 Done: true,
776 }
777 a.Publish(pubsub.CreatedEvent, event)
778 return
779 }
780
781 event = AgentEvent{
782 Type: AgentEventTypeSummarize,
783 Progress: "Analyzing conversation...",
784 }
785 a.Publish(pubsub.CreatedEvent, event)
786
787 // Add a system message to guide the summarization
788 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
789
790 // Create a new message with the summarize prompt
791 promptMsg := message.Message{
792 Role: message.User,
793 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
794 }
795
796 // Append the prompt to the messages
797 msgsWithPrompt := append(msgs, promptMsg)
798
799 event = AgentEvent{
800 Type: AgentEventTypeSummarize,
801 Progress: "Generating summary...",
802 }
803
804 a.Publish(pubsub.CreatedEvent, event)
805
806 // Send the messages to the summarize provider
807 response := a.summarizeProvider.StreamResponse(
808 summarizeCtx,
809 msgsWithPrompt,
810 nil,
811 )
812 var finalResponse *provider.ProviderResponse
813 for r := range response {
814 if r.Error != nil {
815 event = AgentEvent{
816 Type: AgentEventTypeError,
817 Error: fmt.Errorf("failed to summarize: %w", err),
818 Done: true,
819 }
820 a.Publish(pubsub.CreatedEvent, event)
821 return
822 }
823 finalResponse = r.Response
824 }
825
826 summary := strings.TrimSpace(finalResponse.Content)
827 if summary == "" {
828 event = AgentEvent{
829 Type: AgentEventTypeError,
830 Error: fmt.Errorf("empty summary returned"),
831 Done: true,
832 }
833 a.Publish(pubsub.CreatedEvent, event)
834 return
835 }
836 shell := shell.GetPersistentShell(config.Get().WorkingDir())
837 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
838 event = AgentEvent{
839 Type: AgentEventTypeSummarize,
840 Progress: "Creating new session...",
841 }
842
843 a.Publish(pubsub.CreatedEvent, event)
844 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
845 if err != nil {
846 event = AgentEvent{
847 Type: AgentEventTypeError,
848 Error: fmt.Errorf("failed to get session: %w", err),
849 Done: true,
850 }
851
852 a.Publish(pubsub.CreatedEvent, event)
853 return
854 }
855 // Create a message in the new session with the summary
856 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
857 Role: message.Assistant,
858 Parts: []message.ContentPart{
859 message.TextContent{Text: summary},
860 message.Finish{
861 Reason: message.FinishReasonEndTurn,
862 Time: time.Now().Unix(),
863 },
864 },
865 Model: a.summarizeProvider.Model().ID,
866 Provider: a.summarizeProviderID,
867 })
868 if err != nil {
869 event = AgentEvent{
870 Type: AgentEventTypeError,
871 Error: fmt.Errorf("failed to create summary message: %w", err),
872 Done: true,
873 }
874
875 a.Publish(pubsub.CreatedEvent, event)
876 return
877 }
878 oldSession.SummaryMessageID = msg.ID
879 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
880 oldSession.PromptTokens = 0
881 model := a.summarizeProvider.Model()
882 usage := finalResponse.Usage
883 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
884 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
885 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
886 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
887 oldSession.Cost += cost
888 _, err = a.sessions.Save(summarizeCtx, oldSession)
889 if err != nil {
890 event = AgentEvent{
891 Type: AgentEventTypeError,
892 Error: fmt.Errorf("failed to save session: %w", err),
893 Done: true,
894 }
895 a.Publish(pubsub.CreatedEvent, event)
896 }
897
898 event = AgentEvent{
899 Type: AgentEventTypeSummarize,
900 SessionID: oldSession.ID,
901 Progress: "Summary complete",
902 Done: true,
903 }
904 a.Publish(pubsub.CreatedEvent, event)
905 // Send final success event with the new session ID
906 }()
907
908 return nil
909}
910
911func (a *agent) ClearQueue(sessionID string) {
912 if a.QueuedPrompts(sessionID) > 0 {
913 slog.Info("Clearing queued prompts", "session_id", sessionID)
914 a.promptQueue.Del(sessionID)
915 }
916}
917
918func (a *agent) CancelAll() {
919 if !a.IsBusy() {
920 return
921 }
922 for key := range a.activeRequests.Seq2() {
923 a.Cancel(key) // key is sessionID
924 }
925
926 timeout := time.After(5 * time.Second)
927 for a.IsBusy() {
928 select {
929 case <-timeout:
930 return
931 default:
932 time.Sleep(200 * time.Millisecond)
933 }
934 }
935}
936
937func (a *agent) UpdateModel() error {
938 cfg := config.Get()
939
940 // Get current provider configuration
941 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
942 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
943 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
944 }
945
946 // Check if provider has changed
947 if string(currentProviderCfg.ID) != a.providerID {
948 // Provider changed, need to recreate the main provider
949 model := cfg.GetModelByType(a.agentCfg.Model)
950 if model.ID == "" {
951 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
952 }
953
954 promptID := agentPromptMap[a.agentCfg.ID]
955 if promptID == "" {
956 promptID = prompt.PromptDefault
957 }
958
959 opts := []provider.ProviderClientOption{
960 provider.WithModel(a.agentCfg.Model),
961 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
962 }
963
964 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
965 if err != nil {
966 return fmt.Errorf("failed to create new provider: %w", err)
967 }
968
969 // Update the provider and provider ID
970 a.provider = newProvider
971 a.providerID = string(currentProviderCfg.ID)
972 }
973
974 // Check if providers have changed for title (small) and summarize (large)
975 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
976 var smallModelProviderCfg config.ProviderConfig
977 for p := range cfg.Providers.Seq() {
978 if p.ID == smallModelCfg.Provider {
979 smallModelProviderCfg = p
980 break
981 }
982 }
983 if smallModelProviderCfg.ID == "" {
984 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
985 }
986
987 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
988 var largeModelProviderCfg config.ProviderConfig
989 for p := range cfg.Providers.Seq() {
990 if p.ID == largeModelCfg.Provider {
991 largeModelProviderCfg = p
992 break
993 }
994 }
995 if largeModelProviderCfg.ID == "" {
996 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
997 }
998
999 var maxTitleTokens int64 = 40
1000
1001 // if the max output is too low for the gemini provider it won't return anything
1002 if smallModelCfg.Provider == "gemini" {
1003 maxTitleTokens = 1000
1004 }
1005 // Recreate title provider
1006 titleOpts := []provider.ProviderClientOption{
1007 provider.WithModel(config.SelectedModelTypeSmall),
1008 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1009 provider.WithMaxTokens(maxTitleTokens),
1010 }
1011 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1012 if err != nil {
1013 return fmt.Errorf("failed to create new title provider: %w", err)
1014 }
1015 a.titleProvider = newTitleProvider
1016
1017 // Recreate summarize provider if provider changed (now large model)
1018 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1019 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1020 if largeModel == nil {
1021 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1022 }
1023 summarizeOpts := []provider.ProviderClientOption{
1024 provider.WithModel(config.SelectedModelTypeLarge),
1025 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1026 }
1027 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1028 if err != nil {
1029 return fmt.Errorf("failed to create new summarize provider: %w", err)
1030 }
1031 a.summarizeProvider = newSummarizeProvider
1032 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1033 }
1034
1035 return nil
1036}