1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "strings"
10 "sync"
11 "time"
12
13 "charm.land/fantasy"
14 "charm.land/fantasy/providers/anthropic"
15 "charm.land/fantasy/providers/openai"
16 "github.com/charmbracelet/catwalk/pkg/catwalk"
17 "github.com/charmbracelet/crush/internal/agent/tools"
18 "github.com/charmbracelet/crush/internal/config"
19 "github.com/charmbracelet/crush/internal/csync"
20 "github.com/charmbracelet/crush/internal/message"
21 "github.com/charmbracelet/crush/internal/permission"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25//go:embed templates/title.md
26var titlePrompt []byte
27
28//go:embed templates/summary.md
29var summaryPrompt []byte
30
31type SessionAgentCall struct {
32 SessionID string
33 Prompt string
34 ProviderOptions fantasy.ProviderOptions
35 Attachments []message.Attachment
36 MaxOutputTokens int64
37 Temperature *float64
38 TopP *float64
39 TopK *int64
40 FrequencyPenalty *float64
41 PresencePenalty *float64
42}
43
44type SessionAgent interface {
45 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
46 SetModels(large Model, small Model)
47 SetTools(tools []fantasy.AgentTool)
48 Cancel(sessionID string)
49 CancelAll()
50 IsSessionBusy(sessionID string) bool
51 IsBusy() bool
52 QueuedPrompts(sessionID string) int
53 ClearQueue(sessionID string)
54 Summarize(context.Context, string, fantasy.ProviderOptions) error
55 Model() Model
56}
57
58type Model struct {
59 Model fantasy.LanguageModel
60 CatwalkCfg catwalk.Model
61 ModelCfg config.SelectedModel
62}
63
64type sessionAgent struct {
65 largeModel Model
66 smallModel Model
67 systemPrompt string
68 tools []fantasy.AgentTool
69 sessions session.Service
70 messages message.Service
71 disableAutoSummarize bool
72
73 messageQueue *csync.Map[string, []SessionAgentCall]
74 activeRequests *csync.Map[string, context.CancelFunc]
75}
76
77type SessionAgentOptions struct {
78 LargeModel Model
79 SmallModel Model
80 SystemPrompt string
81 DisableAutoSummarize bool
82 Sessions session.Service
83 Messages message.Service
84 Tools []fantasy.AgentTool
85}
86
87func NewSessionAgent(
88 opts SessionAgentOptions,
89) SessionAgent {
90 return &sessionAgent{
91 largeModel: opts.LargeModel,
92 smallModel: opts.SmallModel,
93 systemPrompt: opts.SystemPrompt,
94 sessions: opts.Sessions,
95 messages: opts.Messages,
96 disableAutoSummarize: opts.DisableAutoSummarize,
97 tools: opts.Tools,
98 messageQueue: csync.NewMap[string, []SessionAgentCall](),
99 activeRequests: csync.NewMap[string, context.CancelFunc](),
100 }
101}
102
103func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
104 if call.Prompt == "" {
105 return nil, ErrEmptyPrompt
106 }
107 if call.SessionID == "" {
108 return nil, ErrSessionMissing
109 }
110
111 // Queue the message if busy
112 if a.IsSessionBusy(call.SessionID) {
113 existing, ok := a.messageQueue.Get(call.SessionID)
114 if !ok {
115 existing = []SessionAgentCall{}
116 }
117 existing = append(existing, call)
118 a.messageQueue.Set(call.SessionID, existing)
119 return nil, nil
120 }
121
122 if len(a.tools) > 0 {
123 // add anthropic caching to the last tool
124 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
125 }
126
127 agent := fantasy.NewAgent(
128 a.largeModel.Model,
129 fantasy.WithSystemPrompt(a.systemPrompt),
130 fantasy.WithTools(a.tools...),
131 )
132
133 sessionLock := sync.Mutex{}
134 currentSession, err := a.sessions.Get(ctx, call.SessionID)
135 if err != nil {
136 return nil, fmt.Errorf("failed to get session: %w", err)
137 }
138
139 msgs, err := a.getSessionMessages(ctx, currentSession)
140 if err != nil {
141 return nil, fmt.Errorf("failed to get session messages: %w", err)
142 }
143
144 var wg sync.WaitGroup
145 // Generate title if first message
146 if len(msgs) == 0 {
147 wg.Go(func() {
148 sessionLock.Lock()
149 a.generateTitle(ctx, ¤tSession, call.Prompt)
150 sessionLock.Unlock()
151 })
152 }
153
154 // Add the user message to the session
155 _, err = a.createUserMessage(ctx, call)
156 if err != nil {
157 return nil, err
158 }
159
160 // add the session to the context
161 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
162
163 genCtx, cancel := context.WithCancel(ctx)
164 a.activeRequests.Set(call.SessionID, cancel)
165
166 defer cancel()
167 defer a.activeRequests.Del(call.SessionID)
168
169 history, files := a.preparePrompt(msgs, call.Attachments...)
170
171 var currentAssistant *message.Message
172 var shouldSummarize bool
173 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
174 Prompt: call.Prompt,
175 Files: files,
176 Messages: history,
177 ProviderOptions: call.ProviderOptions,
178 MaxOutputTokens: &call.MaxOutputTokens,
179 TopP: call.TopP,
180 Temperature: call.Temperature,
181 PresencePenalty: call.PresencePenalty,
182 TopK: call.TopK,
183 FrequencyPenalty: call.FrequencyPenalty,
184 // Before each step create the new assistant message
185 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
186 prepared.Messages = options.Messages
187 // reset all cached items
188 for i := range prepared.Messages {
189 prepared.Messages[i].ProviderOptions = nil
190 }
191
192 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
193 a.messageQueue.Del(call.SessionID)
194 for _, queued := range queuedCalls {
195 userMessage, createErr := a.createUserMessage(callContext, queued)
196 if createErr != nil {
197 return callContext, prepared, createErr
198 }
199 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
200 }
201
202 lastSystemRoleInx := 0
203 systemMessageUpdated := false
204 for i, msg := range prepared.Messages {
205 // only add cache control to the last message
206 if msg.Role == fantasy.MessageRoleSystem {
207 lastSystemRoleInx = i
208 } else if !systemMessageUpdated {
209 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
210 systemMessageUpdated = true
211 }
212 // than add cache control to the last 2 messages
213 if i > len(prepared.Messages)-3 {
214 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
215 }
216 }
217
218 var assistantMsg message.Message
219 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
220 Role: message.Assistant,
221 Parts: []message.ContentPart{},
222 Model: a.largeModel.ModelCfg.Model,
223 Provider: a.largeModel.ModelCfg.Provider,
224 })
225 if err != nil {
226 return callContext, prepared, err
227 }
228 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
229 currentAssistant = &assistantMsg
230 return callContext, prepared, err
231 },
232 OnReasoningDelta: func(id string, text string) error {
233 currentAssistant.AppendReasoningContent(text)
234 return a.messages.Update(genCtx, *currentAssistant)
235 },
236 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
237 // handle anthropic signature
238 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
239 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
240 currentAssistant.AppendReasoningSignature(reasoning.Signature)
241 }
242 }
243 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
244 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
245 currentAssistant.SetReasoningResponsesData(reasoning)
246 }
247 }
248 currentAssistant.FinishThinking()
249 return a.messages.Update(genCtx, *currentAssistant)
250 },
251 OnTextDelta: func(id string, text string) error {
252 currentAssistant.AppendContent(text)
253 return a.messages.Update(genCtx, *currentAssistant)
254 },
255 OnToolInputStart: func(id string, toolName string) error {
256 toolCall := message.ToolCall{
257 ID: id,
258 Name: toolName,
259 ProviderExecuted: false,
260 Finished: false,
261 }
262 currentAssistant.AddToolCall(toolCall)
263 return a.messages.Update(genCtx, *currentAssistant)
264 },
265 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
266 // TODO: implement
267 },
268 OnToolCall: func(tc fantasy.ToolCallContent) error {
269 toolCall := message.ToolCall{
270 ID: tc.ToolCallID,
271 Name: tc.ToolName,
272 Input: tc.Input,
273 ProviderExecuted: false,
274 Finished: true,
275 }
276 currentAssistant.AddToolCall(toolCall)
277 return a.messages.Update(genCtx, *currentAssistant)
278 },
279 OnToolResult: func(result fantasy.ToolResultContent) error {
280 var resultContent string
281 isError := false
282 switch result.Result.GetType() {
283 case fantasy.ToolResultContentTypeText:
284 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
285 if ok {
286 resultContent = r.Text
287 }
288 case fantasy.ToolResultContentTypeError:
289 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
290 if ok {
291 isError = true
292 resultContent = r.Error.Error()
293 }
294 case fantasy.ToolResultContentTypeMedia:
295 // TODO: handle this message type
296 }
297 toolResult := message.ToolResult{
298 ToolCallID: result.ToolCallID,
299 Name: result.ToolName,
300 Content: resultContent,
301 IsError: isError,
302 Metadata: result.ClientMetadata,
303 }
304 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
305 Role: message.Tool,
306 Parts: []message.ContentPart{
307 toolResult,
308 },
309 })
310 if createMsgErr != nil {
311 return createMsgErr
312 }
313 return nil
314 },
315 OnStepFinish: func(stepResult fantasy.StepResult) error {
316 finishReason := message.FinishReasonUnknown
317 switch stepResult.FinishReason {
318 case fantasy.FinishReasonLength:
319 finishReason = message.FinishReasonMaxTokens
320 case fantasy.FinishReasonStop:
321 finishReason = message.FinishReasonEndTurn
322 case fantasy.FinishReasonToolCalls:
323 finishReason = message.FinishReasonToolUse
324 }
325 currentAssistant.AddFinish(finishReason, "", "")
326 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
327 sessionLock.Lock()
328 _, sessionErr := a.sessions.Save(genCtx, currentSession)
329 sessionLock.Unlock()
330 if sessionErr != nil {
331 return sessionErr
332 }
333 return a.messages.Update(genCtx, *currentAssistant)
334 },
335 StopWhen: []fantasy.StopCondition{
336 func(_ []fantasy.StepResult) bool {
337 contextWindow := a.largeModel.CatwalkCfg.ContextWindow
338 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
339 percentage := (float64(tokens) / float64(contextWindow)) * 100
340 if (percentage > 80) && !a.disableAutoSummarize {
341 shouldSummarize = true
342 return true
343 }
344 return false
345 },
346 },
347 })
348 if err != nil {
349 isCancelErr := errors.Is(err, context.Canceled)
350 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
351 if currentAssistant == nil {
352 return result, err
353 }
354 toolCalls := currentAssistant.ToolCalls()
355 // INFO: we use the parent context here because the genCtx has been cancelled
356 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
357 if createErr != nil {
358 return nil, createErr
359 }
360 for _, tc := range toolCalls {
361 if !tc.Finished {
362 tc.Finished = true
363 tc.Input = "{}"
364 currentAssistant.AddToolCall(tc)
365 updateErr := a.messages.Update(ctx, *currentAssistant)
366 if updateErr != nil {
367 return nil, updateErr
368 }
369 }
370
371 found := false
372 for _, msg := range msgs {
373 if msg.Role == message.Tool {
374 for _, tr := range msg.ToolResults() {
375 if tr.ToolCallID == tc.ID {
376 found = true
377 break
378 }
379 }
380 }
381 if found {
382 break
383 }
384 }
385 if found {
386 continue
387 }
388 content := "There was an error while executing the tool"
389 if isCancelErr {
390 content = "Tool execution canceled by user"
391 } else if isPermissionErr {
392 content = "Permission denied"
393 }
394 toolResult := message.ToolResult{
395 ToolCallID: tc.ID,
396 Name: tc.Name,
397 Content: content,
398 IsError: true,
399 }
400 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
401 Role: message.Tool,
402 Parts: []message.ContentPart{
403 toolResult,
404 },
405 })
406 if createErr != nil {
407 return nil, createErr
408 }
409 }
410 if isCancelErr {
411 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
412 } else if isPermissionErr {
413 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
414 } else {
415 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
416 }
417 // INFO: we use the parent context here because the genCtx has been cancelled
418 updateErr := a.messages.Update(ctx, *currentAssistant)
419 if updateErr != nil {
420 return nil, updateErr
421 }
422 return nil, err
423 }
424 wg.Wait()
425
426 if shouldSummarize {
427 a.activeRequests.Del(call.SessionID)
428 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
429 return nil, summarizeErr
430 }
431 }
432
433 // release active request before processing queued messages
434 a.activeRequests.Del(call.SessionID)
435 cancel()
436
437 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
438 if !ok || len(queuedMessages) == 0 {
439 return result, err
440 }
441 // there are queued messages restart the loop
442 firstQueuedMessage := queuedMessages[0]
443 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
444 return a.Run(ctx, firstQueuedMessage)
445}
446
447func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
448 if a.IsSessionBusy(sessionID) {
449 return ErrSessionBusy
450 }
451
452 currentSession, err := a.sessions.Get(ctx, sessionID)
453 if err != nil {
454 return fmt.Errorf("failed to get session: %w", err)
455 }
456 msgs, err := a.getSessionMessages(ctx, currentSession)
457 if err != nil {
458 return err
459 }
460 if len(msgs) == 0 {
461 // nothing to summarize
462 return nil
463 }
464
465 aiMsgs, _ := a.preparePrompt(msgs)
466
467 genCtx, cancel := context.WithCancel(ctx)
468 a.activeRequests.Set(sessionID, cancel)
469 defer a.activeRequests.Del(sessionID)
470 defer cancel()
471
472 agent := fantasy.NewAgent(a.largeModel.Model,
473 fantasy.WithSystemPrompt(string(summaryPrompt)),
474 )
475 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
476 Role: message.Assistant,
477 Model: a.largeModel.Model.Model(),
478 Provider: a.largeModel.Model.Provider(),
479 IsSummaryMessage: true,
480 })
481 if err != nil {
482 return err
483 }
484
485 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
486 Prompt: "Provide a detailed summary of our conversation above.",
487 Messages: aiMsgs,
488 ProviderOptions: opts,
489 OnReasoningDelta: func(id string, text string) error {
490 summaryMessage.AppendReasoningContent(text)
491 return a.messages.Update(genCtx, summaryMessage)
492 },
493 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
494 // handle anthropic signature
495 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
496 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
497 summaryMessage.AppendReasoningSignature(signature.Signature)
498 }
499 }
500 summaryMessage.FinishThinking()
501 return a.messages.Update(genCtx, summaryMessage)
502 },
503 OnTextDelta: func(id, text string) error {
504 summaryMessage.AppendContent(text)
505 return a.messages.Update(genCtx, summaryMessage)
506 },
507 })
508 if err != nil {
509 isCancelErr := errors.Is(err, context.Canceled)
510 if isCancelErr {
511 // User cancelled summarize we need to remove the summary message
512 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
513 return deleteErr
514 }
515 return err
516 }
517
518 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
519 err = a.messages.Update(genCtx, summaryMessage)
520 if err != nil {
521 return err
522 }
523
524 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
525
526 // just in case get just the last usage
527 usage := resp.Response.Usage
528 currentSession.SummaryMessageID = summaryMessage.ID
529 currentSession.CompletionTokens = usage.OutputTokens
530 currentSession.PromptTokens = 0
531 _, err = a.sessions.Save(genCtx, currentSession)
532 return err
533}
534
535func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
536 return fantasy.ProviderOptions{
537 anthropic.Name: &anthropic.ProviderCacheControlOptions{
538 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
539 },
540 }
541}
542
543func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
544 var attachmentParts []message.ContentPart
545 for _, attachment := range call.Attachments {
546 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
547 }
548 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
549 parts = append(parts, attachmentParts...)
550 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
551 Role: message.User,
552 Parts: parts,
553 })
554 if err != nil {
555 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
556 }
557 return msg, nil
558}
559
560func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
561 var history []fantasy.Message
562 for _, m := range msgs {
563 if len(m.Parts) == 0 {
564 continue
565 }
566 // Assistant message without content or tool calls (cancelled before it returned anything)
567 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
568 continue
569 }
570 history = append(history, m.ToAIMessage()...)
571 }
572
573 var files []fantasy.FilePart
574 for _, attachment := range attachments {
575 files = append(files, fantasy.FilePart{
576 Filename: attachment.FileName,
577 Data: attachment.Content,
578 MediaType: attachment.MimeType,
579 })
580 }
581
582 return history, files
583}
584
585func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
586 msgs, err := a.messages.List(ctx, session.ID)
587 if err != nil {
588 return nil, fmt.Errorf("failed to list messages: %w", err)
589 }
590
591 if session.SummaryMessageID != "" {
592 summaryMsgInex := -1
593 for i, msg := range msgs {
594 if msg.ID == session.SummaryMessageID {
595 summaryMsgInex = i
596 break
597 }
598 }
599 if summaryMsgInex != -1 {
600 msgs = msgs[summaryMsgInex:]
601 msgs[0].Role = message.User
602 }
603 }
604 return msgs, nil
605}
606
607func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
608 if prompt == "" {
609 return
610 }
611
612 var maxOutput int64 = 40
613 if a.smallModel.CatwalkCfg.CanReason {
614 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
615 }
616
617 agent := fantasy.NewAgent(a.smallModel.Model,
618 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
619 fantasy.WithMaxOutputTokens(maxOutput),
620 )
621
622 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
623 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
624 })
625 if err != nil {
626 slog.Error("error generating title", "err", err)
627 return
628 }
629
630 title := resp.Response.Content.Text()
631
632 title = strings.ReplaceAll(title, "\n", " ")
633
634 // remove thinking tags if present
635 if idx := strings.Index(title, "</think>"); idx > 0 {
636 title = title[idx+len("</think>"):]
637 }
638
639 title = strings.TrimSpace(title)
640 if title == "" {
641 slog.Warn("failed to generate title", "warn", "empty title")
642 return
643 }
644
645 session.Title = title
646 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
647 _, saveErr := a.sessions.Save(ctx, *session)
648 if saveErr != nil {
649 slog.Error("failed to save session title & usage", "error", saveErr)
650 return
651 }
652}
653
654func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
655 modelConfig := model.CatwalkCfg
656 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
657 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
658 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
659 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
660 session.Cost += cost
661 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
662 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
663}
664
665func (a *sessionAgent) Cancel(sessionID string) {
666 // Cancel regular requests
667 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
668 slog.Info("Request cancellation initiated", "session_id", sessionID)
669 cancel()
670 }
671
672 // Also check for summarize requests
673 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
674 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
675 cancel()
676 }
677
678 if a.QueuedPrompts(sessionID) > 0 {
679 slog.Info("Clearing queued prompts", "session_id", sessionID)
680 a.messageQueue.Del(sessionID)
681 }
682}
683
684func (a *sessionAgent) ClearQueue(sessionID string) {
685 if a.QueuedPrompts(sessionID) > 0 {
686 slog.Info("Clearing queued prompts", "session_id", sessionID)
687 a.messageQueue.Del(sessionID)
688 }
689}
690
691func (a *sessionAgent) CancelAll() {
692 if !a.IsBusy() {
693 return
694 }
695 for key := range a.activeRequests.Seq2() {
696 a.Cancel(key) // key is sessionID
697 }
698
699 timeout := time.After(5 * time.Second)
700 for a.IsBusy() {
701 select {
702 case <-timeout:
703 return
704 default:
705 time.Sleep(200 * time.Millisecond)
706 }
707 }
708}
709
710func (a *sessionAgent) IsBusy() bool {
711 var busy bool
712 for cancelFunc := range a.activeRequests.Seq() {
713 if cancelFunc != nil {
714 busy = true
715 break
716 }
717 }
718 return busy
719}
720
721func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
722 _, busy := a.activeRequests.Get(sessionID)
723 return busy
724}
725
726func (a *sessionAgent) QueuedPrompts(sessionID string) int {
727 l, ok := a.messageQueue.Get(sessionID)
728 if !ok {
729 return 0
730 }
731 return len(l)
732}
733
734func (a *sessionAgent) SetModels(large Model, small Model) {
735 a.largeModel = large
736 a.smallModel = small
737}
738
739func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
740 a.tools = tools
741}
742
743func (a *sessionAgent) Model() Model {
744 return a.largeModel
745}