agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"strconv"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"charm.land/fantasy/providers/anthropic"
 17	"charm.land/fantasy/providers/bedrock"
 18	"charm.land/fantasy/providers/openai"
 19	"github.com/charmbracelet/catwalk/pkg/catwalk"
 20	"github.com/charmbracelet/crush/internal/agent/tools"
 21	"github.com/charmbracelet/crush/internal/config"
 22	"github.com/charmbracelet/crush/internal/csync"
 23	"github.com/charmbracelet/crush/internal/message"
 24	"github.com/charmbracelet/crush/internal/permission"
 25	"github.com/charmbracelet/crush/internal/session"
 26)
 27
 28//go:embed templates/title.md
 29var titlePrompt []byte
 30
 31//go:embed templates/summary.md
 32var summaryPrompt []byte
 33
 34type SessionAgentCall struct {
 35	SessionID        string
 36	Prompt           string
 37	ProviderOptions  fantasy.ProviderOptions
 38	Attachments      []message.Attachment
 39	MaxOutputTokens  int64
 40	Temperature      *float64
 41	TopP             *float64
 42	TopK             *int64
 43	FrequencyPenalty *float64
 44	PresencePenalty  *float64
 45}
 46
 47type SessionAgent interface {
 48	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 49	SetModels(large Model, small Model)
 50	SetTools(tools []fantasy.AgentTool)
 51	Cancel(sessionID string)
 52	CancelAll()
 53	IsSessionBusy(sessionID string) bool
 54	IsBusy() bool
 55	QueuedPrompts(sessionID string) int
 56	ClearQueue(sessionID string)
 57	Summarize(context.Context, string, fantasy.ProviderOptions) error
 58	Model() Model
 59}
 60
 61type Model struct {
 62	Model      fantasy.LanguageModel
 63	CatwalkCfg catwalk.Model
 64	ModelCfg   config.SelectedModel
 65}
 66
 67type sessionAgent struct {
 68	largeModel           Model
 69	smallModel           Model
 70	systemPrompt         string
 71	tools                []fantasy.AgentTool
 72	sessions             session.Service
 73	messages             message.Service
 74	disableAutoSummarize bool
 75
 76	messageQueue   *csync.Map[string, []SessionAgentCall]
 77	activeRequests *csync.Map[string, context.CancelFunc]
 78}
 79
 80type SessionAgentOptions struct {
 81	LargeModel           Model
 82	SmallModel           Model
 83	SystemPrompt         string
 84	DisableAutoSummarize bool
 85	Sessions             session.Service
 86	Messages             message.Service
 87	Tools                []fantasy.AgentTool
 88}
 89
 90func NewSessionAgent(
 91	opts SessionAgentOptions,
 92) SessionAgent {
 93	return &sessionAgent{
 94		largeModel:           opts.LargeModel,
 95		smallModel:           opts.SmallModel,
 96		systemPrompt:         opts.SystemPrompt,
 97		sessions:             opts.Sessions,
 98		messages:             opts.Messages,
 99		disableAutoSummarize: opts.DisableAutoSummarize,
100		tools:                opts.Tools,
101		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
102		activeRequests:       csync.NewMap[string, context.CancelFunc](),
103	}
104}
105
106func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
107	if call.Prompt == "" {
108		return nil, ErrEmptyPrompt
109	}
110	if call.SessionID == "" {
111		return nil, ErrSessionMissing
112	}
113
114	// Queue the message if busy
115	if a.IsSessionBusy(call.SessionID) {
116		existing, ok := a.messageQueue.Get(call.SessionID)
117		if !ok {
118			existing = []SessionAgentCall{}
119		}
120		existing = append(existing, call)
121		a.messageQueue.Set(call.SessionID, existing)
122		return nil, nil
123	}
124
125	if len(a.tools) > 0 {
126		// add anthropic caching to the last tool
127		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
128	}
129
130	agent := fantasy.NewAgent(
131		a.largeModel.Model,
132		fantasy.WithSystemPrompt(a.systemPrompt),
133		fantasy.WithTools(a.tools...),
134	)
135
136	sessionLock := sync.Mutex{}
137	currentSession, err := a.sessions.Get(ctx, call.SessionID)
138	if err != nil {
139		return nil, fmt.Errorf("failed to get session: %w", err)
140	}
141
142	msgs, err := a.getSessionMessages(ctx, currentSession)
143	if err != nil {
144		return nil, fmt.Errorf("failed to get session messages: %w", err)
145	}
146
147	var wg sync.WaitGroup
148	// Generate title if first message
149	if len(msgs) == 0 {
150		wg.Go(func() {
151			sessionLock.Lock()
152			a.generateTitle(ctx, &currentSession, call.Prompt)
153			sessionLock.Unlock()
154		})
155	}
156
157	// Add the user message to the session
158	_, err = a.createUserMessage(ctx, call)
159	if err != nil {
160		return nil, err
161	}
162
163	// add the session to the context
164	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
165
166	genCtx, cancel := context.WithCancel(ctx)
167	a.activeRequests.Set(call.SessionID, cancel)
168
169	defer cancel()
170	defer a.activeRequests.Del(call.SessionID)
171
172	history, files := a.preparePrompt(msgs, call.Attachments...)
173
174	var currentAssistant *message.Message
175	var shouldSummarize bool
176	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
177		Prompt:           call.Prompt,
178		Files:            files,
179		Messages:         history,
180		ProviderOptions:  call.ProviderOptions,
181		MaxOutputTokens:  &call.MaxOutputTokens,
182		TopP:             call.TopP,
183		Temperature:      call.Temperature,
184		PresencePenalty:  call.PresencePenalty,
185		TopK:             call.TopK,
186		FrequencyPenalty: call.FrequencyPenalty,
187		// Before each step create the new assistant message
188		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
189			prepared.Messages = options.Messages
190			// reset all cached items
191			for i := range prepared.Messages {
192				prepared.Messages[i].ProviderOptions = nil
193			}
194
195			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
196			a.messageQueue.Del(call.SessionID)
197			for _, queued := range queuedCalls {
198				userMessage, createErr := a.createUserMessage(callContext, queued)
199				if createErr != nil {
200					return callContext, prepared, createErr
201				}
202				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
203			}
204
205			lastSystemRoleInx := 0
206			systemMessageUpdated := false
207			for i, msg := range prepared.Messages {
208				// only add cache control to the last message
209				if msg.Role == fantasy.MessageRoleSystem {
210					lastSystemRoleInx = i
211				} else if !systemMessageUpdated {
212					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
213					systemMessageUpdated = true
214				}
215				// than add cache control to the last 2 messages
216				if i > len(prepared.Messages)-3 {
217					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
218				}
219			}
220
221			var assistantMsg message.Message
222			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
223				Role:     message.Assistant,
224				Parts:    []message.ContentPart{},
225				Model:    a.largeModel.ModelCfg.Model,
226				Provider: a.largeModel.ModelCfg.Provider,
227			})
228			if err != nil {
229				return callContext, prepared, err
230			}
231			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
232			currentAssistant = &assistantMsg
233			return callContext, prepared, err
234		},
235		OnReasoningDelta: func(id string, text string) error {
236			currentAssistant.AppendReasoningContent(text)
237			return a.messages.Update(genCtx, *currentAssistant)
238		},
239		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
240			// handle anthropic signature
241			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
242				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
243					currentAssistant.AppendReasoningSignature(reasoning.Signature)
244				}
245			}
246			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
247				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
248					currentAssistant.SetReasoningResponsesData(reasoning)
249				}
250			}
251			currentAssistant.FinishThinking()
252			return a.messages.Update(genCtx, *currentAssistant)
253		},
254		OnTextDelta: func(id string, text string) error {
255			currentAssistant.AppendContent(text)
256			return a.messages.Update(genCtx, *currentAssistant)
257		},
258		OnToolInputStart: func(id string, toolName string) error {
259			toolCall := message.ToolCall{
260				ID:               id,
261				Name:             toolName,
262				ProviderExecuted: false,
263				Finished:         false,
264			}
265			currentAssistant.AddToolCall(toolCall)
266			return a.messages.Update(genCtx, *currentAssistant)
267		},
268		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
269			// TODO: implement
270		},
271		OnToolCall: func(tc fantasy.ToolCallContent) error {
272			toolCall := message.ToolCall{
273				ID:               tc.ToolCallID,
274				Name:             tc.ToolName,
275				Input:            tc.Input,
276				ProviderExecuted: false,
277				Finished:         true,
278			}
279			currentAssistant.AddToolCall(toolCall)
280			return a.messages.Update(genCtx, *currentAssistant)
281		},
282		OnToolResult: func(result fantasy.ToolResultContent) error {
283			var resultContent string
284			isError := false
285			switch result.Result.GetType() {
286			case fantasy.ToolResultContentTypeText:
287				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
288				if ok {
289					resultContent = r.Text
290				}
291			case fantasy.ToolResultContentTypeError:
292				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
293				if ok {
294					isError = true
295					resultContent = r.Error.Error()
296				}
297			case fantasy.ToolResultContentTypeMedia:
298				// TODO: handle this message type
299			}
300			toolResult := message.ToolResult{
301				ToolCallID: result.ToolCallID,
302				Name:       result.ToolName,
303				Content:    resultContent,
304				IsError:    isError,
305				Metadata:   result.ClientMetadata,
306			}
307			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
308				Role: message.Tool,
309				Parts: []message.ContentPart{
310					toolResult,
311				},
312			})
313			if createMsgErr != nil {
314				return createMsgErr
315			}
316			return nil
317		},
318		OnStepFinish: func(stepResult fantasy.StepResult) error {
319			finishReason := message.FinishReasonUnknown
320			switch stepResult.FinishReason {
321			case fantasy.FinishReasonLength:
322				finishReason = message.FinishReasonMaxTokens
323			case fantasy.FinishReasonStop:
324				finishReason = message.FinishReasonEndTurn
325			case fantasy.FinishReasonToolCalls:
326				finishReason = message.FinishReasonToolUse
327			}
328			currentAssistant.AddFinish(finishReason, "", "")
329			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
330			sessionLock.Lock()
331			_, sessionErr := a.sessions.Save(genCtx, currentSession)
332			sessionLock.Unlock()
333			if sessionErr != nil {
334				return sessionErr
335			}
336			return a.messages.Update(genCtx, *currentAssistant)
337		},
338		StopWhen: []fantasy.StopCondition{
339			func(_ []fantasy.StepResult) bool {
340				contextWindow := a.largeModel.CatwalkCfg.ContextWindow
341				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
342				percentage := (float64(tokens) / float64(contextWindow)) * 100
343				if (percentage > 80) && !a.disableAutoSummarize {
344					shouldSummarize = true
345					return true
346				}
347				return false
348			},
349		},
350	})
351	if err != nil {
352		isCancelErr := errors.Is(err, context.Canceled)
353		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
354		if currentAssistant == nil {
355			return result, err
356		}
357		// Ensure we finish thinking on error to close the reasoning state
358		currentAssistant.FinishThinking()
359		toolCalls := currentAssistant.ToolCalls()
360		// INFO: we use the parent context here because the genCtx has been cancelled
361		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
362		if createErr != nil {
363			return nil, createErr
364		}
365		for _, tc := range toolCalls {
366			if !tc.Finished {
367				tc.Finished = true
368				tc.Input = "{}"
369				currentAssistant.AddToolCall(tc)
370				updateErr := a.messages.Update(ctx, *currentAssistant)
371				if updateErr != nil {
372					return nil, updateErr
373				}
374			}
375
376			found := false
377			for _, msg := range msgs {
378				if msg.Role == message.Tool {
379					for _, tr := range msg.ToolResults() {
380						if tr.ToolCallID == tc.ID {
381							found = true
382							break
383						}
384					}
385				}
386				if found {
387					break
388				}
389			}
390			if found {
391				continue
392			}
393			content := "There was an error while executing the tool"
394			if isCancelErr {
395				content = "Tool execution canceled by user"
396			} else if isPermissionErr {
397				content = "Permission denied"
398			}
399			toolResult := message.ToolResult{
400				ToolCallID: tc.ID,
401				Name:       tc.Name,
402				Content:    content,
403				IsError:    true,
404			}
405			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
406				Role: message.Tool,
407				Parts: []message.ContentPart{
408					toolResult,
409				},
410			})
411			if createErr != nil {
412				return nil, createErr
413			}
414		}
415		if isCancelErr {
416			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
417		} else if isPermissionErr {
418			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
419		} else {
420			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
421		}
422		// INFO: we use the parent context here because the genCtx has been cancelled
423		updateErr := a.messages.Update(ctx, *currentAssistant)
424		if updateErr != nil {
425			return nil, updateErr
426		}
427		return nil, err
428	}
429	wg.Wait()
430
431	if shouldSummarize {
432		a.activeRequests.Del(call.SessionID)
433		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
434			return nil, summarizeErr
435		}
436		// if the agent was not done...
437		if len(currentAssistant.ToolCalls()) > 0 {
438			existing, ok := a.messageQueue.Get(call.SessionID)
439			if !ok {
440				existing = []SessionAgentCall{}
441			}
442			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
443			existing = append(existing, call)
444			a.messageQueue.Set(call.SessionID, existing)
445		}
446	}
447
448	// release active request before processing queued messages
449	a.activeRequests.Del(call.SessionID)
450	cancel()
451
452	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
453	if !ok || len(queuedMessages) == 0 {
454		return result, err
455	}
456	// there are queued messages restart the loop
457	firstQueuedMessage := queuedMessages[0]
458	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
459	return a.Run(ctx, firstQueuedMessage)
460}
461
462func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
463	if a.IsSessionBusy(sessionID) {
464		return ErrSessionBusy
465	}
466
467	currentSession, err := a.sessions.Get(ctx, sessionID)
468	if err != nil {
469		return fmt.Errorf("failed to get session: %w", err)
470	}
471	msgs, err := a.getSessionMessages(ctx, currentSession)
472	if err != nil {
473		return err
474	}
475	if len(msgs) == 0 {
476		// nothing to summarize
477		return nil
478	}
479
480	aiMsgs, _ := a.preparePrompt(msgs)
481
482	genCtx, cancel := context.WithCancel(ctx)
483	a.activeRequests.Set(sessionID, cancel)
484	defer a.activeRequests.Del(sessionID)
485	defer cancel()
486
487	agent := fantasy.NewAgent(a.largeModel.Model,
488		fantasy.WithSystemPrompt(string(summaryPrompt)),
489	)
490	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
491		Role:             message.Assistant,
492		Model:            a.largeModel.Model.Model(),
493		Provider:         a.largeModel.Model.Provider(),
494		IsSummaryMessage: true,
495	})
496	if err != nil {
497		return err
498	}
499
500	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
501		Prompt:          "Provide a detailed summary of our conversation above.",
502		Messages:        aiMsgs,
503		ProviderOptions: opts,
504		OnReasoningDelta: func(id string, text string) error {
505			summaryMessage.AppendReasoningContent(text)
506			return a.messages.Update(genCtx, summaryMessage)
507		},
508		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
509			// handle anthropic signature
510			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
511				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
512					summaryMessage.AppendReasoningSignature(signature.Signature)
513				}
514			}
515			summaryMessage.FinishThinking()
516			return a.messages.Update(genCtx, summaryMessage)
517		},
518		OnTextDelta: func(id, text string) error {
519			summaryMessage.AppendContent(text)
520			return a.messages.Update(genCtx, summaryMessage)
521		},
522	})
523	if err != nil {
524		isCancelErr := errors.Is(err, context.Canceled)
525		if isCancelErr {
526			// User cancelled summarize we need to remove the summary message
527			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
528			return deleteErr
529		}
530		return err
531	}
532
533	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
534	err = a.messages.Update(genCtx, summaryMessage)
535	if err != nil {
536		return err
537	}
538
539	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
540
541	// just in case get just the last usage
542	usage := resp.Response.Usage
543	currentSession.SummaryMessageID = summaryMessage.ID
544	currentSession.CompletionTokens = usage.OutputTokens
545	currentSession.PromptTokens = 0
546	_, err = a.sessions.Save(genCtx, currentSession)
547	return err
548}
549
550func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
551	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
552		return fantasy.ProviderOptions{}
553	}
554	return fantasy.ProviderOptions{
555		anthropic.Name: &anthropic.ProviderCacheControlOptions{
556			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
557		},
558		bedrock.Name: &anthropic.ProviderCacheControlOptions{
559			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
560		},
561	}
562}
563
564func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
565	var attachmentParts []message.ContentPart
566	for _, attachment := range call.Attachments {
567		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
568	}
569	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
570	parts = append(parts, attachmentParts...)
571	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
572		Role:  message.User,
573		Parts: parts,
574	})
575	if err != nil {
576		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
577	}
578	return msg, nil
579}
580
581func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
582	var history []fantasy.Message
583	for _, m := range msgs {
584		if len(m.Parts) == 0 {
585			continue
586		}
587		// Assistant message without content or tool calls (cancelled before it returned anything)
588		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
589			continue
590		}
591		history = append(history, m.ToAIMessage()...)
592	}
593
594	var files []fantasy.FilePart
595	for _, attachment := range attachments {
596		files = append(files, fantasy.FilePart{
597			Filename:  attachment.FileName,
598			Data:      attachment.Content,
599			MediaType: attachment.MimeType,
600		})
601	}
602
603	return history, files
604}
605
606func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
607	msgs, err := a.messages.List(ctx, session.ID)
608	if err != nil {
609		return nil, fmt.Errorf("failed to list messages: %w", err)
610	}
611
612	if session.SummaryMessageID != "" {
613		summaryMsgInex := -1
614		for i, msg := range msgs {
615			if msg.ID == session.SummaryMessageID {
616				summaryMsgInex = i
617				break
618			}
619		}
620		if summaryMsgInex != -1 {
621			msgs = msgs[summaryMsgInex:]
622			msgs[0].Role = message.User
623		}
624	}
625	return msgs, nil
626}
627
628func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
629	if prompt == "" {
630		return
631	}
632
633	var maxOutput int64 = 40
634	if a.smallModel.CatwalkCfg.CanReason {
635		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
636	}
637
638	agent := fantasy.NewAgent(a.smallModel.Model,
639		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
640		fantasy.WithMaxOutputTokens(maxOutput),
641	)
642
643	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
644		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
645	})
646	if err != nil {
647		slog.Error("error generating title", "err", err)
648		return
649	}
650
651	title := resp.Response.Content.Text()
652
653	title = strings.ReplaceAll(title, "\n", " ")
654
655	// remove thinking tags if present
656	if idx := strings.Index(title, "</think>"); idx > 0 {
657		title = title[idx+len("</think>"):]
658	}
659
660	title = strings.TrimSpace(title)
661	if title == "" {
662		slog.Warn("failed to generate title", "warn", "empty title")
663		return
664	}
665
666	session.Title = title
667	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
668	_, saveErr := a.sessions.Save(ctx, *session)
669	if saveErr != nil {
670		slog.Error("failed to save session title & usage", "error", saveErr)
671		return
672	}
673}
674
675func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
676	modelConfig := model.CatwalkCfg
677	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
678		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
679		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
680		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
681	session.Cost += cost
682	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
683	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
684}
685
686func (a *sessionAgent) Cancel(sessionID string) {
687	// Cancel regular requests
688	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
689		slog.Info("Request cancellation initiated", "session_id", sessionID)
690		cancel()
691	}
692
693	// Also check for summarize requests
694	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
695		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
696		cancel()
697	}
698
699	if a.QueuedPrompts(sessionID) > 0 {
700		slog.Info("Clearing queued prompts", "session_id", sessionID)
701		a.messageQueue.Del(sessionID)
702	}
703}
704
705func (a *sessionAgent) ClearQueue(sessionID string) {
706	if a.QueuedPrompts(sessionID) > 0 {
707		slog.Info("Clearing queued prompts", "session_id", sessionID)
708		a.messageQueue.Del(sessionID)
709	}
710}
711
712func (a *sessionAgent) CancelAll() {
713	if !a.IsBusy() {
714		return
715	}
716	for key := range a.activeRequests.Seq2() {
717		a.Cancel(key) // key is sessionID
718	}
719
720	timeout := time.After(5 * time.Second)
721	for a.IsBusy() {
722		select {
723		case <-timeout:
724			return
725		default:
726			time.Sleep(200 * time.Millisecond)
727		}
728	}
729}
730
731func (a *sessionAgent) IsBusy() bool {
732	var busy bool
733	for cancelFunc := range a.activeRequests.Seq() {
734		if cancelFunc != nil {
735			busy = true
736			break
737		}
738	}
739	return busy
740}
741
742func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
743	_, busy := a.activeRequests.Get(sessionID)
744	return busy
745}
746
747func (a *sessionAgent) QueuedPrompts(sessionID string) int {
748	l, ok := a.messageQueue.Get(sessionID)
749	if !ok {
750		return 0
751	}
752	return len(l)
753}
754
755func (a *sessionAgent) SetModels(large Model, small Model) {
756	a.largeModel = large
757	a.smallModel = small
758}
759
760func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
761	a.tools = tools
762}
763
764func (a *sessionAgent) Model() Model {
765	return a.largeModel
766}