1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 globalCtx context.Context
69 cleanupFuncs []func()
70
71 *pubsub.Broker[AgentEvent]
72 agentCfg config.Agent
73 sessions session.Service
74 messages message.Service
75
76 baseTools *csync.Map[string, tools.BaseTool]
77 mcpTools *csync.Map[string, tools.BaseTool]
78
79 provider provider.Provider
80 providerID string
81
82 titleProvider provider.Provider
83 summarizeProvider provider.Provider
84 summarizeProviderID string
85
86 activeRequests *csync.Map[string, context.CancelFunc]
87
88 promptQueue *csync.Map[string, []string]
89}
90
91var agentPromptMap = map[string]prompt.PromptID{
92 "coder": prompt.PromptCoder,
93 "task": prompt.PromptTask,
94}
95
96func NewAgent(
97 ctx context.Context,
98 agentCfg config.Agent,
99 // These services are needed in the tools
100 permissions permission.Service,
101 sessions session.Service,
102 messages message.Service,
103 history history.Service,
104 lspClients map[string]*lsp.Client,
105) (Service, error) {
106 cfg := config.Get()
107
108 var agentTool tools.BaseTool
109 if agentCfg.ID == "coder" {
110 taskAgentCfg := config.Get().Agents["task"]
111 if taskAgentCfg.ID == "" {
112 return nil, fmt.Errorf("task agent not found in config")
113 }
114 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
115 if err != nil {
116 return nil, fmt.Errorf("failed to create task agent: %w", err)
117 }
118
119 agentTool = NewAgentTool(taskAgent, sessions, messages)
120 }
121
122 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
123 if providerCfg == nil {
124 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
125 }
126 model := config.Get().GetModelByType(agentCfg.Model)
127
128 if model == nil {
129 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
130 }
131
132 promptID := agentPromptMap[agentCfg.ID]
133 if promptID == "" {
134 promptID = prompt.PromptDefault
135 }
136 opts := []provider.ProviderClientOption{
137 provider.WithModel(agentCfg.Model),
138 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
139 }
140 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
141 if err != nil {
142 return nil, err
143 }
144
145 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
146 var smallModelProviderCfg *config.ProviderConfig
147 if smallModelCfg.Provider == providerCfg.ID {
148 smallModelProviderCfg = providerCfg
149 } else {
150 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
151
152 if smallModelProviderCfg.ID == "" {
153 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
154 }
155 }
156 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
157 if smallModel.ID == "" {
158 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
159 }
160
161 titleOpts := []provider.ProviderClientOption{
162 provider.WithModel(config.SelectedModelTypeSmall),
163 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
164 }
165 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
166 if err != nil {
167 return nil, err
168 }
169
170 summarizeOpts := []provider.ProviderClientOption{
171 provider.WithModel(config.SelectedModelTypeLarge),
172 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
173 }
174 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
175 if err != nil {
176 return nil, err
177 }
178
179 baseToolsFn := func() map[string]tools.BaseTool {
180 slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
181 defer func() {
182 slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
183 }()
184
185 // Base tools available to all agents
186 cwd := cfg.WorkingDir()
187 result := make(map[string]tools.BaseTool)
188 for _, tool := range []tools.BaseTool{
189 tools.NewBashTool(permissions, cwd),
190 tools.NewDownloadTool(permissions, cwd),
191 tools.NewEditTool(lspClients, permissions, history, cwd),
192 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
193 tools.NewFetchTool(permissions, cwd),
194 tools.NewGlobTool(cwd),
195 tools.NewGrepTool(cwd),
196 tools.NewLsTool(permissions, cwd),
197 tools.NewSourcegraphTool(),
198 tools.NewViewTool(lspClients, permissions, cwd),
199 tools.NewWriteTool(lspClients, permissions, history, cwd),
200 } {
201 result[tool.Name()] = tool
202 }
203
204 if len(lspClients) > 0 {
205 diagnosticsTool := tools.NewDiagnosticsTool(lspClients)
206 result[diagnosticsTool.Name()] = diagnosticsTool
207 }
208
209 if agentTool != nil {
210 result[agentTool.Name()] = agentTool
211 }
212
213 return result
214 }
215 mcpToolsFn := func() map[string]tools.BaseTool {
216 slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
217 defer func() {
218 slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
219 }()
220
221 mcpToolsOnce.Do(func() {
222 doGetMCPTools(ctx, permissions, cfg)
223 })
224
225 result := make(map[string]tools.BaseTool)
226 for _, mcpTool := range mcpTools.Seq2() {
227 result[mcpTool.Name()] = mcpTool
228 }
229 return result
230 }
231
232 a := &agent{
233 globalCtx: ctx,
234 Broker: pubsub.NewBroker[AgentEvent](),
235 agentCfg: agentCfg,
236 provider: agentProvider,
237 providerID: string(providerCfg.ID),
238 messages: messages,
239 sessions: sessions,
240 titleProvider: titleProvider,
241 summarizeProvider: summarizeProvider,
242 summarizeProviderID: string(providerCfg.ID),
243 activeRequests: csync.NewMap[string, context.CancelFunc](),
244 mcpTools: csync.NewLazyMap(mcpToolsFn),
245 baseTools: csync.NewLazyMap(baseToolsFn),
246 promptQueue: csync.NewMap[string, []string](),
247 }
248 a.setupEvents()
249 return a, nil
250}
251
252func (a *agent) Model() catwalk.Model {
253 return *config.Get().GetModelByType(a.agentCfg.Model)
254}
255
256func (a *agent) Cancel(sessionID string) {
257 // Cancel regular requests
258 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
259 slog.Info("Request cancellation initiated", "session_id", sessionID)
260 cancel()
261 }
262
263 // Also check for summarize requests
264 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
265 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
266 cancel()
267 }
268
269 if a.QueuedPrompts(sessionID) > 0 {
270 slog.Info("Clearing queued prompts", "session_id", sessionID)
271 a.promptQueue.Del(sessionID)
272 }
273}
274
275func (a *agent) IsBusy() bool {
276 var busy bool
277 for cancelFunc := range a.activeRequests.Seq() {
278 if cancelFunc != nil {
279 busy = true
280 break
281 }
282 }
283 return busy
284}
285
286func (a *agent) IsSessionBusy(sessionID string) bool {
287 _, busy := a.activeRequests.Get(sessionID)
288 return busy
289}
290
291func (a *agent) QueuedPrompts(sessionID string) int {
292 l, ok := a.promptQueue.Get(sessionID)
293 if !ok {
294 return 0
295 }
296 return len(l)
297}
298
299func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
300 if content == "" {
301 return nil
302 }
303 if a.titleProvider == nil {
304 return nil
305 }
306 session, err := a.sessions.Get(ctx, sessionID)
307 if err != nil {
308 return err
309 }
310 parts := []message.ContentPart{message.TextContent{
311 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
312 }}
313
314 // Use streaming approach like summarization
315 response := a.titleProvider.StreamResponse(
316 ctx,
317 []message.Message{
318 {
319 Role: message.User,
320 Parts: parts,
321 },
322 },
323 nil,
324 )
325
326 var finalResponse *provider.ProviderResponse
327 for r := range response {
328 if r.Error != nil {
329 return r.Error
330 }
331 finalResponse = r.Response
332 }
333
334 if finalResponse == nil {
335 return fmt.Errorf("no response received from title provider")
336 }
337
338 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
339 if title == "" {
340 return nil
341 }
342
343 session.Title = title
344 _, err = a.sessions.Save(ctx, session)
345 return err
346}
347
348func (a *agent) err(err error) AgentEvent {
349 return AgentEvent{
350 Type: AgentEventTypeError,
351 Error: err,
352 }
353}
354
355func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
356 if !a.Model().SupportsImages && attachments != nil {
357 attachments = nil
358 }
359 events := make(chan AgentEvent)
360 if a.IsSessionBusy(sessionID) {
361 existing, ok := a.promptQueue.Get(sessionID)
362 if !ok {
363 existing = []string{}
364 }
365 existing = append(existing, content)
366 a.promptQueue.Set(sessionID, existing)
367 return nil, nil
368 }
369
370 genCtx, cancel := context.WithCancel(ctx)
371
372 a.activeRequests.Set(sessionID, cancel)
373 go func() {
374 slog.Debug("Request started", "sessionID", sessionID)
375 defer log.RecoverPanic("agent.Run", func() {
376 events <- a.err(fmt.Errorf("panic while running the agent"))
377 })
378 var attachmentParts []message.ContentPart
379 for _, attachment := range attachments {
380 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
381 }
382 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
383 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
384 slog.Error(result.Error.Error())
385 }
386 slog.Debug("Request completed", "sessionID", sessionID)
387 a.activeRequests.Del(sessionID)
388 cancel()
389 a.Publish(pubsub.CreatedEvent, result)
390 select {
391 case events <- result:
392 case <-genCtx.Done():
393 }
394 close(events)
395 }()
396 return events, nil
397}
398
399func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
400 cfg := config.Get()
401 // List existing messages; if none, start title generation asynchronously.
402 msgs, err := a.messages.List(ctx, sessionID)
403 if err != nil {
404 return a.err(fmt.Errorf("failed to list messages: %w", err))
405 }
406 if len(msgs) == 0 {
407 go func() {
408 defer log.RecoverPanic("agent.Run", func() {
409 slog.Error("panic while generating title")
410 })
411 titleErr := a.generateTitle(context.Background(), sessionID, content)
412 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
413 slog.Error("failed to generate title", "error", titleErr)
414 }
415 }()
416 }
417 session, err := a.sessions.Get(ctx, sessionID)
418 if err != nil {
419 return a.err(fmt.Errorf("failed to get session: %w", err))
420 }
421 if session.SummaryMessageID != "" {
422 summaryMsgInex := -1
423 for i, msg := range msgs {
424 if msg.ID == session.SummaryMessageID {
425 summaryMsgInex = i
426 break
427 }
428 }
429 if summaryMsgInex != -1 {
430 msgs = msgs[summaryMsgInex:]
431 msgs[0].Role = message.User
432 }
433 }
434
435 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
436 if err != nil {
437 return a.err(fmt.Errorf("failed to create user message: %w", err))
438 }
439 // Append the new user message to the conversation history.
440 msgHistory := append(msgs, userMsg)
441
442 for {
443 // Check for cancellation before each iteration
444 select {
445 case <-ctx.Done():
446 return a.err(ctx.Err())
447 default:
448 // Continue processing
449 }
450 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
451 if err != nil {
452 if errors.Is(err, context.Canceled) {
453 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
454 a.messages.Update(context.Background(), agentMessage)
455 return a.err(ErrRequestCancelled)
456 }
457 return a.err(fmt.Errorf("failed to process events: %w", err))
458 }
459 if cfg.Options.Debug {
460 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
461 }
462 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
463 // We are not done, we need to respond with the tool response
464 msgHistory = append(msgHistory, agentMessage, *toolResults)
465 // If there are queued prompts, process the next one
466 nextPrompt, ok := a.promptQueue.Take(sessionID)
467 if ok {
468 for _, prompt := range nextPrompt {
469 // Create a new user message for the queued prompt
470 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
471 if err != nil {
472 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
473 }
474 // Append the new user message to the conversation history
475 msgHistory = append(msgHistory, userMsg)
476 }
477 }
478
479 continue
480 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
481 queuePrompts, ok := a.promptQueue.Take(sessionID)
482 if ok {
483 for _, prompt := range queuePrompts {
484 if prompt == "" {
485 continue
486 }
487 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
488 if err != nil {
489 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
490 }
491 msgHistory = append(msgHistory, userMsg)
492 }
493 continue
494 }
495 }
496 if agentMessage.FinishReason() == "" {
497 // Kujtim: could not track down where this is happening but this means its cancelled
498 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
499 _ = a.messages.Update(context.Background(), agentMessage)
500 return a.err(ErrRequestCancelled)
501 }
502 return AgentEvent{
503 Type: AgentEventTypeResponse,
504 Message: agentMessage,
505 Done: true,
506 }
507 }
508}
509
510func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
511 parts := []message.ContentPart{message.TextContent{Text: content}}
512 parts = append(parts, attachmentParts...)
513 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
514 Role: message.User,
515 Parts: parts,
516 })
517}
518
519func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
520 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
521
522 // Create the assistant message first so the spinner shows immediately
523 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
524 Role: message.Assistant,
525 Parts: []message.ContentPart{},
526 Model: a.Model().ID,
527 Provider: a.providerID,
528 })
529 if err != nil {
530 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
531 }
532
533 // Now collect tools (which may block on MCP initialization)
534 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.allTools())
535
536 // Add the session and message ID into the context if needed by tools.
537 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
538
539 // Process each event in the stream.
540 for event := range eventChan {
541 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
542 if errors.Is(processErr, context.Canceled) {
543 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
544 } else {
545 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
546 }
547 return assistantMsg, nil, processErr
548 }
549 if ctx.Err() != nil {
550 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
551 return assistantMsg, nil, ctx.Err()
552 }
553 }
554
555 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
556 toolCalls := assistantMsg.ToolCalls()
557 for i, toolCall := range toolCalls {
558 select {
559 case <-ctx.Done():
560 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
561 // Make all future tool calls cancelled
562 for j := i; j < len(toolCalls); j++ {
563 toolResults[j] = message.ToolResult{
564 ToolCallID: toolCalls[j].ID,
565 Content: "Tool execution canceled by user",
566 IsError: true,
567 }
568 }
569 goto out
570 default:
571 // Continue processing
572 tool, ok := a.getToolByName(toolCall.Name)
573
574 // Tool not found
575 if !ok {
576 toolResults[i] = message.ToolResult{
577 ToolCallID: toolCall.ID,
578 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
579 IsError: true,
580 }
581 continue
582 }
583
584 // Run tool in goroutine to allow cancellation
585 type toolExecResult struct {
586 response tools.ToolResponse
587 err error
588 }
589 resultChan := make(chan toolExecResult, 1)
590
591 go func() {
592 response, err := tool.Run(ctx, tools.ToolCall{
593 ID: toolCall.ID,
594 Name: toolCall.Name,
595 Input: toolCall.Input,
596 })
597 resultChan <- toolExecResult{response: response, err: err}
598 }()
599
600 var toolResponse tools.ToolResponse
601 var toolErr error
602
603 select {
604 case <-ctx.Done():
605 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
606 // Mark remaining tool calls as cancelled
607 for j := i; j < len(toolCalls); j++ {
608 toolResults[j] = message.ToolResult{
609 ToolCallID: toolCalls[j].ID,
610 Content: "Tool execution canceled by user",
611 IsError: true,
612 }
613 }
614 goto out
615 case result := <-resultChan:
616 toolResponse = result.response
617 toolErr = result.err
618 }
619
620 if toolErr != nil {
621 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
622 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
623 toolResults[i] = message.ToolResult{
624 ToolCallID: toolCall.ID,
625 Content: "Permission denied",
626 IsError: true,
627 }
628 for j := i + 1; j < len(toolCalls); j++ {
629 toolResults[j] = message.ToolResult{
630 ToolCallID: toolCalls[j].ID,
631 Content: "Tool execution canceled by user",
632 IsError: true,
633 }
634 }
635 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
636 break
637 }
638 }
639 toolResults[i] = message.ToolResult{
640 ToolCallID: toolCall.ID,
641 Content: toolResponse.Content,
642 Metadata: toolResponse.Metadata,
643 IsError: toolResponse.IsError,
644 }
645 }
646 }
647out:
648 if len(toolResults) == 0 {
649 return assistantMsg, nil, nil
650 }
651 parts := make([]message.ContentPart, 0)
652 for _, tr := range toolResults {
653 parts = append(parts, tr)
654 }
655 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
656 Role: message.Tool,
657 Parts: parts,
658 Provider: a.providerID,
659 })
660 if err != nil {
661 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
662 }
663
664 return assistantMsg, &msg, err
665}
666
667func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
668 msg.AddFinish(finishReason, message, details)
669 _ = a.messages.Update(ctx, *msg)
670}
671
672func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
673 select {
674 case <-ctx.Done():
675 return ctx.Err()
676 default:
677 // Continue processing.
678 }
679
680 switch event.Type {
681 case provider.EventThinkingDelta:
682 assistantMsg.AppendReasoningContent(event.Thinking)
683 return a.messages.Update(ctx, *assistantMsg)
684 case provider.EventSignatureDelta:
685 assistantMsg.AppendReasoningSignature(event.Signature)
686 return a.messages.Update(ctx, *assistantMsg)
687 case provider.EventContentDelta:
688 assistantMsg.FinishThinking()
689 assistantMsg.AppendContent(event.Content)
690 return a.messages.Update(ctx, *assistantMsg)
691 case provider.EventToolUseStart:
692 assistantMsg.FinishThinking()
693 slog.Info("Tool call started", "toolCall", event.ToolCall)
694 assistantMsg.AddToolCall(*event.ToolCall)
695 return a.messages.Update(ctx, *assistantMsg)
696 case provider.EventToolUseDelta:
697 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
698 return a.messages.Update(ctx, *assistantMsg)
699 case provider.EventToolUseStop:
700 slog.Info("Finished tool call", "toolCall", event.ToolCall)
701 assistantMsg.FinishToolCall(event.ToolCall.ID)
702 return a.messages.Update(ctx, *assistantMsg)
703 case provider.EventError:
704 return event.Error
705 case provider.EventComplete:
706 assistantMsg.FinishThinking()
707 assistantMsg.SetToolCalls(event.Response.ToolCalls)
708 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
709 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
710 return fmt.Errorf("failed to update message: %w", err)
711 }
712 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
713 }
714
715 return nil
716}
717
718func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
719 sess, err := a.sessions.Get(ctx, sessionID)
720 if err != nil {
721 return fmt.Errorf("failed to get session: %w", err)
722 }
723
724 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
725 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
726 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
727 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
728
729 sess.Cost += cost
730 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
731 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
732
733 _, err = a.sessions.Save(ctx, sess)
734 if err != nil {
735 return fmt.Errorf("failed to save session: %w", err)
736 }
737 return nil
738}
739
740func (a *agent) Summarize(ctx context.Context, sessionID string) error {
741 if a.summarizeProvider == nil {
742 return fmt.Errorf("summarize provider not available")
743 }
744
745 // Check if session is busy
746 if a.IsSessionBusy(sessionID) {
747 return ErrSessionBusy
748 }
749
750 // Create a new context with cancellation
751 summarizeCtx, cancel := context.WithCancel(ctx)
752
753 // Store the cancel function in activeRequests to allow cancellation
754 a.activeRequests.Set(sessionID+"-summarize", cancel)
755
756 go func() {
757 defer a.activeRequests.Del(sessionID + "-summarize")
758 defer cancel()
759 event := AgentEvent{
760 Type: AgentEventTypeSummarize,
761 Progress: "Starting summarization...",
762 }
763
764 a.Publish(pubsub.CreatedEvent, event)
765 // Get all messages from the session
766 msgs, err := a.messages.List(summarizeCtx, sessionID)
767 if err != nil {
768 event = AgentEvent{
769 Type: AgentEventTypeError,
770 Error: fmt.Errorf("failed to list messages: %w", err),
771 Done: true,
772 }
773 a.Publish(pubsub.CreatedEvent, event)
774 return
775 }
776 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
777
778 if len(msgs) == 0 {
779 event = AgentEvent{
780 Type: AgentEventTypeError,
781 Error: fmt.Errorf("no messages to summarize"),
782 Done: true,
783 }
784 a.Publish(pubsub.CreatedEvent, event)
785 return
786 }
787
788 event = AgentEvent{
789 Type: AgentEventTypeSummarize,
790 Progress: "Analyzing conversation...",
791 }
792 a.Publish(pubsub.CreatedEvent, event)
793
794 // Add a system message to guide the summarization
795 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
796
797 // Create a new message with the summarize prompt
798 promptMsg := message.Message{
799 Role: message.User,
800 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
801 }
802
803 // Append the prompt to the messages
804 msgsWithPrompt := append(msgs, promptMsg)
805
806 event = AgentEvent{
807 Type: AgentEventTypeSummarize,
808 Progress: "Generating summary...",
809 }
810
811 a.Publish(pubsub.CreatedEvent, event)
812
813 // Send the messages to the summarize provider
814 response := a.summarizeProvider.StreamResponse(
815 summarizeCtx,
816 msgsWithPrompt,
817 nil,
818 )
819 var finalResponse *provider.ProviderResponse
820 for r := range response {
821 if r.Error != nil {
822 event = AgentEvent{
823 Type: AgentEventTypeError,
824 Error: fmt.Errorf("failed to summarize: %w", err),
825 Done: true,
826 }
827 a.Publish(pubsub.CreatedEvent, event)
828 return
829 }
830 finalResponse = r.Response
831 }
832
833 summary := strings.TrimSpace(finalResponse.Content)
834 if summary == "" {
835 event = AgentEvent{
836 Type: AgentEventTypeError,
837 Error: fmt.Errorf("empty summary returned"),
838 Done: true,
839 }
840 a.Publish(pubsub.CreatedEvent, event)
841 return
842 }
843 shell := shell.GetPersistentShell(config.Get().WorkingDir())
844 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
845 event = AgentEvent{
846 Type: AgentEventTypeSummarize,
847 Progress: "Creating new session...",
848 }
849
850 a.Publish(pubsub.CreatedEvent, event)
851 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
852 if err != nil {
853 event = AgentEvent{
854 Type: AgentEventTypeError,
855 Error: fmt.Errorf("failed to get session: %w", err),
856 Done: true,
857 }
858
859 a.Publish(pubsub.CreatedEvent, event)
860 return
861 }
862 // Create a message in the new session with the summary
863 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
864 Role: message.Assistant,
865 Parts: []message.ContentPart{
866 message.TextContent{Text: summary},
867 message.Finish{
868 Reason: message.FinishReasonEndTurn,
869 Time: time.Now().Unix(),
870 },
871 },
872 Model: a.summarizeProvider.Model().ID,
873 Provider: a.summarizeProviderID,
874 })
875 if err != nil {
876 event = AgentEvent{
877 Type: AgentEventTypeError,
878 Error: fmt.Errorf("failed to create summary message: %w", err),
879 Done: true,
880 }
881
882 a.Publish(pubsub.CreatedEvent, event)
883 return
884 }
885 oldSession.SummaryMessageID = msg.ID
886 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
887 oldSession.PromptTokens = 0
888 model := a.summarizeProvider.Model()
889 usage := finalResponse.Usage
890 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
891 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
892 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
893 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
894 oldSession.Cost += cost
895 _, err = a.sessions.Save(summarizeCtx, oldSession)
896 if err != nil {
897 event = AgentEvent{
898 Type: AgentEventTypeError,
899 Error: fmt.Errorf("failed to save session: %w", err),
900 Done: true,
901 }
902 a.Publish(pubsub.CreatedEvent, event)
903 }
904
905 event = AgentEvent{
906 Type: AgentEventTypeSummarize,
907 SessionID: oldSession.ID,
908 Progress: "Summary complete",
909 Done: true,
910 }
911 a.Publish(pubsub.CreatedEvent, event)
912 // Send final success event with the new session ID
913 }()
914
915 return nil
916}
917
918func (a *agent) ClearQueue(sessionID string) {
919 if a.QueuedPrompts(sessionID) > 0 {
920 slog.Info("Clearing queued prompts", "session_id", sessionID)
921 a.promptQueue.Del(sessionID)
922 }
923}
924
925func (a *agent) CancelAll() {
926 if !a.IsBusy() {
927 return
928 }
929 for key := range a.activeRequests.Seq2() {
930 a.Cancel(key) // key is sessionID
931 }
932
933 for _, cleanup := range a.cleanupFuncs {
934 if cleanup != nil {
935 cleanup()
936 }
937 }
938
939 timeout := time.After(5 * time.Second)
940 for a.IsBusy() {
941 select {
942 case <-timeout:
943 return
944 default:
945 time.Sleep(200 * time.Millisecond)
946 }
947 }
948}
949
950func (a *agent) UpdateModel() error {
951 cfg := config.Get()
952
953 // Get current provider configuration
954 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
955 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
956 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
957 }
958
959 // Check if provider has changed
960 if string(currentProviderCfg.ID) != a.providerID {
961 // Provider changed, need to recreate the main provider
962 model := cfg.GetModelByType(a.agentCfg.Model)
963 if model.ID == "" {
964 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
965 }
966
967 promptID := agentPromptMap[a.agentCfg.ID]
968 if promptID == "" {
969 promptID = prompt.PromptDefault
970 }
971
972 opts := []provider.ProviderClientOption{
973 provider.WithModel(a.agentCfg.Model),
974 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
975 }
976
977 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
978 if err != nil {
979 return fmt.Errorf("failed to create new provider: %w", err)
980 }
981
982 // Update the provider and provider ID
983 a.provider = newProvider
984 a.providerID = string(currentProviderCfg.ID)
985 }
986
987 // Check if providers have changed for title (small) and summarize (large)
988 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
989 var smallModelProviderCfg config.ProviderConfig
990 for p := range cfg.Providers.Seq() {
991 if p.ID == smallModelCfg.Provider {
992 smallModelProviderCfg = p
993 break
994 }
995 }
996 if smallModelProviderCfg.ID == "" {
997 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
998 }
999
1000 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1001 var largeModelProviderCfg config.ProviderConfig
1002 for p := range cfg.Providers.Seq() {
1003 if p.ID == largeModelCfg.Provider {
1004 largeModelProviderCfg = p
1005 break
1006 }
1007 }
1008 if largeModelProviderCfg.ID == "" {
1009 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1010 }
1011
1012 // Recreate title provider
1013 titleOpts := []provider.ProviderClientOption{
1014 provider.WithModel(config.SelectedModelTypeSmall),
1015 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1016 provider.WithMaxTokens(40),
1017 }
1018 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1019 if err != nil {
1020 return fmt.Errorf("failed to create new title provider: %w", err)
1021 }
1022 a.titleProvider = newTitleProvider
1023
1024 // Recreate summarize provider if provider changed (now large model)
1025 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1026 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1027 if largeModel == nil {
1028 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1029 }
1030 summarizeOpts := []provider.ProviderClientOption{
1031 provider.WithModel(config.SelectedModelTypeLarge),
1032 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1033 }
1034 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1035 if err != nil {
1036 return fmt.Errorf("failed to create new summarize provider: %w", err)
1037 }
1038 a.summarizeProvider = newSummarizeProvider
1039 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1040 }
1041
1042 return nil
1043}
1044
1045func (a *agent) allTools() []tools.BaseTool {
1046 result := slices.Collect(a.baseTools.Seq())
1047 result = slices.AppendSeq(result, a.mcpTools.Seq())
1048 return result
1049}
1050
1051func (a *agent) getToolByName(name string) (tools.BaseTool, bool) {
1052 tool, ok := a.baseTools.Get(name)
1053 if !ok {
1054 tool, ok = a.mcpTools.Get(name)
1055 }
1056 return tool, ok
1057}
1058
1059func (a *agent) setupEvents() {
1060 ctx, cancel := context.WithCancel(a.globalCtx)
1061
1062 go func() {
1063 for event := range SubscribeMCPEvents(ctx) {
1064 switch event.Payload.Type {
1065 case MCPEventToolsListChanged:
1066 name := event.Payload.Name
1067 c, ok := mcpClients.Get(name)
1068 if !ok {
1069 slog.Warn("MCP client not found for tools update", "name", name)
1070 continue
1071 }
1072 tools := getTools(ctx, name, c)
1073 updateMcpTools(name, tools)
1074 // Update the lazy map with the new tools
1075 a.mcpTools = csync.NewMapFromSeq(mcpTools.Seq2())
1076 updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1077 default:
1078 continue
1079 }
1080 }
1081 }()
1082
1083 a.cleanupFuncs = append(a.cleanupFuncs, func() {
1084 cancel()
1085 })
1086}