1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.Map[string, tools.BaseTool]
75
76 provider provider.Provider
77 providerID string
78
79 titleProvider provider.Provider
80 summarizeProvider provider.Provider
81 summarizeProviderID string
82
83 activeRequests *csync.Map[string, context.CancelFunc]
84
85 promptQueue *csync.Map[string, []string]
86}
87
88var agentPromptMap = map[string]prompt.PromptID{
89 "coder": prompt.PromptCoder,
90 "task": prompt.PromptTask,
91}
92
93func NewAgent(
94 ctx context.Context,
95 agentCfg config.Agent,
96 // These services are needed in the tools
97 permissions permission.Service,
98 sessions session.Service,
99 messages message.Service,
100 history history.Service,
101 lspClients map[string]*lsp.Client,
102) (Service, error) {
103 cfg := config.Get()
104
105 var agentTool tools.BaseTool
106 if agentCfg.ID == "coder" {
107 taskAgentCfg := config.Get().Agents["task"]
108 if taskAgentCfg.ID == "" {
109 return nil, fmt.Errorf("task agent not found in config")
110 }
111 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
112 if err != nil {
113 return nil, fmt.Errorf("failed to create task agent: %w", err)
114 }
115
116 agentTool = NewAgentTool(taskAgent, sessions, messages)
117 }
118
119 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
120 if providerCfg == nil {
121 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
122 }
123 model := config.Get().GetModelByType(agentCfg.Model)
124
125 if model == nil {
126 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
127 }
128
129 promptID := agentPromptMap[agentCfg.ID]
130 if promptID == "" {
131 promptID = prompt.PromptDefault
132 }
133 opts := []provider.ProviderClientOption{
134 provider.WithModel(agentCfg.Model),
135 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
136 }
137 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
138 if err != nil {
139 return nil, err
140 }
141
142 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
143 var smallModelProviderCfg *config.ProviderConfig
144 if smallModelCfg.Provider == providerCfg.ID {
145 smallModelProviderCfg = providerCfg
146 } else {
147 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
148
149 if smallModelProviderCfg.ID == "" {
150 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
151 }
152 }
153 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
154 if smallModel.ID == "" {
155 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
156 }
157
158 titleOpts := []provider.ProviderClientOption{
159 provider.WithModel(config.SelectedModelTypeSmall),
160 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
161 }
162 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
163 if err != nil {
164 return nil, err
165 }
166
167 summarizeOpts := []provider.ProviderClientOption{
168 provider.WithModel(config.SelectedModelTypeLarge),
169 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
170 }
171 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
172 if err != nil {
173 return nil, err
174 }
175
176 toolFn := func() *csync.Map[string, tools.BaseTool] {
177 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
178 defer func() {
179 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
180 }()
181
182 cwd := cfg.WorkingDir()
183 toolMap := csync.NewMap[string, tools.BaseTool]()
184
185 // Base tools available to all agents
186 baseTools := []tools.BaseTool{
187 tools.NewBashTool(permissions, cwd),
188 tools.NewDownloadTool(permissions, cwd),
189 tools.NewEditTool(lspClients, permissions, history, cwd),
190 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
191 tools.NewFetchTool(permissions, cwd),
192 tools.NewGlobTool(cwd),
193 tools.NewGrepTool(cwd),
194 tools.NewLsTool(permissions, cwd),
195 tools.NewSourcegraphTool(),
196 tools.NewViewTool(lspClients, permissions, cwd),
197 tools.NewWriteTool(lspClients, permissions, history, cwd),
198 }
199 for _, tool := range baseTools {
200 toolMap.Set(tool.Name(), tool)
201 }
202
203 mcpToolsOnce.Do(func() {
204 mcpTools = doGetMCPTools(ctx, permissions, cfg)
205 })
206 for _, mcpTool := range mcpTools {
207 toolMap.Set(mcpTool.Name(), mcpTool)
208 }
209
210 if len(lspClients) > 0 {
211 diagnosticsTool := tools.NewDiagnosticsTool(lspClients)
212 toolMap.Set(diagnosticsTool.Name(), diagnosticsTool)
213 }
214
215 if agentTool != nil {
216 toolMap.Set(agentTool.Name(), agentTool)
217 }
218
219 if agentCfg.AllowedTools != nil {
220 // Filter tools based on allowed tools list
221 for toolName := range toolMap.Seq2() {
222 if !slices.Contains(agentCfg.AllowedTools, toolName) {
223 toolMap.Del(toolName)
224 }
225 }
226 }
227 return toolMap
228 }
229
230 return &agent{
231 Broker: pubsub.NewBroker[AgentEvent](),
232 agentCfg: agentCfg,
233 provider: agentProvider,
234 providerID: string(providerCfg.ID),
235 messages: messages,
236 sessions: sessions,
237 titleProvider: titleProvider,
238 summarizeProvider: summarizeProvider,
239 summarizeProviderID: string(providerCfg.ID),
240 activeRequests: csync.NewMap[string, context.CancelFunc](),
241 tools: toolFn(),
242 promptQueue: csync.NewMap[string, []string](),
243 }, nil
244}
245
246func (a *agent) Model() catwalk.Model {
247 return *config.Get().GetModelByType(a.agentCfg.Model)
248}
249
250func (a *agent) Cancel(sessionID string) {
251 // Cancel regular requests
252 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
253 slog.Info("Request cancellation initiated", "session_id", sessionID)
254 cancel()
255 }
256
257 // Also check for summarize requests
258 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
259 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
260 cancel()
261 }
262
263 if a.QueuedPrompts(sessionID) > 0 {
264 slog.Info("Clearing queued prompts", "session_id", sessionID)
265 a.promptQueue.Del(sessionID)
266 }
267}
268
269func (a *agent) IsBusy() bool {
270 var busy bool
271 for cancelFunc := range a.activeRequests.Seq() {
272 if cancelFunc != nil {
273 busy = true
274 break
275 }
276 }
277 return busy
278}
279
280func (a *agent) IsSessionBusy(sessionID string) bool {
281 _, busy := a.activeRequests.Get(sessionID)
282 return busy
283}
284
285func (a *agent) QueuedPrompts(sessionID string) int {
286 l, ok := a.promptQueue.Get(sessionID)
287 if !ok {
288 return 0
289 }
290 return len(l)
291}
292
293func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
294 if content == "" {
295 return nil
296 }
297 if a.titleProvider == nil {
298 return nil
299 }
300 session, err := a.sessions.Get(ctx, sessionID)
301 if err != nil {
302 return err
303 }
304 parts := []message.ContentPart{message.TextContent{
305 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
306 }}
307
308 // Use streaming approach like summarization
309 response := a.titleProvider.StreamResponse(
310 ctx,
311 []message.Message{
312 {
313 Role: message.User,
314 Parts: parts,
315 },
316 },
317 nil,
318 )
319
320 var finalResponse *provider.ProviderResponse
321 for r := range response {
322 if r.Error != nil {
323 return r.Error
324 }
325 finalResponse = r.Response
326 }
327
328 if finalResponse == nil {
329 return fmt.Errorf("no response received from title provider")
330 }
331
332 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
333 if title == "" {
334 return nil
335 }
336
337 session.Title = title
338 _, err = a.sessions.Save(ctx, session)
339 return err
340}
341
342func (a *agent) err(err error) AgentEvent {
343 return AgentEvent{
344 Type: AgentEventTypeError,
345 Error: err,
346 }
347}
348
349func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
350 if !a.Model().SupportsImages && attachments != nil {
351 attachments = nil
352 }
353 events := make(chan AgentEvent)
354 if a.IsSessionBusy(sessionID) {
355 existing, ok := a.promptQueue.Get(sessionID)
356 if !ok {
357 existing = []string{}
358 }
359 existing = append(existing, content)
360 a.promptQueue.Set(sessionID, existing)
361 return nil, nil
362 }
363
364 genCtx, cancel := context.WithCancel(ctx)
365
366 a.activeRequests.Set(sessionID, cancel)
367 go func() {
368 slog.Debug("Request started", "sessionID", sessionID)
369 defer log.RecoverPanic("agent.Run", func() {
370 events <- a.err(fmt.Errorf("panic while running the agent"))
371 })
372 var attachmentParts []message.ContentPart
373 for _, attachment := range attachments {
374 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
375 }
376 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
377 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
378 slog.Error(result.Error.Error())
379 }
380 slog.Debug("Request completed", "sessionID", sessionID)
381 a.activeRequests.Del(sessionID)
382 cancel()
383 a.Publish(pubsub.CreatedEvent, result)
384 select {
385 case events <- result:
386 case <-genCtx.Done():
387 }
388 close(events)
389 }()
390 return events, nil
391}
392
393func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
394 cfg := config.Get()
395 // List existing messages; if none, start title generation asynchronously.
396 msgs, err := a.messages.List(ctx, sessionID)
397 if err != nil {
398 return a.err(fmt.Errorf("failed to list messages: %w", err))
399 }
400 if len(msgs) == 0 {
401 go func() {
402 defer log.RecoverPanic("agent.Run", func() {
403 slog.Error("panic while generating title")
404 })
405 titleErr := a.generateTitle(context.Background(), sessionID, content)
406 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
407 slog.Error("failed to generate title", "error", titleErr)
408 }
409 }()
410 }
411 session, err := a.sessions.Get(ctx, sessionID)
412 if err != nil {
413 return a.err(fmt.Errorf("failed to get session: %w", err))
414 }
415 if session.SummaryMessageID != "" {
416 summaryMsgInex := -1
417 for i, msg := range msgs {
418 if msg.ID == session.SummaryMessageID {
419 summaryMsgInex = i
420 break
421 }
422 }
423 if summaryMsgInex != -1 {
424 msgs = msgs[summaryMsgInex:]
425 msgs[0].Role = message.User
426 }
427 }
428
429 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
430 if err != nil {
431 return a.err(fmt.Errorf("failed to create user message: %w", err))
432 }
433 // Append the new user message to the conversation history.
434 msgHistory := append(msgs, userMsg)
435
436 for {
437 // Check for cancellation before each iteration
438 select {
439 case <-ctx.Done():
440 return a.err(ctx.Err())
441 default:
442 // Continue processing
443 }
444 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
445 if err != nil {
446 if errors.Is(err, context.Canceled) {
447 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
448 a.messages.Update(context.Background(), agentMessage)
449 return a.err(ErrRequestCancelled)
450 }
451 return a.err(fmt.Errorf("failed to process events: %w", err))
452 }
453 if cfg.Options.Debug {
454 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
455 }
456 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
457 // We are not done, we need to respond with the tool response
458 msgHistory = append(msgHistory, agentMessage, *toolResults)
459 // If there are queued prompts, process the next one
460 nextPrompt, ok := a.promptQueue.Take(sessionID)
461 if ok {
462 for _, prompt := range nextPrompt {
463 // Create a new user message for the queued prompt
464 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
465 if err != nil {
466 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
467 }
468 // Append the new user message to the conversation history
469 msgHistory = append(msgHistory, userMsg)
470 }
471 }
472
473 continue
474 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
475 queuePrompts, ok := a.promptQueue.Take(sessionID)
476 if ok {
477 for _, prompt := range queuePrompts {
478 if prompt == "" {
479 continue
480 }
481 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
482 if err != nil {
483 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
484 }
485 msgHistory = append(msgHistory, userMsg)
486 }
487 continue
488 }
489 }
490 if agentMessage.FinishReason() == "" {
491 // Kujtim: could not track down where this is happening but this means its cancelled
492 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
493 _ = a.messages.Update(context.Background(), agentMessage)
494 return a.err(ErrRequestCancelled)
495 }
496 return AgentEvent{
497 Type: AgentEventTypeResponse,
498 Message: agentMessage,
499 Done: true,
500 }
501 }
502}
503
504func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
505 parts := []message.ContentPart{message.TextContent{Text: content}}
506 parts = append(parts, attachmentParts...)
507 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
508 Role: message.User,
509 Parts: parts,
510 })
511}
512
513func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
514 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
515
516 // Create the assistant message first so the spinner shows immediately
517 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
518 Role: message.Assistant,
519 Parts: []message.ContentPart{},
520 Model: a.Model().ID,
521 Provider: a.providerID,
522 })
523 if err != nil {
524 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
525 }
526
527 // Now collect tools (which may block on MCP initialization)
528 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
529
530 // Add the session and message ID into the context if needed by tools.
531 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
532
533 // Process each event in the stream.
534 for event := range eventChan {
535 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
536 if errors.Is(processErr, context.Canceled) {
537 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
538 } else {
539 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
540 }
541 return assistantMsg, nil, processErr
542 }
543 if ctx.Err() != nil {
544 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
545 return assistantMsg, nil, ctx.Err()
546 }
547 }
548
549 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
550 toolCalls := assistantMsg.ToolCalls()
551 for i, toolCall := range toolCalls {
552 select {
553 case <-ctx.Done():
554 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
555 // Make all future tool calls cancelled
556 for j := i; j < len(toolCalls); j++ {
557 toolResults[j] = message.ToolResult{
558 ToolCallID: toolCalls[j].ID,
559 Content: "Tool execution canceled by user",
560 IsError: true,
561 }
562 }
563 goto out
564 default:
565 // Continue processing
566 tool, ok := a.tools.Get(toolCall.Name)
567
568 // Tool not found
569 if !ok {
570 toolResults[i] = message.ToolResult{
571 ToolCallID: toolCall.ID,
572 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
573 IsError: true,
574 }
575 continue
576 }
577
578 // Run tool in goroutine to allow cancellation
579 type toolExecResult struct {
580 response tools.ToolResponse
581 err error
582 }
583 resultChan := make(chan toolExecResult, 1)
584
585 go func() {
586 response, err := tool.Run(ctx, tools.ToolCall{
587 ID: toolCall.ID,
588 Name: toolCall.Name,
589 Input: toolCall.Input,
590 })
591 resultChan <- toolExecResult{response: response, err: err}
592 }()
593
594 var toolResponse tools.ToolResponse
595 var toolErr error
596
597 select {
598 case <-ctx.Done():
599 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
600 // Mark remaining tool calls as cancelled
601 for j := i; j < len(toolCalls); j++ {
602 toolResults[j] = message.ToolResult{
603 ToolCallID: toolCalls[j].ID,
604 Content: "Tool execution canceled by user",
605 IsError: true,
606 }
607 }
608 goto out
609 case result := <-resultChan:
610 toolResponse = result.response
611 toolErr = result.err
612 }
613
614 if toolErr != nil {
615 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
616 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
617 toolResults[i] = message.ToolResult{
618 ToolCallID: toolCall.ID,
619 Content: "Permission denied",
620 IsError: true,
621 }
622 for j := i + 1; j < len(toolCalls); j++ {
623 toolResults[j] = message.ToolResult{
624 ToolCallID: toolCalls[j].ID,
625 Content: "Tool execution canceled by user",
626 IsError: true,
627 }
628 }
629 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
630 break
631 }
632 }
633 toolResults[i] = message.ToolResult{
634 ToolCallID: toolCall.ID,
635 Content: toolResponse.Content,
636 Metadata: toolResponse.Metadata,
637 IsError: toolResponse.IsError,
638 }
639 }
640 }
641out:
642 if len(toolResults) == 0 {
643 return assistantMsg, nil, nil
644 }
645 parts := make([]message.ContentPart, 0)
646 for _, tr := range toolResults {
647 parts = append(parts, tr)
648 }
649 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
650 Role: message.Tool,
651 Parts: parts,
652 Provider: a.providerID,
653 })
654 if err != nil {
655 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
656 }
657
658 return assistantMsg, &msg, err
659}
660
661func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
662 msg.AddFinish(finishReason, message, details)
663 _ = a.messages.Update(ctx, *msg)
664}
665
666func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
667 select {
668 case <-ctx.Done():
669 return ctx.Err()
670 default:
671 // Continue processing.
672 }
673
674 switch event.Type {
675 case provider.EventThinkingDelta:
676 assistantMsg.AppendReasoningContent(event.Thinking)
677 return a.messages.Update(ctx, *assistantMsg)
678 case provider.EventSignatureDelta:
679 assistantMsg.AppendReasoningSignature(event.Signature)
680 return a.messages.Update(ctx, *assistantMsg)
681 case provider.EventContentDelta:
682 assistantMsg.FinishThinking()
683 assistantMsg.AppendContent(event.Content)
684 return a.messages.Update(ctx, *assistantMsg)
685 case provider.EventToolUseStart:
686 assistantMsg.FinishThinking()
687 slog.Info("Tool call started", "toolCall", event.ToolCall)
688 assistantMsg.AddToolCall(*event.ToolCall)
689 return a.messages.Update(ctx, *assistantMsg)
690 case provider.EventToolUseDelta:
691 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
692 return a.messages.Update(ctx, *assistantMsg)
693 case provider.EventToolUseStop:
694 slog.Info("Finished tool call", "toolCall", event.ToolCall)
695 assistantMsg.FinishToolCall(event.ToolCall.ID)
696 return a.messages.Update(ctx, *assistantMsg)
697 case provider.EventError:
698 return event.Error
699 case provider.EventComplete:
700 assistantMsg.FinishThinking()
701 assistantMsg.SetToolCalls(event.Response.ToolCalls)
702 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
703 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
704 return fmt.Errorf("failed to update message: %w", err)
705 }
706 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
707 }
708
709 return nil
710}
711
712func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
713 sess, err := a.sessions.Get(ctx, sessionID)
714 if err != nil {
715 return fmt.Errorf("failed to get session: %w", err)
716 }
717
718 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
719 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
720 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
721 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
722
723 sess.Cost += cost
724 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
725 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
726
727 _, err = a.sessions.Save(ctx, sess)
728 if err != nil {
729 return fmt.Errorf("failed to save session: %w", err)
730 }
731 return nil
732}
733
734func (a *agent) Summarize(ctx context.Context, sessionID string) error {
735 if a.summarizeProvider == nil {
736 return fmt.Errorf("summarize provider not available")
737 }
738
739 // Check if session is busy
740 if a.IsSessionBusy(sessionID) {
741 return ErrSessionBusy
742 }
743
744 // Create a new context with cancellation
745 summarizeCtx, cancel := context.WithCancel(ctx)
746
747 // Store the cancel function in activeRequests to allow cancellation
748 a.activeRequests.Set(sessionID+"-summarize", cancel)
749
750 go func() {
751 defer a.activeRequests.Del(sessionID + "-summarize")
752 defer cancel()
753 event := AgentEvent{
754 Type: AgentEventTypeSummarize,
755 Progress: "Starting summarization...",
756 }
757
758 a.Publish(pubsub.CreatedEvent, event)
759 // Get all messages from the session
760 msgs, err := a.messages.List(summarizeCtx, sessionID)
761 if err != nil {
762 event = AgentEvent{
763 Type: AgentEventTypeError,
764 Error: fmt.Errorf("failed to list messages: %w", err),
765 Done: true,
766 }
767 a.Publish(pubsub.CreatedEvent, event)
768 return
769 }
770 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
771
772 if len(msgs) == 0 {
773 event = AgentEvent{
774 Type: AgentEventTypeError,
775 Error: fmt.Errorf("no messages to summarize"),
776 Done: true,
777 }
778 a.Publish(pubsub.CreatedEvent, event)
779 return
780 }
781
782 event = AgentEvent{
783 Type: AgentEventTypeSummarize,
784 Progress: "Analyzing conversation...",
785 }
786 a.Publish(pubsub.CreatedEvent, event)
787
788 // Add a system message to guide the summarization
789 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
790
791 // Create a new message with the summarize prompt
792 promptMsg := message.Message{
793 Role: message.User,
794 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
795 }
796
797 // Append the prompt to the messages
798 msgsWithPrompt := append(msgs, promptMsg)
799
800 event = AgentEvent{
801 Type: AgentEventTypeSummarize,
802 Progress: "Generating summary...",
803 }
804
805 a.Publish(pubsub.CreatedEvent, event)
806
807 // Send the messages to the summarize provider
808 response := a.summarizeProvider.StreamResponse(
809 summarizeCtx,
810 msgsWithPrompt,
811 nil,
812 )
813 var finalResponse *provider.ProviderResponse
814 for r := range response {
815 if r.Error != nil {
816 event = AgentEvent{
817 Type: AgentEventTypeError,
818 Error: fmt.Errorf("failed to summarize: %w", err),
819 Done: true,
820 }
821 a.Publish(pubsub.CreatedEvent, event)
822 return
823 }
824 finalResponse = r.Response
825 }
826
827 summary := strings.TrimSpace(finalResponse.Content)
828 if summary == "" {
829 event = AgentEvent{
830 Type: AgentEventTypeError,
831 Error: fmt.Errorf("empty summary returned"),
832 Done: true,
833 }
834 a.Publish(pubsub.CreatedEvent, event)
835 return
836 }
837 shell := shell.GetPersistentShell(config.Get().WorkingDir())
838 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
839 event = AgentEvent{
840 Type: AgentEventTypeSummarize,
841 Progress: "Creating new session...",
842 }
843
844 a.Publish(pubsub.CreatedEvent, event)
845 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
846 if err != nil {
847 event = AgentEvent{
848 Type: AgentEventTypeError,
849 Error: fmt.Errorf("failed to get session: %w", err),
850 Done: true,
851 }
852
853 a.Publish(pubsub.CreatedEvent, event)
854 return
855 }
856 // Create a message in the new session with the summary
857 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
858 Role: message.Assistant,
859 Parts: []message.ContentPart{
860 message.TextContent{Text: summary},
861 message.Finish{
862 Reason: message.FinishReasonEndTurn,
863 Time: time.Now().Unix(),
864 },
865 },
866 Model: a.summarizeProvider.Model().ID,
867 Provider: a.summarizeProviderID,
868 })
869 if err != nil {
870 event = AgentEvent{
871 Type: AgentEventTypeError,
872 Error: fmt.Errorf("failed to create summary message: %w", err),
873 Done: true,
874 }
875
876 a.Publish(pubsub.CreatedEvent, event)
877 return
878 }
879 oldSession.SummaryMessageID = msg.ID
880 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
881 oldSession.PromptTokens = 0
882 model := a.summarizeProvider.Model()
883 usage := finalResponse.Usage
884 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
885 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
886 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
887 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
888 oldSession.Cost += cost
889 _, err = a.sessions.Save(summarizeCtx, oldSession)
890 if err != nil {
891 event = AgentEvent{
892 Type: AgentEventTypeError,
893 Error: fmt.Errorf("failed to save session: %w", err),
894 Done: true,
895 }
896 a.Publish(pubsub.CreatedEvent, event)
897 }
898
899 event = AgentEvent{
900 Type: AgentEventTypeSummarize,
901 SessionID: oldSession.ID,
902 Progress: "Summary complete",
903 Done: true,
904 }
905 a.Publish(pubsub.CreatedEvent, event)
906 // Send final success event with the new session ID
907 }()
908
909 return nil
910}
911
912func (a *agent) ClearQueue(sessionID string) {
913 if a.QueuedPrompts(sessionID) > 0 {
914 slog.Info("Clearing queued prompts", "session_id", sessionID)
915 a.promptQueue.Del(sessionID)
916 }
917}
918
919func (a *agent) CancelAll() {
920 if !a.IsBusy() {
921 return
922 }
923 for key := range a.activeRequests.Seq2() {
924 a.Cancel(key) // key is sessionID
925 }
926
927 timeout := time.After(5 * time.Second)
928 for a.IsBusy() {
929 select {
930 case <-timeout:
931 return
932 default:
933 time.Sleep(200 * time.Millisecond)
934 }
935 }
936}
937
938func (a *agent) UpdateModel() error {
939 cfg := config.Get()
940
941 // Get current provider configuration
942 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
943 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
944 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
945 }
946
947 // Check if provider has changed
948 if string(currentProviderCfg.ID) != a.providerID {
949 // Provider changed, need to recreate the main provider
950 model := cfg.GetModelByType(a.agentCfg.Model)
951 if model.ID == "" {
952 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
953 }
954
955 promptID := agentPromptMap[a.agentCfg.ID]
956 if promptID == "" {
957 promptID = prompt.PromptDefault
958 }
959
960 opts := []provider.ProviderClientOption{
961 provider.WithModel(a.agentCfg.Model),
962 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
963 }
964
965 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
966 if err != nil {
967 return fmt.Errorf("failed to create new provider: %w", err)
968 }
969
970 // Update the provider and provider ID
971 a.provider = newProvider
972 a.providerID = string(currentProviderCfg.ID)
973 }
974
975 // Check if providers have changed for title (small) and summarize (large)
976 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
977 var smallModelProviderCfg config.ProviderConfig
978 for p := range cfg.Providers.Seq() {
979 if p.ID == smallModelCfg.Provider {
980 smallModelProviderCfg = p
981 break
982 }
983 }
984 if smallModelProviderCfg.ID == "" {
985 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
986 }
987
988 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
989 var largeModelProviderCfg config.ProviderConfig
990 for p := range cfg.Providers.Seq() {
991 if p.ID == largeModelCfg.Provider {
992 largeModelProviderCfg = p
993 break
994 }
995 }
996 if largeModelProviderCfg.ID == "" {
997 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
998 }
999
1000 // Recreate title provider
1001 titleOpts := []provider.ProviderClientOption{
1002 provider.WithModel(config.SelectedModelTypeSmall),
1003 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1004 provider.WithMaxTokens(40),
1005 }
1006 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1007 if err != nil {
1008 return fmt.Errorf("failed to create new title provider: %w", err)
1009 }
1010 a.titleProvider = newTitleProvider
1011
1012 // Recreate summarize provider if provider changed (now large model)
1013 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1014 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1015 if largeModel == nil {
1016 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1017 }
1018 summarizeOpts := []provider.ProviderClientOption{
1019 provider.WithModel(config.SelectedModelTypeLarge),
1020 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1021 }
1022 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1023 if err != nil {
1024 return fmt.Errorf("failed to create new summarize provider: %w", err)
1025 }
1026 a.summarizeProvider = newSummarizeProvider
1027 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1028 }
1029
1030 return nil
1031}