1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/notify"
36 "github.com/charmbracelet/crush/internal/agent/tools"
37 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
38 "github.com/charmbracelet/crush/internal/config"
39 "github.com/charmbracelet/crush/internal/csync"
40 "github.com/charmbracelet/crush/internal/message"
41 "github.com/charmbracelet/crush/internal/permission"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var (
68 thinkTagRegex = regexp.MustCompile(`(?s)<think>.*?</think>`)
69 orphanThinkTagRegex = regexp.MustCompile(`</?think>`)
70)
71
72type SessionAgentCall struct {
73 SessionID string
74 Prompt string
75 ProviderOptions fantasy.ProviderOptions
76 Attachments []message.Attachment
77 MaxOutputTokens int64
78 Temperature *float64
79 TopP *float64
80 TopK *int64
81 FrequencyPenalty *float64
82 PresencePenalty *float64
83 NonInteractive bool
84}
85
86type SessionAgent interface {
87 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
88 SetModels(large Model, small Model)
89 SetTools(tools []fantasy.AgentTool)
90 SetSystemPrompt(systemPrompt string)
91 Cancel(sessionID string)
92 CancelAll()
93 IsSessionBusy(sessionID string) bool
94 IsBusy() bool
95 QueuedPrompts(sessionID string) int
96 QueuedPromptsList(sessionID string) []string
97 ClearQueue(sessionID string)
98 Summarize(context.Context, string, fantasy.ProviderOptions) error
99 Model() Model
100}
101
102type Model struct {
103 Model fantasy.LanguageModel
104 CatwalkCfg catwalk.Model
105 ModelCfg config.SelectedModel
106}
107
108type sessionAgent struct {
109 largeModel *csync.Value[Model]
110 smallModel *csync.Value[Model]
111 systemPromptPrefix *csync.Value[string]
112 systemPrompt *csync.Value[string]
113 tools *csync.Slice[fantasy.AgentTool]
114
115 isSubAgent bool
116 sessions session.Service
117 messages message.Service
118 disableAutoSummarize bool
119 isYolo bool
120 notify pubsub.Publisher[notify.Notification]
121
122 messageQueue *csync.Map[string, []SessionAgentCall]
123 activeRequests *csync.Map[string, context.CancelFunc]
124}
125
126type SessionAgentOptions struct {
127 LargeModel Model
128 SmallModel Model
129 SystemPromptPrefix string
130 SystemPrompt string
131 IsSubAgent bool
132 DisableAutoSummarize bool
133 IsYolo bool
134 Sessions session.Service
135 Messages message.Service
136 Tools []fantasy.AgentTool
137 Notify pubsub.Publisher[notify.Notification]
138}
139
140func NewSessionAgent(
141 opts SessionAgentOptions,
142) SessionAgent {
143 return &sessionAgent{
144 largeModel: csync.NewValue(opts.LargeModel),
145 smallModel: csync.NewValue(opts.SmallModel),
146 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
147 systemPrompt: csync.NewValue(opts.SystemPrompt),
148 isSubAgent: opts.IsSubAgent,
149 sessions: opts.Sessions,
150 messages: opts.Messages,
151 disableAutoSummarize: opts.DisableAutoSummarize,
152 tools: csync.NewSliceFrom(opts.Tools),
153 isYolo: opts.IsYolo,
154 notify: opts.Notify,
155 messageQueue: csync.NewMap[string, []SessionAgentCall](),
156 activeRequests: csync.NewMap[string, context.CancelFunc](),
157 }
158}
159
160func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
161 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
162 return nil, ErrEmptyPrompt
163 }
164 if call.SessionID == "" {
165 return nil, ErrSessionMissing
166 }
167
168 // Queue the message if busy
169 if a.IsSessionBusy(call.SessionID) {
170 existing, ok := a.messageQueue.Get(call.SessionID)
171 if !ok {
172 existing = []SessionAgentCall{}
173 }
174 existing = append(existing, call)
175 a.messageQueue.Set(call.SessionID, existing)
176 return nil, nil
177 }
178
179 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
180 agentTools := a.tools.Copy()
181 largeModel := a.largeModel.Get()
182 systemPrompt := a.systemPrompt.Get()
183 promptPrefix := a.systemPromptPrefix.Get()
184 var instructions strings.Builder
185
186 for _, server := range mcp.GetStates() {
187 if server.State != mcp.StateConnected {
188 continue
189 }
190 if s := server.Client.InitializeResult().Instructions; s != "" {
191 instructions.WriteString(s)
192 instructions.WriteString("\n\n")
193 }
194 }
195
196 if s := instructions.String(); s != "" {
197 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
198 }
199
200 if len(agentTools) > 0 {
201 // Add Anthropic caching to the last tool.
202 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
203 }
204
205 agent := fantasy.NewAgent(
206 largeModel.Model,
207 fantasy.WithSystemPrompt(systemPrompt),
208 fantasy.WithTools(agentTools...),
209 fantasy.WithUserAgent(userAgent),
210 )
211
212 sessionLock := sync.Mutex{}
213 currentSession, err := a.sessions.Get(ctx, call.SessionID)
214 if err != nil {
215 return nil, fmt.Errorf("failed to get session: %w", err)
216 }
217
218 msgs, err := a.getSessionMessages(ctx, currentSession)
219 if err != nil {
220 return nil, fmt.Errorf("failed to get session messages: %w", err)
221 }
222
223 var wg sync.WaitGroup
224 // Generate title if first message.
225 if len(msgs) == 0 {
226 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
227 wg.Go(func() {
228 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
229 })
230 }
231 defer wg.Wait()
232
233 // Add the user message to the session.
234 _, err = a.createUserMessage(ctx, call)
235 if err != nil {
236 return nil, err
237 }
238
239 // Add the session to the context.
240 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
241
242 genCtx, cancel := context.WithCancel(ctx)
243 a.activeRequests.Set(call.SessionID, cancel)
244
245 defer cancel()
246 defer a.activeRequests.Del(call.SessionID)
247
248 history, files := a.preparePrompt(msgs, call.Attachments...)
249
250 startTime := time.Now()
251 a.eventPromptSent(call.SessionID)
252
253 var currentAssistant *message.Message
254 var shouldSummarize bool
255 // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
256 var maxOutputTokens *int64
257 if call.MaxOutputTokens > 0 {
258 maxOutputTokens = &call.MaxOutputTokens
259 }
260 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
261 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
262 Files: files,
263 Messages: history,
264 ProviderOptions: call.ProviderOptions,
265 MaxOutputTokens: maxOutputTokens,
266 TopP: call.TopP,
267 Temperature: call.Temperature,
268 PresencePenalty: call.PresencePenalty,
269 TopK: call.TopK,
270 FrequencyPenalty: call.FrequencyPenalty,
271 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
272 prepared.Messages = options.Messages
273 for i := range prepared.Messages {
274 prepared.Messages[i].ProviderOptions = nil
275 }
276
277 // Use latest tools (updated by SetTools when MCP tools change).
278 prepared.Tools = a.tools.Copy()
279
280 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
281 a.messageQueue.Del(call.SessionID)
282 for _, queued := range queuedCalls {
283 userMessage, createErr := a.createUserMessage(callContext, queued)
284 if createErr != nil {
285 return callContext, prepared, createErr
286 }
287 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
288 }
289
290 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
291
292 lastSystemRoleInx := 0
293 systemMessageUpdated := false
294 for i, msg := range prepared.Messages {
295 // Only add cache control to the last message.
296 if msg.Role == fantasy.MessageRoleSystem {
297 lastSystemRoleInx = i
298 } else if !systemMessageUpdated {
299 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
300 systemMessageUpdated = true
301 }
302 // Than add cache control to the last 2 messages.
303 if i > len(prepared.Messages)-3 {
304 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
305 }
306 }
307
308 if promptPrefix != "" {
309 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
310 }
311
312 var assistantMsg message.Message
313 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
314 Role: message.Assistant,
315 Parts: []message.ContentPart{},
316 Model: largeModel.ModelCfg.Model,
317 Provider: largeModel.ModelCfg.Provider,
318 })
319 if err != nil {
320 return callContext, prepared, err
321 }
322 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
323 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
324 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
325 currentAssistant = &assistantMsg
326 return callContext, prepared, err
327 },
328 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
329 currentAssistant.AppendReasoningContent(reasoning.Text)
330 return a.messages.Update(genCtx, *currentAssistant)
331 },
332 OnReasoningDelta: func(id string, text string) error {
333 currentAssistant.AppendReasoningContent(text)
334 return a.messages.Update(genCtx, *currentAssistant)
335 },
336 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
337 // handle anthropic signature
338 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
339 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
340 currentAssistant.AppendReasoningSignature(reasoning.Signature)
341 }
342 }
343 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
344 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
345 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
346 }
347 }
348 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
349 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
350 currentAssistant.SetReasoningResponsesData(reasoning)
351 }
352 }
353 currentAssistant.FinishThinking()
354 return a.messages.Update(genCtx, *currentAssistant)
355 },
356 OnTextDelta: func(id string, text string) error {
357 // Strip leading newline from initial text content. This is is
358 // particularly important in non-interactive mode where leading
359 // newlines are very visible.
360 if len(currentAssistant.Parts) == 0 {
361 text = strings.TrimPrefix(text, "\n")
362 }
363
364 currentAssistant.AppendContent(text)
365 return a.messages.Update(genCtx, *currentAssistant)
366 },
367 OnToolInputStart: func(id string, toolName string) error {
368 toolCall := message.ToolCall{
369 ID: id,
370 Name: toolName,
371 ProviderExecuted: false,
372 Finished: false,
373 }
374 currentAssistant.AddToolCall(toolCall)
375 // Use parent ctx instead of genCtx to ensure the update succeeds
376 // even if the request is canceled mid-stream
377 return a.messages.Update(ctx, *currentAssistant)
378 },
379 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
380 // TODO: implement
381 },
382 OnToolCall: func(tc fantasy.ToolCallContent) error {
383 toolCall := message.ToolCall{
384 ID: tc.ToolCallID,
385 Name: tc.ToolName,
386 Input: tc.Input,
387 ProviderExecuted: false,
388 Finished: true,
389 }
390 currentAssistant.AddToolCall(toolCall)
391 // Use parent ctx instead of genCtx to ensure the update succeeds
392 // even if the request is canceled mid-stream
393 return a.messages.Update(ctx, *currentAssistant)
394 },
395 OnToolResult: func(result fantasy.ToolResultContent) error {
396 toolResult := a.convertToToolResult(result)
397 // Use parent ctx instead of genCtx to ensure the message is created
398 // even if the request is canceled mid-stream
399 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
400 Role: message.Tool,
401 Parts: []message.ContentPart{
402 toolResult,
403 },
404 })
405 return createMsgErr
406 },
407 OnStepFinish: func(stepResult fantasy.StepResult) error {
408 finishReason := message.FinishReasonUnknown
409 switch stepResult.FinishReason {
410 case fantasy.FinishReasonLength:
411 finishReason = message.FinishReasonMaxTokens
412 case fantasy.FinishReasonStop:
413 finishReason = message.FinishReasonEndTurn
414 case fantasy.FinishReasonToolCalls:
415 finishReason = message.FinishReasonToolUse
416 }
417 currentAssistant.AddFinish(finishReason, "", "")
418 sessionLock.Lock()
419 defer sessionLock.Unlock()
420
421 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
422 if getSessionErr != nil {
423 return getSessionErr
424 }
425 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
426 _, sessionErr := a.sessions.Save(ctx, updatedSession)
427 if sessionErr != nil {
428 return sessionErr
429 }
430 currentSession = updatedSession
431 return a.messages.Update(genCtx, *currentAssistant)
432 },
433 StopWhen: []fantasy.StopCondition{
434 func(_ []fantasy.StepResult) bool {
435 cw := int64(largeModel.CatwalkCfg.ContextWindow)
436 // If context window is unknown (0), skip auto-summarize
437 // to avoid immediately truncating custom/local models.
438 if cw == 0 {
439 return false
440 }
441 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
442 remaining := cw - tokens
443 var threshold int64
444 if cw > largeContextWindowThreshold {
445 threshold = largeContextWindowBuffer
446 } else {
447 threshold = int64(float64(cw) * smallContextWindowRatio)
448 }
449 if (remaining <= threshold) && !a.disableAutoSummarize {
450 shouldSummarize = true
451 return true
452 }
453 return false
454 },
455 func(steps []fantasy.StepResult) bool {
456 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
457 },
458 },
459 })
460
461 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
462
463 if err != nil {
464 isCancelErr := errors.Is(err, context.Canceled)
465 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
466 if currentAssistant == nil {
467 return result, err
468 }
469 // Ensure we finish thinking on error to close the reasoning state.
470 currentAssistant.FinishThinking()
471 toolCalls := currentAssistant.ToolCalls()
472 // INFO: we use the parent context here because the genCtx has been cancelled.
473 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
474 if createErr != nil {
475 return nil, createErr
476 }
477 for _, tc := range toolCalls {
478 if !tc.Finished {
479 tc.Finished = true
480 tc.Input = "{}"
481 currentAssistant.AddToolCall(tc)
482 updateErr := a.messages.Update(ctx, *currentAssistant)
483 if updateErr != nil {
484 return nil, updateErr
485 }
486 }
487
488 found := false
489 for _, msg := range msgs {
490 if msg.Role == message.Tool {
491 for _, tr := range msg.ToolResults() {
492 if tr.ToolCallID == tc.ID {
493 found = true
494 break
495 }
496 }
497 }
498 if found {
499 break
500 }
501 }
502 if found {
503 continue
504 }
505 content := "There was an error while executing the tool"
506 if isCancelErr {
507 content = "Error: user cancelled assistant tool calling"
508 } else if isPermissionErr {
509 content = "User denied permission"
510 }
511 toolResult := message.ToolResult{
512 ToolCallID: tc.ID,
513 Name: tc.Name,
514 Content: content,
515 IsError: true,
516 }
517 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
518 Role: message.Tool,
519 Parts: []message.ContentPart{
520 toolResult,
521 },
522 })
523 if createErr != nil {
524 return nil, createErr
525 }
526 }
527 var fantasyErr *fantasy.Error
528 var providerErr *fantasy.ProviderError
529 const defaultTitle = "Provider Error"
530 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
531 if isCancelErr {
532 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
533 } else if isPermissionErr {
534 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
535 } else if errors.Is(err, hyper.ErrUnauthorized) {
536 currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Please re-authenticate with Hyper. You can also run "crush auth" to re-authenticate.`)
537 if a.notify != nil {
538 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
539 SessionID: call.SessionID,
540 SessionTitle: currentSession.Title,
541 Type: notify.TypeReAuthenticate,
542 ProviderID: largeModel.ModelCfg.Provider,
543 })
544 }
545 } else if errors.Is(err, hyper.ErrNoCredits) {
546 url := hyper.BaseURL()
547 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
548 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
549 } else if errors.As(err, &providerErr) {
550 if providerErr.Message == "The requested model is not supported." {
551 url := "https://github.com/settings/copilot/features"
552 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
553 currentAssistant.AddFinish(
554 message.FinishReasonError,
555 "Copilot model not enabled",
556 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
557 )
558 } else {
559 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
560 }
561 } else if errors.As(err, &fantasyErr) {
562 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
563 } else {
564 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
565 }
566 // Note: we use the parent context here because the genCtx has been
567 // cancelled.
568 updateErr := a.messages.Update(ctx, *currentAssistant)
569 if updateErr != nil {
570 return nil, updateErr
571 }
572 return nil, err
573 }
574
575 // Send notification that agent has finished its turn (skip for
576 // nested/non-interactive sessions).
577 if !call.NonInteractive && a.notify != nil {
578 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
579 SessionID: call.SessionID,
580 SessionTitle: currentSession.Title,
581 Type: notify.TypeAgentFinished,
582 })
583 }
584
585 if shouldSummarize {
586 a.activeRequests.Del(call.SessionID)
587 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
588 return nil, summarizeErr
589 }
590 // If the agent wasn't done...
591 if len(currentAssistant.ToolCalls()) > 0 {
592 existing, ok := a.messageQueue.Get(call.SessionID)
593 if !ok {
594 existing = []SessionAgentCall{}
595 }
596 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
597 existing = append(existing, call)
598 a.messageQueue.Set(call.SessionID, existing)
599 }
600 }
601
602 // Release active request before processing queued messages.
603 a.activeRequests.Del(call.SessionID)
604 cancel()
605
606 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
607 if !ok || len(queuedMessages) == 0 {
608 return result, err
609 }
610 // There are queued messages restart the loop.
611 firstQueuedMessage := queuedMessages[0]
612 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
613 return a.Run(ctx, firstQueuedMessage)
614}
615
616func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
617 if a.IsSessionBusy(sessionID) {
618 return ErrSessionBusy
619 }
620
621 // Copy mutable fields under lock to avoid races with SetModels.
622 largeModel := a.largeModel.Get()
623 systemPromptPrefix := a.systemPromptPrefix.Get()
624
625 currentSession, err := a.sessions.Get(ctx, sessionID)
626 if err != nil {
627 return fmt.Errorf("failed to get session: %w", err)
628 }
629 msgs, err := a.getSessionMessages(ctx, currentSession)
630 if err != nil {
631 return err
632 }
633 if len(msgs) == 0 {
634 // Nothing to summarize.
635 return nil
636 }
637
638 aiMsgs, _ := a.preparePrompt(msgs)
639
640 genCtx, cancel := context.WithCancel(ctx)
641 a.activeRequests.Set(sessionID, cancel)
642 defer a.activeRequests.Del(sessionID)
643 defer cancel()
644
645 agent := fantasy.NewAgent(largeModel.Model,
646 fantasy.WithSystemPrompt(string(summaryPrompt)),
647 fantasy.WithUserAgent(userAgent),
648 )
649 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
650 Role: message.Assistant,
651 Model: largeModel.Model.Model(),
652 Provider: largeModel.Model.Provider(),
653 IsSummaryMessage: true,
654 })
655 if err != nil {
656 return err
657 }
658
659 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
660
661 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
662 Prompt: summaryPromptText,
663 Messages: aiMsgs,
664 ProviderOptions: opts,
665 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
666 prepared.Messages = options.Messages
667 if systemPromptPrefix != "" {
668 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
669 }
670 return callContext, prepared, nil
671 },
672 OnReasoningDelta: func(id string, text string) error {
673 summaryMessage.AppendReasoningContent(text)
674 return a.messages.Update(genCtx, summaryMessage)
675 },
676 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
677 // Handle anthropic signature.
678 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
679 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
680 summaryMessage.AppendReasoningSignature(signature.Signature)
681 }
682 }
683 summaryMessage.FinishThinking()
684 return a.messages.Update(genCtx, summaryMessage)
685 },
686 OnTextDelta: func(id, text string) error {
687 summaryMessage.AppendContent(text)
688 return a.messages.Update(genCtx, summaryMessage)
689 },
690 })
691 if err != nil {
692 isCancelErr := errors.Is(err, context.Canceled)
693 if isCancelErr {
694 // User cancelled summarize we need to remove the summary message.
695 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
696 return deleteErr
697 }
698 return err
699 }
700
701 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
702 err = a.messages.Update(genCtx, summaryMessage)
703 if err != nil {
704 return err
705 }
706
707 var openrouterCost *float64
708 for _, step := range resp.Steps {
709 stepCost := a.openrouterCost(step.ProviderMetadata)
710 if stepCost != nil {
711 newCost := *stepCost
712 if openrouterCost != nil {
713 newCost += *openrouterCost
714 }
715 openrouterCost = &newCost
716 }
717 }
718
719 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
720
721 // Just in case, get just the last usage info.
722 usage := resp.Response.Usage
723 currentSession.SummaryMessageID = summaryMessage.ID
724 currentSession.CompletionTokens = usage.OutputTokens
725 currentSession.PromptTokens = 0
726 _, err = a.sessions.Save(genCtx, currentSession)
727 return err
728}
729
730func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
731 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
732 return fantasy.ProviderOptions{}
733 }
734 return fantasy.ProviderOptions{
735 anthropic.Name: &anthropic.ProviderCacheControlOptions{
736 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
737 },
738 bedrock.Name: &anthropic.ProviderCacheControlOptions{
739 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
740 },
741 vercel.Name: &anthropic.ProviderCacheControlOptions{
742 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
743 },
744 }
745}
746
747func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
748 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
749 var attachmentParts []message.ContentPart
750 for _, attachment := range call.Attachments {
751 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
752 }
753 parts = append(parts, attachmentParts...)
754 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
755 Role: message.User,
756 Parts: parts,
757 })
758 if err != nil {
759 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
760 }
761 return msg, nil
762}
763
764func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
765 var history []fantasy.Message
766 if !a.isSubAgent {
767 history = append(history, fantasy.NewUserMessage(
768 fmt.Sprintf("<system_reminder>%s</system_reminder>",
769 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
770If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
771If not, please feel free to ignore. Again do not mention this message to the user.`,
772 ),
773 ))
774 }
775 // Collect all tool call IDs present in assistant messages.
776 knownToolCallIDs := make(map[string]struct{})
777 for _, m := range msgs {
778 if m.Role == message.Assistant {
779 for _, tc := range m.ToolCalls() {
780 knownToolCallIDs[tc.ID] = struct{}{}
781 }
782 }
783 }
784
785 for _, m := range msgs {
786 if len(m.Parts) == 0 {
787 continue
788 }
789 // Assistant message without content or tool calls (cancelled before it
790 // returned anything).
791 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
792 continue
793 }
794 if m.Role == message.Tool {
795 if msg, ok := filterOrphanedToolResults(m, knownToolCallIDs); ok {
796 history = append(history, msg)
797 }
798 continue
799 }
800 history = append(history, m.ToAIMessage()...)
801 }
802
803 var files []fantasy.FilePart
804 for _, attachment := range attachments {
805 if attachment.IsText() {
806 continue
807 }
808 files = append(files, fantasy.FilePart{
809 Filename: attachment.FileName,
810 Data: attachment.Content,
811 MediaType: attachment.MimeType,
812 })
813 }
814
815 return history, files
816}
817
818// filterOrphanedToolResults converts a tool message to a fantasy.Message,
819// dropping any tool result parts whose tool_call_id has no matching tool call
820// in the known set. An orphaned result causes API validation to fail on every
821// subsequent turn, permanently locking the session. Returns the filtered
822// message and true if at least one valid part remains.
823func filterOrphanedToolResults(m message.Message, knownToolCallIDs map[string]struct{}) (fantasy.Message, bool) {
824 aiMsgs := m.ToAIMessage()
825 if len(aiMsgs) == 0 {
826 return fantasy.Message{}, false
827 }
828 var validParts []fantasy.MessagePart
829 for _, part := range aiMsgs[0].Content {
830 tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
831 if !ok {
832 validParts = append(validParts, part)
833 continue
834 }
835 if _, known := knownToolCallIDs[tr.ToolCallID]; known {
836 validParts = append(validParts, part)
837 } else {
838 slog.Warn("Dropping orphaned tool result with no matching tool call",
839 "tool_call_id", tr.ToolCallID,
840 )
841 }
842 }
843 if len(validParts) == 0 {
844 return fantasy.Message{}, false
845 }
846 msg := aiMsgs[0]
847 msg.Content = validParts
848 return msg, true
849}
850
851func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
852 msgs, err := a.messages.List(ctx, session.ID)
853 if err != nil {
854 return nil, fmt.Errorf("failed to list messages: %w", err)
855 }
856
857 if session.SummaryMessageID != "" {
858 summaryMsgIndex := -1
859 for i, msg := range msgs {
860 if msg.ID == session.SummaryMessageID {
861 summaryMsgIndex = i
862 break
863 }
864 }
865 if summaryMsgIndex != -1 {
866 msgs = msgs[summaryMsgIndex:]
867 msgs[0].Role = message.User
868 }
869 }
870 return msgs, nil
871}
872
873// generateTitle generates a session titled based on the initial prompt.
874func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
875 if userPrompt == "" {
876 return
877 }
878
879 smallModel := a.smallModel.Get()
880 largeModel := a.largeModel.Get()
881 systemPromptPrefix := a.systemPromptPrefix.Get()
882
883 var maxOutputTokens int64 = 40
884 if smallModel.CatwalkCfg.CanReason {
885 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
886 }
887
888 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
889 return fantasy.NewAgent(m,
890 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
891 fantasy.WithMaxOutputTokens(tok),
892 fantasy.WithUserAgent(userAgent),
893 )
894 }
895
896 streamCall := fantasy.AgentStreamCall{
897 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
898 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
899 prepared.Messages = opts.Messages
900 if systemPromptPrefix != "" {
901 prepared.Messages = append([]fantasy.Message{
902 fantasy.NewSystemMessage(systemPromptPrefix),
903 }, prepared.Messages...)
904 }
905 return callCtx, prepared, nil
906 },
907 }
908
909 // Use the small model to generate the title.
910 model := smallModel
911 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
912 resp, err := agent.Stream(ctx, streamCall)
913 if err == nil {
914 // We successfully generated a title with the small model.
915 slog.Debug("Generated title with small model")
916 } else {
917 // It didn't work. Let's try with the big model.
918 slog.Error("Error generating title with small model; trying big model", "err", err)
919 model = largeModel
920 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
921 resp, err = agent.Stream(ctx, streamCall)
922 if err == nil {
923 slog.Debug("Generated title with large model")
924 } else {
925 // Welp, the large model didn't work either. Use the default
926 // session name and return.
927 slog.Error("Error generating title with large model", "err", err)
928 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
929 if saveErr != nil {
930 slog.Error("Failed to save session title", "error", saveErr)
931 }
932 return
933 }
934 }
935
936 if resp == nil {
937 // Actually, we didn't get a response so we can't. Use the default
938 // session name and return.
939 slog.Error("Response is nil; can't generate title")
940 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
941 if saveErr != nil {
942 slog.Error("Failed to save session title", "error", saveErr)
943 }
944 return
945 }
946
947 // Clean up title.
948 var title string
949 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
950
951 // Remove thinking tags if present.
952 title = thinkTagRegex.ReplaceAllString(title, "")
953 title = orphanThinkTagRegex.ReplaceAllString(title, "")
954
955 title = strings.TrimSpace(title)
956 title = cmp.Or(title, DefaultSessionName)
957
958 // Calculate usage and cost.
959 var openrouterCost *float64
960 for _, step := range resp.Steps {
961 stepCost := a.openrouterCost(step.ProviderMetadata)
962 if stepCost != nil {
963 newCost := *stepCost
964 if openrouterCost != nil {
965 newCost += *openrouterCost
966 }
967 openrouterCost = &newCost
968 }
969 }
970
971 modelConfig := model.CatwalkCfg
972 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
973 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
974 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
975 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
976
977 // Use override cost if available (e.g., from OpenRouter).
978 if openrouterCost != nil {
979 cost = *openrouterCost
980 }
981
982 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
983 completionTokens := resp.TotalUsage.OutputTokens
984
985 // Atomically update only title and usage fields to avoid overriding other
986 // concurrent session updates.
987 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
988 if saveErr != nil {
989 slog.Error("Failed to save session title and usage", "error", saveErr)
990 return
991 }
992}
993
994func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
995 openrouterMetadata, ok := metadata[openrouter.Name]
996 if !ok {
997 return nil
998 }
999
1000 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
1001 if !ok {
1002 return nil
1003 }
1004 return &opts.Usage.Cost
1005}
1006
1007func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
1008 modelConfig := model.CatwalkCfg
1009 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
1010 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
1011 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
1012 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
1013
1014 a.eventTokensUsed(session.ID, model, usage, cost)
1015
1016 if overrideCost != nil {
1017 session.Cost += *overrideCost
1018 } else {
1019 session.Cost += cost
1020 }
1021
1022 session.CompletionTokens = usage.OutputTokens
1023 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
1024}
1025
1026func (a *sessionAgent) Cancel(sessionID string) {
1027 // Cancel regular requests. Don't use Take() here - we need the entry to
1028 // remain in activeRequests so IsBusy() returns true until the goroutine
1029 // fully completes (including error handling that may access the DB).
1030 // The defer in processRequest will clean up the entry.
1031 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
1032 slog.Debug("Request cancellation initiated", "session_id", sessionID)
1033 cancel()
1034 }
1035
1036 // Also check for summarize requests.
1037 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
1038 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
1039 cancel()
1040 }
1041
1042 if a.QueuedPrompts(sessionID) > 0 {
1043 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1044 a.messageQueue.Del(sessionID)
1045 }
1046}
1047
1048func (a *sessionAgent) ClearQueue(sessionID string) {
1049 if a.QueuedPrompts(sessionID) > 0 {
1050 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1051 a.messageQueue.Del(sessionID)
1052 }
1053}
1054
1055func (a *sessionAgent) CancelAll() {
1056 if !a.IsBusy() {
1057 return
1058 }
1059 for key := range a.activeRequests.Seq2() {
1060 a.Cancel(key) // key is sessionID
1061 }
1062
1063 timeout := time.After(5 * time.Second)
1064 for a.IsBusy() {
1065 select {
1066 case <-timeout:
1067 return
1068 default:
1069 time.Sleep(200 * time.Millisecond)
1070 }
1071 }
1072}
1073
1074func (a *sessionAgent) IsBusy() bool {
1075 var busy bool
1076 for cancelFunc := range a.activeRequests.Seq() {
1077 if cancelFunc != nil {
1078 busy = true
1079 break
1080 }
1081 }
1082 return busy
1083}
1084
1085func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1086 _, busy := a.activeRequests.Get(sessionID)
1087 return busy
1088}
1089
1090func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1091 l, ok := a.messageQueue.Get(sessionID)
1092 if !ok {
1093 return 0
1094 }
1095 return len(l)
1096}
1097
1098func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1099 l, ok := a.messageQueue.Get(sessionID)
1100 if !ok {
1101 return nil
1102 }
1103 prompts := make([]string, len(l))
1104 for i, call := range l {
1105 prompts[i] = call.Prompt
1106 }
1107 return prompts
1108}
1109
1110func (a *sessionAgent) SetModels(large Model, small Model) {
1111 a.largeModel.Set(large)
1112 a.smallModel.Set(small)
1113}
1114
1115func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1116 a.tools.SetSlice(tools)
1117}
1118
1119func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1120 a.systemPrompt.Set(systemPrompt)
1121}
1122
1123func (a *sessionAgent) Model() Model {
1124 return a.largeModel.Get()
1125}
1126
1127// convertToToolResult converts a fantasy tool result to a message tool result.
1128func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1129 baseResult := message.ToolResult{
1130 ToolCallID: result.ToolCallID,
1131 Name: result.ToolName,
1132 Metadata: result.ClientMetadata,
1133 }
1134
1135 switch result.Result.GetType() {
1136 case fantasy.ToolResultContentTypeText:
1137 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1138 baseResult.Content = r.Text
1139 }
1140 case fantasy.ToolResultContentTypeError:
1141 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1142 baseResult.Content = r.Error.Error()
1143 baseResult.IsError = true
1144 }
1145 case fantasy.ToolResultContentTypeMedia:
1146 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1147 if !stringext.IsValidBase64(r.Data) {
1148 slog.Warn("Tool returned media with invalid base64 data, discarding image",
1149 "tool", result.ToolName,
1150 "tool_call_id", result.ToolCallID,
1151 )
1152 baseResult.Content = "Tool returned image data with invalid encoding"
1153 baseResult.IsError = true
1154 } else {
1155 content := r.Text
1156 if content == "" {
1157 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1158 }
1159 baseResult.Content = content
1160 baseResult.Data = r.Data
1161 baseResult.MIMEType = r.MediaType
1162 }
1163 }
1164 }
1165
1166 return baseResult
1167}
1168
1169// workaroundProviderMediaLimitations converts media content in tool results to
1170// user messages for providers that don't natively support images in tool results.
1171//
1172// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1173// don't support sending images/media in tool result messages - they only accept
1174// text in tool results. However, they DO support images in user messages.
1175//
1176// If we send media in tool results to these providers, the API returns an error.
1177//
1178// Solution: For these providers, we:
1179// 1. Replace the media in the tool result with a text placeholder
1180// 2. Inject a user message immediately after with the image as a file attachment
1181// 3. This maintains the tool execution flow while working around API limitations
1182//
1183// Anthropic and Bedrock support images natively in tool results, so we skip
1184// this workaround for them.
1185//
1186// Example transformation:
1187//
1188// BEFORE: [tool result: image data]
1189// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1190func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1191 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1192 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1193
1194 if providerSupportsMedia {
1195 return messages
1196 }
1197
1198 convertedMessages := make([]fantasy.Message, 0, len(messages))
1199
1200 for _, msg := range messages {
1201 if msg.Role != fantasy.MessageRoleTool {
1202 convertedMessages = append(convertedMessages, msg)
1203 continue
1204 }
1205
1206 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1207 var mediaFiles []fantasy.FilePart
1208
1209 for _, part := range msg.Content {
1210 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1211 if !ok {
1212 textParts = append(textParts, part)
1213 continue
1214 }
1215
1216 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1217 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1218 if err != nil {
1219 slog.Warn("Failed to decode media data", "error", err)
1220 textParts = append(textParts, part)
1221 continue
1222 }
1223
1224 mediaFiles = append(mediaFiles, fantasy.FilePart{
1225 Data: decoded,
1226 MediaType: media.MediaType,
1227 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1228 })
1229
1230 textParts = append(textParts, fantasy.ToolResultPart{
1231 ToolCallID: toolResult.ToolCallID,
1232 Output: fantasy.ToolResultOutputContentText{
1233 Text: "[Image/media content loaded - see attached file]",
1234 },
1235 ProviderOptions: toolResult.ProviderOptions,
1236 })
1237 } else {
1238 textParts = append(textParts, part)
1239 }
1240 }
1241
1242 convertedMessages = append(convertedMessages, fantasy.Message{
1243 Role: fantasy.MessageRoleTool,
1244 Content: textParts,
1245 })
1246
1247 if len(mediaFiles) > 0 {
1248 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1249 "Here is the media content from the tool result:",
1250 mediaFiles...,
1251 ))
1252 }
1253 }
1254
1255 return convertedMessages
1256}
1257
1258// buildSummaryPrompt constructs the prompt text for session summarization.
1259func buildSummaryPrompt(todos []session.Todo) string {
1260 var sb strings.Builder
1261 sb.WriteString("Provide a detailed summary of our conversation above.")
1262 if len(todos) > 0 {
1263 sb.WriteString("\n\n## Current Todo List\n\n")
1264 for _, t := range todos {
1265 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1266 }
1267 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1268 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1269 }
1270 return sb.String()
1271}