agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"strconv"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"charm.land/fantasy/providers/anthropic"
 17	"charm.land/fantasy/providers/bedrock"
 18	"charm.land/fantasy/providers/google"
 19	"charm.land/fantasy/providers/openai"
 20	"github.com/charmbracelet/catwalk/pkg/catwalk"
 21	"github.com/charmbracelet/crush/internal/agent/tools"
 22	"github.com/charmbracelet/crush/internal/config"
 23	"github.com/charmbracelet/crush/internal/csync"
 24	"github.com/charmbracelet/crush/internal/message"
 25	"github.com/charmbracelet/crush/internal/permission"
 26	"github.com/charmbracelet/crush/internal/session"
 27)
 28
 29//go:embed templates/title.md
 30var titlePrompt []byte
 31
 32//go:embed templates/summary.md
 33var summaryPrompt []byte
 34
 35type SessionAgentCall struct {
 36	SessionID        string
 37	Prompt           string
 38	ProviderOptions  fantasy.ProviderOptions
 39	Attachments      []message.Attachment
 40	MaxOutputTokens  int64
 41	Temperature      *float64
 42	TopP             *float64
 43	TopK             *int64
 44	FrequencyPenalty *float64
 45	PresencePenalty  *float64
 46}
 47
 48type SessionAgent interface {
 49	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 50	SetModels(large Model, small Model)
 51	SetTools(tools []fantasy.AgentTool)
 52	Cancel(sessionID string)
 53	CancelAll()
 54	IsSessionBusy(sessionID string) bool
 55	IsBusy() bool
 56	QueuedPrompts(sessionID string) int
 57	ClearQueue(sessionID string)
 58	Summarize(context.Context, string, fantasy.ProviderOptions) error
 59	Model() Model
 60}
 61
 62type Model struct {
 63	Model      fantasy.LanguageModel
 64	CatwalkCfg catwalk.Model
 65	ModelCfg   config.SelectedModel
 66}
 67
 68type sessionAgent struct {
 69	largeModel           Model
 70	smallModel           Model
 71	systemPrompt         string
 72	tools                []fantasy.AgentTool
 73	sessions             session.Service
 74	messages             message.Service
 75	disableAutoSummarize bool
 76	isYolo               bool
 77
 78	messageQueue   *csync.Map[string, []SessionAgentCall]
 79	activeRequests *csync.Map[string, context.CancelFunc]
 80}
 81
 82type SessionAgentOptions struct {
 83	LargeModel           Model
 84	SmallModel           Model
 85	SystemPrompt         string
 86	DisableAutoSummarize bool
 87	IsYolo               bool
 88	Sessions             session.Service
 89	Messages             message.Service
 90	Tools                []fantasy.AgentTool
 91}
 92
 93func NewSessionAgent(
 94	opts SessionAgentOptions,
 95) SessionAgent {
 96	return &sessionAgent{
 97		largeModel:           opts.LargeModel,
 98		smallModel:           opts.SmallModel,
 99		systemPrompt:         opts.SystemPrompt,
100		sessions:             opts.Sessions,
101		messages:             opts.Messages,
102		disableAutoSummarize: opts.DisableAutoSummarize,
103		tools:                opts.Tools,
104		isYolo:               opts.IsYolo,
105		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
106		activeRequests:       csync.NewMap[string, context.CancelFunc](),
107	}
108}
109
110func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
111	if call.Prompt == "" {
112		return nil, ErrEmptyPrompt
113	}
114	if call.SessionID == "" {
115		return nil, ErrSessionMissing
116	}
117
118	// Queue the message if busy
119	if a.IsSessionBusy(call.SessionID) {
120		existing, ok := a.messageQueue.Get(call.SessionID)
121		if !ok {
122			existing = []SessionAgentCall{}
123		}
124		existing = append(existing, call)
125		a.messageQueue.Set(call.SessionID, existing)
126		return nil, nil
127	}
128
129	if len(a.tools) > 0 {
130		// add anthropic caching to the last tool
131		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
132	}
133
134	agent := fantasy.NewAgent(
135		a.largeModel.Model,
136		fantasy.WithSystemPrompt(a.systemPrompt),
137		fantasy.WithTools(a.tools...),
138	)
139
140	sessionLock := sync.Mutex{}
141	currentSession, err := a.sessions.Get(ctx, call.SessionID)
142	if err != nil {
143		return nil, fmt.Errorf("failed to get session: %w", err)
144	}
145
146	msgs, err := a.getSessionMessages(ctx, currentSession)
147	if err != nil {
148		return nil, fmt.Errorf("failed to get session messages: %w", err)
149	}
150
151	var wg sync.WaitGroup
152	// Generate title if first message
153	if len(msgs) == 0 {
154		wg.Go(func() {
155			sessionLock.Lock()
156			a.generateTitle(ctx, &currentSession, call.Prompt)
157			sessionLock.Unlock()
158		})
159	}
160
161	// Add the user message to the session
162	_, err = a.createUserMessage(ctx, call)
163	if err != nil {
164		return nil, err
165	}
166
167	// add the session to the context
168	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
169
170	genCtx, cancel := context.WithCancel(ctx)
171	a.activeRequests.Set(call.SessionID, cancel)
172
173	defer cancel()
174	defer a.activeRequests.Del(call.SessionID)
175
176	history, files := a.preparePrompt(msgs, call.Attachments...)
177
178	startTime := time.Now()
179	a.eventPromptSent(call.SessionID)
180
181	var currentAssistant *message.Message
182	var shouldSummarize bool
183	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
184		Prompt:           call.Prompt,
185		Files:            files,
186		Messages:         history,
187		ProviderOptions:  call.ProviderOptions,
188		MaxOutputTokens:  &call.MaxOutputTokens,
189		TopP:             call.TopP,
190		Temperature:      call.Temperature,
191		PresencePenalty:  call.PresencePenalty,
192		TopK:             call.TopK,
193		FrequencyPenalty: call.FrequencyPenalty,
194		// Before each step create the new assistant message
195		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
196			prepared.Messages = options.Messages
197			// reset all cached items
198			for i := range prepared.Messages {
199				prepared.Messages[i].ProviderOptions = nil
200			}
201
202			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
203			a.messageQueue.Del(call.SessionID)
204			for _, queued := range queuedCalls {
205				userMessage, createErr := a.createUserMessage(callContext, queued)
206				if createErr != nil {
207					return callContext, prepared, createErr
208				}
209				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
210			}
211
212			lastSystemRoleInx := 0
213			systemMessageUpdated := false
214			for i, msg := range prepared.Messages {
215				// only add cache control to the last message
216				if msg.Role == fantasy.MessageRoleSystem {
217					lastSystemRoleInx = i
218				} else if !systemMessageUpdated {
219					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
220					systemMessageUpdated = true
221				}
222				// than add cache control to the last 2 messages
223				if i > len(prepared.Messages)-3 {
224					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
225				}
226			}
227
228			var assistantMsg message.Message
229			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
230				Role:     message.Assistant,
231				Parts:    []message.ContentPart{},
232				Model:    a.largeModel.ModelCfg.Model,
233				Provider: a.largeModel.ModelCfg.Provider,
234			})
235			if err != nil {
236				return callContext, prepared, err
237			}
238			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
239			currentAssistant = &assistantMsg
240			return callContext, prepared, err
241		},
242		OnReasoningDelta: func(id string, text string) error {
243			currentAssistant.AppendReasoningContent(text)
244			return a.messages.Update(genCtx, *currentAssistant)
245		},
246		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
247			// handle anthropic signature
248			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
249				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
250					currentAssistant.AppendReasoningSignature(reasoning.Signature)
251				}
252			}
253			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
254				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
255					currentAssistant.AppendReasoningSignature(reasoning.Signature)
256				}
257			}
258			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
259				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
260					currentAssistant.SetReasoningResponsesData(reasoning)
261				}
262			}
263			currentAssistant.FinishThinking()
264			return a.messages.Update(genCtx, *currentAssistant)
265		},
266		OnTextDelta: func(id string, text string) error {
267			currentAssistant.AppendContent(text)
268			return a.messages.Update(genCtx, *currentAssistant)
269		},
270		OnToolInputStart: func(id string, toolName string) error {
271			toolCall := message.ToolCall{
272				ID:               id,
273				Name:             toolName,
274				ProviderExecuted: false,
275				Finished:         false,
276			}
277			currentAssistant.AddToolCall(toolCall)
278			return a.messages.Update(genCtx, *currentAssistant)
279		},
280		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
281			// TODO: implement
282		},
283		OnToolCall: func(tc fantasy.ToolCallContent) error {
284			toolCall := message.ToolCall{
285				ID:               tc.ToolCallID,
286				Name:             tc.ToolName,
287				Input:            tc.Input,
288				ProviderExecuted: false,
289				Finished:         true,
290			}
291			currentAssistant.AddToolCall(toolCall)
292			return a.messages.Update(genCtx, *currentAssistant)
293		},
294		OnToolResult: func(result fantasy.ToolResultContent) error {
295			var resultContent string
296			isError := false
297			switch result.Result.GetType() {
298			case fantasy.ToolResultContentTypeText:
299				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
300				if ok {
301					resultContent = r.Text
302				}
303			case fantasy.ToolResultContentTypeError:
304				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
305				if ok {
306					isError = true
307					resultContent = r.Error.Error()
308				}
309			case fantasy.ToolResultContentTypeMedia:
310				// TODO: handle this message type
311			}
312			toolResult := message.ToolResult{
313				ToolCallID: result.ToolCallID,
314				Name:       result.ToolName,
315				Content:    resultContent,
316				IsError:    isError,
317				Metadata:   result.ClientMetadata,
318			}
319			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
320				Role: message.Tool,
321				Parts: []message.ContentPart{
322					toolResult,
323				},
324			})
325			if createMsgErr != nil {
326				return createMsgErr
327			}
328			return nil
329		},
330		OnStepFinish: func(stepResult fantasy.StepResult) error {
331			finishReason := message.FinishReasonUnknown
332			switch stepResult.FinishReason {
333			case fantasy.FinishReasonLength:
334				finishReason = message.FinishReasonMaxTokens
335			case fantasy.FinishReasonStop:
336				finishReason = message.FinishReasonEndTurn
337			case fantasy.FinishReasonToolCalls:
338				finishReason = message.FinishReasonToolUse
339			}
340			currentAssistant.AddFinish(finishReason, "", "")
341			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
342			sessionLock.Lock()
343			_, sessionErr := a.sessions.Save(genCtx, currentSession)
344			sessionLock.Unlock()
345			if sessionErr != nil {
346				return sessionErr
347			}
348			return a.messages.Update(genCtx, *currentAssistant)
349		},
350		StopWhen: []fantasy.StopCondition{
351			func(_ []fantasy.StepResult) bool {
352				contextWindow := a.largeModel.CatwalkCfg.ContextWindow
353				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
354				percentage := (float64(tokens) / float64(contextWindow)) * 100
355				if (percentage > 80) && !a.disableAutoSummarize {
356					shouldSummarize = true
357					return true
358				}
359				return false
360			},
361		},
362	})
363
364	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
365
366	if err != nil {
367		isCancelErr := errors.Is(err, context.Canceled)
368		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
369		if currentAssistant == nil {
370			return result, err
371		}
372		// Ensure we finish thinking on error to close the reasoning state
373		currentAssistant.FinishThinking()
374		toolCalls := currentAssistant.ToolCalls()
375		// INFO: we use the parent context here because the genCtx has been cancelled
376		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
377		if createErr != nil {
378			return nil, createErr
379		}
380		for _, tc := range toolCalls {
381			if !tc.Finished {
382				tc.Finished = true
383				tc.Input = "{}"
384				currentAssistant.AddToolCall(tc)
385				updateErr := a.messages.Update(ctx, *currentAssistant)
386				if updateErr != nil {
387					return nil, updateErr
388				}
389			}
390
391			found := false
392			for _, msg := range msgs {
393				if msg.Role == message.Tool {
394					for _, tr := range msg.ToolResults() {
395						if tr.ToolCallID == tc.ID {
396							found = true
397							break
398						}
399					}
400				}
401				if found {
402					break
403				}
404			}
405			if found {
406				continue
407			}
408			content := "There was an error while executing the tool"
409			if isCancelErr {
410				content = "Tool execution canceled by user"
411			} else if isPermissionErr {
412				content = "Permission denied"
413			}
414			toolResult := message.ToolResult{
415				ToolCallID: tc.ID,
416				Name:       tc.Name,
417				Content:    content,
418				IsError:    true,
419			}
420			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
421				Role: message.Tool,
422				Parts: []message.ContentPart{
423					toolResult,
424				},
425			})
426			if createErr != nil {
427				return nil, createErr
428			}
429		}
430		if isCancelErr {
431			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
432		} else if isPermissionErr {
433			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
434		} else {
435			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
436		}
437		// INFO: we use the parent context here because the genCtx has been cancelled
438		updateErr := a.messages.Update(ctx, *currentAssistant)
439		if updateErr != nil {
440			return nil, updateErr
441		}
442		return nil, err
443	}
444	wg.Wait()
445
446	if shouldSummarize {
447		a.activeRequests.Del(call.SessionID)
448		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
449			return nil, summarizeErr
450		}
451		// if the agent was not done...
452		if len(currentAssistant.ToolCalls()) > 0 {
453			existing, ok := a.messageQueue.Get(call.SessionID)
454			if !ok {
455				existing = []SessionAgentCall{}
456			}
457			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
458			existing = append(existing, call)
459			a.messageQueue.Set(call.SessionID, existing)
460		}
461	}
462
463	// release active request before processing queued messages
464	a.activeRequests.Del(call.SessionID)
465	cancel()
466
467	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
468	if !ok || len(queuedMessages) == 0 {
469		return result, err
470	}
471	// there are queued messages restart the loop
472	firstQueuedMessage := queuedMessages[0]
473	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
474	return a.Run(ctx, firstQueuedMessage)
475}
476
477func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
478	if a.IsSessionBusy(sessionID) {
479		return ErrSessionBusy
480	}
481
482	currentSession, err := a.sessions.Get(ctx, sessionID)
483	if err != nil {
484		return fmt.Errorf("failed to get session: %w", err)
485	}
486	msgs, err := a.getSessionMessages(ctx, currentSession)
487	if err != nil {
488		return err
489	}
490	if len(msgs) == 0 {
491		// nothing to summarize
492		return nil
493	}
494
495	aiMsgs, _ := a.preparePrompt(msgs)
496
497	genCtx, cancel := context.WithCancel(ctx)
498	a.activeRequests.Set(sessionID, cancel)
499	defer a.activeRequests.Del(sessionID)
500	defer cancel()
501
502	agent := fantasy.NewAgent(a.largeModel.Model,
503		fantasy.WithSystemPrompt(string(summaryPrompt)),
504	)
505	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
506		Role:             message.Assistant,
507		Model:            a.largeModel.Model.Model(),
508		Provider:         a.largeModel.Model.Provider(),
509		IsSummaryMessage: true,
510	})
511	if err != nil {
512		return err
513	}
514
515	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
516		Prompt:          "Provide a detailed summary of our conversation above.",
517		Messages:        aiMsgs,
518		ProviderOptions: opts,
519		OnReasoningDelta: func(id string, text string) error {
520			summaryMessage.AppendReasoningContent(text)
521			return a.messages.Update(genCtx, summaryMessage)
522		},
523		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
524			// handle anthropic signature
525			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
526				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
527					summaryMessage.AppendReasoningSignature(signature.Signature)
528				}
529			}
530			summaryMessage.FinishThinking()
531			return a.messages.Update(genCtx, summaryMessage)
532		},
533		OnTextDelta: func(id, text string) error {
534			summaryMessage.AppendContent(text)
535			return a.messages.Update(genCtx, summaryMessage)
536		},
537	})
538	if err != nil {
539		isCancelErr := errors.Is(err, context.Canceled)
540		if isCancelErr {
541			// User cancelled summarize we need to remove the summary message
542			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
543			return deleteErr
544		}
545		return err
546	}
547
548	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
549	err = a.messages.Update(genCtx, summaryMessage)
550	if err != nil {
551		return err
552	}
553
554	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
555
556	// just in case get just the last usage
557	usage := resp.Response.Usage
558	currentSession.SummaryMessageID = summaryMessage.ID
559	currentSession.CompletionTokens = usage.OutputTokens
560	currentSession.PromptTokens = 0
561	_, err = a.sessions.Save(genCtx, currentSession)
562	return err
563}
564
565func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
566	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
567		return fantasy.ProviderOptions{}
568	}
569	return fantasy.ProviderOptions{
570		anthropic.Name: &anthropic.ProviderCacheControlOptions{
571			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
572		},
573		bedrock.Name: &anthropic.ProviderCacheControlOptions{
574			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
575		},
576	}
577}
578
579func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
580	var attachmentParts []message.ContentPart
581	for _, attachment := range call.Attachments {
582		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
583	}
584	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
585	parts = append(parts, attachmentParts...)
586	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
587		Role:  message.User,
588		Parts: parts,
589	})
590	if err != nil {
591		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
592	}
593	return msg, nil
594}
595
596func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
597	var history []fantasy.Message
598	for _, m := range msgs {
599		if len(m.Parts) == 0 {
600			continue
601		}
602		// Assistant message without content or tool calls (cancelled before it returned anything)
603		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
604			continue
605		}
606		history = append(history, m.ToAIMessage()...)
607	}
608
609	var files []fantasy.FilePart
610	for _, attachment := range attachments {
611		files = append(files, fantasy.FilePart{
612			Filename:  attachment.FileName,
613			Data:      attachment.Content,
614			MediaType: attachment.MimeType,
615		})
616	}
617
618	return history, files
619}
620
621func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
622	msgs, err := a.messages.List(ctx, session.ID)
623	if err != nil {
624		return nil, fmt.Errorf("failed to list messages: %w", err)
625	}
626
627	if session.SummaryMessageID != "" {
628		summaryMsgInex := -1
629		for i, msg := range msgs {
630			if msg.ID == session.SummaryMessageID {
631				summaryMsgInex = i
632				break
633			}
634		}
635		if summaryMsgInex != -1 {
636			msgs = msgs[summaryMsgInex:]
637			msgs[0].Role = message.User
638		}
639	}
640	return msgs, nil
641}
642
643func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
644	if prompt == "" {
645		return
646	}
647
648	var maxOutput int64 = 40
649	if a.smallModel.CatwalkCfg.CanReason {
650		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
651	}
652
653	agent := fantasy.NewAgent(a.smallModel.Model,
654		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
655		fantasy.WithMaxOutputTokens(maxOutput),
656	)
657
658	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
659		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
660	})
661	if err != nil {
662		slog.Error("error generating title", "err", err)
663		return
664	}
665
666	title := resp.Response.Content.Text()
667
668	title = strings.ReplaceAll(title, "\n", " ")
669
670	// remove thinking tags if present
671	if idx := strings.Index(title, "</think>"); idx > 0 {
672		title = title[idx+len("</think>"):]
673	}
674
675	title = strings.TrimSpace(title)
676	if title == "" {
677		slog.Warn("failed to generate title", "warn", "empty title")
678		return
679	}
680
681	session.Title = title
682	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
683	_, saveErr := a.sessions.Save(ctx, *session)
684	if saveErr != nil {
685		slog.Error("failed to save session title & usage", "error", saveErr)
686		return
687	}
688}
689
690func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
691	modelConfig := model.CatwalkCfg
692	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
693		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
694		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
695		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
696
697	a.eventTokensUsed(session.ID, model, usage, cost)
698
699	session.Cost += cost
700	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
701	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
702}
703
704func (a *sessionAgent) Cancel(sessionID string) {
705	// Cancel regular requests
706	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
707		slog.Info("Request cancellation initiated", "session_id", sessionID)
708		cancel()
709	}
710
711	// Also check for summarize requests
712	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
713		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
714		cancel()
715	}
716
717	if a.QueuedPrompts(sessionID) > 0 {
718		slog.Info("Clearing queued prompts", "session_id", sessionID)
719		a.messageQueue.Del(sessionID)
720	}
721}
722
723func (a *sessionAgent) ClearQueue(sessionID string) {
724	if a.QueuedPrompts(sessionID) > 0 {
725		slog.Info("Clearing queued prompts", "session_id", sessionID)
726		a.messageQueue.Del(sessionID)
727	}
728}
729
730func (a *sessionAgent) CancelAll() {
731	if !a.IsBusy() {
732		return
733	}
734	for key := range a.activeRequests.Seq2() {
735		a.Cancel(key) // key is sessionID
736	}
737
738	timeout := time.After(5 * time.Second)
739	for a.IsBusy() {
740		select {
741		case <-timeout:
742			return
743		default:
744			time.Sleep(200 * time.Millisecond)
745		}
746	}
747}
748
749func (a *sessionAgent) IsBusy() bool {
750	var busy bool
751	for cancelFunc := range a.activeRequests.Seq() {
752		if cancelFunc != nil {
753			busy = true
754			break
755		}
756	}
757	return busy
758}
759
760func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
761	_, busy := a.activeRequests.Get(sessionID)
762	return busy
763}
764
765func (a *sessionAgent) QueuedPrompts(sessionID string) int {
766	l, ok := a.messageQueue.Get(sessionID)
767	if !ok {
768		return 0
769	}
770	return len(l)
771}
772
773func (a *sessionAgent) SetModels(large Model, small Model) {
774	a.largeModel = large
775	a.smallModel = small
776}
777
778func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
779	a.tools = tools
780}
781
782func (a *sessionAgent) Model() Model {
783	return a.largeModel
784}