1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "os"
10 "strconv"
11 "strings"
12 "sync"
13 "time"
14
15 "charm.land/fantasy"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/providers/bedrock"
18 "charm.land/fantasy/providers/google"
19 "charm.land/fantasy/providers/openai"
20 "github.com/charmbracelet/catwalk/pkg/catwalk"
21 "github.com/charmbracelet/crush/internal/agent/tools"
22 "github.com/charmbracelet/crush/internal/config"
23 "github.com/charmbracelet/crush/internal/csync"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/session"
27)
28
29//go:embed templates/title.md
30var titlePrompt []byte
31
32//go:embed templates/summary.md
33var summaryPrompt []byte
34
35type SessionAgentCall struct {
36 SessionID string
37 Prompt string
38 ProviderOptions fantasy.ProviderOptions
39 Attachments []message.Attachment
40 MaxOutputTokens int64
41 Temperature *float64
42 TopP *float64
43 TopK *int64
44 FrequencyPenalty *float64
45 PresencePenalty *float64
46}
47
48type SessionAgent interface {
49 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
50 SetModels(large Model, small Model)
51 SetTools(tools []fantasy.AgentTool)
52 Cancel(sessionID string)
53 CancelAll()
54 IsSessionBusy(sessionID string) bool
55 IsBusy() bool
56 QueuedPrompts(sessionID string) int
57 ClearQueue(sessionID string)
58 Summarize(context.Context, string, fantasy.ProviderOptions) error
59 Model() Model
60}
61
62type Model struct {
63 Model fantasy.LanguageModel
64 CatwalkCfg catwalk.Model
65 ModelCfg config.SelectedModel
66}
67
68type sessionAgent struct {
69 largeModel Model
70 smallModel Model
71 systemPrompt string
72 tools []fantasy.AgentTool
73 sessions session.Service
74 messages message.Service
75 disableAutoSummarize bool
76 isYolo bool
77
78 messageQueue *csync.Map[string, []SessionAgentCall]
79 activeRequests *csync.Map[string, context.CancelFunc]
80}
81
82type SessionAgentOptions struct {
83 LargeModel Model
84 SmallModel Model
85 SystemPrompt string
86 DisableAutoSummarize bool
87 IsYolo bool
88 Sessions session.Service
89 Messages message.Service
90 Tools []fantasy.AgentTool
91}
92
93func NewSessionAgent(
94 opts SessionAgentOptions,
95) SessionAgent {
96 return &sessionAgent{
97 largeModel: opts.LargeModel,
98 smallModel: opts.SmallModel,
99 systemPrompt: opts.SystemPrompt,
100 sessions: opts.Sessions,
101 messages: opts.Messages,
102 disableAutoSummarize: opts.DisableAutoSummarize,
103 tools: opts.Tools,
104 isYolo: opts.IsYolo,
105 messageQueue: csync.NewMap[string, []SessionAgentCall](),
106 activeRequests: csync.NewMap[string, context.CancelFunc](),
107 }
108}
109
110func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
111 if call.Prompt == "" {
112 return nil, ErrEmptyPrompt
113 }
114 if call.SessionID == "" {
115 return nil, ErrSessionMissing
116 }
117
118 // Queue the message if busy
119 if a.IsSessionBusy(call.SessionID) {
120 existing, ok := a.messageQueue.Get(call.SessionID)
121 if !ok {
122 existing = []SessionAgentCall{}
123 }
124 existing = append(existing, call)
125 a.messageQueue.Set(call.SessionID, existing)
126 return nil, nil
127 }
128
129 if len(a.tools) > 0 {
130 // add anthropic caching to the last tool
131 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
132 }
133
134 agent := fantasy.NewAgent(
135 a.largeModel.Model,
136 fantasy.WithSystemPrompt(a.systemPrompt),
137 fantasy.WithTools(a.tools...),
138 )
139
140 sessionLock := sync.Mutex{}
141 currentSession, err := a.sessions.Get(ctx, call.SessionID)
142 if err != nil {
143 return nil, fmt.Errorf("failed to get session: %w", err)
144 }
145
146 msgs, err := a.getSessionMessages(ctx, currentSession)
147 if err != nil {
148 return nil, fmt.Errorf("failed to get session messages: %w", err)
149 }
150
151 var wg sync.WaitGroup
152 // Generate title if first message
153 if len(msgs) == 0 {
154 wg.Go(func() {
155 sessionLock.Lock()
156 a.generateTitle(ctx, ¤tSession, call.Prompt)
157 sessionLock.Unlock()
158 })
159 }
160
161 // Add the user message to the session
162 _, err = a.createUserMessage(ctx, call)
163 if err != nil {
164 return nil, err
165 }
166
167 // add the session to the context
168 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
169
170 genCtx, cancel := context.WithCancel(ctx)
171 a.activeRequests.Set(call.SessionID, cancel)
172
173 defer cancel()
174 defer a.activeRequests.Del(call.SessionID)
175
176 history, files := a.preparePrompt(msgs, call.Attachments...)
177
178 startTime := time.Now()
179 a.eventPromptSent(call.SessionID)
180
181 var currentAssistant *message.Message
182 var shouldSummarize bool
183 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
184 Prompt: call.Prompt,
185 Files: files,
186 Messages: history,
187 ProviderOptions: call.ProviderOptions,
188 MaxOutputTokens: &call.MaxOutputTokens,
189 TopP: call.TopP,
190 Temperature: call.Temperature,
191 PresencePenalty: call.PresencePenalty,
192 TopK: call.TopK,
193 FrequencyPenalty: call.FrequencyPenalty,
194 // Before each step create the new assistant message
195 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
196 prepared.Messages = options.Messages
197 // reset all cached items
198 for i := range prepared.Messages {
199 prepared.Messages[i].ProviderOptions = nil
200 }
201
202 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
203 a.messageQueue.Del(call.SessionID)
204 for _, queued := range queuedCalls {
205 userMessage, createErr := a.createUserMessage(callContext, queued)
206 if createErr != nil {
207 return callContext, prepared, createErr
208 }
209 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
210 }
211
212 lastSystemRoleInx := 0
213 systemMessageUpdated := false
214 for i, msg := range prepared.Messages {
215 // only add cache control to the last message
216 if msg.Role == fantasy.MessageRoleSystem {
217 lastSystemRoleInx = i
218 } else if !systemMessageUpdated {
219 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
220 systemMessageUpdated = true
221 }
222 // than add cache control to the last 2 messages
223 if i > len(prepared.Messages)-3 {
224 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
225 }
226 }
227
228 var assistantMsg message.Message
229 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
230 Role: message.Assistant,
231 Parts: []message.ContentPart{},
232 Model: a.largeModel.ModelCfg.Model,
233 Provider: a.largeModel.ModelCfg.Provider,
234 })
235 if err != nil {
236 return callContext, prepared, err
237 }
238 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
239 currentAssistant = &assistantMsg
240 return callContext, prepared, err
241 },
242 OnReasoningDelta: func(id string, text string) error {
243 currentAssistant.AppendReasoningContent(text)
244 return a.messages.Update(genCtx, *currentAssistant)
245 },
246 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
247 // handle anthropic signature
248 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
249 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
250 currentAssistant.AppendReasoningSignature(reasoning.Signature)
251 }
252 }
253 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
254 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
255 currentAssistant.AppendReasoningSignature(reasoning.Signature)
256 }
257 }
258 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
259 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
260 currentAssistant.SetReasoningResponsesData(reasoning)
261 }
262 }
263 currentAssistant.FinishThinking()
264 return a.messages.Update(genCtx, *currentAssistant)
265 },
266 OnTextDelta: func(id string, text string) error {
267 currentAssistant.AppendContent(text)
268 return a.messages.Update(genCtx, *currentAssistant)
269 },
270 OnToolInputStart: func(id string, toolName string) error {
271 toolCall := message.ToolCall{
272 ID: id,
273 Name: toolName,
274 ProviderExecuted: false,
275 Finished: false,
276 }
277 currentAssistant.AddToolCall(toolCall)
278 return a.messages.Update(genCtx, *currentAssistant)
279 },
280 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
281 // TODO: implement
282 },
283 OnToolCall: func(tc fantasy.ToolCallContent) error {
284 toolCall := message.ToolCall{
285 ID: tc.ToolCallID,
286 Name: tc.ToolName,
287 Input: tc.Input,
288 ProviderExecuted: false,
289 Finished: true,
290 }
291 currentAssistant.AddToolCall(toolCall)
292 return a.messages.Update(genCtx, *currentAssistant)
293 },
294 OnToolResult: func(result fantasy.ToolResultContent) error {
295 var resultContent string
296 isError := false
297 switch result.Result.GetType() {
298 case fantasy.ToolResultContentTypeText:
299 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
300 if ok {
301 resultContent = r.Text
302 }
303 case fantasy.ToolResultContentTypeError:
304 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
305 if ok {
306 isError = true
307 resultContent = r.Error.Error()
308 }
309 case fantasy.ToolResultContentTypeMedia:
310 // TODO: handle this message type
311 }
312 toolResult := message.ToolResult{
313 ToolCallID: result.ToolCallID,
314 Name: result.ToolName,
315 Content: resultContent,
316 IsError: isError,
317 Metadata: result.ClientMetadata,
318 }
319 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
320 Role: message.Tool,
321 Parts: []message.ContentPart{
322 toolResult,
323 },
324 })
325 if createMsgErr != nil {
326 return createMsgErr
327 }
328 return nil
329 },
330 OnStepFinish: func(stepResult fantasy.StepResult) error {
331 finishReason := message.FinishReasonUnknown
332 switch stepResult.FinishReason {
333 case fantasy.FinishReasonLength:
334 finishReason = message.FinishReasonMaxTokens
335 case fantasy.FinishReasonStop:
336 finishReason = message.FinishReasonEndTurn
337 case fantasy.FinishReasonToolCalls:
338 finishReason = message.FinishReasonToolUse
339 }
340 currentAssistant.AddFinish(finishReason, "", "")
341 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
342 sessionLock.Lock()
343 _, sessionErr := a.sessions.Save(genCtx, currentSession)
344 sessionLock.Unlock()
345 if sessionErr != nil {
346 return sessionErr
347 }
348 return a.messages.Update(genCtx, *currentAssistant)
349 },
350 StopWhen: []fantasy.StopCondition{
351 func(_ []fantasy.StepResult) bool {
352 contextWindow := a.largeModel.CatwalkCfg.ContextWindow
353 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
354 percentage := (float64(tokens) / float64(contextWindow)) * 100
355 if (percentage > 80) && !a.disableAutoSummarize {
356 shouldSummarize = true
357 return true
358 }
359 return false
360 },
361 },
362 })
363
364 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
365
366 if err != nil {
367 isCancelErr := errors.Is(err, context.Canceled)
368 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
369 if currentAssistant == nil {
370 return result, err
371 }
372 // Ensure we finish thinking on error to close the reasoning state
373 currentAssistant.FinishThinking()
374 toolCalls := currentAssistant.ToolCalls()
375 // INFO: we use the parent context here because the genCtx has been cancelled
376 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
377 if createErr != nil {
378 return nil, createErr
379 }
380 for _, tc := range toolCalls {
381 if !tc.Finished {
382 tc.Finished = true
383 tc.Input = "{}"
384 currentAssistant.AddToolCall(tc)
385 updateErr := a.messages.Update(ctx, *currentAssistant)
386 if updateErr != nil {
387 return nil, updateErr
388 }
389 }
390
391 found := false
392 for _, msg := range msgs {
393 if msg.Role == message.Tool {
394 for _, tr := range msg.ToolResults() {
395 if tr.ToolCallID == tc.ID {
396 found = true
397 break
398 }
399 }
400 }
401 if found {
402 break
403 }
404 }
405 if found {
406 continue
407 }
408 content := "There was an error while executing the tool"
409 if isCancelErr {
410 content = "Tool execution canceled by user"
411 } else if isPermissionErr {
412 content = "Permission denied"
413 }
414 toolResult := message.ToolResult{
415 ToolCallID: tc.ID,
416 Name: tc.Name,
417 Content: content,
418 IsError: true,
419 }
420 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
421 Role: message.Tool,
422 Parts: []message.ContentPart{
423 toolResult,
424 },
425 })
426 if createErr != nil {
427 return nil, createErr
428 }
429 }
430 if isCancelErr {
431 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
432 } else if isPermissionErr {
433 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
434 } else {
435 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
436 }
437 // INFO: we use the parent context here because the genCtx has been cancelled
438 updateErr := a.messages.Update(ctx, *currentAssistant)
439 if updateErr != nil {
440 return nil, updateErr
441 }
442 return nil, err
443 }
444 wg.Wait()
445
446 if shouldSummarize {
447 a.activeRequests.Del(call.SessionID)
448 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
449 return nil, summarizeErr
450 }
451 // if the agent was not done...
452 if len(currentAssistant.ToolCalls()) > 0 {
453 existing, ok := a.messageQueue.Get(call.SessionID)
454 if !ok {
455 existing = []SessionAgentCall{}
456 }
457 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
458 existing = append(existing, call)
459 a.messageQueue.Set(call.SessionID, existing)
460 }
461 }
462
463 // release active request before processing queued messages
464 a.activeRequests.Del(call.SessionID)
465 cancel()
466
467 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
468 if !ok || len(queuedMessages) == 0 {
469 return result, err
470 }
471 // there are queued messages restart the loop
472 firstQueuedMessage := queuedMessages[0]
473 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
474 return a.Run(ctx, firstQueuedMessage)
475}
476
477func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
478 if a.IsSessionBusy(sessionID) {
479 return ErrSessionBusy
480 }
481
482 currentSession, err := a.sessions.Get(ctx, sessionID)
483 if err != nil {
484 return fmt.Errorf("failed to get session: %w", err)
485 }
486 msgs, err := a.getSessionMessages(ctx, currentSession)
487 if err != nil {
488 return err
489 }
490 if len(msgs) == 0 {
491 // nothing to summarize
492 return nil
493 }
494
495 aiMsgs, _ := a.preparePrompt(msgs)
496
497 genCtx, cancel := context.WithCancel(ctx)
498 a.activeRequests.Set(sessionID, cancel)
499 defer a.activeRequests.Del(sessionID)
500 defer cancel()
501
502 agent := fantasy.NewAgent(a.largeModel.Model,
503 fantasy.WithSystemPrompt(string(summaryPrompt)),
504 )
505 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
506 Role: message.Assistant,
507 Model: a.largeModel.Model.Model(),
508 Provider: a.largeModel.Model.Provider(),
509 IsSummaryMessage: true,
510 })
511 if err != nil {
512 return err
513 }
514
515 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
516 Prompt: "Provide a detailed summary of our conversation above.",
517 Messages: aiMsgs,
518 ProviderOptions: opts,
519 OnReasoningDelta: func(id string, text string) error {
520 summaryMessage.AppendReasoningContent(text)
521 return a.messages.Update(genCtx, summaryMessage)
522 },
523 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
524 // handle anthropic signature
525 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
526 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
527 summaryMessage.AppendReasoningSignature(signature.Signature)
528 }
529 }
530 summaryMessage.FinishThinking()
531 return a.messages.Update(genCtx, summaryMessage)
532 },
533 OnTextDelta: func(id, text string) error {
534 summaryMessage.AppendContent(text)
535 return a.messages.Update(genCtx, summaryMessage)
536 },
537 })
538 if err != nil {
539 isCancelErr := errors.Is(err, context.Canceled)
540 if isCancelErr {
541 // User cancelled summarize we need to remove the summary message
542 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
543 return deleteErr
544 }
545 return err
546 }
547
548 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
549 err = a.messages.Update(genCtx, summaryMessage)
550 if err != nil {
551 return err
552 }
553
554 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
555
556 // just in case get just the last usage
557 usage := resp.Response.Usage
558 currentSession.SummaryMessageID = summaryMessage.ID
559 currentSession.CompletionTokens = usage.OutputTokens
560 currentSession.PromptTokens = 0
561 _, err = a.sessions.Save(genCtx, currentSession)
562 return err
563}
564
565func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
566 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
567 return fantasy.ProviderOptions{}
568 }
569 return fantasy.ProviderOptions{
570 anthropic.Name: &anthropic.ProviderCacheControlOptions{
571 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
572 },
573 bedrock.Name: &anthropic.ProviderCacheControlOptions{
574 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
575 },
576 }
577}
578
579func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
580 var attachmentParts []message.ContentPart
581 for _, attachment := range call.Attachments {
582 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
583 }
584 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
585 parts = append(parts, attachmentParts...)
586 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
587 Role: message.User,
588 Parts: parts,
589 })
590 if err != nil {
591 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
592 }
593 return msg, nil
594}
595
596func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
597 var history []fantasy.Message
598 for _, m := range msgs {
599 if len(m.Parts) == 0 {
600 continue
601 }
602 // Assistant message without content or tool calls (cancelled before it returned anything)
603 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
604 continue
605 }
606 history = append(history, m.ToAIMessage()...)
607 }
608
609 var files []fantasy.FilePart
610 for _, attachment := range attachments {
611 files = append(files, fantasy.FilePart{
612 Filename: attachment.FileName,
613 Data: attachment.Content,
614 MediaType: attachment.MimeType,
615 })
616 }
617
618 return history, files
619}
620
621func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
622 msgs, err := a.messages.List(ctx, session.ID)
623 if err != nil {
624 return nil, fmt.Errorf("failed to list messages: %w", err)
625 }
626
627 if session.SummaryMessageID != "" {
628 summaryMsgInex := -1
629 for i, msg := range msgs {
630 if msg.ID == session.SummaryMessageID {
631 summaryMsgInex = i
632 break
633 }
634 }
635 if summaryMsgInex != -1 {
636 msgs = msgs[summaryMsgInex:]
637 msgs[0].Role = message.User
638 }
639 }
640 return msgs, nil
641}
642
643func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
644 if prompt == "" {
645 return
646 }
647
648 var maxOutput int64 = 40
649 if a.smallModel.CatwalkCfg.CanReason {
650 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
651 }
652
653 agent := fantasy.NewAgent(a.smallModel.Model,
654 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
655 fantasy.WithMaxOutputTokens(maxOutput),
656 )
657
658 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
659 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
660 })
661 if err != nil {
662 slog.Error("error generating title", "err", err)
663 return
664 }
665
666 title := resp.Response.Content.Text()
667
668 title = strings.ReplaceAll(title, "\n", " ")
669
670 // remove thinking tags if present
671 if idx := strings.Index(title, "</think>"); idx > 0 {
672 title = title[idx+len("</think>"):]
673 }
674
675 title = strings.TrimSpace(title)
676 if title == "" {
677 slog.Warn("failed to generate title", "warn", "empty title")
678 return
679 }
680
681 session.Title = title
682 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
683 _, saveErr := a.sessions.Save(ctx, *session)
684 if saveErr != nil {
685 slog.Error("failed to save session title & usage", "error", saveErr)
686 return
687 }
688}
689
690func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
691 modelConfig := model.CatwalkCfg
692 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
693 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
694 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
695 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
696
697 a.eventTokensUsed(session.ID, model, usage, cost)
698
699 session.Cost += cost
700 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
701 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
702}
703
704func (a *sessionAgent) Cancel(sessionID string) {
705 // Cancel regular requests
706 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
707 slog.Info("Request cancellation initiated", "session_id", sessionID)
708 cancel()
709 }
710
711 // Also check for summarize requests
712 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
713 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
714 cancel()
715 }
716
717 if a.QueuedPrompts(sessionID) > 0 {
718 slog.Info("Clearing queued prompts", "session_id", sessionID)
719 a.messageQueue.Del(sessionID)
720 }
721}
722
723func (a *sessionAgent) ClearQueue(sessionID string) {
724 if a.QueuedPrompts(sessionID) > 0 {
725 slog.Info("Clearing queued prompts", "session_id", sessionID)
726 a.messageQueue.Del(sessionID)
727 }
728}
729
730func (a *sessionAgent) CancelAll() {
731 if !a.IsBusy() {
732 return
733 }
734 for key := range a.activeRequests.Seq2() {
735 a.Cancel(key) // key is sessionID
736 }
737
738 timeout := time.After(5 * time.Second)
739 for a.IsBusy() {
740 select {
741 case <-timeout:
742 return
743 default:
744 time.Sleep(200 * time.Millisecond)
745 }
746 }
747}
748
749func (a *sessionAgent) IsBusy() bool {
750 var busy bool
751 for cancelFunc := range a.activeRequests.Seq() {
752 if cancelFunc != nil {
753 busy = true
754 break
755 }
756 }
757 return busy
758}
759
760func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
761 _, busy := a.activeRequests.Get(sessionID)
762 return busy
763}
764
765func (a *sessionAgent) QueuedPrompts(sessionID string) int {
766 l, ok := a.messageQueue.Get(sessionID)
767 if !ok {
768 return 0
769 }
770 return len(l)
771}
772
773func (a *sessionAgent) SetModels(large Model, small Model) {
774 a.largeModel = large
775 a.smallModel = small
776}
777
778func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
779 a.tools = tools
780}
781
782func (a *sessionAgent) Model() Model {
783 return a.largeModel
784}