1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"strconv"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"charm.land/fantasy/providers/anthropic"
 17	"charm.land/fantasy/providers/bedrock"
 18	"charm.land/fantasy/providers/google"
 19	"charm.land/fantasy/providers/openai"
 20	"github.com/charmbracelet/catwalk/pkg/catwalk"
 21	"github.com/charmbracelet/crush/internal/agent/tools"
 22	"github.com/charmbracelet/crush/internal/config"
 23	"github.com/charmbracelet/crush/internal/csync"
 24	"github.com/charmbracelet/crush/internal/message"
 25	"github.com/charmbracelet/crush/internal/permission"
 26	"github.com/charmbracelet/crush/internal/session"
 27)
 28
 29//go:embed templates/title.md
 30var titlePrompt []byte
 31
 32//go:embed templates/summary.md
 33var summaryPrompt []byte
 34
 35type SessionAgentCall struct {
 36	SessionID        string
 37	Prompt           string
 38	ProviderOptions  fantasy.ProviderOptions
 39	Attachments      []message.Attachment
 40	MaxOutputTokens  int64
 41	Temperature      *float64
 42	TopP             *float64
 43	TopK             *int64
 44	FrequencyPenalty *float64
 45	PresencePenalty  *float64
 46}
 47
 48type SessionAgent interface {
 49	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 50	SetModels(large Model, small Model)
 51	SetTools(tools []fantasy.AgentTool)
 52	Cancel(sessionID string)
 53	CancelAll()
 54	IsSessionBusy(sessionID string) bool
 55	IsBusy() bool
 56	QueuedPrompts(sessionID string) int
 57	ClearQueue(sessionID string)
 58	Summarize(context.Context, string, fantasy.ProviderOptions) error
 59	Model() Model
 60}
 61
 62type Model struct {
 63	Model      fantasy.LanguageModel
 64	CatwalkCfg catwalk.Model
 65	ModelCfg   config.SelectedModel
 66}
 67
 68type sessionAgent struct {
 69	largeModel           Model
 70	smallModel           Model
 71	systemPrompt         string
 72	tools                []fantasy.AgentTool
 73	sessions             session.Service
 74	messages             message.Service
 75	disableAutoSummarize bool
 76
 77	messageQueue   *csync.Map[string, []SessionAgentCall]
 78	activeRequests *csync.Map[string, context.CancelFunc]
 79}
 80
 81type SessionAgentOptions struct {
 82	LargeModel           Model
 83	SmallModel           Model
 84	SystemPrompt         string
 85	DisableAutoSummarize bool
 86	Sessions             session.Service
 87	Messages             message.Service
 88	Tools                []fantasy.AgentTool
 89}
 90
 91func NewSessionAgent(
 92	opts SessionAgentOptions,
 93) SessionAgent {
 94	return &sessionAgent{
 95		largeModel:           opts.LargeModel,
 96		smallModel:           opts.SmallModel,
 97		systemPrompt:         opts.SystemPrompt,
 98		sessions:             opts.Sessions,
 99		messages:             opts.Messages,
100		disableAutoSummarize: opts.DisableAutoSummarize,
101		tools:                opts.Tools,
102		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
103		activeRequests:       csync.NewMap[string, context.CancelFunc](),
104	}
105}
106
107func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
108	if call.Prompt == "" {
109		return nil, ErrEmptyPrompt
110	}
111	if call.SessionID == "" {
112		return nil, ErrSessionMissing
113	}
114
115	// Queue the message if busy
116	if a.IsSessionBusy(call.SessionID) {
117		existing, ok := a.messageQueue.Get(call.SessionID)
118		if !ok {
119			existing = []SessionAgentCall{}
120		}
121		existing = append(existing, call)
122		a.messageQueue.Set(call.SessionID, existing)
123		return nil, nil
124	}
125
126	if len(a.tools) > 0 {
127		// add anthropic caching to the last tool
128		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
129	}
130
131	agent := fantasy.NewAgent(
132		a.largeModel.Model,
133		fantasy.WithSystemPrompt(a.systemPrompt),
134		fantasy.WithTools(a.tools...),
135	)
136
137	sessionLock := sync.Mutex{}
138	currentSession, err := a.sessions.Get(ctx, call.SessionID)
139	if err != nil {
140		return nil, fmt.Errorf("failed to get session: %w", err)
141	}
142
143	msgs, err := a.getSessionMessages(ctx, currentSession)
144	if err != nil {
145		return nil, fmt.Errorf("failed to get session messages: %w", err)
146	}
147
148	var wg sync.WaitGroup
149	// Generate title if first message
150	if len(msgs) == 0 {
151		wg.Go(func() {
152			sessionLock.Lock()
153			a.generateTitle(ctx, ¤tSession, call.Prompt)
154			sessionLock.Unlock()
155		})
156	}
157
158	// Add the user message to the session
159	_, err = a.createUserMessage(ctx, call)
160	if err != nil {
161		return nil, err
162	}
163
164	// add the session to the context
165	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
166
167	genCtx, cancel := context.WithCancel(ctx)
168	a.activeRequests.Set(call.SessionID, cancel)
169
170	defer cancel()
171	defer a.activeRequests.Del(call.SessionID)
172
173	history, files := a.preparePrompt(msgs, call.Attachments...)
174
175	var currentAssistant *message.Message
176	var shouldSummarize bool
177	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
178		Prompt:           call.Prompt,
179		Files:            files,
180		Messages:         history,
181		ProviderOptions:  call.ProviderOptions,
182		MaxOutputTokens:  &call.MaxOutputTokens,
183		TopP:             call.TopP,
184		Temperature:      call.Temperature,
185		PresencePenalty:  call.PresencePenalty,
186		TopK:             call.TopK,
187		FrequencyPenalty: call.FrequencyPenalty,
188		// Before each step create the new assistant message
189		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
190			prepared.Messages = options.Messages
191			// reset all cached items
192			for i := range prepared.Messages {
193				prepared.Messages[i].ProviderOptions = nil
194			}
195
196			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
197			a.messageQueue.Del(call.SessionID)
198			for _, queued := range queuedCalls {
199				userMessage, createErr := a.createUserMessage(callContext, queued)
200				if createErr != nil {
201					return callContext, prepared, createErr
202				}
203				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
204			}
205
206			lastSystemRoleInx := 0
207			systemMessageUpdated := false
208			for i, msg := range prepared.Messages {
209				// only add cache control to the last message
210				if msg.Role == fantasy.MessageRoleSystem {
211					lastSystemRoleInx = i
212				} else if !systemMessageUpdated {
213					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
214					systemMessageUpdated = true
215				}
216				// than add cache control to the last 2 messages
217				if i > len(prepared.Messages)-3 {
218					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
219				}
220			}
221
222			var assistantMsg message.Message
223			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
224				Role:     message.Assistant,
225				Parts:    []message.ContentPart{},
226				Model:    a.largeModel.ModelCfg.Model,
227				Provider: a.largeModel.ModelCfg.Provider,
228			})
229			if err != nil {
230				return callContext, prepared, err
231			}
232			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
233			currentAssistant = &assistantMsg
234			return callContext, prepared, err
235		},
236		OnReasoningDelta: func(id string, text string) error {
237			currentAssistant.AppendReasoningContent(text)
238			return a.messages.Update(genCtx, *currentAssistant)
239		},
240		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
241			// handle anthropic signature
242			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
243				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
244					currentAssistant.AppendReasoningSignature(reasoning.Signature)
245				}
246			}
247			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
248				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
249					currentAssistant.AppendReasoningSignature(reasoning.Signature)
250				}
251			}
252			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
253				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
254					currentAssistant.SetReasoningResponsesData(reasoning)
255				}
256			}
257			currentAssistant.FinishThinking()
258			return a.messages.Update(genCtx, *currentAssistant)
259		},
260		OnTextDelta: func(id string, text string) error {
261			currentAssistant.AppendContent(text)
262			return a.messages.Update(genCtx, *currentAssistant)
263		},
264		OnToolInputStart: func(id string, toolName string) error {
265			toolCall := message.ToolCall{
266				ID:               id,
267				Name:             toolName,
268				ProviderExecuted: false,
269				Finished:         false,
270			}
271			currentAssistant.AddToolCall(toolCall)
272			return a.messages.Update(genCtx, *currentAssistant)
273		},
274		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
275			// TODO: implement
276		},
277		OnToolCall: func(tc fantasy.ToolCallContent) error {
278			toolCall := message.ToolCall{
279				ID:               tc.ToolCallID,
280				Name:             tc.ToolName,
281				Input:            tc.Input,
282				ProviderExecuted: false,
283				Finished:         true,
284			}
285			currentAssistant.AddToolCall(toolCall)
286			return a.messages.Update(genCtx, *currentAssistant)
287		},
288		OnToolResult: func(result fantasy.ToolResultContent) error {
289			var resultContent string
290			isError := false
291			switch result.Result.GetType() {
292			case fantasy.ToolResultContentTypeText:
293				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
294				if ok {
295					resultContent = r.Text
296				}
297			case fantasy.ToolResultContentTypeError:
298				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
299				if ok {
300					isError = true
301					resultContent = r.Error.Error()
302				}
303			case fantasy.ToolResultContentTypeMedia:
304				// TODO: handle this message type
305			}
306			toolResult := message.ToolResult{
307				ToolCallID: result.ToolCallID,
308				Name:       result.ToolName,
309				Content:    resultContent,
310				IsError:    isError,
311				Metadata:   result.ClientMetadata,
312			}
313			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
314				Role: message.Tool,
315				Parts: []message.ContentPart{
316					toolResult,
317				},
318			})
319			if createMsgErr != nil {
320				return createMsgErr
321			}
322			return nil
323		},
324		OnStepFinish: func(stepResult fantasy.StepResult) error {
325			finishReason := message.FinishReasonUnknown
326			switch stepResult.FinishReason {
327			case fantasy.FinishReasonLength:
328				finishReason = message.FinishReasonMaxTokens
329			case fantasy.FinishReasonStop:
330				finishReason = message.FinishReasonEndTurn
331			case fantasy.FinishReasonToolCalls:
332				finishReason = message.FinishReasonToolUse
333			}
334			currentAssistant.AddFinish(finishReason, "", "")
335			a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
336			sessionLock.Lock()
337			_, sessionErr := a.sessions.Save(genCtx, currentSession)
338			sessionLock.Unlock()
339			if sessionErr != nil {
340				return sessionErr
341			}
342			return a.messages.Update(genCtx, *currentAssistant)
343		},
344		StopWhen: []fantasy.StopCondition{
345			func(_ []fantasy.StepResult) bool {
346				contextWindow := a.largeModel.CatwalkCfg.ContextWindow
347				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
348				percentage := (float64(tokens) / float64(contextWindow)) * 100
349				if (percentage > 80) && !a.disableAutoSummarize {
350					shouldSummarize = true
351					return true
352				}
353				return false
354			},
355		},
356	})
357	if err != nil {
358		isCancelErr := errors.Is(err, context.Canceled)
359		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
360		if currentAssistant == nil {
361			return result, err
362		}
363		// Ensure we finish thinking on error to close the reasoning state
364		currentAssistant.FinishThinking()
365		toolCalls := currentAssistant.ToolCalls()
366		// INFO: we use the parent context here because the genCtx has been cancelled
367		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
368		if createErr != nil {
369			return nil, createErr
370		}
371		for _, tc := range toolCalls {
372			if !tc.Finished {
373				tc.Finished = true
374				tc.Input = "{}"
375				currentAssistant.AddToolCall(tc)
376				updateErr := a.messages.Update(ctx, *currentAssistant)
377				if updateErr != nil {
378					return nil, updateErr
379				}
380			}
381
382			found := false
383			for _, msg := range msgs {
384				if msg.Role == message.Tool {
385					for _, tr := range msg.ToolResults() {
386						if tr.ToolCallID == tc.ID {
387							found = true
388							break
389						}
390					}
391				}
392				if found {
393					break
394				}
395			}
396			if found {
397				continue
398			}
399			content := "There was an error while executing the tool"
400			if isCancelErr {
401				content = "Tool execution canceled by user"
402			} else if isPermissionErr {
403				content = "Permission denied"
404			}
405			toolResult := message.ToolResult{
406				ToolCallID: tc.ID,
407				Name:       tc.Name,
408				Content:    content,
409				IsError:    true,
410			}
411			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
412				Role: message.Tool,
413				Parts: []message.ContentPart{
414					toolResult,
415				},
416			})
417			if createErr != nil {
418				return nil, createErr
419			}
420		}
421		if isCancelErr {
422			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
423		} else if isPermissionErr {
424			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
425		} else {
426			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
427		}
428		// INFO: we use the parent context here because the genCtx has been cancelled
429		updateErr := a.messages.Update(ctx, *currentAssistant)
430		if updateErr != nil {
431			return nil, updateErr
432		}
433		return nil, err
434	}
435	wg.Wait()
436
437	if shouldSummarize {
438		a.activeRequests.Del(call.SessionID)
439		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
440			return nil, summarizeErr
441		}
442		// if the agent was not done...
443		if len(currentAssistant.ToolCalls()) > 0 {
444			existing, ok := a.messageQueue.Get(call.SessionID)
445			if !ok {
446				existing = []SessionAgentCall{}
447			}
448			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
449			existing = append(existing, call)
450			a.messageQueue.Set(call.SessionID, existing)
451		}
452	}
453
454	// release active request before processing queued messages
455	a.activeRequests.Del(call.SessionID)
456	cancel()
457
458	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
459	if !ok || len(queuedMessages) == 0 {
460		return result, err
461	}
462	// there are queued messages restart the loop
463	firstQueuedMessage := queuedMessages[0]
464	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
465	return a.Run(ctx, firstQueuedMessage)
466}
467
468func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
469	if a.IsSessionBusy(sessionID) {
470		return ErrSessionBusy
471	}
472
473	currentSession, err := a.sessions.Get(ctx, sessionID)
474	if err != nil {
475		return fmt.Errorf("failed to get session: %w", err)
476	}
477	msgs, err := a.getSessionMessages(ctx, currentSession)
478	if err != nil {
479		return err
480	}
481	if len(msgs) == 0 {
482		// nothing to summarize
483		return nil
484	}
485
486	aiMsgs, _ := a.preparePrompt(msgs)
487
488	genCtx, cancel := context.WithCancel(ctx)
489	a.activeRequests.Set(sessionID, cancel)
490	defer a.activeRequests.Del(sessionID)
491	defer cancel()
492
493	agent := fantasy.NewAgent(a.largeModel.Model,
494		fantasy.WithSystemPrompt(string(summaryPrompt)),
495	)
496	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
497		Role:             message.Assistant,
498		Model:            a.largeModel.Model.Model(),
499		Provider:         a.largeModel.Model.Provider(),
500		IsSummaryMessage: true,
501	})
502	if err != nil {
503		return err
504	}
505
506	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
507		Prompt:          "Provide a detailed summary of our conversation above.",
508		Messages:        aiMsgs,
509		ProviderOptions: opts,
510		OnReasoningDelta: func(id string, text string) error {
511			summaryMessage.AppendReasoningContent(text)
512			return a.messages.Update(genCtx, summaryMessage)
513		},
514		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
515			// handle anthropic signature
516			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
517				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
518					summaryMessage.AppendReasoningSignature(signature.Signature)
519				}
520			}
521			summaryMessage.FinishThinking()
522			return a.messages.Update(genCtx, summaryMessage)
523		},
524		OnTextDelta: func(id, text string) error {
525			summaryMessage.AppendContent(text)
526			return a.messages.Update(genCtx, summaryMessage)
527		},
528	})
529	if err != nil {
530		isCancelErr := errors.Is(err, context.Canceled)
531		if isCancelErr {
532			// User cancelled summarize we need to remove the summary message
533			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
534			return deleteErr
535		}
536		return err
537	}
538
539	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
540	err = a.messages.Update(genCtx, summaryMessage)
541	if err != nil {
542		return err
543	}
544
545	a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
546
547	// just in case get just the last usage
548	usage := resp.Response.Usage
549	currentSession.SummaryMessageID = summaryMessage.ID
550	currentSession.CompletionTokens = usage.OutputTokens
551	currentSession.PromptTokens = 0
552	_, err = a.sessions.Save(genCtx, currentSession)
553	return err
554}
555
556func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
557	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
558		return fantasy.ProviderOptions{}
559	}
560	return fantasy.ProviderOptions{
561		anthropic.Name: &anthropic.ProviderCacheControlOptions{
562			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
563		},
564		bedrock.Name: &anthropic.ProviderCacheControlOptions{
565			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
566		},
567	}
568}
569
570func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
571	var attachmentParts []message.ContentPart
572	for _, attachment := range call.Attachments {
573		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
574	}
575	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
576	parts = append(parts, attachmentParts...)
577	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
578		Role:  message.User,
579		Parts: parts,
580	})
581	if err != nil {
582		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
583	}
584	return msg, nil
585}
586
587func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
588	var history []fantasy.Message
589	for _, m := range msgs {
590		if len(m.Parts) == 0 {
591			continue
592		}
593		// Assistant message without content or tool calls (cancelled before it returned anything)
594		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
595			continue
596		}
597		history = append(history, m.ToAIMessage()...)
598	}
599
600	var files []fantasy.FilePart
601	for _, attachment := range attachments {
602		files = append(files, fantasy.FilePart{
603			Filename:  attachment.FileName,
604			Data:      attachment.Content,
605			MediaType: attachment.MimeType,
606		})
607	}
608
609	return history, files
610}
611
612func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
613	msgs, err := a.messages.List(ctx, session.ID)
614	if err != nil {
615		return nil, fmt.Errorf("failed to list messages: %w", err)
616	}
617
618	if session.SummaryMessageID != "" {
619		summaryMsgInex := -1
620		for i, msg := range msgs {
621			if msg.ID == session.SummaryMessageID {
622				summaryMsgInex = i
623				break
624			}
625		}
626		if summaryMsgInex != -1 {
627			msgs = msgs[summaryMsgInex:]
628			msgs[0].Role = message.User
629		}
630	}
631	return msgs, nil
632}
633
634func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
635	if prompt == "" {
636		return
637	}
638
639	var maxOutput int64 = 40
640	if a.smallModel.CatwalkCfg.CanReason {
641		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
642	}
643
644	agent := fantasy.NewAgent(a.smallModel.Model,
645		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
646		fantasy.WithMaxOutputTokens(maxOutput),
647	)
648
649	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
650		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
651	})
652	if err != nil {
653		slog.Error("error generating title", "err", err)
654		return
655	}
656
657	title := resp.Response.Content.Text()
658
659	title = strings.ReplaceAll(title, "\n", " ")
660
661	// remove thinking tags if present
662	if idx := strings.Index(title, "</think>"); idx > 0 {
663		title = title[idx+len("</think>"):]
664	}
665
666	title = strings.TrimSpace(title)
667	if title == "" {
668		slog.Warn("failed to generate title", "warn", "empty title")
669		return
670	}
671
672	session.Title = title
673	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
674	_, saveErr := a.sessions.Save(ctx, *session)
675	if saveErr != nil {
676		slog.Error("failed to save session title & usage", "error", saveErr)
677		return
678	}
679}
680
681func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
682	modelConfig := model.CatwalkCfg
683	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
684		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
685		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
686		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
687	session.Cost += cost
688	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
689	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
690}
691
692func (a *sessionAgent) Cancel(sessionID string) {
693	// Cancel regular requests
694	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
695		slog.Info("Request cancellation initiated", "session_id", sessionID)
696		cancel()
697	}
698
699	// Also check for summarize requests
700	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
701		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
702		cancel()
703	}
704
705	if a.QueuedPrompts(sessionID) > 0 {
706		slog.Info("Clearing queued prompts", "session_id", sessionID)
707		a.messageQueue.Del(sessionID)
708	}
709}
710
711func (a *sessionAgent) ClearQueue(sessionID string) {
712	if a.QueuedPrompts(sessionID) > 0 {
713		slog.Info("Clearing queued prompts", "session_id", sessionID)
714		a.messageQueue.Del(sessionID)
715	}
716}
717
718func (a *sessionAgent) CancelAll() {
719	if !a.IsBusy() {
720		return
721	}
722	for key := range a.activeRequests.Seq2() {
723		a.Cancel(key) // key is sessionID
724	}
725
726	timeout := time.After(5 * time.Second)
727	for a.IsBusy() {
728		select {
729		case <-timeout:
730			return
731		default:
732			time.Sleep(200 * time.Millisecond)
733		}
734	}
735}
736
737func (a *sessionAgent) IsBusy() bool {
738	var busy bool
739	for cancelFunc := range a.activeRequests.Seq() {
740		if cancelFunc != nil {
741			busy = true
742			break
743		}
744	}
745	return busy
746}
747
748func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
749	_, busy := a.activeRequests.Get(sessionID)
750	return busy
751}
752
753func (a *sessionAgent) QueuedPrompts(sessionID string) int {
754	l, ok := a.messageQueue.Get(sessionID)
755	if !ok {
756		return 0
757	}
758	return len(l)
759}
760
761func (a *sessionAgent) SetModels(large Model, small Model) {
762	a.largeModel = large
763	a.smallModel = small
764}
765
766func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
767	a.tools = tools
768}
769
770func (a *sessionAgent) Model() Model {
771	return a.largeModel
772}