1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.LazySlice[tools.BaseTool]
75
76 provider provider.Provider
77 providerID string
78
79 titleProvider provider.Provider
80 summarizeProvider provider.Provider
81 summarizeProviderID string
82
83 activeRequests *csync.Map[string, context.CancelFunc]
84
85 promptQueue *csync.Map[string, []string]
86}
87
88var agentPromptMap = map[string]prompt.PromptID{
89 "coder": prompt.PromptCoder,
90 "task": prompt.PromptTask,
91}
92
93func NewAgent(
94 ctx context.Context,
95 agentCfg config.Agent,
96 // These services are needed in the tools
97 permissions permission.Service,
98 sessions session.Service,
99 messages message.Service,
100 history history.Service,
101 lspClients map[string]*lsp.Client,
102) (Service, error) {
103 cfg := config.Get()
104
105 var agentTool tools.BaseTool
106 if agentCfg.ID == "coder" {
107 taskAgentCfg := config.Get().Agents["task"]
108 if taskAgentCfg.ID == "" {
109 return nil, fmt.Errorf("task agent not found in config")
110 }
111 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
112 if err != nil {
113 return nil, fmt.Errorf("failed to create task agent: %w", err)
114 }
115
116 agentTool = NewAgentTool(taskAgent, sessions, messages)
117 }
118
119 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
120 if providerCfg == nil {
121 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
122 }
123 model := config.Get().GetModelByType(agentCfg.Model)
124
125 if model == nil {
126 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
127 }
128
129 promptID := agentPromptMap[agentCfg.ID]
130 if promptID == "" {
131 promptID = prompt.PromptDefault
132 }
133 opts := []provider.ProviderClientOption{
134 provider.WithModel(agentCfg.Model),
135 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
136 }
137 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
138 if err != nil {
139 return nil, err
140 }
141
142 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
143 var smallModelProviderCfg *config.ProviderConfig
144 if smallModelCfg.Provider == providerCfg.ID {
145 smallModelProviderCfg = providerCfg
146 } else {
147 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
148
149 if smallModelProviderCfg.ID == "" {
150 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
151 }
152 }
153 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
154 if smallModel.ID == "" {
155 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
156 }
157
158 titleOpts := []provider.ProviderClientOption{
159 provider.WithModel(config.SelectedModelTypeSmall),
160 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
161 }
162 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
163 if err != nil {
164 return nil, err
165 }
166
167 summarizeOpts := []provider.ProviderClientOption{
168 provider.WithModel(config.SelectedModelTypeLarge),
169 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
170 }
171 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
172 if err != nil {
173 return nil, err
174 }
175
176 toolFn := func() []tools.BaseTool {
177 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
178 defer func() {
179 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
180 }()
181
182 cwd := cfg.WorkingDir()
183 allTools := []tools.BaseTool{
184 tools.NewBashTool(permissions, cwd),
185 tools.NewDownloadTool(permissions, cwd),
186 tools.NewEditTool(lspClients, permissions, history, cwd),
187 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
188 tools.NewFetchTool(permissions, cwd),
189 tools.NewGlobTool(cwd),
190 tools.NewGrepTool(cwd),
191 tools.NewLsTool(permissions, cwd),
192 tools.NewSourcegraphTool(),
193 tools.NewViewTool(lspClients, permissions, cwd),
194 tools.NewWriteTool(lspClients, permissions, history, cwd),
195 }
196
197 mcpToolsOnce.Do(func() {
198 mcpTools = doGetMCPTools(ctx, permissions, cfg)
199 })
200 allTools = append(allTools, mcpTools...)
201
202 if len(lspClients) > 0 {
203 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
204 }
205
206 if agentTool != nil {
207 allTools = append(allTools, agentTool)
208 }
209
210 if agentCfg.AllowedTools == nil {
211 return allTools
212 }
213
214 var filteredTools []tools.BaseTool
215 for _, tool := range allTools {
216 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
217 filteredTools = append(filteredTools, tool)
218 }
219 }
220 return filteredTools
221 }
222
223 return &agent{
224 Broker: pubsub.NewBroker[AgentEvent](),
225 agentCfg: agentCfg,
226 provider: agentProvider,
227 providerID: string(providerCfg.ID),
228 messages: messages,
229 sessions: sessions,
230 titleProvider: titleProvider,
231 summarizeProvider: summarizeProvider,
232 summarizeProviderID: string(providerCfg.ID),
233 activeRequests: csync.NewMap[string, context.CancelFunc](),
234 tools: csync.NewLazySlice(toolFn),
235 promptQueue: csync.NewMap[string, []string](),
236 }, nil
237}
238
239func (a *agent) Model() catwalk.Model {
240 return *config.Get().GetModelByType(a.agentCfg.Model)
241}
242
243func (a *agent) Cancel(sessionID string) {
244 // Cancel regular requests
245 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
246 slog.Info("Request cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249
250 // Also check for summarize requests
251 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
252 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
253 cancel()
254 }
255
256 if a.QueuedPrompts(sessionID) > 0 {
257 slog.Info("Clearing queued prompts", "session_id", sessionID)
258 a.promptQueue.Del(sessionID)
259 }
260}
261
262func (a *agent) IsBusy() bool {
263 var busy bool
264 for cancelFunc := range a.activeRequests.Seq() {
265 if cancelFunc != nil {
266 busy = true
267 break
268 }
269 }
270 return busy
271}
272
273func (a *agent) IsSessionBusy(sessionID string) bool {
274 _, busy := a.activeRequests.Get(sessionID)
275 return busy
276}
277
278func (a *agent) QueuedPrompts(sessionID string) int {
279 l, ok := a.promptQueue.Get(sessionID)
280 if !ok {
281 return 0
282 }
283 return len(l)
284}
285
286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
287 if content == "" {
288 return nil
289 }
290 if a.titleProvider == nil {
291 return nil
292 }
293 session, err := a.sessions.Get(ctx, sessionID)
294 if err != nil {
295 return err
296 }
297 parts := []message.ContentPart{message.TextContent{
298 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
299 }}
300
301 // Use streaming approach like summarization
302 response := a.titleProvider.StreamResponse(
303 ctx,
304 []message.Message{
305 {
306 Role: message.User,
307 Parts: parts,
308 },
309 },
310 nil,
311 )
312
313 var finalResponse *provider.ProviderResponse
314 for r := range response {
315 if r.Error != nil {
316 return r.Error
317 }
318 finalResponse = r.Response
319 }
320
321 if finalResponse == nil {
322 return fmt.Errorf("no response received from title provider")
323 }
324
325 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
326 if title == "" {
327 return nil
328 }
329
330 session.Title = title
331 _, err = a.sessions.Save(ctx, session)
332 return err
333}
334
335func (a *agent) err(err error) AgentEvent {
336 return AgentEvent{
337 Type: AgentEventTypeError,
338 Error: err,
339 }
340}
341
342func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
343 if !a.Model().SupportsImages && attachments != nil {
344 attachments = nil
345 }
346 events := make(chan AgentEvent)
347 if a.IsSessionBusy(sessionID) {
348 existing, ok := a.promptQueue.Get(sessionID)
349 if !ok {
350 existing = []string{}
351 }
352 existing = append(existing, content)
353 a.promptQueue.Set(sessionID, existing)
354 return nil, nil
355 }
356
357 genCtx, cancel := context.WithCancel(ctx)
358 defer cancel() // Ensure cancel is always called
359
360 a.activeRequests.Set(sessionID, cancel)
361 defer a.activeRequests.Del(sessionID) // Clean up on exit
362
363 go func() {
364 slog.Debug("Request started", "sessionID", sessionID)
365 defer log.RecoverPanic("agent.Run", func() {
366 events <- a.err(fmt.Errorf("panic while running the agent"))
367 })
368 var attachmentParts []message.ContentPart
369 for _, attachment := range attachments {
370 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
371 }
372 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
373 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
374 slog.Error(result.Error.Error())
375 }
376 slog.Debug("Request completed", "sessionID", sessionID)
377 a.Publish(pubsub.CreatedEvent, result)
378 events <- result
379 close(events)
380 }()
381 return events, nil
382}
383
384func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
385 cfg := config.Get()
386 // List existing messages; if none, start title generation asynchronously.
387 msgs, err := a.messages.List(ctx, sessionID)
388 if err != nil {
389 return a.err(fmt.Errorf("failed to list messages: %w", err))
390 }
391
392 // sliding window to limit message history
393 maxMessagesInContext := cfg.Options.MaxMessages
394 if maxMessagesInContext > 0 && len(msgs) > maxMessagesInContext {
395 // Keep the first message (usually system/context) and the last N-1 messages
396 msgs = append(msgs[:1], msgs[len(msgs)-maxMessagesInContext+1:]...)
397 }
398
399 if len(msgs) == 0 {
400 // Use a context with timeout for title generation
401 titleCtx, titleCancel := context.WithTimeout(context.Background(), 30*time.Second)
402 go func() {
403 defer titleCancel()
404 defer log.RecoverPanic("agent.Run", func() {
405 slog.Error("panic while generating title")
406 })
407 titleErr := a.generateTitle(titleCtx, sessionID, content)
408 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
409 slog.Error("failed to generate title", "error", titleErr)
410 }
411 }()
412 }
413 session, err := a.sessions.Get(ctx, sessionID)
414 if err != nil {
415 return a.err(fmt.Errorf("failed to get session: %w", err))
416 }
417 if session.SummaryMessageID != "" {
418 summaryMsgInex := -1
419 for i, msg := range msgs {
420 if msg.ID == session.SummaryMessageID {
421 summaryMsgInex = i
422 break
423 }
424 }
425 if summaryMsgInex != -1 {
426 msgs = msgs[summaryMsgInex:]
427 msgs[0].Role = message.User
428 }
429 }
430
431 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
432 if err != nil {
433 return a.err(fmt.Errorf("failed to create user message: %w", err))
434 }
435 // Append the new user message to the conversation history.
436 msgHistory := append(msgs, userMsg)
437
438 for {
439 // Check for cancellation before each iteration
440 select {
441 case <-ctx.Done():
442 return a.err(ctx.Err())
443 default:
444 // Continue processing
445 }
446 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
447 if err != nil {
448 if errors.Is(err, context.Canceled) {
449 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
450 a.messages.Update(context.Background(), agentMessage)
451 return a.err(ErrRequestCancelled)
452 }
453 return a.err(fmt.Errorf("failed to process events: %w", err))
454 }
455 if cfg.Options.Debug {
456 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
457 }
458 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
459 // We are not done, we need to respond with the tool response
460 msgHistory = append(msgHistory, agentMessage, *toolResults)
461 // If there are queued prompts, process the next one
462 nextPrompt, ok := a.promptQueue.Take(sessionID)
463 if ok {
464 for _, prompt := range nextPrompt {
465 // Create a new user message for the queued prompt
466 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
467 if err != nil {
468 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
469 }
470 // Append the new user message to the conversation history
471 msgHistory = append(msgHistory, userMsg)
472 }
473 }
474
475 continue
476 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
477 queuePrompts, ok := a.promptQueue.Take(sessionID)
478 if ok {
479 for _, prompt := range queuePrompts {
480 if prompt == "" {
481 continue
482 }
483 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
484 if err != nil {
485 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
486 }
487 msgHistory = append(msgHistory, userMsg)
488 }
489 continue
490 }
491 }
492 if agentMessage.FinishReason() == "" {
493 // Kujtim: could not track down where this is happening but this means its cancelled
494 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
495 _ = a.messages.Update(context.Background(), agentMessage)
496 return a.err(ErrRequestCancelled)
497 }
498 return AgentEvent{
499 Type: AgentEventTypeResponse,
500 Message: agentMessage,
501 Done: true,
502 }
503 }
504}
505
506func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
507 parts := []message.ContentPart{message.TextContent{Text: content}}
508 parts = append(parts, attachmentParts...)
509 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
510 Role: message.User,
511 Parts: parts,
512 })
513}
514
515func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
516 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
517
518 // Create the assistant message first so the spinner shows immediately
519 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
520 Role: message.Assistant,
521 Parts: []message.ContentPart{},
522 Model: a.Model().ID,
523 Provider: a.providerID,
524 })
525 if err != nil {
526 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
527 }
528
529 // Now collect tools (which may block on MCP initialization)
530 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
531
532 // Add the session and message ID into the context if needed by tools.
533 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
534
535 // Process each event in the stream.
536 for event := range eventChan {
537 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
538 if errors.Is(processErr, context.Canceled) {
539 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
540 } else {
541 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
542 }
543 return assistantMsg, nil, processErr
544 }
545 if ctx.Err() != nil {
546 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
547 return assistantMsg, nil, ctx.Err()
548 }
549 }
550
551 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
552 toolCalls := assistantMsg.ToolCalls()
553 for i, toolCall := range toolCalls {
554 select {
555 case <-ctx.Done():
556 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
557 // Make all future tool calls cancelled
558 for j := i; j < len(toolCalls); j++ {
559 toolResults[j] = message.ToolResult{
560 ToolCallID: toolCalls[j].ID,
561 Content: "Tool execution canceled by user",
562 IsError: true,
563 }
564 }
565 goto out
566 default:
567 // Continue processing
568 var tool tools.BaseTool
569 for availableTool := range a.tools.Seq() {
570 if availableTool.Info().Name == toolCall.Name {
571 tool = availableTool
572 break
573 }
574 }
575
576 // Tool not found
577 if tool == nil {
578 toolResults[i] = message.ToolResult{
579 ToolCallID: toolCall.ID,
580 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
581 IsError: true,
582 }
583 continue
584 }
585
586 // Run tool in goroutine to allow cancellation
587 type toolExecResult struct {
588 response tools.ToolResponse
589 err error
590 }
591 resultChan := make(chan toolExecResult, 1)
592
593 go func() {
594 response, err := tool.Run(ctx, tools.ToolCall{
595 ID: toolCall.ID,
596 Name: toolCall.Name,
597 Input: toolCall.Input,
598 })
599 resultChan <- toolExecResult{response: response, err: err}
600 }()
601
602 var toolResponse tools.ToolResponse
603 var toolErr error
604
605 select {
606 case <-ctx.Done():
607 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
608 // Mark remaining tool calls as cancelled
609 for j := i; j < len(toolCalls); j++ {
610 toolResults[j] = message.ToolResult{
611 ToolCallID: toolCalls[j].ID,
612 Content: "Tool execution canceled by user",
613 IsError: true,
614 }
615 }
616 goto out
617 case result := <-resultChan:
618 toolResponse = result.response
619 toolErr = result.err
620 }
621
622 if toolErr != nil {
623 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
624 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
625 toolResults[i] = message.ToolResult{
626 ToolCallID: toolCall.ID,
627 Content: "Permission denied",
628 IsError: true,
629 }
630 for j := i + 1; j < len(toolCalls); j++ {
631 toolResults[j] = message.ToolResult{
632 ToolCallID: toolCalls[j].ID,
633 Content: "Tool execution canceled by user",
634 IsError: true,
635 }
636 }
637 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
638 break
639 }
640 }
641 toolResults[i] = message.ToolResult{
642 ToolCallID: toolCall.ID,
643 Content: toolResponse.Content,
644 Metadata: toolResponse.Metadata,
645 IsError: toolResponse.IsError,
646 }
647 }
648 }
649out:
650 if len(toolResults) == 0 {
651 return assistantMsg, nil, nil
652 }
653 parts := make([]message.ContentPart, 0)
654 for _, tr := range toolResults {
655 parts = append(parts, tr)
656 }
657 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
658 Role: message.Tool,
659 Parts: parts,
660 Provider: a.providerID,
661 })
662 if err != nil {
663 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
664 }
665
666 return assistantMsg, &msg, err
667}
668
669func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
670 msg.AddFinish(finishReason, message, details)
671 _ = a.messages.Update(ctx, *msg)
672}
673
674func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
675 select {
676 case <-ctx.Done():
677 return ctx.Err()
678 default:
679 // Continue processing.
680 }
681
682 switch event.Type {
683 case provider.EventThinkingDelta:
684 assistantMsg.AppendReasoningContent(event.Thinking)
685 return a.messages.Update(ctx, *assistantMsg)
686 case provider.EventSignatureDelta:
687 assistantMsg.AppendReasoningSignature(event.Signature)
688 return a.messages.Update(ctx, *assistantMsg)
689 case provider.EventContentDelta:
690 assistantMsg.FinishThinking()
691 assistantMsg.AppendContent(event.Content)
692 return a.messages.Update(ctx, *assistantMsg)
693 case provider.EventToolUseStart:
694 assistantMsg.FinishThinking()
695 slog.Info("Tool call started", "toolCall", event.ToolCall)
696 assistantMsg.AddToolCall(*event.ToolCall)
697 return a.messages.Update(ctx, *assistantMsg)
698 case provider.EventToolUseDelta:
699 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
700 return a.messages.Update(ctx, *assistantMsg)
701 case provider.EventToolUseStop:
702 slog.Info("Finished tool call", "toolCall", event.ToolCall)
703 assistantMsg.FinishToolCall(event.ToolCall.ID)
704 return a.messages.Update(ctx, *assistantMsg)
705 case provider.EventError:
706 return event.Error
707 case provider.EventComplete:
708 assistantMsg.FinishThinking()
709 assistantMsg.SetToolCalls(event.Response.ToolCalls)
710 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
711 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
712 return fmt.Errorf("failed to update message: %w", err)
713 }
714 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
715 }
716
717 return nil
718}
719
720func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
721 sess, err := a.sessions.Get(ctx, sessionID)
722 if err != nil {
723 return fmt.Errorf("failed to get session: %w", err)
724 }
725
726 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
727 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
728 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
729 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
730
731 sess.Cost += cost
732 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
733 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
734
735 _, err = a.sessions.Save(ctx, sess)
736 if err != nil {
737 return fmt.Errorf("failed to save session: %w", err)
738 }
739 return nil
740}
741
742func (a *agent) Summarize(ctx context.Context, sessionID string) error {
743 if a.summarizeProvider == nil {
744 return fmt.Errorf("summarize provider not available")
745 }
746
747 // Check if session is busy
748 if a.IsSessionBusy(sessionID) {
749 return ErrSessionBusy
750 }
751
752 // Create a new context with cancellation
753 summarizeCtx, cancel := context.WithCancel(ctx)
754
755 // Store the cancel function in activeRequests to allow cancellation
756 a.activeRequests.Set(sessionID+"-summarize", cancel)
757
758 go func() {
759 defer a.activeRequests.Del(sessionID + "-summarize")
760 defer cancel()
761 event := AgentEvent{
762 Type: AgentEventTypeSummarize,
763 Progress: "Starting summarization...",
764 }
765
766 a.Publish(pubsub.CreatedEvent, event)
767 // Get all messages from the session
768 msgs, err := a.messages.List(summarizeCtx, sessionID)
769 if err != nil {
770 event = AgentEvent{
771 Type: AgentEventTypeError,
772 Error: fmt.Errorf("failed to list messages: %w", err),
773 Done: true,
774 }
775 a.Publish(pubsub.CreatedEvent, event)
776 return
777 }
778 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
779
780 if len(msgs) == 0 {
781 event = AgentEvent{
782 Type: AgentEventTypeError,
783 Error: fmt.Errorf("no messages to summarize"),
784 Done: true,
785 }
786 a.Publish(pubsub.CreatedEvent, event)
787 return
788 }
789
790 event = AgentEvent{
791 Type: AgentEventTypeSummarize,
792 Progress: "Analyzing conversation...",
793 }
794 a.Publish(pubsub.CreatedEvent, event)
795
796 // Add a system message to guide the summarization
797 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
798
799 // Create a new message with the summarize prompt
800 promptMsg := message.Message{
801 Role: message.User,
802 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
803 }
804
805 // Append the prompt to the messages
806 msgsWithPrompt := append(msgs, promptMsg)
807
808 event = AgentEvent{
809 Type: AgentEventTypeSummarize,
810 Progress: "Generating summary...",
811 }
812
813 a.Publish(pubsub.CreatedEvent, event)
814
815 // Send the messages to the summarize provider
816 response := a.summarizeProvider.StreamResponse(
817 summarizeCtx,
818 msgsWithPrompt,
819 nil,
820 )
821 var finalResponse *provider.ProviderResponse
822 for r := range response {
823 if r.Error != nil {
824 event = AgentEvent{
825 Type: AgentEventTypeError,
826 Error: fmt.Errorf("failed to summarize: %w", err),
827 Done: true,
828 }
829 a.Publish(pubsub.CreatedEvent, event)
830 return
831 }
832 finalResponse = r.Response
833 }
834
835 summary := strings.TrimSpace(finalResponse.Content)
836 if summary == "" {
837 event = AgentEvent{
838 Type: AgentEventTypeError,
839 Error: fmt.Errorf("empty summary returned"),
840 Done: true,
841 }
842 a.Publish(pubsub.CreatedEvent, event)
843 return
844 }
845 shell := shell.GetPersistentShell(config.Get().WorkingDir())
846 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
847 event = AgentEvent{
848 Type: AgentEventTypeSummarize,
849 Progress: "Creating new session...",
850 }
851
852 a.Publish(pubsub.CreatedEvent, event)
853 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
854 if err != nil {
855 event = AgentEvent{
856 Type: AgentEventTypeError,
857 Error: fmt.Errorf("failed to get session: %w", err),
858 Done: true,
859 }
860
861 a.Publish(pubsub.CreatedEvent, event)
862 return
863 }
864 // Create a message in the new session with the summary
865 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
866 Role: message.Assistant,
867 Parts: []message.ContentPart{
868 message.TextContent{Text: summary},
869 message.Finish{
870 Reason: message.FinishReasonEndTurn,
871 Time: time.Now().Unix(),
872 },
873 },
874 Model: a.summarizeProvider.Model().ID,
875 Provider: a.summarizeProviderID,
876 })
877 if err != nil {
878 event = AgentEvent{
879 Type: AgentEventTypeError,
880 Error: fmt.Errorf("failed to create summary message: %w", err),
881 Done: true,
882 }
883
884 a.Publish(pubsub.CreatedEvent, event)
885 return
886 }
887 oldSession.SummaryMessageID = msg.ID
888 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
889 oldSession.PromptTokens = 0
890 model := a.summarizeProvider.Model()
891 usage := finalResponse.Usage
892 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
893 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
894 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
895 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
896 oldSession.Cost += cost
897 _, err = a.sessions.Save(summarizeCtx, oldSession)
898 if err != nil {
899 event = AgentEvent{
900 Type: AgentEventTypeError,
901 Error: fmt.Errorf("failed to save session: %w", err),
902 Done: true,
903 }
904 a.Publish(pubsub.CreatedEvent, event)
905 }
906
907 event = AgentEvent{
908 Type: AgentEventTypeSummarize,
909 SessionID: oldSession.ID,
910 Progress: "Summary complete",
911 Done: true,
912 }
913 a.Publish(pubsub.CreatedEvent, event)
914 // Send final success event with the new session ID
915 }()
916
917 return nil
918}
919
920func (a *agent) ClearQueue(sessionID string) {
921 if a.QueuedPrompts(sessionID) > 0 {
922 slog.Info("Clearing queued prompts", "session_id", sessionID)
923 a.promptQueue.Del(sessionID)
924 }
925}
926
927func (a *agent) CancelAll() {
928 if !a.IsBusy() {
929 return
930 }
931 for key := range a.activeRequests.Seq2() {
932 a.Cancel(key) // key is sessionID
933 }
934
935 timeout := time.After(5 * time.Second)
936 for a.IsBusy() {
937 select {
938 case <-timeout:
939 return
940 default:
941 time.Sleep(200 * time.Millisecond)
942 }
943 }
944}
945
946func (a *agent) UpdateModel() error {
947 cfg := config.Get()
948
949 // Get current provider configuration
950 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
951 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
952 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
953 }
954
955 // Check if provider has changed
956 if string(currentProviderCfg.ID) != a.providerID {
957 // Provider changed, need to recreate the main provider
958 model := cfg.GetModelByType(a.agentCfg.Model)
959 if model.ID == "" {
960 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
961 }
962
963 promptID := agentPromptMap[a.agentCfg.ID]
964 if promptID == "" {
965 promptID = prompt.PromptDefault
966 }
967
968 opts := []provider.ProviderClientOption{
969 provider.WithModel(a.agentCfg.Model),
970 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
971 }
972
973 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
974 if err != nil {
975 return fmt.Errorf("failed to create new provider: %w", err)
976 }
977
978 // Update the provider and provider ID
979 a.provider = newProvider
980 a.providerID = string(currentProviderCfg.ID)
981 }
982
983 // Check if providers have changed for title (small) and summarize (large)
984 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
985 var smallModelProviderCfg config.ProviderConfig
986 for p := range cfg.Providers.Seq() {
987 if p.ID == smallModelCfg.Provider {
988 smallModelProviderCfg = p
989 break
990 }
991 }
992 if smallModelProviderCfg.ID == "" {
993 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
994 }
995
996 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
997 var largeModelProviderCfg config.ProviderConfig
998 for p := range cfg.Providers.Seq() {
999 if p.ID == largeModelCfg.Provider {
1000 largeModelProviderCfg = p
1001 break
1002 }
1003 }
1004 if largeModelProviderCfg.ID == "" {
1005 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1006 }
1007
1008 // Recreate title provider
1009 titleOpts := []provider.ProviderClientOption{
1010 provider.WithModel(config.SelectedModelTypeSmall),
1011 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1012 provider.WithMaxTokens(40),
1013 }
1014 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1015 if err != nil {
1016 return fmt.Errorf("failed to create new title provider: %w", err)
1017 }
1018 a.titleProvider = newTitleProvider
1019
1020 // Recreate summarize provider if provider changed (now large model)
1021 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1022 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1023 if largeModel == nil {
1024 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1025 }
1026 summarizeOpts := []provider.ProviderClientOption{
1027 provider.WithModel(config.SelectedModelTypeLarge),
1028 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1029 }
1030 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1031 if err != nil {
1032 return fmt.Errorf("failed to create new summarize provider: %w", err)
1033 }
1034 a.summarizeProvider = newSummarizeProvider
1035 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1036 }
1037
1038 return nil
1039}