agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/agent/tools"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/session"
 20	"github.com/charmbracelet/fantasy/ai"
 21	"github.com/charmbracelet/fantasy/anthropic"
 22)
 23
 24//go:embed templates/title.md
 25var titlePrompt []byte
 26
 27//go:embed templates/summary.md
 28var summaryPrompt []byte
 29
 30type SessionAgentCall struct {
 31	SessionID        string
 32	Prompt           string
 33	ProviderOptions  ai.ProviderOptions
 34	Attachments      []message.Attachment
 35	MaxOutputTokens  int64
 36	Temperature      *float64
 37	TopP             *float64
 38	TopK             *int64
 39	FrequencyPenalty *float64
 40	PresencePenalty  *float64
 41}
 42
 43type SessionAgent interface {
 44	Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
 45	SetModels(large Model, small Model)
 46	SetTools(tools []ai.AgentTool)
 47	Cancel(sessionID string)
 48	CancelAll()
 49	IsSessionBusy(sessionID string) bool
 50	IsBusy() bool
 51	QueuedPrompts(sessionID string) int
 52	ClearQueue(sessionID string)
 53	Summarize(context.Context, string) error
 54	Model() Model
 55}
 56
 57type Model struct {
 58	Model      ai.LanguageModel
 59	CatwalkCfg catwalk.Model
 60	ModelCfg   config.SelectedModel
 61}
 62
 63type sessionAgent struct {
 64	largeModel           Model
 65	smallModel           Model
 66	systemPrompt         string
 67	tools                []ai.AgentTool
 68	sessions             session.Service
 69	messages             message.Service
 70	disableAutoSummarize bool
 71
 72	messageQueue   *csync.Map[string, []SessionAgentCall]
 73	activeRequests *csync.Map[string, context.CancelFunc]
 74}
 75
 76type SessionAgentOptions struct {
 77	LargeModel           Model
 78	SmallModel           Model
 79	SystemPrompt         string
 80	DisableAutoSummarize bool
 81	Sessions             session.Service
 82	Messages             message.Service
 83	Tools                []ai.AgentTool
 84}
 85
 86func NewSessionAgent(
 87	opts SessionAgentOptions,
 88) SessionAgent {
 89	return &sessionAgent{
 90		largeModel:           opts.LargeModel,
 91		smallModel:           opts.SmallModel,
 92		systemPrompt:         opts.SystemPrompt,
 93		sessions:             opts.Sessions,
 94		messages:             opts.Messages,
 95		disableAutoSummarize: opts.DisableAutoSummarize,
 96		tools:                opts.Tools,
 97		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
 98		activeRequests:       csync.NewMap[string, context.CancelFunc](),
 99	}
100}
101
102func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
103	if call.Prompt == "" {
104		return nil, ErrEmptyPrompt
105	}
106	if call.SessionID == "" {
107		return nil, ErrSessionMissing
108	}
109
110	// Queue the message if busy
111	if a.IsSessionBusy(call.SessionID) {
112		existing, ok := a.messageQueue.Get(call.SessionID)
113		if !ok {
114			existing = []SessionAgentCall{}
115		}
116		existing = append(existing, call)
117		a.messageQueue.Set(call.SessionID, existing)
118		return nil, nil
119	}
120
121	if len(a.tools) > 0 {
122		// add anthropic caching to the last tool
123		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
124	}
125
126	agent := ai.NewAgent(
127		a.largeModel.Model,
128		ai.WithSystemPrompt(a.systemPrompt),
129		ai.WithTools(a.tools...),
130	)
131
132	sessionLock := sync.Mutex{}
133	currentSession, err := a.sessions.Get(ctx, call.SessionID)
134	if err != nil {
135		return nil, fmt.Errorf("failed to get session: %w", err)
136	}
137
138	msgs, err := a.getSessionMessages(ctx, currentSession)
139	if err != nil {
140		return nil, fmt.Errorf("failed to get session messages: %w", err)
141	}
142
143	var wg sync.WaitGroup
144	// Generate title if first message
145	if len(msgs) == 0 {
146		wg.Go(func() {
147			sessionLock.Lock()
148			a.generateTitle(ctx, &currentSession, call.Prompt)
149			sessionLock.Unlock()
150		})
151	}
152
153	// Add the user message to the session
154	_, err = a.createUserMessage(ctx, call)
155	if err != nil {
156		return nil, err
157	}
158
159	// add the session to the context
160	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
161
162	genCtx, cancel := context.WithCancel(ctx)
163	a.activeRequests.Set(call.SessionID, cancel)
164
165	defer cancel()
166	defer a.activeRequests.Del(call.SessionID)
167
168	history, files := a.preparePrompt(msgs, call.Attachments...)
169
170	var currentAssistant *message.Message
171	var shouldSummarize bool
172	result, err := agent.Stream(genCtx, ai.AgentStreamCall{
173		Prompt:           call.Prompt,
174		Files:            files,
175		Messages:         history,
176		ProviderOptions:  call.ProviderOptions,
177		MaxOutputTokens:  &call.MaxOutputTokens,
178		TopP:             call.TopP,
179		Temperature:      call.Temperature,
180		PresencePenalty:  call.PresencePenalty,
181		TopK:             call.TopK,
182		FrequencyPenalty: call.FrequencyPenalty,
183		// Before each step create the new assistant message
184		PrepareStep: func(callContext context.Context, options ai.PrepareStepFunctionOptions) (_ context.Context, prepared ai.PrepareStepResult, err error) {
185			var assistantMsg message.Message
186			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
187				Role:     message.Assistant,
188				Parts:    []message.ContentPart{},
189				Model:    a.largeModel.ModelCfg.Model,
190				Provider: a.largeModel.ModelCfg.Provider,
191			})
192			if err != nil {
193				return callContext, prepared, err
194			}
195
196			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
197
198			currentAssistant = &assistantMsg
199
200			prepared.Messages = options.Messages
201			// reset all cached items
202			for i := range prepared.Messages {
203				prepared.Messages[i].ProviderOptions = nil
204			}
205
206			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
207			a.messageQueue.Del(call.SessionID)
208			for _, queued := range queuedCalls {
209				userMessage, createErr := a.createUserMessage(callContext, queued)
210				if createErr != nil {
211					return callContext, prepared, createErr
212				}
213				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
214			}
215
216			lastSystemRoleInx := 0
217			systemMessageUpdated := false
218			for i, msg := range prepared.Messages {
219				// only add cache control to the last message
220				if msg.Role == ai.MessageRoleSystem {
221					lastSystemRoleInx = i
222				} else if !systemMessageUpdated {
223					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
224					systemMessageUpdated = true
225				}
226				// than add cache control to the last 2 messages
227				if i > len(prepared.Messages)-3 {
228					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
229				}
230			}
231			return callContext, prepared, err
232		},
233		OnReasoningDelta: func(id string, text string) error {
234			currentAssistant.AppendReasoningContent(text)
235			return a.messages.Update(genCtx, *currentAssistant)
236		},
237		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
238			// handle anthropic signature
239			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
240				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
241					currentAssistant.AppendReasoningSignature(reasoning.Signature)
242				}
243			}
244			currentAssistant.FinishThinking()
245			return a.messages.Update(genCtx, *currentAssistant)
246		},
247		OnTextDelta: func(id string, text string) error {
248			currentAssistant.AppendContent(text)
249			return a.messages.Update(genCtx, *currentAssistant)
250		},
251		OnToolInputStart: func(id string, toolName string) error {
252			toolCall := message.ToolCall{
253				ID:               id,
254				Name:             toolName,
255				ProviderExecuted: false,
256				Finished:         false,
257			}
258			currentAssistant.AddToolCall(toolCall)
259			return a.messages.Update(genCtx, *currentAssistant)
260		},
261		OnRetry: func(err *ai.APICallError, delay time.Duration) {
262			// TODO: implement
263		},
264		OnToolCall: func(tc ai.ToolCallContent) error {
265			toolCall := message.ToolCall{
266				ID:               tc.ToolCallID,
267				Name:             tc.ToolName,
268				Input:            tc.Input,
269				ProviderExecuted: false,
270				Finished:         true,
271			}
272			currentAssistant.AddToolCall(toolCall)
273			return a.messages.Update(genCtx, *currentAssistant)
274		},
275		OnToolResult: func(result ai.ToolResultContent) error {
276			var resultContent string
277			isError := false
278			switch result.Result.GetType() {
279			case ai.ToolResultContentTypeText:
280				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
281				if ok {
282					resultContent = r.Text
283				}
284			case ai.ToolResultContentTypeError:
285				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
286				if ok {
287					isError = true
288					resultContent = r.Error.Error()
289				}
290			case ai.ToolResultContentTypeMedia:
291				// TODO: handle this message type
292			}
293			toolResult := message.ToolResult{
294				ToolCallID: result.ToolCallID,
295				Name:       result.ToolName,
296				Content:    resultContent,
297				IsError:    isError,
298				Metadata:   result.ClientMetadata,
299			}
300			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
301				Role: message.Tool,
302				Parts: []message.ContentPart{
303					toolResult,
304				},
305			})
306			if createMsgErr != nil {
307				return createMsgErr
308			}
309			return nil
310		},
311		OnStepFinish: func(stepResult ai.StepResult) error {
312			finishReason := message.FinishReasonUnknown
313			switch stepResult.FinishReason {
314			case ai.FinishReasonLength:
315				finishReason = message.FinishReasonMaxTokens
316			case ai.FinishReasonStop:
317				finishReason = message.FinishReasonEndTurn
318			case ai.FinishReasonToolCalls:
319				finishReason = message.FinishReasonToolUse
320			}
321			currentAssistant.AddFinish(finishReason, "", "")
322			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
323			sessionLock.Lock()
324			_, sessionErr := a.sessions.Save(genCtx, currentSession)
325			sessionLock.Unlock()
326			if sessionErr != nil {
327				return sessionErr
328			}
329			return a.messages.Update(genCtx, *currentAssistant)
330		},
331		StopWhen: []ai.StopCondition{
332			func(_ []ai.StepResult) bool {
333				contextWindow := a.largeModel.CatwalkCfg.ContextWindow
334				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
335				percentage := (float64(tokens) / float64(contextWindow)) * 100
336				if (percentage > 80) && !a.disableAutoSummarize {
337					shouldSummarize = true
338					return true
339				}
340				return false
341			},
342		},
343	})
344	if err != nil {
345		isCancelErr := errors.Is(err, context.Canceled)
346		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
347		if currentAssistant == nil {
348			return result, err
349		}
350		toolCalls := currentAssistant.ToolCalls()
351		// INFO: we use the parent context here because the genCtx has been cancelled
352		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
353		if createErr != nil {
354			return nil, createErr
355		}
356		for _, tc := range toolCalls {
357			if !tc.Finished {
358				tc.Finished = true
359				tc.Input = "{}"
360				currentAssistant.AddToolCall(tc)
361				updateErr := a.messages.Update(ctx, *currentAssistant)
362				if updateErr != nil {
363					return nil, updateErr
364				}
365			}
366
367			found := false
368			for _, msg := range msgs {
369				if msg.Role == message.Tool {
370					for _, tr := range msg.ToolResults() {
371						if tr.ToolCallID == tc.ID {
372							found = true
373							break
374						}
375					}
376				}
377				if found {
378					break
379				}
380			}
381			if found {
382				continue
383			}
384			content := "There was an error while executing the tool"
385			if isCancelErr {
386				content = "Tool execution canceled by user"
387			} else if isPermissionErr {
388				content = "Permission denied"
389			}
390			toolResult := message.ToolResult{
391				ToolCallID: tc.ID,
392				Name:       tc.Name,
393				Content:    content,
394				IsError:    true,
395			}
396			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
397				Role: message.Tool,
398				Parts: []message.ContentPart{
399					toolResult,
400				},
401			})
402			if createErr != nil {
403				return nil, createErr
404			}
405		}
406		if isCancelErr {
407			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
408		} else if isPermissionErr {
409			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
410		} else {
411			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
412		}
413		// INFO: we use the parent context here because the genCtx has been cancelled
414		updateErr := a.messages.Update(ctx, *currentAssistant)
415		if updateErr != nil {
416			return nil, updateErr
417		}
418		return nil, err
419	}
420	wg.Wait()
421
422	if shouldSummarize {
423		a.activeRequests.Del(call.SessionID)
424		if summarizeErr := a.Summarize(genCtx, call.SessionID); summarizeErr != nil {
425			return nil, summarizeErr
426		}
427	}
428
429	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
430	if !ok || len(queuedMessages) == 0 {
431		return result, err
432	}
433	// there are queued messages restart the loop
434	firstQueuedMessage := queuedMessages[0]
435	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
436	return a.Run(genCtx, firstQueuedMessage)
437}
438
439func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
440	if a.IsSessionBusy(sessionID) {
441		return ErrSessionBusy
442	}
443
444	currentSession, err := a.sessions.Get(ctx, sessionID)
445	if err != nil {
446		return fmt.Errorf("failed to get session: %w", err)
447	}
448	msgs, err := a.getSessionMessages(ctx, currentSession)
449	if err != nil {
450		return err
451	}
452	if len(msgs) == 0 {
453		// nothing to summarize
454		return nil
455	}
456
457	aiMsgs, _ := a.preparePrompt(msgs)
458
459	genCtx, cancel := context.WithCancel(ctx)
460	a.activeRequests.Set(sessionID, cancel)
461	defer a.activeRequests.Del(sessionID)
462	defer cancel()
463
464	agent := ai.NewAgent(a.largeModel.Model,
465		ai.WithSystemPrompt(string(summaryPrompt)),
466	)
467	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
468		Role:             message.Assistant,
469		Model:            a.largeModel.Model.Model(),
470		Provider:         a.largeModel.Model.Provider(),
471		IsSummaryMessage: true,
472	})
473	if err != nil {
474		return err
475	}
476
477	resp, err := agent.Stream(genCtx, ai.AgentStreamCall{
478		Prompt:   "Provide a detailed summary of our conversation above.",
479		Messages: aiMsgs,
480		OnReasoningDelta: func(id string, text string) error {
481			summaryMessage.AppendReasoningContent(text)
482			return a.messages.Update(genCtx, summaryMessage)
483		},
484		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
485			// handle anthropic signature
486			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
487				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
488					summaryMessage.AppendReasoningSignature(signature.Signature)
489				}
490			}
491			summaryMessage.FinishThinking()
492			return a.messages.Update(genCtx, summaryMessage)
493		},
494		OnTextDelta: func(id, text string) error {
495			summaryMessage.AppendContent(text)
496			return a.messages.Update(genCtx, summaryMessage)
497		},
498	})
499	if err != nil {
500		isCancelErr := errors.Is(err, context.Canceled)
501		if isCancelErr {
502			// User cancelled summarize we need to remove the summary message
503			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
504			return deleteErr
505		}
506		return err
507	}
508
509	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
510	err = a.messages.Update(genCtx, summaryMessage)
511	if err != nil {
512		return err
513	}
514
515	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
516
517	// just in case get just the last usage
518	usage := resp.Response.Usage
519	currentSession.SummaryMessageID = summaryMessage.ID
520	currentSession.CompletionTokens = usage.OutputTokens
521	currentSession.PromptTokens = 0
522	_, err = a.sessions.Save(genCtx, currentSession)
523	return err
524}
525
526func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
527	return ai.ProviderOptions{
528		anthropic.Name: &anthropic.ProviderCacheControlOptions{
529			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
530		},
531	}
532}
533
534func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
535	var attachmentParts []message.ContentPart
536	for _, attachment := range call.Attachments {
537		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
538	}
539	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
540	parts = append(parts, attachmentParts...)
541	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
542		Role:  message.User,
543		Parts: parts,
544	})
545	if err != nil {
546		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
547	}
548	return msg, nil
549}
550
551func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
552	var history []ai.Message
553	for _, m := range msgs {
554		if len(m.Parts) == 0 {
555			continue
556		}
557		// Assistant message without content or tool calls (cancelled before it returned anything)
558		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
559			continue
560		}
561		history = append(history, m.ToAIMessage()...)
562	}
563
564	var files []ai.FilePart
565	for _, attachment := range attachments {
566		files = append(files, ai.FilePart{
567			Filename:  attachment.FileName,
568			Data:      attachment.Content,
569			MediaType: attachment.MimeType,
570		})
571	}
572
573	return history, files
574}
575
576func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
577	msgs, err := a.messages.List(ctx, session.ID)
578	if err != nil {
579		return nil, fmt.Errorf("failed to list messages: %w", err)
580	}
581
582	if session.SummaryMessageID != "" {
583		summaryMsgInex := -1
584		for i, msg := range msgs {
585			if msg.ID == session.SummaryMessageID {
586				summaryMsgInex = i
587				break
588			}
589		}
590		if summaryMsgInex != -1 {
591			msgs = msgs[summaryMsgInex:]
592			msgs[0].Role = message.User
593		}
594	}
595	return msgs, nil
596}
597
598func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
599	if prompt == "" {
600		return
601	}
602
603	var maxOutput int64 = 40
604	if a.smallModel.CatwalkCfg.CanReason {
605		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
606	}
607
608	agent := ai.NewAgent(a.smallModel.Model,
609		ai.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
610		ai.WithMaxOutputTokens(maxOutput),
611	)
612
613	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
614		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
615	})
616	if err != nil {
617		slog.Error("error generating title", "err", err)
618		return
619	}
620
621	title := resp.Response.Content.Text()
622
623	title = strings.ReplaceAll(title, "\n", " ")
624
625	// remove thinking tags if present
626	if idx := strings.Index(title, "</think>"); idx > 0 {
627		title = title[idx+len("</think>"):]
628	}
629
630	title = strings.TrimSpace(title)
631	if title == "" {
632		slog.Warn("failed to generate title", "warn", "empty title")
633		return
634	}
635
636	session.Title = title
637	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
638	_, saveErr := a.sessions.Save(ctx, *session)
639	if saveErr != nil {
640		slog.Error("failed to save session title & usage", "error", saveErr)
641		return
642	}
643}
644
645func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
646	modelConfig := model.CatwalkCfg
647	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
648		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
649		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
650		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
651	session.Cost += cost
652	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
653	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
654}
655
656func (a *sessionAgent) Cancel(sessionID string) {
657	// Cancel regular requests
658	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
659		slog.Info("Request cancellation initiated", "session_id", sessionID)
660		cancel()
661	}
662
663	// Also check for summarize requests
664	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
665		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
666		cancel()
667	}
668
669	if a.QueuedPrompts(sessionID) > 0 {
670		slog.Info("Clearing queued prompts", "session_id", sessionID)
671		a.messageQueue.Del(sessionID)
672	}
673}
674
675func (a *sessionAgent) ClearQueue(sessionID string) {
676	if a.QueuedPrompts(sessionID) > 0 {
677		slog.Info("Clearing queued prompts", "session_id", sessionID)
678		a.messageQueue.Del(sessionID)
679	}
680}
681
682func (a *sessionAgent) CancelAll() {
683	if !a.IsBusy() {
684		return
685	}
686	for key := range a.activeRequests.Seq2() {
687		a.Cancel(key) // key is sessionID
688	}
689
690	timeout := time.After(5 * time.Second)
691	for a.IsBusy() {
692		select {
693		case <-timeout:
694			return
695		default:
696			time.Sleep(200 * time.Millisecond)
697		}
698	}
699}
700
701func (a *sessionAgent) IsBusy() bool {
702	var busy bool
703	for cancelFunc := range a.activeRequests.Seq() {
704		if cancelFunc != nil {
705			busy = true
706			break
707		}
708	}
709	return busy
710}
711
712func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
713	_, busy := a.activeRequests.Get(sessionID)
714	return busy
715}
716
717func (a *sessionAgent) QueuedPrompts(sessionID string) int {
718	l, ok := a.messageQueue.Get(sessionID)
719	if !ok {
720		return 0
721	}
722	return len(l)
723}
724
725func (a *sessionAgent) SetModels(large Model, small Model) {
726	a.largeModel = large
727	a.smallModel = small
728}
729
730func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
731	a.tools = tools
732}
733
734func (a *sessionAgent) Model() Model {
735	return a.largeModel
736}