1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63 QueuedPrompts(sessionID string) int
64 ClearQueue(sessionID string)
65}
66
67type agent struct {
68 *pubsub.Broker[AgentEvent]
69 agentCfg config.Agent
70 sessions session.Service
71 messages message.Service
72 mcpTools []McpTool
73
74 tools *csync.LazySlice[tools.BaseTool]
75
76 provider provider.Provider
77 providerID string
78
79 titleProvider provider.Provider
80 summarizeProvider provider.Provider
81 summarizeProviderID string
82
83 activeRequests *csync.Map[string, context.CancelFunc]
84
85 promptQueue *csync.Map[string, []string]
86}
87
88var agentPromptMap = map[string]prompt.PromptID{
89 "coder": prompt.PromptCoder,
90 "task": prompt.PromptTask,
91}
92
93func NewAgent(
94 ctx context.Context,
95 agentCfg config.Agent,
96 // These services are needed in the tools
97 permissions permission.Service,
98 sessions session.Service,
99 messages message.Service,
100 history history.Service,
101 lspClients map[string]*lsp.Client,
102) (Service, error) {
103 cfg := config.Get()
104
105 var agentTool tools.BaseTool
106 if agentCfg.ID == "coder" {
107 taskAgentCfg := config.Get().Agents["task"]
108 if taskAgentCfg.ID == "" {
109 return nil, fmt.Errorf("task agent not found in config")
110 }
111 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
112 if err != nil {
113 return nil, fmt.Errorf("failed to create task agent: %w", err)
114 }
115
116 agentTool = NewAgentTool(taskAgent, sessions, messages)
117 }
118
119 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
120 if providerCfg == nil {
121 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
122 }
123 model := config.Get().GetModelByType(agentCfg.Model)
124
125 if model == nil {
126 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
127 }
128
129 promptID := agentPromptMap[agentCfg.ID]
130 if promptID == "" {
131 promptID = prompt.PromptDefault
132 }
133 opts := []provider.ProviderClientOption{
134 provider.WithModel(agentCfg.Model),
135 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
136 }
137 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
138 if err != nil {
139 return nil, err
140 }
141
142 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
143 var smallModelProviderCfg *config.ProviderConfig
144 if smallModelCfg.Provider == providerCfg.ID {
145 smallModelProviderCfg = providerCfg
146 } else {
147 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
148
149 if smallModelProviderCfg.ID == "" {
150 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
151 }
152 }
153 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
154 if smallModel.ID == "" {
155 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
156 }
157
158 titleOpts := []provider.ProviderClientOption{
159 provider.WithModel(config.SelectedModelTypeSmall),
160 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
161 }
162 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
163 if err != nil {
164 return nil, err
165 }
166
167 summarizeOpts := []provider.ProviderClientOption{
168 provider.WithModel(config.SelectedModelTypeLarge),
169 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
170 }
171 summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
172 if err != nil {
173 return nil, err
174 }
175
176 toolFn := func() []tools.BaseTool {
177 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
178 defer func() {
179 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
180 }()
181
182 cwd := cfg.WorkingDir()
183 allTools := []tools.BaseTool{
184 tools.NewBashTool(permissions, cwd),
185 tools.NewDownloadTool(permissions, cwd),
186 tools.NewEditTool(lspClients, permissions, history, cwd),
187 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
188 tools.NewFetchTool(permissions, cwd),
189 tools.NewGlobTool(cwd),
190 tools.NewGrepTool(cwd),
191 tools.NewLsTool(permissions, cwd),
192 tools.NewSourcegraphTool(),
193 tools.NewViewTool(lspClients, permissions, cwd),
194 tools.NewWriteTool(lspClients, permissions, history, cwd),
195 }
196
197 mcpToolsOnce.Do(func() {
198 mcpTools = doGetMCPTools(ctx, permissions, cfg)
199 })
200 allTools = append(allTools, mcpTools...)
201
202 if len(lspClients) > 0 {
203 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
204 }
205
206 if agentTool != nil {
207 allTools = append(allTools, agentTool)
208 }
209
210 if agentCfg.AllowedTools == nil {
211 return allTools
212 }
213
214 var filteredTools []tools.BaseTool
215 for _, tool := range allTools {
216 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
217 filteredTools = append(filteredTools, tool)
218 }
219 }
220 return filteredTools
221 }
222
223 return &agent{
224 Broker: pubsub.NewBroker[AgentEvent](),
225 agentCfg: agentCfg,
226 provider: agentProvider,
227 providerID: string(providerCfg.ID),
228 messages: messages,
229 sessions: sessions,
230 titleProvider: titleProvider,
231 summarizeProvider: summarizeProvider,
232 summarizeProviderID: string(providerCfg.ID),
233 activeRequests: csync.NewMap[string, context.CancelFunc](),
234 tools: csync.NewLazySlice(toolFn),
235 promptQueue: csync.NewMap[string, []string](),
236 }, nil
237}
238
239func (a *agent) Model() catwalk.Model {
240 return *config.Get().GetModelByType(a.agentCfg.Model)
241}
242
243func (a *agent) Cancel(sessionID string) {
244 // Cancel regular requests
245 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
246 slog.Info("Request cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249
250 // Also check for summarize requests
251 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
252 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
253 cancel()
254 }
255
256 if a.QueuedPrompts(sessionID) > 0 {
257 slog.Info("Clearing queued prompts", "session_id", sessionID)
258 a.promptQueue.Del(sessionID)
259 }
260}
261
262func (a *agent) IsBusy() bool {
263 var busy bool
264 for cancelFunc := range a.activeRequests.Seq() {
265 if cancelFunc != nil {
266 busy = true
267 break
268 }
269 }
270 return busy
271}
272
273func (a *agent) IsSessionBusy(sessionID string) bool {
274 _, busy := a.activeRequests.Get(sessionID)
275 return busy
276}
277
278func (a *agent) QueuedPrompts(sessionID string) int {
279 l, ok := a.promptQueue.Get(sessionID)
280 if !ok {
281 return 0
282 }
283 return len(l)
284}
285
286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
287 if content == "" {
288 return nil
289 }
290 if a.titleProvider == nil {
291 return nil
292 }
293 session, err := a.sessions.Get(ctx, sessionID)
294 if err != nil {
295 return err
296 }
297 parts := []message.ContentPart{message.TextContent{
298 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
299 }}
300
301 // Use streaming approach like summarization
302 response := a.titleProvider.StreamResponse(
303 ctx,
304 []message.Message{
305 {
306 Role: message.User,
307 Parts: parts,
308 },
309 },
310 nil,
311 )
312
313 var finalResponse *provider.ProviderResponse
314 for r := range response {
315 if r.Error != nil {
316 return r.Error
317 }
318 finalResponse = r.Response
319 }
320
321 if finalResponse == nil {
322 return fmt.Errorf("no response received from title provider")
323 }
324
325 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
326 if title == "" {
327 return nil
328 }
329
330 session.Title = title
331 _, err = a.sessions.Save(ctx, session)
332 return err
333}
334
335func (a *agent) err(err error) AgentEvent {
336 return AgentEvent{
337 Type: AgentEventTypeError,
338 Error: err,
339 }
340}
341
342func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
343 if !a.Model().SupportsImages && attachments != nil {
344 attachments = nil
345 }
346 events := make(chan AgentEvent)
347 if a.IsSessionBusy(sessionID) {
348 existing, ok := a.promptQueue.Get(sessionID)
349 if !ok {
350 existing = []string{}
351 }
352 existing = append(existing, content)
353 a.promptQueue.Set(sessionID, existing)
354 return nil, nil
355 }
356
357 genCtx, cancel := context.WithCancel(ctx)
358
359 a.activeRequests.Set(sessionID, cancel)
360 go func() {
361 slog.Debug("Request started", "sessionID", sessionID)
362 defer log.RecoverPanic("agent.Run", func() {
363 events <- a.err(fmt.Errorf("panic while running the agent"))
364 })
365 var attachmentParts []message.ContentPart
366 for _, attachment := range attachments {
367 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
368 }
369 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
370 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
371 slog.Error(result.Error.Error())
372 }
373 slog.Debug("Request completed", "sessionID", sessionID)
374 a.activeRequests.Del(sessionID)
375 cancel()
376 a.Publish(pubsub.CreatedEvent, result)
377 events <- result
378 close(events)
379 }()
380 return events, nil
381}
382
383func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
384 cfg := config.Get()
385 // List existing messages; if none, start title generation asynchronously.
386 msgs, err := a.messages.List(ctx, sessionID)
387 if err != nil {
388 return a.err(fmt.Errorf("failed to list messages: %w", err))
389 }
390
391 if len(msgs) == 0 {
392 // Use a context with timeout for title generation
393 titleCtx, titleCancel := context.WithTimeout(context.Background(), 30*time.Second)
394 go func() {
395 defer titleCancel()
396 defer log.RecoverPanic("agent.Run", func() {
397 slog.Error("panic while generating title")
398 })
399 titleErr := a.generateTitle(titleCtx, sessionID, content)
400 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
401 slog.Error("failed to generate title", "error", titleErr)
402 }
403 }()
404 }
405 session, err := a.sessions.Get(ctx, sessionID)
406 if err != nil {
407 return a.err(fmt.Errorf("failed to get session: %w", err))
408 }
409 if session.SummaryMessageID != "" {
410 summaryMsgInex := -1
411 for i, msg := range msgs {
412 if msg.ID == session.SummaryMessageID {
413 summaryMsgInex = i
414 break
415 }
416 }
417 if summaryMsgInex != -1 {
418 msgs = msgs[summaryMsgInex:]
419 msgs[0].Role = message.User
420 }
421 }
422
423 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
424 if err != nil {
425 return a.err(fmt.Errorf("failed to create user message: %w", err))
426 }
427 // Append the new user message to the conversation history.
428 msgHistory := append(msgs, userMsg)
429
430 for {
431 // Check for cancellation before each iteration
432 select {
433 case <-ctx.Done():
434 return a.err(ctx.Err())
435 default:
436 // Continue processing
437 }
438 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
439 if err != nil {
440 if errors.Is(err, context.Canceled) {
441 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
442 a.messages.Update(context.Background(), agentMessage)
443 return a.err(ErrRequestCancelled)
444 }
445 return a.err(fmt.Errorf("failed to process events: %w", err))
446 }
447 if cfg.Options.Debug {
448 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
449 }
450 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
451 // We are not done, we need to respond with the tool response
452 msgHistory = append(msgHistory, agentMessage, *toolResults)
453 // If there are queued prompts, process the next one
454 nextPrompt, ok := a.promptQueue.Take(sessionID)
455 if ok {
456 for _, prompt := range nextPrompt {
457 // Create a new user message for the queued prompt
458 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
459 if err != nil {
460 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
461 }
462 // Append the new user message to the conversation history
463 msgHistory = append(msgHistory, userMsg)
464 }
465 }
466
467 continue
468 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
469 queuePrompts, ok := a.promptQueue.Take(sessionID)
470 if ok {
471 for _, prompt := range queuePrompts {
472 if prompt == "" {
473 continue
474 }
475 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
476 if err != nil {
477 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
478 }
479 msgHistory = append(msgHistory, userMsg)
480 }
481 continue
482 }
483 }
484 if agentMessage.FinishReason() == "" {
485 // Kujtim: could not track down where this is happening but this means its cancelled
486 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
487 _ = a.messages.Update(context.Background(), agentMessage)
488 return a.err(ErrRequestCancelled)
489 }
490 return AgentEvent{
491 Type: AgentEventTypeResponse,
492 Message: agentMessage,
493 Done: true,
494 }
495 }
496}
497
498func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
499 parts := []message.ContentPart{message.TextContent{Text: content}}
500 parts = append(parts, attachmentParts...)
501 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
502 Role: message.User,
503 Parts: parts,
504 })
505}
506
507func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
508 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
509
510 // Create the assistant message first so the spinner shows immediately
511 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
512 Role: message.Assistant,
513 Parts: []message.ContentPart{},
514 Model: a.Model().ID,
515 Provider: a.providerID,
516 })
517 if err != nil {
518 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
519 }
520
521 // Now collect tools (which may block on MCP initialization)
522 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
523
524 // Add the session and message ID into the context if needed by tools.
525 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
526
527 // Process each event in the stream.
528 for event := range eventChan {
529 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
530 if errors.Is(processErr, context.Canceled) {
531 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
532 } else {
533 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
534 }
535 return assistantMsg, nil, processErr
536 }
537 if ctx.Err() != nil {
538 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
539 return assistantMsg, nil, ctx.Err()
540 }
541 }
542
543 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
544 toolCalls := assistantMsg.ToolCalls()
545 for i, toolCall := range toolCalls {
546 select {
547 case <-ctx.Done():
548 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
549 // Make all future tool calls cancelled
550 for j := i; j < len(toolCalls); j++ {
551 toolResults[j] = message.ToolResult{
552 ToolCallID: toolCalls[j].ID,
553 Content: "Tool execution canceled by user",
554 IsError: true,
555 }
556 }
557 goto out
558 default:
559 // Continue processing
560 var tool tools.BaseTool
561 for availableTool := range a.tools.Seq() {
562 if availableTool.Info().Name == toolCall.Name {
563 tool = availableTool
564 break
565 }
566 }
567
568 // Tool not found
569 if tool == nil {
570 toolResults[i] = message.ToolResult{
571 ToolCallID: toolCall.ID,
572 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
573 IsError: true,
574 }
575 continue
576 }
577
578 // Run tool in goroutine to allow cancellation
579 type toolExecResult struct {
580 response tools.ToolResponse
581 err error
582 }
583 resultChan := make(chan toolExecResult, 1)
584
585 go func() {
586 response, err := tool.Run(ctx, tools.ToolCall{
587 ID: toolCall.ID,
588 Name: toolCall.Name,
589 Input: toolCall.Input,
590 })
591 resultChan <- toolExecResult{response: response, err: err}
592 }()
593
594 var toolResponse tools.ToolResponse
595 var toolErr error
596
597 select {
598 case <-ctx.Done():
599 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
600 // Mark remaining tool calls as cancelled
601 for j := i; j < len(toolCalls); j++ {
602 toolResults[j] = message.ToolResult{
603 ToolCallID: toolCalls[j].ID,
604 Content: "Tool execution canceled by user",
605 IsError: true,
606 }
607 }
608 goto out
609 case result := <-resultChan:
610 toolResponse = result.response
611 toolErr = result.err
612 }
613
614 if toolErr != nil {
615 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
616 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
617 toolResults[i] = message.ToolResult{
618 ToolCallID: toolCall.ID,
619 Content: "Permission denied",
620 IsError: true,
621 }
622 for j := i + 1; j < len(toolCalls); j++ {
623 toolResults[j] = message.ToolResult{
624 ToolCallID: toolCalls[j].ID,
625 Content: "Tool execution canceled by user",
626 IsError: true,
627 }
628 }
629 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
630 break
631 }
632 }
633 toolResults[i] = message.ToolResult{
634 ToolCallID: toolCall.ID,
635 Content: toolResponse.Content,
636 Metadata: toolResponse.Metadata,
637 IsError: toolResponse.IsError,
638 }
639 }
640 }
641out:
642 if len(toolResults) == 0 {
643 return assistantMsg, nil, nil
644 }
645 parts := make([]message.ContentPart, 0)
646 for _, tr := range toolResults {
647 parts = append(parts, tr)
648 }
649 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
650 Role: message.Tool,
651 Parts: parts,
652 Provider: a.providerID,
653 })
654 if err != nil {
655 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
656 }
657
658 return assistantMsg, &msg, err
659}
660
661func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
662 msg.AddFinish(finishReason, message, details)
663 _ = a.messages.Update(ctx, *msg)
664}
665
666func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
667 select {
668 case <-ctx.Done():
669 return ctx.Err()
670 default:
671 // Continue processing.
672 }
673
674 switch event.Type {
675 case provider.EventThinkingDelta:
676 assistantMsg.AppendReasoningContent(event.Thinking)
677 return a.messages.Update(ctx, *assistantMsg)
678 case provider.EventSignatureDelta:
679 assistantMsg.AppendReasoningSignature(event.Signature)
680 return a.messages.Update(ctx, *assistantMsg)
681 case provider.EventContentDelta:
682 assistantMsg.FinishThinking()
683 assistantMsg.AppendContent(event.Content)
684 return a.messages.Update(ctx, *assistantMsg)
685 case provider.EventToolUseStart:
686 assistantMsg.FinishThinking()
687 slog.Info("Tool call started", "toolCall", event.ToolCall)
688 assistantMsg.AddToolCall(*event.ToolCall)
689 return a.messages.Update(ctx, *assistantMsg)
690 case provider.EventToolUseDelta:
691 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
692 return a.messages.Update(ctx, *assistantMsg)
693 case provider.EventToolUseStop:
694 slog.Info("Finished tool call", "toolCall", event.ToolCall)
695 assistantMsg.FinishToolCall(event.ToolCall.ID)
696 return a.messages.Update(ctx, *assistantMsg)
697 case provider.EventError:
698 return event.Error
699 case provider.EventComplete:
700 assistantMsg.FinishThinking()
701 assistantMsg.SetToolCalls(event.Response.ToolCalls)
702 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
703 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
704 return fmt.Errorf("failed to update message: %w", err)
705 }
706 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
707 }
708
709 return nil
710}
711
712func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
713 sess, err := a.sessions.Get(ctx, sessionID)
714 if err != nil {
715 return fmt.Errorf("failed to get session: %w", err)
716 }
717
718 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
719 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
720 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
721 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
722
723 sess.Cost += cost
724 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
725 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
726
727 _, err = a.sessions.Save(ctx, sess)
728 if err != nil {
729 return fmt.Errorf("failed to save session: %w", err)
730 }
731 return nil
732}
733
734func (a *agent) Summarize(ctx context.Context, sessionID string) error {
735 if a.summarizeProvider == nil {
736 return fmt.Errorf("summarize provider not available")
737 }
738
739 // Check if session is busy
740 if a.IsSessionBusy(sessionID) {
741 return ErrSessionBusy
742 }
743
744 // Create a new context with cancellation
745 summarizeCtx, cancel := context.WithCancel(ctx)
746
747 // Store the cancel function in activeRequests to allow cancellation
748 a.activeRequests.Set(sessionID+"-summarize", cancel)
749
750 go func() {
751 defer a.activeRequests.Del(sessionID + "-summarize")
752 defer cancel()
753 event := AgentEvent{
754 Type: AgentEventTypeSummarize,
755 Progress: "Starting summarization...",
756 }
757
758 a.Publish(pubsub.CreatedEvent, event)
759 // Get all messages from the session
760 msgs, err := a.messages.List(summarizeCtx, sessionID)
761 if err != nil {
762 event = AgentEvent{
763 Type: AgentEventTypeError,
764 Error: fmt.Errorf("failed to list messages: %w", err),
765 Done: true,
766 }
767 a.Publish(pubsub.CreatedEvent, event)
768 return
769 }
770 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
771
772 if len(msgs) == 0 {
773 event = AgentEvent{
774 Type: AgentEventTypeError,
775 Error: fmt.Errorf("no messages to summarize"),
776 Done: true,
777 }
778 a.Publish(pubsub.CreatedEvent, event)
779 return
780 }
781
782 event = AgentEvent{
783 Type: AgentEventTypeSummarize,
784 Progress: "Analyzing conversation...",
785 }
786 a.Publish(pubsub.CreatedEvent, event)
787
788 // Add a system message to guide the summarization
789 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
790
791 // Create a new message with the summarize prompt
792 promptMsg := message.Message{
793 Role: message.User,
794 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
795 }
796
797 // Append the prompt to the messages
798 msgsWithPrompt := append(msgs, promptMsg)
799
800 event = AgentEvent{
801 Type: AgentEventTypeSummarize,
802 Progress: "Generating summary...",
803 }
804
805 a.Publish(pubsub.CreatedEvent, event)
806
807 // Send the messages to the summarize provider
808 response := a.summarizeProvider.StreamResponse(
809 summarizeCtx,
810 msgsWithPrompt,
811 nil,
812 )
813 var finalResponse *provider.ProviderResponse
814 for r := range response {
815 if r.Error != nil {
816 event = AgentEvent{
817 Type: AgentEventTypeError,
818 Error: fmt.Errorf("failed to summarize: %w", err),
819 Done: true,
820 }
821 a.Publish(pubsub.CreatedEvent, event)
822 return
823 }
824 finalResponse = r.Response
825 }
826
827 summary := strings.TrimSpace(finalResponse.Content)
828 if summary == "" {
829 event = AgentEvent{
830 Type: AgentEventTypeError,
831 Error: fmt.Errorf("empty summary returned"),
832 Done: true,
833 }
834 a.Publish(pubsub.CreatedEvent, event)
835 return
836 }
837 shell := shell.GetPersistentShell(config.Get().WorkingDir())
838 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
839 event = AgentEvent{
840 Type: AgentEventTypeSummarize,
841 Progress: "Creating new session...",
842 }
843
844 a.Publish(pubsub.CreatedEvent, event)
845 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
846 if err != nil {
847 event = AgentEvent{
848 Type: AgentEventTypeError,
849 Error: fmt.Errorf("failed to get session: %w", err),
850 Done: true,
851 }
852
853 a.Publish(pubsub.CreatedEvent, event)
854 return
855 }
856 // Create a message in the new session with the summary
857 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
858 Role: message.Assistant,
859 Parts: []message.ContentPart{
860 message.TextContent{Text: summary},
861 message.Finish{
862 Reason: message.FinishReasonEndTurn,
863 Time: time.Now().Unix(),
864 },
865 },
866 Model: a.summarizeProvider.Model().ID,
867 Provider: a.summarizeProviderID,
868 })
869 if err != nil {
870 event = AgentEvent{
871 Type: AgentEventTypeError,
872 Error: fmt.Errorf("failed to create summary message: %w", err),
873 Done: true,
874 }
875
876 a.Publish(pubsub.CreatedEvent, event)
877 return
878 }
879 oldSession.SummaryMessageID = msg.ID
880 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
881 oldSession.PromptTokens = 0
882 model := a.summarizeProvider.Model()
883 usage := finalResponse.Usage
884 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
885 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
886 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
887 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
888 oldSession.Cost += cost
889 _, err = a.sessions.Save(summarizeCtx, oldSession)
890 if err != nil {
891 event = AgentEvent{
892 Type: AgentEventTypeError,
893 Error: fmt.Errorf("failed to save session: %w", err),
894 Done: true,
895 }
896 a.Publish(pubsub.CreatedEvent, event)
897 }
898
899 event = AgentEvent{
900 Type: AgentEventTypeSummarize,
901 SessionID: oldSession.ID,
902 Progress: "Summary complete",
903 Done: true,
904 }
905 a.Publish(pubsub.CreatedEvent, event)
906 // Send final success event with the new session ID
907 }()
908
909 return nil
910}
911
912func (a *agent) ClearQueue(sessionID string) {
913 if a.QueuedPrompts(sessionID) > 0 {
914 slog.Info("Clearing queued prompts", "session_id", sessionID)
915 a.promptQueue.Del(sessionID)
916 }
917}
918
919func (a *agent) CancelAll() {
920 if !a.IsBusy() {
921 return
922 }
923 for key := range a.activeRequests.Seq2() {
924 a.Cancel(key) // key is sessionID
925 }
926
927 timeout := time.After(5 * time.Second)
928 for a.IsBusy() {
929 select {
930 case <-timeout:
931 return
932 default:
933 time.Sleep(200 * time.Millisecond)
934 }
935 }
936}
937
938func (a *agent) UpdateModel() error {
939 cfg := config.Get()
940
941 // Get current provider configuration
942 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
943 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
944 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
945 }
946
947 // Check if provider has changed
948 if string(currentProviderCfg.ID) != a.providerID {
949 // Provider changed, need to recreate the main provider
950 model := cfg.GetModelByType(a.agentCfg.Model)
951 if model.ID == "" {
952 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
953 }
954
955 promptID := agentPromptMap[a.agentCfg.ID]
956 if promptID == "" {
957 promptID = prompt.PromptDefault
958 }
959
960 opts := []provider.ProviderClientOption{
961 provider.WithModel(a.agentCfg.Model),
962 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
963 }
964
965 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
966 if err != nil {
967 return fmt.Errorf("failed to create new provider: %w", err)
968 }
969
970 // Update the provider and provider ID
971 a.provider = newProvider
972 a.providerID = string(currentProviderCfg.ID)
973 }
974
975 // Check if providers have changed for title (small) and summarize (large)
976 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
977 var smallModelProviderCfg config.ProviderConfig
978 for p := range cfg.Providers.Seq() {
979 if p.ID == smallModelCfg.Provider {
980 smallModelProviderCfg = p
981 break
982 }
983 }
984 if smallModelProviderCfg.ID == "" {
985 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
986 }
987
988 largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
989 var largeModelProviderCfg config.ProviderConfig
990 for p := range cfg.Providers.Seq() {
991 if p.ID == largeModelCfg.Provider {
992 largeModelProviderCfg = p
993 break
994 }
995 }
996 if largeModelProviderCfg.ID == "" {
997 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
998 }
999
1000 // Recreate title provider
1001 titleOpts := []provider.ProviderClientOption{
1002 provider.WithModel(config.SelectedModelTypeSmall),
1003 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1004 provider.WithMaxTokens(40),
1005 }
1006 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1007 if err != nil {
1008 return fmt.Errorf("failed to create new title provider: %w", err)
1009 }
1010 a.titleProvider = newTitleProvider
1011
1012 // Recreate summarize provider if provider changed (now large model)
1013 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1014 largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1015 if largeModel == nil {
1016 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1017 }
1018 summarizeOpts := []provider.ProviderClientOption{
1019 provider.WithModel(config.SelectedModelTypeLarge),
1020 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1021 }
1022 newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1023 if err != nil {
1024 return fmt.Errorf("failed to create new summarize provider: %w", err)
1025 }
1026 a.summarizeProvider = newSummarizeProvider
1027 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1028 }
1029
1030 return nil
1031}