1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "maps"
9 "slices"
10 "strings"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26 "github.com/charmbracelet/crush/internal/shell"
27)
28
29// Common errors
30var (
31 ErrRequestCancelled = errors.New("request canceled by user")
32 ErrSessionBusy = errors.New("session is currently processing another request")
33)
34
35type AgentEventType string
36
37const (
38 AgentEventTypeError AgentEventType = "error"
39 AgentEventTypeResponse AgentEventType = "response"
40 AgentEventTypeSummarize AgentEventType = "summarize"
41)
42
43type AgentEvent struct {
44 Type AgentEventType
45 Message message.Message
46 Error error
47
48 // When summarizing
49 SessionID string
50 Progress string
51 Done bool
52}
53
54type Service interface {
55 pubsub.Suscriber[AgentEvent]
56 Model() catwalk.Model
57 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
58 Cancel(sessionID string)
59 CancelAll()
60 IsSessionBusy(sessionID string) bool
61 IsBusy() bool
62 Summarize(ctx context.Context, sessionID string) error
63 UpdateModel() error
64 QueuedPrompts(sessionID string) int
65 ClearQueue(sessionID string)
66}
67
68type agent struct {
69 globalCtx context.Context
70 cleanupFuncs []func()
71
72 *pubsub.Broker[AgentEvent]
73 agentCfg config.Agent
74 sessions session.Service
75 messages message.Service
76
77 baseTools *csync.Map[string, tools.BaseTool]
78 mcpTools *csync.Map[string, tools.BaseTool]
79
80 provider provider.Provider
81 providerID string
82
83 titleProvider provider.Provider
84 summarizeProvider provider.Provider
85 summarizeProviderID string
86
87 activeRequests *csync.Map[string, context.CancelFunc]
88
89 promptQueue *csync.Map[string, []string]
90}
91
92var agentPromptMap = map[string]prompt.PromptID{
93 "coder": prompt.PromptCoder,
94 "task": prompt.PromptTask,
95}
96
97func NewAgent(
98 ctx context.Context,
99 agentCfg config.Agent,
100 // These services are needed in the tools
101 permissions permission.Service,
102 sessions session.Service,
103 messages message.Service,
104 history history.Service,
105 lspClients map[string]*lsp.Client,
106) (Service, error) {
107 cfg := config.Get()
108
109 var agentTool tools.BaseTool
110 if agentCfg.ID == "coder" {
111 taskAgentCfg := config.Get().Agents["task"]
112 if taskAgentCfg.ID == "" {
113 return nil, fmt.Errorf("task agent not found in config")
114 }
115 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
116 if err != nil {
117 return nil, fmt.Errorf("failed to create task agent: %w", err)
118 }
119
120 agentTool = NewAgentTool(taskAgent, sessions, messages)
121 }
122
123 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
124 if providerCfg == nil {
125 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
126 }
127 model := config.Get().GetModelByType(agentCfg.Model)
128
129 if model == nil {
130 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
131 }
132
133 promptID := agentPromptMap[agentCfg.ID]
134 if promptID == "" {
135 promptID = prompt.PromptDefault
136 }
137 opts := []provider.ProviderClientOption{
138 provider.WithModel(agentCfg.Model),
139 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
140 }
141 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
142 if err != nil {
143 return nil, err
144 }
145
146 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
147 var smallModelProviderCfg *config.ProviderConfig
148 if smallModelCfg.Provider == providerCfg.ID {
149 smallModelProviderCfg = providerCfg
150 } else {
151 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
152
153 if smallModelProviderCfg.ID == "" {
154 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
155 }
156 }
157 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
158 if smallModel.ID == "" {
159 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
160 }
161
162 titleOpts := []provider.ProviderClientOption{
163 provider.WithModel(config.SelectedModelTypeSmall),
164 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
165 }
166 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
167 if err != nil {
168 return nil, err
169 }
170
171 summarizeOpts := []provider.ProviderClientOption{
172 provider.WithModel(config.SelectedModelTypeLarge),
173 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
174 }
175 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
176 if err != nil {
177 return nil, err
178 }
179
180 baseToolsFn := func() map[string]tools.BaseTool {
181 slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
182 defer func() {
183 slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
184 }()
185
186 // Base tools available to all agents
187 cwd := cfg.WorkingDir()
188 result := make(map[string]tools.BaseTool)
189 for _, tool := range []tools.BaseTool{
190 tools.NewBashTool(permissions, cwd),
191 tools.NewDownloadTool(permissions, cwd),
192 tools.NewEditTool(lspClients, permissions, history, cwd),
193 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
194 tools.NewFetchTool(permissions, cwd),
195 tools.NewGlobTool(cwd),
196 tools.NewGrepTool(cwd),
197 tools.NewLsTool(permissions, cwd),
198 tools.NewSourcegraphTool(),
199 tools.NewViewTool(lspClients, permissions, cwd),
200 tools.NewWriteTool(lspClients, permissions, history, cwd),
201 } {
202 result[tool.Name()] = tool
203 }
204
205 if len(lspClients) > 0 {
206 diagnosticsTool := tools.NewDiagnosticsTool(lspClients)
207 result[diagnosticsTool.Name()] = diagnosticsTool
208 }
209
210 if agentTool != nil {
211 result[agentTool.Name()] = agentTool
212 }
213
214 return result
215 }
216 mcpToolsFn := func() map[string]tools.BaseTool {
217 slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
218 defer func() {
219 slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
220 }()
221
222 mcpToolsOnce.Do(func() {
223 doGetMCPTools(ctx, permissions, cfg)
224 })
225
226 result := make(map[string]tools.BaseTool)
227 for _, mcpTool := range mcpTools.Seq2() {
228 result[mcpTool.Name()] = mcpTool
229 }
230 return result
231 }
232
233 a := &agent{
234 globalCtx: ctx,
235 Broker: pubsub.NewBroker[AgentEvent](),
236 agentCfg: agentCfg,
237 provider: agentProvider,
238 providerID: string(providerCfg.ID),
239 messages: messages,
240 sessions: sessions,
241 titleProvider: titleProvider,
242 summarizeProvider: summarizeProvider,
243 summarizeProviderID: string(providerCfg.ID),
244 activeRequests: csync.NewMap[string, context.CancelFunc](),
245 mcpTools: csync.NewLazyMap(mcpToolsFn),
246 baseTools: csync.NewLazyMap(baseToolsFn),
247 promptQueue: csync.NewMap[string, []string](),
248 }
249 a.setupEvents()
250 return a, nil
251}
252
253func (a *agent) Model() catwalk.Model {
254 return *config.Get().GetModelByType(a.agentCfg.Model)
255}
256
257func (a *agent) Cancel(sessionID string) {
258 // Cancel regular requests
259 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
260 slog.Info("Request cancellation initiated", "session_id", sessionID)
261 cancel()
262 }
263
264 // Also check for summarize requests
265 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
266 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
267 cancel()
268 }
269
270 if a.QueuedPrompts(sessionID) > 0 {
271 slog.Info("Clearing queued prompts", "session_id", sessionID)
272 a.promptQueue.Del(sessionID)
273 }
274}
275
276func (a *agent) IsBusy() bool {
277 var busy bool
278 for cancelFunc := range a.activeRequests.Seq() {
279 if cancelFunc != nil {
280 busy = true
281 break
282 }
283 }
284 return busy
285}
286
287func (a *agent) IsSessionBusy(sessionID string) bool {
288 _, busy := a.activeRequests.Get(sessionID)
289 return busy
290}
291
292func (a *agent) QueuedPrompts(sessionID string) int {
293 l, ok := a.promptQueue.Get(sessionID)
294 if !ok {
295 return 0
296 }
297 return len(l)
298}
299
300func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
301 if content == "" {
302 return nil
303 }
304 if a.titleProvider == nil {
305 return nil
306 }
307 session, err := a.sessions.Get(ctx, sessionID)
308 if err != nil {
309 return err
310 }
311 parts := []message.ContentPart{message.TextContent{
312 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
313 }}
314
315 // Use streaming approach like summarization
316 response := a.titleProvider.StreamResponse(
317 ctx,
318 []message.Message{
319 {
320 Role: message.User,
321 Parts: parts,
322 },
323 },
324 nil,
325 )
326
327 var finalResponse *provider.ProviderResponse
328 for r := range response {
329 if r.Error != nil {
330 return r.Error
331 }
332 finalResponse = r.Response
333 }
334
335 if finalResponse == nil {
336 return fmt.Errorf("no response received from title provider")
337 }
338
339 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
340 if title == "" {
341 return nil
342 }
343
344 session.Title = title
345 _, err = a.sessions.Save(ctx, session)
346 return err
347}
348
349func (a *agent) err(err error) AgentEvent {
350 return AgentEvent{
351 Type: AgentEventTypeError,
352 Error: err,
353 }
354}
355
356func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
357 if !a.Model().SupportsImages && attachments != nil {
358 attachments = nil
359 }
360 events := make(chan AgentEvent)
361 if a.IsSessionBusy(sessionID) {
362 existing, ok := a.promptQueue.Get(sessionID)
363 if !ok {
364 existing = []string{}
365 }
366 existing = append(existing, content)
367 a.promptQueue.Set(sessionID, existing)
368 return nil, nil
369 }
370
371 genCtx, cancel := context.WithCancel(ctx)
372
373 a.activeRequests.Set(sessionID, cancel)
374 go func() {
375 slog.Debug("Request started", "sessionID", sessionID)
376 defer log.RecoverPanic("agent.Run", func() {
377 events <- a.err(fmt.Errorf("panic while running the agent"))
378 })
379 var attachmentParts []message.ContentPart
380 for _, attachment := range attachments {
381 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
382 }
383 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
384 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
385 slog.Error(result.Error.Error())
386 }
387 slog.Debug("Request completed", "sessionID", sessionID)
388 a.activeRequests.Del(sessionID)
389 cancel()
390 a.Publish(pubsub.CreatedEvent, result)
391 select {
392 case events <- result:
393 case <-genCtx.Done():
394 }
395 close(events)
396 }()
397 return events, nil
398}
399
400func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
401 cfg := config.Get()
402 // List existing messages; if none, start title generation asynchronously.
403 msgs, err := a.messages.List(ctx, sessionID)
404 if err != nil {
405 return a.err(fmt.Errorf("failed to list messages: %w", err))
406 }
407 if len(msgs) == 0 {
408 go func() {
409 defer log.RecoverPanic("agent.Run", func() {
410 slog.Error("panic while generating title")
411 })
412 titleErr := a.generateTitle(context.Background(), sessionID, content)
413 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
414 slog.Error("failed to generate title", "error", titleErr)
415 }
416 }()
417 }
418 session, err := a.sessions.Get(ctx, sessionID)
419 if err != nil {
420 return a.err(fmt.Errorf("failed to get session: %w", err))
421 }
422 if session.SummaryMessageID != "" {
423 summaryMsgInex := -1
424 for i, msg := range msgs {
425 if msg.ID == session.SummaryMessageID {
426 summaryMsgInex = i
427 break
428 }
429 }
430 if summaryMsgInex != -1 {
431 msgs = msgs[summaryMsgInex:]
432 msgs[0].Role = message.User
433 }
434 }
435
436 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
437 if err != nil {
438 return a.err(fmt.Errorf("failed to create user message: %w", err))
439 }
440 // Append the new user message to the conversation history.
441 msgHistory := append(msgs, userMsg)
442
443 for {
444 // Check for cancellation before each iteration
445 select {
446 case <-ctx.Done():
447 return a.err(ctx.Err())
448 default:
449 // Continue processing
450 }
451 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
452 if err != nil {
453 if errors.Is(err, context.Canceled) {
454 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
455 a.messages.Update(context.Background(), agentMessage)
456 return a.err(ErrRequestCancelled)
457 }
458 return a.err(fmt.Errorf("failed to process events: %w", err))
459 }
460 if cfg.Options.Debug {
461 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
462 }
463 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
464 // We are not done, we need to respond with the tool response
465 msgHistory = append(msgHistory, agentMessage, *toolResults)
466 // If there are queued prompts, process the next one
467 nextPrompt, ok := a.promptQueue.Take(sessionID)
468 if ok {
469 for _, prompt := range nextPrompt {
470 // Create a new user message for the queued prompt
471 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
472 if err != nil {
473 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
474 }
475 // Append the new user message to the conversation history
476 msgHistory = append(msgHistory, userMsg)
477 }
478 }
479
480 continue
481 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
482 queuePrompts, ok := a.promptQueue.Take(sessionID)
483 if ok {
484 for _, prompt := range queuePrompts {
485 if prompt == "" {
486 continue
487 }
488 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
489 if err != nil {
490 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
491 }
492 msgHistory = append(msgHistory, userMsg)
493 }
494 continue
495 }
496 }
497 if agentMessage.FinishReason() == "" {
498 // Kujtim: could not track down where this is happening but this means its cancelled
499 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
500 _ = a.messages.Update(context.Background(), agentMessage)
501 return a.err(ErrRequestCancelled)
502 }
503 return AgentEvent{
504 Type: AgentEventTypeResponse,
505 Message: agentMessage,
506 Done: true,
507 }
508 }
509}
510
511func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
512 parts := []message.ContentPart{message.TextContent{Text: content}}
513 parts = append(parts, attachmentParts...)
514 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
515 Role: message.User,
516 Parts: parts,
517 })
518}
519
520func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
521 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
522
523 // Create the assistant message first so the spinner shows immediately
524 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
525 Role: message.Assistant,
526 Parts: []message.ContentPart{},
527 Model: a.Model().ID,
528 Provider: a.providerID,
529 })
530 if err != nil {
531 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
532 }
533
534 // Now collect tools (which may block on MCP initialization)
535 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.allTools())
536
537 // Add the session and message ID into the context if needed by tools.
538 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
539
540 // Process each event in the stream.
541 for event := range eventChan {
542 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
543 if errors.Is(processErr, context.Canceled) {
544 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
545 } else {
546 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
547 }
548 return assistantMsg, nil, processErr
549 }
550 if ctx.Err() != nil {
551 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
552 return assistantMsg, nil, ctx.Err()
553 }
554 }
555
556 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
557 toolCalls := assistantMsg.ToolCalls()
558 for i, toolCall := range toolCalls {
559 select {
560 case <-ctx.Done():
561 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
562 // Make all future tool calls cancelled
563 for j := i; j < len(toolCalls); j++ {
564 toolResults[j] = message.ToolResult{
565 ToolCallID: toolCalls[j].ID,
566 Content: "Tool execution canceled by user",
567 IsError: true,
568 }
569 }
570 goto out
571 default:
572 // Continue processing
573 tool, ok := a.getToolByName(toolCall.Name)
574
575 // Tool not found
576 if !ok {
577 toolResults[i] = message.ToolResult{
578 ToolCallID: toolCall.ID,
579 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
580 IsError: true,
581 }
582 continue
583 }
584
585 // Run tool in goroutine to allow cancellation
586 type toolExecResult struct {
587 response tools.ToolResponse
588 err error
589 }
590 resultChan := make(chan toolExecResult, 1)
591
592 go func() {
593 response, err := tool.Run(ctx, tools.ToolCall{
594 ID: toolCall.ID,
595 Name: toolCall.Name,
596 Input: toolCall.Input,
597 })
598 resultChan <- toolExecResult{response: response, err: err}
599 }()
600
601 var toolResponse tools.ToolResponse
602 var toolErr error
603
604 select {
605 case <-ctx.Done():
606 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
607 // Mark remaining tool calls as cancelled
608 for j := i; j < len(toolCalls); j++ {
609 toolResults[j] = message.ToolResult{
610 ToolCallID: toolCalls[j].ID,
611 Content: "Tool execution canceled by user",
612 IsError: true,
613 }
614 }
615 goto out
616 case result := <-resultChan:
617 toolResponse = result.response
618 toolErr = result.err
619 }
620
621 if toolErr != nil {
622 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
623 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
624 toolResults[i] = message.ToolResult{
625 ToolCallID: toolCall.ID,
626 Content: "Permission denied",
627 IsError: true,
628 }
629 for j := i + 1; j < len(toolCalls); j++ {
630 toolResults[j] = message.ToolResult{
631 ToolCallID: toolCalls[j].ID,
632 Content: "Tool execution canceled by user",
633 IsError: true,
634 }
635 }
636 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
637 break
638 }
639 }
640 toolResults[i] = message.ToolResult{
641 ToolCallID: toolCall.ID,
642 Content: toolResponse.Content,
643 Metadata: toolResponse.Metadata,
644 IsError: toolResponse.IsError,
645 }
646 }
647 }
648out:
649 if len(toolResults) == 0 {
650 return assistantMsg, nil, nil
651 }
652 parts := make([]message.ContentPart, 0)
653 for _, tr := range toolResults {
654 parts = append(parts, tr)
655 }
656 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
657 Role: message.Tool,
658 Parts: parts,
659 Provider: a.providerID,
660 })
661 if err != nil {
662 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
663 }
664
665 return assistantMsg, &msg, err
666}
667
668func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
669 msg.AddFinish(finishReason, message, details)
670 _ = a.messages.Update(ctx, *msg)
671}
672
673func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
674 select {
675 case <-ctx.Done():
676 return ctx.Err()
677 default:
678 // Continue processing.
679 }
680
681 switch event.Type {
682 case provider.EventThinkingDelta:
683 assistantMsg.AppendReasoningContent(event.Thinking)
684 return a.messages.Update(ctx, *assistantMsg)
685 case provider.EventSignatureDelta:
686 assistantMsg.AppendReasoningSignature(event.Signature)
687 return a.messages.Update(ctx, *assistantMsg)
688 case provider.EventContentDelta:
689 assistantMsg.FinishThinking()
690 assistantMsg.AppendContent(event.Content)
691 return a.messages.Update(ctx, *assistantMsg)
692 case provider.EventToolUseStart:
693 assistantMsg.FinishThinking()
694 slog.Info("Tool call started", "toolCall", event.ToolCall)
695 assistantMsg.AddToolCall(*event.ToolCall)
696 return a.messages.Update(ctx, *assistantMsg)
697 case provider.EventToolUseDelta:
698 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
699 return a.messages.Update(ctx, *assistantMsg)
700 case provider.EventToolUseStop:
701 slog.Info("Finished tool call", "toolCall", event.ToolCall)
702 assistantMsg.FinishToolCall(event.ToolCall.ID)
703 return a.messages.Update(ctx, *assistantMsg)
704 case provider.EventError:
705 return event.Error
706 case provider.EventComplete:
707 assistantMsg.FinishThinking()
708 assistantMsg.SetToolCalls(event.Response.ToolCalls)
709 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
710 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
711 return fmt.Errorf("failed to update message: %w", err)
712 }
713 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
714 }
715
716 return nil
717}
718
719func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
720 sess, err := a.sessions.Get(ctx, sessionID)
721 if err != nil {
722 return fmt.Errorf("failed to get session: %w", err)
723 }
724
725 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
726 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
727 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
728 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
729
730 sess.Cost += cost
731 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
732 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
733
734 _, err = a.sessions.Save(ctx, sess)
735 if err != nil {
736 return fmt.Errorf("failed to save session: %w", err)
737 }
738 return nil
739}
740
741func (a *agent) Summarize(ctx context.Context, sessionID string) error {
742 if a.summarizeProvider == nil {
743 return fmt.Errorf("summarize provider not available")
744 }
745
746 // Check if session is busy
747 if a.IsSessionBusy(sessionID) {
748 return ErrSessionBusy
749 }
750
751 // Create a new context with cancellation
752 summarizeCtx, cancel := context.WithCancel(ctx)
753
754 // Store the cancel function in activeRequests to allow cancellation
755 a.activeRequests.Set(sessionID+"-summarize", cancel)
756
757 go func() {
758 defer a.activeRequests.Del(sessionID + "-summarize")
759 defer cancel()
760 event := AgentEvent{
761 Type: AgentEventTypeSummarize,
762 Progress: "Starting summarization...",
763 }
764
765 a.Publish(pubsub.CreatedEvent, event)
766 // Get all messages from the session
767 msgs, err := a.messages.List(summarizeCtx, sessionID)
768 if err != nil {
769 event = AgentEvent{
770 Type: AgentEventTypeError,
771 Error: fmt.Errorf("failed to list messages: %w", err),
772 Done: true,
773 }
774 a.Publish(pubsub.CreatedEvent, event)
775 return
776 }
777 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
778
779 if len(msgs) == 0 {
780 event = AgentEvent{
781 Type: AgentEventTypeError,
782 Error: fmt.Errorf("no messages to summarize"),
783 Done: true,
784 }
785 a.Publish(pubsub.CreatedEvent, event)
786 return
787 }
788
789 event = AgentEvent{
790 Type: AgentEventTypeSummarize,
791 Progress: "Analyzing conversation...",
792 }
793 a.Publish(pubsub.CreatedEvent, event)
794
795 // Add a system message to guide the summarization
796 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
797
798 // Create a new message with the summarize prompt
799 promptMsg := message.Message{
800 Role: message.User,
801 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
802 }
803
804 // Append the prompt to the messages
805 msgsWithPrompt := append(msgs, promptMsg)
806
807 event = AgentEvent{
808 Type: AgentEventTypeSummarize,
809 Progress: "Generating summary...",
810 }
811
812 a.Publish(pubsub.CreatedEvent, event)
813
814 // Send the messages to the summarize provider
815 response := a.summarizeProvider.StreamResponse(
816 summarizeCtx,
817 msgsWithPrompt,
818 nil,
819 )
820 var finalResponse *provider.ProviderResponse
821 for r := range response {
822 if r.Error != nil {
823 event = AgentEvent{
824 Type: AgentEventTypeError,
825 Error: fmt.Errorf("failed to summarize: %w", err),
826 Done: true,
827 }
828 a.Publish(pubsub.CreatedEvent, event)
829 return
830 }
831 finalResponse = r.Response
832 }
833
834 summary := strings.TrimSpace(finalResponse.Content)
835 if summary == "" {
836 event = AgentEvent{
837 Type: AgentEventTypeError,
838 Error: fmt.Errorf("empty summary returned"),
839 Done: true,
840 }
841 a.Publish(pubsub.CreatedEvent, event)
842 return
843 }
844 shell := shell.GetPersistentShell(config.Get().WorkingDir())
845 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
846 event = AgentEvent{
847 Type: AgentEventTypeSummarize,
848 Progress: "Creating new session...",
849 }
850
851 a.Publish(pubsub.CreatedEvent, event)
852 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
853 if err != nil {
854 event = AgentEvent{
855 Type: AgentEventTypeError,
856 Error: fmt.Errorf("failed to get session: %w", err),
857 Done: true,
858 }
859
860 a.Publish(pubsub.CreatedEvent, event)
861 return
862 }
863 // Create a message in the new session with the summary
864 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
865 Role: message.Assistant,
866 Parts: []message.ContentPart{
867 message.TextContent{Text: summary},
868 message.Finish{
869 Reason: message.FinishReasonEndTurn,
870 Time: time.Now().Unix(),
871 },
872 },
873 Model: a.summarizeProvider.Model().ID,
874 Provider: a.summarizeProviderID,
875 })
876 if err != nil {
877 event = AgentEvent{
878 Type: AgentEventTypeError,
879 Error: fmt.Errorf("failed to create summary message: %w", err),
880 Done: true,
881 }
882
883 a.Publish(pubsub.CreatedEvent, event)
884 return
885 }
886 oldSession.SummaryMessageID = msg.ID
887 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
888 oldSession.PromptTokens = 0
889 model := a.summarizeProvider.Model()
890 usage := finalResponse.Usage
891 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
892 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
893 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
894 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
895 oldSession.Cost += cost
896 _, err = a.sessions.Save(summarizeCtx, oldSession)
897 if err != nil {
898 event = AgentEvent{
899 Type: AgentEventTypeError,
900 Error: fmt.Errorf("failed to save session: %w", err),
901 Done: true,
902 }
903 a.Publish(pubsub.CreatedEvent, event)
904 }
905
906 event = AgentEvent{
907 Type: AgentEventTypeSummarize,
908 SessionID: oldSession.ID,
909 Progress: "Summary complete",
910 Done: true,
911 }
912 a.Publish(pubsub.CreatedEvent, event)
913 // Send final success event with the new session ID
914 }()
915
916 return nil
917}
918
919func (a *agent) ClearQueue(sessionID string) {
920 if a.QueuedPrompts(sessionID) > 0 {
921 slog.Info("Clearing queued prompts", "session_id", sessionID)
922 a.promptQueue.Del(sessionID)
923 }
924}
925
926func (a *agent) CancelAll() {
927 if !a.IsBusy() {
928 return
929 }
930 for key := range a.activeRequests.Seq2() {
931 a.Cancel(key) // key is sessionID
932 }
933
934 for _, cleanup := range a.cleanupFuncs {
935 if cleanup != nil {
936 cleanup()
937 }
938 }
939
940 timeout := time.After(5 * time.Second)
941 for a.IsBusy() {
942 select {
943 case <-timeout:
944 return
945 default:
946 time.Sleep(200 * time.Millisecond)
947 }
948 }
949}
950
951func (a *agent) UpdateModel() error {
952 cfg := config.Get()
953
954 // Get current provider configuration
955 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
956 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
957 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
958 }
959
960 // Check if provider has changed
961 if string(currentProviderCfg.ID) != a.providerID {
962 // Provider changed, need to recreate the main provider
963 model := cfg.GetModelByType(a.agentCfg.Model)
964 if model.ID == "" {
965 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
966 }
967
968 promptID := agentPromptMap[a.agentCfg.ID]
969 if promptID == "" {
970 promptID = prompt.PromptDefault
971 }
972
973 opts := []provider.ProviderClientOption{
974 provider.WithModel(a.agentCfg.Model),
975 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
976 }
977
978 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
979 if err != nil {
980 return fmt.Errorf("failed to create new provider: %w", err)
981 }
982
983 // Update the provider and provider ID
984 a.provider = newProvider
985 a.providerID = string(currentProviderCfg.ID)
986 }
987
988 // Check if providers have changed for title (small) and summarize (large)
989 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
990 var smallModelProviderCfg config.ProviderConfig
991 for p := range cfg.Providers.Seq() {
992 if p.ID == smallModelCfg.Provider {
993 smallModelProviderCfg = p
994 break
995 }
996 }
997 if smallModelProviderCfg.ID == "" {
998 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
999 }
1000
1001 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1002 var largeModelProviderCfg config.ProviderConfig
1003 for p := range cfg.Providers.Seq() {
1004 if p.ID == largeModelCfg.Provider {
1005 largeModelProviderCfg = p
1006 break
1007 }
1008 }
1009 if largeModelProviderCfg.ID == "" {
1010 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1011 }
1012
1013 // Recreate title provider
1014 titleOpts := []provider.ProviderClientOption{
1015 provider.WithModel(config.SelectedModelTypeSmall),
1016 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1017 provider.WithMaxTokens(40),
1018 }
1019 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1020 if err != nil {
1021 return fmt.Errorf("failed to create new title provider: %w", err)
1022 }
1023 a.titleProvider = newTitleProvider
1024
1025 // Recreate summarize provider if provider changed (now large model)
1026 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1027 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1028 if largeModel == nil {
1029 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1030 }
1031 summarizeOpts := []provider.ProviderClientOption{
1032 provider.WithModel(config.SelectedModelTypeLarge),
1033 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1034 }
1035 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1036 if err != nil {
1037 return fmt.Errorf("failed to create new summarize provider: %w", err)
1038 }
1039 a.summarizeProvider = newSummarizeProvider
1040 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1041 }
1042
1043 return nil
1044}
1045
1046func (a *agent) allTools() []tools.BaseTool {
1047 result := slices.Collect(a.baseTools.Seq())
1048 result = slices.AppendSeq(result, a.mcpTools.Seq())
1049 return result
1050}
1051
1052func (a *agent) getToolByName(name string) (tools.BaseTool, bool) {
1053 tool, ok := a.baseTools.Get(name)
1054 if !ok {
1055 tool, ok = a.mcpTools.Get(name)
1056 }
1057 return tool, ok
1058}
1059
1060func (a *agent) setupEvents() {
1061 ctx, cancel := context.WithCancel(a.globalCtx)
1062
1063 go func() {
1064 for event := range SubscribeMCPEvents(ctx) {
1065 switch event.Payload.Type {
1066 case MCPEventToolsListChanged:
1067 name := event.Payload.Name
1068 c, ok := mcpClients.Get(name)
1069 if !ok {
1070 slog.Warn("MCP client not found for tools update", "name", name)
1071 continue
1072 }
1073 tools := getTools(ctx, name, c)
1074 updateMcpTools(name, tools)
1075 // Update the lazy map with the new tools
1076 a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2()))
1077 updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1078 default:
1079 continue
1080 }
1081 }
1082 }()
1083
1084 a.cleanupFuncs = append(a.cleanupFuncs, func() {
1085 cancel()
1086 })
1087}