1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "strings"
10 "sync"
11 "time"
12
13 "charm.land/fantasy"
14 "charm.land/fantasy/providers/anthropic"
15 "charm.land/fantasy/providers/openai"
16 "github.com/charmbracelet/catwalk/pkg/catwalk"
17 "github.com/charmbracelet/crush/internal/agent/tools"
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/message"
21 "github.com/charmbracelet/crush/internal/permission"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25//go:embed templates/title.md
26var titlePrompt []byte
27
28//go:embed templates/summary.md
29var summaryPrompt []byte
30
31type SessionAgentCall struct {
32 SessionID string
33 Prompt string
34 ProviderOptions fantasy.ProviderOptions
35 Attachments []message.Attachment
36 MaxOutputTokens int64
37 Temperature *float64
38 TopP *float64
39 TopK *int64
40 FrequencyPenalty *float64
41 PresencePenalty *float64
42}
43
44type SessionAgent interface {
45 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
46 SetModels(large Model, small Model)
47 SetTools(tools []fantasy.AgentTool)
48 Cancel(sessionID string)
49 CancelAll()
50 IsSessionBusy(sessionID string) bool
51 IsBusy() bool
52 QueuedPrompts(sessionID string) int
53 ClearQueue(sessionID string)
54 Summarize(context.Context, string) error
55 Model() Model
56}
57
58type Model struct {
59 Model fantasy.LanguageModel
60 CatwalkCfg catwalk.Model
61 ModelCfg config.SelectedModel
62}
63
64type sessionAgent struct {
65 largeModel Model
66 smallModel Model
67 systemPrompt string
68 tools []fantasy.AgentTool
69 sessions session.Service
70 messages message.Service
71 disableAutoSummarize bool
72
73 messageQueue *csync.Map[string, []SessionAgentCall]
74 activeRequests *csync.Map[string, context.CancelFunc]
75}
76
77type SessionAgentOptions struct {
78 LargeModel Model
79 SmallModel Model
80 SystemPrompt string
81 DisableAutoSummarize bool
82 Sessions session.Service
83 Messages message.Service
84 Tools []fantasy.AgentTool
85}
86
87func NewSessionAgent(
88 opts SessionAgentOptions,
89) SessionAgent {
90 return &sessionAgent{
91 largeModel: opts.LargeModel,
92 smallModel: opts.SmallModel,
93 systemPrompt: opts.SystemPrompt,
94 sessions: opts.Sessions,
95 messages: opts.Messages,
96 disableAutoSummarize: opts.DisableAutoSummarize,
97 tools: opts.Tools,
98 messageQueue: csync.NewMap[string, []SessionAgentCall](),
99 activeRequests: csync.NewMap[string, context.CancelFunc](),
100 }
101}
102
103func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
104 if call.Prompt == "" {
105 return nil, ErrEmptyPrompt
106 }
107 if call.SessionID == "" {
108 return nil, ErrSessionMissing
109 }
110
111 // Queue the message if busy
112 if a.IsSessionBusy(call.SessionID) {
113 existing, ok := a.messageQueue.Get(call.SessionID)
114 if !ok {
115 existing = []SessionAgentCall{}
116 }
117 existing = append(existing, call)
118 a.messageQueue.Set(call.SessionID, existing)
119 return nil, nil
120 }
121
122 if len(a.tools) > 0 {
123 // add anthropic caching to the last tool
124 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
125 }
126
127 agent := fantasy.NewAgent(
128 a.largeModel.Model,
129 fantasy.WithSystemPrompt(a.systemPrompt),
130 fantasy.WithTools(a.tools...),
131 )
132
133 sessionLock := sync.Mutex{}
134 currentSession, err := a.sessions.Get(ctx, call.SessionID)
135 if err != nil {
136 return nil, fmt.Errorf("failed to get session: %w", err)
137 }
138
139 msgs, err := a.getSessionMessages(ctx, currentSession)
140 if err != nil {
141 return nil, fmt.Errorf("failed to get session messages: %w", err)
142 }
143
144 var wg sync.WaitGroup
145 // Generate title if first message
146 if len(msgs) == 0 {
147 wg.Go(func() {
148 sessionLock.Lock()
149 a.generateTitle(ctx, ¤tSession, call.Prompt)
150 sessionLock.Unlock()
151 })
152 }
153
154 // Add the user message to the session
155 _, err = a.createUserMessage(ctx, call)
156 if err != nil {
157 return nil, err
158 }
159
160 // add the session to the context
161 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
162
163 genCtx, cancel := context.WithCancel(ctx)
164 a.activeRequests.Set(call.SessionID, cancel)
165
166 defer cancel()
167 defer a.activeRequests.Del(call.SessionID)
168
169 history, files := a.preparePrompt(msgs, call.Attachments...)
170
171 var currentAssistant *message.Message
172 var shouldSummarize bool
173 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
174 Prompt: call.Prompt,
175 Files: files,
176 Messages: history,
177 ProviderOptions: call.ProviderOptions,
178 MaxOutputTokens: &call.MaxOutputTokens,
179 TopP: call.TopP,
180 Temperature: call.Temperature,
181 PresencePenalty: call.PresencePenalty,
182 TopK: call.TopK,
183 FrequencyPenalty: call.FrequencyPenalty,
184 // Before each step create the new assistant message
185 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
186 var assistantMsg message.Message
187 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
188 Role: message.Assistant,
189 Parts: []message.ContentPart{},
190 Model: a.largeModel.ModelCfg.Model,
191 Provider: a.largeModel.ModelCfg.Provider,
192 })
193 if err != nil {
194 return callContext, prepared, err
195 }
196
197 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
198
199 currentAssistant = &assistantMsg
200
201 prepared.Messages = options.Messages
202 // reset all cached items
203 for i := range prepared.Messages {
204 prepared.Messages[i].ProviderOptions = nil
205 }
206
207 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
208 a.messageQueue.Del(call.SessionID)
209 for _, queued := range queuedCalls {
210 userMessage, createErr := a.createUserMessage(callContext, queued)
211 if createErr != nil {
212 return callContext, prepared, createErr
213 }
214 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
215 }
216
217 lastSystemRoleInx := 0
218 systemMessageUpdated := false
219 for i, msg := range prepared.Messages {
220 // only add cache control to the last message
221 if msg.Role == fantasy.MessageRoleSystem {
222 lastSystemRoleInx = i
223 } else if !systemMessageUpdated {
224 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
225 systemMessageUpdated = true
226 }
227 // than add cache control to the last 2 messages
228 if i > len(prepared.Messages)-3 {
229 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
230 }
231 }
232 return callContext, prepared, err
233 },
234 OnReasoningDelta: func(id string, text string) error {
235 currentAssistant.AppendReasoningContent(text)
236 return a.messages.Update(genCtx, *currentAssistant)
237 },
238 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
239 // handle anthropic signature
240 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
241 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
242 currentAssistant.AppendReasoningSignature(reasoning.Signature)
243 }
244 }
245 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
246 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
247 currentAssistant.SetReasoningResponsesData(reasoning)
248 }
249 }
250 currentAssistant.FinishThinking()
251 return a.messages.Update(genCtx, *currentAssistant)
252 },
253 OnTextDelta: func(id string, text string) error {
254 currentAssistant.AppendContent(text)
255 return a.messages.Update(genCtx, *currentAssistant)
256 },
257 OnToolInputStart: func(id string, toolName string) error {
258 toolCall := message.ToolCall{
259 ID: id,
260 Name: toolName,
261 ProviderExecuted: false,
262 Finished: false,
263 }
264 currentAssistant.AddToolCall(toolCall)
265 return a.messages.Update(genCtx, *currentAssistant)
266 },
267 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
268 // TODO: implement
269 },
270 OnToolCall: func(tc fantasy.ToolCallContent) error {
271 toolCall := message.ToolCall{
272 ID: tc.ToolCallID,
273 Name: tc.ToolName,
274 Input: tc.Input,
275 ProviderExecuted: false,
276 Finished: true,
277 }
278 currentAssistant.AddToolCall(toolCall)
279 return a.messages.Update(genCtx, *currentAssistant)
280 },
281 OnToolResult: func(result fantasy.ToolResultContent) error {
282 var resultContent string
283 isError := false
284 switch result.Result.GetType() {
285 case fantasy.ToolResultContentTypeText:
286 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
287 if ok {
288 resultContent = r.Text
289 }
290 case fantasy.ToolResultContentTypeError:
291 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
292 if ok {
293 isError = true
294 resultContent = r.Error.Error()
295 }
296 case fantasy.ToolResultContentTypeMedia:
297 // TODO: handle this message type
298 }
299 toolResult := message.ToolResult{
300 ToolCallID: result.ToolCallID,
301 Name: result.ToolName,
302 Content: resultContent,
303 IsError: isError,
304 Metadata: result.ClientMetadata,
305 }
306 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
307 Role: message.Tool,
308 Parts: []message.ContentPart{
309 toolResult,
310 },
311 })
312 if createMsgErr != nil {
313 return createMsgErr
314 }
315 return nil
316 },
317 OnStepFinish: func(stepResult fantasy.StepResult) error {
318 finishReason := message.FinishReasonUnknown
319 switch stepResult.FinishReason {
320 case fantasy.FinishReasonLength:
321 finishReason = message.FinishReasonMaxTokens
322 case fantasy.FinishReasonStop:
323 finishReason = message.FinishReasonEndTurn
324 case fantasy.FinishReasonToolCalls:
325 finishReason = message.FinishReasonToolUse
326 }
327 currentAssistant.AddFinish(finishReason, "", "")
328 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
329 sessionLock.Lock()
330 _, sessionErr := a.sessions.Save(genCtx, currentSession)
331 sessionLock.Unlock()
332 if sessionErr != nil {
333 return sessionErr
334 }
335 return a.messages.Update(genCtx, *currentAssistant)
336 },
337 StopWhen: []fantasy.StopCondition{
338 func(_ []fantasy.StepResult) bool {
339 contextWindow := a.largeModel.CatwalkCfg.ContextWindow
340 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
341 percentage := (float64(tokens) / float64(contextWindow)) * 100
342 if (percentage > 80) && !a.disableAutoSummarize {
343 shouldSummarize = true
344 return true
345 }
346 return false
347 },
348 },
349 })
350 if err != nil {
351 isCancelErr := errors.Is(err, context.Canceled)
352 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
353 if currentAssistant == nil {
354 return result, err
355 }
356 toolCalls := currentAssistant.ToolCalls()
357 // INFO: we use the parent context here because the genCtx has been cancelled
358 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
359 if createErr != nil {
360 return nil, createErr
361 }
362 for _, tc := range toolCalls {
363 if !tc.Finished {
364 tc.Finished = true
365 tc.Input = "{}"
366 currentAssistant.AddToolCall(tc)
367 updateErr := a.messages.Update(ctx, *currentAssistant)
368 if updateErr != nil {
369 return nil, updateErr
370 }
371 }
372
373 found := false
374 for _, msg := range msgs {
375 if msg.Role == message.Tool {
376 for _, tr := range msg.ToolResults() {
377 if tr.ToolCallID == tc.ID {
378 found = true
379 break
380 }
381 }
382 }
383 if found {
384 break
385 }
386 }
387 if found {
388 continue
389 }
390 content := "There was an error while executing the tool"
391 if isCancelErr {
392 content = "Tool execution canceled by user"
393 } else if isPermissionErr {
394 content = "Permission denied"
395 }
396 toolResult := message.ToolResult{
397 ToolCallID: tc.ID,
398 Name: tc.Name,
399 Content: content,
400 IsError: true,
401 }
402 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
403 Role: message.Tool,
404 Parts: []message.ContentPart{
405 toolResult,
406 },
407 })
408 if createErr != nil {
409 return nil, createErr
410 }
411 }
412 if isCancelErr {
413 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
414 } else if isPermissionErr {
415 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
416 } else {
417 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
418 }
419 // INFO: we use the parent context here because the genCtx has been cancelled
420 updateErr := a.messages.Update(ctx, *currentAssistant)
421 if updateErr != nil {
422 return nil, updateErr
423 }
424 return nil, err
425 }
426 wg.Wait()
427
428 if shouldSummarize {
429 a.activeRequests.Del(call.SessionID)
430 if summarizeErr := a.Summarize(genCtx, call.SessionID); summarizeErr != nil {
431 return nil, summarizeErr
432 }
433 }
434
435 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
436 if !ok || len(queuedMessages) == 0 {
437 return result, err
438 }
439 // there are queued messages restart the loop
440 firstQueuedMessage := queuedMessages[0]
441 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
442 return a.Run(genCtx, firstQueuedMessage)
443}
444
445func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
446 if a.IsSessionBusy(sessionID) {
447 return ErrSessionBusy
448 }
449
450 currentSession, err := a.sessions.Get(ctx, sessionID)
451 if err != nil {
452 return fmt.Errorf("failed to get session: %w", err)
453 }
454 msgs, err := a.getSessionMessages(ctx, currentSession)
455 if err != nil {
456 return err
457 }
458 if len(msgs) == 0 {
459 // nothing to summarize
460 return nil
461 }
462
463 aiMsgs, _ := a.preparePrompt(msgs)
464
465 genCtx, cancel := context.WithCancel(ctx)
466 a.activeRequests.Set(sessionID, cancel)
467 defer a.activeRequests.Del(sessionID)
468 defer cancel()
469
470 agent := fantasy.NewAgent(a.largeModel.Model,
471 fantasy.WithSystemPrompt(string(summaryPrompt)),
472 )
473 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
474 Role: message.Assistant,
475 Model: a.largeModel.Model.Model(),
476 Provider: a.largeModel.Model.Provider(),
477 IsSummaryMessage: true,
478 })
479 if err != nil {
480 return err
481 }
482
483 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
484 Prompt: "Provide a detailed summary of our conversation above.",
485 Messages: aiMsgs,
486 OnReasoningDelta: func(id string, text string) error {
487 summaryMessage.AppendReasoningContent(text)
488 return a.messages.Update(genCtx, summaryMessage)
489 },
490 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
491 // handle anthropic signature
492 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
493 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
494 summaryMessage.AppendReasoningSignature(signature.Signature)
495 }
496 }
497 summaryMessage.FinishThinking()
498 return a.messages.Update(genCtx, summaryMessage)
499 },
500 OnTextDelta: func(id, text string) error {
501 summaryMessage.AppendContent(text)
502 return a.messages.Update(genCtx, summaryMessage)
503 },
504 })
505 if err != nil {
506 isCancelErr := errors.Is(err, context.Canceled)
507 if isCancelErr {
508 // User cancelled summarize we need to remove the summary message
509 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
510 return deleteErr
511 }
512 return err
513 }
514
515 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
516 err = a.messages.Update(genCtx, summaryMessage)
517 if err != nil {
518 return err
519 }
520
521 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
522
523 // just in case get just the last usage
524 usage := resp.Response.Usage
525 currentSession.SummaryMessageID = summaryMessage.ID
526 currentSession.CompletionTokens = usage.OutputTokens
527 currentSession.PromptTokens = 0
528 _, err = a.sessions.Save(genCtx, currentSession)
529 return err
530}
531
532func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
533 return fantasy.ProviderOptions{
534 anthropic.Name: &anthropic.ProviderCacheControlOptions{
535 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
536 },
537 }
538}
539
540func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
541 var attachmentParts []message.ContentPart
542 for _, attachment := range call.Attachments {
543 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
544 }
545 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
546 parts = append(parts, attachmentParts...)
547 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
548 Role: message.User,
549 Parts: parts,
550 })
551 if err != nil {
552 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
553 }
554 return msg, nil
555}
556
557func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
558 var history []fantasy.Message
559 for _, m := range msgs {
560 if len(m.Parts) == 0 {
561 continue
562 }
563 // Assistant message without content or tool calls (cancelled before it returned anything)
564 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
565 continue
566 }
567 history = append(history, m.ToAIMessage()...)
568 }
569
570 var files []fantasy.FilePart
571 for _, attachment := range attachments {
572 files = append(files, fantasy.FilePart{
573 Filename: attachment.FileName,
574 Data: attachment.Content,
575 MediaType: attachment.MimeType,
576 })
577 }
578
579 return history, files
580}
581
582func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
583 msgs, err := a.messages.List(ctx, session.ID)
584 if err != nil {
585 return nil, fmt.Errorf("failed to list messages: %w", err)
586 }
587
588 if session.SummaryMessageID != "" {
589 summaryMsgInex := -1
590 for i, msg := range msgs {
591 if msg.ID == session.SummaryMessageID {
592 summaryMsgInex = i
593 break
594 }
595 }
596 if summaryMsgInex != -1 {
597 msgs = msgs[summaryMsgInex:]
598 msgs[0].Role = message.User
599 }
600 }
601 return msgs, nil
602}
603
604func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
605 if prompt == "" {
606 return
607 }
608
609 var maxOutput int64 = 40
610 if a.smallModel.CatwalkCfg.CanReason {
611 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
612 }
613
614 agent := fantasy.NewAgent(a.smallModel.Model,
615 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
616 fantasy.WithMaxOutputTokens(maxOutput),
617 )
618
619 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
620 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
621 })
622 if err != nil {
623 slog.Error("error generating title", "err", err)
624 return
625 }
626
627 title := resp.Response.Content.Text()
628
629 title = strings.ReplaceAll(title, "\n", " ")
630
631 // remove thinking tags if present
632 if idx := strings.Index(title, "</think>"); idx > 0 {
633 title = title[idx+len("</think>"):]
634 }
635
636 title = strings.TrimSpace(title)
637 if title == "" {
638 slog.Warn("failed to generate title", "warn", "empty title")
639 return
640 }
641
642 session.Title = title
643 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
644 _, saveErr := a.sessions.Save(ctx, *session)
645 if saveErr != nil {
646 slog.Error("failed to save session title & usage", "error", saveErr)
647 return
648 }
649}
650
651func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
652 modelConfig := model.CatwalkCfg
653 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
654 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
655 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
656 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
657 session.Cost += cost
658 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
659 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
660}
661
662func (a *sessionAgent) Cancel(sessionID string) {
663 // Cancel regular requests
664 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
665 slog.Info("Request cancellation initiated", "session_id", sessionID)
666 cancel()
667 }
668
669 // Also check for summarize requests
670 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
671 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
672 cancel()
673 }
674
675 if a.QueuedPrompts(sessionID) > 0 {
676 slog.Info("Clearing queued prompts", "session_id", sessionID)
677 a.messageQueue.Del(sessionID)
678 }
679}
680
681func (a *sessionAgent) ClearQueue(sessionID string) {
682 if a.QueuedPrompts(sessionID) > 0 {
683 slog.Info("Clearing queued prompts", "session_id", sessionID)
684 a.messageQueue.Del(sessionID)
685 }
686}
687
688func (a *sessionAgent) CancelAll() {
689 if !a.IsBusy() {
690 return
691 }
692 for key := range a.activeRequests.Seq2() {
693 a.Cancel(key) // key is sessionID
694 }
695
696 timeout := time.After(5 * time.Second)
697 for a.IsBusy() {
698 select {
699 case <-timeout:
700 return
701 default:
702 time.Sleep(200 * time.Millisecond)
703 }
704 }
705}
706
707func (a *sessionAgent) IsBusy() bool {
708 var busy bool
709 for cancelFunc := range a.activeRequests.Seq() {
710 if cancelFunc != nil {
711 busy = true
712 break
713 }
714 }
715 return busy
716}
717
718func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
719 _, busy := a.activeRequests.Get(sessionID)
720 return busy
721}
722
723func (a *sessionAgent) QueuedPrompts(sessionID string) int {
724 l, ok := a.messageQueue.Get(sessionID)
725 if !ok {
726 return 0
727 }
728 return len(l)
729}
730
731func (a *sessionAgent) SetModels(large Model, small Model) {
732 a.largeModel = large
733 a.smallModel = small
734}
735
736func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
737 a.tools = tools
738}
739
740func (a *sessionAgent) Model() Model {
741 return a.largeModel
742}