agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"charm.land/fantasy"
 14	"charm.land/fantasy/providers/anthropic"
 15	"charm.land/fantasy/providers/openai"
 16	"github.com/charmbracelet/catwalk/pkg/catwalk"
 17	"github.com/charmbracelet/crush/internal/agent/tools"
 18	"github.com/charmbracelet/crush/internal/config"
 19	"github.com/charmbracelet/crush/internal/csync"
 20	"github.com/charmbracelet/crush/internal/message"
 21	"github.com/charmbracelet/crush/internal/permission"
 22	"github.com/charmbracelet/crush/internal/session"
 23)
 24
 25//go:embed templates/title.md
 26var titlePrompt []byte
 27
 28//go:embed templates/summary.md
 29var summaryPrompt []byte
 30
 31type SessionAgentCall struct {
 32	SessionID        string
 33	Prompt           string
 34	ProviderOptions  fantasy.ProviderOptions
 35	Attachments      []message.Attachment
 36	MaxOutputTokens  int64
 37	Temperature      *float64
 38	TopP             *float64
 39	TopK             *int64
 40	FrequencyPenalty *float64
 41	PresencePenalty  *float64
 42}
 43
 44type SessionAgent interface {
 45	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 46	SetModels(large Model, small Model)
 47	SetTools(tools []fantasy.AgentTool)
 48	Cancel(sessionID string)
 49	CancelAll()
 50	IsSessionBusy(sessionID string) bool
 51	IsBusy() bool
 52	QueuedPrompts(sessionID string) int
 53	ClearQueue(sessionID string)
 54	Summarize(context.Context, string, fantasy.ProviderOptions) error
 55	Model() Model
 56}
 57
 58type Model struct {
 59	Model      fantasy.LanguageModel
 60	CatwalkCfg catwalk.Model
 61	ModelCfg   config.SelectedModel
 62}
 63
 64type sessionAgent struct {
 65	largeModel           Model
 66	smallModel           Model
 67	systemPrompt         string
 68	tools                []fantasy.AgentTool
 69	sessions             session.Service
 70	messages             message.Service
 71	disableAutoSummarize bool
 72
 73	messageQueue   *csync.Map[string, []SessionAgentCall]
 74	activeRequests *csync.Map[string, context.CancelFunc]
 75}
 76
 77type SessionAgentOptions struct {
 78	LargeModel           Model
 79	SmallModel           Model
 80	SystemPrompt         string
 81	DisableAutoSummarize bool
 82	Sessions             session.Service
 83	Messages             message.Service
 84	Tools                []fantasy.AgentTool
 85}
 86
 87func NewSessionAgent(
 88	opts SessionAgentOptions,
 89) SessionAgent {
 90	return &sessionAgent{
 91		largeModel:           opts.LargeModel,
 92		smallModel:           opts.SmallModel,
 93		systemPrompt:         opts.SystemPrompt,
 94		sessions:             opts.Sessions,
 95		messages:             opts.Messages,
 96		disableAutoSummarize: opts.DisableAutoSummarize,
 97		tools:                opts.Tools,
 98		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 99		activeRequests:       csync.NewMap[string, context.CancelFunc](),
100	}
101}
102
103func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
104	if call.Prompt == "" {
105		return nil, ErrEmptyPrompt
106	}
107	if call.SessionID == "" {
108		return nil, ErrSessionMissing
109	}
110
111	// Queue the message if busy
112	if a.IsSessionBusy(call.SessionID) {
113		existing, ok := a.messageQueue.Get(call.SessionID)
114		if !ok {
115			existing = []SessionAgentCall{}
116		}
117		existing = append(existing, call)
118		a.messageQueue.Set(call.SessionID, existing)
119		return nil, nil
120	}
121
122	if len(a.tools) > 0 {
123		// add anthropic caching to the last tool
124		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
125	}
126
127	agent := fantasy.NewAgent(
128		a.largeModel.Model,
129		fantasy.WithSystemPrompt(a.systemPrompt),
130		fantasy.WithTools(a.tools...),
131	)
132
133	sessionLock := sync.Mutex{}
134	currentSession, err := a.sessions.Get(ctx, call.SessionID)
135	if err != nil {
136		return nil, fmt.Errorf("failed to get session: %w", err)
137	}
138
139	msgs, err := a.getSessionMessages(ctx, currentSession)
140	if err != nil {
141		return nil, fmt.Errorf("failed to get session messages: %w", err)
142	}
143
144	var wg sync.WaitGroup
145	// Generate title if first message
146	if len(msgs) == 0 {
147		wg.Go(func() {
148			sessionLock.Lock()
149			a.generateTitle(ctx, &currentSession, call.Prompt)
150			sessionLock.Unlock()
151		})
152	}
153
154	// Add the user message to the session
155	_, err = a.createUserMessage(ctx, call)
156	if err != nil {
157		return nil, err
158	}
159
160	// add the session to the context
161	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
162
163	genCtx, cancel := context.WithCancel(ctx)
164	a.activeRequests.Set(call.SessionID, cancel)
165
166	defer cancel()
167	defer a.activeRequests.Del(call.SessionID)
168
169	history, files := a.preparePrompt(msgs, call.Attachments...)
170
171	var currentAssistant *message.Message
172	var shouldSummarize bool
173	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
174		Prompt:           call.Prompt,
175		Files:            files,
176		Messages:         history,
177		ProviderOptions:  call.ProviderOptions,
178		MaxOutputTokens:  &call.MaxOutputTokens,
179		TopP:             call.TopP,
180		Temperature:      call.Temperature,
181		PresencePenalty:  call.PresencePenalty,
182		TopK:             call.TopK,
183		FrequencyPenalty: call.FrequencyPenalty,
184		// Before each step create the new assistant message
185		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
186			prepared.Messages = options.Messages
187			// reset all cached items
188			for i := range prepared.Messages {
189				prepared.Messages[i].ProviderOptions = nil
190			}
191
192			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
193			a.messageQueue.Del(call.SessionID)
194			for _, queued := range queuedCalls {
195				userMessage, createErr := a.createUserMessage(callContext, queued)
196				if createErr != nil {
197					return callContext, prepared, createErr
198				}
199				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
200			}
201
202			lastSystemRoleInx := 0
203			systemMessageUpdated := false
204			for i, msg := range prepared.Messages {
205				// only add cache control to the last message
206				if msg.Role == fantasy.MessageRoleSystem {
207					lastSystemRoleInx = i
208				} else if !systemMessageUpdated {
209					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
210					systemMessageUpdated = true
211				}
212				// than add cache control to the last 2 messages
213				if i > len(prepared.Messages)-3 {
214					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
215				}
216			}
217
218			var assistantMsg message.Message
219			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
220				Role:     message.Assistant,
221				Parts:    []message.ContentPart{},
222				Model:    a.largeModel.ModelCfg.Model,
223				Provider: a.largeModel.ModelCfg.Provider,
224			})
225			if err != nil {
226				return callContext, prepared, err
227			}
228			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
229			currentAssistant = &assistantMsg
230			return callContext, prepared, err
231		},
232		OnReasoningDelta: func(id string, text string) error {
233			currentAssistant.AppendReasoningContent(text)
234			return a.messages.Update(genCtx, *currentAssistant)
235		},
236		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
237			// handle anthropic signature
238			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
239				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
240					currentAssistant.AppendReasoningSignature(reasoning.Signature)
241				}
242			}
243			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
244				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
245					currentAssistant.SetReasoningResponsesData(reasoning)
246				}
247			}
248			currentAssistant.FinishThinking()
249			return a.messages.Update(genCtx, *currentAssistant)
250		},
251		OnTextDelta: func(id string, text string) error {
252			currentAssistant.AppendContent(text)
253			return a.messages.Update(genCtx, *currentAssistant)
254		},
255		OnToolInputStart: func(id string, toolName string) error {
256			toolCall := message.ToolCall{
257				ID:               id,
258				Name:             toolName,
259				ProviderExecuted: false,
260				Finished:         false,
261			}
262			currentAssistant.AddToolCall(toolCall)
263			return a.messages.Update(genCtx, *currentAssistant)
264		},
265		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
266			// TODO: implement
267		},
268		OnToolCall: func(tc fantasy.ToolCallContent) error {
269			toolCall := message.ToolCall{
270				ID:               tc.ToolCallID,
271				Name:             tc.ToolName,
272				Input:            tc.Input,
273				ProviderExecuted: false,
274				Finished:         true,
275			}
276			currentAssistant.AddToolCall(toolCall)
277			return a.messages.Update(genCtx, *currentAssistant)
278		},
279		OnToolResult: func(result fantasy.ToolResultContent) error {
280			var resultContent string
281			isError := false
282			switch result.Result.GetType() {
283			case fantasy.ToolResultContentTypeText:
284				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
285				if ok {
286					resultContent = r.Text
287				}
288			case fantasy.ToolResultContentTypeError:
289				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
290				if ok {
291					isError = true
292					resultContent = r.Error.Error()
293				}
294			case fantasy.ToolResultContentTypeMedia:
295				// TODO: handle this message type
296			}
297			toolResult := message.ToolResult{
298				ToolCallID: result.ToolCallID,
299				Name:       result.ToolName,
300				Content:    resultContent,
301				IsError:    isError,
302				Metadata:   result.ClientMetadata,
303			}
304			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
305				Role: message.Tool,
306				Parts: []message.ContentPart{
307					toolResult,
308				},
309			})
310			if createMsgErr != nil {
311				return createMsgErr
312			}
313			return nil
314		},
315		OnStepFinish: func(stepResult fantasy.StepResult) error {
316			finishReason := message.FinishReasonUnknown
317			switch stepResult.FinishReason {
318			case fantasy.FinishReasonLength:
319				finishReason = message.FinishReasonMaxTokens
320			case fantasy.FinishReasonStop:
321				finishReason = message.FinishReasonEndTurn
322			case fantasy.FinishReasonToolCalls:
323				finishReason = message.FinishReasonToolUse
324			}
325			currentAssistant.AddFinish(finishReason, "", "")
326			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
327			sessionLock.Lock()
328			_, sessionErr := a.sessions.Save(genCtx, currentSession)
329			sessionLock.Unlock()
330			if sessionErr != nil {
331				return sessionErr
332			}
333			return a.messages.Update(genCtx, *currentAssistant)
334		},
335		StopWhen: []fantasy.StopCondition{
336			func(_ []fantasy.StepResult) bool {
337				contextWindow := a.largeModel.CatwalkCfg.ContextWindow
338				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
339				percentage := (float64(tokens) / float64(contextWindow)) * 100
340				if (percentage > 80) && !a.disableAutoSummarize {
341					shouldSummarize = true
342					return true
343				}
344				return false
345			},
346		},
347	})
348	if err != nil {
349		isCancelErr := errors.Is(err, context.Canceled)
350		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
351		if currentAssistant == nil {
352			return result, err
353		}
354		// Ensure we finish thinking on error to close the reasoning state
355		currentAssistant.FinishThinking()
356		toolCalls := currentAssistant.ToolCalls()
357		// INFO: we use the parent context here because the genCtx has been cancelled
358		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
359		if createErr != nil {
360			return nil, createErr
361		}
362		for _, tc := range toolCalls {
363			if !tc.Finished {
364				tc.Finished = true
365				tc.Input = "{}"
366				currentAssistant.AddToolCall(tc)
367				updateErr := a.messages.Update(ctx, *currentAssistant)
368				if updateErr != nil {
369					return nil, updateErr
370				}
371			}
372
373			found := false
374			for _, msg := range msgs {
375				if msg.Role == message.Tool {
376					for _, tr := range msg.ToolResults() {
377						if tr.ToolCallID == tc.ID {
378							found = true
379							break
380						}
381					}
382				}
383				if found {
384					break
385				}
386			}
387			if found {
388				continue
389			}
390			content := "There was an error while executing the tool"
391			if isCancelErr {
392				content = "Tool execution canceled by user"
393			} else if isPermissionErr {
394				content = "Permission denied"
395			}
396			toolResult := message.ToolResult{
397				ToolCallID: tc.ID,
398				Name:       tc.Name,
399				Content:    content,
400				IsError:    true,
401			}
402			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
403				Role: message.Tool,
404				Parts: []message.ContentPart{
405					toolResult,
406				},
407			})
408			if createErr != nil {
409				return nil, createErr
410			}
411		}
412		if isCancelErr {
413			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
414		} else if isPermissionErr {
415			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
416		} else {
417			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
418		}
419		// INFO: we use the parent context here because the genCtx has been cancelled
420		updateErr := a.messages.Update(ctx, *currentAssistant)
421		if updateErr != nil {
422			return nil, updateErr
423		}
424		return nil, err
425	}
426	wg.Wait()
427
428	if shouldSummarize {
429		a.activeRequests.Del(call.SessionID)
430		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
431			return nil, summarizeErr
432		}
433		// if the agent was not done...
434		if len(currentAssistant.ToolCalls()) > 0 {
435			existing, ok := a.messageQueue.Get(call.SessionID)
436			if !ok {
437				existing = []SessionAgentCall{}
438			}
439			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
440			existing = append(existing, call)
441			a.messageQueue.Set(call.SessionID, existing)
442		}
443	}
444
445	// release active request before processing queued messages
446	a.activeRequests.Del(call.SessionID)
447	cancel()
448
449	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
450	if !ok || len(queuedMessages) == 0 {
451		return result, err
452	}
453	// there are queued messages restart the loop
454	firstQueuedMessage := queuedMessages[0]
455	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
456	return a.Run(ctx, firstQueuedMessage)
457}
458
459func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
460	if a.IsSessionBusy(sessionID) {
461		return ErrSessionBusy
462	}
463
464	currentSession, err := a.sessions.Get(ctx, sessionID)
465	if err != nil {
466		return fmt.Errorf("failed to get session: %w", err)
467	}
468	msgs, err := a.getSessionMessages(ctx, currentSession)
469	if err != nil {
470		return err
471	}
472	if len(msgs) == 0 {
473		// nothing to summarize
474		return nil
475	}
476
477	aiMsgs, _ := a.preparePrompt(msgs)
478
479	genCtx, cancel := context.WithCancel(ctx)
480	a.activeRequests.Set(sessionID, cancel)
481	defer a.activeRequests.Del(sessionID)
482	defer cancel()
483
484	agent := fantasy.NewAgent(a.largeModel.Model,
485		fantasy.WithSystemPrompt(string(summaryPrompt)),
486	)
487	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
488		Role:             message.Assistant,
489		Model:            a.largeModel.Model.Model(),
490		Provider:         a.largeModel.Model.Provider(),
491		IsSummaryMessage: true,
492	})
493	if err != nil {
494		return err
495	}
496
497	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
498		Prompt:          "Provide a detailed summary of our conversation above.",
499		Messages:        aiMsgs,
500		ProviderOptions: opts,
501		OnReasoningDelta: func(id string, text string) error {
502			summaryMessage.AppendReasoningContent(text)
503			return a.messages.Update(genCtx, summaryMessage)
504		},
505		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
506			// handle anthropic signature
507			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
508				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
509					summaryMessage.AppendReasoningSignature(signature.Signature)
510				}
511			}
512			summaryMessage.FinishThinking()
513			return a.messages.Update(genCtx, summaryMessage)
514		},
515		OnTextDelta: func(id, text string) error {
516			summaryMessage.AppendContent(text)
517			return a.messages.Update(genCtx, summaryMessage)
518		},
519	})
520	if err != nil {
521		isCancelErr := errors.Is(err, context.Canceled)
522		if isCancelErr {
523			// User cancelled summarize we need to remove the summary message
524			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
525			return deleteErr
526		}
527		return err
528	}
529
530	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
531	err = a.messages.Update(genCtx, summaryMessage)
532	if err != nil {
533		return err
534	}
535
536	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
537
538	// just in case get just the last usage
539	usage := resp.Response.Usage
540	currentSession.SummaryMessageID = summaryMessage.ID
541	currentSession.CompletionTokens = usage.OutputTokens
542	currentSession.PromptTokens = 0
543	_, err = a.sessions.Save(genCtx, currentSession)
544	return err
545}
546
547func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
548	return fantasy.ProviderOptions{
549		anthropic.Name: &anthropic.ProviderCacheControlOptions{
550			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
551		},
552	}
553}
554
555func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
556	var attachmentParts []message.ContentPart
557	for _, attachment := range call.Attachments {
558		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
559	}
560	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
561	parts = append(parts, attachmentParts...)
562	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
563		Role:  message.User,
564		Parts: parts,
565	})
566	if err != nil {
567		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
568	}
569	return msg, nil
570}
571
572func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
573	var history []fantasy.Message
574	for _, m := range msgs {
575		if len(m.Parts) == 0 {
576			continue
577		}
578		// Assistant message without content or tool calls (cancelled before it returned anything)
579		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
580			continue
581		}
582		history = append(history, m.ToAIMessage()...)
583	}
584
585	var files []fantasy.FilePart
586	for _, attachment := range attachments {
587		files = append(files, fantasy.FilePart{
588			Filename:  attachment.FileName,
589			Data:      attachment.Content,
590			MediaType: attachment.MimeType,
591		})
592	}
593
594	return history, files
595}
596
597func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
598	msgs, err := a.messages.List(ctx, session.ID)
599	if err != nil {
600		return nil, fmt.Errorf("failed to list messages: %w", err)
601	}
602
603	if session.SummaryMessageID != "" {
604		summaryMsgInex := -1
605		for i, msg := range msgs {
606			if msg.ID == session.SummaryMessageID {
607				summaryMsgInex = i
608				break
609			}
610		}
611		if summaryMsgInex != -1 {
612			msgs = msgs[summaryMsgInex:]
613			msgs[0].Role = message.User
614		}
615	}
616	return msgs, nil
617}
618
619func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
620	if prompt == "" {
621		return
622	}
623
624	var maxOutput int64 = 40
625	if a.smallModel.CatwalkCfg.CanReason {
626		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
627	}
628
629	agent := fantasy.NewAgent(a.smallModel.Model,
630		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
631		fantasy.WithMaxOutputTokens(maxOutput),
632	)
633
634	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
635		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
636	})
637	if err != nil {
638		slog.Error("error generating title", "err", err)
639		return
640	}
641
642	title := resp.Response.Content.Text()
643
644	title = strings.ReplaceAll(title, "\n", " ")
645
646	// remove thinking tags if present
647	if idx := strings.Index(title, "</think>"); idx > 0 {
648		title = title[idx+len("</think>"):]
649	}
650
651	title = strings.TrimSpace(title)
652	if title == "" {
653		slog.Warn("failed to generate title", "warn", "empty title")
654		return
655	}
656
657	session.Title = title
658	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
659	_, saveErr := a.sessions.Save(ctx, *session)
660	if saveErr != nil {
661		slog.Error("failed to save session title & usage", "error", saveErr)
662		return
663	}
664}
665
666func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
667	modelConfig := model.CatwalkCfg
668	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
669		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
670		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
671		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
672	session.Cost += cost
673	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
674	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
675}
676
677func (a *sessionAgent) Cancel(sessionID string) {
678	// Cancel regular requests
679	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
680		slog.Info("Request cancellation initiated", "session_id", sessionID)
681		cancel()
682	}
683
684	// Also check for summarize requests
685	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
686		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
687		cancel()
688	}
689
690	if a.QueuedPrompts(sessionID) > 0 {
691		slog.Info("Clearing queued prompts", "session_id", sessionID)
692		a.messageQueue.Del(sessionID)
693	}
694}
695
696func (a *sessionAgent) ClearQueue(sessionID string) {
697	if a.QueuedPrompts(sessionID) > 0 {
698		slog.Info("Clearing queued prompts", "session_id", sessionID)
699		a.messageQueue.Del(sessionID)
700	}
701}
702
703func (a *sessionAgent) CancelAll() {
704	if !a.IsBusy() {
705		return
706	}
707	for key := range a.activeRequests.Seq2() {
708		a.Cancel(key) // key is sessionID
709	}
710
711	timeout := time.After(5 * time.Second)
712	for a.IsBusy() {
713		select {
714		case <-timeout:
715			return
716		default:
717			time.Sleep(200 * time.Millisecond)
718		}
719	}
720}
721
722func (a *sessionAgent) IsBusy() bool {
723	var busy bool
724	for cancelFunc := range a.activeRequests.Seq() {
725		if cancelFunc != nil {
726			busy = true
727			break
728		}
729	}
730	return busy
731}
732
733func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
734	_, busy := a.activeRequests.Get(sessionID)
735	return busy
736}
737
738func (a *sessionAgent) QueuedPrompts(sessionID string) int {
739	l, ok := a.messageQueue.Get(sessionID)
740	if !ok {
741		return 0
742	}
743	return len(l)
744}
745
746func (a *sessionAgent) SetModels(large Model, small Model) {
747	a.largeModel = large
748	a.smallModel = small
749}
750
751func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
752	a.tools = tools
753}
754
755func (a *sessionAgent) Model() Model {
756	return a.largeModel
757}