1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "os"
10 "strconv"
11 "strings"
12 "sync"
13 "time"
14
15 "charm.land/fantasy"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/providers/bedrock"
18 "charm.land/fantasy/providers/google"
19 "charm.land/fantasy/providers/openai"
20 "github.com/charmbracelet/catwalk/pkg/catwalk"
21 "github.com/charmbracelet/crush/internal/agent/tools"
22 "github.com/charmbracelet/crush/internal/config"
23 "github.com/charmbracelet/crush/internal/csync"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/session"
27)
28
29//go:embed templates/title.md
30var titlePrompt []byte
31
32//go:embed templates/summary.md
33var summaryPrompt []byte
34
35type SessionAgentCall struct {
36 SessionID string
37 Prompt string
38 ProviderOptions fantasy.ProviderOptions
39 Attachments []message.Attachment
40 MaxOutputTokens int64
41 Temperature *float64
42 TopP *float64
43 TopK *int64
44 FrequencyPenalty *float64
45 PresencePenalty *float64
46}
47
48type SessionAgent interface {
49 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
50 SetModels(large Model, small Model)
51 SetTools(tools []fantasy.AgentTool)
52 Cancel(sessionID string)
53 CancelAll()
54 IsSessionBusy(sessionID string) bool
55 IsBusy() bool
56 QueuedPrompts(sessionID string) int
57 ClearQueue(sessionID string)
58 Summarize(context.Context, string, fantasy.ProviderOptions) error
59 Model() Model
60}
61
62type Model struct {
63 Model fantasy.LanguageModel
64 CatwalkCfg catwalk.Model
65 ModelCfg config.SelectedModel
66}
67
68type sessionAgent struct {
69 largeModel Model
70 smallModel Model
71 systemPrompt string
72 tools []fantasy.AgentTool
73 sessions session.Service
74 messages message.Service
75 disableAutoSummarize bool
76
77 messageQueue *csync.Map[string, []SessionAgentCall]
78 activeRequests *csync.Map[string, context.CancelFunc]
79}
80
81type SessionAgentOptions struct {
82 LargeModel Model
83 SmallModel Model
84 SystemPrompt string
85 DisableAutoSummarize bool
86 Sessions session.Service
87 Messages message.Service
88 Tools []fantasy.AgentTool
89}
90
91func NewSessionAgent(
92 opts SessionAgentOptions,
93) SessionAgent {
94 return &sessionAgent{
95 largeModel: opts.LargeModel,
96 smallModel: opts.SmallModel,
97 systemPrompt: opts.SystemPrompt,
98 sessions: opts.Sessions,
99 messages: opts.Messages,
100 disableAutoSummarize: opts.DisableAutoSummarize,
101 tools: opts.Tools,
102 messageQueue: csync.NewMap[string, []SessionAgentCall](),
103 activeRequests: csync.NewMap[string, context.CancelFunc](),
104 }
105}
106
107func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
108 if call.Prompt == "" {
109 return nil, ErrEmptyPrompt
110 }
111 if call.SessionID == "" {
112 return nil, ErrSessionMissing
113 }
114
115 // Queue the message if busy
116 if a.IsSessionBusy(call.SessionID) {
117 existing, ok := a.messageQueue.Get(call.SessionID)
118 if !ok {
119 existing = []SessionAgentCall{}
120 }
121 existing = append(existing, call)
122 a.messageQueue.Set(call.SessionID, existing)
123 return nil, nil
124 }
125
126 if len(a.tools) > 0 {
127 // add anthropic caching to the last tool
128 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
129 }
130
131 agent := fantasy.NewAgent(
132 a.largeModel.Model,
133 fantasy.WithSystemPrompt(a.systemPrompt),
134 fantasy.WithTools(a.tools...),
135 )
136
137 sessionLock := sync.Mutex{}
138 currentSession, err := a.sessions.Get(ctx, call.SessionID)
139 if err != nil {
140 return nil, fmt.Errorf("failed to get session: %w", err)
141 }
142
143 msgs, err := a.getSessionMessages(ctx, currentSession)
144 if err != nil {
145 return nil, fmt.Errorf("failed to get session messages: %w", err)
146 }
147
148 var wg sync.WaitGroup
149 // Generate title if first message
150 if len(msgs) == 0 {
151 wg.Go(func() {
152 sessionLock.Lock()
153 a.generateTitle(ctx, ¤tSession, call.Prompt)
154 sessionLock.Unlock()
155 })
156 }
157
158 // Add the user message to the session
159 _, err = a.createUserMessage(ctx, call)
160 if err != nil {
161 return nil, err
162 }
163
164 // add the session to the context
165 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
166
167 genCtx, cancel := context.WithCancel(ctx)
168 a.activeRequests.Set(call.SessionID, cancel)
169
170 defer cancel()
171 defer a.activeRequests.Del(call.SessionID)
172
173 history, files := a.preparePrompt(msgs, call.Attachments...)
174
175 var currentAssistant *message.Message
176 var shouldSummarize bool
177 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
178 Prompt: call.Prompt,
179 Files: files,
180 Messages: history,
181 ProviderOptions: call.ProviderOptions,
182 MaxOutputTokens: &call.MaxOutputTokens,
183 TopP: call.TopP,
184 Temperature: call.Temperature,
185 PresencePenalty: call.PresencePenalty,
186 TopK: call.TopK,
187 FrequencyPenalty: call.FrequencyPenalty,
188 // Before each step create the new assistant message
189 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
190 prepared.Messages = options.Messages
191 // reset all cached items
192 for i := range prepared.Messages {
193 prepared.Messages[i].ProviderOptions = nil
194 }
195
196 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
197 a.messageQueue.Del(call.SessionID)
198 for _, queued := range queuedCalls {
199 userMessage, createErr := a.createUserMessage(callContext, queued)
200 if createErr != nil {
201 return callContext, prepared, createErr
202 }
203 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
204 }
205
206 lastSystemRoleInx := 0
207 systemMessageUpdated := false
208 for i, msg := range prepared.Messages {
209 // only add cache control to the last message
210 if msg.Role == fantasy.MessageRoleSystem {
211 lastSystemRoleInx = i
212 } else if !systemMessageUpdated {
213 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
214 systemMessageUpdated = true
215 }
216 // than add cache control to the last 2 messages
217 if i > len(prepared.Messages)-3 {
218 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
219 }
220 }
221
222 var assistantMsg message.Message
223 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
224 Role: message.Assistant,
225 Parts: []message.ContentPart{},
226 Model: a.largeModel.ModelCfg.Model,
227 Provider: a.largeModel.ModelCfg.Provider,
228 })
229 if err != nil {
230 return callContext, prepared, err
231 }
232 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
233 currentAssistant = &assistantMsg
234 return callContext, prepared, err
235 },
236 OnReasoningDelta: func(id string, text string) error {
237 currentAssistant.AppendReasoningContent(text)
238 return a.messages.Update(genCtx, *currentAssistant)
239 },
240 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
241 // handle anthropic signature
242 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
243 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
244 currentAssistant.AppendReasoningSignature(reasoning.Signature)
245 }
246 }
247 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
248 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
249 currentAssistant.AppendReasoningSignature(reasoning.Signature)
250 }
251 }
252 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
253 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
254 currentAssistant.SetReasoningResponsesData(reasoning)
255 }
256 }
257 currentAssistant.FinishThinking()
258 return a.messages.Update(genCtx, *currentAssistant)
259 },
260 OnTextDelta: func(id string, text string) error {
261 currentAssistant.AppendContent(text)
262 return a.messages.Update(genCtx, *currentAssistant)
263 },
264 OnToolInputStart: func(id string, toolName string) error {
265 toolCall := message.ToolCall{
266 ID: id,
267 Name: toolName,
268 ProviderExecuted: false,
269 Finished: false,
270 }
271 currentAssistant.AddToolCall(toolCall)
272 return a.messages.Update(genCtx, *currentAssistant)
273 },
274 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
275 // TODO: implement
276 },
277 OnToolCall: func(tc fantasy.ToolCallContent) error {
278 toolCall := message.ToolCall{
279 ID: tc.ToolCallID,
280 Name: tc.ToolName,
281 Input: tc.Input,
282 ProviderExecuted: false,
283 Finished: true,
284 }
285 currentAssistant.AddToolCall(toolCall)
286 return a.messages.Update(genCtx, *currentAssistant)
287 },
288 OnToolResult: func(result fantasy.ToolResultContent) error {
289 var resultContent string
290 isError := false
291 switch result.Result.GetType() {
292 case fantasy.ToolResultContentTypeText:
293 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
294 if ok {
295 resultContent = r.Text
296 }
297 case fantasy.ToolResultContentTypeError:
298 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
299 if ok {
300 isError = true
301 resultContent = r.Error.Error()
302 }
303 case fantasy.ToolResultContentTypeMedia:
304 // TODO: handle this message type
305 }
306 toolResult := message.ToolResult{
307 ToolCallID: result.ToolCallID,
308 Name: result.ToolName,
309 Content: resultContent,
310 IsError: isError,
311 Metadata: result.ClientMetadata,
312 }
313 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
314 Role: message.Tool,
315 Parts: []message.ContentPart{
316 toolResult,
317 },
318 })
319 if createMsgErr != nil {
320 return createMsgErr
321 }
322 return nil
323 },
324 OnStepFinish: func(stepResult fantasy.StepResult) error {
325 finishReason := message.FinishReasonUnknown
326 switch stepResult.FinishReason {
327 case fantasy.FinishReasonLength:
328 finishReason = message.FinishReasonMaxTokens
329 case fantasy.FinishReasonStop:
330 finishReason = message.FinishReasonEndTurn
331 case fantasy.FinishReasonToolCalls:
332 finishReason = message.FinishReasonToolUse
333 }
334 currentAssistant.AddFinish(finishReason, "", "")
335 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
336 sessionLock.Lock()
337 _, sessionErr := a.sessions.Save(genCtx, currentSession)
338 sessionLock.Unlock()
339 if sessionErr != nil {
340 return sessionErr
341 }
342 return a.messages.Update(genCtx, *currentAssistant)
343 },
344 StopWhen: []fantasy.StopCondition{
345 func(_ []fantasy.StepResult) bool {
346 contextWindow := a.largeModel.CatwalkCfg.ContextWindow
347 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
348 percentage := (float64(tokens) / float64(contextWindow)) * 100
349 if (percentage > 80) && !a.disableAutoSummarize {
350 shouldSummarize = true
351 return true
352 }
353 return false
354 },
355 },
356 })
357 if err != nil {
358 isCancelErr := errors.Is(err, context.Canceled)
359 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
360 if currentAssistant == nil {
361 return result, err
362 }
363 // Ensure we finish thinking on error to close the reasoning state
364 currentAssistant.FinishThinking()
365 toolCalls := currentAssistant.ToolCalls()
366 // INFO: we use the parent context here because the genCtx has been cancelled
367 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
368 if createErr != nil {
369 return nil, createErr
370 }
371 for _, tc := range toolCalls {
372 if !tc.Finished {
373 tc.Finished = true
374 tc.Input = "{}"
375 currentAssistant.AddToolCall(tc)
376 updateErr := a.messages.Update(ctx, *currentAssistant)
377 if updateErr != nil {
378 return nil, updateErr
379 }
380 }
381
382 found := false
383 for _, msg := range msgs {
384 if msg.Role == message.Tool {
385 for _, tr := range msg.ToolResults() {
386 if tr.ToolCallID == tc.ID {
387 found = true
388 break
389 }
390 }
391 }
392 if found {
393 break
394 }
395 }
396 if found {
397 continue
398 }
399 content := "There was an error while executing the tool"
400 if isCancelErr {
401 content = "Tool execution canceled by user"
402 } else if isPermissionErr {
403 content = "Permission denied"
404 }
405 toolResult := message.ToolResult{
406 ToolCallID: tc.ID,
407 Name: tc.Name,
408 Content: content,
409 IsError: true,
410 }
411 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
412 Role: message.Tool,
413 Parts: []message.ContentPart{
414 toolResult,
415 },
416 })
417 if createErr != nil {
418 return nil, createErr
419 }
420 }
421 if isCancelErr {
422 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
423 } else if isPermissionErr {
424 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
425 } else {
426 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
427 }
428 // INFO: we use the parent context here because the genCtx has been cancelled
429 updateErr := a.messages.Update(ctx, *currentAssistant)
430 if updateErr != nil {
431 return nil, updateErr
432 }
433 return nil, err
434 }
435 wg.Wait()
436
437 if shouldSummarize {
438 a.activeRequests.Del(call.SessionID)
439 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
440 return nil, summarizeErr
441 }
442 // if the agent was not done...
443 if len(currentAssistant.ToolCalls()) > 0 {
444 existing, ok := a.messageQueue.Get(call.SessionID)
445 if !ok {
446 existing = []SessionAgentCall{}
447 }
448 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
449 existing = append(existing, call)
450 a.messageQueue.Set(call.SessionID, existing)
451 }
452 }
453
454 // release active request before processing queued messages
455 a.activeRequests.Del(call.SessionID)
456 cancel()
457
458 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
459 if !ok || len(queuedMessages) == 0 {
460 return result, err
461 }
462 // there are queued messages restart the loop
463 firstQueuedMessage := queuedMessages[0]
464 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
465 return a.Run(ctx, firstQueuedMessage)
466}
467
468func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
469 if a.IsSessionBusy(sessionID) {
470 return ErrSessionBusy
471 }
472
473 currentSession, err := a.sessions.Get(ctx, sessionID)
474 if err != nil {
475 return fmt.Errorf("failed to get session: %w", err)
476 }
477 msgs, err := a.getSessionMessages(ctx, currentSession)
478 if err != nil {
479 return err
480 }
481 if len(msgs) == 0 {
482 // nothing to summarize
483 return nil
484 }
485
486 aiMsgs, _ := a.preparePrompt(msgs)
487
488 genCtx, cancel := context.WithCancel(ctx)
489 a.activeRequests.Set(sessionID, cancel)
490 defer a.activeRequests.Del(sessionID)
491 defer cancel()
492
493 agent := fantasy.NewAgent(a.largeModel.Model,
494 fantasy.WithSystemPrompt(string(summaryPrompt)),
495 )
496 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
497 Role: message.Assistant,
498 Model: a.largeModel.Model.Model(),
499 Provider: a.largeModel.Model.Provider(),
500 IsSummaryMessage: true,
501 })
502 if err != nil {
503 return err
504 }
505
506 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
507 Prompt: "Provide a detailed summary of our conversation above.",
508 Messages: aiMsgs,
509 ProviderOptions: opts,
510 OnReasoningDelta: func(id string, text string) error {
511 summaryMessage.AppendReasoningContent(text)
512 return a.messages.Update(genCtx, summaryMessage)
513 },
514 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
515 // handle anthropic signature
516 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
517 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
518 summaryMessage.AppendReasoningSignature(signature.Signature)
519 }
520 }
521 summaryMessage.FinishThinking()
522 return a.messages.Update(genCtx, summaryMessage)
523 },
524 OnTextDelta: func(id, text string) error {
525 summaryMessage.AppendContent(text)
526 return a.messages.Update(genCtx, summaryMessage)
527 },
528 })
529 if err != nil {
530 isCancelErr := errors.Is(err, context.Canceled)
531 if isCancelErr {
532 // User cancelled summarize we need to remove the summary message
533 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
534 return deleteErr
535 }
536 return err
537 }
538
539 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
540 err = a.messages.Update(genCtx, summaryMessage)
541 if err != nil {
542 return err
543 }
544
545 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
546
547 // just in case get just the last usage
548 usage := resp.Response.Usage
549 currentSession.SummaryMessageID = summaryMessage.ID
550 currentSession.CompletionTokens = usage.OutputTokens
551 currentSession.PromptTokens = 0
552 _, err = a.sessions.Save(genCtx, currentSession)
553 return err
554}
555
556func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
557 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
558 return fantasy.ProviderOptions{}
559 }
560 return fantasy.ProviderOptions{
561 anthropic.Name: &anthropic.ProviderCacheControlOptions{
562 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
563 },
564 bedrock.Name: &anthropic.ProviderCacheControlOptions{
565 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
566 },
567 }
568}
569
570func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
571 var attachmentParts []message.ContentPart
572 for _, attachment := range call.Attachments {
573 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
574 }
575 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
576 parts = append(parts, attachmentParts...)
577 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
578 Role: message.User,
579 Parts: parts,
580 })
581 if err != nil {
582 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
583 }
584 return msg, nil
585}
586
587func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
588 var history []fantasy.Message
589 for _, m := range msgs {
590 if len(m.Parts) == 0 {
591 continue
592 }
593 // Assistant message without content or tool calls (cancelled before it returned anything)
594 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
595 continue
596 }
597 history = append(history, m.ToAIMessage()...)
598 }
599
600 var files []fantasy.FilePart
601 for _, attachment := range attachments {
602 files = append(files, fantasy.FilePart{
603 Filename: attachment.FileName,
604 Data: attachment.Content,
605 MediaType: attachment.MimeType,
606 })
607 }
608
609 return history, files
610}
611
612func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
613 msgs, err := a.messages.List(ctx, session.ID)
614 if err != nil {
615 return nil, fmt.Errorf("failed to list messages: %w", err)
616 }
617
618 if session.SummaryMessageID != "" {
619 summaryMsgInex := -1
620 for i, msg := range msgs {
621 if msg.ID == session.SummaryMessageID {
622 summaryMsgInex = i
623 break
624 }
625 }
626 if summaryMsgInex != -1 {
627 msgs = msgs[summaryMsgInex:]
628 msgs[0].Role = message.User
629 }
630 }
631 return msgs, nil
632}
633
634func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
635 if prompt == "" {
636 return
637 }
638
639 var maxOutput int64 = 40
640 if a.smallModel.CatwalkCfg.CanReason {
641 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
642 }
643
644 agent := fantasy.NewAgent(a.smallModel.Model,
645 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
646 fantasy.WithMaxOutputTokens(maxOutput),
647 )
648
649 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
650 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
651 })
652 if err != nil {
653 slog.Error("error generating title", "err", err)
654 return
655 }
656
657 title := resp.Response.Content.Text()
658
659 title = strings.ReplaceAll(title, "\n", " ")
660
661 // remove thinking tags if present
662 if idx := strings.Index(title, "</think>"); idx > 0 {
663 title = title[idx+len("</think>"):]
664 }
665
666 title = strings.TrimSpace(title)
667 if title == "" {
668 slog.Warn("failed to generate title", "warn", "empty title")
669 return
670 }
671
672 session.Title = title
673 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
674 _, saveErr := a.sessions.Save(ctx, *session)
675 if saveErr != nil {
676 slog.Error("failed to save session title & usage", "error", saveErr)
677 return
678 }
679}
680
681func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
682 modelConfig := model.CatwalkCfg
683 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
684 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
685 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
686 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
687 session.Cost += cost
688 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
689 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
690}
691
692func (a *sessionAgent) Cancel(sessionID string) {
693 // Cancel regular requests
694 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
695 slog.Info("Request cancellation initiated", "session_id", sessionID)
696 cancel()
697 }
698
699 // Also check for summarize requests
700 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
701 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
702 cancel()
703 }
704
705 if a.QueuedPrompts(sessionID) > 0 {
706 slog.Info("Clearing queued prompts", "session_id", sessionID)
707 a.messageQueue.Del(sessionID)
708 }
709}
710
711func (a *sessionAgent) ClearQueue(sessionID string) {
712 if a.QueuedPrompts(sessionID) > 0 {
713 slog.Info("Clearing queued prompts", "session_id", sessionID)
714 a.messageQueue.Del(sessionID)
715 }
716}
717
718func (a *sessionAgent) CancelAll() {
719 if !a.IsBusy() {
720 return
721 }
722 for key := range a.activeRequests.Seq2() {
723 a.Cancel(key) // key is sessionID
724 }
725
726 timeout := time.After(5 * time.Second)
727 for a.IsBusy() {
728 select {
729 case <-timeout:
730 return
731 default:
732 time.Sleep(200 * time.Millisecond)
733 }
734 }
735}
736
737func (a *sessionAgent) IsBusy() bool {
738 var busy bool
739 for cancelFunc := range a.activeRequests.Seq() {
740 if cancelFunc != nil {
741 busy = true
742 break
743 }
744 }
745 return busy
746}
747
748func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
749 _, busy := a.activeRequests.Get(sessionID)
750 return busy
751}
752
753func (a *sessionAgent) QueuedPrompts(sessionID string) int {
754 l, ok := a.messageQueue.Get(sessionID)
755 if !ok {
756 return 0
757 }
758 return len(l)
759}
760
761func (a *sessionAgent) SetModels(large Model, small Model) {
762 a.largeModel = large
763 a.smallModel = small
764}
765
766func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
767 a.tools = tools
768}
769
770func (a *sessionAgent) Model() Model {
771 return a.largeModel
772}