agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"strings"
 11	"sync"
 12	"time"
 13
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	"github.com/charmbracelet/crush/internal/agent/tools"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/message"
 19	"github.com/charmbracelet/crush/internal/permission"
 20	"github.com/charmbracelet/crush/internal/session"
 21	"github.com/charmbracelet/fantasy/ai"
 22	"github.com/charmbracelet/fantasy/anthropic"
 23)
 24
 25//go:embed templates/title.md
 26var titlePrompt []byte
 27
 28//go:embed templates/summary.md
 29var summaryPrompt []byte
 30
 31type SessionAgentCall struct {
 32	SessionID        string
 33	Prompt           string
 34	ProviderOptions  ai.ProviderOptions
 35	Attachments      []message.Attachment
 36	MaxOutputTokens  int64
 37	Temperature      *float64
 38	TopP             *float64
 39	TopK             *int64
 40	FrequencyPenalty *float64
 41	PresencePenalty  *float64
 42}
 43
 44type SessionAgent interface {
 45	Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
 46	SetModels(large Model, small Model)
 47	SetTools(tools []ai.AgentTool)
 48	Cancel(sessionID string)
 49	CancelAll()
 50	IsSessionBusy(sessionID string) bool
 51	IsBusy() bool
 52	QueuedPrompts(sessionID string) int
 53	ClearQueue(sessionID string)
 54	Summarize(context.Context, string) error
 55	Model() Model
 56}
 57
 58type Model struct {
 59	Model      ai.LanguageModel
 60	CatwalkCfg catwalk.Model
 61	ModelCfg   config.SelectedModel
 62}
 63
 64type sessionAgent struct {
 65	largeModel           Model
 66	smallModel           Model
 67	systemPrompt         string
 68	tools                []ai.AgentTool
 69	sessions             session.Service
 70	messages             message.Service
 71	disableAutoSummarize bool
 72
 73	messageQueue   *csync.Map[string, []SessionAgentCall]
 74	activeRequests *csync.Map[string, context.CancelFunc]
 75}
 76
 77type SessionAgentOptions struct {
 78	LargeModel           Model
 79	SmallModel           Model
 80	SystemPrompt         string
 81	DisableAutoSummarize bool
 82	Sessions             session.Service
 83	Messages             message.Service
 84	Tools                []ai.AgentTool
 85}
 86
 87func NewSessionAgent(
 88	opts SessionAgentOptions,
 89) SessionAgent {
 90	return &sessionAgent{
 91		largeModel:           opts.LargeModel,
 92		smallModel:           opts.SmallModel,
 93		systemPrompt:         opts.SystemPrompt,
 94		sessions:             opts.Sessions,
 95		messages:             opts.Messages,
 96		disableAutoSummarize: opts.DisableAutoSummarize,
 97		tools:                opts.Tools,
 98		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 99		activeRequests:       csync.NewMap[string, context.CancelFunc](),
100	}
101}
102
103func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
104	if call.Prompt == "" {
105		return nil, ErrEmptyPrompt
106	}
107	if call.SessionID == "" {
108		return nil, ErrSessionMissing
109	}
110
111	// Queue the message if busy
112	if a.IsSessionBusy(call.SessionID) {
113		existing, ok := a.messageQueue.Get(call.SessionID)
114		if !ok {
115			existing = []SessionAgentCall{}
116		}
117		existing = append(existing, call)
118		a.messageQueue.Set(call.SessionID, existing)
119		return nil, nil
120	}
121
122	if len(a.tools) > 0 {
123		// add anthropic caching to the last tool
124		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
125	}
126
127	agent := ai.NewAgent(
128		a.largeModel.Model,
129		ai.WithSystemPrompt(a.systemPrompt),
130		ai.WithTools(a.tools...),
131	)
132
133	sessionLock := sync.Mutex{}
134	currentSession, err := a.sessions.Get(ctx, call.SessionID)
135	if err != nil {
136		return nil, fmt.Errorf("failed to get session: %w", err)
137	}
138
139	msgs, err := a.getSessionMessages(ctx, currentSession)
140	if err != nil {
141		return nil, fmt.Errorf("failed to get session messages: %w", err)
142	}
143
144	var wg sync.WaitGroup
145	// Generate title if first message
146	if len(msgs) == 0 {
147		wg.Go(func() {
148			sessionLock.Lock()
149			a.generateTitle(ctx, &currentSession, call.Prompt)
150			sessionLock.Unlock()
151		})
152	}
153
154	// Add the user message to the session
155	_, err = a.createUserMessage(ctx, call)
156	if err != nil {
157		return nil, err
158	}
159
160	// add the session to the context
161	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
162
163	genCtx, cancel := context.WithCancel(ctx)
164	a.activeRequests.Set(call.SessionID, cancel)
165
166	defer cancel()
167	defer a.activeRequests.Del(call.SessionID)
168
169	history, files := a.preparePrompt(msgs, call.Attachments...)
170
171	var currentAssistant *message.Message
172	var shouldSummarize bool
173	result, err := agent.Stream(genCtx, ai.AgentStreamCall{
174		Prompt:           call.Prompt,
175		Files:            files,
176		Messages:         history,
177		ProviderOptions:  call.ProviderOptions,
178		MaxOutputTokens:  &call.MaxOutputTokens,
179		TopP:             call.TopP,
180		Temperature:      call.Temperature,
181		PresencePenalty:  call.PresencePenalty,
182		TopK:             call.TopK,
183		FrequencyPenalty: call.FrequencyPenalty,
184		// Before each step create the new assistant message
185		PrepareStep: func(callContext context.Context, options ai.PrepareStepFunctionOptions) (_ context.Context, prepared ai.PrepareStepResult, err error) {
186			var assistantMsg message.Message
187			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
188				Role:     message.Assistant,
189				Parts:    []message.ContentPart{},
190				Model:    a.largeModel.ModelCfg.Model,
191				Provider: a.largeModel.ModelCfg.Provider,
192			})
193			if err != nil {
194				return callContext, prepared, err
195			}
196
197			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
198
199			currentAssistant = &assistantMsg
200
201			prepared.Messages = options.Messages
202			// reset all cached items
203			for i := range prepared.Messages {
204				prepared.Messages[i].ProviderOptions = nil
205			}
206
207			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
208			a.messageQueue.Del(call.SessionID)
209			for _, queued := range queuedCalls {
210				userMessage, createErr := a.createUserMessage(callContext, queued)
211				if createErr != nil {
212					return callContext, prepared, createErr
213				}
214				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
215			}
216
217			lastSystemRoleInx := 0
218			systemMessageUpdated := false
219			for i, msg := range prepared.Messages {
220				// only add cache control to the last message
221				if msg.Role == ai.MessageRoleSystem {
222					lastSystemRoleInx = i
223				} else if !systemMessageUpdated {
224					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
225					systemMessageUpdated = true
226				}
227				// than add cache control to the last 2 messages
228				if i > len(prepared.Messages)-3 {
229					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
230				}
231			}
232			return callContext, prepared, err
233		},
234		OnReasoningDelta: func(id string, text string) error {
235			currentAssistant.AppendReasoningContent(text)
236			return a.messages.Update(genCtx, *currentAssistant)
237		},
238		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
239			// handle anthropic signature
240			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
241				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
242					currentAssistant.AppendReasoningSignature(reasoning.Signature)
243				}
244			}
245			currentAssistant.FinishThinking()
246			return a.messages.Update(genCtx, *currentAssistant)
247		},
248		OnTextDelta: func(id string, text string) error {
249			currentAssistant.AppendContent(text)
250			return a.messages.Update(genCtx, *currentAssistant)
251		},
252		OnToolInputStart: func(id string, toolName string) error {
253			toolCall := message.ToolCall{
254				ID:               id,
255				Name:             toolName,
256				ProviderExecuted: false,
257				Finished:         false,
258			}
259			currentAssistant.AddToolCall(toolCall)
260			return a.messages.Update(genCtx, *currentAssistant)
261		},
262		OnRetry: func(err *ai.APICallError, delay time.Duration) {
263			// TODO: implement
264		},
265		OnToolCall: func(tc ai.ToolCallContent) error {
266			toolCall := message.ToolCall{
267				ID:               tc.ToolCallID,
268				Name:             tc.ToolName,
269				Input:            tc.Input,
270				ProviderExecuted: false,
271				Finished:         true,
272			}
273			currentAssistant.AddToolCall(toolCall)
274			return a.messages.Update(genCtx, *currentAssistant)
275		},
276		OnToolResult: func(result ai.ToolResultContent) error {
277			var resultContent string
278			isError := false
279			switch result.Result.GetType() {
280			case ai.ToolResultContentTypeText:
281				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
282				if ok {
283					resultContent = r.Text
284				}
285			case ai.ToolResultContentTypeError:
286				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
287				if ok {
288					isError = true
289					resultContent = r.Error.Error()
290				}
291			case ai.ToolResultContentTypeMedia:
292				// TODO: handle this message type
293			}
294			toolResult := message.ToolResult{
295				ToolCallID: result.ToolCallID,
296				Name:       result.ToolName,
297				Content:    resultContent,
298				IsError:    isError,
299				Metadata:   result.ClientMetadata,
300			}
301			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
302				Role: message.Tool,
303				Parts: []message.ContentPart{
304					toolResult,
305				},
306			})
307			if createMsgErr != nil {
308				return createMsgErr
309			}
310			return nil
311		},
312		OnStepFinish: func(stepResult ai.StepResult) error {
313			finishReason := message.FinishReasonUnknown
314			switch stepResult.FinishReason {
315			case ai.FinishReasonLength:
316				finishReason = message.FinishReasonMaxTokens
317			case ai.FinishReasonStop:
318				finishReason = message.FinishReasonEndTurn
319			case ai.FinishReasonToolCalls:
320				finishReason = message.FinishReasonToolUse
321			}
322			currentAssistant.AddFinish(finishReason, "", "")
323			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
324			sessionLock.Lock()
325			_, sessionErr := a.sessions.Save(genCtx, currentSession)
326			sessionLock.Unlock()
327			if sessionErr != nil {
328				return sessionErr
329			}
330			return a.messages.Update(genCtx, *currentAssistant)
331		},
332		StopWhen: []ai.StopCondition{
333			func(_ []ai.StepResult) bool {
334				contextWindow := a.largeModel.CatwalkCfg.ContextWindow
335				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
336				percentage := (float64(tokens) / float64(contextWindow)) * 100
337				if (percentage > 80) && !a.disableAutoSummarize {
338					shouldSummarize = true
339					return true
340				}
341				return false
342			},
343		},
344	})
345	if err != nil {
346		isCancelErr := errors.Is(err, context.Canceled)
347		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
348		if currentAssistant == nil {
349			return result, err
350		}
351		toolCalls := currentAssistant.ToolCalls()
352		// INFO: we use the parent context here because the genCtx has been cancelled
353		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
354		if createErr != nil {
355			return nil, createErr
356		}
357		for _, tc := range toolCalls {
358			if !tc.Finished {
359				tc.Finished = true
360				tc.Input = "{}"
361				currentAssistant.AddToolCall(tc)
362				updateErr := a.messages.Update(ctx, *currentAssistant)
363				if updateErr != nil {
364					return nil, updateErr
365				}
366			}
367
368			found := false
369			for _, msg := range msgs {
370				if msg.Role == message.Tool {
371					for _, tr := range msg.ToolResults() {
372						if tr.ToolCallID == tc.ID {
373							found = true
374							break
375						}
376					}
377				}
378				if found {
379					break
380				}
381			}
382			if found {
383				continue
384			}
385			content := "There was an error while executing the tool"
386			if isCancelErr {
387				content = "Tool execution canceled by user"
388			} else if isPermissionErr {
389				content = "Permission denied"
390			}
391			toolResult := message.ToolResult{
392				ToolCallID: tc.ID,
393				Name:       tc.Name,
394				Content:    content,
395				IsError:    true,
396			}
397			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
398				Role: message.Tool,
399				Parts: []message.ContentPart{
400					toolResult,
401				},
402			})
403			if createErr != nil {
404				return nil, createErr
405			}
406		}
407		if isCancelErr {
408			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
409		} else if isPermissionErr {
410			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
411		} else {
412			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
413		}
414		// INFO: we use the parent context here because the genCtx has been cancelled
415		updateErr := a.messages.Update(ctx, *currentAssistant)
416		if updateErr != nil {
417			return nil, updateErr
418		}
419		return nil, err
420	}
421	wg.Wait()
422
423	if shouldSummarize {
424		a.activeRequests.Del(call.SessionID)
425		if summarizeErr := a.Summarize(genCtx, call.SessionID); summarizeErr != nil {
426			return nil, summarizeErr
427		}
428	}
429
430	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
431	if !ok || len(queuedMessages) == 0 {
432		return result, err
433	}
434	// there are queued messages restart the loop
435	firstQueuedMessage := queuedMessages[0]
436	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
437	return a.Run(genCtx, firstQueuedMessage)
438}
439
440func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
441	if a.IsSessionBusy(sessionID) {
442		return ErrSessionBusy
443	}
444
445	currentSession, err := a.sessions.Get(ctx, sessionID)
446	if err != nil {
447		return fmt.Errorf("failed to get session: %w", err)
448	}
449	msgs, err := a.getSessionMessages(ctx, currentSession)
450	if err != nil {
451		return err
452	}
453	if len(msgs) == 0 {
454		// nothing to summarize
455		return nil
456	}
457
458	aiMsgs, _ := a.preparePrompt(msgs)
459
460	genCtx, cancel := context.WithCancel(ctx)
461	a.activeRequests.Set(sessionID, cancel)
462	defer a.activeRequests.Del(sessionID)
463	defer cancel()
464
465	agent := ai.NewAgent(a.largeModel.Model,
466		ai.WithSystemPrompt(string(summaryPrompt)),
467	)
468	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
469		Role:             message.Assistant,
470		Model:            a.largeModel.Model.Model(),
471		Provider:         a.largeModel.Model.Provider(),
472		IsSummaryMessage: true,
473	})
474	if err != nil {
475		return err
476	}
477
478	resp, err := agent.Stream(genCtx, ai.AgentStreamCall{
479		Prompt:   "Provide a detailed summary of our conversation above.",
480		Messages: aiMsgs,
481		OnReasoningDelta: func(id string, text string) error {
482			summaryMessage.AppendReasoningContent(text)
483			return a.messages.Update(genCtx, summaryMessage)
484		},
485		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
486			// handle anthropic signature
487			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
488				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
489					summaryMessage.AppendReasoningSignature(signature.Signature)
490				}
491			}
492			summaryMessage.FinishThinking()
493			return a.messages.Update(genCtx, summaryMessage)
494		},
495		OnTextDelta: func(id, text string) error {
496			summaryMessage.AppendContent(text)
497			return a.messages.Update(genCtx, summaryMessage)
498		},
499	})
500	if err != nil {
501		isCancelErr := errors.Is(err, context.Canceled)
502		if isCancelErr {
503			// User cancelled summarize we need to remove the summary message
504			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
505			return deleteErr
506		}
507		return err
508	}
509
510	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
511	err = a.messages.Update(genCtx, summaryMessage)
512	if err != nil {
513		return err
514	}
515
516	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
517
518	// just in case get just the last usage
519	usage := resp.Response.Usage
520	currentSession.SummaryMessageID = summaryMessage.ID
521	currentSession.CompletionTokens = usage.OutputTokens
522	currentSession.PromptTokens = 0
523	_, err = a.sessions.Save(genCtx, currentSession)
524	return err
525}
526
527func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
528	return ai.ProviderOptions{
529		anthropic.Name: &anthropic.ProviderCacheControlOptions{
530			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
531		},
532	}
533}
534
535func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
536	var attachmentParts []message.ContentPart
537	for _, attachment := range call.Attachments {
538		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
539	}
540	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
541	parts = append(parts, attachmentParts...)
542	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
543		Role:  message.User,
544		Parts: parts,
545	})
546	if err != nil {
547		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
548	}
549	return msg, nil
550}
551
552func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
553	var history []ai.Message
554	for _, m := range msgs {
555		if len(m.Parts) == 0 {
556			continue
557		}
558		// Assistant message without content or tool calls (cancelled before it returned anything)
559		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
560			continue
561		}
562		history = append(history, m.ToAIMessage()...)
563	}
564
565	var files []ai.FilePart
566	for _, attachment := range attachments {
567		files = append(files, ai.FilePart{
568			Filename:  attachment.FileName,
569			Data:      attachment.Content,
570			MediaType: attachment.MimeType,
571		})
572	}
573
574	return history, files
575}
576
577func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
578	msgs, err := a.messages.List(ctx, session.ID)
579	if err != nil {
580		return nil, fmt.Errorf("failed to list messages: %w", err)
581	}
582
583	if session.SummaryMessageID != "" {
584		summaryMsgInex := -1
585		for i, msg := range msgs {
586			if msg.ID == session.SummaryMessageID {
587				summaryMsgInex = i
588				break
589			}
590		}
591		if summaryMsgInex != -1 {
592			msgs = msgs[summaryMsgInex:]
593			msgs[0].Role = message.User
594		}
595	}
596	return msgs, nil
597}
598
599func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
600	if prompt == "" {
601		return
602	}
603
604	var maxOutput int64 = 40
605	if a.smallModel.CatwalkCfg.CanReason {
606		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
607	}
608
609	agent := ai.NewAgent(a.smallModel.Model,
610		ai.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
611		ai.WithMaxOutputTokens(maxOutput),
612	)
613
614	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
615		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
616	})
617	if err != nil {
618		slog.Error("error generating title", "err", err)
619		return
620	}
621
622	data, _ := json.Marshal(resp)
623
624	slog.Info("Title Response")
625	slog.Info(string(data))
626	title := resp.Response.Content.Text()
627
628	title = strings.ReplaceAll(title, "\n", " ")
629	slog.Info(title)
630
631	// remove thinking tags if present
632	if idx := strings.Index(title, "</think>"); idx > 0 {
633		title = title[idx+len("</think>"):]
634	}
635
636	title = strings.TrimSpace(title)
637	if title == "" {
638		slog.Warn("failed to generate title", "warn", "empty title")
639		return
640	}
641
642	session.Title = title
643	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
644	_, saveErr := a.sessions.Save(ctx, *session)
645	if saveErr != nil {
646		slog.Error("failed to save session title & usage", "error", saveErr)
647		return
648	}
649}
650
651func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
652	modelConfig := model.CatwalkCfg
653	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
654		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
655		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
656		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
657	session.Cost += cost
658	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
659	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
660}
661
662func (a *sessionAgent) Cancel(sessionID string) {
663	// Cancel regular requests
664	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
665		slog.Info("Request cancellation initiated", "session_id", sessionID)
666		cancel()
667	}
668
669	// Also check for summarize requests
670	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
671		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
672		cancel()
673	}
674
675	if a.QueuedPrompts(sessionID) > 0 {
676		slog.Info("Clearing queued prompts", "session_id", sessionID)
677		a.messageQueue.Del(sessionID)
678	}
679}
680
681func (a *sessionAgent) ClearQueue(sessionID string) {
682	if a.QueuedPrompts(sessionID) > 0 {
683		slog.Info("Clearing queued prompts", "session_id", sessionID)
684		a.messageQueue.Del(sessionID)
685	}
686}
687
688func (a *sessionAgent) CancelAll() {
689	if !a.IsBusy() {
690		return
691	}
692	for key := range a.activeRequests.Seq2() {
693		a.Cancel(key) // key is sessionID
694	}
695
696	timeout := time.After(5 * time.Second)
697	for a.IsBusy() {
698		select {
699		case <-timeout:
700			return
701		default:
702			time.Sleep(200 * time.Millisecond)
703		}
704	}
705}
706
707func (a *sessionAgent) IsBusy() bool {
708	var busy bool
709	for cancelFunc := range a.activeRequests.Seq() {
710		if cancelFunc != nil {
711			busy = true
712			break
713		}
714	}
715	return busy
716}
717
718func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
719	_, busy := a.activeRequests.Get(sessionID)
720	return busy
721}
722
723func (a *sessionAgent) QueuedPrompts(sessionID string) int {
724	l, ok := a.messageQueue.Get(sessionID)
725	if !ok {
726		return 0
727	}
728	return len(l)
729}
730
731func (a *sessionAgent) SetModels(large Model, small Model) {
732	a.largeModel = large
733	a.smallModel = small
734}
735
736func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
737	a.tools = tools
738}
739
740func (a *sessionAgent) Model() Model {
741	return a.largeModel
742}