1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.LazySlice[tools.BaseTool]
75
76 provider provider.Provider
77 providerID string
78
79 titleProvider provider.Provider
80 summarizeProvider provider.Provider
81 summarizeProviderID string
82
83 activeRequests *csync.Map[string, context.CancelFunc]
84
85 promptQueue *csync.Map[string, []string]
86}
87
88var agentPromptMap = map[string]prompt.PromptID{
89 "coder": prompt.PromptCoder,
90 "task": prompt.PromptTask,
91}
92
93func NewAgent(
94 ctx context.Context,
95 agentCfg config.Agent,
96 // These services are needed in the tools
97 permissions permission.Service,
98 sessions session.Service,
99 messages message.Service,
100 history history.Service,
101 lspClients map[string]*lsp.Client,
102) (Service, error) {
103 cfg := config.Get()
104
105 var agentTool tools.BaseTool
106 if agentCfg.ID == "coder" {
107 taskAgentCfg := config.Get().Agents["task"]
108 if taskAgentCfg.ID == "" {
109 return nil, fmt.Errorf("task agent not found in config")
110 }
111 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
112 if err != nil {
113 return nil, fmt.Errorf("failed to create task agent: %w", err)
114 }
115
116 agentTool = NewAgentTool(taskAgent, sessions, messages)
117 }
118
119 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
120 if providerCfg == nil {
121 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
122 }
123 model := config.Get().GetModelByType(agentCfg.Model)
124
125 if model == nil {
126 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
127 }
128
129 promptID := agentPromptMap[agentCfg.ID]
130 if promptID == "" {
131 promptID = prompt.PromptDefault
132 }
133 opts := []provider.ProviderClientOption{
134 provider.WithModel(agentCfg.Model),
135 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
136 }
137 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
138 if err != nil {
139 return nil, err
140 }
141
142 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
143 var smallModelProviderCfg *config.ProviderConfig
144 if smallModelCfg.Provider == providerCfg.ID {
145 smallModelProviderCfg = providerCfg
146 } else {
147 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
148
149 if smallModelProviderCfg.ID == "" {
150 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
151 }
152 }
153 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
154 if smallModel.ID == "" {
155 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
156 }
157
158 titleOpts := []provider.ProviderClientOption{
159 provider.WithModel(config.SelectedModelTypeSmall),
160 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
161 }
162 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
163 if err != nil {
164 return nil, err
165 }
166
167 summarizeOpts := []provider.ProviderClientOption{
168 provider.WithModel(config.SelectedModelTypeLarge),
169 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
170 }
171 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
172 if err != nil {
173 return nil, err
174 }
175
176 toolFn := func() []tools.BaseTool {
177 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
178 defer func() {
179 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
180 }()
181
182 cwd := cfg.WorkingDir()
183 allTools := []tools.BaseTool{
184 tools.NewBashTool(permissions, cwd),
185 tools.NewDownloadTool(permissions, cwd),
186 tools.NewEditTool(lspClients, permissions, history, cwd),
187 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
188 tools.NewFetchTool(permissions, cwd),
189 tools.NewGlobTool(cwd),
190 tools.NewGrepTool(cwd),
191 tools.NewLsTool(permissions, cwd),
192 tools.NewSourcegraphTool(),
193 tools.NewViewTool(lspClients, permissions, cwd),
194 tools.NewWriteTool(lspClients, permissions, history, cwd),
195 }
196
197 mcpToolsOnce.Do(func() {
198 mcpTools = doGetMCPTools(ctx, permissions, cfg)
199 })
200 allTools = append(allTools, mcpTools...)
201
202 if len(lspClients) > 0 {
203 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
204 }
205
206 if agentTool != nil {
207 allTools = append(allTools, agentTool)
208 }
209
210 if agentCfg.AllowedTools == nil {
211 return allTools
212 }
213
214 var filteredTools []tools.BaseTool
215 for _, tool := range allTools {
216 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
217 filteredTools = append(filteredTools, tool)
218 }
219 }
220 return filteredTools
221 }
222
223 return &agent{
224 Broker: pubsub.NewBroker[AgentEvent](),
225 agentCfg: agentCfg,
226 provider: agentProvider,
227 providerID: string(providerCfg.ID),
228 messages: messages,
229 sessions: sessions,
230 titleProvider: titleProvider,
231 summarizeProvider: summarizeProvider,
232 summarizeProviderID: string(providerCfg.ID),
233 activeRequests: csync.NewMap[string, context.CancelFunc](),
234 tools: csync.NewLazySlice(toolFn),
235 promptQueue: csync.NewMap[string, []string](),
236 }, nil
237}
238
239func (a *agent) Model() catwalk.Model {
240 return *config.Get().GetModelByType(a.agentCfg.Model)
241}
242
243func (a *agent) Cancel(sessionID string) {
244 // Cancel regular requests
245 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
246 slog.Info("Request cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249
250 // Also check for summarize requests
251 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
252 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
253 cancel()
254 }
255
256 if a.QueuedPrompts(sessionID) > 0 {
257 slog.Info("Clearing queued prompts", "session_id", sessionID)
258 a.promptQueue.Del(sessionID)
259 }
260}
261
262func (a *agent) IsBusy() bool {
263 var busy bool
264 for cancelFunc := range a.activeRequests.Seq() {
265 if cancelFunc != nil {
266 busy = true
267 break
268 }
269 }
270 return busy
271}
272
273func (a *agent) IsSessionBusy(sessionID string) bool {
274 _, busy := a.activeRequests.Get(sessionID)
275 return busy
276}
277
278func (a *agent) QueuedPrompts(sessionID string) int {
279 l, ok := a.promptQueue.Get(sessionID)
280 if !ok {
281 return 0
282 }
283 return len(l)
284}
285
286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
287 if content == "" {
288 return nil
289 }
290 if a.titleProvider == nil {
291 return nil
292 }
293 session, err := a.sessions.Get(ctx, sessionID)
294 if err != nil {
295 return err
296 }
297 parts := []message.ContentPart{message.TextContent{
298 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
299 }}
300
301 // Use streaming approach like summarization
302 response := a.titleProvider.StreamResponse(
303 ctx,
304 []message.Message{
305 {
306 Role: message.User,
307 Parts: parts,
308 },
309 },
310 nil,
311 )
312
313 var finalResponse *provider.ProviderResponse
314 for r := range response {
315 if r.Error != nil {
316 return r.Error
317 }
318 finalResponse = r.Response
319 }
320
321 if finalResponse == nil {
322 return fmt.Errorf("no response received from title provider")
323 }
324
325 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
326 if title == "" {
327 return nil
328 }
329
330 session.Title = title
331 _, err = a.sessions.Save(ctx, session)
332 return err
333}
334
335func (a *agent) err(err error) AgentEvent {
336 return AgentEvent{
337 Type: AgentEventTypeError,
338 Error: err,
339 }
340}
341
342func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
343 if !a.Model().SupportsImages && attachments != nil {
344 attachments = nil
345 }
346 events := make(chan AgentEvent, 1)
347 if a.IsSessionBusy(sessionID) {
348 existing, ok := a.promptQueue.Get(sessionID)
349 if !ok {
350 existing = []string{}
351 }
352 existing = append(existing, content)
353 a.promptQueue.Set(sessionID, existing)
354 return nil, nil
355 }
356
357 genCtx, cancel := context.WithCancel(ctx)
358
359 a.activeRequests.Set(sessionID, cancel)
360 go func() {
361 slog.Debug("Request started", "sessionID", sessionID)
362 defer log.RecoverPanic("agent.Run", func() {
363 events <- a.err(fmt.Errorf("panic while running the agent"))
364 })
365 var attachmentParts []message.ContentPart
366 for _, attachment := range attachments {
367 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
368 }
369 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
370 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
371 slog.Error(result.Error.Error())
372 }
373 slog.Debug("Request completed", "sessionID", sessionID)
374 a.activeRequests.Del(sessionID)
375 cancel()
376 a.Publish(pubsub.CreatedEvent, result)
377 events <- result
378 close(events)
379 }()
380 return events, nil
381}
382
383func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
384 cfg := config.Get()
385 // List existing messages; if none, start title generation asynchronously.
386 msgs, err := a.messages.List(ctx, sessionID)
387 if err != nil {
388 return a.err(fmt.Errorf("failed to list messages: %w", err))
389 }
390 if len(msgs) == 0 {
391 go func() {
392 defer log.RecoverPanic("agent.Run", func() {
393 slog.Error("panic while generating title")
394 })
395 titleErr := a.generateTitle(ctx, sessionID, content)
396 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
397 slog.Error("failed to generate title", "error", titleErr)
398 }
399 }()
400 }
401 session, err := a.sessions.Get(ctx, sessionID)
402 if err != nil {
403 return a.err(fmt.Errorf("failed to get session: %w", err))
404 }
405 if session.SummaryMessageID != "" {
406 summaryMsgInex := -1
407 for i, msg := range msgs {
408 if msg.ID == session.SummaryMessageID {
409 summaryMsgInex = i
410 break
411 }
412 }
413 if summaryMsgInex != -1 {
414 msgs = msgs[summaryMsgInex:]
415 msgs[0].Role = message.User
416 }
417 }
418
419 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
420 if err != nil {
421 return a.err(fmt.Errorf("failed to create user message: %w", err))
422 }
423 // Append the new user message to the conversation history.
424 msgHistory := append(msgs, userMsg)
425
426 for {
427 // Check for cancellation before each iteration
428 select {
429 case <-ctx.Done():
430 return a.err(ctx.Err())
431 default:
432 // Continue processing
433 }
434 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
435 if err != nil {
436 if errors.Is(err, context.Canceled) {
437 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
438 a.messages.Update(context.Background(), agentMessage)
439 return a.err(ErrRequestCancelled)
440 }
441 return a.err(fmt.Errorf("failed to process events: %w", err))
442 }
443 if cfg.Options.Debug {
444 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
445 }
446 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
447 // We are not done, we need to respond with the tool response
448 msgHistory = append(msgHistory, agentMessage, *toolResults)
449 // If there are queued prompts, process the next one
450 nextPrompt, ok := a.promptQueue.Take(sessionID)
451 if ok {
452 for _, prompt := range nextPrompt {
453 // Create a new user message for the queued prompt
454 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
455 if err != nil {
456 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
457 }
458 // Append the new user message to the conversation history
459 msgHistory = append(msgHistory, userMsg)
460 }
461 }
462
463 continue
464 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
465 queuePrompts, ok := a.promptQueue.Take(sessionID)
466 if ok {
467 for _, prompt := range queuePrompts {
468 if prompt == "" {
469 continue
470 }
471 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
472 if err != nil {
473 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
474 }
475 msgHistory = append(msgHistory, userMsg)
476 }
477 continue
478 }
479 }
480 if agentMessage.FinishReason() == "" {
481 // Kujtim: could not track down where this is happening but this means its cancelled
482 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
483 _ = a.messages.Update(context.Background(), agentMessage)
484 return a.err(ErrRequestCancelled)
485 }
486 return AgentEvent{
487 Type: AgentEventTypeResponse,
488 Message: agentMessage,
489 Done: true,
490 }
491 }
492}
493
494func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
495 parts := []message.ContentPart{message.TextContent{Text: content}}
496 parts = append(parts, attachmentParts...)
497 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
498 Role: message.User,
499 Parts: parts,
500 })
501}
502
503func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
504 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
505
506 // Create the assistant message first so the spinner shows immediately
507 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
508 Role: message.Assistant,
509 Parts: []message.ContentPart{},
510 Model: a.Model().ID,
511 Provider: a.providerID,
512 })
513 if err != nil {
514 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
515 }
516
517 // Now collect tools (which may block on MCP initialization)
518 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
519
520 // Add the session and message ID into the context if needed by tools.
521 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
522
523 // Process each event in the stream.
524 for event := range eventChan {
525 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
526 if errors.Is(processErr, context.Canceled) {
527 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
528 } else {
529 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
530 }
531 return assistantMsg, nil, processErr
532 }
533 if ctx.Err() != nil {
534 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
535 return assistantMsg, nil, ctx.Err()
536 }
537 }
538
539 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
540 toolCalls := assistantMsg.ToolCalls()
541 for i, toolCall := range toolCalls {
542 select {
543 case <-ctx.Done():
544 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
545 // Make all future tool calls cancelled
546 for j := i; j < len(toolCalls); j++ {
547 toolResults[j] = message.ToolResult{
548 ToolCallID: toolCalls[j].ID,
549 Content: "Tool execution canceled by user",
550 IsError: true,
551 }
552 }
553 goto out
554 default:
555 // Continue processing
556 var tool tools.BaseTool
557 for availableTool := range a.tools.Seq() {
558 if availableTool.Info().Name == toolCall.Name {
559 tool = availableTool
560 break
561 }
562 }
563
564 // Tool not found
565 if tool == nil {
566 toolResults[i] = message.ToolResult{
567 ToolCallID: toolCall.ID,
568 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
569 IsError: true,
570 }
571 continue
572 }
573
574 // Run tool in goroutine to allow cancellation
575 type toolExecResult struct {
576 response tools.ToolResponse
577 err error
578 }
579 resultChan := make(chan toolExecResult, 1)
580
581 go func() {
582 response, err := tool.Run(ctx, tools.ToolCall{
583 ID: toolCall.ID,
584 Name: toolCall.Name,
585 Input: toolCall.Input,
586 })
587 resultChan <- toolExecResult{response: response, err: err}
588 }()
589
590 var toolResponse tools.ToolResponse
591 var toolErr error
592
593 select {
594 case <-ctx.Done():
595 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
596 // Mark remaining tool calls as cancelled
597 for j := i; j < len(toolCalls); j++ {
598 toolResults[j] = message.ToolResult{
599 ToolCallID: toolCalls[j].ID,
600 Content: "Tool execution canceled by user",
601 IsError: true,
602 }
603 }
604 goto out
605 case result := <-resultChan:
606 toolResponse = result.response
607 toolErr = result.err
608 }
609
610 if toolErr != nil {
611 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
612 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
613 toolResults[i] = message.ToolResult{
614 ToolCallID: toolCall.ID,
615 Content: "Permission denied",
616 IsError: true,
617 }
618 for j := i + 1; j < len(toolCalls); j++ {
619 toolResults[j] = message.ToolResult{
620 ToolCallID: toolCalls[j].ID,
621 Content: "Tool execution canceled by user",
622 IsError: true,
623 }
624 }
625 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
626 break
627 }
628 }
629 toolResults[i] = message.ToolResult{
630 ToolCallID: toolCall.ID,
631 Content: toolResponse.Content,
632 Metadata: toolResponse.Metadata,
633 IsError: toolResponse.IsError,
634 }
635 }
636 }
637out:
638 if len(toolResults) == 0 {
639 return assistantMsg, nil, nil
640 }
641 parts := make([]message.ContentPart, 0)
642 for _, tr := range toolResults {
643 parts = append(parts, tr)
644 }
645 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
646 Role: message.Tool,
647 Parts: parts,
648 Provider: a.providerID,
649 })
650 if err != nil {
651 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
652 }
653
654 return assistantMsg, &msg, err
655}
656
657func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
658 msg.AddFinish(finishReason, message, details)
659 _ = a.messages.Update(ctx, *msg)
660}
661
662func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
663 select {
664 case <-ctx.Done():
665 return ctx.Err()
666 default:
667 // Continue processing.
668 }
669
670 switch event.Type {
671 case provider.EventThinkingDelta:
672 assistantMsg.AppendReasoningContent(event.Thinking)
673 return a.messages.Update(ctx, *assistantMsg)
674 case provider.EventSignatureDelta:
675 assistantMsg.AppendReasoningSignature(event.Signature)
676 return a.messages.Update(ctx, *assistantMsg)
677 case provider.EventContentDelta:
678 assistantMsg.FinishThinking()
679 assistantMsg.AppendContent(event.Content)
680 return a.messages.Update(ctx, *assistantMsg)
681 case provider.EventToolUseStart:
682 assistantMsg.FinishThinking()
683 slog.Info("Tool call started", "toolCall", event.ToolCall)
684 assistantMsg.AddToolCall(*event.ToolCall)
685 return a.messages.Update(ctx, *assistantMsg)
686 case provider.EventToolUseDelta:
687 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
688 return a.messages.Update(ctx, *assistantMsg)
689 case provider.EventToolUseStop:
690 slog.Info("Finished tool call", "toolCall", event.ToolCall)
691 assistantMsg.FinishToolCall(event.ToolCall.ID)
692 return a.messages.Update(ctx, *assistantMsg)
693 case provider.EventError:
694 return event.Error
695 case provider.EventComplete:
696 assistantMsg.FinishThinking()
697 assistantMsg.SetToolCalls(event.Response.ToolCalls)
698 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
699 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
700 return fmt.Errorf("failed to update message: %w", err)
701 }
702 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
703 }
704
705 return nil
706}
707
708func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
709 sess, err := a.sessions.Get(ctx, sessionID)
710 if err != nil {
711 return fmt.Errorf("failed to get session: %w", err)
712 }
713
714 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
715 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
716 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
717 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
718
719 sess.Cost += cost
720 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
721 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
722
723 _, err = a.sessions.Save(ctx, sess)
724 if err != nil {
725 return fmt.Errorf("failed to save session: %w", err)
726 }
727 return nil
728}
729
730func (a *agent) Summarize(ctx context.Context, sessionID string) error {
731 if a.summarizeProvider == nil {
732 return fmt.Errorf("summarize provider not available")
733 }
734
735 // Check if session is busy
736 if a.IsSessionBusy(sessionID) {
737 return ErrSessionBusy
738 }
739
740 // Create a new context with cancellation
741 summarizeCtx, cancel := context.WithCancel(ctx)
742
743 // Store the cancel function in activeRequests to allow cancellation
744 a.activeRequests.Set(sessionID+"-summarize", cancel)
745
746 go func() {
747 defer a.activeRequests.Del(sessionID + "-summarize")
748 defer cancel()
749 event := AgentEvent{
750 Type: AgentEventTypeSummarize,
751 Progress: "Starting summarization...",
752 }
753
754 a.Publish(pubsub.CreatedEvent, event)
755 // Get all messages from the session
756 msgs, err := a.messages.List(summarizeCtx, sessionID)
757 if err != nil {
758 event = AgentEvent{
759 Type: AgentEventTypeError,
760 Error: fmt.Errorf("failed to list messages: %w", err),
761 Done: true,
762 }
763 a.Publish(pubsub.CreatedEvent, event)
764 return
765 }
766 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
767
768 if len(msgs) == 0 {
769 event = AgentEvent{
770 Type: AgentEventTypeError,
771 Error: fmt.Errorf("no messages to summarize"),
772 Done: true,
773 }
774 a.Publish(pubsub.CreatedEvent, event)
775 return
776 }
777
778 event = AgentEvent{
779 Type: AgentEventTypeSummarize,
780 Progress: "Analyzing conversation...",
781 }
782 a.Publish(pubsub.CreatedEvent, event)
783
784 // Add a system message to guide the summarization
785 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
786
787 // Create a new message with the summarize prompt
788 promptMsg := message.Message{
789 Role: message.User,
790 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
791 }
792
793 // Append the prompt to the messages
794 msgsWithPrompt := append(msgs, promptMsg)
795
796 event = AgentEvent{
797 Type: AgentEventTypeSummarize,
798 Progress: "Generating summary...",
799 }
800
801 a.Publish(pubsub.CreatedEvent, event)
802
803 // Send the messages to the summarize provider
804 response := a.summarizeProvider.StreamResponse(
805 summarizeCtx,
806 msgsWithPrompt,
807 nil,
808 )
809 var finalResponse *provider.ProviderResponse
810 for r := range response {
811 if r.Error != nil {
812 event = AgentEvent{
813 Type: AgentEventTypeError,
814 Error: fmt.Errorf("failed to summarize: %w", err),
815 Done: true,
816 }
817 a.Publish(pubsub.CreatedEvent, event)
818 return
819 }
820 finalResponse = r.Response
821 }
822
823 summary := strings.TrimSpace(finalResponse.Content)
824 if summary == "" {
825 event = AgentEvent{
826 Type: AgentEventTypeError,
827 Error: fmt.Errorf("empty summary returned"),
828 Done: true,
829 }
830 a.Publish(pubsub.CreatedEvent, event)
831 return
832 }
833 shell := shell.GetPersistentShell(config.Get().WorkingDir())
834 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
835 event = AgentEvent{
836 Type: AgentEventTypeSummarize,
837 Progress: "Creating new session...",
838 }
839
840 a.Publish(pubsub.CreatedEvent, event)
841 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
842 if err != nil {
843 event = AgentEvent{
844 Type: AgentEventTypeError,
845 Error: fmt.Errorf("failed to get session: %w", err),
846 Done: true,
847 }
848
849 a.Publish(pubsub.CreatedEvent, event)
850 return
851 }
852 // Create a message in the new session with the summary
853 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
854 Role: message.Assistant,
855 Parts: []message.ContentPart{
856 message.TextContent{Text: summary},
857 message.Finish{
858 Reason: message.FinishReasonEndTurn,
859 Time: time.Now().Unix(),
860 },
861 },
862 Model: a.summarizeProvider.Model().ID,
863 Provider: a.summarizeProviderID,
864 })
865 if err != nil {
866 event = AgentEvent{
867 Type: AgentEventTypeError,
868 Error: fmt.Errorf("failed to create summary message: %w", err),
869 Done: true,
870 }
871
872 a.Publish(pubsub.CreatedEvent, event)
873 return
874 }
875 oldSession.SummaryMessageID = msg.ID
876 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
877 oldSession.PromptTokens = 0
878 model := a.summarizeProvider.Model()
879 usage := finalResponse.Usage
880 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
881 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
882 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
883 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
884 oldSession.Cost += cost
885 _, err = a.sessions.Save(summarizeCtx, oldSession)
886 if err != nil {
887 event = AgentEvent{
888 Type: AgentEventTypeError,
889 Error: fmt.Errorf("failed to save session: %w", err),
890 Done: true,
891 }
892 a.Publish(pubsub.CreatedEvent, event)
893 }
894
895 event = AgentEvent{
896 Type: AgentEventTypeSummarize,
897 SessionID: oldSession.ID,
898 Progress: "Summary complete",
899 Done: true,
900 }
901 a.Publish(pubsub.CreatedEvent, event)
902 // Send final success event with the new session ID
903 }()
904
905 return nil
906}
907
908func (a *agent) ClearQueue(sessionID string) {
909 if a.QueuedPrompts(sessionID) > 0 {
910 slog.Info("Clearing queued prompts", "session_id", sessionID)
911 a.promptQueue.Del(sessionID)
912 }
913}
914
915func (a *agent) CancelAll() {
916 if !a.IsBusy() {
917 return
918 }
919 for key := range a.activeRequests.Seq2() {
920 a.Cancel(key) // key is sessionID
921 }
922
923 timeout := time.After(5 * time.Second)
924 for a.IsBusy() {
925 select {
926 case <-timeout:
927 return
928 default:
929 time.Sleep(200 * time.Millisecond)
930 }
931 }
932}
933
934func (a *agent) UpdateModel() error {
935 cfg := config.Get()
936
937 // Get current provider configuration
938 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
939 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
940 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
941 }
942
943 // Check if provider has changed
944 if string(currentProviderCfg.ID) != a.providerID {
945 // Provider changed, need to recreate the main provider
946 model := cfg.GetModelByType(a.agentCfg.Model)
947 if model.ID == "" {
948 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
949 }
950
951 promptID := agentPromptMap[a.agentCfg.ID]
952 if promptID == "" {
953 promptID = prompt.PromptDefault
954 }
955
956 opts := []provider.ProviderClientOption{
957 provider.WithModel(a.agentCfg.Model),
958 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
959 }
960
961 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
962 if err != nil {
963 return fmt.Errorf("failed to create new provider: %w", err)
964 }
965
966 // Update the provider and provider ID
967 a.provider = newProvider
968 a.providerID = string(currentProviderCfg.ID)
969 }
970
971 // Check if providers have changed for title (small) and summarize (large)
972 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
973 var smallModelProviderCfg config.ProviderConfig
974 for p := range cfg.Providers.Seq() {
975 if p.ID == smallModelCfg.Provider {
976 smallModelProviderCfg = p
977 break
978 }
979 }
980 if smallModelProviderCfg.ID == "" {
981 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
982 }
983
984 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
985 var largeModelProviderCfg config.ProviderConfig
986 for p := range cfg.Providers.Seq() {
987 if p.ID == largeModelCfg.Provider {
988 largeModelProviderCfg = p
989 break
990 }
991 }
992 if largeModelProviderCfg.ID == "" {
993 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
994 }
995
996 var maxTitleTokens int64 = 40
997
998 // if the max output is too low for the gemini provider it won't return anything
999 if smallModelCfg.Provider == "gemini" {
1000 maxTitleTokens = 1000
1001 }
1002 // Recreate title provider
1003 titleOpts := []provider.ProviderClientOption{
1004 provider.WithModel(config.SelectedModelTypeSmall),
1005 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1006 provider.WithMaxTokens(maxTitleTokens),
1007 }
1008 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1009 if err != nil {
1010 return fmt.Errorf("failed to create new title provider: %w", err)
1011 }
1012 a.titleProvider = newTitleProvider
1013
1014 // Recreate summarize provider if provider changed (now large model)
1015 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1016 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1017 if largeModel == nil {
1018 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1019 }
1020 summarizeOpts := []provider.ProviderClientOption{
1021 provider.WithModel(config.SelectedModelTypeLarge),
1022 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1023 }
1024 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1025 if err != nil {
1026 return fmt.Errorf("failed to create new summarize provider: %w", err)
1027 }
1028 a.summarizeProvider = newSummarizeProvider
1029 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1030 }
1031
1032 return nil
1033}