1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.LazySlice[tools.BaseTool]
75 // We need this to be able to update it when model changes
76 agentToolFn func() (tools.BaseTool, error)
77
78 provider provider.Provider
79 providerID string
80
81 titleProvider provider.Provider
82 summarizeProvider provider.Provider
83 summarizeProviderID string
84
85 activeRequests *csync.Map[string, context.CancelFunc]
86
87 promptQueue *csync.Map[string, []string]
88}
89
90var agentPromptMap = map[string]prompt.PromptID{
91 "coder": prompt.PromptCoder,
92 "task": prompt.PromptTask,
93}
94
95func NewAgent(
96 ctx context.Context,
97 agentCfg config.Agent,
98 // These services are needed in the tools
99 permissions permission.Service,
100 sessions session.Service,
101 messages message.Service,
102 history history.Service,
103 lspClients map[string]*lsp.Client,
104) (Service, error) {
105 cfg := config.Get()
106
107 var agentToolFn func() (tools.BaseTool, error)
108 if agentCfg.ID == "coder" {
109 agentToolFn = func() (tools.BaseTool, error) {
110 taskAgentCfg := config.Get().Agents["task"]
111 if taskAgentCfg.ID == "" {
112 return nil, fmt.Errorf("task agent not found in config")
113 }
114 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
115 if err != nil {
116 return nil, fmt.Errorf("failed to create task agent: %w", err)
117 }
118 return NewAgentTool(taskAgent, sessions, messages), nil
119 }
120 }
121
122 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
123 if providerCfg == nil {
124 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
125 }
126 model := config.Get().GetModelByType(agentCfg.Model)
127
128 if model == nil {
129 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
130 }
131
132 promptID := agentPromptMap[agentCfg.ID]
133 if promptID == "" {
134 promptID = prompt.PromptDefault
135 }
136 opts := []provider.ProviderClientOption{
137 provider.WithModel(agentCfg.Model),
138 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
139 }
140 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
141 if err != nil {
142 return nil, err
143 }
144
145 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
146 var smallModelProviderCfg *config.ProviderConfig
147 if smallModelCfg.Provider == providerCfg.ID {
148 smallModelProviderCfg = providerCfg
149 } else {
150 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
151
152 if smallModelProviderCfg.ID == "" {
153 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
154 }
155 }
156 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
157 if smallModel.ID == "" {
158 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
159 }
160
161 titleOpts := []provider.ProviderClientOption{
162 provider.WithModel(config.SelectedModelTypeSmall),
163 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
164 }
165 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
166 if err != nil {
167 return nil, err
168 }
169
170 summarizeOpts := []provider.ProviderClientOption{
171 provider.WithModel(config.SelectedModelTypeLarge),
172 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
173 }
174 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
175 if err != nil {
176 return nil, err
177 }
178
179 toolFn := func() []tools.BaseTool {
180 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
181 defer func() {
182 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
183 }()
184
185 cwd := cfg.WorkingDir()
186 allTools := []tools.BaseTool{
187 tools.NewBashTool(permissions, cwd),
188 tools.NewDownloadTool(permissions, cwd),
189 tools.NewEditTool(lspClients, permissions, history, cwd),
190 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
191 tools.NewFetchTool(permissions, cwd),
192 tools.NewGlobTool(cwd),
193 tools.NewGrepTool(cwd),
194 tools.NewLsTool(permissions, cwd),
195 tools.NewSourcegraphTool(),
196 tools.NewViewTool(lspClients, permissions, cwd),
197 tools.NewWriteTool(lspClients, permissions, history, cwd),
198 }
199
200 mcpToolsOnce.Do(func() {
201 mcpTools = doGetMCPTools(ctx, permissions, cfg)
202 })
203 allTools = append(allTools, mcpTools...)
204
205 if len(lspClients) > 0 {
206 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
207 }
208
209 if agentCfg.AllowedTools == nil {
210 return allTools
211 }
212
213 var filteredTools []tools.BaseTool
214 for _, tool := range allTools {
215 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
216 filteredTools = append(filteredTools, tool)
217 }
218 }
219 return filteredTools
220 }
221
222 return &agent{
223 Broker: pubsub.NewBroker[AgentEvent](),
224 agentCfg: agentCfg,
225 provider: agentProvider,
226 providerID: string(providerCfg.ID),
227 messages: messages,
228 sessions: sessions,
229 titleProvider: titleProvider,
230 summarizeProvider: summarizeProvider,
231 summarizeProviderID: string(providerCfg.ID),
232 agentToolFn: agentToolFn,
233 activeRequests: csync.NewMap[string, context.CancelFunc](),
234 tools: csync.NewLazySlice(toolFn),
235 promptQueue: csync.NewMap[string, []string](),
236 }, nil
237}
238
239func (a *agent) Model() catwalk.Model {
240 return *config.Get().GetModelByType(a.agentCfg.Model)
241}
242
243func (a *agent) Cancel(sessionID string) {
244 // Cancel regular requests
245 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
246 slog.Info("Request cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249
250 // Also check for summarize requests
251 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
252 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
253 cancel()
254 }
255
256 if a.QueuedPrompts(sessionID) > 0 {
257 slog.Info("Clearing queued prompts", "session_id", sessionID)
258 a.promptQueue.Del(sessionID)
259 }
260}
261
262func (a *agent) IsBusy() bool {
263 var busy bool
264 for cancelFunc := range a.activeRequests.Seq() {
265 if cancelFunc != nil {
266 busy = true
267 break
268 }
269 }
270 return busy
271}
272
273func (a *agent) IsSessionBusy(sessionID string) bool {
274 _, busy := a.activeRequests.Get(sessionID)
275 return busy
276}
277
278func (a *agent) QueuedPrompts(sessionID string) int {
279 l, ok := a.promptQueue.Get(sessionID)
280 if !ok {
281 return 0
282 }
283 return len(l)
284}
285
286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
287 if content == "" {
288 return nil
289 }
290 if a.titleProvider == nil {
291 return nil
292 }
293 session, err := a.sessions.Get(ctx, sessionID)
294 if err != nil {
295 return err
296 }
297 parts := []message.ContentPart{message.TextContent{
298 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
299 }}
300
301 // Use streaming approach like summarization
302 response := a.titleProvider.StreamResponse(
303 ctx,
304 []message.Message{
305 {
306 Role: message.User,
307 Parts: parts,
308 },
309 },
310 nil,
311 )
312
313 var finalResponse *provider.ProviderResponse
314 for r := range response {
315 if r.Error != nil {
316 return r.Error
317 }
318 finalResponse = r.Response
319 }
320
321 if finalResponse == nil {
322 return fmt.Errorf("no response received from title provider")
323 }
324
325 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
326 if title == "" {
327 return nil
328 }
329
330 session.Title = title
331 _, err = a.sessions.Save(ctx, session)
332 return err
333}
334
335func (a *agent) err(err error) AgentEvent {
336 return AgentEvent{
337 Type: AgentEventTypeError,
338 Error: err,
339 }
340}
341
342func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
343 if !a.Model().SupportsImages && attachments != nil {
344 attachments = nil
345 }
346 events := make(chan AgentEvent, 1)
347 if a.IsSessionBusy(sessionID) {
348 existing, ok := a.promptQueue.Get(sessionID)
349 if !ok {
350 existing = []string{}
351 }
352 existing = append(existing, content)
353 a.promptQueue.Set(sessionID, existing)
354 return nil, nil
355 }
356
357 genCtx, cancel := context.WithCancel(ctx)
358
359 a.activeRequests.Set(sessionID, cancel)
360 go func() {
361 slog.Debug("Request started", "sessionID", sessionID)
362 defer log.RecoverPanic("agent.Run", func() {
363 events <- a.err(fmt.Errorf("panic while running the agent"))
364 })
365 var attachmentParts []message.ContentPart
366 for _, attachment := range attachments {
367 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
368 }
369 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
370 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
371 slog.Error(result.Error.Error())
372 }
373 slog.Debug("Request completed", "sessionID", sessionID)
374 a.activeRequests.Del(sessionID)
375 cancel()
376 a.Publish(pubsub.CreatedEvent, result)
377 events <- result
378 close(events)
379 }()
380 return events, nil
381}
382
383func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
384 cfg := config.Get()
385 // List existing messages; if none, start title generation asynchronously.
386 msgs, err := a.messages.List(ctx, sessionID)
387 if err != nil {
388 return a.err(fmt.Errorf("failed to list messages: %w", err))
389 }
390 if len(msgs) == 0 {
391 go func() {
392 defer log.RecoverPanic("agent.Run", func() {
393 slog.Error("panic while generating title")
394 })
395 titleErr := a.generateTitle(ctx, sessionID, content)
396 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
397 slog.Error("failed to generate title", "error", titleErr)
398 }
399 }()
400 }
401 session, err := a.sessions.Get(ctx, sessionID)
402 if err != nil {
403 return a.err(fmt.Errorf("failed to get session: %w", err))
404 }
405 if session.SummaryMessageID != "" {
406 summaryMsgInex := -1
407 for i, msg := range msgs {
408 if msg.ID == session.SummaryMessageID {
409 summaryMsgInex = i
410 break
411 }
412 }
413 if summaryMsgInex != -1 {
414 msgs = msgs[summaryMsgInex:]
415 msgs[0].Role = message.User
416 }
417 }
418
419 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
420 if err != nil {
421 return a.err(fmt.Errorf("failed to create user message: %w", err))
422 }
423 // Append the new user message to the conversation history.
424 msgHistory := append(msgs, userMsg)
425
426 for {
427 // Check for cancellation before each iteration
428 select {
429 case <-ctx.Done():
430 return a.err(ctx.Err())
431 default:
432 // Continue processing
433 }
434 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
435 if err != nil {
436 if errors.Is(err, context.Canceled) {
437 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
438 a.messages.Update(context.Background(), agentMessage)
439 return a.err(ErrRequestCancelled)
440 }
441 return a.err(fmt.Errorf("failed to process events: %w", err))
442 }
443 if cfg.Options.Debug {
444 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
445 }
446 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
447 // We are not done, we need to respond with the tool response
448 msgHistory = append(msgHistory, agentMessage, *toolResults)
449 // If there are queued prompts, process the next one
450 nextPrompt, ok := a.promptQueue.Take(sessionID)
451 if ok {
452 for _, prompt := range nextPrompt {
453 // Create a new user message for the queued prompt
454 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
455 if err != nil {
456 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
457 }
458 // Append the new user message to the conversation history
459 msgHistory = append(msgHistory, userMsg)
460 }
461 }
462
463 continue
464 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
465 queuePrompts, ok := a.promptQueue.Take(sessionID)
466 if ok {
467 for _, prompt := range queuePrompts {
468 if prompt == "" {
469 continue
470 }
471 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
472 if err != nil {
473 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
474 }
475 msgHistory = append(msgHistory, userMsg)
476 }
477 continue
478 }
479 }
480 if agentMessage.FinishReason() == "" {
481 // Kujtim: could not track down where this is happening but this means its cancelled
482 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
483 _ = a.messages.Update(context.Background(), agentMessage)
484 return a.err(ErrRequestCancelled)
485 }
486 return AgentEvent{
487 Type: AgentEventTypeResponse,
488 Message: agentMessage,
489 Done: true,
490 }
491 }
492}
493
494func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
495 parts := []message.ContentPart{message.TextContent{Text: content}}
496 parts = append(parts, attachmentParts...)
497 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
498 Role: message.User,
499 Parts: parts,
500 })
501}
502
503func (a *agent) getAllTools() ([]tools.BaseTool, error) {
504 allTools := slices.Collect(a.tools.Seq())
505 if a.agentToolFn != nil {
506 agentTool, agentToolErr := a.agentToolFn()
507 if agentToolErr != nil {
508 return nil, agentToolErr
509 }
510 allTools = append(allTools, agentTool)
511 }
512 return allTools, nil
513}
514
515func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
516 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
517
518 // Create the assistant message first so the spinner shows immediately
519 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
520 Role: message.Assistant,
521 Parts: []message.ContentPart{},
522 Model: a.Model().ID,
523 Provider: a.providerID,
524 })
525 if err != nil {
526 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
527 }
528
529 allTools, toolsErr := a.getAllTools()
530 if toolsErr != nil {
531 return assistantMsg, nil, toolsErr
532 }
533 // Now collect tools (which may block on MCP initialization)
534 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
535
536 // Add the session and message ID into the context if needed by tools.
537 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
538
539 // Process each event in the stream.
540 for event := range eventChan {
541 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
542 if errors.Is(processErr, context.Canceled) {
543 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
544 } else {
545 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
546 }
547 return assistantMsg, nil, processErr
548 }
549 if ctx.Err() != nil {
550 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
551 return assistantMsg, nil, ctx.Err()
552 }
553 }
554
555 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
556 toolCalls := assistantMsg.ToolCalls()
557 for i, toolCall := range toolCalls {
558 select {
559 case <-ctx.Done():
560 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
561 // Make all future tool calls cancelled
562 for j := i; j < len(toolCalls); j++ {
563 toolResults[j] = message.ToolResult{
564 ToolCallID: toolCalls[j].ID,
565 Content: "Tool execution canceled by user",
566 IsError: true,
567 }
568 }
569 goto out
570 default:
571 // Continue processing
572 var tool tools.BaseTool
573 allTools, _ := a.getAllTools()
574 for _, availableTool := range allTools {
575 if availableTool.Info().Name == toolCall.Name {
576 tool = availableTool
577 break
578 }
579 }
580
581 // Tool not found
582 if tool == nil {
583 toolResults[i] = message.ToolResult{
584 ToolCallID: toolCall.ID,
585 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
586 IsError: true,
587 }
588 continue
589 }
590
591 // Run tool in goroutine to allow cancellation
592 type toolExecResult struct {
593 response tools.ToolResponse
594 err error
595 }
596 resultChan := make(chan toolExecResult, 1)
597
598 go func() {
599 response, err := tool.Run(ctx, tools.ToolCall{
600 ID: toolCall.ID,
601 Name: toolCall.Name,
602 Input: toolCall.Input,
603 })
604 resultChan <- toolExecResult{response: response, err: err}
605 }()
606
607 var toolResponse tools.ToolResponse
608 var toolErr error
609
610 select {
611 case <-ctx.Done():
612 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
613 // Mark remaining tool calls as cancelled
614 for j := i; j < len(toolCalls); j++ {
615 toolResults[j] = message.ToolResult{
616 ToolCallID: toolCalls[j].ID,
617 Content: "Tool execution canceled by user",
618 IsError: true,
619 }
620 }
621 goto out
622 case result := <-resultChan:
623 toolResponse = result.response
624 toolErr = result.err
625 }
626
627 if toolErr != nil {
628 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
629 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
630 toolResults[i] = message.ToolResult{
631 ToolCallID: toolCall.ID,
632 Content: "Permission denied",
633 IsError: true,
634 }
635 for j := i + 1; j < len(toolCalls); j++ {
636 toolResults[j] = message.ToolResult{
637 ToolCallID: toolCalls[j].ID,
638 Content: "Tool execution canceled by user",
639 IsError: true,
640 }
641 }
642 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
643 break
644 }
645 }
646 toolResults[i] = message.ToolResult{
647 ToolCallID: toolCall.ID,
648 Content: toolResponse.Content,
649 Metadata: toolResponse.Metadata,
650 IsError: toolResponse.IsError,
651 }
652 }
653 }
654out:
655 if len(toolResults) == 0 {
656 return assistantMsg, nil, nil
657 }
658 parts := make([]message.ContentPart, 0)
659 for _, tr := range toolResults {
660 parts = append(parts, tr)
661 }
662 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
663 Role: message.Tool,
664 Parts: parts,
665 Provider: a.providerID,
666 })
667 if err != nil {
668 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
669 }
670
671 return assistantMsg, &msg, err
672}
673
674func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
675 msg.AddFinish(finishReason, message, details)
676 _ = a.messages.Update(ctx, *msg)
677}
678
679func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
680 select {
681 case <-ctx.Done():
682 return ctx.Err()
683 default:
684 // Continue processing.
685 }
686
687 switch event.Type {
688 case provider.EventThinkingDelta:
689 assistantMsg.AppendReasoningContent(event.Thinking)
690 return a.messages.Update(ctx, *assistantMsg)
691 case provider.EventSignatureDelta:
692 assistantMsg.AppendReasoningSignature(event.Signature)
693 return a.messages.Update(ctx, *assistantMsg)
694 case provider.EventContentDelta:
695 assistantMsg.FinishThinking()
696 assistantMsg.AppendContent(event.Content)
697 return a.messages.Update(ctx, *assistantMsg)
698 case provider.EventToolUseStart:
699 assistantMsg.FinishThinking()
700 slog.Info("Tool call started", "toolCall", event.ToolCall)
701 assistantMsg.AddToolCall(*event.ToolCall)
702 return a.messages.Update(ctx, *assistantMsg)
703 case provider.EventToolUseDelta:
704 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
705 return a.messages.Update(ctx, *assistantMsg)
706 case provider.EventToolUseStop:
707 slog.Info("Finished tool call", "toolCall", event.ToolCall)
708 assistantMsg.FinishToolCall(event.ToolCall.ID)
709 return a.messages.Update(ctx, *assistantMsg)
710 case provider.EventError:
711 return event.Error
712 case provider.EventComplete:
713 assistantMsg.FinishThinking()
714 assistantMsg.SetToolCalls(event.Response.ToolCalls)
715 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
716 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
717 return fmt.Errorf("failed to update message: %w", err)
718 }
719 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
720 }
721
722 return nil
723}
724
725func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
726 sess, err := a.sessions.Get(ctx, sessionID)
727 if err != nil {
728 return fmt.Errorf("failed to get session: %w", err)
729 }
730
731 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
732 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
733 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
734 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
735
736 sess.Cost += cost
737 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
738 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
739
740 _, err = a.sessions.Save(ctx, sess)
741 if err != nil {
742 return fmt.Errorf("failed to save session: %w", err)
743 }
744 return nil
745}
746
747func (a *agent) Summarize(ctx context.Context, sessionID string) error {
748 if a.summarizeProvider == nil {
749 return fmt.Errorf("summarize provider not available")
750 }
751
752 // Check if session is busy
753 if a.IsSessionBusy(sessionID) {
754 return ErrSessionBusy
755 }
756
757 // Create a new context with cancellation
758 summarizeCtx, cancel := context.WithCancel(ctx)
759
760 // Store the cancel function in activeRequests to allow cancellation
761 a.activeRequests.Set(sessionID+"-summarize", cancel)
762
763 go func() {
764 defer a.activeRequests.Del(sessionID + "-summarize")
765 defer cancel()
766 event := AgentEvent{
767 Type: AgentEventTypeSummarize,
768 Progress: "Starting summarization...",
769 }
770
771 a.Publish(pubsub.CreatedEvent, event)
772 // Get all messages from the session
773 msgs, err := a.messages.List(summarizeCtx, sessionID)
774 if err != nil {
775 event = AgentEvent{
776 Type: AgentEventTypeError,
777 Error: fmt.Errorf("failed to list messages: %w", err),
778 Done: true,
779 }
780 a.Publish(pubsub.CreatedEvent, event)
781 return
782 }
783 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
784
785 if len(msgs) == 0 {
786 event = AgentEvent{
787 Type: AgentEventTypeError,
788 Error: fmt.Errorf("no messages to summarize"),
789 Done: true,
790 }
791 a.Publish(pubsub.CreatedEvent, event)
792 return
793 }
794
795 event = AgentEvent{
796 Type: AgentEventTypeSummarize,
797 Progress: "Analyzing conversation...",
798 }
799 a.Publish(pubsub.CreatedEvent, event)
800
801 // Add a system message to guide the summarization
802 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
803
804 // Create a new message with the summarize prompt
805 promptMsg := message.Message{
806 Role: message.User,
807 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
808 }
809
810 // Append the prompt to the messages
811 msgsWithPrompt := append(msgs, promptMsg)
812
813 event = AgentEvent{
814 Type: AgentEventTypeSummarize,
815 Progress: "Generating summary...",
816 }
817
818 a.Publish(pubsub.CreatedEvent, event)
819
820 // Send the messages to the summarize provider
821 response := a.summarizeProvider.StreamResponse(
822 summarizeCtx,
823 msgsWithPrompt,
824 nil,
825 )
826 var finalResponse *provider.ProviderResponse
827 for r := range response {
828 if r.Error != nil {
829 event = AgentEvent{
830 Type: AgentEventTypeError,
831 Error: fmt.Errorf("failed to summarize: %w", err),
832 Done: true,
833 }
834 a.Publish(pubsub.CreatedEvent, event)
835 return
836 }
837 finalResponse = r.Response
838 }
839
840 summary := strings.TrimSpace(finalResponse.Content)
841 if summary == "" {
842 event = AgentEvent{
843 Type: AgentEventTypeError,
844 Error: fmt.Errorf("empty summary returned"),
845 Done: true,
846 }
847 a.Publish(pubsub.CreatedEvent, event)
848 return
849 }
850 shell := shell.GetPersistentShell(config.Get().WorkingDir())
851 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
852 event = AgentEvent{
853 Type: AgentEventTypeSummarize,
854 Progress: "Creating new session...",
855 }
856
857 a.Publish(pubsub.CreatedEvent, event)
858 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
859 if err != nil {
860 event = AgentEvent{
861 Type: AgentEventTypeError,
862 Error: fmt.Errorf("failed to get session: %w", err),
863 Done: true,
864 }
865
866 a.Publish(pubsub.CreatedEvent, event)
867 return
868 }
869 // Create a message in the new session with the summary
870 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
871 Role: message.Assistant,
872 Parts: []message.ContentPart{
873 message.TextContent{Text: summary},
874 message.Finish{
875 Reason: message.FinishReasonEndTurn,
876 Time: time.Now().Unix(),
877 },
878 },
879 Model: a.summarizeProvider.Model().ID,
880 Provider: a.summarizeProviderID,
881 })
882 if err != nil {
883 event = AgentEvent{
884 Type: AgentEventTypeError,
885 Error: fmt.Errorf("failed to create summary message: %w", err),
886 Done: true,
887 }
888
889 a.Publish(pubsub.CreatedEvent, event)
890 return
891 }
892 oldSession.SummaryMessageID = msg.ID
893 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
894 oldSession.PromptTokens = 0
895 model := a.summarizeProvider.Model()
896 usage := finalResponse.Usage
897 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
898 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
899 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
900 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
901 oldSession.Cost += cost
902 _, err = a.sessions.Save(summarizeCtx, oldSession)
903 if err != nil {
904 event = AgentEvent{
905 Type: AgentEventTypeError,
906 Error: fmt.Errorf("failed to save session: %w", err),
907 Done: true,
908 }
909 a.Publish(pubsub.CreatedEvent, event)
910 }
911
912 event = AgentEvent{
913 Type: AgentEventTypeSummarize,
914 SessionID: oldSession.ID,
915 Progress: "Summary complete",
916 Done: true,
917 }
918 a.Publish(pubsub.CreatedEvent, event)
919 // Send final success event with the new session ID
920 }()
921
922 return nil
923}
924
925func (a *agent) ClearQueue(sessionID string) {
926 if a.QueuedPrompts(sessionID) > 0 {
927 slog.Info("Clearing queued prompts", "session_id", sessionID)
928 a.promptQueue.Del(sessionID)
929 }
930}
931
932func (a *agent) CancelAll() {
933 if !a.IsBusy() {
934 return
935 }
936 for key := range a.activeRequests.Seq2() {
937 a.Cancel(key) // key is sessionID
938 }
939
940 timeout := time.After(5 * time.Second)
941 for a.IsBusy() {
942 select {
943 case <-timeout:
944 return
945 default:
946 time.Sleep(200 * time.Millisecond)
947 }
948 }
949}
950
951func (a *agent) UpdateModel() error {
952 cfg := config.Get()
953
954 // Get current provider configuration
955 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
956 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
957 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
958 }
959
960 // Check if provider has changed
961 if string(currentProviderCfg.ID) != a.providerID {
962 // Provider changed, need to recreate the main provider
963 model := cfg.GetModelByType(a.agentCfg.Model)
964 if model.ID == "" {
965 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
966 }
967
968 promptID := agentPromptMap[a.agentCfg.ID]
969 if promptID == "" {
970 promptID = prompt.PromptDefault
971 }
972
973 opts := []provider.ProviderClientOption{
974 provider.WithModel(a.agentCfg.Model),
975 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
976 }
977
978 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
979 if err != nil {
980 return fmt.Errorf("failed to create new provider: %w", err)
981 }
982
983 // Update the provider and provider ID
984 a.provider = newProvider
985 a.providerID = string(currentProviderCfg.ID)
986 }
987
988 // Check if providers have changed for title (small) and summarize (large)
989 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
990 var smallModelProviderCfg config.ProviderConfig
991 for p := range cfg.Providers.Seq() {
992 if p.ID == smallModelCfg.Provider {
993 smallModelProviderCfg = p
994 break
995 }
996 }
997 if smallModelProviderCfg.ID == "" {
998 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
999 }
1000
1001 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1002 var largeModelProviderCfg config.ProviderConfig
1003 for p := range cfg.Providers.Seq() {
1004 if p.ID == largeModelCfg.Provider {
1005 largeModelProviderCfg = p
1006 break
1007 }
1008 }
1009 if largeModelProviderCfg.ID == "" {
1010 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1011 }
1012
1013 var maxTitleTokens int64 = 40
1014
1015 // if the max output is too low for the gemini provider it won't return anything
1016 if smallModelCfg.Provider == "gemini" {
1017 maxTitleTokens = 1000
1018 }
1019 // Recreate title provider
1020 titleOpts := []provider.ProviderClientOption{
1021 provider.WithModel(config.SelectedModelTypeSmall),
1022 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1023 provider.WithMaxTokens(maxTitleTokens),
1024 }
1025 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1026 if err != nil {
1027 return fmt.Errorf("failed to create new title provider: %w", err)
1028 }
1029 a.titleProvider = newTitleProvider
1030
1031 // Recreate summarize provider if provider changed (now large model)
1032 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1033 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1034 if largeModel == nil {
1035 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1036 }
1037 summarizeOpts := []provider.ProviderClientOption{
1038 provider.WithModel(config.SelectedModelTypeLarge),
1039 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1040 }
1041 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1042 if err != nil {
1043 return fmt.Errorf("failed to create new summarize provider: %w", err)
1044 }
1045 a.summarizeProvider = newSummarizeProvider
1046 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1047 }
1048
1049 return nil
1050}