1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "github.com/charmbracelet/crush/internal/agent/hyper"
35 "github.com/charmbracelet/crush/internal/agent/notify"
36 "github.com/charmbracelet/crush/internal/agent/tools"
37 "github.com/charmbracelet/crush/internal/agent/tools/mcp"
38 "github.com/charmbracelet/crush/internal/config"
39 "github.com/charmbracelet/crush/internal/csync"
40 "github.com/charmbracelet/crush/internal/message"
41 "github.com/charmbracelet/crush/internal/permission"
42 "github.com/charmbracelet/crush/internal/pubsub"
43 "github.com/charmbracelet/crush/internal/session"
44 "github.com/charmbracelet/crush/internal/stringext"
45 "github.com/charmbracelet/crush/internal/version"
46 "github.com/charmbracelet/x/exp/charmtone"
47)
48
49const (
50 DefaultSessionName = "Untitled Session"
51
52 // Constants for auto-summarization thresholds
53 largeContextWindowThreshold = 200_000
54 largeContextWindowBuffer = 20_000
55 smallContextWindowRatio = 0.2
56)
57
58//go:embed templates/title.md
59var titlePrompt []byte
60
61//go:embed templates/summary.md
62var summaryPrompt []byte
63
64// Used to remove <think> tags from generated titles.
65var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
66
67type SessionAgentCall struct {
68 SessionID string
69 Prompt string
70 ProviderOptions fantasy.ProviderOptions
71 Attachments []message.Attachment
72 MaxOutputTokens int64
73 Temperature *float64
74 TopP *float64
75 TopK *int64
76 FrequencyPenalty *float64
77 PresencePenalty *float64
78 NonInteractive bool
79}
80
81type SessionAgent interface {
82 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
83 SetModels(large Model, small Model)
84 SetTools(tools []fantasy.AgentTool)
85 SetSystemPrompt(systemPrompt string)
86 Cancel(sessionID string)
87 CancelAll()
88 IsSessionBusy(sessionID string) bool
89 IsBusy() bool
90 QueuedPrompts(sessionID string) int
91 QueuedPromptsList(sessionID string) []string
92 ClearQueue(sessionID string)
93 Summarize(context.Context, string, fantasy.ProviderOptions) error
94 Model() Model
95}
96
97type Model struct {
98 Model fantasy.LanguageModel
99 CatwalkCfg catwalk.Model
100 ModelCfg config.SelectedModel
101}
102
103type sessionAgent struct {
104 largeModel *csync.Value[Model]
105 smallModel *csync.Value[Model]
106 systemPromptPrefix *csync.Value[string]
107 systemPrompt *csync.Value[string]
108 tools *csync.Slice[fantasy.AgentTool]
109
110 isSubAgent bool
111 sessions session.Service
112 messages message.Service
113 disableAutoSummarize bool
114 isYolo bool
115 notify pubsub.Publisher[notify.Notification]
116
117 messageQueue *csync.Map[string, []SessionAgentCall]
118 activeRequests *csync.Map[string, context.CancelFunc]
119}
120
121type SessionAgentOptions struct {
122 LargeModel Model
123 SmallModel Model
124 SystemPromptPrefix string
125 SystemPrompt string
126 IsSubAgent bool
127 DisableAutoSummarize bool
128 IsYolo bool
129 Sessions session.Service
130 Messages message.Service
131 Tools []fantasy.AgentTool
132 Notify pubsub.Publisher[notify.Notification]
133}
134
135func NewSessionAgent(
136 opts SessionAgentOptions,
137) SessionAgent {
138 return &sessionAgent{
139 largeModel: csync.NewValue(opts.LargeModel),
140 smallModel: csync.NewValue(opts.SmallModel),
141 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
142 systemPrompt: csync.NewValue(opts.SystemPrompt),
143 isSubAgent: opts.IsSubAgent,
144 sessions: opts.Sessions,
145 messages: opts.Messages,
146 disableAutoSummarize: opts.DisableAutoSummarize,
147 tools: csync.NewSliceFrom(opts.Tools),
148 isYolo: opts.IsYolo,
149 notify: opts.Notify,
150 messageQueue: csync.NewMap[string, []SessionAgentCall](),
151 activeRequests: csync.NewMap[string, context.CancelFunc](),
152 }
153}
154
155func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
156 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
157 return nil, ErrEmptyPrompt
158 }
159 if call.SessionID == "" {
160 return nil, ErrSessionMissing
161 }
162
163 // Queue the message if busy
164 if a.IsSessionBusy(call.SessionID) {
165 existing, ok := a.messageQueue.Get(call.SessionID)
166 if !ok {
167 existing = []SessionAgentCall{}
168 }
169 existing = append(existing, call)
170 a.messageQueue.Set(call.SessionID, existing)
171 return nil, nil
172 }
173
174 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
175 agentTools := a.tools.Copy()
176 largeModel := a.largeModel.Get()
177 systemPrompt := a.systemPrompt.Get()
178 promptPrefix := a.systemPromptPrefix.Get()
179 var instructions strings.Builder
180
181 for _, server := range mcp.GetStates() {
182 if server.State != mcp.StateConnected {
183 continue
184 }
185 if s := server.Client.InitializeResult().Instructions; s != "" {
186 instructions.WriteString(s)
187 instructions.WriteString("\n\n")
188 }
189 }
190
191 if s := instructions.String(); s != "" {
192 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
193 }
194
195 if len(agentTools) > 0 {
196 // Add Anthropic caching to the last tool.
197 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
198 }
199
200 agent := fantasy.NewAgent(
201 largeModel.Model,
202 fantasy.WithSystemPrompt(systemPrompt),
203 fantasy.WithTools(agentTools...),
204 fantasy.WithUserAgent("Charm Crush/"+version.Version),
205 )
206
207 sessionLock := sync.Mutex{}
208 currentSession, err := a.sessions.Get(ctx, call.SessionID)
209 if err != nil {
210 return nil, fmt.Errorf("failed to get session: %w", err)
211 }
212
213 msgs, err := a.getSessionMessages(ctx, currentSession)
214 if err != nil {
215 return nil, fmt.Errorf("failed to get session messages: %w", err)
216 }
217
218 var wg sync.WaitGroup
219 // Generate title if first message.
220 if len(msgs) == 0 {
221 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
222 wg.Go(func() {
223 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
224 })
225 }
226 defer wg.Wait()
227
228 // Add the user message to the session.
229 _, err = a.createUserMessage(ctx, call)
230 if err != nil {
231 return nil, err
232 }
233
234 // Add the session to the context.
235 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
236
237 genCtx, cancel := context.WithCancel(ctx)
238 a.activeRequests.Set(call.SessionID, cancel)
239
240 defer cancel()
241 defer a.activeRequests.Del(call.SessionID)
242
243 history, files := a.preparePrompt(msgs, call.Attachments...)
244
245 startTime := time.Now()
246 a.eventPromptSent(call.SessionID)
247
248 var currentAssistant *message.Message
249 var shouldSummarize bool
250 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
251 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
252 Files: files,
253 Messages: history,
254 ProviderOptions: call.ProviderOptions,
255 MaxOutputTokens: &call.MaxOutputTokens,
256 TopP: call.TopP,
257 Temperature: call.Temperature,
258 PresencePenalty: call.PresencePenalty,
259 TopK: call.TopK,
260 FrequencyPenalty: call.FrequencyPenalty,
261 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
262 prepared.Messages = options.Messages
263 for i := range prepared.Messages {
264 prepared.Messages[i].ProviderOptions = nil
265 }
266
267 // Use latest tools (updated by SetTools when MCP tools change).
268 prepared.Tools = a.tools.Copy()
269
270 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
271 a.messageQueue.Del(call.SessionID)
272 for _, queued := range queuedCalls {
273 userMessage, createErr := a.createUserMessage(callContext, queued)
274 if createErr != nil {
275 return callContext, prepared, createErr
276 }
277 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
278 }
279
280 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
281
282 lastSystemRoleInx := 0
283 systemMessageUpdated := false
284 for i, msg := range prepared.Messages {
285 // Only add cache control to the last message.
286 if msg.Role == fantasy.MessageRoleSystem {
287 lastSystemRoleInx = i
288 } else if !systemMessageUpdated {
289 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
290 systemMessageUpdated = true
291 }
292 // Than add cache control to the last 2 messages.
293 if i > len(prepared.Messages)-3 {
294 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
295 }
296 }
297
298 if promptPrefix != "" {
299 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
300 }
301
302 var assistantMsg message.Message
303 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
304 Role: message.Assistant,
305 Parts: []message.ContentPart{},
306 Model: largeModel.ModelCfg.Model,
307 Provider: largeModel.ModelCfg.Provider,
308 })
309 if err != nil {
310 return callContext, prepared, err
311 }
312 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
313 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
314 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
315 currentAssistant = &assistantMsg
316 return callContext, prepared, err
317 },
318 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
319 currentAssistant.AppendReasoningContent(reasoning.Text)
320 return a.messages.Update(genCtx, *currentAssistant)
321 },
322 OnReasoningDelta: func(id string, text string) error {
323 currentAssistant.AppendReasoningContent(text)
324 return a.messages.Update(genCtx, *currentAssistant)
325 },
326 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
327 // handle anthropic signature
328 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
329 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
330 currentAssistant.AppendReasoningSignature(reasoning.Signature)
331 }
332 }
333 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
334 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
335 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
336 }
337 }
338 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
339 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
340 currentAssistant.SetReasoningResponsesData(reasoning)
341 }
342 }
343 currentAssistant.FinishThinking()
344 return a.messages.Update(genCtx, *currentAssistant)
345 },
346 OnTextDelta: func(id string, text string) error {
347 // Strip leading newline from initial text content. This is is
348 // particularly important in non-interactive mode where leading
349 // newlines are very visible.
350 if len(currentAssistant.Parts) == 0 {
351 text = strings.TrimPrefix(text, "\n")
352 }
353
354 currentAssistant.AppendContent(text)
355 return a.messages.Update(genCtx, *currentAssistant)
356 },
357 OnToolInputStart: func(id string, toolName string) error {
358 toolCall := message.ToolCall{
359 ID: id,
360 Name: toolName,
361 ProviderExecuted: false,
362 Finished: false,
363 }
364 currentAssistant.AddToolCall(toolCall)
365 return a.messages.Update(genCtx, *currentAssistant)
366 },
367 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
368 // TODO: implement
369 },
370 OnToolCall: func(tc fantasy.ToolCallContent) error {
371 toolCall := message.ToolCall{
372 ID: tc.ToolCallID,
373 Name: tc.ToolName,
374 Input: tc.Input,
375 ProviderExecuted: false,
376 Finished: true,
377 }
378 currentAssistant.AddToolCall(toolCall)
379 return a.messages.Update(genCtx, *currentAssistant)
380 },
381 OnToolResult: func(result fantasy.ToolResultContent) error {
382 toolResult := a.convertToToolResult(result)
383 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
384 Role: message.Tool,
385 Parts: []message.ContentPart{
386 toolResult,
387 },
388 })
389 return createMsgErr
390 },
391 OnStepFinish: func(stepResult fantasy.StepResult) error {
392 finishReason := message.FinishReasonUnknown
393 switch stepResult.FinishReason {
394 case fantasy.FinishReasonLength:
395 finishReason = message.FinishReasonMaxTokens
396 case fantasy.FinishReasonStop:
397 finishReason = message.FinishReasonEndTurn
398 case fantasy.FinishReasonToolCalls:
399 finishReason = message.FinishReasonToolUse
400 }
401 currentAssistant.AddFinish(finishReason, "", "")
402 sessionLock.Lock()
403 defer sessionLock.Unlock()
404
405 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
406 if getSessionErr != nil {
407 return getSessionErr
408 }
409 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
410 _, sessionErr := a.sessions.Save(ctx, updatedSession)
411 if sessionErr != nil {
412 return sessionErr
413 }
414 currentSession = updatedSession
415 return a.messages.Update(genCtx, *currentAssistant)
416 },
417 StopWhen: []fantasy.StopCondition{
418 func(_ []fantasy.StepResult) bool {
419 cw := int64(largeModel.CatwalkCfg.ContextWindow)
420 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
421 remaining := cw - tokens
422 var threshold int64
423 if cw > largeContextWindowThreshold {
424 threshold = largeContextWindowBuffer
425 } else {
426 threshold = int64(float64(cw) * smallContextWindowRatio)
427 }
428 if (remaining <= threshold) && !a.disableAutoSummarize {
429 shouldSummarize = true
430 return true
431 }
432 return false
433 },
434 func(steps []fantasy.StepResult) bool {
435 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
436 },
437 },
438 })
439
440 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
441
442 if err != nil {
443 isCancelErr := errors.Is(err, context.Canceled)
444 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
445 if currentAssistant == nil {
446 return result, err
447 }
448 // Ensure we finish thinking on error to close the reasoning state.
449 currentAssistant.FinishThinking()
450 toolCalls := currentAssistant.ToolCalls()
451 // INFO: we use the parent context here because the genCtx has been cancelled.
452 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
453 if createErr != nil {
454 return nil, createErr
455 }
456 for _, tc := range toolCalls {
457 if !tc.Finished {
458 tc.Finished = true
459 tc.Input = "{}"
460 currentAssistant.AddToolCall(tc)
461 updateErr := a.messages.Update(ctx, *currentAssistant)
462 if updateErr != nil {
463 return nil, updateErr
464 }
465 }
466
467 found := false
468 for _, msg := range msgs {
469 if msg.Role == message.Tool {
470 for _, tr := range msg.ToolResults() {
471 if tr.ToolCallID == tc.ID {
472 found = true
473 break
474 }
475 }
476 }
477 if found {
478 break
479 }
480 }
481 if found {
482 continue
483 }
484 content := "There was an error while executing the tool"
485 if isCancelErr {
486 content = "Tool execution canceled by user"
487 } else if isPermissionErr {
488 content = "User denied permission"
489 }
490 toolResult := message.ToolResult{
491 ToolCallID: tc.ID,
492 Name: tc.Name,
493 Content: content,
494 IsError: true,
495 }
496 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
497 Role: message.Tool,
498 Parts: []message.ContentPart{
499 toolResult,
500 },
501 })
502 if createErr != nil {
503 return nil, createErr
504 }
505 }
506 var fantasyErr *fantasy.Error
507 var providerErr *fantasy.ProviderError
508 const defaultTitle = "Provider Error"
509 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
510 if isCancelErr {
511 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
512 } else if isPermissionErr {
513 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
514 } else if errors.Is(err, hyper.ErrNoCredits) {
515 url := hyper.BaseURL()
516 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
517 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
518 } else if errors.As(err, &providerErr) {
519 if providerErr.Message == "The requested model is not supported." {
520 url := "https://github.com/settings/copilot/features"
521 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
522 currentAssistant.AddFinish(
523 message.FinishReasonError,
524 "Copilot model not enabled",
525 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
526 )
527 } else {
528 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
529 }
530 } else if errors.As(err, &fantasyErr) {
531 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
532 } else {
533 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
534 }
535 // Note: we use the parent context here because the genCtx has been
536 // cancelled.
537 updateErr := a.messages.Update(ctx, *currentAssistant)
538 if updateErr != nil {
539 return nil, updateErr
540 }
541 return nil, err
542 }
543
544 // Send notification that agent has finished its turn (skip for
545 // nested/non-interactive sessions).
546 if !call.NonInteractive && a.notify != nil {
547 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
548 SessionID: call.SessionID,
549 SessionTitle: currentSession.Title,
550 Type: notify.TypeAgentFinished,
551 })
552 }
553
554 if shouldSummarize {
555 a.activeRequests.Del(call.SessionID)
556 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
557 return nil, summarizeErr
558 }
559 // If the agent wasn't done...
560 if len(currentAssistant.ToolCalls()) > 0 {
561 existing, ok := a.messageQueue.Get(call.SessionID)
562 if !ok {
563 existing = []SessionAgentCall{}
564 }
565 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
566 existing = append(existing, call)
567 a.messageQueue.Set(call.SessionID, existing)
568 }
569 }
570
571 // Release active request before processing queued messages.
572 a.activeRequests.Del(call.SessionID)
573 cancel()
574
575 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
576 if !ok || len(queuedMessages) == 0 {
577 return result, err
578 }
579 // There are queued messages restart the loop.
580 firstQueuedMessage := queuedMessages[0]
581 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
582 return a.Run(ctx, firstQueuedMessage)
583}
584
585func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
586 if a.IsSessionBusy(sessionID) {
587 return ErrSessionBusy
588 }
589
590 // Copy mutable fields under lock to avoid races with SetModels.
591 largeModel := a.largeModel.Get()
592 systemPromptPrefix := a.systemPromptPrefix.Get()
593
594 currentSession, err := a.sessions.Get(ctx, sessionID)
595 if err != nil {
596 return fmt.Errorf("failed to get session: %w", err)
597 }
598 msgs, err := a.getSessionMessages(ctx, currentSession)
599 if err != nil {
600 return err
601 }
602 if len(msgs) == 0 {
603 // Nothing to summarize.
604 return nil
605 }
606
607 aiMsgs, _ := a.preparePrompt(msgs)
608
609 genCtx, cancel := context.WithCancel(ctx)
610 a.activeRequests.Set(sessionID, cancel)
611 defer a.activeRequests.Del(sessionID)
612 defer cancel()
613
614 agent := fantasy.NewAgent(largeModel.Model,
615 fantasy.WithSystemPrompt(string(summaryPrompt)),
616 fantasy.WithUserAgent("Charm Crush/"+version.Version),
617 )
618 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
619 Role: message.Assistant,
620 Model: largeModel.Model.Model(),
621 Provider: largeModel.Model.Provider(),
622 IsSummaryMessage: true,
623 })
624 if err != nil {
625 return err
626 }
627
628 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
629
630 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
631 Prompt: summaryPromptText,
632 Messages: aiMsgs,
633 ProviderOptions: opts,
634 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
635 prepared.Messages = options.Messages
636 if systemPromptPrefix != "" {
637 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
638 }
639 return callContext, prepared, nil
640 },
641 OnReasoningDelta: func(id string, text string) error {
642 summaryMessage.AppendReasoningContent(text)
643 return a.messages.Update(genCtx, summaryMessage)
644 },
645 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
646 // Handle anthropic signature.
647 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
648 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
649 summaryMessage.AppendReasoningSignature(signature.Signature)
650 }
651 }
652 summaryMessage.FinishThinking()
653 return a.messages.Update(genCtx, summaryMessage)
654 },
655 OnTextDelta: func(id, text string) error {
656 summaryMessage.AppendContent(text)
657 return a.messages.Update(genCtx, summaryMessage)
658 },
659 })
660 if err != nil {
661 isCancelErr := errors.Is(err, context.Canceled)
662 if isCancelErr {
663 // User cancelled summarize we need to remove the summary message.
664 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
665 return deleteErr
666 }
667 return err
668 }
669
670 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
671 err = a.messages.Update(genCtx, summaryMessage)
672 if err != nil {
673 return err
674 }
675
676 var openrouterCost *float64
677 for _, step := range resp.Steps {
678 stepCost := a.openrouterCost(step.ProviderMetadata)
679 if stepCost != nil {
680 newCost := *stepCost
681 if openrouterCost != nil {
682 newCost += *openrouterCost
683 }
684 openrouterCost = &newCost
685 }
686 }
687
688 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
689
690 // Just in case, get just the last usage info.
691 usage := resp.Response.Usage
692 currentSession.SummaryMessageID = summaryMessage.ID
693 currentSession.CompletionTokens = usage.OutputTokens
694 currentSession.PromptTokens = 0
695 _, err = a.sessions.Save(genCtx, currentSession)
696 return err
697}
698
699func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
700 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
701 return fantasy.ProviderOptions{}
702 }
703 return fantasy.ProviderOptions{
704 anthropic.Name: &anthropic.ProviderCacheControlOptions{
705 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
706 },
707 bedrock.Name: &anthropic.ProviderCacheControlOptions{
708 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
709 },
710 vercel.Name: &anthropic.ProviderCacheControlOptions{
711 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
712 },
713 }
714}
715
716func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
717 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
718 var attachmentParts []message.ContentPart
719 for _, attachment := range call.Attachments {
720 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
721 }
722 parts = append(parts, attachmentParts...)
723 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
724 Role: message.User,
725 Parts: parts,
726 })
727 if err != nil {
728 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
729 }
730 return msg, nil
731}
732
733func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
734 var history []fantasy.Message
735 if !a.isSubAgent {
736 history = append(history, fantasy.NewUserMessage(
737 fmt.Sprintf("<system_reminder>%s</system_reminder>",
738 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
739If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
740If not, please feel free to ignore. Again do not mention this message to the user.`,
741 ),
742 ))
743 }
744 for _, m := range msgs {
745 if len(m.Parts) == 0 {
746 continue
747 }
748 // Assistant message without content or tool calls (cancelled before it
749 // returned anything).
750 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
751 continue
752 }
753 history = append(history, m.ToAIMessage()...)
754 }
755
756 var files []fantasy.FilePart
757 for _, attachment := range attachments {
758 if attachment.IsText() {
759 continue
760 }
761 files = append(files, fantasy.FilePart{
762 Filename: attachment.FileName,
763 Data: attachment.Content,
764 MediaType: attachment.MimeType,
765 })
766 }
767
768 return history, files
769}
770
771func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
772 msgs, err := a.messages.List(ctx, session.ID)
773 if err != nil {
774 return nil, fmt.Errorf("failed to list messages: %w", err)
775 }
776
777 if session.SummaryMessageID != "" {
778 summaryMsgIndex := -1
779 for i, msg := range msgs {
780 if msg.ID == session.SummaryMessageID {
781 summaryMsgIndex = i
782 break
783 }
784 }
785 if summaryMsgIndex != -1 {
786 msgs = msgs[summaryMsgIndex:]
787 msgs[0].Role = message.User
788 }
789 }
790 return msgs, nil
791}
792
793// generateTitle generates a session titled based on the initial prompt.
794func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
795 if userPrompt == "" {
796 return
797 }
798
799 smallModel := a.smallModel.Get()
800 largeModel := a.largeModel.Get()
801 systemPromptPrefix := a.systemPromptPrefix.Get()
802
803 var maxOutputTokens int64 = 40
804 if smallModel.CatwalkCfg.CanReason {
805 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
806 }
807
808 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
809 return fantasy.NewAgent(m,
810 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
811 fantasy.WithMaxOutputTokens(tok),
812 fantasy.WithUserAgent("Charm Crush/"+version.Version),
813 )
814 }
815
816 streamCall := fantasy.AgentStreamCall{
817 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
818 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
819 prepared.Messages = opts.Messages
820 if systemPromptPrefix != "" {
821 prepared.Messages = append([]fantasy.Message{
822 fantasy.NewSystemMessage(systemPromptPrefix),
823 }, prepared.Messages...)
824 }
825 return callCtx, prepared, nil
826 },
827 }
828
829 // Use the small model to generate the title.
830 model := smallModel
831 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
832 resp, err := agent.Stream(ctx, streamCall)
833 if err == nil {
834 // We successfully generated a title with the small model.
835 slog.Debug("Generated title with small model")
836 } else {
837 // It didn't work. Let's try with the big model.
838 slog.Error("Error generating title with small model; trying big model", "err", err)
839 model = largeModel
840 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
841 resp, err = agent.Stream(ctx, streamCall)
842 if err == nil {
843 slog.Debug("Generated title with large model")
844 } else {
845 // Welp, the large model didn't work either. Use the default
846 // session name and return.
847 slog.Error("Error generating title with large model", "err", err)
848 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, DefaultSessionName, 0, 0, 0)
849 if saveErr != nil {
850 slog.Error("Failed to save session title and usage", "error", saveErr)
851 }
852 return
853 }
854 }
855
856 if resp == nil {
857 // Actually, we didn't get a response so we can't. Use the default
858 // session name and return.
859 slog.Error("Response is nil; can't generate title")
860 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, DefaultSessionName, 0, 0, 0)
861 if saveErr != nil {
862 slog.Error("Failed to save session title and usage", "error", saveErr)
863 }
864 return
865 }
866
867 // Clean up title.
868 var title string
869 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
870
871 // Remove thinking tags if present.
872 title = thinkTagRegex.ReplaceAllString(title, "")
873
874 title = strings.TrimSpace(title)
875 title = cmp.Or(title, DefaultSessionName)
876
877 // Calculate usage and cost.
878 var openrouterCost *float64
879 for _, step := range resp.Steps {
880 stepCost := a.openrouterCost(step.ProviderMetadata)
881 if stepCost != nil {
882 newCost := *stepCost
883 if openrouterCost != nil {
884 newCost += *openrouterCost
885 }
886 openrouterCost = &newCost
887 }
888 }
889
890 modelConfig := model.CatwalkCfg
891 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
892 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
893 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
894 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
895
896 // Use override cost if available (e.g., from OpenRouter).
897 if openrouterCost != nil {
898 cost = *openrouterCost
899 }
900
901 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
902 completionTokens := resp.TotalUsage.OutputTokens
903
904 // Atomically update only title and usage fields to avoid overriding other
905 // concurrent session updates.
906 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
907 if saveErr != nil {
908 slog.Error("Failed to save session title and usage", "error", saveErr)
909 return
910 }
911}
912
913func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
914 openrouterMetadata, ok := metadata[openrouter.Name]
915 if !ok {
916 return nil
917 }
918
919 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
920 if !ok {
921 return nil
922 }
923 return &opts.Usage.Cost
924}
925
926func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
927 modelConfig := model.CatwalkCfg
928 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
929 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
930 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
931 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
932
933 a.eventTokensUsed(session.ID, model, usage, cost)
934
935 if overrideCost != nil {
936 session.Cost += *overrideCost
937 } else {
938 session.Cost += cost
939 }
940
941 session.CompletionTokens = usage.OutputTokens
942 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
943}
944
945func (a *sessionAgent) Cancel(sessionID string) {
946 // Cancel regular requests. Don't use Take() here - we need the entry to
947 // remain in activeRequests so IsBusy() returns true until the goroutine
948 // fully completes (including error handling that may access the DB).
949 // The defer in processRequest will clean up the entry.
950 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
951 slog.Debug("Request cancellation initiated", "session_id", sessionID)
952 cancel()
953 }
954
955 // Also check for summarize requests.
956 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
957 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
958 cancel()
959 }
960
961 if a.QueuedPrompts(sessionID) > 0 {
962 slog.Debug("Clearing queued prompts", "session_id", sessionID)
963 a.messageQueue.Del(sessionID)
964 }
965}
966
967func (a *sessionAgent) ClearQueue(sessionID string) {
968 if a.QueuedPrompts(sessionID) > 0 {
969 slog.Debug("Clearing queued prompts", "session_id", sessionID)
970 a.messageQueue.Del(sessionID)
971 }
972}
973
974func (a *sessionAgent) CancelAll() {
975 if !a.IsBusy() {
976 return
977 }
978 for key := range a.activeRequests.Seq2() {
979 a.Cancel(key) // key is sessionID
980 }
981
982 timeout := time.After(5 * time.Second)
983 for a.IsBusy() {
984 select {
985 case <-timeout:
986 return
987 default:
988 time.Sleep(200 * time.Millisecond)
989 }
990 }
991}
992
993func (a *sessionAgent) IsBusy() bool {
994 var busy bool
995 for cancelFunc := range a.activeRequests.Seq() {
996 if cancelFunc != nil {
997 busy = true
998 break
999 }
1000 }
1001 return busy
1002}
1003
1004func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1005 _, busy := a.activeRequests.Get(sessionID)
1006 return busy
1007}
1008
1009func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1010 l, ok := a.messageQueue.Get(sessionID)
1011 if !ok {
1012 return 0
1013 }
1014 return len(l)
1015}
1016
1017func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1018 l, ok := a.messageQueue.Get(sessionID)
1019 if !ok {
1020 return nil
1021 }
1022 prompts := make([]string, len(l))
1023 for i, call := range l {
1024 prompts[i] = call.Prompt
1025 }
1026 return prompts
1027}
1028
1029func (a *sessionAgent) SetModels(large Model, small Model) {
1030 a.largeModel.Set(large)
1031 a.smallModel.Set(small)
1032}
1033
1034func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1035 a.tools.SetSlice(tools)
1036}
1037
1038func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1039 a.systemPrompt.Set(systemPrompt)
1040}
1041
1042func (a *sessionAgent) Model() Model {
1043 return a.largeModel.Get()
1044}
1045
1046// convertToToolResult converts a fantasy tool result to a message tool result.
1047func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1048 baseResult := message.ToolResult{
1049 ToolCallID: result.ToolCallID,
1050 Name: result.ToolName,
1051 Metadata: result.ClientMetadata,
1052 }
1053
1054 switch result.Result.GetType() {
1055 case fantasy.ToolResultContentTypeText:
1056 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1057 baseResult.Content = r.Text
1058 }
1059 case fantasy.ToolResultContentTypeError:
1060 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1061 baseResult.Content = r.Error.Error()
1062 baseResult.IsError = true
1063 }
1064 case fantasy.ToolResultContentTypeMedia:
1065 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1066 content := r.Text
1067 if content == "" {
1068 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1069 }
1070 baseResult.Content = content
1071 baseResult.Data = r.Data
1072 baseResult.MIMEType = r.MediaType
1073 }
1074 }
1075
1076 return baseResult
1077}
1078
1079// workaroundProviderMediaLimitations converts media content in tool results to
1080// user messages for providers that don't natively support images in tool results.
1081//
1082// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1083// don't support sending images/media in tool result messages - they only accept
1084// text in tool results. However, they DO support images in user messages.
1085//
1086// If we send media in tool results to these providers, the API returns an error.
1087//
1088// Solution: For these providers, we:
1089// 1. Replace the media in the tool result with a text placeholder
1090// 2. Inject a user message immediately after with the image as a file attachment
1091// 3. This maintains the tool execution flow while working around API limitations
1092//
1093// Anthropic and Bedrock support images natively in tool results, so we skip
1094// this workaround for them.
1095//
1096// Example transformation:
1097//
1098// BEFORE: [tool result: image data]
1099// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1100func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1101 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1102 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1103
1104 if providerSupportsMedia {
1105 return messages
1106 }
1107
1108 convertedMessages := make([]fantasy.Message, 0, len(messages))
1109
1110 for _, msg := range messages {
1111 if msg.Role != fantasy.MessageRoleTool {
1112 convertedMessages = append(convertedMessages, msg)
1113 continue
1114 }
1115
1116 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1117 var mediaFiles []fantasy.FilePart
1118
1119 for _, part := range msg.Content {
1120 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1121 if !ok {
1122 textParts = append(textParts, part)
1123 continue
1124 }
1125
1126 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1127 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1128 if err != nil {
1129 slog.Warn("Failed to decode media data", "error", err)
1130 textParts = append(textParts, part)
1131 continue
1132 }
1133
1134 mediaFiles = append(mediaFiles, fantasy.FilePart{
1135 Data: decoded,
1136 MediaType: media.MediaType,
1137 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1138 })
1139
1140 textParts = append(textParts, fantasy.ToolResultPart{
1141 ToolCallID: toolResult.ToolCallID,
1142 Output: fantasy.ToolResultOutputContentText{
1143 Text: "[Image/media content loaded - see attached file]",
1144 },
1145 ProviderOptions: toolResult.ProviderOptions,
1146 })
1147 } else {
1148 textParts = append(textParts, part)
1149 }
1150 }
1151
1152 convertedMessages = append(convertedMessages, fantasy.Message{
1153 Role: fantasy.MessageRoleTool,
1154 Content: textParts,
1155 })
1156
1157 if len(mediaFiles) > 0 {
1158 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1159 "Here is the media content from the tool result:",
1160 mediaFiles...,
1161 ))
1162 }
1163 }
1164
1165 return convertedMessages
1166}
1167
1168// buildSummaryPrompt constructs the prompt text for session summarization.
1169func buildSummaryPrompt(todos []session.Todo) string {
1170 var sb strings.Builder
1171 sb.WriteString("Provide a detailed summary of our conversation above.")
1172 if len(todos) > 0 {
1173 sb.WriteString("\n\n## Current Todo List\n\n")
1174 for _, t := range todos {
1175 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1176 }
1177 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1178 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1179 }
1180 return sb.String()
1181}