1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "os"
10 "strconv"
11 "strings"
12 "sync"
13 "time"
14
15 "charm.land/fantasy"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/providers/bedrock"
18 "charm.land/fantasy/providers/google"
19 "charm.land/fantasy/providers/openai"
20 "charm.land/fantasy/providers/openrouter"
21 "github.com/charmbracelet/catwalk/pkg/catwalk"
22 "github.com/charmbracelet/crush/internal/agent/tools"
23 "github.com/charmbracelet/crush/internal/config"
24 "github.com/charmbracelet/crush/internal/csync"
25 "github.com/charmbracelet/crush/internal/message"
26 "github.com/charmbracelet/crush/internal/permission"
27 "github.com/charmbracelet/crush/internal/session"
28)
29
30//go:embed templates/title.md
31var titlePrompt []byte
32
33//go:embed templates/summary.md
34var summaryPrompt []byte
35
36type SessionAgentCall struct {
37 SessionID string
38 Prompt string
39 ProviderOptions fantasy.ProviderOptions
40 Attachments []message.Attachment
41 MaxOutputTokens int64
42 Temperature *float64
43 TopP *float64
44 TopK *int64
45 FrequencyPenalty *float64
46 PresencePenalty *float64
47}
48
49type SessionAgent interface {
50 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
51 SetModels(large Model, small Model)
52 SetTools(tools []fantasy.AgentTool)
53 Cancel(sessionID string)
54 CancelAll()
55 IsSessionBusy(sessionID string) bool
56 IsBusy() bool
57 QueuedPrompts(sessionID string) int
58 ClearQueue(sessionID string)
59 Summarize(context.Context, string, fantasy.ProviderOptions) error
60 Model() Model
61}
62
63type Model struct {
64 Model fantasy.LanguageModel
65 CatwalkCfg catwalk.Model
66 ModelCfg config.SelectedModel
67}
68
69type sessionAgent struct {
70 largeModel Model
71 smallModel Model
72 systemPromptPrefix string
73 systemPrompt string
74 tools []fantasy.AgentTool
75 sessions session.Service
76 messages message.Service
77 disableAutoSummarize bool
78 isYolo bool
79
80 messageQueue *csync.Map[string, []SessionAgentCall]
81 activeRequests *csync.Map[string, context.CancelFunc]
82}
83
84type SessionAgentOptions struct {
85 LargeModel Model
86 SmallModel Model
87 SystemPromptPrefix string
88 SystemPrompt string
89 DisableAutoSummarize bool
90 IsYolo bool
91 Sessions session.Service
92 Messages message.Service
93 Tools []fantasy.AgentTool
94}
95
96func NewSessionAgent(
97 opts SessionAgentOptions,
98) SessionAgent {
99 return &sessionAgent{
100 largeModel: opts.LargeModel,
101 smallModel: opts.SmallModel,
102 systemPromptPrefix: opts.SystemPromptPrefix,
103 systemPrompt: opts.SystemPrompt,
104 sessions: opts.Sessions,
105 messages: opts.Messages,
106 disableAutoSummarize: opts.DisableAutoSummarize,
107 tools: opts.Tools,
108 isYolo: opts.IsYolo,
109 messageQueue: csync.NewMap[string, []SessionAgentCall](),
110 activeRequests: csync.NewMap[string, context.CancelFunc](),
111 }
112}
113
114func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
115 if call.Prompt == "" {
116 return nil, ErrEmptyPrompt
117 }
118 if call.SessionID == "" {
119 return nil, ErrSessionMissing
120 }
121
122 // Queue the message if busy
123 if a.IsSessionBusy(call.SessionID) {
124 existing, ok := a.messageQueue.Get(call.SessionID)
125 if !ok {
126 existing = []SessionAgentCall{}
127 }
128 existing = append(existing, call)
129 a.messageQueue.Set(call.SessionID, existing)
130 return nil, nil
131 }
132
133 if len(a.tools) > 0 {
134 // Add anthropic caching to the last tool.
135 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
136 }
137
138 agent := fantasy.NewAgent(
139 a.largeModel.Model,
140 fantasy.WithSystemPrompt(a.systemPrompt),
141 fantasy.WithTools(a.tools...),
142 )
143
144 sessionLock := sync.Mutex{}
145 currentSession, err := a.sessions.Get(ctx, call.SessionID)
146 if err != nil {
147 return nil, fmt.Errorf("failed to get session: %w", err)
148 }
149
150 msgs, err := a.getSessionMessages(ctx, currentSession)
151 if err != nil {
152 return nil, fmt.Errorf("failed to get session messages: %w", err)
153 }
154
155 var wg sync.WaitGroup
156 // Generate title if first message.
157 if len(msgs) == 0 {
158 wg.Go(func() {
159 sessionLock.Lock()
160 a.generateTitle(ctx, ¤tSession, call.Prompt)
161 sessionLock.Unlock()
162 })
163 }
164
165 // Add the user message to the session.
166 _, err = a.createUserMessage(ctx, call)
167 if err != nil {
168 return nil, err
169 }
170
171 // Add the session to the context.
172 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
173
174 genCtx, cancel := context.WithCancel(ctx)
175 a.activeRequests.Set(call.SessionID, cancel)
176
177 defer cancel()
178 defer a.activeRequests.Del(call.SessionID)
179
180 history, files := a.preparePrompt(msgs, call.Attachments...)
181
182 startTime := time.Now()
183 a.eventPromptSent(call.SessionID)
184
185 var currentAssistant *message.Message
186 var shouldSummarize bool
187 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
188 Prompt: call.Prompt,
189 Files: files,
190 Messages: history,
191 ProviderOptions: call.ProviderOptions,
192 MaxOutputTokens: &call.MaxOutputTokens,
193 TopP: call.TopP,
194 Temperature: call.Temperature,
195 PresencePenalty: call.PresencePenalty,
196 TopK: call.TopK,
197 FrequencyPenalty: call.FrequencyPenalty,
198 // Before each step create a new assistant message.
199 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
200 prepared.Messages = options.Messages
201 // Reset all cached items.
202 for i := range prepared.Messages {
203 prepared.Messages[i].ProviderOptions = nil
204 }
205
206 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
207 a.messageQueue.Del(call.SessionID)
208 for _, queued := range queuedCalls {
209 userMessage, createErr := a.createUserMessage(callContext, queued)
210 if createErr != nil {
211 return callContext, prepared, createErr
212 }
213 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
214 }
215
216 lastSystemRoleInx := 0
217 systemMessageUpdated := false
218 for i, msg := range prepared.Messages {
219 // Only add cache control to the last message.
220 if msg.Role == fantasy.MessageRoleSystem {
221 lastSystemRoleInx = i
222 } else if !systemMessageUpdated {
223 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
224 systemMessageUpdated = true
225 }
226 // Than add cache control to the last 2 messages.
227 if i > len(prepared.Messages)-3 {
228 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
229 }
230 }
231
232 if a.systemPromptPrefix != "" {
233 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
234 }
235
236 var assistantMsg message.Message
237 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
238 Role: message.Assistant,
239 Parts: []message.ContentPart{},
240 Model: a.largeModel.ModelCfg.Model,
241 Provider: a.largeModel.ModelCfg.Provider,
242 })
243 if err != nil {
244 return callContext, prepared, err
245 }
246 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
247 currentAssistant = &assistantMsg
248 return callContext, prepared, err
249 },
250 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
251 currentAssistant.AppendReasoningContent(reasoning.Text)
252 return a.messages.Update(genCtx, *currentAssistant)
253 },
254 OnReasoningDelta: func(id string, text string) error {
255 currentAssistant.AppendReasoningContent(text)
256 return a.messages.Update(genCtx, *currentAssistant)
257 },
258 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
259 // handle anthropic signature
260 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
261 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
262 currentAssistant.AppendReasoningSignature(reasoning.Signature)
263 }
264 }
265 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
266 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
267 currentAssistant.AppendReasoningSignature(reasoning.Signature)
268 }
269 }
270 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
271 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
272 currentAssistant.SetReasoningResponsesData(reasoning)
273 }
274 }
275 currentAssistant.FinishThinking()
276 return a.messages.Update(genCtx, *currentAssistant)
277 },
278 OnTextDelta: func(id string, text string) error {
279 // Strip leading newline from initial text content. This is is
280 // particularly important in non-interactive mode where leading
281 // newlines are very visible.
282 if len(currentAssistant.Parts) == 0 && strings.HasPrefix(text, "\n") {
283 text = strings.TrimPrefix(text, "\n")
284 }
285
286 currentAssistant.AppendContent(text)
287 return a.messages.Update(genCtx, *currentAssistant)
288 },
289 OnToolInputStart: func(id string, toolName string) error {
290 toolCall := message.ToolCall{
291 ID: id,
292 Name: toolName,
293 ProviderExecuted: false,
294 Finished: false,
295 }
296 currentAssistant.AddToolCall(toolCall)
297 return a.messages.Update(genCtx, *currentAssistant)
298 },
299 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
300 // TODO: implement
301 },
302 OnToolCall: func(tc fantasy.ToolCallContent) error {
303 toolCall := message.ToolCall{
304 ID: tc.ToolCallID,
305 Name: tc.ToolName,
306 Input: tc.Input,
307 ProviderExecuted: false,
308 Finished: true,
309 }
310 currentAssistant.AddToolCall(toolCall)
311 return a.messages.Update(genCtx, *currentAssistant)
312 },
313 OnToolResult: func(result fantasy.ToolResultContent) error {
314 var resultContent string
315 isError := false
316 switch result.Result.GetType() {
317 case fantasy.ToolResultContentTypeText:
318 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
319 if ok {
320 resultContent = r.Text
321 }
322 case fantasy.ToolResultContentTypeError:
323 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
324 if ok {
325 isError = true
326 resultContent = r.Error.Error()
327 }
328 case fantasy.ToolResultContentTypeMedia:
329 // TODO: handle this message type
330 }
331 toolResult := message.ToolResult{
332 ToolCallID: result.ToolCallID,
333 Name: result.ToolName,
334 Content: resultContent,
335 IsError: isError,
336 Metadata: result.ClientMetadata,
337 }
338 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
339 Role: message.Tool,
340 Parts: []message.ContentPart{
341 toolResult,
342 },
343 })
344 if createMsgErr != nil {
345 return createMsgErr
346 }
347 return nil
348 },
349 OnStepFinish: func(stepResult fantasy.StepResult) error {
350 finishReason := message.FinishReasonUnknown
351 switch stepResult.FinishReason {
352 case fantasy.FinishReasonLength:
353 finishReason = message.FinishReasonMaxTokens
354 case fantasy.FinishReasonStop:
355 finishReason = message.FinishReasonEndTurn
356 case fantasy.FinishReasonToolCalls:
357 finishReason = message.FinishReasonToolUse
358 }
359 currentAssistant.AddFinish(finishReason, "", "")
360 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
361 sessionLock.Lock()
362 _, sessionErr := a.sessions.Save(genCtx, currentSession)
363 sessionLock.Unlock()
364 if sessionErr != nil {
365 return sessionErr
366 }
367 return a.messages.Update(genCtx, *currentAssistant)
368 },
369 StopWhen: []fantasy.StopCondition{
370 func(_ []fantasy.StepResult) bool {
371 cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
372 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
373 remaining := cw - tokens
374 var threshold int64
375 if cw > 200_000 {
376 threshold = 20_000
377 } else {
378 threshold = int64(float64(cw) * 0.2)
379 }
380 if (remaining <= threshold) && !a.disableAutoSummarize {
381 shouldSummarize = true
382 return true
383 }
384 return false
385 },
386 },
387 })
388
389 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
390
391 if err != nil {
392 isCancelErr := errors.Is(err, context.Canceled)
393 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
394 if currentAssistant == nil {
395 return result, err
396 }
397 // Ensure we finish thinking on error to close the reasoning state.
398 currentAssistant.FinishThinking()
399 toolCalls := currentAssistant.ToolCalls()
400 // INFO: we use the parent context here because the genCtx has been cancelled.
401 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
402 if createErr != nil {
403 return nil, createErr
404 }
405 for _, tc := range toolCalls {
406 if !tc.Finished {
407 tc.Finished = true
408 tc.Input = "{}"
409 currentAssistant.AddToolCall(tc)
410 updateErr := a.messages.Update(ctx, *currentAssistant)
411 if updateErr != nil {
412 return nil, updateErr
413 }
414 }
415
416 found := false
417 for _, msg := range msgs {
418 if msg.Role == message.Tool {
419 for _, tr := range msg.ToolResults() {
420 if tr.ToolCallID == tc.ID {
421 found = true
422 break
423 }
424 }
425 }
426 if found {
427 break
428 }
429 }
430 if found {
431 continue
432 }
433 content := "There was an error while executing the tool"
434 if isCancelErr {
435 content = "Tool execution canceled by user"
436 } else if isPermissionErr {
437 content = "User denied permission"
438 }
439 toolResult := message.ToolResult{
440 ToolCallID: tc.ID,
441 Name: tc.Name,
442 Content: content,
443 IsError: true,
444 }
445 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
446 Role: message.Tool,
447 Parts: []message.ContentPart{
448 toolResult,
449 },
450 })
451 if createErr != nil {
452 return nil, createErr
453 }
454 }
455 if isCancelErr {
456 currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
457 } else if isPermissionErr {
458 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
459 } else {
460 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
461 }
462 // Note: we use the parent context here because the genCtx has been
463 // cancelled.
464 updateErr := a.messages.Update(ctx, *currentAssistant)
465 if updateErr != nil {
466 return nil, updateErr
467 }
468 return nil, err
469 }
470 wg.Wait()
471
472 if shouldSummarize {
473 a.activeRequests.Del(call.SessionID)
474 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
475 return nil, summarizeErr
476 }
477 // If the agent wasn't done...
478 if len(currentAssistant.ToolCalls()) > 0 {
479 existing, ok := a.messageQueue.Get(call.SessionID)
480 if !ok {
481 existing = []SessionAgentCall{}
482 }
483 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
484 existing = append(existing, call)
485 a.messageQueue.Set(call.SessionID, existing)
486 }
487 }
488
489 // Release active request before processing queued messages.
490 a.activeRequests.Del(call.SessionID)
491 cancel()
492
493 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
494 if !ok || len(queuedMessages) == 0 {
495 return result, err
496 }
497 // There are queued messages restart the loop.
498 firstQueuedMessage := queuedMessages[0]
499 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
500 return a.Run(ctx, firstQueuedMessage)
501}
502
503func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
504 if a.IsSessionBusy(sessionID) {
505 return ErrSessionBusy
506 }
507
508 currentSession, err := a.sessions.Get(ctx, sessionID)
509 if err != nil {
510 return fmt.Errorf("failed to get session: %w", err)
511 }
512 msgs, err := a.getSessionMessages(ctx, currentSession)
513 if err != nil {
514 return err
515 }
516 if len(msgs) == 0 {
517 // Nothing to summarize.
518 return nil
519 }
520
521 aiMsgs, _ := a.preparePrompt(msgs)
522
523 genCtx, cancel := context.WithCancel(ctx)
524 a.activeRequests.Set(sessionID, cancel)
525 defer a.activeRequests.Del(sessionID)
526 defer cancel()
527
528 agent := fantasy.NewAgent(a.largeModel.Model,
529 fantasy.WithSystemPrompt(string(summaryPrompt)),
530 )
531 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
532 Role: message.Assistant,
533 Model: a.largeModel.Model.Model(),
534 Provider: a.largeModel.Model.Provider(),
535 IsSummaryMessage: true,
536 })
537 if err != nil {
538 return err
539 }
540
541 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
542 Prompt: "Provide a detailed summary of our conversation above.",
543 Messages: aiMsgs,
544 ProviderOptions: opts,
545 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
546 prepared.Messages = options.Messages
547 if a.systemPromptPrefix != "" {
548 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
549 }
550 return callContext, prepared, nil
551 },
552 OnReasoningDelta: func(id string, text string) error {
553 summaryMessage.AppendReasoningContent(text)
554 return a.messages.Update(genCtx, summaryMessage)
555 },
556 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
557 // Handle anthropic signature.
558 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
559 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
560 summaryMessage.AppendReasoningSignature(signature.Signature)
561 }
562 }
563 summaryMessage.FinishThinking()
564 return a.messages.Update(genCtx, summaryMessage)
565 },
566 OnTextDelta: func(id, text string) error {
567 summaryMessage.AppendContent(text)
568 return a.messages.Update(genCtx, summaryMessage)
569 },
570 })
571 if err != nil {
572 isCancelErr := errors.Is(err, context.Canceled)
573 if isCancelErr {
574 // User cancelled summarize we need to remove the summary message.
575 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
576 return deleteErr
577 }
578 return err
579 }
580
581 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
582 err = a.messages.Update(genCtx, summaryMessage)
583 if err != nil {
584 return err
585 }
586
587 var openrouterCost *float64
588 for _, step := range resp.Steps {
589 stepCost := a.openrouterCost(step.ProviderMetadata)
590 if stepCost != nil {
591 newCost := *stepCost
592 if openrouterCost != nil {
593 newCost += *openrouterCost
594 }
595 openrouterCost = &newCost
596 }
597 }
598
599 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage, openrouterCost)
600
601 // Just in case, get just the last usage info.
602 usage := resp.Response.Usage
603 currentSession.SummaryMessageID = summaryMessage.ID
604 currentSession.CompletionTokens = usage.OutputTokens
605 currentSession.PromptTokens = 0
606 _, err = a.sessions.Save(genCtx, currentSession)
607 return err
608}
609
610func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
611 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
612 return fantasy.ProviderOptions{}
613 }
614 return fantasy.ProviderOptions{
615 anthropic.Name: &anthropic.ProviderCacheControlOptions{
616 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
617 },
618 bedrock.Name: &anthropic.ProviderCacheControlOptions{
619 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
620 },
621 }
622}
623
624func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
625 var attachmentParts []message.ContentPart
626 for _, attachment := range call.Attachments {
627 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
628 }
629 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
630 parts = append(parts, attachmentParts...)
631 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
632 Role: message.User,
633 Parts: parts,
634 })
635 if err != nil {
636 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
637 }
638 return msg, nil
639}
640
641func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
642 var history []fantasy.Message
643 for _, m := range msgs {
644 if len(m.Parts) == 0 {
645 continue
646 }
647 // Assistant message without content or tool calls (cancelled before it
648 // returned anything).
649 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
650 continue
651 }
652 history = append(history, m.ToAIMessage()...)
653 }
654
655 var files []fantasy.FilePart
656 for _, attachment := range attachments {
657 files = append(files, fantasy.FilePart{
658 Filename: attachment.FileName,
659 Data: attachment.Content,
660 MediaType: attachment.MimeType,
661 })
662 }
663
664 return history, files
665}
666
667func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
668 msgs, err := a.messages.List(ctx, session.ID)
669 if err != nil {
670 return nil, fmt.Errorf("failed to list messages: %w", err)
671 }
672
673 if session.SummaryMessageID != "" {
674 summaryMsgInex := -1
675 for i, msg := range msgs {
676 if msg.ID == session.SummaryMessageID {
677 summaryMsgInex = i
678 break
679 }
680 }
681 if summaryMsgInex != -1 {
682 msgs = msgs[summaryMsgInex:]
683 msgs[0].Role = message.User
684 }
685 }
686 return msgs, nil
687}
688
689func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
690 if prompt == "" {
691 return
692 }
693
694 var maxOutput int64 = 40
695 if a.smallModel.CatwalkCfg.CanReason {
696 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
697 }
698
699 agent := fantasy.NewAgent(a.smallModel.Model,
700 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
701 fantasy.WithMaxOutputTokens(maxOutput),
702 )
703
704 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
705 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
706 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
707 prepared.Messages = options.Messages
708 if a.systemPromptPrefix != "" {
709 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
710 }
711 return callContext, prepared, nil
712 },
713 })
714 if err != nil {
715 slog.Error("error generating title", "err", err)
716 return
717 }
718
719 title := resp.Response.Content.Text()
720
721 title = strings.ReplaceAll(title, "\n", " ")
722
723 // Remove thinking tags if present.
724 if idx := strings.Index(title, "</think>"); idx > 0 {
725 title = title[idx+len("</think>"):]
726 }
727
728 title = strings.TrimSpace(title)
729 if title == "" {
730 slog.Warn("failed to generate title", "warn", "empty title")
731 return
732 }
733
734 session.Title = title
735
736 var openrouterCost *float64
737 for _, step := range resp.Steps {
738 stepCost := a.openrouterCost(step.ProviderMetadata)
739 if stepCost != nil {
740 newCost := *stepCost
741 if openrouterCost != nil {
742 newCost += *openrouterCost
743 }
744 openrouterCost = &newCost
745 }
746 }
747
748 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
749 _, saveErr := a.sessions.Save(ctx, *session)
750 if saveErr != nil {
751 slog.Error("failed to save session title & usage", "error", saveErr)
752 return
753 }
754}
755
756func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
757 openrouterMetadata, ok := metadata[openrouter.Name]
758 if !ok {
759 return nil
760 }
761
762 opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
763 if !ok {
764 return nil
765 }
766 return &opts.Usage.Cost
767}
768
769func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
770 modelConfig := model.CatwalkCfg
771 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
772 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
773 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
774 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
775
776 a.eventTokensUsed(session.ID, model, usage, cost)
777
778 if overrideCost != nil {
779 session.Cost += *overrideCost
780 } else {
781 session.Cost += cost
782 }
783
784 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
785 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
786}
787
788func (a *sessionAgent) Cancel(sessionID string) {
789 // Cancel regular requests.
790 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
791 slog.Info("Request cancellation initiated", "session_id", sessionID)
792 cancel()
793 }
794
795 // Also check for summarize requests.
796 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
797 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
798 cancel()
799 }
800
801 if a.QueuedPrompts(sessionID) > 0 {
802 slog.Info("Clearing queued prompts", "session_id", sessionID)
803 a.messageQueue.Del(sessionID)
804 }
805}
806
807func (a *sessionAgent) ClearQueue(sessionID string) {
808 if a.QueuedPrompts(sessionID) > 0 {
809 slog.Info("Clearing queued prompts", "session_id", sessionID)
810 a.messageQueue.Del(sessionID)
811 }
812}
813
814func (a *sessionAgent) CancelAll() {
815 if !a.IsBusy() {
816 return
817 }
818 for key := range a.activeRequests.Seq2() {
819 a.Cancel(key) // key is sessionID
820 }
821
822 timeout := time.After(5 * time.Second)
823 for a.IsBusy() {
824 select {
825 case <-timeout:
826 return
827 default:
828 time.Sleep(200 * time.Millisecond)
829 }
830 }
831}
832
833func (a *sessionAgent) IsBusy() bool {
834 var busy bool
835 for cancelFunc := range a.activeRequests.Seq() {
836 if cancelFunc != nil {
837 busy = true
838 break
839 }
840 }
841 return busy
842}
843
844func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
845 _, busy := a.activeRequests.Get(sessionID)
846 return busy
847}
848
849func (a *sessionAgent) QueuedPrompts(sessionID string) int {
850 l, ok := a.messageQueue.Get(sessionID)
851 if !ok {
852 return 0
853 }
854 return len(l)
855}
856
857func (a *sessionAgent) SetModels(large Model, small Model) {
858 a.largeModel = large
859 a.smallModel = small
860}
861
862func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
863 a.tools = tools
864}
865
866func (a *sessionAgent) Model() Model {
867 return a.largeModel
868}