1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/notify"
36 "github.com/charmbracelet/crush/internal/agent/tools"
37 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
38 "github.com/charmbracelet/crush/internal/config"
39 "github.com/charmbracelet/crush/internal/csync"
40 "github.com/charmbracelet/crush/internal/message"
41 "github.com/charmbracelet/crush/internal/permission"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58var userAgent = fmt.Sprintf("Charm-Crush/%s (https://charm.land/crush)", version.Version)
59
60//go:embed templates/title.md
61var titlePrompt []byte
62
63//go:embed templates/summary.md
64var summaryPrompt []byte
65
66// Used to remove <think> tags from generated titles.
67var (
68 thinkTagRegex = regexp.MustCompile(`(?s)<think>.*?</think>`)
69 orphanThinkTagRegex = regexp.MustCompile(`</?think>`)
70)
71
72type SessionAgentCall struct {
73 SessionID string
74 Prompt string
75 ProviderOptions fantasy.ProviderOptions
76 Attachments []message.Attachment
77 MaxOutputTokens int64
78 Temperature *float64
79 TopP *float64
80 TopK *int64
81 FrequencyPenalty *float64
82 PresencePenalty *float64
83 NonInteractive bool
84}
85
86type SessionAgent interface {
87 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
88 SetModels(large Model, small Model)
89 SetTools(tools []fantasy.AgentTool)
90 SetSystemPrompt(systemPrompt string)
91 Cancel(sessionID string)
92 CancelAll()
93 IsSessionBusy(sessionID string) bool
94 IsBusy() bool
95 QueuedPrompts(sessionID string) int
96 QueuedPromptsList(sessionID string) []string
97 ClearQueue(sessionID string)
98 Summarize(context.Context, string, fantasy.ProviderOptions) error
99 Model() Model
100}
101
102type Model struct {
103 Model fantasy.LanguageModel
104 CatwalkCfg catwalk.Model
105 ModelCfg config.SelectedModel
106}
107
108type sessionAgent struct {
109 largeModel *csync.Value[Model]
110 smallModel *csync.Value[Model]
111 systemPromptPrefix *csync.Value[string]
112 systemPrompt *csync.Value[string]
113 tools *csync.Slice[fantasy.AgentTool]
114
115 isSubAgent bool
116 sessions session.Service
117 messages message.Service
118 disableAutoSummarize bool
119 isYolo bool
120 notify pubsub.Publisher[notify.Notification]
121
122 messageQueue *csync.Map[string, []SessionAgentCall]
123 activeRequests *csync.Map[string, context.CancelFunc]
124}
125
126type SessionAgentOptions struct {
127 LargeModel Model
128 SmallModel Model
129 SystemPromptPrefix string
130 SystemPrompt string
131 IsSubAgent bool
132 DisableAutoSummarize bool
133 IsYolo bool
134 Sessions session.Service
135 Messages message.Service
136 Tools []fantasy.AgentTool
137 Notify pubsub.Publisher[notify.Notification]
138}
139
140func NewSessionAgent(
141 opts SessionAgentOptions,
142) SessionAgent {
143 return &sessionAgent{
144 largeModel: csync.NewValue(opts.LargeModel),
145 smallModel: csync.NewValue(opts.SmallModel),
146 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
147 systemPrompt: csync.NewValue(opts.SystemPrompt),
148 isSubAgent: opts.IsSubAgent,
149 sessions: opts.Sessions,
150 messages: opts.Messages,
151 disableAutoSummarize: opts.DisableAutoSummarize,
152 tools: csync.NewSliceFrom(opts.Tools),
153 isYolo: opts.IsYolo,
154 notify: opts.Notify,
155 messageQueue: csync.NewMap[string, []SessionAgentCall](),
156 activeRequests: csync.NewMap[string, context.CancelFunc](),
157 }
158}
159
160func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
161 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
162 return nil, ErrEmptyPrompt
163 }
164 if call.SessionID == "" {
165 return nil, ErrSessionMissing
166 }
167
168 // Queue the message if busy
169 if a.IsSessionBusy(call.SessionID) {
170 existing, ok := a.messageQueue.Get(call.SessionID)
171 if !ok {
172 existing = []SessionAgentCall{}
173 }
174 existing = append(existing, call)
175 a.messageQueue.Set(call.SessionID, existing)
176 return nil, nil
177 }
178
179 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
180 agentTools := a.tools.Copy()
181 largeModel := a.largeModel.Get()
182 systemPrompt := a.systemPrompt.Get()
183 promptPrefix := a.systemPromptPrefix.Get()
184 var instructions strings.Builder
185
186 for _, server := range mcp.GetStates() {
187 if server.State != mcp.StateConnected {
188 continue
189 }
190 if s := server.Client.InitializeResult().Instructions; s != "" {
191 instructions.WriteString(s)
192 instructions.WriteString("\n\n")
193 }
194 }
195
196 if s := instructions.String(); s != "" {
197 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
198 }
199
200 if len(agentTools) > 0 {
201 // Add Anthropic caching to the last tool.
202 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
203 }
204
205 agent := fantasy.NewAgent(
206 largeModel.Model,
207 fantasy.WithSystemPrompt(systemPrompt),
208 fantasy.WithTools(agentTools...),
209 fantasy.WithUserAgent(userAgent),
210 )
211
212 sessionLock := sync.Mutex{}
213 currentSession, err := a.sessions.Get(ctx, call.SessionID)
214 if err != nil {
215 return nil, fmt.Errorf("failed to get session: %w", err)
216 }
217
218 msgs, err := a.getSessionMessages(ctx, currentSession)
219 if err != nil {
220 return nil, fmt.Errorf("failed to get session messages: %w", err)
221 }
222
223 var wg sync.WaitGroup
224 // Generate title if first message.
225 if len(msgs) == 0 {
226 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
227 wg.Go(func() {
228 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
229 })
230 }
231 defer wg.Wait()
232
233 // Add the user message to the session.
234 _, err = a.createUserMessage(ctx, call)
235 if err != nil {
236 return nil, err
237 }
238
239 // Add the session to the context.
240 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
241
242 genCtx, cancel := context.WithCancel(ctx)
243 a.activeRequests.Set(call.SessionID, cancel)
244
245 defer cancel()
246 defer a.activeRequests.Del(call.SessionID)
247
248 history, files := a.preparePrompt(msgs, call.Attachments...)
249
250 startTime := time.Now()
251 a.eventPromptSent(call.SessionID)
252
253 var currentAssistant *message.Message
254 var shouldSummarize bool
255 // Don't send MaxOutputTokens if 0 — some providers (e.g. LM Studio) reject it
256 var maxOutputTokens *int64
257 if call.MaxOutputTokens > 0 {
258 maxOutputTokens = &call.MaxOutputTokens
259 }
260 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
261 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
262 Files: files,
263 Messages: history,
264 ProviderOptions: call.ProviderOptions,
265 MaxOutputTokens: maxOutputTokens,
266 TopP: call.TopP,
267 Temperature: call.Temperature,
268 PresencePenalty: call.PresencePenalty,
269 TopK: call.TopK,
270 FrequencyPenalty: call.FrequencyPenalty,
271 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
272 prepared.Messages = options.Messages
273 for i := range prepared.Messages {
274 prepared.Messages[i].ProviderOptions = nil
275 }
276
277 // Use latest tools (updated by SetTools when MCP tools change).
278 prepared.Tools = a.tools.Copy()
279
280 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
281 a.messageQueue.Del(call.SessionID)
282 for _, queued := range queuedCalls {
283 userMessage, createErr := a.createUserMessage(callContext, queued)
284 if createErr != nil {
285 return callContext, prepared, createErr
286 }
287 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
288 }
289
290 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
291
292 lastSystemRoleInx := 0
293 systemMessageUpdated := false
294 for i, msg := range prepared.Messages {
295 // Only add cache control to the last message.
296 if msg.Role == fantasy.MessageRoleSystem {
297 lastSystemRoleInx = i
298 } else if !systemMessageUpdated {
299 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
300 systemMessageUpdated = true
301 }
302 // Than add cache control to the last 2 messages.
303 if i > len(prepared.Messages)-3 {
304 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
305 }
306 }
307
308 if promptPrefix != "" {
309 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
310 }
311
312 var assistantMsg message.Message
313 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
314 Role: message.Assistant,
315 Parts: []message.ContentPart{},
316 Model: largeModel.ModelCfg.Model,
317 Provider: largeModel.ModelCfg.Provider,
318 })
319 if err != nil {
320 return callContext, prepared, err
321 }
322 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
323 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
324 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
325 currentAssistant = &assistantMsg
326 return callContext, prepared, err
327 },
328 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
329 currentAssistant.AppendReasoningContent(reasoning.Text)
330 return a.messages.Update(genCtx, *currentAssistant)
331 },
332 OnReasoningDelta: func(id string, text string) error {
333 currentAssistant.AppendReasoningContent(text)
334 return a.messages.Update(genCtx, *currentAssistant)
335 },
336 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
337 // handle anthropic signature
338 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
339 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
340 currentAssistant.AppendReasoningSignature(reasoning.Signature)
341 }
342 }
343 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
344 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
345 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
346 }
347 }
348 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
349 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
350 currentAssistant.SetReasoningResponsesData(reasoning)
351 }
352 }
353 currentAssistant.FinishThinking()
354 return a.messages.Update(genCtx, *currentAssistant)
355 },
356 OnTextDelta: func(id string, text string) error {
357 // Strip leading newline from initial text content. This is is
358 // particularly important in non-interactive mode where leading
359 // newlines are very visible.
360 if len(currentAssistant.Parts) == 0 {
361 text = strings.TrimPrefix(text, "\n")
362 }
363
364 currentAssistant.AppendContent(text)
365 return a.messages.Update(genCtx, *currentAssistant)
366 },
367 OnToolInputStart: func(id string, toolName string) error {
368 toolCall := message.ToolCall{
369 ID: id,
370 Name: toolName,
371 ProviderExecuted: false,
372 Finished: false,
373 }
374 currentAssistant.AddToolCall(toolCall)
375 // Use parent ctx instead of genCtx to ensure the update succeeds
376 // even if the request is canceled mid-stream
377 return a.messages.Update(ctx, *currentAssistant)
378 },
379 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
380 // TODO: implement
381 },
382 OnToolCall: func(tc fantasy.ToolCallContent) error {
383 toolCall := message.ToolCall{
384 ID: tc.ToolCallID,
385 Name: tc.ToolName,
386 Input: tc.Input,
387 ProviderExecuted: false,
388 Finished: true,
389 }
390 currentAssistant.AddToolCall(toolCall)
391 // Use parent ctx instead of genCtx to ensure the update succeeds
392 // even if the request is canceled mid-stream
393 return a.messages.Update(ctx, *currentAssistant)
394 },
395 OnToolResult: func(result fantasy.ToolResultContent) error {
396 toolResult := a.convertToToolResult(result)
397 // Use parent ctx instead of genCtx to ensure the message is created
398 // even if the request is canceled mid-stream
399 _, createMsgErr := a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
400 Role: message.Tool,
401 Parts: []message.ContentPart{
402 toolResult,
403 },
404 })
405 return createMsgErr
406 },
407 OnStepFinish: func(stepResult fantasy.StepResult) error {
408 finishReason := message.FinishReasonUnknown
409 switch stepResult.FinishReason {
410 case fantasy.FinishReasonLength:
411 finishReason = message.FinishReasonMaxTokens
412 case fantasy.FinishReasonStop:
413 finishReason = message.FinishReasonEndTurn
414 case fantasy.FinishReasonToolCalls:
415 finishReason = message.FinishReasonToolUse
416 }
417 currentAssistant.AddFinish(finishReason, "", "")
418 sessionLock.Lock()
419 defer sessionLock.Unlock()
420
421 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
422 if getSessionErr != nil {
423 return getSessionErr
424 }
425 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
426 _, sessionErr := a.sessions.Save(ctx, updatedSession)
427 if sessionErr != nil {
428 return sessionErr
429 }
430 currentSession = updatedSession
431 return a.messages.Update(genCtx, *currentAssistant)
432 },
433 StopWhen: []fantasy.StopCondition{
434 func(_ []fantasy.StepResult) bool {
435 cw := int64(largeModel.CatwalkCfg.ContextWindow)
436 // If context window is unknown (0), skip auto-summarize
437 // to avoid immediately truncating custom/local models.
438 if cw == 0 {
439 return false
440 }
441 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
442 remaining := cw - tokens
443 var threshold int64
444 if cw > largeContextWindowThreshold {
445 threshold = largeContextWindowBuffer
446 } else {
447 threshold = int64(float64(cw) * smallContextWindowRatio)
448 }
449 if (remaining <= threshold) && !a.disableAutoSummarize {
450 shouldSummarize = true
451 return true
452 }
453 return false
454 },
455 func(steps []fantasy.StepResult) bool {
456 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
457 },
458 },
459 })
460
461 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
462
463 if err != nil {
464 isCancelErr := errors.Is(err, context.Canceled)
465 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
466 if currentAssistant == nil {
467 return result, err
468 }
469 // Ensure we finish thinking on error to close the reasoning state.
470 currentAssistant.FinishThinking()
471 toolCalls := currentAssistant.ToolCalls()
472 // INFO: we use the parent context here because the genCtx has been cancelled.
473 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
474 if createErr != nil {
475 return nil, createErr
476 }
477 for _, tc := range toolCalls {
478 if !tc.Finished {
479 tc.Finished = true
480 tc.Input = "{}"
481 currentAssistant.AddToolCall(tc)
482 updateErr := a.messages.Update(ctx, *currentAssistant)
483 if updateErr != nil {
484 return nil, updateErr
485 }
486 }
487
488 found := false
489 for _, msg := range msgs {
490 if msg.Role == message.Tool {
491 for _, tr := range msg.ToolResults() {
492 if tr.ToolCallID == tc.ID {
493 found = true
494 break
495 }
496 }
497 }
498 if found {
499 break
500 }
501 }
502 if found {
503 continue
504 }
505 content := "There was an error while executing the tool"
506 if isCancelErr {
507 content = "Error: user cancelled assistant tool calling"
508 } else if isPermissionErr {
509 content = "User denied permission"
510 }
511 toolResult := message.ToolResult{
512 ToolCallID: tc.ID,
513 Name: tc.Name,
514 Content: content,
515 IsError: true,
516 }
517 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
518 Role: message.Tool,
519 Parts: []message.ContentPart{
520 toolResult,
521 },
522 })
523 if createErr != nil {
524 return nil, createErr
525 }
526 }
527 var fantasyErr *fantasy.Error
528 var providerErr *fantasy.ProviderError
529 const defaultTitle = "Provider Error"
530 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
531 if isCancelErr {
532 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
533 } else if isPermissionErr {
534 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
535 } else if errors.Is(err, hyper.ErrUnauthorized) {
536 currentAssistant.AddFinish(message.FinishReasonError, "Unauthorized", `Please re-authenticate with Hyper. You can also run "crush auth" to re-authenticate.`)
537 if a.notify != nil {
538 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
539 SessionID: call.SessionID,
540 SessionTitle: currentSession.Title,
541 Type: notify.TypeReAuthenticate,
542 ProviderID: largeModel.ModelCfg.Provider,
543 })
544 }
545 } else if errors.Is(err, hyper.ErrNoCredits) {
546 url := hyper.BaseURL()
547 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
548 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
549 } else if errors.As(err, &providerErr) {
550 if providerErr.Message == "The requested model is not supported." {
551 url := "https://github.com/settings/copilot/features"
552 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
553 currentAssistant.AddFinish(
554 message.FinishReasonError,
555 "Copilot model not enabled",
556 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
557 )
558 } else {
559 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
560 }
561 } else if errors.As(err, &fantasyErr) {
562 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
563 } else {
564 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
565 }
566 // Note: we use the parent context here because the genCtx has been
567 // cancelled.
568 updateErr := a.messages.Update(ctx, *currentAssistant)
569 if updateErr != nil {
570 return nil, updateErr
571 }
572 return nil, err
573 }
574
575 // Send notification that agent has finished its turn (skip for
576 // nested/non-interactive sessions).
577 if !call.NonInteractive && a.notify != nil {
578 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
579 SessionID: call.SessionID,
580 SessionTitle: currentSession.Title,
581 Type: notify.TypeAgentFinished,
582 })
583 }
584
585 if shouldSummarize {
586 a.activeRequests.Del(call.SessionID)
587 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
588 return nil, summarizeErr
589 }
590 // If the agent wasn't done...
591 if len(currentAssistant.ToolCalls()) > 0 {
592 existing, ok := a.messageQueue.Get(call.SessionID)
593 if !ok {
594 existing = []SessionAgentCall{}
595 }
596 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
597 existing = append(existing, call)
598 a.messageQueue.Set(call.SessionID, existing)
599 }
600 }
601
602 // Release active request before processing queued messages.
603 a.activeRequests.Del(call.SessionID)
604 cancel()
605
606 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
607 if !ok || len(queuedMessages) == 0 {
608 return result, err
609 }
610 // There are queued messages restart the loop.
611 firstQueuedMessage := queuedMessages[0]
612 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
613 return a.Run(ctx, firstQueuedMessage)
614}
615
616func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
617 if a.IsSessionBusy(sessionID) {
618 return ErrSessionBusy
619 }
620
621 // Copy mutable fields under lock to avoid races with SetModels.
622 largeModel := a.largeModel.Get()
623 systemPromptPrefix := a.systemPromptPrefix.Get()
624
625 currentSession, err := a.sessions.Get(ctx, sessionID)
626 if err != nil {
627 return fmt.Errorf("failed to get session: %w", err)
628 }
629 msgs, err := a.getSessionMessages(ctx, currentSession)
630 if err != nil {
631 return err
632 }
633 if len(msgs) == 0 {
634 // Nothing to summarize.
635 return nil
636 }
637
638 aiMsgs, _ := a.preparePrompt(msgs)
639
640 genCtx, cancel := context.WithCancel(ctx)
641 a.activeRequests.Set(sessionID, cancel)
642 defer a.activeRequests.Del(sessionID)
643 defer cancel()
644
645 agent := fantasy.NewAgent(largeModel.Model,
646 fantasy.WithSystemPrompt(string(summaryPrompt)),
647 fantasy.WithUserAgent(userAgent),
648 )
649 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
650 Role: message.Assistant,
651 Model: largeModel.Model.Model(),
652 Provider: largeModel.Model.Provider(),
653 IsSummaryMessage: true,
654 })
655 if err != nil {
656 return err
657 }
658
659 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
660
661 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
662 Prompt: summaryPromptText,
663 Messages: aiMsgs,
664 ProviderOptions: opts,
665 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
666 prepared.Messages = options.Messages
667 if systemPromptPrefix != "" {
668 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
669 }
670 return callContext, prepared, nil
671 },
672 OnReasoningDelta: func(id string, text string) error {
673 summaryMessage.AppendReasoningContent(text)
674 return a.messages.Update(genCtx, summaryMessage)
675 },
676 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
677 // Handle anthropic signature.
678 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
679 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
680 summaryMessage.AppendReasoningSignature(signature.Signature)
681 }
682 }
683 summaryMessage.FinishThinking()
684 return a.messages.Update(genCtx, summaryMessage)
685 },
686 OnTextDelta: func(id, text string) error {
687 summaryMessage.AppendContent(text)
688 return a.messages.Update(genCtx, summaryMessage)
689 },
690 })
691 if err != nil {
692 isCancelErr := errors.Is(err, context.Canceled)
693 if isCancelErr {
694 // User cancelled summarize we need to remove the summary message.
695 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
696 return deleteErr
697 }
698 return err
699 }
700
701 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
702 err = a.messages.Update(genCtx, summaryMessage)
703 if err != nil {
704 return err
705 }
706
707 var openrouterCost *float64
708 for _, step := range resp.Steps {
709 stepCost := a.openrouterCost(step.ProviderMetadata)
710 if stepCost != nil {
711 newCost := *stepCost
712 if openrouterCost != nil {
713 newCost += *openrouterCost
714 }
715 openrouterCost = &newCost
716 }
717 }
718
719 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
720
721 // Just in case, get just the last usage info.
722 usage := resp.Response.Usage
723 currentSession.SummaryMessageID = summaryMessage.ID
724 currentSession.CompletionTokens = usage.OutputTokens
725 currentSession.PromptTokens = 0
726 _, err = a.sessions.Save(genCtx, currentSession)
727 return err
728}
729
730func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
731 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
732 return fantasy.ProviderOptions{}
733 }
734 return fantasy.ProviderOptions{
735 anthropic.Name: &anthropic.ProviderCacheControlOptions{
736 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
737 },
738 bedrock.Name: &anthropic.ProviderCacheControlOptions{
739 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
740 },
741 vercel.Name: &anthropic.ProviderCacheControlOptions{
742 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
743 },
744 }
745}
746
747func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
748 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
749 var attachmentParts []message.ContentPart
750 for _, attachment := range call.Attachments {
751 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
752 }
753 parts = append(parts, attachmentParts...)
754 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
755 Role: message.User,
756 Parts: parts,
757 })
758 if err != nil {
759 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
760 }
761 return msg, nil
762}
763
764func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
765 var history []fantasy.Message
766 if !a.isSubAgent {
767 history = append(history, fantasy.NewUserMessage(
768 fmt.Sprintf("<system_reminder>%s</system_reminder>",
769 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
770If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
771If not, please feel free to ignore. Again do not mention this message to the user.`,
772 ),
773 ))
774 }
775 // Collect all tool call IDs present in assistant messages.
776 knownToolCallIDs := make(map[string]struct{})
777 for _, m := range msgs {
778 if m.Role == message.Assistant {
779 for _, tc := range m.ToolCalls() {
780 knownToolCallIDs[tc.ID] = struct{}{}
781 }
782 }
783 }
784
785 for _, m := range msgs {
786 if len(m.Parts) == 0 {
787 continue
788 }
789 // Assistant message without content or tool calls (cancelled before it
790 // returned anything).
791 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
792 continue
793 }
794 if m.Role == message.Tool {
795 // Filter out tool results that have no matching tool call. An orphaned
796 // result causes every subsequent API call to fail validation.
797 aiMsgs := m.ToAIMessage()
798 if len(aiMsgs) == 0 {
799 continue
800 }
801 var validParts []fantasy.MessagePart
802 for _, part := range aiMsgs[0].Content {
803 tr, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
804 if !ok {
805 validParts = append(validParts, part)
806 continue
807 }
808 if _, known := knownToolCallIDs[tr.ToolCallID]; known {
809 validParts = append(validParts, part)
810 } else {
811 slog.Warn("Dropping orphaned tool result with no matching tool call",
812 "tool_call_id", tr.ToolCallID,
813 )
814 }
815 }
816 if len(validParts) > 0 {
817 msg := aiMsgs[0]
818 msg.Content = validParts
819 history = append(history, msg)
820 }
821 continue
822 }
823 history = append(history, m.ToAIMessage()...)
824 }
825
826 var files []fantasy.FilePart
827 for _, attachment := range attachments {
828 if attachment.IsText() {
829 continue
830 }
831 files = append(files, fantasy.FilePart{
832 Filename: attachment.FileName,
833 Data: attachment.Content,
834 MediaType: attachment.MimeType,
835 })
836 }
837
838 return history, files
839}
840
841func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
842 msgs, err := a.messages.List(ctx, session.ID)
843 if err != nil {
844 return nil, fmt.Errorf("failed to list messages: %w", err)
845 }
846
847 if session.SummaryMessageID != "" {
848 summaryMsgIndex := -1
849 for i, msg := range msgs {
850 if msg.ID == session.SummaryMessageID {
851 summaryMsgIndex = i
852 break
853 }
854 }
855 if summaryMsgIndex != -1 {
856 msgs = msgs[summaryMsgIndex:]
857 msgs[0].Role = message.User
858 }
859 }
860 return msgs, nil
861}
862
863// generateTitle generates a session titled based on the initial prompt.
864func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
865 if userPrompt == "" {
866 return
867 }
868
869 smallModel := a.smallModel.Get()
870 largeModel := a.largeModel.Get()
871 systemPromptPrefix := a.systemPromptPrefix.Get()
872
873 var maxOutputTokens int64 = 40
874 if smallModel.CatwalkCfg.CanReason {
875 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
876 }
877
878 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
879 return fantasy.NewAgent(m,
880 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
881 fantasy.WithMaxOutputTokens(tok),
882 fantasy.WithUserAgent(userAgent),
883 )
884 }
885
886 streamCall := fantasy.AgentStreamCall{
887 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
888 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
889 prepared.Messages = opts.Messages
890 if systemPromptPrefix != "" {
891 prepared.Messages = append([]fantasy.Message{
892 fantasy.NewSystemMessage(systemPromptPrefix),
893 }, prepared.Messages...)
894 }
895 return callCtx, prepared, nil
896 },
897 }
898
899 // Use the small model to generate the title.
900 model := smallModel
901 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
902 resp, err := agent.Stream(ctx, streamCall)
903 if err == nil {
904 // We successfully generated a title with the small model.
905 slog.Debug("Generated title with small model")
906 } else {
907 // It didn't work. Let's try with the big model.
908 slog.Error("Error generating title with small model; trying big model", "err", err)
909 model = largeModel
910 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
911 resp, err = agent.Stream(ctx, streamCall)
912 if err == nil {
913 slog.Debug("Generated title with large model")
914 } else {
915 // Welp, the large model didn't work either. Use the default
916 // session name and return.
917 slog.Error("Error generating title with large model", "err", err)
918 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
919 if saveErr != nil {
920 slog.Error("Failed to save session title", "error", saveErr)
921 }
922 return
923 }
924 }
925
926 if resp == nil {
927 // Actually, we didn't get a response so we can't. Use the default
928 // session name and return.
929 slog.Error("Response is nil; can't generate title")
930 saveErr := a.sessions.Rename(ctx, sessionID, DefaultSessionName)
931 if saveErr != nil {
932 slog.Error("Failed to save session title", "error", saveErr)
933 }
934 return
935 }
936
937 // Clean up title.
938 var title string
939 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
940
941 // Remove thinking tags if present.
942 title = thinkTagRegex.ReplaceAllString(title, "")
943 title = orphanThinkTagRegex.ReplaceAllString(title, "")
944
945 title = strings.TrimSpace(title)
946 title = cmp.Or(title, DefaultSessionName)
947
948 // Calculate usage and cost.
949 var openrouterCost *float64
950 for _, step := range resp.Steps {
951 stepCost := a.openrouterCost(step.ProviderMetadata)
952 if stepCost != nil {
953 newCost := *stepCost
954 if openrouterCost != nil {
955 newCost += *openrouterCost
956 }
957 openrouterCost = &newCost
958 }
959 }
960
961 modelConfig := model.CatwalkCfg
962 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
963 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
964 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
965 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
966
967 // Use override cost if available (e.g., from OpenRouter).
968 if openrouterCost != nil {
969 cost = *openrouterCost
970 }
971
972 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
973 completionTokens := resp.TotalUsage.OutputTokens
974
975 // Atomically update only title and usage fields to avoid overriding other
976 // concurrent session updates.
977 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
978 if saveErr != nil {
979 slog.Error("Failed to save session title and usage", "error", saveErr)
980 return
981 }
982}
983
984func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
985 openrouterMetadata, ok := metadata[openrouter.Name]
986 if !ok {
987 return nil
988 }
989
990 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
991 if !ok {
992 return nil
993 }
994 return &opts.Usage.Cost
995}
996
997func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
998 modelConfig := model.CatwalkCfg
999 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
1000 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
1001 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
1002 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
1003
1004 a.eventTokensUsed(session.ID, model, usage, cost)
1005
1006 if overrideCost != nil {
1007 session.Cost += *overrideCost
1008 } else {
1009 session.Cost += cost
1010 }
1011
1012 session.CompletionTokens = usage.OutputTokens
1013 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
1014}
1015
1016func (a *sessionAgent) Cancel(sessionID string) {
1017 // Cancel regular requests. Don't use Take() here - we need the entry to
1018 // remain in activeRequests so IsBusy() returns true until the goroutine
1019 // fully completes (including error handling that may access the DB).
1020 // The defer in processRequest will clean up the entry.
1021 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
1022 slog.Debug("Request cancellation initiated", "session_id", sessionID)
1023 cancel()
1024 }
1025
1026 // Also check for summarize requests.
1027 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
1028 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
1029 cancel()
1030 }
1031
1032 if a.QueuedPrompts(sessionID) > 0 {
1033 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1034 a.messageQueue.Del(sessionID)
1035 }
1036}
1037
1038func (a *sessionAgent) ClearQueue(sessionID string) {
1039 if a.QueuedPrompts(sessionID) > 0 {
1040 slog.Debug("Clearing queued prompts", "session_id", sessionID)
1041 a.messageQueue.Del(sessionID)
1042 }
1043}
1044
1045func (a *sessionAgent) CancelAll() {
1046 if !a.IsBusy() {
1047 return
1048 }
1049 for key := range a.activeRequests.Seq2() {
1050 a.Cancel(key) // key is sessionID
1051 }
1052
1053 timeout := time.After(5 * time.Second)
1054 for a.IsBusy() {
1055 select {
1056 case <-timeout:
1057 return
1058 default:
1059 time.Sleep(200 * time.Millisecond)
1060 }
1061 }
1062}
1063
1064func (a *sessionAgent) IsBusy() bool {
1065 var busy bool
1066 for cancelFunc := range a.activeRequests.Seq() {
1067 if cancelFunc != nil {
1068 busy = true
1069 break
1070 }
1071 }
1072 return busy
1073}
1074
1075func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1076 _, busy := a.activeRequests.Get(sessionID)
1077 return busy
1078}
1079
1080func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1081 l, ok := a.messageQueue.Get(sessionID)
1082 if !ok {
1083 return 0
1084 }
1085 return len(l)
1086}
1087
1088func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1089 l, ok := a.messageQueue.Get(sessionID)
1090 if !ok {
1091 return nil
1092 }
1093 prompts := make([]string, len(l))
1094 for i, call := range l {
1095 prompts[i] = call.Prompt
1096 }
1097 return prompts
1098}
1099
1100func (a *sessionAgent) SetModels(large Model, small Model) {
1101 a.largeModel.Set(large)
1102 a.smallModel.Set(small)
1103}
1104
1105func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1106 a.tools.SetSlice(tools)
1107}
1108
1109func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1110 a.systemPrompt.Set(systemPrompt)
1111}
1112
1113func (a *sessionAgent) Model() Model {
1114 return a.largeModel.Get()
1115}
1116
1117// convertToToolResult converts a fantasy tool result to a message tool result.
1118func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1119 baseResult := message.ToolResult{
1120 ToolCallID: result.ToolCallID,
1121 Name: result.ToolName,
1122 Metadata: result.ClientMetadata,
1123 }
1124
1125 switch result.Result.GetType() {
1126 case fantasy.ToolResultContentTypeText:
1127 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1128 baseResult.Content = r.Text
1129 }
1130 case fantasy.ToolResultContentTypeError:
1131 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1132 baseResult.Content = r.Error.Error()
1133 baseResult.IsError = true
1134 }
1135 case fantasy.ToolResultContentTypeMedia:
1136 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1137 if !stringext.IsValidBase64(r.Data) {
1138 slog.Warn("Tool returned media with invalid base64 data, discarding image",
1139 "tool", result.ToolName,
1140 "tool_call_id", result.ToolCallID,
1141 )
1142 baseResult.Content = "Tool returned image data with invalid encoding"
1143 baseResult.IsError = true
1144 } else {
1145 content := r.Text
1146 if content == "" {
1147 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1148 }
1149 baseResult.Content = content
1150 baseResult.Data = r.Data
1151 baseResult.MIMEType = r.MediaType
1152 }
1153 }
1154 }
1155
1156 return baseResult
1157}
1158
1159// workaroundProviderMediaLimitations converts media content in tool results to
1160// user messages for providers that don't natively support images in tool results.
1161//
1162// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1163// don't support sending images/media in tool result messages - they only accept
1164// text in tool results. However, they DO support images in user messages.
1165//
1166// If we send media in tool results to these providers, the API returns an error.
1167//
1168// Solution: For these providers, we:
1169// 1. Replace the media in the tool result with a text placeholder
1170// 2. Inject a user message immediately after with the image as a file attachment
1171// 3. This maintains the tool execution flow while working around API limitations
1172//
1173// Anthropic and Bedrock support images natively in tool results, so we skip
1174// this workaround for them.
1175//
1176// Example transformation:
1177//
1178// BEFORE: [tool result: image data]
1179// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1180func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1181 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1182 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1183
1184 if providerSupportsMedia {
1185 return messages
1186 }
1187
1188 convertedMessages := make([]fantasy.Message, 0, len(messages))
1189
1190 for _, msg := range messages {
1191 if msg.Role != fantasy.MessageRoleTool {
1192 convertedMessages = append(convertedMessages, msg)
1193 continue
1194 }
1195
1196 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1197 var mediaFiles []fantasy.FilePart
1198
1199 for _, part := range msg.Content {
1200 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1201 if !ok {
1202 textParts = append(textParts, part)
1203 continue
1204 }
1205
1206 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1207 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1208 if err != nil {
1209 slog.Warn("Failed to decode media data", "error", err)
1210 textParts = append(textParts, part)
1211 continue
1212 }
1213
1214 mediaFiles = append(mediaFiles, fantasy.FilePart{
1215 Data: decoded,
1216 MediaType: media.MediaType,
1217 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1218 })
1219
1220 textParts = append(textParts, fantasy.ToolResultPart{
1221 ToolCallID: toolResult.ToolCallID,
1222 Output: fantasy.ToolResultOutputContentText{
1223 Text: "[Image/media content loaded - see attached file]",
1224 },
1225 ProviderOptions: toolResult.ProviderOptions,
1226 })
1227 } else {
1228 textParts = append(textParts, part)
1229 }
1230 }
1231
1232 convertedMessages = append(convertedMessages, fantasy.Message{
1233 Role: fantasy.MessageRoleTool,
1234 Content: textParts,
1235 })
1236
1237 if len(mediaFiles) > 0 {
1238 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1239 "Here is the media content from the tool result:",
1240 mediaFiles...,
1241 ))
1242 }
1243 }
1244
1245 return convertedMessages
1246}
1247
1248// buildSummaryPrompt constructs the prompt text for session summarization.
1249func buildSummaryPrompt(todos []session.Todo) string {
1250 var sb strings.Builder
1251 sb.WriteString("Provide a detailed summary of our conversation above.")
1252 if len(todos) > 0 {
1253 sb.WriteString("\n\n## Current Todo List\n\n")
1254 for _, t := range todos {
1255 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1256 }
1257 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1258 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1259 }
1260 return sb.String()
1261}