1package agent
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "strings"
11 "sync"
12 "time"
13
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/agent/tools"
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/message"
19 "github.com/charmbracelet/crush/internal/permission"
20 "github.com/charmbracelet/crush/internal/session"
21 "github.com/charmbracelet/fantasy/ai"
22 "github.com/charmbracelet/fantasy/anthropic"
23)
24
25//go:embed templates/title.md
26var titlePrompt []byte
27
28//go:embed templates/summary.md
29var summaryPrompt []byte
30
31type SessionAgentCall struct {
32 SessionID string
33 Prompt string
34 ProviderOptions ai.ProviderOptions
35 Attachments []message.Attachment
36 MaxOutputTokens int64
37 Temperature *float64
38 TopP *float64
39 TopK *int64
40 FrequencyPenalty *float64
41 PresencePenalty *float64
42}
43
44type SessionAgent interface {
45 Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
46 SetModels(large Model, small Model)
47 SetTools(tools []ai.AgentTool)
48 Cancel(sessionID string)
49 CancelAll()
50 IsSessionBusy(sessionID string) bool
51 IsBusy() bool
52 QueuedPrompts(sessionID string) int
53 ClearQueue(sessionID string)
54 Summarize(context.Context, string) error
55 Model() Model
56}
57
58type Model struct {
59 Model ai.LanguageModel
60 CatwalkCfg catwalk.Model
61 ModelCfg config.SelectedModel
62}
63
64type sessionAgent struct {
65 largeModel Model
66 smallModel Model
67 systemPrompt string
68 tools []ai.AgentTool
69 sessions session.Service
70 messages message.Service
71 disableAutoSummarize bool
72
73 messageQueue *csync.Map[string, []SessionAgentCall]
74 activeRequests *csync.Map[string, context.CancelFunc]
75}
76
77type SessionAgentOptions struct {
78 LargeModel Model
79 SmallModel Model
80 SystemPrompt string
81 DisableAutoSummarize bool
82 Sessions session.Service
83 Messages message.Service
84 Tools []ai.AgentTool
85}
86
87func NewSessionAgent(
88 opts SessionAgentOptions,
89) SessionAgent {
90 return &sessionAgent{
91 largeModel: opts.LargeModel,
92 smallModel: opts.SmallModel,
93 systemPrompt: opts.SystemPrompt,
94 sessions: opts.Sessions,
95 messages: opts.Messages,
96 disableAutoSummarize: opts.DisableAutoSummarize,
97 tools: opts.Tools,
98 messageQueue: csync.NewMap[string, []SessionAgentCall](),
99 activeRequests: csync.NewMap[string, context.CancelFunc](),
100 }
101}
102
103func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
104 if call.Prompt == "" {
105 return nil, ErrEmptyPrompt
106 }
107 if call.SessionID == "" {
108 return nil, ErrSessionMissing
109 }
110
111 // Queue the message if busy
112 if a.IsSessionBusy(call.SessionID) {
113 existing, ok := a.messageQueue.Get(call.SessionID)
114 if !ok {
115 existing = []SessionAgentCall{}
116 }
117 existing = append(existing, call)
118 a.messageQueue.Set(call.SessionID, existing)
119 return nil, nil
120 }
121
122 if len(a.tools) > 0 {
123 // add anthropic caching to the last tool
124 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
125 }
126
127 agent := ai.NewAgent(
128 a.largeModel.Model,
129 ai.WithSystemPrompt(a.systemPrompt),
130 ai.WithTools(a.tools...),
131 )
132
133 sessionLock := sync.Mutex{}
134 currentSession, err := a.sessions.Get(ctx, call.SessionID)
135 if err != nil {
136 return nil, fmt.Errorf("failed to get session: %w", err)
137 }
138
139 msgs, err := a.getSessionMessages(ctx, currentSession)
140 if err != nil {
141 return nil, fmt.Errorf("failed to get session messages: %w", err)
142 }
143
144 var wg sync.WaitGroup
145 // Generate title if first message
146 if len(msgs) == 0 {
147 wg.Go(func() {
148 sessionLock.Lock()
149 a.generateTitle(ctx, ¤tSession, call.Prompt)
150 sessionLock.Unlock()
151 })
152 }
153
154 // Add the user message to the session
155 _, err = a.createUserMessage(ctx, call)
156 if err != nil {
157 return nil, err
158 }
159
160 // add the session to the context
161 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
162
163 genCtx, cancel := context.WithCancel(ctx)
164 a.activeRequests.Set(call.SessionID, cancel)
165
166 defer cancel()
167 defer a.activeRequests.Del(call.SessionID)
168
169 history, files := a.preparePrompt(msgs, call.Attachments...)
170
171 var currentAssistant *message.Message
172 var shouldSummarize bool
173 result, err := agent.Stream(genCtx, ai.AgentStreamCall{
174 Prompt: call.Prompt,
175 Files: files,
176 Messages: history,
177 ProviderOptions: call.ProviderOptions,
178 MaxOutputTokens: &call.MaxOutputTokens,
179 TopP: call.TopP,
180 Temperature: call.Temperature,
181 PresencePenalty: call.PresencePenalty,
182 TopK: call.TopK,
183 FrequencyPenalty: call.FrequencyPenalty,
184 // Before each step create the new assistant message
185 PrepareStep: func(callContext context.Context, options ai.PrepareStepFunctionOptions) (_ context.Context, prepared ai.PrepareStepResult, err error) {
186 var assistantMsg message.Message
187 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
188 Role: message.Assistant,
189 Parts: []message.ContentPart{},
190 Model: a.largeModel.ModelCfg.Model,
191 Provider: a.largeModel.ModelCfg.Provider,
192 })
193 if err != nil {
194 return callContext, prepared, err
195 }
196
197 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
198
199 currentAssistant = &assistantMsg
200
201 prepared.Messages = options.Messages
202 // reset all cached items
203 for i := range prepared.Messages {
204 prepared.Messages[i].ProviderOptions = nil
205 }
206
207 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
208 a.messageQueue.Del(call.SessionID)
209 for _, queued := range queuedCalls {
210 userMessage, createErr := a.createUserMessage(callContext, queued)
211 if createErr != nil {
212 return callContext, prepared, createErr
213 }
214 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
215 }
216
217 lastSystemRoleInx := 0
218 systemMessageUpdated := false
219 for i, msg := range prepared.Messages {
220 // only add cache control to the last message
221 if msg.Role == ai.MessageRoleSystem {
222 lastSystemRoleInx = i
223 } else if !systemMessageUpdated {
224 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
225 systemMessageUpdated = true
226 }
227 // than add cache control to the last 2 messages
228 if i > len(prepared.Messages)-3 {
229 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
230 }
231 }
232 return callContext, prepared, err
233 },
234 OnReasoningDelta: func(id string, text string) error {
235 currentAssistant.AppendReasoningContent(text)
236 return a.messages.Update(genCtx, *currentAssistant)
237 },
238 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
239 // handle anthropic signature
240 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
241 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
242 currentAssistant.AppendReasoningSignature(reasoning.Signature)
243 }
244 }
245 currentAssistant.FinishThinking()
246 return a.messages.Update(genCtx, *currentAssistant)
247 },
248 OnTextDelta: func(id string, text string) error {
249 currentAssistant.AppendContent(text)
250 return a.messages.Update(genCtx, *currentAssistant)
251 },
252 OnToolInputStart: func(id string, toolName string) error {
253 toolCall := message.ToolCall{
254 ID: id,
255 Name: toolName,
256 ProviderExecuted: false,
257 Finished: false,
258 }
259 currentAssistant.AddToolCall(toolCall)
260 return a.messages.Update(genCtx, *currentAssistant)
261 },
262 OnRetry: func(err *ai.APICallError, delay time.Duration) {
263 // TODO: implement
264 },
265 OnToolCall: func(tc ai.ToolCallContent) error {
266 toolCall := message.ToolCall{
267 ID: tc.ToolCallID,
268 Name: tc.ToolName,
269 Input: tc.Input,
270 ProviderExecuted: false,
271 Finished: true,
272 }
273 currentAssistant.AddToolCall(toolCall)
274 return a.messages.Update(genCtx, *currentAssistant)
275 },
276 OnToolResult: func(result ai.ToolResultContent) error {
277 var resultContent string
278 isError := false
279 switch result.Result.GetType() {
280 case ai.ToolResultContentTypeText:
281 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
282 if ok {
283 resultContent = r.Text
284 }
285 case ai.ToolResultContentTypeError:
286 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
287 if ok {
288 isError = true
289 resultContent = r.Error.Error()
290 }
291 case ai.ToolResultContentTypeMedia:
292 // TODO: handle this message type
293 }
294 toolResult := message.ToolResult{
295 ToolCallID: result.ToolCallID,
296 Name: result.ToolName,
297 Content: resultContent,
298 IsError: isError,
299 Metadata: result.ClientMetadata,
300 }
301 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
302 Role: message.Tool,
303 Parts: []message.ContentPart{
304 toolResult,
305 },
306 })
307 if createMsgErr != nil {
308 return createMsgErr
309 }
310 return nil
311 },
312 OnStepFinish: func(stepResult ai.StepResult) error {
313 finishReason := message.FinishReasonUnknown
314 switch stepResult.FinishReason {
315 case ai.FinishReasonLength:
316 finishReason = message.FinishReasonMaxTokens
317 case ai.FinishReasonStop:
318 finishReason = message.FinishReasonEndTurn
319 case ai.FinishReasonToolCalls:
320 finishReason = message.FinishReasonToolUse
321 }
322 currentAssistant.AddFinish(finishReason, "", "")
323 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
324 sessionLock.Lock()
325 _, sessionErr := a.sessions.Save(genCtx, currentSession)
326 sessionLock.Unlock()
327 if sessionErr != nil {
328 return sessionErr
329 }
330 return a.messages.Update(genCtx, *currentAssistant)
331 },
332 StopWhen: []ai.StopCondition{
333 func(_ []ai.StepResult) bool {
334 contextWindow := a.largeModel.CatwalkCfg.ContextWindow
335 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
336 percentage := (float64(tokens) / float64(contextWindow)) * 100
337 if (percentage > 80) && !a.disableAutoSummarize {
338 shouldSummarize = true
339 return true
340 }
341 return false
342 },
343 },
344 })
345 if err != nil {
346 isCancelErr := errors.Is(err, context.Canceled)
347 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
348 if currentAssistant == nil {
349 return result, err
350 }
351 toolCalls := currentAssistant.ToolCalls()
352 toolResults := currentAssistant.ToolResults()
353 // INFO: we use the parent context here because the genCtx has been cancelled
354 msgs, createErr := a.messages.List(ctx, currentSession.ID)
355 if createErr != nil {
356 return nil, createErr
357 }
358 for _, tc := range toolCalls {
359 if !tc.Finished {
360 tc.Finished = true
361 tc.Input = "{}"
362 }
363 currentAssistant.AddToolCall(tc)
364
365 found := false
366 for _, msg := range msgs {
367 if msg.Role == message.Tool {
368 for _, tr := range toolResults {
369 if tr.ToolCallID == tc.ID {
370 found = true
371 break
372 }
373 }
374 }
375 if found {
376 break
377 }
378 }
379 if found {
380 continue
381 }
382 content := "There was an error while executing the tool"
383 if isCancelErr {
384 content = "Tool execution canceled by user"
385 } else if isPermissionErr {
386 content = "Permission denied"
387 }
388 toolResult := message.ToolResult{
389 ToolCallID: tc.ID,
390 Name: tc.Name,
391 Content: content,
392 IsError: true,
393 }
394 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
395 Role: message.Tool,
396 Parts: []message.ContentPart{
397 toolResult,
398 },
399 })
400 if createErr != nil {
401 return nil, createErr
402 }
403 }
404 if isCancelErr {
405 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
406 } else if isPermissionErr {
407 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
408 } else {
409 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
410 }
411 // INFO: we use the parent context here because the genCtx has been cancelled
412 updateErr := a.messages.Update(ctx, *currentAssistant)
413 if updateErr != nil {
414 return nil, updateErr
415 }
416 return nil, err
417 }
418 wg.Wait()
419
420 if shouldSummarize {
421 a.activeRequests.Del(call.SessionID)
422 if summarizeErr := a.Summarize(genCtx, call.SessionID); summarizeErr != nil {
423 return nil, summarizeErr
424 }
425 }
426
427 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
428 if !ok || len(queuedMessages) == 0 {
429 return result, err
430 }
431 // there are queued messages restart the loop
432 firstQueuedMessage := queuedMessages[0]
433 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
434 return a.Run(genCtx, firstQueuedMessage)
435}
436
437func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
438 if a.IsSessionBusy(sessionID) {
439 return ErrSessionBusy
440 }
441
442 currentSession, err := a.sessions.Get(ctx, sessionID)
443 if err != nil {
444 return fmt.Errorf("failed to get session: %w", err)
445 }
446 msgs, err := a.getSessionMessages(ctx, currentSession)
447 if err != nil {
448 return err
449 }
450 if len(msgs) == 0 {
451 // nothing to summarize
452 return nil
453 }
454
455 aiMsgs, _ := a.preparePrompt(msgs)
456
457 genCtx, cancel := context.WithCancel(ctx)
458 a.activeRequests.Set(sessionID, cancel)
459 defer a.activeRequests.Del(sessionID)
460 defer cancel()
461
462 agent := ai.NewAgent(a.largeModel.Model,
463 ai.WithSystemPrompt(string(summaryPrompt)),
464 )
465 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
466 Role: message.Assistant,
467 Model: a.largeModel.Model.Model(),
468 Provider: a.largeModel.Model.Provider(),
469 IsSummaryMessage: true,
470 })
471 if err != nil {
472 return err
473 }
474
475 resp, err := agent.Stream(genCtx, ai.AgentStreamCall{
476 Prompt: "Provide a detailed summary of our conversation above.",
477 Messages: aiMsgs,
478 OnReasoningDelta: func(id string, text string) error {
479 summaryMessage.AppendReasoningContent(text)
480 return a.messages.Update(genCtx, summaryMessage)
481 },
482 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
483 // handle anthropic signature
484 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
485 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
486 summaryMessage.AppendReasoningSignature(signature.Signature)
487 }
488 }
489 summaryMessage.FinishThinking()
490 return a.messages.Update(genCtx, summaryMessage)
491 },
492 OnTextDelta: func(id, text string) error {
493 summaryMessage.AppendContent(text)
494 return a.messages.Update(genCtx, summaryMessage)
495 },
496 })
497 if err != nil {
498 isCancelErr := errors.Is(err, context.Canceled)
499 if isCancelErr {
500 // User cancelled summarize we need to remove the summary message
501 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
502 return deleteErr
503 }
504 return err
505 }
506
507 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
508 err = a.messages.Update(genCtx, summaryMessage)
509 if err != nil {
510 return err
511 }
512
513 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
514
515 // just in case get just the last usage
516 usage := resp.Response.Usage
517 currentSession.SummaryMessageID = summaryMessage.ID
518 currentSession.CompletionTokens = usage.OutputTokens
519 currentSession.PromptTokens = 0
520 _, err = a.sessions.Save(genCtx, currentSession)
521 return err
522}
523
524func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
525 return ai.ProviderOptions{
526 anthropic.Name: &anthropic.ProviderCacheControlOptions{
527 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
528 },
529 }
530}
531
532func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
533 var attachmentParts []message.ContentPart
534 for _, attachment := range call.Attachments {
535 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
536 }
537 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
538 parts = append(parts, attachmentParts...)
539 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
540 Role: message.User,
541 Parts: parts,
542 })
543 if err != nil {
544 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
545 }
546 return msg, nil
547}
548
549func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
550 var history []ai.Message
551 for _, m := range msgs {
552 if len(m.Parts) == 0 {
553 continue
554 }
555 // Assistant message without content or tool calls (cancelled before it returned anything)
556 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
557 continue
558 }
559 history = append(history, m.ToAIMessage()...)
560 }
561
562 var files []ai.FilePart
563 for _, attachment := range attachments {
564 files = append(files, ai.FilePart{
565 Filename: attachment.FileName,
566 Data: attachment.Content,
567 MediaType: attachment.MimeType,
568 })
569 }
570
571 return history, files
572}
573
574func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
575 msgs, err := a.messages.List(ctx, session.ID)
576 if err != nil {
577 return nil, fmt.Errorf("failed to list messages: %w", err)
578 }
579
580 if session.SummaryMessageID != "" {
581 summaryMsgInex := -1
582 for i, msg := range msgs {
583 if msg.ID == session.SummaryMessageID {
584 summaryMsgInex = i
585 break
586 }
587 }
588 if summaryMsgInex != -1 {
589 msgs = msgs[summaryMsgInex:]
590 msgs[0].Role = message.User
591 }
592 }
593 return msgs, nil
594}
595
596func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
597 if prompt == "" {
598 return
599 }
600
601 var maxOutput int64 = 40
602 if a.smallModel.CatwalkCfg.CanReason {
603 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
604 }
605
606 agent := ai.NewAgent(a.smallModel.Model,
607 ai.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
608 ai.WithMaxOutputTokens(maxOutput),
609 )
610
611 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
612 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
613 })
614 if err != nil {
615 slog.Error("error generating title", "err", err)
616 return
617 }
618
619 data, _ := json.Marshal(resp)
620
621 slog.Info("Title Response")
622 slog.Info(string(data))
623 title := resp.Response.Content.Text()
624
625 title = strings.ReplaceAll(title, "\n", " ")
626 slog.Info(title)
627
628 // remove thinking tags if present
629 if idx := strings.Index(title, "</think>"); idx > 0 {
630 title = title[idx+len("</think>"):]
631 }
632
633 title = strings.TrimSpace(title)
634 if title == "" {
635 slog.Warn("failed to generate title", "warn", "empty title")
636 return
637 }
638
639 session.Title = title
640 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
641 _, saveErr := a.sessions.Save(ctx, *session)
642 if saveErr != nil {
643 slog.Error("failed to save session title & usage", "error", saveErr)
644 return
645 }
646}
647
648func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
649 modelConfig := model.CatwalkCfg
650 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
651 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
652 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
653 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
654 session.Cost += cost
655 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
656 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
657}
658
659func (a *sessionAgent) Cancel(sessionID string) {
660 // Cancel regular requests
661 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
662 slog.Info("Request cancellation initiated", "session_id", sessionID)
663 cancel()
664 }
665
666 // Also check for summarize requests
667 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
668 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
669 cancel()
670 }
671
672 if a.QueuedPrompts(sessionID) > 0 {
673 slog.Info("Clearing queued prompts", "session_id", sessionID)
674 a.messageQueue.Del(sessionID)
675 }
676}
677
678func (a *sessionAgent) ClearQueue(sessionID string) {
679 if a.QueuedPrompts(sessionID) > 0 {
680 slog.Info("Clearing queued prompts", "session_id", sessionID)
681 a.messageQueue.Del(sessionID)
682 }
683}
684
685func (a *sessionAgent) CancelAll() {
686 if !a.IsBusy() {
687 return
688 }
689 for key := range a.activeRequests.Seq2() {
690 a.Cancel(key) // key is sessionID
691 }
692
693 timeout := time.After(5 * time.Second)
694 for a.IsBusy() {
695 select {
696 case <-timeout:
697 return
698 default:
699 time.Sleep(200 * time.Millisecond)
700 }
701 }
702}
703
704func (a *sessionAgent) IsBusy() bool {
705 var busy bool
706 for cancelFunc := range a.activeRequests.Seq() {
707 if cancelFunc != nil {
708 busy = true
709 break
710 }
711 }
712 return busy
713}
714
715func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
716 _, busy := a.activeRequests.Get(sessionID)
717 return busy
718}
719
720func (a *sessionAgent) QueuedPrompts(sessionID string) int {
721 l, ok := a.messageQueue.Get(sessionID)
722 if !ok {
723 return 0
724 }
725 return len(l)
726}
727
728func (a *sessionAgent) SetModels(large Model, small Model) {
729 a.largeModel = large
730 a.smallModel = small
731}
732
733func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
734 a.tools = tools
735}
736
737func (a *sessionAgent) Model() Model {
738 return a.largeModel
739}