1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "slices"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26 "github.com/charmbracelet/crush/internal/shell"
27)
28
29// Common errors
30var (
31 ErrRequestCancelled = errors.New("request canceled by user")
32 ErrSessionBusy = errors.New("session is currently processing another request")
33)
34
35type AgentEventType string
36
37const (
38 AgentEventTypeError AgentEventType = "error"
39 AgentEventTypeResponse AgentEventType = "response"
40 AgentEventTypeSummarize AgentEventType = "summarize"
41)
42
43type AgentEvent struct {
44 Type AgentEventType
45 Message message.Message
46 Error error
47
48 // When summarizing
49 SessionID string
50 Progress string
51 Done bool
52}
53
54type Service interface {
55 pubsub.Suscriber[AgentEvent]
56 Model() catwalk.Model
57 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
58 Cancel(sessionID string)
59 CancelAll()
60 IsSessionBusy(sessionID string) bool
61 IsBusy() bool
62 Summarize(ctx context.Context, sessionID string) error
63 UpdateModel() error
64 QueuedPrompts(sessionID string) int
65 ClearQueue(sessionID string)
66}
67
68type agent struct {
69 cleanupFuncs []func()
70
71 *pubsub.Broker[AgentEvent]
72 agentCfg config.Agent
73 sessions session.Service
74 messages message.Service
75
76 baseTools *csync.Map[string, tools.BaseTool]
77 mcpTools *csync.Map[string, tools.BaseTool]
78
79 provider provider.Provider
80 providerID string
81
82 titleProvider provider.Provider
83 summarizeProvider provider.Provider
84 summarizeProviderID string
85
86 activeRequests *csync.Map[string, context.CancelFunc]
87
88 promptQueue *csync.Map[string, []string]
89}
90
91var agentPromptMap = map[string]prompt.PromptID{
92 "coder": prompt.PromptCoder,
93 "task": prompt.PromptTask,
94}
95
96func NewAgent(
97 ctx context.Context,
98 agentCfg config.Agent,
99 // These services are needed in the tools
100 permissions permission.Service,
101 sessions session.Service,
102 messages message.Service,
103 history history.Service,
104 lspClients map[string]*lsp.Client,
105) (Service, error) {
106 cfg := config.Get()
107
108 var agentTool tools.BaseTool
109 if agentCfg.ID == "coder" {
110 taskAgentCfg := config.Get().Agents["task"]
111 if taskAgentCfg.ID == "" {
112 return nil, fmt.Errorf("task agent not found in config")
113 }
114 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
115 if err != nil {
116 return nil, fmt.Errorf("failed to create task agent: %w", err)
117 }
118
119 agentTool = NewAgentTool(taskAgent, sessions, messages)
120 }
121
122 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
123 if providerCfg == nil {
124 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
125 }
126 model := config.Get().GetModelByType(agentCfg.Model)
127
128 if model == nil {
129 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
130 }
131
132 promptID := agentPromptMap[agentCfg.ID]
133 if promptID == "" {
134 promptID = prompt.PromptDefault
135 }
136 opts := []provider.ProviderClientOption{
137 provider.WithModel(agentCfg.Model),
138 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
139 }
140 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
141 if err != nil {
142 return nil, err
143 }
144
145 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
146 var smallModelProviderCfg *config.ProviderConfig
147 if smallModelCfg.Provider == providerCfg.ID {
148 smallModelProviderCfg = providerCfg
149 } else {
150 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
151
152 if smallModelProviderCfg.ID == "" {
153 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
154 }
155 }
156 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
157 if smallModel.ID == "" {
158 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
159 }
160
161 titleOpts := []provider.ProviderClientOption{
162 provider.WithModel(config.SelectedModelTypeSmall),
163 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
164 }
165 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
166 if err != nil {
167 return nil, err
168 }
169
170 summarizeOpts := []provider.ProviderClientOption{
171 provider.WithModel(config.SelectedModelTypeLarge),
172 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
173 }
174 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
175 if err != nil {
176 return nil, err
177 }
178
179 baseToolsFn := func() map[string]tools.BaseTool {
180 slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
181 defer func() {
182 slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
183 }()
184
185 // Base tools available to all agents
186 cwd := cfg.WorkingDir()
187 result := make(map[string]tools.BaseTool)
188 for _, tool := range []tools.BaseTool{
189 tools.NewBashTool(permissions, cwd),
190 tools.NewDownloadTool(permissions, cwd),
191 tools.NewEditTool(lspClients, permissions, history, cwd),
192 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
193 tools.NewFetchTool(permissions, cwd),
194 tools.NewGlobTool(cwd),
195 tools.NewGrepTool(cwd),
196 tools.NewLsTool(permissions, cwd),
197 tools.NewSourcegraphTool(),
198 tools.NewViewTool(lspClients, permissions, cwd),
199 tools.NewWriteTool(lspClients, permissions, history, cwd),
200 } {
201 result[tool.Name()] = tool
202 }
203
204 if len(lspClients) > 0 {
205 diagnosticsTool := tools.NewDiagnosticsTool(lspClients)
206 result[diagnosticsTool.Name()] = diagnosticsTool
207 }
208
209 if agentTool != nil {
210 result[agentTool.Name()] = agentTool
211 }
212
213 return result
214 }
215 mcpToolsFn := func() map[string]tools.BaseTool {
216 slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
217 defer func() {
218 slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
219 }()
220
221 mcpToolsOnce.Do(func() {
222 doGetMCPTools(ctx, permissions, cfg)
223 })
224
225 result := make(map[string]tools.BaseTool)
226 for _, mcpTool := range mcpTools.Seq2() {
227 result[mcpTool.Name()] = mcpTool
228 }
229 return result
230 }
231
232 a := &agent{
233 Broker: pubsub.NewBroker[AgentEvent](),
234 agentCfg: agentCfg,
235 provider: agentProvider,
236 providerID: string(providerCfg.ID),
237 messages: messages,
238 sessions: sessions,
239 titleProvider: titleProvider,
240 summarizeProvider: summarizeProvider,
241 summarizeProviderID: string(providerCfg.ID),
242 activeRequests: csync.NewMap[string, context.CancelFunc](),
243 mcpTools: csync.NewLazyMap(mcpToolsFn),
244 baseTools: csync.NewLazyMap(baseToolsFn),
245 promptQueue: csync.NewMap[string, []string](),
246 }
247 a.setupEvents(ctx)
248 return a, nil
249}
250
251func (a *agent) Model() catwalk.Model {
252 return *config.Get().GetModelByType(a.agentCfg.Model)
253}
254
255func (a *agent) Cancel(sessionID string) {
256 // Cancel regular requests
257 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
258 slog.Info("Request cancellation initiated", "session_id", sessionID)
259 cancel()
260 }
261
262 // Also check for summarize requests
263 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
264 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
265 cancel()
266 }
267
268 if a.QueuedPrompts(sessionID) > 0 {
269 slog.Info("Clearing queued prompts", "session_id", sessionID)
270 a.promptQueue.Del(sessionID)
271 }
272}
273
274func (a *agent) IsBusy() bool {
275 var busy bool
276 for cancelFunc := range a.activeRequests.Seq() {
277 if cancelFunc != nil {
278 busy = true
279 break
280 }
281 }
282 return busy
283}
284
285func (a *agent) IsSessionBusy(sessionID string) bool {
286 _, busy := a.activeRequests.Get(sessionID)
287 return busy
288}
289
290func (a *agent) QueuedPrompts(sessionID string) int {
291 l, ok := a.promptQueue.Get(sessionID)
292 if !ok {
293 return 0
294 }
295 return len(l)
296}
297
298func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
299 if content == "" {
300 return nil
301 }
302 if a.titleProvider == nil {
303 return nil
304 }
305 session, err := a.sessions.Get(ctx, sessionID)
306 if err != nil {
307 return err
308 }
309 parts := []message.ContentPart{message.TextContent{
310 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
311 }}
312
313 // Use streaming approach like summarization
314 response := a.titleProvider.StreamResponse(
315 ctx,
316 []message.Message{
317 {
318 Role: message.User,
319 Parts: parts,
320 },
321 },
322 nil,
323 )
324
325 var finalResponse *provider.ProviderResponse
326 for r := range response {
327 if r.Error != nil {
328 return r.Error
329 }
330 finalResponse = r.Response
331 }
332
333 if finalResponse == nil {
334 return fmt.Errorf("no response received from title provider")
335 }
336
337 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
338 if title == "" {
339 return nil
340 }
341
342 session.Title = title
343 _, err = a.sessions.Save(ctx, session)
344 return err
345}
346
347func (a *agent) err(err error) AgentEvent {
348 return AgentEvent{
349 Type: AgentEventTypeError,
350 Error: err,
351 }
352}
353
354func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
355 if !a.Model().SupportsImages && attachments != nil {
356 attachments = nil
357 }
358 events := make(chan AgentEvent)
359 if a.IsSessionBusy(sessionID) {
360 existing, ok := a.promptQueue.Get(sessionID)
361 if !ok {
362 existing = []string{}
363 }
364 existing = append(existing, content)
365 a.promptQueue.Set(sessionID, existing)
366 return nil, nil
367 }
368
369 genCtx, cancel := context.WithCancel(ctx)
370
371 a.activeRequests.Set(sessionID, cancel)
372 go func() {
373 slog.Debug("Request started", "sessionID", sessionID)
374 defer log.RecoverPanic("agent.Run", func() {
375 events <- a.err(fmt.Errorf("panic while running the agent"))
376 })
377 var attachmentParts []message.ContentPart
378 for _, attachment := range attachments {
379 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
380 }
381 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
382 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
383 slog.Error(result.Error.Error())
384 }
385 slog.Debug("Request completed", "sessionID", sessionID)
386 a.activeRequests.Del(sessionID)
387 cancel()
388 a.Publish(pubsub.CreatedEvent, result)
389 select {
390 case events <- result:
391 case <-genCtx.Done():
392 }
393 close(events)
394 }()
395 return events, nil
396}
397
398func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
399 cfg := config.Get()
400 // List existing messages; if none, start title generation asynchronously.
401 msgs, err := a.messages.List(ctx, sessionID)
402 if err != nil {
403 return a.err(fmt.Errorf("failed to list messages: %w", err))
404 }
405 if len(msgs) == 0 {
406 go func() {
407 defer log.RecoverPanic("agent.Run", func() {
408 slog.Error("panic while generating title")
409 })
410 titleErr := a.generateTitle(context.Background(), sessionID, content)
411 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
412 slog.Error("failed to generate title", "error", titleErr)
413 }
414 }()
415 }
416 session, err := a.sessions.Get(ctx, sessionID)
417 if err != nil {
418 return a.err(fmt.Errorf("failed to get session: %w", err))
419 }
420 if session.SummaryMessageID != "" {
421 summaryMsgInex := -1
422 for i, msg := range msgs {
423 if msg.ID == session.SummaryMessageID {
424 summaryMsgInex = i
425 break
426 }
427 }
428 if summaryMsgInex != -1 {
429 msgs = msgs[summaryMsgInex:]
430 msgs[0].Role = message.User
431 }
432 }
433
434 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
435 if err != nil {
436 return a.err(fmt.Errorf("failed to create user message: %w", err))
437 }
438 // Append the new user message to the conversation history.
439 msgHistory := append(msgs, userMsg)
440
441 for {
442 // Check for cancellation before each iteration
443 select {
444 case <-ctx.Done():
445 return a.err(ctx.Err())
446 default:
447 // Continue processing
448 }
449 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
450 if err != nil {
451 if errors.Is(err, context.Canceled) {
452 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
453 a.messages.Update(context.Background(), agentMessage)
454 return a.err(ErrRequestCancelled)
455 }
456 return a.err(fmt.Errorf("failed to process events: %w", err))
457 }
458 if cfg.Options.Debug {
459 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
460 }
461 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
462 // We are not done, we need to respond with the tool response
463 msgHistory = append(msgHistory, agentMessage, *toolResults)
464 // If there are queued prompts, process the next one
465 nextPrompt, ok := a.promptQueue.Take(sessionID)
466 if ok {
467 for _, prompt := range nextPrompt {
468 // Create a new user message for the queued prompt
469 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
470 if err != nil {
471 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
472 }
473 // Append the new user message to the conversation history
474 msgHistory = append(msgHistory, userMsg)
475 }
476 }
477
478 continue
479 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
480 queuePrompts, ok := a.promptQueue.Take(sessionID)
481 if ok {
482 for _, prompt := range queuePrompts {
483 if prompt == "" {
484 continue
485 }
486 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
487 if err != nil {
488 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
489 }
490 msgHistory = append(msgHistory, userMsg)
491 }
492 continue
493 }
494 }
495 if agentMessage.FinishReason() == "" {
496 // Kujtim: could not track down where this is happening but this means its cancelled
497 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
498 _ = a.messages.Update(context.Background(), agentMessage)
499 return a.err(ErrRequestCancelled)
500 }
501 return AgentEvent{
502 Type: AgentEventTypeResponse,
503 Message: agentMessage,
504 Done: true,
505 }
506 }
507}
508
509func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
510 parts := []message.ContentPart{message.TextContent{Text: content}}
511 parts = append(parts, attachmentParts...)
512 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
513 Role: message.User,
514 Parts: parts,
515 })
516}
517
518func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
519 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
520
521 // Create the assistant message first so the spinner shows immediately
522 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
523 Role: message.Assistant,
524 Parts: []message.ContentPart{},
525 Model: a.Model().ID,
526 Provider: a.providerID,
527 })
528 if err != nil {
529 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
530 }
531
532 // Now collect tools (which may block on MCP initialization)
533 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.allTools())
534
535 // Add the session and message ID into the context if needed by tools.
536 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
537
538 // Process each event in the stream.
539 for event := range eventChan {
540 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
541 if errors.Is(processErr, context.Canceled) {
542 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
543 } else {
544 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
545 }
546 return assistantMsg, nil, processErr
547 }
548 if ctx.Err() != nil {
549 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
550 return assistantMsg, nil, ctx.Err()
551 }
552 }
553
554 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
555 toolCalls := assistantMsg.ToolCalls()
556 for i, toolCall := range toolCalls {
557 select {
558 case <-ctx.Done():
559 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
560 // Make all future tool calls cancelled
561 for j := i; j < len(toolCalls); j++ {
562 toolResults[j] = message.ToolResult{
563 ToolCallID: toolCalls[j].ID,
564 Content: "Tool execution canceled by user",
565 IsError: true,
566 }
567 }
568 goto out
569 default:
570 // Continue processing
571 tool, ok := a.getToolByName(toolCall.Name)
572
573 // Tool not found
574 if !ok {
575 toolResults[i] = message.ToolResult{
576 ToolCallID: toolCall.ID,
577 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
578 IsError: true,
579 }
580 continue
581 }
582
583 // Run tool in goroutine to allow cancellation
584 type toolExecResult struct {
585 response tools.ToolResponse
586 err error
587 }
588 resultChan := make(chan toolExecResult, 1)
589
590 go func() {
591 response, err := tool.Run(ctx, tools.ToolCall{
592 ID: toolCall.ID,
593 Name: toolCall.Name,
594 Input: toolCall.Input,
595 })
596 resultChan <- toolExecResult{response: response, err: err}
597 }()
598
599 var toolResponse tools.ToolResponse
600 var toolErr error
601
602 select {
603 case <-ctx.Done():
604 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
605 // Mark remaining tool calls as cancelled
606 for j := i; j < len(toolCalls); j++ {
607 toolResults[j] = message.ToolResult{
608 ToolCallID: toolCalls[j].ID,
609 Content: "Tool execution canceled by user",
610 IsError: true,
611 }
612 }
613 goto out
614 case result := <-resultChan:
615 toolResponse = result.response
616 toolErr = result.err
617 }
618
619 if toolErr != nil {
620 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
621 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
622 toolResults[i] = message.ToolResult{
623 ToolCallID: toolCall.ID,
624 Content: "Permission denied",
625 IsError: true,
626 }
627 for j := i + 1; j < len(toolCalls); j++ {
628 toolResults[j] = message.ToolResult{
629 ToolCallID: toolCalls[j].ID,
630 Content: "Tool execution canceled by user",
631 IsError: true,
632 }
633 }
634 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
635 break
636 }
637 }
638 toolResults[i] = message.ToolResult{
639 ToolCallID: toolCall.ID,
640 Content: toolResponse.Content,
641 Metadata: toolResponse.Metadata,
642 IsError: toolResponse.IsError,
643 }
644 }
645 }
646out:
647 if len(toolResults) == 0 {
648 return assistantMsg, nil, nil
649 }
650 parts := make([]message.ContentPart, 0)
651 for _, tr := range toolResults {
652 parts = append(parts, tr)
653 }
654 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
655 Role: message.Tool,
656 Parts: parts,
657 Provider: a.providerID,
658 })
659 if err != nil {
660 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
661 }
662
663 return assistantMsg, &msg, err
664}
665
666func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
667 msg.AddFinish(finishReason, message, details)
668 _ = a.messages.Update(ctx, *msg)
669}
670
671func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
672 select {
673 case <-ctx.Done():
674 return ctx.Err()
675 default:
676 // Continue processing.
677 }
678
679 switch event.Type {
680 case provider.EventThinkingDelta:
681 assistantMsg.AppendReasoningContent(event.Thinking)
682 return a.messages.Update(ctx, *assistantMsg)
683 case provider.EventSignatureDelta:
684 assistantMsg.AppendReasoningSignature(event.Signature)
685 return a.messages.Update(ctx, *assistantMsg)
686 case provider.EventContentDelta:
687 assistantMsg.FinishThinking()
688 assistantMsg.AppendContent(event.Content)
689 return a.messages.Update(ctx, *assistantMsg)
690 case provider.EventToolUseStart:
691 assistantMsg.FinishThinking()
692 slog.Info("Tool call started", "toolCall", event.ToolCall)
693 assistantMsg.AddToolCall(*event.ToolCall)
694 return a.messages.Update(ctx, *assistantMsg)
695 case provider.EventToolUseDelta:
696 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
697 return a.messages.Update(ctx, *assistantMsg)
698 case provider.EventToolUseStop:
699 slog.Info("Finished tool call", "toolCall", event.ToolCall)
700 assistantMsg.FinishToolCall(event.ToolCall.ID)
701 return a.messages.Update(ctx, *assistantMsg)
702 case provider.EventError:
703 return event.Error
704 case provider.EventComplete:
705 assistantMsg.FinishThinking()
706 assistantMsg.SetToolCalls(event.Response.ToolCalls)
707 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
708 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
709 return fmt.Errorf("failed to update message: %w", err)
710 }
711 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
712 }
713
714 return nil
715}
716
717func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
718 sess, err := a.sessions.Get(ctx, sessionID)
719 if err != nil {
720 return fmt.Errorf("failed to get session: %w", err)
721 }
722
723 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
724 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
725 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
726 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
727
728 sess.Cost += cost
729 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
730 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
731
732 _, err = a.sessions.Save(ctx, sess)
733 if err != nil {
734 return fmt.Errorf("failed to save session: %w", err)
735 }
736 return nil
737}
738
739func (a *agent) Summarize(ctx context.Context, sessionID string) error {
740 if a.summarizeProvider == nil {
741 return fmt.Errorf("summarize provider not available")
742 }
743
744 // Check if session is busy
745 if a.IsSessionBusy(sessionID) {
746 return ErrSessionBusy
747 }
748
749 // Create a new context with cancellation
750 summarizeCtx, cancel := context.WithCancel(ctx)
751
752 // Store the cancel function in activeRequests to allow cancellation
753 a.activeRequests.Set(sessionID+"-summarize", cancel)
754
755 go func() {
756 defer a.activeRequests.Del(sessionID + "-summarize")
757 defer cancel()
758 event := AgentEvent{
759 Type: AgentEventTypeSummarize,
760 Progress: "Starting summarization...",
761 }
762
763 a.Publish(pubsub.CreatedEvent, event)
764 // Get all messages from the session
765 msgs, err := a.messages.List(summarizeCtx, sessionID)
766 if err != nil {
767 event = AgentEvent{
768 Type: AgentEventTypeError,
769 Error: fmt.Errorf("failed to list messages: %w", err),
770 Done: true,
771 }
772 a.Publish(pubsub.CreatedEvent, event)
773 return
774 }
775 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
776
777 if len(msgs) == 0 {
778 event = AgentEvent{
779 Type: AgentEventTypeError,
780 Error: fmt.Errorf("no messages to summarize"),
781 Done: true,
782 }
783 a.Publish(pubsub.CreatedEvent, event)
784 return
785 }
786
787 event = AgentEvent{
788 Type: AgentEventTypeSummarize,
789 Progress: "Analyzing conversation...",
790 }
791 a.Publish(pubsub.CreatedEvent, event)
792
793 // Add a system message to guide the summarization
794 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
795
796 // Create a new message with the summarize prompt
797 promptMsg := message.Message{
798 Role: message.User,
799 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
800 }
801
802 // Append the prompt to the messages
803 msgsWithPrompt := append(msgs, promptMsg)
804
805 event = AgentEvent{
806 Type: AgentEventTypeSummarize,
807 Progress: "Generating summary...",
808 }
809
810 a.Publish(pubsub.CreatedEvent, event)
811
812 // Send the messages to the summarize provider
813 response := a.summarizeProvider.StreamResponse(
814 summarizeCtx,
815 msgsWithPrompt,
816 nil,
817 )
818 var finalResponse *provider.ProviderResponse
819 for r := range response {
820 if r.Error != nil {
821 event = AgentEvent{
822 Type: AgentEventTypeError,
823 Error: fmt.Errorf("failed to summarize: %w", err),
824 Done: true,
825 }
826 a.Publish(pubsub.CreatedEvent, event)
827 return
828 }
829 finalResponse = r.Response
830 }
831
832 summary := strings.TrimSpace(finalResponse.Content)
833 if summary == "" {
834 event = AgentEvent{
835 Type: AgentEventTypeError,
836 Error: fmt.Errorf("empty summary returned"),
837 Done: true,
838 }
839 a.Publish(pubsub.CreatedEvent, event)
840 return
841 }
842 shell := shell.GetPersistentShell(config.Get().WorkingDir())
843 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
844 event = AgentEvent{
845 Type: AgentEventTypeSummarize,
846 Progress: "Creating new session...",
847 }
848
849 a.Publish(pubsub.CreatedEvent, event)
850 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
851 if err != nil {
852 event = AgentEvent{
853 Type: AgentEventTypeError,
854 Error: fmt.Errorf("failed to get session: %w", err),
855 Done: true,
856 }
857
858 a.Publish(pubsub.CreatedEvent, event)
859 return
860 }
861 // Create a message in the new session with the summary
862 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
863 Role: message.Assistant,
864 Parts: []message.ContentPart{
865 message.TextContent{Text: summary},
866 message.Finish{
867 Reason: message.FinishReasonEndTurn,
868 Time: time.Now().Unix(),
869 },
870 },
871 Model: a.summarizeProvider.Model().ID,
872 Provider: a.summarizeProviderID,
873 })
874 if err != nil {
875 event = AgentEvent{
876 Type: AgentEventTypeError,
877 Error: fmt.Errorf("failed to create summary message: %w", err),
878 Done: true,
879 }
880
881 a.Publish(pubsub.CreatedEvent, event)
882 return
883 }
884 oldSession.SummaryMessageID = msg.ID
885 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
886 oldSession.PromptTokens = 0
887 model := a.summarizeProvider.Model()
888 usage := finalResponse.Usage
889 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
890 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
891 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
892 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
893 oldSession.Cost += cost
894 _, err = a.sessions.Save(summarizeCtx, oldSession)
895 if err != nil {
896 event = AgentEvent{
897 Type: AgentEventTypeError,
898 Error: fmt.Errorf("failed to save session: %w", err),
899 Done: true,
900 }
901 a.Publish(pubsub.CreatedEvent, event)
902 }
903
904 event = AgentEvent{
905 Type: AgentEventTypeSummarize,
906 SessionID: oldSession.ID,
907 Progress: "Summary complete",
908 Done: true,
909 }
910 a.Publish(pubsub.CreatedEvent, event)
911 // Send final success event with the new session ID
912 }()
913
914 return nil
915}
916
917func (a *agent) ClearQueue(sessionID string) {
918 if a.QueuedPrompts(sessionID) > 0 {
919 slog.Info("Clearing queued prompts", "session_id", sessionID)
920 a.promptQueue.Del(sessionID)
921 }
922}
923
924func (a *agent) CancelAll() {
925 if !a.IsBusy() {
926 return
927 }
928 for key := range a.activeRequests.Seq2() {
929 a.Cancel(key) // key is sessionID
930 }
931
932 for _, cleanup := range a.cleanupFuncs {
933 if cleanup != nil {
934 cleanup()
935 }
936 }
937
938 timeout := time.After(5 * time.Second)
939 for a.IsBusy() {
940 select {
941 case <-timeout:
942 return
943 default:
944 time.Sleep(200 * time.Millisecond)
945 }
946 }
947}
948
949func (a *agent) UpdateModel() error {
950 cfg := config.Get()
951
952 // Get current provider configuration
953 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
954 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
955 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
956 }
957
958 // Check if provider has changed
959 if string(currentProviderCfg.ID) != a.providerID {
960 // Provider changed, need to recreate the main provider
961 model := cfg.GetModelByType(a.agentCfg.Model)
962 if model.ID == "" {
963 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
964 }
965
966 promptID := agentPromptMap[a.agentCfg.ID]
967 if promptID == "" {
968 promptID = prompt.PromptDefault
969 }
970
971 opts := []provider.ProviderClientOption{
972 provider.WithModel(a.agentCfg.Model),
973 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
974 }
975
976 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
977 if err != nil {
978 return fmt.Errorf("failed to create new provider: %w", err)
979 }
980
981 // Update the provider and provider ID
982 a.provider = newProvider
983 a.providerID = string(currentProviderCfg.ID)
984 }
985
986 // Check if providers have changed for title (small) and summarize (large)
987 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
988 var smallModelProviderCfg config.ProviderConfig
989 for p := range cfg.Providers.Seq() {
990 if p.ID == smallModelCfg.Provider {
991 smallModelProviderCfg = p
992 break
993 }
994 }
995 if smallModelProviderCfg.ID == "" {
996 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
997 }
998
999 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1000 var largeModelProviderCfg config.ProviderConfig
1001 for p := range cfg.Providers.Seq() {
1002 if p.ID == largeModelCfg.Provider {
1003 largeModelProviderCfg = p
1004 break
1005 }
1006 }
1007 if largeModelProviderCfg.ID == "" {
1008 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1009 }
1010
1011 // Recreate title provider
1012 titleOpts := []provider.ProviderClientOption{
1013 provider.WithModel(config.SelectedModelTypeSmall),
1014 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1015 provider.WithMaxTokens(40),
1016 }
1017 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1018 if err != nil {
1019 return fmt.Errorf("failed to create new title provider: %w", err)
1020 }
1021 a.titleProvider = newTitleProvider
1022
1023 // Recreate summarize provider if provider changed (now large model)
1024 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1025 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1026 if largeModel == nil {
1027 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1028 }
1029 summarizeOpts := []provider.ProviderClientOption{
1030 provider.WithModel(config.SelectedModelTypeLarge),
1031 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1032 }
1033 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1034 if err != nil {
1035 return fmt.Errorf("failed to create new summarize provider: %w", err)
1036 }
1037 a.summarizeProvider = newSummarizeProvider
1038 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1039 }
1040
1041 return nil
1042}
1043
1044func (a *agent) allTools() []tools.BaseTool {
1045 result := slices.Collect(a.baseTools.Seq())
1046 result = slices.AppendSeq(result, a.mcpTools.Seq())
1047 return result
1048}
1049
1050func (a *agent) getToolByName(name string) (tools.BaseTool, bool) {
1051 tool, ok := a.baseTools.Get(name)
1052 if !ok {
1053 tool, ok = a.mcpTools.Get(name)
1054 }
1055 return tool, ok
1056}
1057
1058func (a *agent) setupEvents(ctx context.Context) {
1059 ctx, cancel := context.WithCancel(ctx)
1060
1061 go func() {
1062 for event := range SubscribeMCPEvents(ctx) {
1063 switch event.Payload.Type {
1064 case MCPEventToolsListChanged:
1065 name := event.Payload.Name
1066 c, ok := mcpClients.Get(name)
1067 if !ok {
1068 slog.Warn("MCP client not found for tools update", "name", name)
1069 continue
1070 }
1071 tools := getTools(ctx, name, c)
1072 updateMcpTools(name, tools)
1073 // Update the lazy map with the new tools
1074 a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2()))
1075 updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1076 default:
1077 continue
1078 }
1079 }
1080 }()
1081
1082 a.cleanupFuncs = append(a.cleanupFuncs, func() {
1083 cancel()
1084 })
1085}