1// Package agent is the core orchestration layer for Crush AI agents.
2//
3// It provides session-based AI agent functionality for managing
4// conversations, tool execution, and message handling. It coordinates
5// interactions between language models, messages, sessions, and tools while
6// handling features like automatic summarization, queuing, and token
7// management.
8package agent
9
10import (
11 "cmp"
12 "context"
13 _ "embed"
14 "encoding/base64"
15 "errors"
16 "fmt"
17 "log/slog"
18 "os"
19 "regexp"
20 "strconv"
21 "strings"
22 "sync"
23 "time"
24
25 "charm.land/catwalk/pkg/catwalk"
26 "charm.land/fantasy"
27 "charm.land/fantasy/providers/anthropic"
28 "charm.land/fantasy/providers/bedrock"
29 "charm.land/fantasy/providers/google"
30 "charm.land/fantasy/providers/openai"
31 "charm.land/fantasy/providers/openrouter"
32 "charm.land/fantasy/providers/vercel"
33 "charm.land/lipgloss/v2"
34 "git.secluded.site/crush/internal/agent/hyper"
35 "git.secluded.site/crush/internal/agent/notify"
36 "git.secluded.site/crush/internal/agent/tools"
37 "git.secluded.site/crush/internal/agent/tools/mcp"
38 "git.secluded.site/crush/internal/config"
39 "git.secluded.site/crush/internal/csync"
40 "git.secluded.site/crush/internal/message"
41 "git.secluded.site/crush/internal/permission"
42 "git.secluded.site/crush/internal/pubsub"
43 "git.secluded.site/crush/internal/session"
44 "git.secluded.site/crush/internal/stringext"
45 "github.com/charmbracelet/x/exp/charmtone"
46)
47
48const (
49 defaultSessionName = "Untitled Session"
50
51 // Constants for auto-summarization thresholds
52 largeContextWindowThreshold = 200_000
53 largeContextWindowBuffer = 20_000
54 smallContextWindowRatio = 0.2
55)
56
57//go:embed templates/title.md
58var titlePrompt []byte
59
60//go:embed templates/summary.md
61var summaryPrompt []byte
62
63// Used to remove <think> tags from generated titles.
64var thinkTagRegex = regexp.MustCompile(`<think>.*?</think>`)
65
66type SessionAgentCall struct {
67 SessionID string
68 Prompt string
69 ProviderOptions fantasy.ProviderOptions
70 Attachments []message.Attachment
71 MaxOutputTokens int64
72 Temperature *float64
73 TopP *float64
74 TopK *int64
75 FrequencyPenalty *float64
76 PresencePenalty *float64
77 NonInteractive bool
78}
79
80type SessionAgent interface {
81 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
82 SetModels(large Model, small Model)
83 SetTools(tools []fantasy.AgentTool)
84 SetSystemPrompt(systemPrompt string)
85 Cancel(sessionID string)
86 CancelAll()
87 IsSessionBusy(sessionID string) bool
88 IsBusy() bool
89 QueuedPrompts(sessionID string) int
90 QueuedPromptsList(sessionID string) []string
91 ClearQueue(sessionID string)
92 Summarize(context.Context, string, fantasy.ProviderOptions) error
93 Model() Model
94}
95
96type Model struct {
97 Model fantasy.LanguageModel
98 CatwalkCfg catwalk.Model
99 ModelCfg config.SelectedModel
100}
101
102type sessionAgent struct {
103 largeModel *csync.Value[Model]
104 smallModel *csync.Value[Model]
105 systemPromptPrefix *csync.Value[string]
106 systemPrompt *csync.Value[string]
107 tools *csync.Slice[fantasy.AgentTool]
108
109 isSubAgent bool
110 sessions session.Service
111 messages message.Service
112 disableAutoSummarize bool
113 isYolo bool
114 notify pubsub.Publisher[notify.Notification]
115
116 messageQueue *csync.Map[string, []SessionAgentCall]
117 activeRequests *csync.Map[string, context.CancelFunc]
118}
119
120type SessionAgentOptions struct {
121 LargeModel Model
122 SmallModel Model
123 SystemPromptPrefix string
124 SystemPrompt string
125 IsSubAgent bool
126 DisableAutoSummarize bool
127 IsYolo bool
128 Sessions session.Service
129 Messages message.Service
130 Tools []fantasy.AgentTool
131 Notify pubsub.Publisher[notify.Notification]
132}
133
134func NewSessionAgent(
135 opts SessionAgentOptions,
136) SessionAgent {
137 return &sessionAgent{
138 largeModel: csync.NewValue(opts.LargeModel),
139 smallModel: csync.NewValue(opts.SmallModel),
140 systemPromptPrefix: csync.NewValue(opts.SystemPromptPrefix),
141 systemPrompt: csync.NewValue(opts.SystemPrompt),
142 isSubAgent: opts.IsSubAgent,
143 sessions: opts.Sessions,
144 messages: opts.Messages,
145 disableAutoSummarize: opts.DisableAutoSummarize,
146 tools: csync.NewSliceFrom(opts.Tools),
147 isYolo: opts.IsYolo,
148 notify: opts.Notify,
149 messageQueue: csync.NewMap[string, []SessionAgentCall](),
150 activeRequests: csync.NewMap[string, context.CancelFunc](),
151 }
152}
153
154func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
155 if call.Prompt == "" && !message.ContainsTextAttachment(call.Attachments) {
156 return nil, ErrEmptyPrompt
157 }
158 if call.SessionID == "" {
159 return nil, ErrSessionMissing
160 }
161
162 // Queue the message if busy
163 if a.IsSessionBusy(call.SessionID) {
164 existing, ok := a.messageQueue.Get(call.SessionID)
165 if !ok {
166 existing = []SessionAgentCall{}
167 }
168 existing = append(existing, call)
169 a.messageQueue.Set(call.SessionID, existing)
170 return nil, nil
171 }
172
173 // Copy mutable fields under lock to avoid races with SetTools/SetModels.
174 agentTools := a.tools.Copy()
175 largeModel := a.largeModel.Get()
176 systemPrompt := a.systemPrompt.Get()
177 promptPrefix := a.systemPromptPrefix.Get()
178 var instructions strings.Builder
179
180 for _, server := range mcp.GetStates() {
181 if server.State != mcp.StateConnected {
182 continue
183 }
184 if s := server.Client.InitializeResult().Instructions; s != "" {
185 instructions.WriteString(s)
186 instructions.WriteString("\n\n")
187 }
188 }
189
190 if s := instructions.String(); s != "" {
191 systemPrompt += "\n\n<mcp-instructions>\n" + s + "\n</mcp-instructions>"
192 }
193
194 if len(agentTools) > 0 {
195 // Add Anthropic caching to the last tool.
196 agentTools[len(agentTools)-1].SetProviderOptions(a.getCacheControlOptions())
197 }
198
199 agent := fantasy.NewAgent(
200 largeModel.Model,
201 fantasy.WithSystemPrompt(systemPrompt),
202 fantasy.WithTools(agentTools...),
203 )
204
205 sessionLock := sync.Mutex{}
206 currentSession, err := a.sessions.Get(ctx, call.SessionID)
207 if err != nil {
208 return nil, fmt.Errorf("failed to get session: %w", err)
209 }
210
211 msgs, err := a.getSessionMessages(ctx, currentSession)
212 if err != nil {
213 return nil, fmt.Errorf("failed to get session messages: %w", err)
214 }
215
216 var wg sync.WaitGroup
217 // Generate title if first message.
218 if len(msgs) == 0 {
219 titleCtx := ctx // Copy to avoid race with ctx reassignment below.
220 wg.Go(func() {
221 a.generateTitle(titleCtx, call.SessionID, call.Prompt)
222 })
223 }
224 defer wg.Wait()
225
226 // Add the user message to the session.
227 _, err = a.createUserMessage(ctx, call)
228 if err != nil {
229 return nil, err
230 }
231
232 // Add the session to the context.
233 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
234
235 genCtx, cancel := context.WithCancel(ctx)
236 a.activeRequests.Set(call.SessionID, cancel)
237
238 defer cancel()
239 defer a.activeRequests.Del(call.SessionID)
240
241 history, files := a.preparePrompt(msgs, call.Attachments...)
242
243 startTime := time.Now()
244 a.eventPromptSent(call.SessionID)
245
246 var currentAssistant *message.Message
247 var shouldSummarize bool
248 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
249 Prompt: message.PromptWithTextAttachments(call.Prompt, call.Attachments),
250 Files: files,
251 Messages: history,
252 ProviderOptions: call.ProviderOptions,
253 MaxOutputTokens: &call.MaxOutputTokens,
254 TopP: call.TopP,
255 Temperature: call.Temperature,
256 PresencePenalty: call.PresencePenalty,
257 TopK: call.TopK,
258 FrequencyPenalty: call.FrequencyPenalty,
259 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
260 prepared.Messages = options.Messages
261 for i := range prepared.Messages {
262 prepared.Messages[i].ProviderOptions = nil
263 }
264
265 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
266 a.messageQueue.Del(call.SessionID)
267 for _, queued := range queuedCalls {
268 userMessage, createErr := a.createUserMessage(callContext, queued)
269 if createErr != nil {
270 return callContext, prepared, createErr
271 }
272 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
273 }
274
275 prepared.Messages = a.workaroundProviderMediaLimitations(prepared.Messages, largeModel)
276
277 lastSystemRoleInx := 0
278 systemMessageUpdated := false
279 for i, msg := range prepared.Messages {
280 // Only add cache control to the last message.
281 if msg.Role == fantasy.MessageRoleSystem {
282 lastSystemRoleInx = i
283 } else if !systemMessageUpdated {
284 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
285 systemMessageUpdated = true
286 }
287 // Than add cache control to the last 2 messages.
288 if i > len(prepared.Messages)-3 {
289 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
290 }
291 }
292
293 if promptPrefix != "" {
294 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(promptPrefix)}, prepared.Messages...)
295 }
296
297 var assistantMsg message.Message
298 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
299 Role: message.Assistant,
300 Parts: []message.ContentPart{},
301 Model: largeModel.ModelCfg.Model,
302 Provider: largeModel.ModelCfg.Provider,
303 })
304 if err != nil {
305 return callContext, prepared, err
306 }
307 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
308 callContext = context.WithValue(callContext, tools.SupportsImagesContextKey, largeModel.CatwalkCfg.SupportsImages)
309 callContext = context.WithValue(callContext, tools.ModelNameContextKey, largeModel.CatwalkCfg.Name)
310 currentAssistant = &assistantMsg
311 return callContext, prepared, err
312 },
313 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
314 currentAssistant.AppendReasoningContent(reasoning.Text)
315 return a.messages.Update(genCtx, *currentAssistant)
316 },
317 OnReasoningDelta: func(id string, text string) error {
318 currentAssistant.AppendReasoningContent(text)
319 return a.messages.Update(genCtx, *currentAssistant)
320 },
321 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
322 // handle anthropic signature
323 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
324 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
325 currentAssistant.AppendReasoningSignature(reasoning.Signature)
326 }
327 }
328 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
329 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
330 currentAssistant.AppendThoughtSignature(reasoning.Signature, reasoning.ToolID)
331 }
332 }
333 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
334 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
335 currentAssistant.SetReasoningResponsesData(reasoning)
336 }
337 }
338 currentAssistant.FinishThinking()
339 return a.messages.Update(genCtx, *currentAssistant)
340 },
341 OnTextDelta: func(id string, text string) error {
342 // Strip leading newline from initial text content. This is is
343 // particularly important in non-interactive mode where leading
344 // newlines are very visible.
345 if len(currentAssistant.Parts) == 0 {
346 text = strings.TrimPrefix(text, "\n")
347 }
348
349 currentAssistant.AppendContent(text)
350 return a.messages.Update(genCtx, *currentAssistant)
351 },
352 OnToolInputStart: func(id string, toolName string) error {
353 toolCall := message.ToolCall{
354 ID: id,
355 Name: toolName,
356 ProviderExecuted: false,
357 Finished: false,
358 }
359 currentAssistant.AddToolCall(toolCall)
360 return a.messages.Update(genCtx, *currentAssistant)
361 },
362 OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
363 // TODO: implement
364 },
365 OnToolCall: func(tc fantasy.ToolCallContent) error {
366 toolCall := message.ToolCall{
367 ID: tc.ToolCallID,
368 Name: tc.ToolName,
369 Input: tc.Input,
370 ProviderExecuted: false,
371 Finished: true,
372 }
373 currentAssistant.AddToolCall(toolCall)
374 return a.messages.Update(genCtx, *currentAssistant)
375 },
376 OnToolResult: func(result fantasy.ToolResultContent) error {
377 toolResult := a.convertToToolResult(result)
378 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
379 Role: message.Tool,
380 Parts: []message.ContentPart{
381 toolResult,
382 },
383 })
384 return createMsgErr
385 },
386 OnStepFinish: func(stepResult fantasy.StepResult) error {
387 finishReason := message.FinishReasonUnknown
388 switch stepResult.FinishReason {
389 case fantasy.FinishReasonLength:
390 finishReason = message.FinishReasonMaxTokens
391 case fantasy.FinishReasonStop:
392 finishReason = message.FinishReasonEndTurn
393 case fantasy.FinishReasonToolCalls:
394 finishReason = message.FinishReasonToolUse
395 }
396 currentAssistant.AddFinish(finishReason, "", "")
397 sessionLock.Lock()
398 defer sessionLock.Unlock()
399
400 updatedSession, getSessionErr := a.sessions.Get(ctx, call.SessionID)
401 if getSessionErr != nil {
402 return getSessionErr
403 }
404 a.updateSessionUsage(largeModel, &updatedSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
405 _, sessionErr := a.sessions.Save(ctx, updatedSession)
406 if sessionErr != nil {
407 return sessionErr
408 }
409 currentSession = updatedSession
410 return a.messages.Update(genCtx, *currentAssistant)
411 },
412 StopWhen: []fantasy.StopCondition{
413 func(_ []fantasy.StepResult) bool {
414 cw := int64(largeModel.CatwalkCfg.ContextWindow)
415 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
416 remaining := cw - tokens
417 var threshold int64
418 if cw > largeContextWindowThreshold {
419 threshold = largeContextWindowBuffer
420 } else {
421 threshold = int64(float64(cw) * smallContextWindowRatio)
422 }
423 if (remaining <= threshold) && !a.disableAutoSummarize {
424 shouldSummarize = true
425 return true
426 }
427 return false
428 },
429 func(steps []fantasy.StepResult) bool {
430 return hasRepeatedToolCalls(steps, loopDetectionWindowSize, loopDetectionMaxRepeats)
431 },
432 },
433 })
434
435 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
436
437 if err != nil {
438 isCancelErr := errors.Is(err, context.Canceled)
439 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
440 if currentAssistant == nil {
441 return result, err
442 }
443 // Ensure we finish thinking on error to close the reasoning state.
444 currentAssistant.FinishThinking()
445 toolCalls := currentAssistant.ToolCalls()
446 // INFO: we use the parent context here because the genCtx has been cancelled.
447 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
448 if createErr != nil {
449 return nil, createErr
450 }
451 for _, tc := range toolCalls {
452 if !tc.Finished {
453 tc.Finished = true
454 tc.Input = "{}"
455 currentAssistant.AddToolCall(tc)
456 updateErr := a.messages.Update(ctx, *currentAssistant)
457 if updateErr != nil {
458 return nil, updateErr
459 }
460 }
461
462 found := false
463 for _, msg := range msgs {
464 if msg.Role == message.Tool {
465 for _, tr := range msg.ToolResults() {
466 if tr.ToolCallID == tc.ID {
467 found = true
468 break
469 }
470 }
471 }
472 if found {
473 break
474 }
475 }
476 if found {
477 continue
478 }
479 content := "There was an error while executing the tool"
480 if isCancelErr {
481 content = "Tool execution canceled by user"
482 } else if isPermissionErr {
483 content = "User denied permission"
484 }
485 toolResult := message.ToolResult{
486 ToolCallID: tc.ID,
487 Name: tc.Name,
488 Content: content,
489 IsError: true,
490 }
491 _, createErr = a.messages.Create(ctx, currentAssistant.SessionID, message.CreateMessageParams{
492 Role: message.Tool,
493 Parts: []message.ContentPart{
494 toolResult,
495 },
496 })
497 if createErr != nil {
498 return nil, createErr
499 }
500 }
501 var fantasyErr *fantasy.Error
502 var providerErr *fantasy.ProviderError
503 const defaultTitle = "Provider Error"
504 linkStyle := lipgloss.NewStyle().Foreground(charmtone.Guac).Underline(true)
505 if isCancelErr {
506 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
507 } else if isPermissionErr {
508 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
509 } else if errors.Is(err, hyper.ErrNoCredits) {
510 url := hyper.BaseURL()
511 link := linkStyle.Hyperlink(url, "id=hyper").Render(url)
512 currentAssistant.AddFinish(message.FinishReasonError, "No credits", "You're out of credits. Add more at "+link)
513 } else if errors.As(err, &providerErr) {
514 if providerErr.Message == "The requested model is not supported." {
515 url := "https://github.com/settings/copilot/features"
516 link := linkStyle.Hyperlink(url, "id=copilot").Render(url)
517 currentAssistant.AddFinish(
518 message.FinishReasonError,
519 "Copilot model not enabled",
520 fmt.Sprintf("%q is not enabled in Copilot. Go to the following page to enable it. Then, wait 5 minutes before trying again. %s", largeModel.CatwalkCfg.Name, link),
521 )
522 } else {
523 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
524 }
525 } else if errors.As(err, &fantasyErr) {
526 currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
527 } else {
528 currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
529 }
530 // Note: we use the parent context here because the genCtx has been
531 // cancelled.
532 updateErr := a.messages.Update(ctx, *currentAssistant)
533 if updateErr != nil {
534 return nil, updateErr
535 }
536 return nil, err
537 }
538
539 // Send notification that agent has finished its turn (skip for
540 // nested/non-interactive sessions).
541 if !call.NonInteractive && a.notify != nil {
542 a.notify.Publish(pubsub.CreatedEvent, notify.Notification{
543 SessionID: call.SessionID,
544 SessionTitle: currentSession.Title,
545 Type: notify.TypeAgentFinished,
546 })
547 }
548
549 if shouldSummarize {
550 a.activeRequests.Del(call.SessionID)
551 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
552 return nil, summarizeErr
553 }
554 // If the agent wasn't done...
555 if len(currentAssistant.ToolCalls()) > 0 {
556 existing, ok := a.messageQueue.Get(call.SessionID)
557 if !ok {
558 existing = []SessionAgentCall{}
559 }
560 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
561 existing = append(existing, call)
562 a.messageQueue.Set(call.SessionID, existing)
563 }
564 }
565
566 // Release active request before processing queued messages.
567 a.activeRequests.Del(call.SessionID)
568 cancel()
569
570 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
571 if !ok || len(queuedMessages) == 0 {
572 return result, err
573 }
574 // There are queued messages restart the loop.
575 firstQueuedMessage := queuedMessages[0]
576 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
577 return a.Run(ctx, firstQueuedMessage)
578}
579
580func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
581 if a.IsSessionBusy(sessionID) {
582 return ErrSessionBusy
583 }
584
585 // Copy mutable fields under lock to avoid races with SetModels.
586 largeModel := a.largeModel.Get()
587 systemPromptPrefix := a.systemPromptPrefix.Get()
588
589 currentSession, err := a.sessions.Get(ctx, sessionID)
590 if err != nil {
591 return fmt.Errorf("failed to get session: %w", err)
592 }
593 msgs, err := a.getSessionMessages(ctx, currentSession)
594 if err != nil {
595 return err
596 }
597 if len(msgs) == 0 {
598 // Nothing to summarize.
599 return nil
600 }
601
602 aiMsgs, _ := a.preparePrompt(msgs)
603
604 genCtx, cancel := context.WithCancel(ctx)
605 a.activeRequests.Set(sessionID, cancel)
606 defer a.activeRequests.Del(sessionID)
607 defer cancel()
608
609 agent := fantasy.NewAgent(largeModel.Model,
610 fantasy.WithSystemPrompt(string(summaryPrompt)),
611 )
612 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
613 Role: message.Assistant,
614 Model: largeModel.Model.Model(),
615 Provider: largeModel.Model.Provider(),
616 IsSummaryMessage: true,
617 })
618 if err != nil {
619 return err
620 }
621
622 summaryPromptText := buildSummaryPrompt(currentSession.Todos)
623
624 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
625 Prompt: summaryPromptText,
626 Messages: aiMsgs,
627 ProviderOptions: opts,
628 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
629 prepared.Messages = options.Messages
630 if systemPromptPrefix != "" {
631 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(systemPromptPrefix)}, prepared.Messages...)
632 }
633 return callContext, prepared, nil
634 },
635 OnReasoningDelta: func(id string, text string) error {
636 summaryMessage.AppendReasoningContent(text)
637 return a.messages.Update(genCtx, summaryMessage)
638 },
639 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
640 // Handle anthropic signature.
641 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
642 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
643 summaryMessage.AppendReasoningSignature(signature.Signature)
644 }
645 }
646 summaryMessage.FinishThinking()
647 return a.messages.Update(genCtx, summaryMessage)
648 },
649 OnTextDelta: func(id, text string) error {
650 summaryMessage.AppendContent(text)
651 return a.messages.Update(genCtx, summaryMessage)
652 },
653 })
654 if err != nil {
655 isCancelErr := errors.Is(err, context.Canceled)
656 if isCancelErr {
657 // User cancelled summarize we need to remove the summary message.
658 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
659 return deleteErr
660 }
661 return err
662 }
663
664 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
665 err = a.messages.Update(genCtx, summaryMessage)
666 if err != nil {
667 return err
668 }
669
670 var openrouterCost *float64
671 for _, step := range resp.Steps {
672 stepCost := a.openrouterCost(step.ProviderMetadata)
673 if stepCost != nil {
674 newCost := *stepCost
675 if openrouterCost != nil {
676 newCost += *openrouterCost
677 }
678 openrouterCost = &newCost
679 }
680 }
681
682 a.updateSessionUsage(largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
683
684 // Just in case, get just the last usage info.
685 usage := resp.Response.Usage
686 currentSession.SummaryMessageID = summaryMessage.ID
687 currentSession.CompletionTokens = usage.OutputTokens
688 currentSession.PromptTokens = 0
689 _, err = a.sessions.Save(genCtx, currentSession)
690 return err
691}
692
693func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
694 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
695 return fantasy.ProviderOptions{}
696 }
697 return fantasy.ProviderOptions{
698 anthropic.Name: &anthropic.ProviderCacheControlOptions{
699 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
700 },
701 bedrock.Name: &anthropic.ProviderCacheControlOptions{
702 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
703 },
704 vercel.Name: &anthropic.ProviderCacheControlOptions{
705 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
706 },
707 }
708}
709
710func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
711 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
712 var attachmentParts []message.ContentPart
713 for _, attachment := range call.Attachments {
714 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
715 }
716 parts = append(parts, attachmentParts...)
717 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
718 Role: message.User,
719 Parts: parts,
720 })
721 if err != nil {
722 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
723 }
724 return msg, nil
725}
726
727func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
728 var history []fantasy.Message
729 if !a.isSubAgent {
730 history = append(history, fantasy.NewUserMessage(
731 fmt.Sprintf("<system_reminder>%s</system_reminder>",
732 `This is a reminder that your todo list is currently empty. DO NOT mention this to the user explicitly because they are already aware.
733If you are working on tasks that would benefit from a todo list please use the "todos" tool to create one.
734If not, please feel free to ignore. Again do not mention this message to the user.`,
735 ),
736 ))
737 }
738 for _, m := range msgs {
739 if len(m.Parts) == 0 {
740 continue
741 }
742 // Assistant message without content or tool calls (cancelled before it
743 // returned anything).
744 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
745 continue
746 }
747 history = append(history, m.ToAIMessage()...)
748 }
749
750 var files []fantasy.FilePart
751 for _, attachment := range attachments {
752 if attachment.IsText() {
753 continue
754 }
755 files = append(files, fantasy.FilePart{
756 Filename: attachment.FileName,
757 Data: attachment.Content,
758 MediaType: attachment.MimeType,
759 })
760 }
761
762 return history, files
763}
764
765func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
766 msgs, err := a.messages.List(ctx, session.ID)
767 if err != nil {
768 return nil, fmt.Errorf("failed to list messages: %w", err)
769 }
770
771 if session.SummaryMessageID != "" {
772 summaryMsgIndex := -1
773 for i, msg := range msgs {
774 if msg.ID == session.SummaryMessageID {
775 summaryMsgIndex = i
776 break
777 }
778 }
779 if summaryMsgIndex != -1 {
780 msgs = msgs[summaryMsgIndex:]
781 msgs[0].Role = message.User
782 }
783 }
784 return msgs, nil
785}
786
787// generateTitle generates a session titled based on the initial prompt.
788func (a *sessionAgent) generateTitle(ctx context.Context, sessionID string, userPrompt string) {
789 if userPrompt == "" {
790 return
791 }
792
793 smallModel := a.smallModel.Get()
794 largeModel := a.largeModel.Get()
795 systemPromptPrefix := a.systemPromptPrefix.Get()
796
797 var maxOutputTokens int64 = 40
798 if smallModel.CatwalkCfg.CanReason {
799 maxOutputTokens = smallModel.CatwalkCfg.DefaultMaxTokens
800 }
801
802 newAgent := func(m fantasy.LanguageModel, p []byte, tok int64) fantasy.Agent {
803 return fantasy.NewAgent(m,
804 fantasy.WithSystemPrompt(string(p)+"\n /no_think"),
805 fantasy.WithMaxOutputTokens(tok),
806 )
807 }
808
809 streamCall := fantasy.AgentStreamCall{
810 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", userPrompt),
811 PrepareStep: func(callCtx context.Context, opts fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
812 prepared.Messages = opts.Messages
813 if systemPromptPrefix != "" {
814 prepared.Messages = append([]fantasy.Message{
815 fantasy.NewSystemMessage(systemPromptPrefix),
816 }, prepared.Messages...)
817 }
818 return callCtx, prepared, nil
819 },
820 }
821
822 // Use the small model to generate the title.
823 model := smallModel
824 agent := newAgent(model.Model, titlePrompt, maxOutputTokens)
825 resp, err := agent.Stream(ctx, streamCall)
826 if err == nil {
827 // We successfully generated a title with the small model.
828 slog.Debug("Generated title with small model")
829 } else {
830 // It didn't work. Let's try with the big model.
831 slog.Error("Error generating title with small model; trying big model", "err", err)
832 model = largeModel
833 agent = newAgent(model.Model, titlePrompt, maxOutputTokens)
834 resp, err = agent.Stream(ctx, streamCall)
835 if err == nil {
836 slog.Debug("Generated title with large model")
837 } else {
838 // Welp, the large model didn't work either. Use the default
839 // session name and return.
840 slog.Error("Error generating title with large model", "err", err)
841 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
842 if saveErr != nil {
843 slog.Error("Failed to save session title and usage", "error", saveErr)
844 }
845 return
846 }
847 }
848
849 if resp == nil {
850 // Actually, we didn't get a response so we can't. Use the default
851 // session name and return.
852 slog.Error("Response is nil; can't generate title")
853 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, defaultSessionName, 0, 0, 0)
854 if saveErr != nil {
855 slog.Error("Failed to save session title and usage", "error", saveErr)
856 }
857 return
858 }
859
860 // Clean up title.
861 var title string
862 title = strings.ReplaceAll(resp.Response.Content.Text(), "\n", " ")
863
864 // Remove thinking tags if present.
865 title = thinkTagRegex.ReplaceAllString(title, "")
866
867 title = strings.TrimSpace(title)
868 if title == "" {
869 slog.Debug("Empty title; using fallback")
870 title = defaultSessionName
871 }
872
873 // Calculate usage and cost.
874 var openrouterCost *float64
875 for _, step := range resp.Steps {
876 stepCost := a.openrouterCost(step.ProviderMetadata)
877 if stepCost != nil {
878 newCost := *stepCost
879 if openrouterCost != nil {
880 newCost += *openrouterCost
881 }
882 openrouterCost = &newCost
883 }
884 }
885
886 modelConfig := model.CatwalkCfg
887 cost := modelConfig.CostPer1MInCached/1e6*float64(resp.TotalUsage.CacheCreationTokens) +
888 modelConfig.CostPer1MOutCached/1e6*float64(resp.TotalUsage.CacheReadTokens) +
889 modelConfig.CostPer1MIn/1e6*float64(resp.TotalUsage.InputTokens) +
890 modelConfig.CostPer1MOut/1e6*float64(resp.TotalUsage.OutputTokens)
891
892 // Use override cost if available (e.g., from OpenRouter).
893 if openrouterCost != nil {
894 cost = *openrouterCost
895 }
896
897 promptTokens := resp.TotalUsage.InputTokens + resp.TotalUsage.CacheCreationTokens
898 completionTokens := resp.TotalUsage.OutputTokens
899
900 // Atomically update only title and usage fields to avoid overriding other
901 // concurrent session updates.
902 saveErr := a.sessions.UpdateTitleAndUsage(ctx, sessionID, title, promptTokens, completionTokens, cost)
903 if saveErr != nil {
904 slog.Error("Failed to save session title and usage", "error", saveErr)
905 return
906 }
907}
908
909func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
910 openrouterMetadata, ok := metadata[openrouter.Name]
911 if !ok {
912 return nil
913 }
914
915 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
916 if !ok {
917 return nil
918 }
919 return &opts.Usage.Cost
920}
921
922func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
923 modelConfig := model.CatwalkCfg
924 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
925 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
926 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
927 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
928
929 a.eventTokensUsed(session.ID, model, usage, cost)
930
931 if overrideCost != nil {
932 session.Cost += *overrideCost
933 } else {
934 session.Cost += cost
935 }
936
937 session.CompletionTokens = usage.OutputTokens
938 session.PromptTokens = usage.InputTokens + usage.CacheReadTokens
939}
940
941func (a *sessionAgent) Cancel(sessionID string) {
942 // Cancel regular requests. Don't use Take() here - we need the entry to
943 // remain in activeRequests so IsBusy() returns true until the goroutine
944 // fully completes (including error handling that may access the DB).
945 // The defer in processRequest will clean up the entry.
946 if cancel, ok := a.activeRequests.Get(sessionID); ok && cancel != nil {
947 slog.Debug("Request cancellation initiated", "session_id", sessionID)
948 cancel()
949 }
950
951 // Also check for summarize requests.
952 if cancel, ok := a.activeRequests.Get(sessionID + "-summarize"); ok && cancel != nil {
953 slog.Debug("Summarize cancellation initiated", "session_id", sessionID)
954 cancel()
955 }
956
957 if a.QueuedPrompts(sessionID) > 0 {
958 slog.Debug("Clearing queued prompts", "session_id", sessionID)
959 a.messageQueue.Del(sessionID)
960 }
961}
962
963func (a *sessionAgent) ClearQueue(sessionID string) {
964 if a.QueuedPrompts(sessionID) > 0 {
965 slog.Debug("Clearing queued prompts", "session_id", sessionID)
966 a.messageQueue.Del(sessionID)
967 }
968}
969
970func (a *sessionAgent) CancelAll() {
971 if !a.IsBusy() {
972 return
973 }
974 for key := range a.activeRequests.Seq2() {
975 a.Cancel(key) // key is sessionID
976 }
977
978 timeout := time.After(5 * time.Second)
979 for a.IsBusy() {
980 select {
981 case <-timeout:
982 return
983 default:
984 time.Sleep(200 * time.Millisecond)
985 }
986 }
987}
988
989func (a *sessionAgent) IsBusy() bool {
990 var busy bool
991 for cancelFunc := range a.activeRequests.Seq() {
992 if cancelFunc != nil {
993 busy = true
994 break
995 }
996 }
997 return busy
998}
999
1000func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
1001 _, busy := a.activeRequests.Get(sessionID)
1002 return busy
1003}
1004
1005func (a *sessionAgent) QueuedPrompts(sessionID string) int {
1006 l, ok := a.messageQueue.Get(sessionID)
1007 if !ok {
1008 return 0
1009 }
1010 return len(l)
1011}
1012
1013func (a *sessionAgent) QueuedPromptsList(sessionID string) []string {
1014 l, ok := a.messageQueue.Get(sessionID)
1015 if !ok {
1016 return nil
1017 }
1018 prompts := make([]string, len(l))
1019 for i, call := range l {
1020 prompts[i] = call.Prompt
1021 }
1022 return prompts
1023}
1024
1025func (a *sessionAgent) SetModels(large Model, small Model) {
1026 a.largeModel.Set(large)
1027 a.smallModel.Set(small)
1028}
1029
1030func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
1031 a.tools.SetSlice(tools)
1032}
1033
1034func (a *sessionAgent) SetSystemPrompt(systemPrompt string) {
1035 a.systemPrompt.Set(systemPrompt)
1036}
1037
1038func (a *sessionAgent) Model() Model {
1039 return a.largeModel.Get()
1040}
1041
1042// convertToToolResult converts a fantasy tool result to a message tool result.
1043func (a *sessionAgent) convertToToolResult(result fantasy.ToolResultContent) message.ToolResult {
1044 baseResult := message.ToolResult{
1045 ToolCallID: result.ToolCallID,
1046 Name: result.ToolName,
1047 Metadata: result.ClientMetadata,
1048 }
1049
1050 switch result.Result.GetType() {
1051 case fantasy.ToolResultContentTypeText:
1052 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result); ok {
1053 baseResult.Content = r.Text
1054 }
1055 case fantasy.ToolResultContentTypeError:
1056 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result); ok {
1057 baseResult.Content = r.Error.Error()
1058 baseResult.IsError = true
1059 }
1060 case fantasy.ToolResultContentTypeMedia:
1061 if r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](result.Result); ok {
1062 content := r.Text
1063 if content == "" {
1064 content = fmt.Sprintf("Loaded %s content", r.MediaType)
1065 }
1066 baseResult.Content = content
1067 baseResult.Data = r.Data
1068 baseResult.MIMEType = r.MediaType
1069 }
1070 }
1071
1072 return baseResult
1073}
1074
1075// workaroundProviderMediaLimitations converts media content in tool results to
1076// user messages for providers that don't natively support images in tool results.
1077//
1078// Problem: OpenAI, Google, OpenRouter, and other OpenAI-compatible providers
1079// don't support sending images/media in tool result messages - they only accept
1080// text in tool results. However, they DO support images in user messages.
1081//
1082// If we send media in tool results to these providers, the API returns an error.
1083//
1084// Solution: For these providers, we:
1085// 1. Replace the media in the tool result with a text placeholder
1086// 2. Inject a user message immediately after with the image as a file attachment
1087// 3. This maintains the tool execution flow while working around API limitations
1088//
1089// Anthropic and Bedrock support images natively in tool results, so we skip
1090// this workaround for them.
1091//
1092// Example transformation:
1093//
1094// BEFORE: [tool result: image data]
1095// AFTER: [tool result: "Image loaded - see attached"], [user: image attachment]
1096func (a *sessionAgent) workaroundProviderMediaLimitations(messages []fantasy.Message, largeModel Model) []fantasy.Message {
1097 providerSupportsMedia := largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderAnthropic) ||
1098 largeModel.ModelCfg.Provider == string(catwalk.InferenceProviderBedrock)
1099
1100 if providerSupportsMedia {
1101 return messages
1102 }
1103
1104 convertedMessages := make([]fantasy.Message, 0, len(messages))
1105
1106 for _, msg := range messages {
1107 if msg.Role != fantasy.MessageRoleTool {
1108 convertedMessages = append(convertedMessages, msg)
1109 continue
1110 }
1111
1112 textParts := make([]fantasy.MessagePart, 0, len(msg.Content))
1113 var mediaFiles []fantasy.FilePart
1114
1115 for _, part := range msg.Content {
1116 toolResult, ok := fantasy.AsMessagePart[fantasy.ToolResultPart](part)
1117 if !ok {
1118 textParts = append(textParts, part)
1119 continue
1120 }
1121
1122 if media, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentMedia](toolResult.Output); ok {
1123 decoded, err := base64.StdEncoding.DecodeString(media.Data)
1124 if err != nil {
1125 slog.Warn("Failed to decode media data", "error", err)
1126 textParts = append(textParts, part)
1127 continue
1128 }
1129
1130 mediaFiles = append(mediaFiles, fantasy.FilePart{
1131 Data: decoded,
1132 MediaType: media.MediaType,
1133 Filename: fmt.Sprintf("tool-result-%s", toolResult.ToolCallID),
1134 })
1135
1136 textParts = append(textParts, fantasy.ToolResultPart{
1137 ToolCallID: toolResult.ToolCallID,
1138 Output: fantasy.ToolResultOutputContentText{
1139 Text: "[Image/media content loaded - see attached file]",
1140 },
1141 ProviderOptions: toolResult.ProviderOptions,
1142 })
1143 } else {
1144 textParts = append(textParts, part)
1145 }
1146 }
1147
1148 convertedMessages = append(convertedMessages, fantasy.Message{
1149 Role: fantasy.MessageRoleTool,
1150 Content: textParts,
1151 })
1152
1153 if len(mediaFiles) > 0 {
1154 convertedMessages = append(convertedMessages, fantasy.NewUserMessage(
1155 "Here is the media content from the tool result:",
1156 mediaFiles...,
1157 ))
1158 }
1159 }
1160
1161 return convertedMessages
1162}
1163
1164// buildSummaryPrompt constructs the prompt text for session summarization.
1165func buildSummaryPrompt(todos []session.Todo) string {
1166 var sb strings.Builder
1167 sb.WriteString("Provide a detailed summary of our conversation above.")
1168 if len(todos) > 0 {
1169 sb.WriteString("\n\n## Current Todo List\n\n")
1170 for _, t := range todos {
1171 fmt.Fprintf(&sb, "- [%s] %s\n", t.Status, t.Content)
1172 }
1173 sb.WriteString("\nInclude these tasks and their statuses in your summary. ")
1174 sb.WriteString("Instruct the resuming assistant to use the `todos` tool to continue tracking progress on these tasks.")
1175 }
1176 return sb.String()
1177}