1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "os"
10 "strconv"
11 "strings"
12 "sync"
13 "time"
14
15 "charm.land/fantasy"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/providers/bedrock"
18 "charm.land/fantasy/providers/openai"
19 "github.com/charmbracelet/catwalk/pkg/catwalk"
20 "github.com/charmbracelet/crush/internal/agent/tools"
21 "github.com/charmbracelet/crush/internal/config"
22 "github.com/charmbracelet/crush/internal/csync"
23 "github.com/charmbracelet/crush/internal/message"
24 "github.com/charmbracelet/crush/internal/permission"
25 "github.com/charmbracelet/crush/internal/session"
26)
27
28//go:embed templates/title.md
29var titlePrompt []byte
30
31//go:embed templates/summary.md
32var summaryPrompt []byte
33
34type SessionAgentCall struct {
35 SessionID string
36 Prompt string
37 ProviderOptions fantasy.ProviderOptions
38 Attachments []message.Attachment
39 MaxOutputTokens int64
40 Temperature *float64
41 TopP *float64
42 TopK *int64
43 FrequencyPenalty *float64
44 PresencePenalty *float64
45}
46
47type SessionAgent interface {
48 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
49 SetModels(large Model, small Model)
50 SetTools(tools []fantasy.AgentTool)
51 Cancel(sessionID string)
52 CancelAll()
53 IsSessionBusy(sessionID string) bool
54 IsBusy() bool
55 QueuedPrompts(sessionID string) int
56 ClearQueue(sessionID string)
57 Summarize(context.Context, string, fantasy.ProviderOptions) error
58 Model() Model
59}
60
61type Model struct {
62 Model fantasy.LanguageModel
63 CatwalkCfg catwalk.Model
64 ModelCfg config.SelectedModel
65}
66
67type sessionAgent struct {
68 largeModel Model
69 smallModel Model
70 systemPrompt string
71 tools []fantasy.AgentTool
72 sessions session.Service
73 messages message.Service
74 disableAutoSummarize bool
75
76 messageQueue *csync.Map[string, []SessionAgentCall]
77 activeRequests *csync.Map[string, context.CancelFunc]
78}
79
80type SessionAgentOptions struct {
81 LargeModel Model
82 SmallModel Model
83 SystemPrompt string
84 DisableAutoSummarize bool
85 Sessions session.Service
86 Messages message.Service
87 Tools []fantasy.AgentTool
88}
89
90func NewSessionAgent(
91 opts SessionAgentOptions,
92) SessionAgent {
93 return &sessionAgent{
94 largeModel: opts.LargeModel,
95 smallModel: opts.SmallModel,
96 systemPrompt: opts.SystemPrompt,
97 sessions: opts.Sessions,
98 messages: opts.Messages,
99 disableAutoSummarize: opts.DisableAutoSummarize,
100 tools: opts.Tools,
101 messageQueue: csync.NewMap[string, []SessionAgentCall](),
102 activeRequests: csync.NewMap[string, context.CancelFunc](),
103 }
104}
105
106func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
107 if call.Prompt == "" {
108 return nil, ErrEmptyPrompt
109 }
110 if call.SessionID == "" {
111 return nil, ErrSessionMissing
112 }
113
114 // Queue the message if busy
115 if a.IsSessionBusy(call.SessionID) {
116 existing, ok := a.messageQueue.Get(call.SessionID)
117 if !ok {
118 existing = []SessionAgentCall{}
119 }
120 existing = append(existing, call)
121 a.messageQueue.Set(call.SessionID, existing)
122 return nil, nil
123 }
124
125 if len(a.tools) > 0 {
126 // add anthropic caching to the last tool
127 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
128 }
129
130 agent := fantasy.NewAgent(
131 a.largeModel.Model,
132 fantasy.WithSystemPrompt(a.systemPrompt),
133 fantasy.WithTools(a.tools...),
134 )
135
136 sessionLock := sync.Mutex{}
137 currentSession, err := a.sessions.Get(ctx, call.SessionID)
138 if err != nil {
139 return nil, fmt.Errorf("failed to get session: %w", err)
140 }
141
142 msgs, err := a.getSessionMessages(ctx, currentSession)
143 if err != nil {
144 return nil, fmt.Errorf("failed to get session messages: %w", err)
145 }
146
147 var wg sync.WaitGroup
148 // Generate title if first message
149 if len(msgs) == 0 {
150 wg.Go(func() {
151 sessionLock.Lock()
152 a.generateTitle(ctx, ¤tSession, call.Prompt)
153 sessionLock.Unlock()
154 })
155 }
156
157 // Add the user message to the session
158 _, err = a.createUserMessage(ctx, call)
159 if err != nil {
160 return nil, err
161 }
162
163 // add the session to the context
164 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
165
166 genCtx, cancel := context.WithCancel(ctx)
167 a.activeRequests.Set(call.SessionID, cancel)
168
169 defer cancel()
170 defer a.activeRequests.Del(call.SessionID)
171
172 history, files := a.preparePrompt(msgs, call.Attachments...)
173
174 var currentAssistant *message.Message
175 var shouldSummarize bool
176 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
177 Prompt: call.Prompt,
178 Files: files,
179 Messages: history,
180 ProviderOptions: call.ProviderOptions,
181 MaxOutputTokens: &call.MaxOutputTokens,
182 TopP: call.TopP,
183 Temperature: call.Temperature,
184 PresencePenalty: call.PresencePenalty,
185 TopK: call.TopK,
186 FrequencyPenalty: call.FrequencyPenalty,
187 // Before each step create the new assistant message
188 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
189 prepared.Messages = options.Messages
190 // reset all cached items
191 for i := range prepared.Messages {
192 prepared.Messages[i].ProviderOptions = nil
193 }
194
195 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
196 a.messageQueue.Del(call.SessionID)
197 for _, queued := range queuedCalls {
198 userMessage, createErr := a.createUserMessage(callContext, queued)
199 if createErr != nil {
200 return callContext, prepared, createErr
201 }
202 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
203 }
204
205 lastSystemRoleInx := 0
206 systemMessageUpdated := false
207 for i, msg := range prepared.Messages {
208 // only add cache control to the last message
209 if msg.Role == fantasy.MessageRoleSystem {
210 lastSystemRoleInx = i
211 } else if !systemMessageUpdated {
212 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
213 systemMessageUpdated = true
214 }
215 // than add cache control to the last 2 messages
216 if i > len(prepared.Messages)-3 {
217 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
218 }
219 }
220
221 var assistantMsg message.Message
222 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
223 Role: message.Assistant,
224 Parts: []message.ContentPart{},
225 Model: a.largeModel.ModelCfg.Model,
226 Provider: a.largeModel.ModelCfg.Provider,
227 })
228 if err != nil {
229 return callContext, prepared, err
230 }
231 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
232 currentAssistant = &assistantMsg
233 return callContext, prepared, err
234 },
235 OnReasoningDelta: func(id string, text string) error {
236 currentAssistant.AppendReasoningContent(text)
237 return a.messages.Update(genCtx, *currentAssistant)
238 },
239 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
240 // handle anthropic signature
241 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
242 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
243 currentAssistant.AppendReasoningSignature(reasoning.Signature)
244 }
245 }
246 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
247 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
248 currentAssistant.SetReasoningResponsesData(reasoning)
249 }
250 }
251 currentAssistant.FinishThinking()
252 return a.messages.Update(genCtx, *currentAssistant)
253 },
254 OnTextDelta: func(id string, text string) error {
255 currentAssistant.AppendContent(text)
256 return a.messages.Update(genCtx, *currentAssistant)
257 },
258 OnToolInputStart: func(id string, toolName string) error {
259 toolCall := message.ToolCall{
260 ID: id,
261 Name: toolName,
262 ProviderExecuted: false,
263 Finished: false,
264 }
265 currentAssistant.AddToolCall(toolCall)
266 return a.messages.Update(genCtx, *currentAssistant)
267 },
268 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
269 // TODO: implement
270 },
271 OnToolCall: func(tc fantasy.ToolCallContent) error {
272 toolCall := message.ToolCall{
273 ID: tc.ToolCallID,
274 Name: tc.ToolName,
275 Input: tc.Input,
276 ProviderExecuted: false,
277 Finished: true,
278 }
279 currentAssistant.AddToolCall(toolCall)
280 return a.messages.Update(genCtx, *currentAssistant)
281 },
282 OnToolResult: func(result fantasy.ToolResultContent) error {
283 var resultContent string
284 isError := false
285 switch result.Result.GetType() {
286 case fantasy.ToolResultContentTypeText:
287 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
288 if ok {
289 resultContent = r.Text
290 }
291 case fantasy.ToolResultContentTypeError:
292 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
293 if ok {
294 isError = true
295 resultContent = r.Error.Error()
296 }
297 case fantasy.ToolResultContentTypeMedia:
298 // TODO: handle this message type
299 }
300 toolResult := message.ToolResult{
301 ToolCallID: result.ToolCallID,
302 Name: result.ToolName,
303 Content: resultContent,
304 IsError: isError,
305 Metadata: result.ClientMetadata,
306 }
307 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
308 Role: message.Tool,
309 Parts: []message.ContentPart{
310 toolResult,
311 },
312 })
313 if createMsgErr != nil {
314 return createMsgErr
315 }
316 return nil
317 },
318 OnStepFinish: func(stepResult fantasy.StepResult) error {
319 finishReason := message.FinishReasonUnknown
320 switch stepResult.FinishReason {
321 case fantasy.FinishReasonLength:
322 finishReason = message.FinishReasonMaxTokens
323 case fantasy.FinishReasonStop:
324 finishReason = message.FinishReasonEndTurn
325 case fantasy.FinishReasonToolCalls:
326 finishReason = message.FinishReasonToolUse
327 }
328 currentAssistant.AddFinish(finishReason, "", "")
329 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
330 sessionLock.Lock()
331 _, sessionErr := a.sessions.Save(genCtx, currentSession)
332 sessionLock.Unlock()
333 if sessionErr != nil {
334 return sessionErr
335 }
336 return a.messages.Update(genCtx, *currentAssistant)
337 },
338 StopWhen: []fantasy.StopCondition{
339 func(_ []fantasy.StepResult) bool {
340 contextWindow := a.largeModel.CatwalkCfg.ContextWindow
341 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
342 percentage := (float64(tokens) / float64(contextWindow)) * 100
343 if (percentage > 80) && !a.disableAutoSummarize {
344 shouldSummarize = true
345 return true
346 }
347 return false
348 },
349 },
350 })
351 if err != nil {
352 isCancelErr := errors.Is(err, context.Canceled)
353 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
354 if currentAssistant == nil {
355 return result, err
356 }
357 // Ensure we finish thinking on error to close the reasoning state
358 currentAssistant.FinishThinking()
359 toolCalls := currentAssistant.ToolCalls()
360 // INFO: we use the parent context here because the genCtx has been cancelled
361 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
362 if createErr != nil {
363 return nil, createErr
364 }
365 for _, tc := range toolCalls {
366 if !tc.Finished {
367 tc.Finished = true
368 tc.Input = "{}"
369 currentAssistant.AddToolCall(tc)
370 updateErr := a.messages.Update(ctx, *currentAssistant)
371 if updateErr != nil {
372 return nil, updateErr
373 }
374 }
375
376 found := false
377 for _, msg := range msgs {
378 if msg.Role == message.Tool {
379 for _, tr := range msg.ToolResults() {
380 if tr.ToolCallID == tc.ID {
381 found = true
382 break
383 }
384 }
385 }
386 if found {
387 break
388 }
389 }
390 if found {
391 continue
392 }
393 content := "There was an error while executing the tool"
394 if isCancelErr {
395 content = "Tool execution canceled by user"
396 } else if isPermissionErr {
397 content = "Permission denied"
398 }
399 toolResult := message.ToolResult{
400 ToolCallID: tc.ID,
401 Name: tc.Name,
402 Content: content,
403 IsError: true,
404 }
405 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
406 Role: message.Tool,
407 Parts: []message.ContentPart{
408 toolResult,
409 },
410 })
411 if createErr != nil {
412 return nil, createErr
413 }
414 }
415 if isCancelErr {
416 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
417 } else if isPermissionErr {
418 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
419 } else {
420 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
421 }
422 // INFO: we use the parent context here because the genCtx has been cancelled
423 updateErr := a.messages.Update(ctx, *currentAssistant)
424 if updateErr != nil {
425 return nil, updateErr
426 }
427 return nil, err
428 }
429 wg.Wait()
430
431 if shouldSummarize {
432 a.activeRequests.Del(call.SessionID)
433 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
434 return nil, summarizeErr
435 }
436 // if the agent was not done...
437 if len(currentAssistant.ToolCalls()) > 0 {
438 existing, ok := a.messageQueue.Get(call.SessionID)
439 if !ok {
440 existing = []SessionAgentCall{}
441 }
442 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
443 existing = append(existing, call)
444 a.messageQueue.Set(call.SessionID, existing)
445 }
446 }
447
448 // release active request before processing queued messages
449 a.activeRequests.Del(call.SessionID)
450 cancel()
451
452 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
453 if !ok || len(queuedMessages) == 0 {
454 return result, err
455 }
456 // there are queued messages restart the loop
457 firstQueuedMessage := queuedMessages[0]
458 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
459 return a.Run(ctx, firstQueuedMessage)
460}
461
462func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
463 if a.IsSessionBusy(sessionID) {
464 return ErrSessionBusy
465 }
466
467 currentSession, err := a.sessions.Get(ctx, sessionID)
468 if err != nil {
469 return fmt.Errorf("failed to get session: %w", err)
470 }
471 msgs, err := a.getSessionMessages(ctx, currentSession)
472 if err != nil {
473 return err
474 }
475 if len(msgs) == 0 {
476 // nothing to summarize
477 return nil
478 }
479
480 aiMsgs, _ := a.preparePrompt(msgs)
481
482 genCtx, cancel := context.WithCancel(ctx)
483 a.activeRequests.Set(sessionID, cancel)
484 defer a.activeRequests.Del(sessionID)
485 defer cancel()
486
487 agent := fantasy.NewAgent(a.largeModel.Model,
488 fantasy.WithSystemPrompt(string(summaryPrompt)),
489 )
490 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
491 Role: message.Assistant,
492 Model: a.largeModel.Model.Model(),
493 Provider: a.largeModel.Model.Provider(),
494 IsSummaryMessage: true,
495 })
496 if err != nil {
497 return err
498 }
499
500 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
501 Prompt: "Provide a detailed summary of our conversation above.",
502 Messages: aiMsgs,
503 ProviderOptions: opts,
504 OnReasoningDelta: func(id string, text string) error {
505 summaryMessage.AppendReasoningContent(text)
506 return a.messages.Update(genCtx, summaryMessage)
507 },
508 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
509 // handle anthropic signature
510 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
511 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
512 summaryMessage.AppendReasoningSignature(signature.Signature)
513 }
514 }
515 summaryMessage.FinishThinking()
516 return a.messages.Update(genCtx, summaryMessage)
517 },
518 OnTextDelta: func(id, text string) error {
519 summaryMessage.AppendContent(text)
520 return a.messages.Update(genCtx, summaryMessage)
521 },
522 })
523 if err != nil {
524 isCancelErr := errors.Is(err, context.Canceled)
525 if isCancelErr {
526 // User cancelled summarize we need to remove the summary message
527 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
528 return deleteErr
529 }
530 return err
531 }
532
533 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
534 err = a.messages.Update(genCtx, summaryMessage)
535 if err != nil {
536 return err
537 }
538
539 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
540
541 // just in case get just the last usage
542 usage := resp.Response.Usage
543 currentSession.SummaryMessageID = summaryMessage.ID
544 currentSession.CompletionTokens = usage.OutputTokens
545 currentSession.PromptTokens = 0
546 _, err = a.sessions.Save(genCtx, currentSession)
547 return err
548}
549
550func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
551 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
552 return fantasy.ProviderOptions{}
553 }
554 return fantasy.ProviderOptions{
555 anthropic.Name: &anthropic.ProviderCacheControlOptions{
556 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
557 },
558 bedrock.Name: &anthropic.ProviderCacheControlOptions{
559 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
560 },
561 }
562}
563
564func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
565 var attachmentParts []message.ContentPart
566 for _, attachment := range call.Attachments {
567 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
568 }
569 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
570 parts = append(parts, attachmentParts...)
571 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
572 Role: message.User,
573 Parts: parts,
574 })
575 if err != nil {
576 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
577 }
578 return msg, nil
579}
580
581func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
582 var history []fantasy.Message
583 for _, m := range msgs {
584 if len(m.Parts) == 0 {
585 continue
586 }
587 // Assistant message without content or tool calls (cancelled before it returned anything)
588 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
589 continue
590 }
591 history = append(history, m.ToAIMessage()...)
592 }
593
594 var files []fantasy.FilePart
595 for _, attachment := range attachments {
596 files = append(files, fantasy.FilePart{
597 Filename: attachment.FileName,
598 Data: attachment.Content,
599 MediaType: attachment.MimeType,
600 })
601 }
602
603 return history, files
604}
605
606func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
607 msgs, err := a.messages.List(ctx, session.ID)
608 if err != nil {
609 return nil, fmt.Errorf("failed to list messages: %w", err)
610 }
611
612 if session.SummaryMessageID != "" {
613 summaryMsgInex := -1
614 for i, msg := range msgs {
615 if msg.ID == session.SummaryMessageID {
616 summaryMsgInex = i
617 break
618 }
619 }
620 if summaryMsgInex != -1 {
621 msgs = msgs[summaryMsgInex:]
622 msgs[0].Role = message.User
623 }
624 }
625 return msgs, nil
626}
627
628func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
629 if prompt == "" {
630 return
631 }
632
633 var maxOutput int64 = 40
634 if a.smallModel.CatwalkCfg.CanReason {
635 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
636 }
637
638 agent := fantasy.NewAgent(a.smallModel.Model,
639 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
640 fantasy.WithMaxOutputTokens(maxOutput),
641 )
642
643 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
644 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
645 })
646 if err != nil {
647 slog.Error("error generating title", "err", err)
648 return
649 }
650
651 title := resp.Response.Content.Text()
652
653 title = strings.ReplaceAll(title, "\n", " ")
654
655 // remove thinking tags if present
656 if idx := strings.Index(title, "</think>"); idx > 0 {
657 title = title[idx+len("</think>"):]
658 }
659
660 title = strings.TrimSpace(title)
661 if title == "" {
662 slog.Warn("failed to generate title", "warn", "empty title")
663 return
664 }
665
666 session.Title = title
667 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
668 _, saveErr := a.sessions.Save(ctx, *session)
669 if saveErr != nil {
670 slog.Error("failed to save session title & usage", "error", saveErr)
671 return
672 }
673}
674
675func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
676 modelConfig := model.CatwalkCfg
677 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
678 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
679 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
680 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
681 session.Cost += cost
682 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
683 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
684}
685
686func (a *sessionAgent) Cancel(sessionID string) {
687 // Cancel regular requests
688 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
689 slog.Info("Request cancellation initiated", "session_id", sessionID)
690 cancel()
691 }
692
693 // Also check for summarize requests
694 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
695 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
696 cancel()
697 }
698
699 if a.QueuedPrompts(sessionID) > 0 {
700 slog.Info("Clearing queued prompts", "session_id", sessionID)
701 a.messageQueue.Del(sessionID)
702 }
703}
704
705func (a *sessionAgent) ClearQueue(sessionID string) {
706 if a.QueuedPrompts(sessionID) > 0 {
707 slog.Info("Clearing queued prompts", "session_id", sessionID)
708 a.messageQueue.Del(sessionID)
709 }
710}
711
712func (a *sessionAgent) CancelAll() {
713 if !a.IsBusy() {
714 return
715 }
716 for key := range a.activeRequests.Seq2() {
717 a.Cancel(key) // key is sessionID
718 }
719
720 timeout := time.After(5 * time.Second)
721 for a.IsBusy() {
722 select {
723 case <-timeout:
724 return
725 default:
726 time.Sleep(200 * time.Millisecond)
727 }
728 }
729}
730
731func (a *sessionAgent) IsBusy() bool {
732 var busy bool
733 for cancelFunc := range a.activeRequests.Seq() {
734 if cancelFunc != nil {
735 busy = true
736 break
737 }
738 }
739 return busy
740}
741
742func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
743 _, busy := a.activeRequests.Get(sessionID)
744 return busy
745}
746
747func (a *sessionAgent) QueuedPrompts(sessionID string) int {
748 l, ok := a.messageQueue.Get(sessionID)
749 if !ok {
750 return 0
751 }
752 return len(l)
753}
754
755func (a *sessionAgent) SetModels(large Model, small Model) {
756 a.largeModel = large
757 a.smallModel = small
758}
759
760func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
761 a.tools = tools
762}
763
764func (a *sessionAgent) Model() Model {
765 return a.largeModel
766}