1package agent
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "os"
11 "strconv"
12 "strings"
13 "sync"
14 "time"
15
16 "charm.land/fantasy"
17 "charm.land/fantasy/providers/anthropic"
18 "charm.land/fantasy/providers/bedrock"
19 "charm.land/fantasy/providers/google"
20 "charm.land/fantasy/providers/openai"
21 "charm.land/fantasy/providers/openrouter"
22 "github.com/charmbracelet/catwalk/pkg/catwalk"
23 "github.com/charmbracelet/crush/internal/agent/tools"
24 "github.com/charmbracelet/crush/internal/config"
25 "github.com/charmbracelet/crush/internal/csync"
26 "github.com/charmbracelet/crush/internal/hooks"
27 "github.com/charmbracelet/crush/internal/message"
28 "github.com/charmbracelet/crush/internal/permission"
29 "github.com/charmbracelet/crush/internal/session"
30)
31
32//go:embed templates/title.md
33var titlePrompt []byte
34
35//go:embed templates/summary.md
36var summaryPrompt []byte
37
38type SessionAgentCall struct {
39 SessionID string
40 Prompt string
41 ProviderOptions fantasy.ProviderOptions
42 Attachments []message.Attachment
43 MaxOutputTokens int64
44 Temperature *float64
45 TopP *float64
46 TopK *int64
47 FrequencyPenalty *float64
48 PresencePenalty *float64
49}
50
51type SessionAgent interface {
52 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
53 SetModels(large Model, small Model)
54 SetTools(tools []fantasy.AgentTool)
55 Cancel(sessionID string)
56 CancelAll()
57 IsSessionBusy(sessionID string) bool
58 IsBusy() bool
59 QueuedPrompts(sessionID string) int
60 ClearQueue(sessionID string)
61 Summarize(context.Context, string, fantasy.ProviderOptions) error
62 Model() Model
63}
64
65type Model struct {
66 Model fantasy.LanguageModel
67 CatwalkCfg catwalk.Model
68 ModelCfg config.SelectedModel
69}
70
71type sessionAgent struct {
72 largeModel Model
73 smallModel Model
74 systemPromptPrefix string
75 systemPrompt string
76 tools []fantasy.AgentTool
77 sessions session.Service
78 messages message.Service
79 disableAutoSummarize bool
80 isYolo bool
81 hooks *hooks.Executor
82
83 messageQueue *csync.Map[string, []SessionAgentCall]
84 activeRequests *csync.Map[string, context.CancelFunc]
85}
86
87type SessionAgentOptions struct {
88 LargeModel Model
89 SmallModel Model
90 SystemPromptPrefix string
91 SystemPrompt string
92 DisableAutoSummarize bool
93 IsYolo bool
94 Sessions session.Service
95 Messages message.Service
96 Tools []fantasy.AgentTool
97 Hooks *hooks.Executor
98}
99
100func NewSessionAgent(
101 opts SessionAgentOptions,
102) SessionAgent {
103 return &sessionAgent{
104 largeModel: opts.LargeModel,
105 smallModel: opts.SmallModel,
106 systemPromptPrefix: opts.SystemPromptPrefix,
107 systemPrompt: opts.SystemPrompt,
108 sessions: opts.Sessions,
109 messages: opts.Messages,
110 disableAutoSummarize: opts.DisableAutoSummarize,
111 tools: opts.Tools,
112 isYolo: opts.IsYolo,
113 hooks: opts.Hooks,
114 messageQueue: csync.NewMap[string, []SessionAgentCall](),
115 activeRequests: csync.NewMap[string, context.CancelFunc](),
116 }
117}
118
119func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
120 if call.Prompt == "" {
121 return nil, ErrEmptyPrompt
122 }
123 if call.SessionID == "" {
124 return nil, ErrSessionMissing
125 }
126
127 // Queue the message if busy
128 if a.IsSessionBusy(call.SessionID) {
129 existing, ok := a.messageQueue.Get(call.SessionID)
130 if !ok {
131 existing = []SessionAgentCall{}
132 }
133 existing = append(existing, call)
134 a.messageQueue.Set(call.SessionID, existing)
135 return nil, nil
136 }
137
138 if len(a.tools) > 0 {
139 // add anthropic caching to the last tool
140 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
141 }
142
143 agent := fantasy.NewAgent(
144 a.largeModel.Model,
145 fantasy.WithSystemPrompt(a.systemPrompt),
146 fantasy.WithTools(a.tools...),
147 )
148
149 sessionLock := sync.Mutex{}
150 currentSession, err := a.sessions.Get(ctx, call.SessionID)
151 if err != nil {
152 return nil, fmt.Errorf("failed to get session: %w", err)
153 }
154
155 msgs, err := a.getSessionMessages(ctx, currentSession)
156 if err != nil {
157 return nil, fmt.Errorf("failed to get session messages: %w", err)
158 }
159
160 var wg sync.WaitGroup
161 // Generate title if first message
162 if len(msgs) == 0 {
163 wg.Go(func() {
164 sessionLock.Lock()
165 a.generateTitle(ctx, ¤tSession, call.Prompt)
166 sessionLock.Unlock()
167 })
168 }
169
170 // Add the user message to the session
171 _, err = a.createUserMessage(ctx, call)
172 if err != nil {
173 return nil, err
174 }
175
176 // Execute UserPromptSubmit hook
177 if a.hooks != nil {
178 _ = a.hooks.Execute(ctx, hooks.HookContext{
179 EventType: config.UserPromptSubmit,
180 SessionID: call.SessionID,
181 UserPrompt: call.Prompt,
182 Provider: a.largeModel.ModelCfg.Provider,
183 Model: a.largeModel.ModelCfg.Model,
184 })
185 }
186
187 // add the session to the context
188 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
189
190 genCtx, cancel := context.WithCancel(ctx)
191 a.activeRequests.Set(call.SessionID, cancel)
192
193 defer cancel()
194 defer a.activeRequests.Del(call.SessionID)
195
196 history, files := a.preparePrompt(msgs, call.Attachments...)
197
198 startTime := time.Now()
199 a.eventPromptSent(call.SessionID)
200
201 var currentAssistant *message.Message
202 var shouldSummarize bool
203 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
204 Prompt: call.Prompt,
205 Files: files,
206 Messages: history,
207 ProviderOptions: call.ProviderOptions,
208 MaxOutputTokens: &call.MaxOutputTokens,
209 TopP: call.TopP,
210 Temperature: call.Temperature,
211 PresencePenalty: call.PresencePenalty,
212 TopK: call.TopK,
213 FrequencyPenalty: call.FrequencyPenalty,
214 // Before each step create the new assistant message
215 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
216 prepared.Messages = options.Messages
217 // reset all cached items
218 for i := range prepared.Messages {
219 prepared.Messages[i].ProviderOptions = nil
220 }
221
222 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
223 a.messageQueue.Del(call.SessionID)
224 for _, queued := range queuedCalls {
225 userMessage, createErr := a.createUserMessage(callContext, queued)
226 if createErr != nil {
227 return callContext, prepared, createErr
228 }
229 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
230 }
231
232 lastSystemRoleInx := 0
233 systemMessageUpdated := false
234 for i, msg := range prepared.Messages {
235 // only add cache control to the last message
236 if msg.Role == fantasy.MessageRoleSystem {
237 lastSystemRoleInx = i
238 } else if !systemMessageUpdated {
239 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
240 systemMessageUpdated = true
241 }
242 // than add cache control to the last 2 messages
243 if i > len(prepared.Messages)-3 {
244 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
245 }
246 }
247
248 if a.systemPromptPrefix != "" {
249 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
250 }
251
252 var assistantMsg message.Message
253 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
254 Role: message.Assistant,
255 Parts: []message.ContentPart{},
256 Model: a.largeModel.ModelCfg.Model,
257 Provider: a.largeModel.ModelCfg.Provider,
258 })
259 if err != nil {
260 return callContext, prepared, err
261 }
262 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
263 currentAssistant = &assistantMsg
264 return callContext, prepared, err
265 },
266 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
267 currentAssistant.AppendReasoningContent(reasoning.Text)
268 return a.messages.Update(genCtx, *currentAssistant)
269 },
270 OnReasoningDelta: func(id string, text string) error {
271 currentAssistant.AppendReasoningContent(text)
272 return a.messages.Update(genCtx, *currentAssistant)
273 },
274 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
275 // handle anthropic signature
276 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
277 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
278 currentAssistant.AppendReasoningSignature(reasoning.Signature)
279 }
280 }
281 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
282 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
283 currentAssistant.AppendReasoningSignature(reasoning.Signature)
284 }
285 }
286 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
287 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
288 currentAssistant.SetReasoningResponsesData(reasoning)
289 }
290 }
291 currentAssistant.FinishThinking()
292 return a.messages.Update(genCtx, *currentAssistant)
293 },
294 OnTextDelta: func(id string, text string) error {
295 currentAssistant.AppendContent(text)
296 return a.messages.Update(genCtx, *currentAssistant)
297 },
298 OnToolInputStart: func(id string, toolName string) error {
299 toolCall := message.ToolCall{
300 ID: id,
301 Name: toolName,
302 ProviderExecuted: false,
303 Finished: false,
304 }
305 currentAssistant.AddToolCall(toolCall)
306 return a.messages.Update(genCtx, *currentAssistant)
307 },
308 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
309 // TODO: implement
310 },
311 OnToolCall: func(tc fantasy.ToolCallContent) error {
312 // Execute PreToolUse hook - blocks tool execution on error
313 if a.hooks != nil {
314 toolInput := make(map[string]any)
315 if err := json.Unmarshal([]byte(tc.Input), &toolInput); err == nil {
316 if err := a.hooks.Execute(genCtx, hooks.HookContext{
317 EventType: config.PreToolUse,
318 SessionID: call.SessionID,
319 ToolName: tc.ToolName,
320 ToolInput: toolInput,
321 MessageID: currentAssistant.ID,
322 Provider: a.largeModel.ModelCfg.Provider,
323 Model: a.largeModel.ModelCfg.Model,
324 }); err != nil {
325 return fmt.Errorf("PreToolUse hook blocked tool execution: %w", err)
326 }
327 }
328 }
329
330 toolCall := message.ToolCall{
331 ID: tc.ToolCallID,
332 Name: tc.ToolName,
333 Input: tc.Input,
334 ProviderExecuted: false,
335 Finished: true,
336 }
337 currentAssistant.AddToolCall(toolCall)
338 return a.messages.Update(genCtx, *currentAssistant)
339 },
340 OnToolResult: func(result fantasy.ToolResultContent) error {
341 var resultContent string
342 isError := false
343 switch result.Result.GetType() {
344 case fantasy.ToolResultContentTypeText:
345 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
346 if ok {
347 resultContent = r.Text
348 }
349 case fantasy.ToolResultContentTypeError:
350 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
351 if ok {
352 isError = true
353 resultContent = r.Error.Error()
354 }
355 case fantasy.ToolResultContentTypeMedia:
356 // TODO: handle this message type
357 }
358
359 // Execute PostToolUse hook
360 if a.hooks != nil {
361 toolInput := make(map[string]any)
362 // Try to get tool input from the assistant message
363 toolCalls := currentAssistant.ToolCalls()
364 for _, tc := range toolCalls {
365 if tc.ID == result.ToolCallID {
366 _ = json.Unmarshal([]byte(tc.Input), &toolInput)
367 break
368 }
369 }
370
371 _ = a.hooks.Execute(genCtx, hooks.HookContext{
372 EventType: config.PostToolUse,
373 SessionID: call.SessionID,
374 ToolName: result.ToolName,
375 ToolInput: toolInput,
376 ToolResult: resultContent,
377 ToolError: isError,
378 MessageID: currentAssistant.ID,
379 Provider: a.largeModel.ModelCfg.Provider,
380 Model: a.largeModel.ModelCfg.Model,
381 })
382 }
383
384 toolResult := message.ToolResult{
385 ToolCallID: result.ToolCallID,
386 Name: result.ToolName,
387 Content: resultContent,
388 IsError: isError,
389 Metadata: result.ClientMetadata,
390 }
391 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
392 Role: message.Tool,
393 Parts: []message.ContentPart{
394 toolResult,
395 },
396 })
397 if createMsgErr != nil {
398 return createMsgErr
399 }
400 return nil
401 },
402 OnStepFinish: func(stepResult fantasy.StepResult) error {
403 finishReason := message.FinishReasonUnknown
404 switch stepResult.FinishReason {
405 case fantasy.FinishReasonLength:
406 finishReason = message.FinishReasonMaxTokens
407 case fantasy.FinishReasonStop:
408 finishReason = message.FinishReasonEndTurn
409 case fantasy.FinishReasonToolCalls:
410 finishReason = message.FinishReasonToolUse
411 }
412 currentAssistant.AddFinish(finishReason, "", "")
413 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
414 sessionLock.Lock()
415 _, sessionErr := a.sessions.Save(genCtx, currentSession)
416 sessionLock.Unlock()
417 if sessionErr != nil {
418 return sessionErr
419 }
420 return a.messages.Update(genCtx, *currentAssistant)
421 },
422 StopWhen: []fantasy.StopCondition{
423 func(_ []fantasy.StepResult) bool {
424 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
425 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
426 remaining := cw - tokens
427 var threshold int64
428 if cw > 200_000 {
429 threshold = 20_000
430 } else {
431 threshold = int64(float64(cw) * 0.2)
432 }
433 if (remaining <= threshold) && !a.disableAutoSummarize {
434 shouldSummarize = true
435 return true
436 }
437 return false
438 },
439 },
440 })
441
442 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
443
444 if err != nil {
445 isCancelErr := errors.Is(err, context.Canceled)
446 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
447 if currentAssistant == nil {
448 return result, err
449 }
450 // Ensure we finish thinking on error to close the reasoning state
451 currentAssistant.FinishThinking()
452 toolCalls := currentAssistant.ToolCalls()
453 // INFO: we use the parent context here because the genCtx has been cancelled
454 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
455 if createErr != nil {
456 return nil, createErr
457 }
458 for _, tc := range toolCalls {
459 if !tc.Finished {
460 tc.Finished = true
461 tc.Input = "{}"
462 currentAssistant.AddToolCall(tc)
463 updateErr := a.messages.Update(ctx, *currentAssistant)
464 if updateErr != nil {
465 return nil, updateErr
466 }
467 }
468
469 found := false
470 for _, msg := range msgs {
471 if msg.Role == message.Tool {
472 for _, tr := range msg.ToolResults() {
473 if tr.ToolCallID == tc.ID {
474 found = true
475 break
476 }
477 }
478 }
479 if found {
480 break
481 }
482 }
483 if found {
484 continue
485 }
486 content := "There was an error while executing the tool"
487 if isCancelErr {
488 content = "Tool execution canceled by user"
489 } else if isPermissionErr {
490 content = "Permission denied"
491 }
492 toolResult := message.ToolResult{
493 ToolCallID: tc.ID,
494 Name: tc.Name,
495 Content: content,
496 IsError: true,
497 }
498 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
499 Role: message.Tool,
500 Parts: []message.ContentPart{
501 toolResult,
502 },
503 })
504 if createErr != nil {
505 return nil, createErr
506 }
507 }
508 if isCancelErr {
509 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
510 } else if isPermissionErr {
511 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
512 } else {
513 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
514 }
515 // INFO: we use the parent context here because the genCtx has been cancelled
516 updateErr := a.messages.Update(ctx, *currentAssistant)
517 if updateErr != nil {
518 return nil, updateErr
519 }
520 return nil, err
521 }
522 wg.Wait()
523
524 // Execute Stop hook
525 if a.hooks != nil && result != nil {
526 var totalTokens, inputTokens int64
527 for _, step := range result.Steps {
528 totalTokens += step.Usage.TotalTokens
529 inputTokens += step.Usage.InputTokens
530 }
531
532 _ = a.hooks.Execute(ctx, hooks.HookContext{
533 EventType: config.Stop,
534 SessionID: call.SessionID,
535 MessageID: currentAssistant.ID,
536 Provider: a.largeModel.ModelCfg.Provider,
537 Model: a.largeModel.ModelCfg.Model,
538 TokensUsed: totalTokens,
539 TokensInput: inputTokens,
540 })
541 }
542
543 if shouldSummarize {
544 a.activeRequests.Del(call.SessionID)
545 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
546 return nil, summarizeErr
547 }
548 // if the agent was not done...
549 if len(currentAssistant.ToolCalls()) > 0 {
550 existing, ok := a.messageQueue.Get(call.SessionID)
551 if !ok {
552 existing = []SessionAgentCall{}
553 }
554 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
555 existing = append(existing, call)
556 a.messageQueue.Set(call.SessionID, existing)
557 }
558 }
559
560 // release active request before processing queued messages
561 a.activeRequests.Del(call.SessionID)
562 cancel()
563
564 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
565 if !ok || len(queuedMessages) == 0 {
566 return result, err
567 }
568 // there are queued messages restart the loop
569 firstQueuedMessage := queuedMessages[0]
570 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
571 return a.Run(ctx, firstQueuedMessage)
572}
573
574func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
575 if a.IsSessionBusy(sessionID) {
576 return ErrSessionBusy
577 }
578
579 currentSession, err := a.sessions.Get(ctx, sessionID)
580 if err != nil {
581 return fmt.Errorf("failed to get session: %w", err)
582 }
583 msgs, err := a.getSessionMessages(ctx, currentSession)
584 if err != nil {
585 return err
586 }
587 if len(msgs) == 0 {
588 // nothing to summarize
589 return nil
590 }
591
592 aiMsgs, _ := a.preparePrompt(msgs)
593
594 genCtx, cancel := context.WithCancel(ctx)
595 a.activeRequests.Set(sessionID, cancel)
596 defer a.activeRequests.Del(sessionID)
597 defer cancel()
598
599 agent := fantasy.NewAgent(a.largeModel.Model,
600 fantasy.WithSystemPrompt(string(summaryPrompt)),
601 )
602 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
603 Role: message.Assistant,
604 Model: a.largeModel.Model.Model(),
605 Provider: a.largeModel.Model.Provider(),
606 IsSummaryMessage: true,
607 })
608 if err != nil {
609 return err
610 }
611
612 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
613 Prompt: "Provide a detailed summary of our conversation above.",
614 Messages: aiMsgs,
615 ProviderOptions: opts,
616 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
617 prepared.Messages = options.Messages
618 if a.systemPromptPrefix != "" {
619 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
620 }
621 return callContext, prepared, nil
622 },
623 OnReasoningDelta: func(id string, text string) error {
624 summaryMessage.AppendReasoningContent(text)
625 return a.messages.Update(genCtx, summaryMessage)
626 },
627 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
628 // handle anthropic signature
629 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
630 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
631 summaryMessage.AppendReasoningSignature(signature.Signature)
632 }
633 }
634 summaryMessage.FinishThinking()
635 return a.messages.Update(genCtx, summaryMessage)
636 },
637 OnTextDelta: func(id, text string) error {
638 summaryMessage.AppendContent(text)
639 return a.messages.Update(genCtx, summaryMessage)
640 },
641 })
642 if err != nil {
643 isCancelErr := errors.Is(err, context.Canceled)
644 if isCancelErr {
645 // User cancelled summarize we need to remove the summary message
646 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
647 return deleteErr
648 }
649 return err
650 }
651
652 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
653 err = a.messages.Update(genCtx, summaryMessage)
654 if err != nil {
655 return err
656 }
657
658 var openrouterCost *float64
659 for _, step := range resp.Steps {
660 stepCost := a.openrouterCost(step.ProviderMetadata)
661 if stepCost != nil {
662 newCost := *stepCost
663 if openrouterCost != nil {
664 newCost += *openrouterCost
665 }
666 openrouterCost = &newCost
667 }
668 }
669
670 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
671
672 // just in case get just the last usage
673 usage := resp.Response.Usage
674 currentSession.SummaryMessageID = summaryMessage.ID
675 currentSession.CompletionTokens = usage.OutputTokens
676 currentSession.PromptTokens = 0
677 _, err = a.sessions.Save(genCtx, currentSession)
678 return err
679}
680
681func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
682 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
683 return fantasy.ProviderOptions{}
684 }
685 return fantasy.ProviderOptions{
686 anthropic.Name: &anthropic.ProviderCacheControlOptions{
687 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
688 },
689 bedrock.Name: &anthropic.ProviderCacheControlOptions{
690 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
691 },
692 }
693}
694
695func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
696 var attachmentParts []message.ContentPart
697 for _, attachment := range call.Attachments {
698 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
699 }
700 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
701 parts = append(parts, attachmentParts...)
702 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
703 Role: message.User,
704 Parts: parts,
705 })
706 if err != nil {
707 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
708 }
709 return msg, nil
710}
711
712func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
713 var history []fantasy.Message
714 for _, m := range msgs {
715 if len(m.Parts) == 0 {
716 continue
717 }
718 // Assistant message without content or tool calls (cancelled before it returned anything)
719 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
720 continue
721 }
722 history = append(history, m.ToAIMessage()...)
723 }
724
725 var files []fantasy.FilePart
726 for _, attachment := range attachments {
727 files = append(files, fantasy.FilePart{
728 Filename: attachment.FileName,
729 Data: attachment.Content,
730 MediaType: attachment.MimeType,
731 })
732 }
733
734 return history, files
735}
736
737func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
738 msgs, err := a.messages.List(ctx, session.ID)
739 if err != nil {
740 return nil, fmt.Errorf("failed to list messages: %w", err)
741 }
742
743 if session.SummaryMessageID != "" {
744 summaryMsgInex := -1
745 for i, msg := range msgs {
746 if msg.ID == session.SummaryMessageID {
747 summaryMsgInex = i
748 break
749 }
750 }
751 if summaryMsgInex != -1 {
752 msgs = msgs[summaryMsgInex:]
753 msgs[0].Role = message.User
754 }
755 }
756 return msgs, nil
757}
758
759func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
760 if prompt == "" {
761 return
762 }
763
764 var maxOutput int64 = 40
765 if a.smallModel.CatwalkCfg.CanReason {
766 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
767 }
768
769 agent := fantasy.NewAgent(a.smallModel.Model,
770 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
771 fantasy.WithMaxOutputTokens(maxOutput),
772 )
773
774 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
775 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
776 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
777 prepared.Messages = options.Messages
778 if a.systemPromptPrefix != "" {
779 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
780 }
781 return callContext, prepared, nil
782 },
783 })
784 if err != nil {
785 slog.Error("error generating title", "err", err)
786 return
787 }
788
789 title := resp.Response.Content.Text()
790
791 title = strings.ReplaceAll(title, "\n", " ")
792
793 // remove thinking tags if present
794 if idx := strings.Index(title, "</think>"); idx > 0 {
795 title = title[idx+len("</think>"):]
796 }
797
798 title = strings.TrimSpace(title)
799 if title == "" {
800 slog.Warn("failed to generate title", "warn", "empty title")
801 return
802 }
803
804 session.Title = title
805
806 var openrouterCost *float64
807 for _, step := range resp.Steps {
808 stepCost := a.openrouterCost(step.ProviderMetadata)
809 if stepCost != nil {
810 newCost := *stepCost
811 if openrouterCost != nil {
812 newCost += *openrouterCost
813 }
814 openrouterCost = &newCost
815 }
816 }
817
818 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
819 _, saveErr := a.sessions.Save(ctx, *session)
820 if saveErr != nil {
821 slog.Error("failed to save session title & usage", "error", saveErr)
822 return
823 }
824}
825
826func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
827 openrouterMetadata, ok := metadata[openrouter.Name]
828 if !ok {
829 return nil
830 }
831
832 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
833 if !ok {
834 return nil
835 }
836 return &opts.Usage.Cost
837}
838
839func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
840 modelConfig := model.CatwalkCfg
841 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
842 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
843 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
844 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
845
846 a.eventTokensUsed(session.ID, model, usage, cost)
847
848 if overrideCost != nil {
849 session.Cost += *overrideCost
850 } else {
851 session.Cost += cost
852 }
853
854 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
855 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
856}
857
858func (a *sessionAgent) Cancel(sessionID string) {
859 // Cancel regular requests
860 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
861 slog.Info("Request cancellation initiated", "session_id", sessionID)
862 cancel()
863 }
864
865 // Also check for summarize requests
866 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
867 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
868 cancel()
869 }
870
871 if a.QueuedPrompts(sessionID) > 0 {
872 slog.Info("Clearing queued prompts", "session_id", sessionID)
873 a.messageQueue.Del(sessionID)
874 }
875}
876
877func (a *sessionAgent) ClearQueue(sessionID string) {
878 if a.QueuedPrompts(sessionID) > 0 {
879 slog.Info("Clearing queued prompts", "session_id", sessionID)
880 a.messageQueue.Del(sessionID)
881 }
882}
883
884func (a *sessionAgent) CancelAll() {
885 if !a.IsBusy() {
886 return
887 }
888 for key := range a.activeRequests.Seq2() {
889 a.Cancel(key) // key is sessionID
890 }
891
892 timeout := time.After(5 * time.Second)
893 for a.IsBusy() {
894 select {
895 case <-timeout:
896 return
897 default:
898 time.Sleep(200 * time.Millisecond)
899 }
900 }
901}
902
903func (a *sessionAgent) IsBusy() bool {
904 var busy bool
905 for cancelFunc := range a.activeRequests.Seq() {
906 if cancelFunc != nil {
907 busy = true
908 break
909 }
910 }
911 return busy
912}
913
914func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
915 _, busy := a.activeRequests.Get(sessionID)
916 return busy
917}
918
919func (a *sessionAgent) QueuedPrompts(sessionID string) int {
920 l, ok := a.messageQueue.Get(sessionID)
921 if !ok {
922 return 0
923 }
924 return len(l)
925}
926
927func (a *sessionAgent) SetModels(large Model, small Model) {
928 a.largeModel = large
929 a.smallModel = small
930}
931
932func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
933 a.tools = tools
934}
935
936func (a *sessionAgent) Model() Model {
937 return a.largeModel
938}