1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/notify"
36 "github.com/charmbracelet/crush/internal/agent/tools"
37 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
38 "github.com/charmbracelet/crush/internal/config"
39 "github.com/charmbracelet/crush/internal/csync"
40 "github.com/charmbracelet/crush/internal/message"
41 "github.com/charmbracelet/crush/internal/permission"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
68
69type SessionAgentCall struct {
70 SessionID string
71 Prompt string
72 ProviderOptions fantasy.ProviderOptions
73 Attachments []message.Attachment
74 MaxOutputTokens int64
75 Temperature *float64
76 TopP *float64
77 TopK *int64
78 FrequencyPenalty *float64
79 PresencePenalty *float64
80 NonInteractive bool
81}
82
83type SessionAgent interface {
84 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
85 SetModels(large Model, small Model)
86 SetTools(tools []fantasy.AgentTool)
87 SetSystemPrompt(systemPrompt string)
88 Cancel(sessionID string)
89 CancelAll()
90 IsSessionBusy(sessionID string) bool
91 IsBusy() bool
92 QueuedPrompts(sessionID string) int
93 QueuedPromptsList(sessionID string) []string
94 ClearQueue(sessionID string)
95 Summarize(context.Context, string, fantasy.ProviderOptions) error
96 Model() Model
97}
98
99type Model struct {
100 Model fantasy.LanguageModel
101 CatwalkCfg catwalk.Model
102 ModelCfg config.SelectedModel
103}
104
105type sessionAgent struct {
106 largeModel *csync.Value[Model]
107 smallModel *csync.Value[Model]
108 systemPromptPrefix *csync.Value[string]
109 systemPrompt *csync.Value[string]
110 tools *csync.Slice[fantasy.AgentTool]
111
112 isSubAgent bool
113 sessions session.Service
114 messages message.Service
115 disableAutoSummarize bool
116 isYolo bool
117 notify pubsub.Publisher[notify.Notification]
118
119 messageQueue *csync.Map[string, []SessionAgentCall]
120 activeRequests *csync.Map[string, context.CancelFunc]
121}
122
123type SessionAgentOptions struct {
124 LargeModel Model
125 SmallModel Model
126 SystemPromptPrefix string
127 SystemPrompt string
128 IsSubAgent bool
129 DisableAutoSummarize bool
130 IsYolo bool
131 Sessions session.Service
132 Messages message.Service
133 Tools []fantasy.AgentTool
134 Notify pubsub.Publisher[notify.Notification]
135}
136
137func NewSessionAgent(
138 opts SessionAgentOptions,
139) SessionAgent {
140 return &sessionAgent{
141 largeModel: csync.NewValue(opts.LargeModel),
142 smallModel: csync.NewValue(opts.SmallModel),
143 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
144 systemPrompt: csync.NewValue(opts.SystemPrompt),
145 isSubAgent: opts.IsSubAgent,
146 sessions: opts.Sessions,
147 messages: opts.Messages,
148 disableAutoSummarize: opts.DisableAutoSummarize,
149 tools: csync.NewSliceFrom(opts.Tools),
150 isYolo: opts.IsYolo,
151 notify: opts.Notify,
152 messageQueue: csync.NewMap[string, []SessionAgentCall](),
153 activeRequests: csync.NewMap[string, context.CancelFunc](),
154 }
155}
156
157func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
158 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
159 return nil, ErrEmptyPrompt
160 }
161 if call.SessionID == "" {
162 return nil, ErrSessionMissing
163 }
164
165 // Queue the message if busy
166 if a.IsSessionBusy(call.SessionID) {
167 existing, ok := a.messageQueue.Get(call.SessionID)
168 if !ok {
169 existing = []SessionAgentCall{}
170 }
171 existing = append(existing, call)
172 a.messageQueue.Set(call.SessionID, existing)
173 return nil, nil
174 }
175
176 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
177 agentTools := a.tools.Copy()
178 largeModel := a.largeModel.Get()
179 systemPrompt := a.systemPrompt.Get()
180 promptPrefix := a.systemPromptPrefix.Get()
181 var instructions strings.Builder
182
183 for _, server := range mcp.GetStates() {
184 if server.State != mcp.StateConnected {
185 continue
186 }
187 if s := server.Client.InitializeResult().Instructions; s != "" {
188 instructions.WriteString(s)
189 instructions.WriteString("\n\n")
190 }
191 }
192
193 if s := instructions.String(); s != "" {
194 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
195 }
196
197 if len(agentTools) > 0 {
198 // Add Anthropic caching to the last tool.
199 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
200 }
201
202 agent := fantasy.NewAgent(
203 largeModel.Model,
204 fantasy.WithSystemPrompt(systemPrompt),
205 fantasy.WithTools(agentTools...),
206 fantasy.WithUserAgent(userAgent),
207 )
208
209 sessionLock := sync.Mutex{}
210 currentSession, err := a.sessions.Get(ctx, call.SessionID)
211 if err != nil {
212 return nil, fmt.Errorf("failed to get session: %w", err)
213 }
214
215 msgs, err := a.getSessionMessages(ctx, currentSession)
216 if err != nil {
217 return nil, fmt.Errorf("failed to get session messages: %w", err)
218 }
219
220 var wg sync.WaitGroup
221 // Generate title if first message.
222 if len(msgs) == 0 {
223 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
224 wg.Go(func() {
225 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
226 })
227 }
228 defer wg.Wait()
229
230 // Add the user message to the session.
231 _, err = a.createUserMessage(ctx, call)
232 if err != nil {
233 return nil, err
234 }
235
236 // Add the session to the context.
237 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
238
239 genCtx, cancel := context.WithCancel(ctx)
240 a.activeRequests.Set(call.SessionID, cancel)
241
242 defer cancel()
243 defer a.activeRequests.Del(call.SessionID)
244
245 history, files := a.preparePrompt(msgs, call.Attachments...)
246
247 startTime := time.Now()
248 a.eventPromptSent(call.SessionID)
249
250 var currentAssistant *message.Message
251 var shouldSummarize bool
252 // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
253 var maxOutputTokens *int64
254 if call.MaxOutputTokens > 0 {
255 maxOutputTokens = &call.MaxOutputTokens
256 }
257 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
258 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
259 Files: files,
260 Messages: history,
261 ProviderOptions: call.ProviderOptions,
262 MaxOutputTokens: maxOutputTokens,
263 TopP: call.TopP,
264 Temperature: call.Temperature,
265 PresencePenalty: call.PresencePenalty,
266 TopK: call.TopK,
267 FrequencyPenalty: call.FrequencyPenalty,
268 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
269 prepared.Messages = options.Messages
270 for i := range prepared.Messages {
271 prepared.Messages[i].ProviderOptions = nil
272 }
273
274 // Use latest tools (updated by SetTools when MCP tools change).
275 prepared.Tools = a.tools.Copy()
276
277 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
278 a.messageQueue.Del(call.SessionID)
279 for _, queued := range queuedCalls {
280 userMessage, createErr := a.createUserMessage(callContext, queued)
281 if createErr != nil {
282 return callContext, prepared, createErr
283 }
284 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
285 }
286
287 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
288
289 lastSystemRoleInx := 0
290 systemMessageUpdated := false
291 for i, msg := range prepared.Messages {
292 // Only add cache control to the last message.
293 if msg.Role == fantasy.MessageRoleSystem {
294 lastSystemRoleInx = i
295 } else if !systemMessageUpdated {
296 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
297 systemMessageUpdated = true
298 }
299 // Than add cache control to the last 2 messages.
300 if i > len(prepared.Messages)-3 {
301 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
302 }
303 }
304
305 if promptPrefix != "" {
306 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
307 }
308
309 var assistantMsg message.Message
310 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
311 Role: message.Assistant,
312 Parts: []message.ContentPart{},
313 Model: largeModel.ModelCfg.Model,
314 Provider: largeModel.ModelCfg.Provider,
315 })
316 if err != nil {
317 return callContext, prepared, err
318 }
319 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
320 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
321 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
322 currentAssistant = &assistantMsg
323 return callContext, prepared, err
324 },
325 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
326 currentAssistant.AppendReasoningContent(reasoning.Text)
327 return a.messages.Update(genCtx, *currentAssistant)
328 },
329 OnReasoningDelta: func(id string, text string) error {
330 currentAssistant.AppendReasoningContent(text)
331 return a.messages.Update(genCtx, *currentAssistant)
332 },
333 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
334 // handle anthropic signature
335 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
336 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
337 currentAssistant.AppendReasoningSignature(reasoning.Signature)
338 }
339 }
340 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
341 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
342 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
343 }
344 }
345 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
346 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
347 currentAssistant.SetReasoningResponsesData(reasoning)
348 }
349 }
350 currentAssistant.FinishThinking()
351 return a.messages.Update(genCtx, *currentAssistant)
352 },
353 OnTextDelta: func(id string, text string) error {
354 // Strip leading newline from initial text content. This is is
355 // particularly important in non-interactive mode where leading
356 // newlines are very visible.
357 if len(currentAssistant.Parts) == 0 {
358 text = strings.TrimPrefix(text, "\n")
359 }
360
361 currentAssistant.AppendContent(text)
362 return a.messages.Update(genCtx, *currentAssistant)
363 },
364 OnToolInputStart: func(id string, toolName string) error {
365 toolCall := message.ToolCall{
366 ID: id,
367 Name: toolName,
368 ProviderExecuted: false,
369 Finished: false,
370 }
371 currentAssistant.AddToolCall(toolCall)
372 // Use parent ctx instead of genCtx to ensure the update succeeds
373 // even if the request is canceled mid-stream
374 return a.messages.Update(ctx, *currentAssistant)
375 },
376 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
377 // TODO: implement
378 },
379 OnToolCall: func(tc fantasy.ToolCallContent) error {
380 toolCall := message.ToolCall{
381 ID: tc.ToolCallID,
382 Name: tc.ToolName,
383 Input: tc.Input,
384 ProviderExecuted: false,
385 Finished: true,
386 }
387 currentAssistant.AddToolCall(toolCall)
388 // Use parent ctx instead of genCtx to ensure the update succeeds
389 // even if the request is canceled mid-stream
390 return a.messages.Update(ctx, *currentAssistant)
391 },
392 OnToolResult: func(result fantasy.ToolResultContent) error {
393 toolResult := a.convertToToolResult(result)
394 // Use parent ctx instead of genCtx to ensure the message is created
395 // even if the request is canceled mid-stream
396 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
397 Role: message.Tool,
398 Parts: []message.ContentPart{
399 toolResult,
400 },
401 })
402 return createMsgErr
403 },
404 OnStepFinish: func(stepResult fantasy.StepResult) error {
405 finishReason := message.FinishReasonUnknown
406 switch stepResult.FinishReason {
407 case fantasy.FinishReasonLength:
408 finishReason = message.FinishReasonMaxTokens
409 case fantasy.FinishReasonStop:
410 finishReason = message.FinishReasonEndTurn
411 case fantasy.FinishReasonToolCalls:
412 finishReason = message.FinishReasonToolUse
413 }
414 currentAssistant.AddFinish(finishReason, "", "")
415 sessionLock.Lock()
416 defer sessionLock.Unlock()
417
418 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
419 if getSessionErr != nil {
420 return getSessionErr
421 }
422 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
423 _, sessionErr := a.sessions.Save(ctx, updatedSession)
424 if sessionErr != nil {
425 return sessionErr
426 }
427 currentSession = updatedSession
428 return a.messages.Update(genCtx, *currentAssistant)
429 },
430 StopWhen: []fantasy.StopCondition{
431 func(_ []fantasy.StepResult) bool {
432 cw := int64(largeModel.CatwalkCfg.ContextWindow)
433 // If context window is unknown (0), skip auto-summarize
434 // to avoid immediately truncating custom/local models.
435 if cw == 0 {
436 return false
437 }
438 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
439 remaining := cw - tokens
440 var threshold int64
441 if cw > largeContextWindowThreshold {
442 threshold = largeContextWindowBuffer
443 } else {
444 threshold = int64(float64(cw) * smallContextWindowRatio)
445 }
446 if (remaining <= threshold) && !a.disableAutoSummarize {
447 shouldSummarize = true
448 return true
449 }
450 return false
451 },
452 func(steps []fantasy.StepResult) bool {
453 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
454 },
455 },
456 })
457
458 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
459
460 if err != nil {
461 isCancelErr := errors.Is(err, context.Canceled)
462 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
463 if currentAssistant == nil {
464 return result, err
465 }
466 // Ensure we finish thinking on error to close the reasoning state.
467 currentAssistant.FinishThinking()
468 toolCalls := currentAssistant.ToolCalls()
469 // INFO: we use the parent context here because the genCtx has been cancelled.
470 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
471 if createErr != nil {
472 return nil, createErr
473 }
474 for _, tc := range toolCalls {
475 if !tc.Finished {
476 tc.Finished = true
477 tc.Input = "{}"
478 currentAssistant.AddToolCall(tc)
479 updateErr := a.messages.Update(ctx, *currentAssistant)
480 if updateErr != nil {
481 return nil, updateErr
482 }
483 }
484
485 found := false
486 for _, msg := range msgs {
487 if msg.Role == message.Tool {
488 for _, tr := range msg.ToolResults() {
489 if tr.ToolCallID == tc.ID {
490 found = true
491 break
492 }
493 }
494 }
495 if found {
496 break
497 }
498 }
499 if found {
500 continue
501 }
502 content := "There was an error while executing the tool"
503 if isCancelErr {
504 content = "Error: user cancelled assistant tool calling"
505 } else if isPermissionErr {
506 content = "User denied permission"
507 }
508 toolResult := message.ToolResult{
509 ToolCallID: tc.ID,
510 Name: tc.Name,
511 Content: content,
512 IsError: true,
513 }
514 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
515 Role: message.Tool,
516 Parts: []message.ContentPart{
517 toolResult,
518 },
519 })
520 if createErr != nil {
521 return nil, createErr
522 }
523 }
524 var fantasyErr *fantasy.Error
525 var providerErr *fantasy.ProviderError
526 const defaultTitle = "Provider Error"
527 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
528 if isCancelErr {
529 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
530 } else if isPermissionErr {
531 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
532 } else if errors.Is(err, hyper.ErrUnauthorized) {
533 currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Please re-authenticate with Hyper. You can also run "crush auth" to re-authenticate.`)
534 if a.notify != nil {
535 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
536 SessionID: call.SessionID,
537 SessionTitle: currentSession.Title,
538 Type: notify.TypeReAuthenticate,
539 ProviderID: largeModel.ModelCfg.Provider,
540 })
541 }
542 } else if errors.Is(err, hyper.ErrNoCredits) {
543 url := hyper.BaseURL()
544 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
545 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
546 } else if errors.As(err, &providerErr) {
547 if providerErr.Message == "The requested model is not supported." {
548 url := "https://github.com/settings/copilot/features"
549 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
550 currentAssistant.AddFinish(
551 message.FinishReasonError,
552 "Copilot model not enabled",
553 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
554 )
555 } else {
556 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
557 }
558 } else if errors.As(err, &fantasyErr) {
559 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
560 } else {
561 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
562 }
563 // Note: we use the parent context here because the genCtx has been
564 // cancelled.
565 updateErr := a.messages.Update(ctx, *currentAssistant)
566 if updateErr != nil {
567 return nil, updateErr
568 }
569 return nil, err
570 }
571
572 // Send notification that agent has finished its turn (skip for
573 // nested/non-interactive sessions).
574 if !call.NonInteractive && a.notify != nil {
575 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
576 SessionID: call.SessionID,
577 SessionTitle: currentSession.Title,
578 Type: notify.TypeAgentFinished,
579 })
580 }
581
582 if shouldSummarize {
583 a.activeRequests.Del(call.SessionID)
584 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
585 return nil, summarizeErr
586 }
587 // If the agent wasn't done...
588 if len(currentAssistant.ToolCalls()) > 0 {
589 existing, ok := a.messageQueue.Get(call.SessionID)
590 if !ok {
591 existing = []SessionAgentCall{}
592 }
593 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
594 existing = append(existing, call)
595 a.messageQueue.Set(call.SessionID, existing)
596 }
597 }
598
599 // Release active request before processing queued messages.
600 a.activeRequests.Del(call.SessionID)
601 cancel()
602
603 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
604 if !ok || len(queuedMessages) == 0 {
605 return result, err
606 }
607 // There are queued messages restart the loop.
608 firstQueuedMessage := queuedMessages[0]
609 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
610 return a.Run(ctx, firstQueuedMessage)
611}
612
613func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
614 if a.IsSessionBusy(sessionID) {
615 return ErrSessionBusy
616 }
617
618 // Copy mutable fields under lock to avoid races with SetModels.
619 largeModel := a.largeModel.Get()
620 systemPromptPrefix := a.systemPromptPrefix.Get()
621
622 currentSession, err := a.sessions.Get(ctx, sessionID)
623 if err != nil {
624 return fmt.Errorf("failed to get session: %w", err)
625 }
626 msgs, err := a.getSessionMessages(ctx, currentSession)
627 if err != nil {
628 return err
629 }
630 if len(msgs) == 0 {
631 // Nothing to summarize.
632 return nil
633 }
634
635 aiMsgs, _ := a.preparePrompt(msgs)
636
637 genCtx, cancel := context.WithCancel(ctx)
638 a.activeRequests.Set(sessionID, cancel)
639 defer a.activeRequests.Del(sessionID)
640 defer cancel()
641
642 agent := fantasy.NewAgent(largeModel.Model,
643 fantasy.WithSystemPrompt(string(summaryPrompt)),
644 fantasy.WithUserAgent(userAgent),
645 )
646 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
647 Role: message.Assistant,
648 Model: largeModel.Model.Model(),
649 Provider: largeModel.Model.Provider(),
650 IsSummaryMessage: true,
651 })
652 if err != nil {
653 return err
654 }
655
656 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
657
658 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
659 Prompt: summaryPromptText,
660 Messages: aiMsgs,
661 ProviderOptions: opts,
662 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
663 prepared.Messages = options.Messages
664 if systemPromptPrefix != "" {
665 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
666 }
667 return callContext, prepared, nil
668 },
669 OnReasoningDelta: func(id string, text string) error {
670 summaryMessage.AppendReasoningContent(text)
671 return a.messages.Update(genCtx, summaryMessage)
672 },
673 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
674 // Handle anthropic signature.
675 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
676 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
677 summaryMessage.AppendReasoningSignature(signature.Signature)
678 }
679 }
680 summaryMessage.FinishThinking()
681 return a.messages.Update(genCtx, summaryMessage)
682 },
683 OnTextDelta: func(id, text string) error {
684 summaryMessage.AppendContent(text)
685 return a.messages.Update(genCtx, summaryMessage)
686 },
687 })
688 if err != nil {
689 isCancelErr := errors.Is(err, context.Canceled)
690 if isCancelErr {
691 // User cancelled summarize we need to remove the summary message.
692 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
693 return deleteErr
694 }
695 return err
696 }
697
698 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
699 err = a.messages.Update(genCtx, summaryMessage)
700 if err != nil {
701 return err
702 }
703
704 var openrouterCost *float64
705 for _, step := range resp.Steps {
706 stepCost := a.openrouterCost(step.ProviderMetadata)
707 if stepCost != nil {
708 newCost := *stepCost
709 if openrouterCost != nil {
710 newCost += *openrouterCost
711 }
712 openrouterCost = &newCost
713 }
714 }
715
716 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
717
718 // Just in case, get just the last usage info.
719 usage := resp.Response.Usage
720 currentSession.SummaryMessageID = summaryMessage.ID
721 currentSession.CompletionTokens = usage.OutputTokens
722 currentSession.PromptTokens = 0
723 _, err = a.sessions.Save(genCtx, currentSession)
724 return err
725}
726
727func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
728 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
729 return fantasy.ProviderOptions{}
730 }
731 return fantasy.ProviderOptions{
732 anthropic.Name: &anthropic.ProviderCacheControlOptions{
733 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
734 },
735 bedrock.Name: &anthropic.ProviderCacheControlOptions{
736 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
737 },
738 vercel.Name: &anthropic.ProviderCacheControlOptions{
739 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
740 },
741 }
742}
743
744func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
745 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
746 var attachmentParts []message.ContentPart
747 for _, attachment := range call.Attachments {
748 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
749 }
750 parts = append(parts, attachmentParts...)
751 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
752 Role: message.User,
753 Parts: parts,
754 })
755 if err != nil {
756 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
757 }
758 return msg, nil
759}
760
761func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
762 var history []fantasy.Message
763 if !a.isSubAgent {
764 history = append(history, fantasy.NewUserMessage(
765 fmt.Sprintf("<system_reminder>%s</system_reminder>",
766 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
767If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
768If not, please feel free to ignore. Again do not mention this message to the user.`,
769 ),
770 ))
771 }
772 for _, m := range msgs {
773 if len(m.Parts) == 0 {
774 continue
775 }
776 // Assistant message without content or tool calls (cancelled before it
777 // returned anything).
778 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
779 continue
780 }
781 history = append(history, m.ToAIMessage()...)
782 }
783
784 var files []fantasy.FilePart
785 for _, attachment := range attachments {
786 if attachment.IsText() {
787 continue
788 }
789 files = append(files, fantasy.FilePart{
790 Filename: attachment.FileName,
791 Data: attachment.Content,
792 MediaType: attachment.MimeType,
793 })
794 }
795
796 return history, files
797}
798
799func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
800 msgs, err := a.messages.List(ctx, session.ID)
801 if err != nil {
802 return nil, fmt.Errorf("failed to list messages: %w", err)
803 }
804
805 if session.SummaryMessageID != "" {
806 summaryMsgIndex := -1
807 for i, msg := range msgs {
808 if msg.ID == session.SummaryMessageID {
809 summaryMsgIndex = i
810 break
811 }
812 }
813 if summaryMsgIndex != -1 {
814 msgs = msgs[summaryMsgIndex:]
815 msgs[0].Role = message.User
816 }
817 }
818 return msgs, nil
819}
820
821// generateTitle generates a session titled based on the initial prompt.
822func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
823 if userPrompt == "" {
824 return
825 }
826
827 smallModel := a.smallModel.Get()
828 largeModel := a.largeModel.Get()
829 systemPromptPrefix := a.systemPromptPrefix.Get()
830
831 var maxOutputTokens int64 = 40
832 if smallModel.CatwalkCfg.CanReason {
833 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
834 }
835
836 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
837 return fantasy.NewAgent(m,
838 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
839 fantasy.WithMaxOutputTokens(tok),
840 fantasy.WithUserAgent(userAgent),
841 )
842 }
843
844 streamCall := fantasy.AgentStreamCall{
845 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
846 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
847 prepared.Messages = opts.Messages
848 if systemPromptPrefix != "" {
849 prepared.Messages = append([]fantasy.Message{
850 fantasy.NewSystemMessage(systemPromptPrefix),
851 }, prepared.Messages...)
852 }
853 return callCtx, prepared, nil
854 },
855 }
856
857 // Use the small model to generate the title.
858 model := smallModel
859 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
860 resp, err := agent.Stream(ctx, streamCall)
861 if err == nil {
862 // We successfully generated a title with the small model.
863 slog.Debug("Generated title with small model")
864 } else {
865 // It didn't work. Let's try with the big model.
866 slog.Error("Error generating title with small model; trying big model", "err", err)
867 model = largeModel
868 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
869 resp, err = agent.Stream(ctx, streamCall)
870 if err == nil {
871 slog.Debug("Generated title with large model")
872 } else {
873 // Welp, the large model didn't work either. Use the default
874 // session name and return.
875 slog.Error("Error generating title with large model", "err", err)
876 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
877 if saveErr != nil {
878 slog.Error("Failed to save session title", "error", saveErr)
879 }
880 return
881 }
882 }
883
884 if resp == nil {
885 // Actually, we didn't get a response so we can't. Use the default
886 // session name and return.
887 slog.Error("Response is nil; can't generate title")
888 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
889 if saveErr != nil {
890 slog.Error("Failed to save session title", "error", saveErr)
891 }
892 return
893 }
894
895 // Clean up title.
896 var title string
897 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
898
899 // Remove thinking tags if present.
900 title = thinkTagRegex.ReplaceAllString(title, "")
901
902 title = strings.TrimSpace(title)
903 title = cmp.Or(title, DefaultSessionName)
904
905 // Calculate usage and cost.
906 var openrouterCost *float64
907 for _, step := range resp.Steps {
908 stepCost := a.openrouterCost(step.ProviderMetadata)
909 if stepCost != nil {
910 newCost := *stepCost
911 if openrouterCost != nil {
912 newCost += *openrouterCost
913 }
914 openrouterCost = &newCost
915 }
916 }
917
918 modelConfig := model.CatwalkCfg
919 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
920 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
921 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
922 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
923
924 // Use override cost if available (e.g., from OpenRouter).
925 if openrouterCost != nil {
926 cost = *openrouterCost
927 }
928
929 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
930 completionTokens := resp.TotalUsage.OutputTokens
931
932 // Atomically update only title and usage fields to avoid overriding other
933 // concurrent session updates.
934 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
935 if saveErr != nil {
936 slog.Error("Failed to save session title and usage", "error", saveErr)
937 return
938 }
939}
940
941func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
942 openrouterMetadata, ok := metadata[openrouter.Name]
943 if !ok {
944 return nil
945 }
946
947 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
948 if !ok {
949 return nil
950 }
951 return &opts.Usage.Cost
952}
953
954func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
955 modelConfig := model.CatwalkCfg
956 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
957 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
958 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
959 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
960
961 a.eventTokensUsed(session.ID, model, usage, cost)
962
963 if overrideCost != nil {
964 session.Cost += *overrideCost
965 } else {
966 session.Cost += cost
967 }
968
969 session.CompletionTokens = usage.OutputTokens
970 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
971}
972
973func (a *sessionAgent) Cancel(sessionID string) {
974 // Cancel regular requests. Don't use Take() here - we need the entry to
975 // remain in activeRequests so IsBusy() returns true until the goroutine
976 // fully completes (including error handling that may access the DB).
977 // The defer in processRequest will clean up the entry.
978 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
979 slog.Debug("Request cancellation initiated", "session_id", sessionID)
980 cancel()
981 }
982
983 // Also check for summarize requests.
984 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
985 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
986 cancel()
987 }
988
989 if a.QueuedPrompts(sessionID) > 0 {
990 slog.Debug("Clearing queued prompts", "session_id", sessionID)
991 a.messageQueue.Del(sessionID)
992 }
993}
994
995func (a *sessionAgent) ClearQueue(sessionID string) {
996 if a.QueuedPrompts(sessionID) > 0 {
997 slog.Debug("Clearing queued prompts", "session_id", sessionID)
998 a.messageQueue.Del(sessionID)
999 }
1000}
1001
1002func (a *sessionAgent) CancelAll() {
1003 if !a.IsBusy() {
1004 return
1005 }
1006 for key := range a.activeRequests.Seq2() {
1007 a.Cancel(key) // key is sessionID
1008 }
1009
1010 timeout := time.After(5 * time.Second)
1011 for a.IsBusy() {
1012 select {
1013 case <-timeout:
1014 return
1015 default:
1016 time.Sleep(200 * time.Millisecond)
1017 }
1018 }
1019}
1020
1021func (a *sessionAgent) IsBusy() bool {
1022 var busy bool
1023 for cancelFunc := range a.activeRequests.Seq() {
1024 if cancelFunc != nil {
1025 busy = true
1026 break
1027 }
1028 }
1029 return busy
1030}
1031
1032func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1033 _, busy := a.activeRequests.Get(sessionID)
1034 return busy
1035}
1036
1037func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1038 l, ok := a.messageQueue.Get(sessionID)
1039 if !ok {
1040 return 0
1041 }
1042 return len(l)
1043}
1044
1045func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1046 l, ok := a.messageQueue.Get(sessionID)
1047 if !ok {
1048 return nil
1049 }
1050 prompts := make([]string, len(l))
1051 for i, call := range l {
1052 prompts[i] = call.Prompt
1053 }
1054 return prompts
1055}
1056
1057func (a *sessionAgent) SetModels(large Model, small Model) {
1058 a.largeModel.Set(large)
1059 a.smallModel.Set(small)
1060}
1061
1062func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1063 a.tools.SetSlice(tools)
1064}
1065
1066func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1067 a.systemPrompt.Set(systemPrompt)
1068}
1069
1070func (a *sessionAgent) Model() Model {
1071 return a.largeModel.Get()
1072}
1073
1074// convertToToolResult converts a fantasy tool result to a message tool result.
1075func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1076 baseResult := message.ToolResult{
1077 ToolCallID: result.ToolCallID,
1078 Name: result.ToolName,
1079 Metadata: result.ClientMetadata,
1080 }
1081
1082 switch result.Result.GetType() {
1083 case fantasy.ToolResultContentTypeText:
1084 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1085 baseResult.Content = r.Text
1086 }
1087 case fantasy.ToolResultContentTypeError:
1088 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1089 baseResult.Content = r.Error.Error()
1090 baseResult.IsError = true
1091 }
1092 case fantasy.ToolResultContentTypeMedia:
1093 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1094 content := r.Text
1095 if content == "" {
1096 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1097 }
1098 baseResult.Content = content
1099 baseResult.Data = r.Data
1100 baseResult.MIMEType = r.MediaType
1101 }
1102 }
1103
1104 return baseResult
1105}
1106
1107// workaroundProviderMediaLimitations converts media content in tool results to
1108// user messages for providers that don't natively support images in tool results.
1109//
1110// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1111// don't support sending images/media in tool result messages - they only accept
1112// text in tool results. However, they DO support images in user messages.
1113//
1114// If we send media in tool results to these providers, the API returns an error.
1115//
1116// Solution: For these providers, we:
1117// 1. Replace the media in the tool result with a text placeholder
1118// 2. Inject a user message immediately after with the image as a file attachment
1119// 3. This maintains the tool execution flow while working around API limitations
1120//
1121// Anthropic and Bedrock support images natively in tool results, so we skip
1122// this workaround for them.
1123//
1124// Example transformation:
1125//
1126// BEFORE: [tool result: image data]
1127// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1128func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1129 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1130 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1131
1132 if providerSupportsMedia {
1133 return messages
1134 }
1135
1136 convertedMessages := make([]fantasy.Message, 0, len(messages))
1137
1138 for _, msg := range messages {
1139 if msg.Role != fantasy.MessageRoleTool {
1140 convertedMessages = append(convertedMessages, msg)
1141 continue
1142 }
1143
1144 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1145 var mediaFiles []fantasy.FilePart
1146
1147 for _, part := range msg.Content {
1148 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1149 if !ok {
1150 textParts = append(textParts, part)
1151 continue
1152 }
1153
1154 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1155 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1156 if err != nil {
1157 slog.Warn("Failed to decode media data", "error", err)
1158 textParts = append(textParts, part)
1159 continue
1160 }
1161
1162 mediaFiles = append(mediaFiles, fantasy.FilePart{
1163 Data: decoded,
1164 MediaType: media.MediaType,
1165 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1166 })
1167
1168 textParts = append(textParts, fantasy.ToolResultPart{
1169 ToolCallID: toolResult.ToolCallID,
1170 Output: fantasy.ToolResultOutputContentText{
1171 Text: "[Image/media content loaded - see attached file]",
1172 },
1173 ProviderOptions: toolResult.ProviderOptions,
1174 })
1175 } else {
1176 textParts = append(textParts, part)
1177 }
1178 }
1179
1180 convertedMessages = append(convertedMessages, fantasy.Message{
1181 Role: fantasy.MessageRoleTool,
1182 Content: textParts,
1183 })
1184
1185 if len(mediaFiles) > 0 {
1186 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1187 "Here is the media content from the tool result:",
1188 mediaFiles...,
1189 ))
1190 }
1191 }
1192
1193 return convertedMessages
1194}
1195
1196// buildSummaryPrompt constructs the prompt text for session summarization.
1197func buildSummaryPrompt(todos []session.Todo) string {
1198 var sb strings.Builder
1199 sb.WriteString("Provide a detailed summary of our conversation above.")
1200 if len(todos) > 0 {
1201 sb.WriteString("\n\n## Current Todo List\n\n")
1202 for _, t := range todos {
1203 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1204 }
1205 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1206 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1207 }
1208 return sb.String()
1209}