1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "os"
10 "strconv"
11 "strings"
12 "sync"
13 "time"
14
15 "charm.land/fantasy"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/providers/bedrock"
18 "charm.land/fantasy/providers/google"
19 "charm.land/fantasy/providers/openai"
20 "github.com/charmbracelet/catwalk/pkg/catwalk"
21 "github.com/charmbracelet/crush/internal/agent/tools"
22 "github.com/charmbracelet/crush/internal/config"
23 "github.com/charmbracelet/crush/internal/csync"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/session"
27)
28
29//go:embed templates/title.md
30var titlePrompt []byte
31
32//go:embed templates/summary.md
33var summaryPrompt []byte
34
35type SessionAgentCall struct {
36 SessionID string
37 Prompt string
38 ProviderOptions fantasy.ProviderOptions
39 Attachments []message.Attachment
40 MaxOutputTokens int64
41 Temperature *float64
42 TopP *float64
43 TopK *int64
44 FrequencyPenalty *float64
45 PresencePenalty *float64
46}
47
48type SessionAgent interface {
49 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
50 SetModels(large Model, small Model)
51 SetTools(tools []fantasy.AgentTool)
52 Cancel(sessionID string)
53 CancelAll()
54 IsSessionBusy(sessionID string) bool
55 IsBusy() bool
56 QueuedPrompts(sessionID string) int
57 ClearQueue(sessionID string)
58 Summarize(context.Context, string, fantasy.ProviderOptions) error
59 Model() Model
60}
61
62type Model struct {
63 Model fantasy.LanguageModel
64 CatwalkCfg catwalk.Model
65 ModelCfg config.SelectedModel
66}
67
68type sessionAgent struct {
69 largeModel Model
70 smallModel Model
71 systemPrompt string
72 tools []fantasy.AgentTool
73 sessions session.Service
74 messages message.Service
75 disableAutoSummarize bool
76 isYolo bool
77
78 messageQueue *csync.Map[string, []SessionAgentCall]
79 activeRequests *csync.Map[string, context.CancelFunc]
80}
81
82type SessionAgentOptions struct {
83 LargeModel Model
84 SmallModel Model
85 SystemPrompt string
86 DisableAutoSummarize bool
87 IsYolo bool
88 Sessions session.Service
89 Messages message.Service
90 Tools []fantasy.AgentTool
91}
92
93func NewSessionAgent(
94 opts SessionAgentOptions,
95) SessionAgent {
96 return &sessionAgent{
97 largeModel: opts.LargeModel,
98 smallModel: opts.SmallModel,
99 systemPrompt: opts.SystemPrompt,
100 sessions: opts.Sessions,
101 messages: opts.Messages,
102 disableAutoSummarize: opts.DisableAutoSummarize,
103 tools: opts.Tools,
104 isYolo: opts.IsYolo,
105 messageQueue: csync.NewMap[string, []SessionAgentCall](),
106 activeRequests: csync.NewMap[string, context.CancelFunc](),
107 }
108}
109
110func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
111 if call.Prompt == "" {
112 return nil, ErrEmptyPrompt
113 }
114 if call.SessionID == "" {
115 return nil, ErrSessionMissing
116 }
117
118 // Queue the message if busy
119 if a.IsSessionBusy(call.SessionID) {
120 existing, ok := a.messageQueue.Get(call.SessionID)
121 if !ok {
122 existing = []SessionAgentCall{}
123 }
124 existing = append(existing, call)
125 a.messageQueue.Set(call.SessionID, existing)
126 return nil, nil
127 }
128
129 if len(a.tools) > 0 {
130 // add anthropic caching to the last tool
131 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
132 }
133
134 agent := fantasy.NewAgent(
135 a.largeModel.Model,
136 fantasy.WithSystemPrompt(a.systemPrompt),
137 fantasy.WithTools(a.tools...),
138 )
139
140 sessionLock := sync.Mutex{}
141 currentSession, err := a.sessions.Get(ctx, call.SessionID)
142 if err != nil {
143 return nil, fmt.Errorf("failed to get session: %w", err)
144 }
145
146 msgs, err := a.getSessionMessages(ctx, currentSession)
147 if err != nil {
148 return nil, fmt.Errorf("failed to get session messages: %w", err)
149 }
150
151 var wg sync.WaitGroup
152 // Generate title if first message
153 if len(msgs) == 0 {
154 wg.Go(func() {
155 sessionLock.Lock()
156 a.generateTitle(ctx, ¤tSession, call.Prompt)
157 sessionLock.Unlock()
158 })
159 }
160
161 // Add the user message to the session
162 _, err = a.createUserMessage(ctx, call)
163 if err != nil {
164 return nil, err
165 }
166
167 // add the session to the context
168 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
169
170 genCtx, cancel := context.WithCancel(ctx)
171 a.activeRequests.Set(call.SessionID, cancel)
172
173 defer cancel()
174 defer a.activeRequests.Del(call.SessionID)
175
176 history, files := a.preparePrompt(msgs, call.Attachments...)
177
178 startTime := time.Now()
179 a.eventPromptSent(call.SessionID)
180
181 var currentAssistant *message.Message
182 var shouldSummarize bool
183 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
184 Prompt: call.Prompt,
185 Files: files,
186 Messages: history,
187 ProviderOptions: call.ProviderOptions,
188 MaxOutputTokens: &call.MaxOutputTokens,
189 TopP: call.TopP,
190 Temperature: call.Temperature,
191 PresencePenalty: call.PresencePenalty,
192 TopK: call.TopK,
193 FrequencyPenalty: call.FrequencyPenalty,
194 // Before each step create the new assistant message
195 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
196 prepared.Messages = options.Messages
197 // reset all cached items
198 for i := range prepared.Messages {
199 prepared.Messages[i].ProviderOptions = nil
200 }
201
202 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
203 a.messageQueue.Del(call.SessionID)
204 for _, queued := range queuedCalls {
205 userMessage, createErr := a.createUserMessage(callContext, queued)
206 if createErr != nil {
207 return callContext, prepared, createErr
208 }
209 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
210 }
211
212 lastSystemRoleInx := 0
213 systemMessageUpdated := false
214 for i, msg := range prepared.Messages {
215 // only add cache control to the last message
216 if msg.Role == fantasy.MessageRoleSystem {
217 lastSystemRoleInx = i
218 } else if !systemMessageUpdated {
219 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
220 systemMessageUpdated = true
221 }
222 // than add cache control to the last 2 messages
223 if i > len(prepared.Messages)-3 {
224 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
225 }
226 }
227
228 var assistantMsg message.Message
229 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
230 Role: message.Assistant,
231 Parts: []message.ContentPart{},
232 Model: a.largeModel.ModelCfg.Model,
233 Provider: a.largeModel.ModelCfg.Provider,
234 })
235 if err != nil {
236 return callContext, prepared, err
237 }
238 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
239 currentAssistant = &assistantMsg
240 return callContext, prepared, err
241 },
242 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
243 currentAssistant.AppendReasoningContent(reasoning.Text)
244 return a.messages.Update(genCtx, *currentAssistant)
245 },
246 OnReasoningDelta: func(id string, text string) error {
247 currentAssistant.AppendReasoningContent(text)
248 return a.messages.Update(genCtx, *currentAssistant)
249 },
250 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
251 // handle anthropic signature
252 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
253 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
254 currentAssistant.AppendReasoningSignature(reasoning.Signature)
255 }
256 }
257 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
258 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
259 currentAssistant.AppendReasoningSignature(reasoning.Signature)
260 }
261 }
262 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
263 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
264 currentAssistant.SetReasoningResponsesData(reasoning)
265 }
266 }
267 currentAssistant.FinishThinking()
268 return a.messages.Update(genCtx, *currentAssistant)
269 },
270 OnTextDelta: func(id string, text string) error {
271 currentAssistant.AppendContent(text)
272 return a.messages.Update(genCtx, *currentAssistant)
273 },
274 OnToolInputStart: func(id string, toolName string) error {
275 toolCall := message.ToolCall{
276 ID: id,
277 Name: toolName,
278 ProviderExecuted: false,
279 Finished: false,
280 }
281 currentAssistant.AddToolCall(toolCall)
282 return a.messages.Update(genCtx, *currentAssistant)
283 },
284 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
285 // TODO: implement
286 },
287 OnToolCall: func(tc fantasy.ToolCallContent) error {
288 toolCall := message.ToolCall{
289 ID: tc.ToolCallID,
290 Name: tc.ToolName,
291 Input: tc.Input,
292 ProviderExecuted: false,
293 Finished: true,
294 }
295 currentAssistant.AddToolCall(toolCall)
296 return a.messages.Update(genCtx, *currentAssistant)
297 },
298 OnToolResult: func(result fantasy.ToolResultContent) error {
299 var resultContent string
300 isError := false
301 switch result.Result.GetType() {
302 case fantasy.ToolResultContentTypeText:
303 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
304 if ok {
305 resultContent = r.Text
306 }
307 case fantasy.ToolResultContentTypeError:
308 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
309 if ok {
310 isError = true
311 resultContent = r.Error.Error()
312 }
313 case fantasy.ToolResultContentTypeMedia:
314 // TODO: handle this message type
315 }
316 toolResult := message.ToolResult{
317 ToolCallID: result.ToolCallID,
318 Name: result.ToolName,
319 Content: resultContent,
320 IsError: isError,
321 Metadata: result.ClientMetadata,
322 }
323 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
324 Role: message.Tool,
325 Parts: []message.ContentPart{
326 toolResult,
327 },
328 })
329 if createMsgErr != nil {
330 return createMsgErr
331 }
332 return nil
333 },
334 OnStepFinish: func(stepResult fantasy.StepResult) error {
335 finishReason := message.FinishReasonUnknown
336 switch stepResult.FinishReason {
337 case fantasy.FinishReasonLength:
338 finishReason = message.FinishReasonMaxTokens
339 case fantasy.FinishReasonStop:
340 finishReason = message.FinishReasonEndTurn
341 case fantasy.FinishReasonToolCalls:
342 finishReason = message.FinishReasonToolUse
343 }
344 currentAssistant.AddFinish(finishReason, "", "")
345 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
346 sessionLock.Lock()
347 _, sessionErr := a.sessions.Save(genCtx, currentSession)
348 sessionLock.Unlock()
349 if sessionErr != nil {
350 return sessionErr
351 }
352 return a.messages.Update(genCtx, *currentAssistant)
353 },
354 StopWhen: []fantasy.StopCondition{
355 func(_ []fantasy.StepResult) bool {
356 contextWindow := a.largeModel.CatwalkCfg.ContextWindow
357 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
358 percentage := (float64(tokens) / float64(contextWindow)) * 100
359 if (percentage > 80) && !a.disableAutoSummarize {
360 shouldSummarize = true
361 return true
362 }
363 return false
364 },
365 },
366 })
367
368 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
369
370 if err != nil {
371 isCancelErr := errors.Is(err, context.Canceled)
372 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
373 if currentAssistant == nil {
374 return result, err
375 }
376 // Ensure we finish thinking on error to close the reasoning state
377 currentAssistant.FinishThinking()
378 toolCalls := currentAssistant.ToolCalls()
379 // INFO: we use the parent context here because the genCtx has been cancelled
380 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
381 if createErr != nil {
382 return nil, createErr
383 }
384 for _, tc := range toolCalls {
385 if !tc.Finished {
386 tc.Finished = true
387 tc.Input = "{}"
388 currentAssistant.AddToolCall(tc)
389 updateErr := a.messages.Update(ctx, *currentAssistant)
390 if updateErr != nil {
391 return nil, updateErr
392 }
393 }
394
395 found := false
396 for _, msg := range msgs {
397 if msg.Role == message.Tool {
398 for _, tr := range msg.ToolResults() {
399 if tr.ToolCallID == tc.ID {
400 found = true
401 break
402 }
403 }
404 }
405 if found {
406 break
407 }
408 }
409 if found {
410 continue
411 }
412 content := "There was an error while executing the tool"
413 if isCancelErr {
414 content = "Tool execution canceled by user"
415 } else if isPermissionErr {
416 content = "Permission denied"
417 }
418 toolResult := message.ToolResult{
419 ToolCallID: tc.ID,
420 Name: tc.Name,
421 Content: content,
422 IsError: true,
423 }
424 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
425 Role: message.Tool,
426 Parts: []message.ContentPart{
427 toolResult,
428 },
429 })
430 if createErr != nil {
431 return nil, createErr
432 }
433 }
434 if isCancelErr {
435 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
436 } else if isPermissionErr {
437 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
438 } else {
439 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
440 }
441 // INFO: we use the parent context here because the genCtx has been cancelled
442 updateErr := a.messages.Update(ctx, *currentAssistant)
443 if updateErr != nil {
444 return nil, updateErr
445 }
446 return nil, err
447 }
448 wg.Wait()
449
450 if shouldSummarize {
451 a.activeRequests.Del(call.SessionID)
452 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
453 return nil, summarizeErr
454 }
455 // if the agent was not done...
456 if len(currentAssistant.ToolCalls()) > 0 {
457 existing, ok := a.messageQueue.Get(call.SessionID)
458 if !ok {
459 existing = []SessionAgentCall{}
460 }
461 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
462 existing = append(existing, call)
463 a.messageQueue.Set(call.SessionID, existing)
464 }
465 }
466
467 // release active request before processing queued messages
468 a.activeRequests.Del(call.SessionID)
469 cancel()
470
471 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
472 if !ok || len(queuedMessages) == 0 {
473 return result, err
474 }
475 // there are queued messages restart the loop
476 firstQueuedMessage := queuedMessages[0]
477 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
478 return a.Run(ctx, firstQueuedMessage)
479}
480
481func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
482 if a.IsSessionBusy(sessionID) {
483 return ErrSessionBusy
484 }
485
486 currentSession, err := a.sessions.Get(ctx, sessionID)
487 if err != nil {
488 return fmt.Errorf("failed to get session: %w", err)
489 }
490 msgs, err := a.getSessionMessages(ctx, currentSession)
491 if err != nil {
492 return err
493 }
494 if len(msgs) == 0 {
495 // nothing to summarize
496 return nil
497 }
498
499 aiMsgs, _ := a.preparePrompt(msgs)
500
501 genCtx, cancel := context.WithCancel(ctx)
502 a.activeRequests.Set(sessionID, cancel)
503 defer a.activeRequests.Del(sessionID)
504 defer cancel()
505
506 agent := fantasy.NewAgent(a.largeModel.Model,
507 fantasy.WithSystemPrompt(string(summaryPrompt)),
508 )
509 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
510 Role: message.Assistant,
511 Model: a.largeModel.Model.Model(),
512 Provider: a.largeModel.Model.Provider(),
513 IsSummaryMessage: true,
514 })
515 if err != nil {
516 return err
517 }
518
519 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
520 Prompt: "Provide a detailed summary of our conversation above.",
521 Messages: aiMsgs,
522 ProviderOptions: opts,
523 OnReasoningDelta: func(id string, text string) error {
524 summaryMessage.AppendReasoningContent(text)
525 return a.messages.Update(genCtx, summaryMessage)
526 },
527 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
528 // handle anthropic signature
529 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
530 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
531 summaryMessage.AppendReasoningSignature(signature.Signature)
532 }
533 }
534 summaryMessage.FinishThinking()
535 return a.messages.Update(genCtx, summaryMessage)
536 },
537 OnTextDelta: func(id, text string) error {
538 summaryMessage.AppendContent(text)
539 return a.messages.Update(genCtx, summaryMessage)
540 },
541 })
542 if err != nil {
543 isCancelErr := errors.Is(err, context.Canceled)
544 if isCancelErr {
545 // User cancelled summarize we need to remove the summary message
546 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
547 return deleteErr
548 }
549 return err
550 }
551
552 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
553 err = a.messages.Update(genCtx, summaryMessage)
554 if err != nil {
555 return err
556 }
557
558 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
559
560 // just in case get just the last usage
561 usage := resp.Response.Usage
562 currentSession.SummaryMessageID = summaryMessage.ID
563 currentSession.CompletionTokens = usage.OutputTokens
564 currentSession.PromptTokens = 0
565 _, err = a.sessions.Save(genCtx, currentSession)
566 return err
567}
568
569func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
570 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
571 return fantasy.ProviderOptions{}
572 }
573 return fantasy.ProviderOptions{
574 anthropic.Name: &anthropic.ProviderCacheControlOptions{
575 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
576 },
577 bedrock.Name: &anthropic.ProviderCacheControlOptions{
578 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
579 },
580 }
581}
582
583func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
584 var attachmentParts []message.ContentPart
585 for _, attachment := range call.Attachments {
586 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
587 }
588 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
589 parts = append(parts, attachmentParts...)
590 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
591 Role: message.User,
592 Parts: parts,
593 })
594 if err != nil {
595 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
596 }
597 return msg, nil
598}
599
600func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
601 var history []fantasy.Message
602 for _, m := range msgs {
603 if len(m.Parts) == 0 {
604 continue
605 }
606 // Assistant message without content or tool calls (cancelled before it returned anything)
607 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
608 continue
609 }
610 history = append(history, m.ToAIMessage()...)
611 }
612
613 var files []fantasy.FilePart
614 for _, attachment := range attachments {
615 files = append(files, fantasy.FilePart{
616 Filename: attachment.FileName,
617 Data: attachment.Content,
618 MediaType: attachment.MimeType,
619 })
620 }
621
622 return history, files
623}
624
625func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
626 msgs, err := a.messages.List(ctx, session.ID)
627 if err != nil {
628 return nil, fmt.Errorf("failed to list messages: %w", err)
629 }
630
631 if session.SummaryMessageID != "" {
632 summaryMsgInex := -1
633 for i, msg := range msgs {
634 if msg.ID == session.SummaryMessageID {
635 summaryMsgInex = i
636 break
637 }
638 }
639 if summaryMsgInex != -1 {
640 msgs = msgs[summaryMsgInex:]
641 msgs[0].Role = message.User
642 }
643 }
644 return msgs, nil
645}
646
647func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
648 if prompt == "" {
649 return
650 }
651
652 var maxOutput int64 = 40
653 if a.smallModel.CatwalkCfg.CanReason {
654 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
655 }
656
657 agent := fantasy.NewAgent(a.smallModel.Model,
658 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
659 fantasy.WithMaxOutputTokens(maxOutput),
660 )
661
662 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
663 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
664 })
665 if err != nil {
666 slog.Error("error generating title", "err", err)
667 return
668 }
669
670 title := resp.Response.Content.Text()
671
672 title = strings.ReplaceAll(title, "\n", " ")
673
674 // remove thinking tags if present
675 if idx := strings.Index(title, "</think>"); idx > 0 {
676 title = title[idx+len("</think>"):]
677 }
678
679 title = strings.TrimSpace(title)
680 if title == "" {
681 slog.Warn("failed to generate title", "warn", "empty title")
682 return
683 }
684
685 session.Title = title
686 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
687 _, saveErr := a.sessions.Save(ctx, *session)
688 if saveErr != nil {
689 slog.Error("failed to save session title & usage", "error", saveErr)
690 return
691 }
692}
693
694func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
695 modelConfig := model.CatwalkCfg
696 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
697 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
698 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
699 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
700
701 a.eventTokensUsed(session.ID, model, usage, cost)
702
703 session.Cost += cost
704 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
705 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
706}
707
708func (a *sessionAgent) Cancel(sessionID string) {
709 // Cancel regular requests
710 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
711 slog.Info("Request cancellation initiated", "session_id", sessionID)
712 cancel()
713 }
714
715 // Also check for summarize requests
716 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
717 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
718 cancel()
719 }
720
721 if a.QueuedPrompts(sessionID) > 0 {
722 slog.Info("Clearing queued prompts", "session_id", sessionID)
723 a.messageQueue.Del(sessionID)
724 }
725}
726
727func (a *sessionAgent) ClearQueue(sessionID string) {
728 if a.QueuedPrompts(sessionID) > 0 {
729 slog.Info("Clearing queued prompts", "session_id", sessionID)
730 a.messageQueue.Del(sessionID)
731 }
732}
733
734func (a *sessionAgent) CancelAll() {
735 if !a.IsBusy() {
736 return
737 }
738 for key := range a.activeRequests.Seq2() {
739 a.Cancel(key) // key is sessionID
740 }
741
742 timeout := time.After(5 * time.Second)
743 for a.IsBusy() {
744 select {
745 case <-timeout:
746 return
747 default:
748 time.Sleep(200 * time.Millisecond)
749 }
750 }
751}
752
753func (a *sessionAgent) IsBusy() bool {
754 var busy bool
755 for cancelFunc := range a.activeRequests.Seq() {
756 if cancelFunc != nil {
757 busy = true
758 break
759 }
760 }
761 return busy
762}
763
764func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
765 _, busy := a.activeRequests.Get(sessionID)
766 return busy
767}
768
769func (a *sessionAgent) QueuedPrompts(sessionID string) int {
770 l, ok := a.messageQueue.Get(sessionID)
771 if !ok {
772 return 0
773 }
774 return len(l)
775}
776
777func (a *sessionAgent) SetModels(large Model, small Model) {
778 a.largeModel = large
779 a.smallModel = small
780}
781
782func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
783 a.tools = tools
784}
785
786func (a *sessionAgent) Model() Model {
787 return a.largeModel
788}