1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/event"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/proto"
25 "github.com/charmbracelet/crush/internal/pubsub"
26 "github.com/charmbracelet/crush/internal/session"
27 "github.com/charmbracelet/crush/internal/shell"
28)
29
30type (
31 AgentEventType = proto.AgentEventType
32 AgentEvent = proto.AgentEvent
33)
34
35const (
36 AgentEventTypeError = proto.AgentEventTypeError
37 AgentEventTypeResponse = proto.AgentEventTypeResponse
38 AgentEventTypeSummarize = proto.AgentEventTypeSummarize
39)
40
41type Service interface {
42 pubsub.Suscriber[AgentEvent]
43 Model() catwalk.Model
44 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
45 Cancel(sessionID string)
46 CancelAll()
47 IsSessionBusy(sessionID string) bool
48 IsBusy() bool
49 Summarize(ctx context.Context, sessionID string) error
50 UpdateModel() error
51 QueuedPrompts(sessionID string) int
52 ClearQueue(sessionID string)
53}
54
55type agent struct {
56 *pubsub.Broker[AgentEvent]
57 agentCfg config.Agent
58 sessions session.Service
59 messages message.Service
60 permissions permission.Service
61 mcpTools []McpTool
62 cfg *config.Config
63
64 tools *csync.LazySlice[tools.BaseTool]
65 // We need this to be able to update it when model changes
66 agentToolFn func() (tools.BaseTool, error)
67
68 provider provider.Provider
69 providerID string
70
71 titleProvider provider.Provider
72 summarizeProvider provider.Provider
73 summarizeProviderID string
74
75 activeRequests *csync.Map[string, context.CancelFunc]
76 promptQueue *csync.Map[string, []string]
77}
78
79var agentPromptMap = map[string]prompt.PromptID{
80 "coder": prompt.PromptCoder,
81 "task": prompt.PromptTask,
82}
83
84func NewAgent(
85 ctx context.Context,
86 cfg *config.Config,
87 agentCfg config.Agent,
88 // These services are needed in the tools
89 permissions permission.Service,
90 sessions session.Service,
91 messages message.Service,
92 history history.Service,
93 lspClients *csync.Map[string, *lsp.Client],
94) (Service, error) {
95 var agentToolFn func() (tools.BaseTool, error)
96 if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
97 agentToolFn = func() (tools.BaseTool, error) {
98 taskAgentCfg := cfg.Agents["task"]
99 if taskAgentCfg.ID == "" {
100 return nil, fmt.Errorf("task agent not found in config")
101 }
102 taskAgent, err := NewAgent(ctx, cfg, taskAgentCfg, permissions, sessions, messages, history, lspClients)
103 if err != nil {
104 return nil, fmt.Errorf("failed to create task agent: %w", err)
105 }
106 return NewAgentTool(taskAgent, sessions, messages), nil
107 }
108 }
109
110 providerCfg := cfg.GetProviderForModel(agentCfg.Model)
111 if providerCfg == nil {
112 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
113 }
114 model := cfg.GetModelByType(agentCfg.Model)
115
116 if model == nil {
117 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
118 }
119
120 promptID := agentPromptMap[agentCfg.ID]
121 if promptID == "" {
122 promptID = prompt.PromptDefault
123 }
124 opts := []provider.ProviderClientOption{
125 provider.WithModel(agentCfg.Model),
126 provider.WithSystemMessage(prompt.GetPrompt(cfg, promptID, providerCfg.ID, cfg.Options.ContextPaths...)),
127 }
128 agentProvider, err := provider.NewProvider(cfg, *providerCfg, opts...)
129 if err != nil {
130 return nil, err
131 }
132
133 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
134 var smallModelProviderCfg *config.ProviderConfig
135 if smallModelCfg.Provider == providerCfg.ID {
136 smallModelProviderCfg = providerCfg
137 } else {
138 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
139
140 if smallModelProviderCfg.ID == "" {
141 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
142 }
143 }
144 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
145 if smallModel.ID == "" {
146 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
147 }
148
149 titleOpts := []provider.ProviderClientOption{
150 provider.WithModel(config.SelectedModelTypeSmall),
151 provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptTitle, smallModelProviderCfg.ID)),
152 }
153 titleProvider, err := provider.NewProvider(cfg, *smallModelProviderCfg, titleOpts...)
154 if err != nil {
155 return nil, err
156 }
157
158 summarizeOpts := []provider.ProviderClientOption{
159 provider.WithModel(config.SelectedModelTypeLarge),
160 provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptSummarizer, providerCfg.ID)),
161 }
162 summarizeProvider, err := provider.NewProvider(cfg, *providerCfg, summarizeOpts...)
163 if err != nil {
164 return nil, err
165 }
166
167 toolFn := func() []tools.BaseTool {
168 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
169 defer func() {
170 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
171 }()
172
173 cwd := cfg.WorkingDir()
174 allTools := []tools.BaseTool{
175 tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
176 tools.NewDownloadTool(permissions, cwd),
177 tools.NewEditTool(lspClients, permissions, history, cwd),
178 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
179 tools.NewFetchTool(permissions, cwd),
180 tools.NewGlobTool(cwd),
181 tools.NewGrepTool(cwd),
182 tools.NewLsTool(permissions, cwd),
183 tools.NewSourcegraphTool(),
184 tools.NewViewTool(lspClients, permissions, cwd),
185 tools.NewWriteTool(lspClients, permissions, history, cwd),
186 }
187
188 mcpToolsOnce.Do(func() {
189 mcpTools = doGetMCPTools(ctx, permissions, cfg)
190 })
191
192 withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
193 if agentCfg.ID == "coder" {
194 t = append(t, mcpTools...)
195 if lspClients.Len() > 0 {
196 t = append(t, tools.NewDiagnosticsTool(lspClients))
197 }
198 }
199 return t
200 }
201
202 if agentCfg.AllowedTools == nil {
203 return withCoderTools(allTools)
204 }
205
206 var filteredTools []tools.BaseTool
207 for _, tool := range allTools {
208 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
209 filteredTools = append(filteredTools, tool)
210 }
211 }
212 return withCoderTools(filteredTools)
213 }
214
215 return &agent{
216 Broker: pubsub.NewBroker[AgentEvent](),
217 agentCfg: agentCfg,
218 provider: agentProvider,
219 providerID: string(providerCfg.ID),
220 messages: messages,
221 sessions: sessions,
222 titleProvider: titleProvider,
223 summarizeProvider: summarizeProvider,
224 summarizeProviderID: string(providerCfg.ID),
225 agentToolFn: agentToolFn,
226 activeRequests: csync.NewMap[string, context.CancelFunc](),
227 tools: csync.NewLazySlice(toolFn),
228 promptQueue: csync.NewMap[string, []string](),
229 cfg: cfg,
230 permissions: permissions,
231 }, nil
232}
233
234func (a *agent) Model() catwalk.Model {
235 return *a.cfg.GetModelByType(a.agentCfg.Model)
236}
237
238func (a *agent) Cancel(sessionID string) {
239 // Cancel regular requests
240 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
241 slog.Info("Request cancellation initiated", "session_id", sessionID)
242 cancel()
243 }
244
245 // Also check for summarize requests
246 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
247 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
248 cancel()
249 }
250
251 if a.QueuedPrompts(sessionID) > 0 {
252 slog.Info("Clearing queued prompts", "session_id", sessionID)
253 a.promptQueue.Del(sessionID)
254 }
255}
256
257func (a *agent) IsBusy() bool {
258 var busy bool
259 for cancelFunc := range a.activeRequests.Seq() {
260 if cancelFunc != nil {
261 busy = true
262 break
263 }
264 }
265 return busy
266}
267
268func (a *agent) IsSessionBusy(sessionID string) bool {
269 _, busy := a.activeRequests.Get(sessionID)
270 return busy
271}
272
273func (a *agent) QueuedPrompts(sessionID string) int {
274 l, ok := a.promptQueue.Get(sessionID)
275 if !ok {
276 return 0
277 }
278 return len(l)
279}
280
281func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
282 if content == "" {
283 return nil
284 }
285 if a.titleProvider == nil {
286 return nil
287 }
288 session, err := a.sessions.Get(ctx, sessionID)
289 if err != nil {
290 return err
291 }
292 parts := []message.ContentPart{message.TextContent{
293 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
294 }}
295
296 // Use streaming approach like summarization
297 response := a.titleProvider.StreamResponse(
298 ctx,
299 []message.Message{
300 {
301 Role: message.User,
302 Parts: parts,
303 },
304 },
305 nil,
306 )
307
308 var finalResponse *provider.ProviderResponse
309 for r := range response {
310 if r.Error != nil {
311 return r.Error
312 }
313 finalResponse = r.Response
314 }
315
316 if finalResponse == nil {
317 return fmt.Errorf("no response received from title provider")
318 }
319
320 title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
321
322 if idx := strings.Index(title, "</think>"); idx > 0 {
323 title = title[idx+len("</think>"):]
324 }
325
326 title = strings.TrimSpace(title)
327 if title == "" {
328 return nil
329 }
330
331 session.Title = title
332 _, err = a.sessions.Save(ctx, session)
333 return err
334}
335
336func (a *agent) err(err error) AgentEvent {
337 return AgentEvent{
338 Type: AgentEventTypeError,
339 Error: err,
340 }
341}
342
343func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
344 if !a.Model().SupportsImages && attachments != nil {
345 attachments = nil
346 }
347 events := make(chan AgentEvent, 1)
348 if a.IsSessionBusy(sessionID) {
349 existing, ok := a.promptQueue.Get(sessionID)
350 if !ok {
351 existing = []string{}
352 }
353 existing = append(existing, content)
354 a.promptQueue.Set(sessionID, existing)
355 return nil, nil
356 }
357
358 genCtx, cancel := context.WithCancel(ctx)
359 a.activeRequests.Set(sessionID, cancel)
360 startTime := time.Now()
361
362 go func() {
363 slog.Debug("Request started", "sessionID", sessionID)
364 defer log.RecoverPanic("agent.Run", func() {
365 events <- a.err(fmt.Errorf("panic while running the agent"))
366 })
367 var attachmentParts []message.ContentPart
368 for _, attachment := range attachments {
369 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
370 }
371 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
372 if result.Error != nil {
373 if isCancelledErr(result.Error) {
374 slog.Error("Request canceled", "sessionID", sessionID)
375 } else {
376 slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
377 event.Error(result.Error)
378 }
379 } else {
380 slog.Debug("Request completed", "sessionID", sessionID)
381 }
382 a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
383 a.activeRequests.Del(sessionID)
384 cancel()
385 a.Publish(pubsub.CreatedEvent, result)
386 events <- result
387 close(events)
388 }()
389 a.eventPromptSent(sessionID)
390 return events, nil
391}
392
393func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
394 // List existing messages; if none, start title generation asynchronously.
395 msgs, err := a.messages.List(ctx, sessionID)
396 if err != nil {
397 return a.err(fmt.Errorf("failed to list messages: %w", err))
398 }
399 if len(msgs) == 0 {
400 go func() {
401 defer log.RecoverPanic("agent.Run", func() {
402 slog.Error("panic while generating title")
403 })
404 titleErr := a.generateTitle(ctx, sessionID, content)
405 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
406 slog.Error("failed to generate title", "error", titleErr)
407 }
408 }()
409 }
410 session, err := a.sessions.Get(ctx, sessionID)
411 if err != nil {
412 return a.err(fmt.Errorf("failed to get session: %w", err))
413 }
414 if session.SummaryMessageID != "" {
415 summaryMsgInex := -1
416 for i, msg := range msgs {
417 if msg.ID == session.SummaryMessageID {
418 summaryMsgInex = i
419 break
420 }
421 }
422 if summaryMsgInex != -1 {
423 msgs = msgs[summaryMsgInex:]
424 msgs[0].Role = message.User
425 }
426 }
427
428 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
429 if err != nil {
430 return a.err(fmt.Errorf("failed to create user message: %w", err))
431 }
432 // Append the new user message to the conversation history.
433 msgHistory := append(msgs, userMsg)
434
435 for {
436 // Check for cancellation before each iteration
437 select {
438 case <-ctx.Done():
439 return a.err(ctx.Err())
440 default:
441 // Continue processing
442 }
443 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
444 if err != nil {
445 if errors.Is(err, context.Canceled) {
446 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
447 a.messages.Update(context.Background(), agentMessage)
448 return a.err(ErrRequestCancelled)
449 }
450 return a.err(fmt.Errorf("failed to process events: %w", err))
451 }
452 if a.cfg.Options.Debug {
453 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
454 }
455 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
456 // We are not done, we need to respond with the tool response
457 msgHistory = append(msgHistory, agentMessage, *toolResults)
458 // If there are queued prompts, process the next one
459 nextPrompt, ok := a.promptQueue.Take(sessionID)
460 if ok {
461 for _, prompt := range nextPrompt {
462 // Create a new user message for the queued prompt
463 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
464 if err != nil {
465 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
466 }
467 // Append the new user message to the conversation history
468 msgHistory = append(msgHistory, userMsg)
469 }
470 }
471
472 continue
473 } else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
474 queuePrompts, ok := a.promptQueue.Take(sessionID)
475 if ok {
476 for _, prompt := range queuePrompts {
477 if prompt == "" {
478 continue
479 }
480 userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
481 if err != nil {
482 return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
483 }
484 msgHistory = append(msgHistory, userMsg)
485 }
486 continue
487 }
488 }
489 if agentMessage.FinishReason() == "" {
490 // Kujtim: could not track down where this is happening but this means its cancelled
491 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
492 _ = a.messages.Update(context.Background(), agentMessage)
493 return a.err(ErrRequestCancelled)
494 }
495 return AgentEvent{
496 Type: AgentEventTypeResponse,
497 Message: agentMessage,
498 Done: true,
499 }
500 }
501}
502
503func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
504 parts := []message.ContentPart{message.TextContent{Text: content}}
505 parts = append(parts, attachmentParts...)
506 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
507 Role: message.User,
508 Parts: parts,
509 })
510}
511
512func (a *agent) getAllTools() ([]tools.BaseTool, error) {
513 allTools := slices.Collect(a.tools.Seq())
514 if a.agentToolFn != nil {
515 agentTool, agentToolErr := a.agentToolFn()
516 if agentToolErr != nil {
517 return nil, agentToolErr
518 }
519 allTools = append(allTools, agentTool)
520 }
521 return allTools, nil
522}
523
524func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
525 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
526
527 // Create the assistant message first so the spinner shows immediately
528 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
529 Role: message.Assistant,
530 Parts: []message.ContentPart{},
531 Model: a.Model().ID,
532 Provider: a.providerID,
533 })
534 if err != nil {
535 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
536 }
537
538 allTools, toolsErr := a.getAllTools()
539 if toolsErr != nil {
540 return assistantMsg, nil, toolsErr
541 }
542 // Now collect tools (which may block on MCP initialization)
543 eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
544
545 // Add the session and message ID into the context if needed by tools.
546 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
547
548 // Process each event in the stream.
549 for event := range eventChan {
550 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
551 if errors.Is(processErr, context.Canceled) {
552 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
553 } else {
554 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
555 }
556 return assistantMsg, nil, processErr
557 }
558 if ctx.Err() != nil {
559 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
560 return assistantMsg, nil, ctx.Err()
561 }
562 }
563
564 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
565 toolCalls := assistantMsg.ToolCalls()
566 for i, toolCall := range toolCalls {
567 select {
568 case <-ctx.Done():
569 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
570 // Make all future tool calls cancelled
571 for j := i; j < len(toolCalls); j++ {
572 toolResults[j] = message.ToolResult{
573 ToolCallID: toolCalls[j].ID,
574 Content: "Tool execution canceled by user",
575 IsError: true,
576 }
577 }
578 goto out
579 default:
580 // Continue processing
581 var tool tools.BaseTool
582 allTools, _ := a.getAllTools()
583 for _, availableTool := range allTools {
584 if availableTool.Info().Name == toolCall.Name {
585 tool = availableTool
586 break
587 }
588 }
589
590 // Tool not found
591 if tool == nil {
592 toolResults[i] = message.ToolResult{
593 ToolCallID: toolCall.ID,
594 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
595 IsError: true,
596 }
597 continue
598 }
599
600 // Run tool in goroutine to allow cancellation
601 type toolExecResult struct {
602 response tools.ToolResponse
603 err error
604 }
605 resultChan := make(chan toolExecResult, 1)
606
607 go func() {
608 response, err := tool.Run(ctx, tools.ToolCall{
609 ID: toolCall.ID,
610 Name: toolCall.Name,
611 Input: toolCall.Input,
612 })
613 resultChan <- toolExecResult{response: response, err: err}
614 }()
615
616 var toolResponse tools.ToolResponse
617 var toolErr error
618
619 select {
620 case <-ctx.Done():
621 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
622 // Mark remaining tool calls as cancelled
623 for j := i; j < len(toolCalls); j++ {
624 toolResults[j] = message.ToolResult{
625 ToolCallID: toolCalls[j].ID,
626 Content: "Tool execution canceled by user",
627 IsError: true,
628 }
629 }
630 goto out
631 case result := <-resultChan:
632 toolResponse = result.response
633 toolErr = result.err
634 }
635
636 if toolErr != nil {
637 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
638 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
639 toolResults[i] = message.ToolResult{
640 ToolCallID: toolCall.ID,
641 Content: "Permission denied",
642 IsError: true,
643 }
644 for j := i + 1; j < len(toolCalls); j++ {
645 toolResults[j] = message.ToolResult{
646 ToolCallID: toolCalls[j].ID,
647 Content: "Tool execution canceled by user",
648 IsError: true,
649 }
650 }
651 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
652 break
653 }
654 }
655 toolResults[i] = message.ToolResult{
656 ToolCallID: toolCall.ID,
657 Content: toolResponse.Content,
658 Metadata: toolResponse.Metadata,
659 IsError: toolResponse.IsError,
660 }
661 }
662 }
663out:
664 if len(toolResults) == 0 {
665 return assistantMsg, nil, nil
666 }
667 parts := make([]message.ContentPart, 0)
668 for _, tr := range toolResults {
669 parts = append(parts, tr)
670 }
671 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
672 Role: message.Tool,
673 Parts: parts,
674 Provider: a.providerID,
675 })
676 if err != nil {
677 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
678 }
679
680 return assistantMsg, &msg, err
681}
682
683func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
684 msg.AddFinish(finishReason, message, details)
685 _ = a.messages.Update(ctx, *msg)
686}
687
688func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
689 select {
690 case <-ctx.Done():
691 return ctx.Err()
692 default:
693 // Continue processing.
694 }
695
696 switch event.Type {
697 case provider.EventThinkingDelta:
698 assistantMsg.AppendReasoningContent(event.Thinking)
699 return a.messages.Update(ctx, *assistantMsg)
700 case provider.EventSignatureDelta:
701 assistantMsg.AppendReasoningSignature(event.Signature)
702 return a.messages.Update(ctx, *assistantMsg)
703 case provider.EventContentDelta:
704 assistantMsg.FinishThinking()
705 assistantMsg.AppendContent(event.Content)
706 return a.messages.Update(ctx, *assistantMsg)
707 case provider.EventToolUseStart:
708 assistantMsg.FinishThinking()
709 slog.Info("Tool call started", "toolCall", event.ToolCall)
710 assistantMsg.AddToolCall(*event.ToolCall)
711 return a.messages.Update(ctx, *assistantMsg)
712 case provider.EventToolUseDelta:
713 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
714 return a.messages.Update(ctx, *assistantMsg)
715 case provider.EventToolUseStop:
716 slog.Info("Finished tool call", "toolCall", event.ToolCall)
717 assistantMsg.FinishToolCall(event.ToolCall.ID)
718 return a.messages.Update(ctx, *assistantMsg)
719 case provider.EventError:
720 return event.Error
721 case provider.EventComplete:
722 assistantMsg.FinishThinking()
723 assistantMsg.SetToolCalls(event.Response.ToolCalls)
724 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
725 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
726 return fmt.Errorf("failed to update message: %w", err)
727 }
728 return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
729 }
730
731 return nil
732}
733
734func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
735 sess, err := a.sessions.Get(ctx, sessionID)
736 if err != nil {
737 return fmt.Errorf("failed to get session: %w", err)
738 }
739
740 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
741 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
742 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
743 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
744
745 a.eventTokensUsed(sessionID, usage, cost)
746
747 sess.Cost += cost
748 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
749 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
750
751 _, err = a.sessions.Save(ctx, sess)
752 if err != nil {
753 return fmt.Errorf("failed to save session: %w", err)
754 }
755 return nil
756}
757
758func (a *agent) Summarize(ctx context.Context, sessionID string) error {
759 if a.summarizeProvider == nil {
760 return fmt.Errorf("summarize provider not available")
761 }
762
763 // Check if session is busy
764 if a.IsSessionBusy(sessionID) {
765 return ErrSessionBusy
766 }
767
768 // Create a new context with cancellation
769 summarizeCtx, cancel := context.WithCancel(ctx)
770
771 // Store the cancel function in activeRequests to allow cancellation
772 a.activeRequests.Set(sessionID+"-summarize", cancel)
773
774 go func() {
775 defer a.activeRequests.Del(sessionID + "-summarize")
776 defer cancel()
777 event := AgentEvent{
778 Type: AgentEventTypeSummarize,
779 Progress: "Starting summarization...",
780 }
781
782 a.Publish(pubsub.CreatedEvent, event)
783 // Get all messages from the session
784 msgs, err := a.messages.List(summarizeCtx, sessionID)
785 if err != nil {
786 event = AgentEvent{
787 Type: AgentEventTypeError,
788 Error: fmt.Errorf("failed to list messages: %w", err),
789 Done: true,
790 }
791 a.Publish(pubsub.CreatedEvent, event)
792 return
793 }
794 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
795
796 if len(msgs) == 0 {
797 event = AgentEvent{
798 Type: AgentEventTypeError,
799 Error: fmt.Errorf("no messages to summarize"),
800 Done: true,
801 }
802 a.Publish(pubsub.CreatedEvent, event)
803 return
804 }
805
806 event = AgentEvent{
807 Type: AgentEventTypeSummarize,
808 Progress: "Analyzing conversation...",
809 }
810 a.Publish(pubsub.CreatedEvent, event)
811
812 // Add a system message to guide the summarization
813 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
814
815 // Create a new message with the summarize prompt
816 promptMsg := message.Message{
817 Role: message.User,
818 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
819 }
820
821 // Append the prompt to the messages
822 msgsWithPrompt := append(msgs, promptMsg)
823
824 event = AgentEvent{
825 Type: AgentEventTypeSummarize,
826 Progress: "Generating summary...",
827 }
828
829 a.Publish(pubsub.CreatedEvent, event)
830
831 // Send the messages to the summarize provider
832 response := a.summarizeProvider.StreamResponse(
833 summarizeCtx,
834 msgsWithPrompt,
835 nil,
836 )
837 var finalResponse *provider.ProviderResponse
838 for r := range response {
839 if r.Error != nil {
840 event = AgentEvent{
841 Type: AgentEventTypeError,
842 Error: fmt.Errorf("failed to summarize: %w", r.Error),
843 Done: true,
844 }
845 a.Publish(pubsub.CreatedEvent, event)
846 return
847 }
848 finalResponse = r.Response
849 }
850
851 summary := strings.TrimSpace(finalResponse.Content)
852 if summary == "" {
853 event = AgentEvent{
854 Type: AgentEventTypeError,
855 Error: fmt.Errorf("empty summary returned"),
856 Done: true,
857 }
858 a.Publish(pubsub.CreatedEvent, event)
859 return
860 }
861 shell := shell.GetPersistentShell(a.cfg.WorkingDir())
862 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
863 event = AgentEvent{
864 Type: AgentEventTypeSummarize,
865 Progress: "Creating new session...",
866 }
867
868 a.Publish(pubsub.CreatedEvent, event)
869 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
870 if err != nil {
871 event = AgentEvent{
872 Type: AgentEventTypeError,
873 Error: fmt.Errorf("failed to get session: %w", err),
874 Done: true,
875 }
876
877 a.Publish(pubsub.CreatedEvent, event)
878 return
879 }
880 // Create a message in the new session with the summary
881 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
882 Role: message.Assistant,
883 Parts: []message.ContentPart{
884 message.TextContent{Text: summary},
885 message.Finish{
886 Reason: message.FinishReasonEndTurn,
887 Time: time.Now().Unix(),
888 },
889 },
890 Model: a.summarizeProvider.Model().ID,
891 Provider: a.summarizeProviderID,
892 })
893 if err != nil {
894 event = AgentEvent{
895 Type: AgentEventTypeError,
896 Error: fmt.Errorf("failed to create summary message: %w", err),
897 Done: true,
898 }
899
900 a.Publish(pubsub.CreatedEvent, event)
901 return
902 }
903 oldSession.SummaryMessageID = msg.ID
904 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
905 oldSession.PromptTokens = 0
906 model := a.summarizeProvider.Model()
907 usage := finalResponse.Usage
908 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
909 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
910 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
911 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
912 oldSession.Cost += cost
913 _, err = a.sessions.Save(summarizeCtx, oldSession)
914 if err != nil {
915 event = AgentEvent{
916 Type: AgentEventTypeError,
917 Error: fmt.Errorf("failed to save session: %w", err),
918 Done: true,
919 }
920 a.Publish(pubsub.CreatedEvent, event)
921 }
922
923 event = AgentEvent{
924 Type: AgentEventTypeSummarize,
925 SessionID: oldSession.ID,
926 Progress: "Summary complete",
927 Done: true,
928 }
929 a.Publish(pubsub.CreatedEvent, event)
930 // Send final success event with the new session ID
931 }()
932
933 return nil
934}
935
936func (a *agent) ClearQueue(sessionID string) {
937 if a.QueuedPrompts(sessionID) > 0 {
938 slog.Info("Clearing queued prompts", "session_id", sessionID)
939 a.promptQueue.Del(sessionID)
940 }
941}
942
943func (a *agent) CancelAll() {
944 if !a.IsBusy() {
945 return
946 }
947 for key := range a.activeRequests.Seq2() {
948 a.Cancel(key) // key is sessionID
949 }
950
951 timeout := time.After(5 * time.Second)
952 for a.IsBusy() {
953 select {
954 case <-timeout:
955 return
956 default:
957 time.Sleep(200 * time.Millisecond)
958 }
959 }
960}
961
962func (a *agent) UpdateModel() error {
963 // Get current provider configuration
964 currentProviderCfg := a.cfg.GetProviderForModel(a.agentCfg.Model)
965 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
966 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
967 }
968
969 // Check if provider has changed
970 if string(currentProviderCfg.ID) != a.providerID {
971 // Provider changed, need to recreate the main provider
972 model := a.cfg.GetModelByType(a.agentCfg.Model)
973 if model.ID == "" {
974 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
975 }
976
977 promptID := agentPromptMap[a.agentCfg.ID]
978 if promptID == "" {
979 promptID = prompt.PromptDefault
980 }
981
982 opts := []provider.ProviderClientOption{
983 provider.WithModel(a.agentCfg.Model),
984 provider.WithSystemMessage(prompt.GetPrompt(a.cfg, promptID, currentProviderCfg.ID, a.cfg.Options.ContextPaths...)),
985 }
986
987 newProvider, err := provider.NewProvider(a.cfg, *currentProviderCfg, opts...)
988 if err != nil {
989 return fmt.Errorf("failed to create new provider: %w", err)
990 }
991
992 // Update the provider and provider ID
993 a.provider = newProvider
994 a.providerID = string(currentProviderCfg.ID)
995 }
996
997 // Check if providers have changed for title (small) and summarize (large)
998 smallModelCfg := a.cfg.Models[config.SelectedModelTypeSmall]
999 var smallModelProviderCfg config.ProviderConfig
1000 for p := range a.cfg.Providers.Seq() {
1001 if p.ID == smallModelCfg.Provider {
1002 smallModelProviderCfg = p
1003 break
1004 }
1005 }
1006 if smallModelProviderCfg.ID == "" {
1007 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1008 }
1009
1010 largeModelCfg := a.cfg.Models[config.SelectedModelTypeLarge]
1011 var largeModelProviderCfg config.ProviderConfig
1012 for p := range a.cfg.Providers.Seq() {
1013 if p.ID == largeModelCfg.Provider {
1014 largeModelProviderCfg = p
1015 break
1016 }
1017 }
1018 if largeModelProviderCfg.ID == "" {
1019 return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1020 }
1021
1022 var maxTitleTokens int64 = 40
1023
1024 // if the max output is too low for the gemini provider it won't return anything
1025 if smallModelCfg.Provider == "gemini" {
1026 maxTitleTokens = 1000
1027 }
1028 // Recreate title provider
1029 titleOpts := []provider.ProviderClientOption{
1030 provider.WithModel(config.SelectedModelTypeSmall),
1031 provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptTitle, smallModelProviderCfg.ID)),
1032 provider.WithMaxTokens(maxTitleTokens),
1033 }
1034 newTitleProvider, err := provider.NewProvider(a.cfg, smallModelProviderCfg, titleOpts...)
1035 if err != nil {
1036 return fmt.Errorf("failed to create new title provider: %w", err)
1037 }
1038 a.titleProvider = newTitleProvider
1039
1040 // Recreate summarize provider if provider changed (now large model)
1041 if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1042 largeModel := a.cfg.GetModelByType(config.SelectedModelTypeLarge)
1043 if largeModel == nil {
1044 return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1045 }
1046 summarizeOpts := []provider.ProviderClientOption{
1047 provider.WithModel(config.SelectedModelTypeLarge),
1048 provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1049 }
1050 newSummarizeProvider, err := provider.NewProvider(a.cfg, largeModelProviderCfg, summarizeOpts...)
1051 if err != nil {
1052 return fmt.Errorf("failed to create new summarize provider: %w", err)
1053 }
1054 a.summarizeProvider = newSummarizeProvider
1055 a.summarizeProviderID = string(largeModelProviderCfg.ID)
1056 }
1057
1058 return nil
1059}