1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/proto"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26 "github.com/charmbracelet/crush/internal/shell"
27)
28
29// Common errors
30var (
31 ErrRequestCancelled = errors.New("request canceled by user")
32 ErrSessionBusy = errors.New("session is currently processing another request")
33)
34
35type (
36 AgentEventType = proto.AgentEventType
37 AgentEvent = proto.AgentEvent
38)
39
40const (
41 AgentEventTypeError = proto.AgentEventTypeError
42 AgentEventTypeResponse = proto.AgentEventTypeResponse
43 AgentEventTypeSummarize = proto.AgentEventTypeSummarize
44)
45
46type Service interface {
47 pubsub.Suscriber[AgentEvent]
48 Model() catwalk.Model
49 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
50 Cancel(sessionID string)
51 CancelAll()
52 IsSessionBusy(sessionID string) bool
53 IsBusy() bool
54 Summarize(ctx context.Context, sessionID string) error
55 UpdateModel() error
56 QueuedPrompts(sessionID string) int
57 ClearQueue(sessionID string)
58}
59
60type agent struct {
61 *pubsub.Broker[AgentEvent]
62 agentCfg config.Agent
63 sessions session.Service
64 messages message.Service
65 mcpTools []McpTool
66 cfg *config.Config
67
68 tools *csync.LazySlice[tools.BaseTool]
69 // We need this to be able to update it when model changes
70 agentToolFn func() (tools.BaseTool, error)
71
72 provider provider.Provider
73 providerID string
74
75 titleProvider provider.Provider
76 summarizeProvider provider.Provider
77 summarizeProviderID string
78
79 activeRequests *csync.Map[string, context.CancelFunc]
80 promptQueue *csync.Map[string, []string]
81}
82
83var agentPromptMap = map[string]prompt.PromptID{
84 "coder": prompt.PromptCoder,
85 "task": prompt.PromptTask,
86}
87
88func NewAgent(
89 ctx context.Context,
90 cfg *config.Config,
91 agentCfg config.Agent,
92 // These services are needed in the tools
93 permissions permission.Service,
94 sessions session.Service,
95 messages message.Service,
96 history history.Service,
97 lspClients *csync.Map[string, *lsp.Client],
98) (Service, error) {
99 var agentToolFn func() (tools.BaseTool, error)
100 if agentCfg.ID == "coder" {
101 agentToolFn = func() (tools.BaseTool, error) {
102 taskAgentCfg := cfg.Agents["task"]
103 if taskAgentCfg.ID == "" {
104 return nil, fmt.Errorf("task agent not found in config")
105 }
106 taskAgent, err := NewAgent(ctx, cfg, taskAgentCfg, permissions, sessions, messages, history, lspClients)
107 if err != nil {
108 return nil, fmt.Errorf("failed to create task agent: %w", err)
109 }
110 return NewAgentTool(taskAgent, sessions, messages), nil
111 }
112 }
113
114 providerCfg := cfg.GetProviderForModel(agentCfg.Model)
115 if providerCfg == nil {
116 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
117 }
118 model := cfg.GetModelByType(agentCfg.Model)
119
120 if model == nil {
121 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
122 }
123
124 promptID := agentPromptMap[agentCfg.ID]
125 if promptID == "" {
126 promptID = prompt.PromptDefault
127 }
128 opts := []provider.ProviderClientOption{
129 provider.WithModel(agentCfg.Model),
130 provider.WithSystemMessage(prompt.GetPrompt(cfg, promptID, providerCfg.ID, cfg.Options.ContextPaths...)),
131 }
132 agentProvider, err := provider.NewProvider(cfg, *providerCfg, opts...)
133 if err != nil {
134 return nil, err
135 }
136
137 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
138 var smallModelProviderCfg *config.ProviderConfig
139 if smallModelCfg.Provider == providerCfg.ID {
140 smallModelProviderCfg = providerCfg
141 } else {
142 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
143
144 if smallModelProviderCfg.ID == "" {
145 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
146 }
147 }
148 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
149 if smallModel.ID == "" {
150 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
151 }
152
153 titleOpts := []provider.ProviderClientOption{
154 provider.WithModel(config.SelectedModelTypeSmall),
155 provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptTitle, smallModelProviderCfg.ID)),
156 }
157 titleProvider, err := provider.NewProvider(cfg, *smallModelProviderCfg, titleOpts...)
158 if err != nil {
159 return nil, err
160 }
161
162 summarizeOpts := []provider.ProviderClientOption{
163 provider.WithModel(config.SelectedModelTypeLarge),
164 provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptSummarizer, providerCfg.ID)),
165 }
166 summarizeProvider, err := provider.NewProvider(cfg, *providerCfg, summarizeOpts...)
167 if err != nil {
168 return nil, err
169 }
170
171 toolFn := func() []tools.BaseTool {
172 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173 defer func() {
174 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175 }()
176
177 cwd := cfg.WorkingDir()
178 allTools := []tools.BaseTool{
179 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
180 tools.NewDownloadTool(permissions, cwd),
181 tools.NewEditTool(lspClients, permissions, history, cwd),
182 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
183 tools.NewFetchTool(permissions, cwd),
184 tools.NewGlobTool(cwd),
185 tools.NewGrepTool(cwd),
186 tools.NewLsTool(permissions, cwd),
187 tools.NewSourcegraphTool(),
188 tools.NewViewTool(lspClients, permissions, cwd),
189 tools.NewWriteTool(lspClients, permissions, history, cwd),
190 }
191
192 mcpToolsOnce.Do(func() {
193 mcpTools = doGetMCPTools(ctx, permissions, cfg)
194 })
195
196 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
197 if agentCfg.ID == "coder" {
198 t = append(t, mcpTools...)
199 if lspClients.Len() > 0 {
200 t = append(t, tools.NewDiagnosticsTool(lspClients))
201 }
202 }
203 return t
204 }
205
206 if agentCfg.AllowedTools == nil {
207 return withCoderTools(allTools)
208 }
209
210 var filteredTools []tools.BaseTool
211 for _, tool := range allTools {
212 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
213 filteredTools = append(filteredTools, tool)
214 }
215 }
216 return withCoderTools(filteredTools)
217 }
218
219 return &agent{
220 Broker: pubsub.NewBroker[AgentEvent](),
221 agentCfg: agentCfg,
222 provider: agentProvider,
223 providerID: string(providerCfg.ID),
224 messages: messages,
225 sessions: sessions,
226 titleProvider: titleProvider,
227 summarizeProvider: summarizeProvider,
228 summarizeProviderID: string(providerCfg.ID),
229 agentToolFn: agentToolFn,
230 activeRequests: csync.NewMap[string, context.CancelFunc](),
231 tools: csync.NewLazySlice(toolFn),
232 promptQueue: csync.NewMap[string, []string](),
233 cfg: cfg,
234 }, nil
235}
236
237func (a *agent) Model() catwalk.Model {
238 return *a.cfg.GetModelByType(a.agentCfg.Model)
239}
240
241func (a *agent) Cancel(sessionID string) {
242 // Cancel regular requests
243 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
244 slog.Info("Request cancellation initiated", "session_id", sessionID)
245 cancel()
246 }
247
248 // Also check for summarize requests
249 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
250 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
251 cancel()
252 }
253
254 if a.QueuedPrompts(sessionID) > 0 {
255 slog.Info("Clearing queued prompts", "session_id", sessionID)
256 a.promptQueue.Del(sessionID)
257 }
258}
259
260func (a *agent) IsBusy() bool {
261 var busy bool
262 for cancelFunc := range a.activeRequests.Seq() {
263 if cancelFunc != nil {
264 busy = true
265 break
266 }
267 }
268 return busy
269}
270
271func (a *agent) IsSessionBusy(sessionID string) bool {
272 _, busy := a.activeRequests.Get(sessionID)
273 return busy
274}
275
276func (a *agent) QueuedPrompts(sessionID string) int {
277 l, ok := a.promptQueue.Get(sessionID)
278 if !ok {
279 return 0
280 }
281 return len(l)
282}
283
284func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
285 if content == "" {
286 return nil
287 }
288 if a.titleProvider == nil {
289 return nil
290 }
291 session, err := a.sessions.Get(ctx, sessionID)
292 if err != nil {
293 return err
294 }
295 parts := []message.ContentPart{message.TextContent{
296 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
297 }}
298
299 // Use streaming approach like summarization
300 response := a.titleProvider.StreamResponse(
301 ctx,
302 []message.Message{
303 {
304 Role: message.User,
305 Parts: parts,
306 },
307 },
308 nil,
309 )
310
311 var finalResponse *provider.ProviderResponse
312 for r := range response {
313 if r.Error != nil {
314 return r.Error
315 }
316 finalResponse = r.Response
317 }
318
319 if finalResponse == nil {
320 return fmt.Errorf("no response received from title provider")
321 }
322
323 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
324
325 if idx := strings.Index(title, "</think>"); idx > 0 {
326 title = title[idx+len("</think>"):]
327 }
328
329 title = strings.TrimSpace(title)
330 if title == "" {
331 return nil
332 }
333
334 session.Title = title
335 _, err = a.sessions.Save(ctx, session)
336 return err
337}
338
339func (a *agent) err(err error) AgentEvent {
340 return AgentEvent{
341 Type: AgentEventTypeError,
342 Error: err,
343 }
344}
345
346func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
347 if !a.Model().SupportsImages && attachments != nil {
348 attachments = nil
349 }
350 events := make(chan AgentEvent, 1)
351 if a.IsSessionBusy(sessionID) {
352 existing, ok := a.promptQueue.Get(sessionID)
353 if !ok {
354 existing = []string{}
355 }
356 existing = append(existing, content)
357 a.promptQueue.Set(sessionID, existing)
358 return nil, nil
359 }
360
361 genCtx, cancel := context.WithCancel(ctx)
362
363 a.activeRequests.Set(sessionID, cancel)
364 go func() {
365 slog.Debug("Request started", "sessionID", sessionID)
366 defer log.RecoverPanic("agent.Run", func() {
367 events <- a.err(fmt.Errorf("panic while running the agent"))
368 })
369 var attachmentParts []message.ContentPart
370 for _, attachment := range attachments {
371 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
372 }
373 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
374 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
375 slog.Error(result.Error.Error())
376 }
377 slog.Debug("Request completed", "sessionID", sessionID)
378 a.activeRequests.Del(sessionID)
379 cancel()
380 a.Publish(pubsub.CreatedEvent, result)
381 events <- result
382 close(events)
383 }()
384 return events, nil
385}
386
387func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
388 // List existing messages; if none, start title generation asynchronously.
389 msgs, err := a.messages.List(ctx, sessionID)
390 if err != nil {
391 return a.err(fmt.Errorf("failed to list messages: %w", err))
392 }
393 if len(msgs) == 0 {
394 go func() {
395 defer log.RecoverPanic("agent.Run", func() {
396 slog.Error("panic while generating title")
397 })
398 titleErr := a.generateTitle(ctx, sessionID, content)
399 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
400 slog.Error("failed to generate title", "error", titleErr)
401 }
402 }()
403 }
404 session, err := a.sessions.Get(ctx, sessionID)
405 if err != nil {
406 return a.err(fmt.Errorf("failed to get session: %w", err))
407 }
408 if session.SummaryMessageID != "" {
409 summaryMsgInex := -1
410 for i, msg := range msgs {
411 if msg.ID == session.SummaryMessageID {
412 summaryMsgInex = i
413 break
414 }
415 }
416 if summaryMsgInex != -1 {
417 msgs = msgs[summaryMsgInex:]
418 msgs[0].Role = message.User
419 }
420 }
421
422 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
423 if err != nil {
424 return a.err(fmt.Errorf("failed to create user message: %w", err))
425 }
426 // Append the new user message to the conversation history.
427 msgHistory := append(msgs, userMsg)
428
429 for {
430 // Check for cancellation before each iteration
431 select {
432 case <-ctx.Done():
433 return a.err(ctx.Err())
434 default:
435 // Continue processing
436 }
437 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
438 if err != nil {
439 if errors.Is(err, context.Canceled) {
440 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
441 a.messages.Update(context.Background(), agentMessage)
442 return a.err(ErrRequestCancelled)
443 }
444 return a.err(fmt.Errorf("failed to process events: %w", err))
445 }
446 if a.cfg.Options.Debug {
447 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
448 }
449 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
450 // We are not done, we need to respond with the tool response
451 msgHistory = append(msgHistory, agentMessage, *toolResults)
452 // If there are queued prompts, process the next one
453 nextPrompt, ok := a.promptQueue.Take(sessionID)
454 if ok {
455 for _, prompt := range nextPrompt {
456 // Create a new user message for the queued prompt
457 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
458 if err != nil {
459 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
460 }
461 // Append the new user message to the conversation history
462 msgHistory = append(msgHistory, userMsg)
463 }
464 }
465
466 continue
467 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
468 queuePrompts, ok := a.promptQueue.Take(sessionID)
469 if ok {
470 for _, prompt := range queuePrompts {
471 if prompt == "" {
472 continue
473 }
474 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
475 if err != nil {
476 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
477 }
478 msgHistory = append(msgHistory, userMsg)
479 }
480 continue
481 }
482 }
483 if agentMessage.FinishReason() == "" {
484 // Kujtim: could not track down where this is happening but this means its cancelled
485 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
486 _ = a.messages.Update(context.Background(), agentMessage)
487 return a.err(ErrRequestCancelled)
488 }
489 return AgentEvent{
490 Type: AgentEventTypeResponse,
491 Message: agentMessage,
492 Done: true,
493 }
494 }
495}
496
497func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
498 parts := []message.ContentPart{message.TextContent{Text: content}}
499 parts = append(parts, attachmentParts...)
500 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
501 Role: message.User,
502 Parts: parts,
503 })
504}
505
506func (a *agent) getAllTools() ([]tools.BaseTool, error) {
507 allTools := slices.Collect(a.tools.Seq())
508 if a.agentToolFn != nil {
509 agentTool, agentToolErr := a.agentToolFn()
510 if agentToolErr != nil {
511 return nil, agentToolErr
512 }
513 allTools = append(allTools, agentTool)
514 }
515 return allTools, nil
516}
517
518func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
519 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
520
521 // Create the assistant message first so the spinner shows immediately
522 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
523 Role: message.Assistant,
524 Parts: []message.ContentPart{},
525 Model: a.Model().ID,
526 Provider: a.providerID,
527 })
528 if err != nil {
529 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
530 }
531
532 allTools, toolsErr := a.getAllTools()
533 if toolsErr != nil {
534 return assistantMsg, nil, toolsErr
535 }
536 // Now collect tools (which may block on MCP initialization)
537 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
538
539 // Add the session and message ID into the context if needed by tools.
540 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
541
542 // Process each event in the stream.
543 for event := range eventChan {
544 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
545 if errors.Is(processErr, context.Canceled) {
546 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
547 } else {
548 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
549 }
550 return assistantMsg, nil, processErr
551 }
552 if ctx.Err() != nil {
553 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
554 return assistantMsg, nil, ctx.Err()
555 }
556 }
557
558 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
559 toolCalls := assistantMsg.ToolCalls()
560 for i, toolCall := range toolCalls {
561 select {
562 case <-ctx.Done():
563 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
564 // Make all future tool calls cancelled
565 for j := i; j < len(toolCalls); j++ {
566 toolResults[j] = message.ToolResult{
567 ToolCallID: toolCalls[j].ID,
568 Content: "Tool execution canceled by user",
569 IsError: true,
570 }
571 }
572 goto out
573 default:
574 // Continue processing
575 var tool tools.BaseTool
576 allTools, _ := a.getAllTools()
577 for _, availableTool := range allTools {
578 if availableTool.Info().Name == toolCall.Name {
579 tool = availableTool
580 break
581 }
582 }
583
584 // Tool not found
585 if tool == nil {
586 toolResults[i] = message.ToolResult{
587 ToolCallID: toolCall.ID,
588 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
589 IsError: true,
590 }
591 continue
592 }
593
594 // Run tool in goroutine to allow cancellation
595 type toolExecResult struct {
596 response tools.ToolResponse
597 err error
598 }
599 resultChan := make(chan toolExecResult, 1)
600
601 go func() {
602 response, err := tool.Run(ctx, tools.ToolCall{
603 ID: toolCall.ID,
604 Name: toolCall.Name,
605 Input: toolCall.Input,
606 })
607 resultChan <- toolExecResult{response: response, err: err}
608 }()
609
610 var toolResponse tools.ToolResponse
611 var toolErr error
612
613 select {
614 case <-ctx.Done():
615 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
616 // Mark remaining tool calls as cancelled
617 for j := i; j < len(toolCalls); j++ {
618 toolResults[j] = message.ToolResult{
619 ToolCallID: toolCalls[j].ID,
620 Content: "Tool execution canceled by user",
621 IsError: true,
622 }
623 }
624 goto out
625 case result := <-resultChan:
626 toolResponse = result.response
627 toolErr = result.err
628 }
629
630 if toolErr != nil {
631 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
632 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
633 toolResults[i] = message.ToolResult{
634 ToolCallID: toolCall.ID,
635 Content: "Permission denied",
636 IsError: true,
637 }
638 for j := i + 1; j < len(toolCalls); j++ {
639 toolResults[j] = message.ToolResult{
640 ToolCallID: toolCalls[j].ID,
641 Content: "Tool execution canceled by user",
642 IsError: true,
643 }
644 }
645 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
646 break
647 }
648 }
649 toolResults[i] = message.ToolResult{
650 ToolCallID: toolCall.ID,
651 Content: toolResponse.Content,
652 Metadata: toolResponse.Metadata,
653 IsError: toolResponse.IsError,
654 }
655 }
656 }
657out:
658 if len(toolResults) == 0 {
659 return assistantMsg, nil, nil
660 }
661 parts := make([]message.ContentPart, 0)
662 for _, tr := range toolResults {
663 parts = append(parts, tr)
664 }
665 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
666 Role: message.Tool,
667 Parts: parts,
668 Provider: a.providerID,
669 })
670 if err != nil {
671 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
672 }
673
674 return assistantMsg, &msg, err
675}
676
677func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
678 msg.AddFinish(finishReason, message, details)
679 _ = a.messages.Update(ctx, *msg)
680}
681
682func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
683 select {
684 case <-ctx.Done():
685 return ctx.Err()
686 default:
687 // Continue processing.
688 }
689
690 switch event.Type {
691 case provider.EventThinkingDelta:
692 assistantMsg.AppendReasoningContent(event.Thinking)
693 return a.messages.Update(ctx, *assistantMsg)
694 case provider.EventSignatureDelta:
695 assistantMsg.AppendReasoningSignature(event.Signature)
696 return a.messages.Update(ctx, *assistantMsg)
697 case provider.EventContentDelta:
698 assistantMsg.FinishThinking()
699 assistantMsg.AppendContent(event.Content)
700 return a.messages.Update(ctx, *assistantMsg)
701 case provider.EventToolUseStart:
702 assistantMsg.FinishThinking()
703 slog.Info("Tool call started", "toolCall", event.ToolCall)
704 assistantMsg.AddToolCall(*event.ToolCall)
705 return a.messages.Update(ctx, *assistantMsg)
706 case provider.EventToolUseDelta:
707 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
708 return a.messages.Update(ctx, *assistantMsg)
709 case provider.EventToolUseStop:
710 slog.Info("Finished tool call", "toolCall", event.ToolCall)
711 assistantMsg.FinishToolCall(event.ToolCall.ID)
712 return a.messages.Update(ctx, *assistantMsg)
713 case provider.EventError:
714 return event.Error
715 case provider.EventComplete:
716 assistantMsg.FinishThinking()
717 assistantMsg.SetToolCalls(event.Response.ToolCalls)
718 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
719 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
720 return fmt.Errorf("failed to update message: %w", err)
721 }
722 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
723 }
724
725 return nil
726}
727
728func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
729 sess, err := a.sessions.Get(ctx, sessionID)
730 if err != nil {
731 return fmt.Errorf("failed to get session: %w", err)
732 }
733
734 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
735 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
736 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
737 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
738
739 sess.Cost += cost
740 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
741 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
742
743 _, err = a.sessions.Save(ctx, sess)
744 if err != nil {
745 return fmt.Errorf("failed to save session: %w", err)
746 }
747 return nil
748}
749
750func (a *agent) Summarize(ctx context.Context, sessionID string) error {
751 if a.summarizeProvider == nil {
752 return fmt.Errorf("summarize provider not available")
753 }
754
755 // Check if session is busy
756 if a.IsSessionBusy(sessionID) {
757 return ErrSessionBusy
758 }
759
760 // Create a new context with cancellation
761 summarizeCtx, cancel := context.WithCancel(ctx)
762
763 // Store the cancel function in activeRequests to allow cancellation
764 a.activeRequests.Set(sessionID+"-summarize", cancel)
765
766 go func() {
767 defer a.activeRequests.Del(sessionID + "-summarize")
768 defer cancel()
769 event := AgentEvent{
770 Type: AgentEventTypeSummarize,
771 Progress: "Starting summarization...",
772 }
773
774 a.Publish(pubsub.CreatedEvent, event)
775 // Get all messages from the session
776 msgs, err := a.messages.List(summarizeCtx, sessionID)
777 if err != nil {
778 event = AgentEvent{
779 Type: AgentEventTypeError,
780 Error: fmt.Errorf("failed to list messages: %w", err),
781 Done: true,
782 }
783 a.Publish(pubsub.CreatedEvent, event)
784 return
785 }
786 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
787
788 if len(msgs) == 0 {
789 event = AgentEvent{
790 Type: AgentEventTypeError,
791 Error: fmt.Errorf("no messages to summarize"),
792 Done: true,
793 }
794 a.Publish(pubsub.CreatedEvent, event)
795 return
796 }
797
798 event = AgentEvent{
799 Type: AgentEventTypeSummarize,
800 Progress: "Analyzing conversation...",
801 }
802 a.Publish(pubsub.CreatedEvent, event)
803
804 // Add a system message to guide the summarization
805 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
806
807 // Create a new message with the summarize prompt
808 promptMsg := message.Message{
809 Role: message.User,
810 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
811 }
812
813 // Append the prompt to the messages
814 msgsWithPrompt := append(msgs, promptMsg)
815
816 event = AgentEvent{
817 Type: AgentEventTypeSummarize,
818 Progress: "Generating summary...",
819 }
820
821 a.Publish(pubsub.CreatedEvent, event)
822
823 // Send the messages to the summarize provider
824 response := a.summarizeProvider.StreamResponse(
825 summarizeCtx,
826 msgsWithPrompt,
827 nil,
828 )
829 var finalResponse *provider.ProviderResponse
830 for r := range response {
831 if r.Error != nil {
832 event = AgentEvent{
833 Type: AgentEventTypeError,
834 Error: fmt.Errorf("failed to summarize: %w", r.Error),
835 Done: true,
836 }
837 a.Publish(pubsub.CreatedEvent, event)
838 return
839 }
840 finalResponse = r.Response
841 }
842
843 summary := strings.TrimSpace(finalResponse.Content)
844 if summary == "" {
845 event = AgentEvent{
846 Type: AgentEventTypeError,
847 Error: fmt.Errorf("empty summary returned"),
848 Done: true,
849 }
850 a.Publish(pubsub.CreatedEvent, event)
851 return
852 }
853 shell := shell.GetPersistentShell(a.cfg.WorkingDir())
854 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
855 event = AgentEvent{
856 Type: AgentEventTypeSummarize,
857 Progress: "Creating new session...",
858 }
859
860 a.Publish(pubsub.CreatedEvent, event)
861 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
862 if err != nil {
863 event = AgentEvent{
864 Type: AgentEventTypeError,
865 Error: fmt.Errorf("failed to get session: %w", err),
866 Done: true,
867 }
868
869 a.Publish(pubsub.CreatedEvent, event)
870 return
871 }
872 // Create a message in the new session with the summary
873 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
874 Role: message.Assistant,
875 Parts: []message.ContentPart{
876 message.TextContent{Text: summary},
877 message.Finish{
878 Reason: message.FinishReasonEndTurn,
879 Time: time.Now().Unix(),
880 },
881 },
882 Model: a.summarizeProvider.Model().ID,
883 Provider: a.summarizeProviderID,
884 })
885 if err != nil {
886 event = AgentEvent{
887 Type: AgentEventTypeError,
888 Error: fmt.Errorf("failed to create summary message: %w", err),
889 Done: true,
890 }
891
892 a.Publish(pubsub.CreatedEvent, event)
893 return
894 }
895 oldSession.SummaryMessageID = msg.ID
896 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
897 oldSession.PromptTokens = 0
898 model := a.summarizeProvider.Model()
899 usage := finalResponse.Usage
900 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
901 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
902 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
903 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
904 oldSession.Cost += cost
905 _, err = a.sessions.Save(summarizeCtx, oldSession)
906 if err != nil {
907 event = AgentEvent{
908 Type: AgentEventTypeError,
909 Error: fmt.Errorf("failed to save session: %w", err),
910 Done: true,
911 }
912 a.Publish(pubsub.CreatedEvent, event)
913 }
914
915 event = AgentEvent{
916 Type: AgentEventTypeSummarize,
917 SessionID: oldSession.ID,
918 Progress: "Summary complete",
919 Done: true,
920 }
921 a.Publish(pubsub.CreatedEvent, event)
922 // Send final success event with the new session ID
923 }()
924
925 return nil
926}
927
928func (a *agent) ClearQueue(sessionID string) {
929 if a.QueuedPrompts(sessionID) > 0 {
930 slog.Info("Clearing queued prompts", "session_id", sessionID)
931 a.promptQueue.Del(sessionID)
932 }
933}
934
935func (a *agent) CancelAll() {
936 if !a.IsBusy() {
937 return
938 }
939 for key := range a.activeRequests.Seq2() {
940 a.Cancel(key) // key is sessionID
941 }
942
943 timeout := time.After(5 * time.Second)
944 for a.IsBusy() {
945 select {
946 case <-timeout:
947 return
948 default:
949 time.Sleep(200 * time.Millisecond)
950 }
951 }
952}
953
954func (a *agent) UpdateModel() error {
955 // Get current provider configuration
956 currentProviderCfg := a.cfg.GetProviderForModel(a.agentCfg.Model)
957 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
958 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
959 }
960
961 // Check if provider has changed
962 if string(currentProviderCfg.ID) != a.providerID {
963 // Provider changed, need to recreate the main provider
964 model := a.cfg.GetModelByType(a.agentCfg.Model)
965 if model.ID == "" {
966 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
967 }
968
969 promptID := agentPromptMap[a.agentCfg.ID]
970 if promptID == "" {
971 promptID = prompt.PromptDefault
972 }
973
974 opts := []provider.ProviderClientOption{
975 provider.WithModel(a.agentCfg.Model),
976 provider.WithSystemMessage(prompt.GetPrompt(a.cfg, promptID, currentProviderCfg.ID, a.cfg.Options.ContextPaths...)),
977 }
978
979 newProvider, err := provider.NewProvider(a.cfg, *currentProviderCfg, opts...)
980 if err != nil {
981 return fmt.Errorf("failed to create new provider: %w", err)
982 }
983
984 // Update the provider and provider ID
985 a.provider = newProvider
986 a.providerID = string(currentProviderCfg.ID)
987 }
988
989 // Check if providers have changed for title (small) and summarize (large)
990 smallModelCfg := a.cfg.Models[config.SelectedModelTypeSmall]
991 var smallModelProviderCfg config.ProviderConfig
992 for p := range a.cfg.Providers.Seq() {
993 if p.ID == smallModelCfg.Provider {
994 smallModelProviderCfg = p
995 break
996 }
997 }
998 if smallModelProviderCfg.ID == "" {
999 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1000 }
1001
1002 largeModelCfg := a.cfg.Models[config.SelectedModelTypeLarge]
1003 var largeModelProviderCfg config.ProviderConfig
1004 for p := range a.cfg.Providers.Seq() {
1005 if p.ID == largeModelCfg.Provider {
1006 largeModelProviderCfg = p
1007 break
1008 }
1009 }
1010 if largeModelProviderCfg.ID == "" {
1011 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1012 }
1013
1014 var maxTitleTokens int64 = 40
1015
1016 // if the max output is too low for the gemini provider it won't return anything
1017 if smallModelCfg.Provider == "gemini" {
1018 maxTitleTokens = 1000
1019 }
1020 // Recreate title provider
1021 titleOpts := []provider.ProviderClientOption{
1022 provider.WithModel(config.SelectedModelTypeSmall),
1023 provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptTitle, smallModelProviderCfg.ID)),
1024 provider.WithMaxTokens(maxTitleTokens),
1025 }
1026 newTitleProvider, err := provider.NewProvider(a.cfg, smallModelProviderCfg, titleOpts...)
1027 if err != nil {
1028 return fmt.Errorf("failed to create new title provider: %w", err)
1029 }
1030 a.titleProvider = newTitleProvider
1031
1032 // Recreate summarize provider if provider changed (now large model)
1033 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1034 largeModel := a.cfg.GetModelByType(config.SelectedModelTypeLarge)
1035 if largeModel == nil {
1036 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1037 }
1038 summarizeOpts := []provider.ProviderClientOption{
1039 provider.WithModel(config.SelectedModelTypeLarge),
1040 provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1041 }
1042 newSummarizeProvider, err := provider.NewProvider(a.cfg, largeModelProviderCfg, summarizeOpts...)
1043 if err != nil {
1044 return fmt.Errorf("failed to create new summarize provider: %w", err)
1045 }
1046 a.summarizeProvider = newSummarizeProvider
1047 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1048 }
1049
1050 return nil
1051}