1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "os"
10 "strconv"
11 "strings"
12 "sync"
13 "time"
14
15 "charm.land/fantasy"
16 "charm.land/fantasy/providers/anthropic"
17 "charm.land/fantasy/providers/bedrock"
18 "charm.land/fantasy/providers/google"
19 "charm.land/fantasy/providers/openai"
20 "github.com/charmbracelet/catwalk/pkg/catwalk"
21 "github.com/charmbracelet/crush/internal/agent/tools"
22 "github.com/charmbracelet/crush/internal/config"
23 "github.com/charmbracelet/crush/internal/csync"
24 "github.com/charmbracelet/crush/internal/message"
25 "github.com/charmbracelet/crush/internal/permission"
26 "github.com/charmbracelet/crush/internal/session"
27)
28
29//go:embed templates/title.md
30var titlePrompt []byte
31
32//go:embed templates/summary.md
33var summaryPrompt []byte
34
35type SessionAgentCall struct {
36 SessionID string
37 Prompt string
38 ProviderOptions fantasy.ProviderOptions
39 Attachments []message.Attachment
40 MaxOutputTokens int64
41 Temperature *float64
42 TopP *float64
43 TopK *int64
44 FrequencyPenalty *float64
45 PresencePenalty *float64
46}
47
48type SessionAgent interface {
49 Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
50 SetModels(large Model, small Model)
51 SetTools(tools []fantasy.AgentTool)
52 Cancel(sessionID string)
53 CancelAll()
54 IsSessionBusy(sessionID string) bool
55 IsBusy() bool
56 QueuedPrompts(sessionID string) int
57 ClearQueue(sessionID string)
58 Summarize(context.Context, string, fantasy.ProviderOptions) error
59 Model() Model
60}
61
62type Model struct {
63 Model fantasy.LanguageModel
64 CatwalkCfg catwalk.Model
65 ModelCfg config.SelectedModel
66}
67
68type sessionAgent struct {
69 largeModel Model
70 smallModel Model
71 systemPromptPrefix string
72 systemPrompt string
73 tools []fantasy.AgentTool
74 sessions session.Service
75 messages message.Service
76 disableAutoSummarize bool
77 isYolo bool
78
79 messageQueue *csync.Map[string, []SessionAgentCall]
80 activeRequests *csync.Map[string, context.CancelFunc]
81}
82
83type SessionAgentOptions struct {
84 LargeModel Model
85 SmallModel Model
86 SystemPromptPrefix string
87 SystemPrompt string
88 DisableAutoSummarize bool
89 IsYolo bool
90 Sessions session.Service
91 Messages message.Service
92 Tools []fantasy.AgentTool
93}
94
95func NewSessionAgent(
96 opts SessionAgentOptions,
97) SessionAgent {
98 return &sessionAgent{
99 largeModel: opts.LargeModel,
100 smallModel: opts.SmallModel,
101 systemPromptPrefix: opts.SystemPromptPrefix,
102 systemPrompt: opts.SystemPrompt,
103 sessions: opts.Sessions,
104 messages: opts.Messages,
105 disableAutoSummarize: opts.DisableAutoSummarize,
106 tools: opts.Tools,
107 isYolo: opts.IsYolo,
108 messageQueue: csync.NewMap[string, []SessionAgentCall](),
109 activeRequests: csync.NewMap[string, context.CancelFunc](),
110 }
111}
112
113func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
114 if call.Prompt == "" {
115 return nil, ErrEmptyPrompt
116 }
117 if call.SessionID == "" {
118 return nil, ErrSessionMissing
119 }
120
121 // Queue the message if busy
122 if a.IsSessionBusy(call.SessionID) {
123 existing, ok := a.messageQueue.Get(call.SessionID)
124 if !ok {
125 existing = []SessionAgentCall{}
126 }
127 existing = append(existing, call)
128 a.messageQueue.Set(call.SessionID, existing)
129 return nil, nil
130 }
131
132 if len(a.tools) > 0 {
133 // add anthropic caching to the last tool
134 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
135 }
136
137 agent := fantasy.NewAgent(
138 a.largeModel.Model,
139 fantasy.WithSystemPrompt(a.systemPrompt),
140 fantasy.WithTools(a.tools...),
141 )
142
143 sessionLock := sync.Mutex{}
144 currentSession, err := a.sessions.Get(ctx, call.SessionID)
145 if err != nil {
146 return nil, fmt.Errorf("failed to get session: %w", err)
147 }
148
149 msgs, err := a.getSessionMessages(ctx, currentSession)
150 if err != nil {
151 return nil, fmt.Errorf("failed to get session messages: %w", err)
152 }
153
154 var wg sync.WaitGroup
155 // Generate title if first message
156 if len(msgs) == 0 {
157 wg.Go(func() {
158 sessionLock.Lock()
159 a.generateTitle(ctx, ¤tSession, call.Prompt)
160 sessionLock.Unlock()
161 })
162 }
163
164 // Add the user message to the session
165 _, err = a.createUserMessage(ctx, call)
166 if err != nil {
167 return nil, err
168 }
169
170 // add the session to the context
171 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
172
173 genCtx, cancel := context.WithCancel(ctx)
174 a.activeRequests.Set(call.SessionID, cancel)
175
176 defer cancel()
177 defer a.activeRequests.Del(call.SessionID)
178
179 history, files := a.preparePrompt(msgs, call.Attachments...)
180
181 startTime := time.Now()
182 a.eventPromptSent(call.SessionID)
183
184 var currentAssistant *message.Message
185 var shouldSummarize bool
186 result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
187 Prompt: call.Prompt,
188 Files: files,
189 Messages: history,
190 ProviderOptions: call.ProviderOptions,
191 MaxOutputTokens: &call.MaxOutputTokens,
192 TopP: call.TopP,
193 Temperature: call.Temperature,
194 PresencePenalty: call.PresencePenalty,
195 TopK: call.TopK,
196 FrequencyPenalty: call.FrequencyPenalty,
197 // Before each step create the new assistant message
198 PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
199 prepared.Messages = options.Messages
200 // reset all cached items
201 for i := range prepared.Messages {
202 prepared.Messages[i].ProviderOptions = nil
203 }
204
205 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
206 a.messageQueue.Del(call.SessionID)
207 for _, queued := range queuedCalls {
208 userMessage, createErr := a.createUserMessage(callContext, queued)
209 if createErr != nil {
210 return callContext, prepared, createErr
211 }
212 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
213 }
214
215 lastSystemRoleInx := 0
216 systemMessageUpdated := false
217 for i, msg := range prepared.Messages {
218 // only add cache control to the last message
219 if msg.Role == fantasy.MessageRoleSystem {
220 lastSystemRoleInx = i
221 } else if !systemMessageUpdated {
222 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
223 systemMessageUpdated = true
224 }
225 // than add cache control to the last 2 messages
226 if i > len(prepared.Messages)-3 {
227 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
228 }
229 }
230
231 if a.systemPromptPrefix != "" {
232 prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
233 }
234
235 var assistantMsg message.Message
236 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
237 Role: message.Assistant,
238 Parts: []message.ContentPart{},
239 Model: a.largeModel.ModelCfg.Model,
240 Provider: a.largeModel.ModelCfg.Provider,
241 })
242 if err != nil {
243 return callContext, prepared, err
244 }
245 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
246 currentAssistant = &assistantMsg
247 return callContext, prepared, err
248 },
249 OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
250 currentAssistant.AppendReasoningContent(reasoning.Text)
251 return a.messages.Update(genCtx, *currentAssistant)
252 },
253 OnReasoningDelta: func(id string, text string) error {
254 currentAssistant.AppendReasoningContent(text)
255 return a.messages.Update(genCtx, *currentAssistant)
256 },
257 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
258 // handle anthropic signature
259 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
260 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
261 currentAssistant.AppendReasoningSignature(reasoning.Signature)
262 }
263 }
264 if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
265 if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
266 currentAssistant.AppendReasoningSignature(reasoning.Signature)
267 }
268 }
269 if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
270 if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
271 currentAssistant.SetReasoningResponsesData(reasoning)
272 }
273 }
274 currentAssistant.FinishThinking()
275 return a.messages.Update(genCtx, *currentAssistant)
276 },
277 OnTextDelta: func(id string, text string) error {
278 currentAssistant.AppendContent(text)
279 return a.messages.Update(genCtx, *currentAssistant)
280 },
281 OnToolInputStart: func(id string, toolName string) error {
282 toolCall := message.ToolCall{
283 ID: id,
284 Name: toolName,
285 ProviderExecuted: false,
286 Finished: false,
287 }
288 currentAssistant.AddToolCall(toolCall)
289 return a.messages.Update(genCtx, *currentAssistant)
290 },
291 OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
292 // TODO: implement
293 },
294 OnToolCall: func(tc fantasy.ToolCallContent) error {
295 toolCall := message.ToolCall{
296 ID: tc.ToolCallID,
297 Name: tc.ToolName,
298 Input: tc.Input,
299 ProviderExecuted: false,
300 Finished: true,
301 }
302 currentAssistant.AddToolCall(toolCall)
303 return a.messages.Update(genCtx, *currentAssistant)
304 },
305 OnToolResult: func(result fantasy.ToolResultContent) error {
306 var resultContent string
307 isError := false
308 switch result.Result.GetType() {
309 case fantasy.ToolResultContentTypeText:
310 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
311 if ok {
312 resultContent = r.Text
313 }
314 case fantasy.ToolResultContentTypeError:
315 r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
316 if ok {
317 isError = true
318 resultContent = r.Error.Error()
319 }
320 case fantasy.ToolResultContentTypeMedia:
321 // TODO: handle this message type
322 }
323 toolResult := message.ToolResult{
324 ToolCallID: result.ToolCallID,
325 Name: result.ToolName,
326 Content: resultContent,
327 IsError: isError,
328 Metadata: result.ClientMetadata,
329 }
330 _, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
331 Role: message.Tool,
332 Parts: []message.ContentPart{
333 toolResult,
334 },
335 })
336 if createMsgErr != nil {
337 return createMsgErr
338 }
339 return nil
340 },
341 OnStepFinish: func(stepResult fantasy.StepResult) error {
342 finishReason := message.FinishReasonUnknown
343 switch stepResult.FinishReason {
344 case fantasy.FinishReasonLength:
345 finishReason = message.FinishReasonMaxTokens
346 case fantasy.FinishReasonStop:
347 finishReason = message.FinishReasonEndTurn
348 case fantasy.FinishReasonToolCalls:
349 finishReason = message.FinishReasonToolUse
350 }
351 currentAssistant.AddFinish(finishReason, "", "")
352 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
353 sessionLock.Lock()
354 _, sessionErr := a.sessions.Save(genCtx, currentSession)
355 sessionLock.Unlock()
356 if sessionErr != nil {
357 return sessionErr
358 }
359 return a.messages.Update(genCtx, *currentAssistant)
360 },
361 StopWhen: []fantasy.StopCondition{
362 func(_ []fantasy.StepResult) bool {
363 contextWindow := a.largeModel.CatwalkCfg.ContextWindow
364 tokens := currentSession.CompletionTokens + currentSession.PromptTokens
365 percentage := (float64(tokens) / float64(contextWindow)) * 100
366 if (percentage > 80) && !a.disableAutoSummarize {
367 shouldSummarize = true
368 return true
369 }
370 return false
371 },
372 },
373 })
374
375 a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
376
377 if err != nil {
378 isCancelErr := errors.Is(err, context.Canceled)
379 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
380 if currentAssistant == nil {
381 return result, err
382 }
383 // Ensure we finish thinking on error to close the reasoning state
384 currentAssistant.FinishThinking()
385 toolCalls := currentAssistant.ToolCalls()
386 // INFO: we use the parent context here because the genCtx has been cancelled
387 msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
388 if createErr != nil {
389 return nil, createErr
390 }
391 for _, tc := range toolCalls {
392 if !tc.Finished {
393 tc.Finished = true
394 tc.Input = "{}"
395 currentAssistant.AddToolCall(tc)
396 updateErr := a.messages.Update(ctx, *currentAssistant)
397 if updateErr != nil {
398 return nil, updateErr
399 }
400 }
401
402 found := false
403 for _, msg := range msgs {
404 if msg.Role == message.Tool {
405 for _, tr := range msg.ToolResults() {
406 if tr.ToolCallID == tc.ID {
407 found = true
408 break
409 }
410 }
411 }
412 if found {
413 break
414 }
415 }
416 if found {
417 continue
418 }
419 content := "There was an error while executing the tool"
420 if isCancelErr {
421 content = "Tool execution canceled by user"
422 } else if isPermissionErr {
423 content = "Permission denied"
424 }
425 toolResult := message.ToolResult{
426 ToolCallID: tc.ID,
427 Name: tc.Name,
428 Content: content,
429 IsError: true,
430 }
431 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
432 Role: message.Tool,
433 Parts: []message.ContentPart{
434 toolResult,
435 },
436 })
437 if createErr != nil {
438 return nil, createErr
439 }
440 }
441 if isCancelErr {
442 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
443 } else if isPermissionErr {
444 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
445 } else {
446 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
447 }
448 // INFO: we use the parent context here because the genCtx has been cancelled
449 updateErr := a.messages.Update(ctx, *currentAssistant)
450 if updateErr != nil {
451 return nil, updateErr
452 }
453 return nil, err
454 }
455 wg.Wait()
456
457 if shouldSummarize {
458 a.activeRequests.Del(call.SessionID)
459 if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
460 return nil, summarizeErr
461 }
462 // if the agent was not done...
463 if len(currentAssistant.ToolCalls()) > 0 {
464 existing, ok := a.messageQueue.Get(call.SessionID)
465 if !ok {
466 existing = []SessionAgentCall{}
467 }
468 call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
469 existing = append(existing, call)
470 a.messageQueue.Set(call.SessionID, existing)
471 }
472 }
473
474 // release active request before processing queued messages
475 a.activeRequests.Del(call.SessionID)
476 cancel()
477
478 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
479 if !ok || len(queuedMessages) == 0 {
480 return result, err
481 }
482 // there are queued messages restart the loop
483 firstQueuedMessage := queuedMessages[0]
484 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
485 return a.Run(ctx, firstQueuedMessage)
486}
487
488func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
489 if a.IsSessionBusy(sessionID) {
490 return ErrSessionBusy
491 }
492
493 currentSession, err := a.sessions.Get(ctx, sessionID)
494 if err != nil {
495 return fmt.Errorf("failed to get session: %w", err)
496 }
497 msgs, err := a.getSessionMessages(ctx, currentSession)
498 if err != nil {
499 return err
500 }
501 if len(msgs) == 0 {
502 // nothing to summarize
503 return nil
504 }
505
506 aiMsgs, _ := a.preparePrompt(msgs)
507
508 genCtx, cancel := context.WithCancel(ctx)
509 a.activeRequests.Set(sessionID, cancel)
510 defer a.activeRequests.Del(sessionID)
511 defer cancel()
512
513 agent := fantasy.NewAgent(a.largeModel.Model,
514 fantasy.WithSystemPrompt(string(summaryPrompt)),
515 )
516 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
517 Role: message.Assistant,
518 Model: a.largeModel.Model.Model(),
519 Provider: a.largeModel.Model.Provider(),
520 IsSummaryMessage: true,
521 })
522 if err != nil {
523 return err
524 }
525
526 resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
527 Prompt: "Provide a detailed summary of our conversation above.",
528 Messages: aiMsgs,
529 ProviderOptions: opts,
530 OnReasoningDelta: func(id string, text string) error {
531 summaryMessage.AppendReasoningContent(text)
532 return a.messages.Update(genCtx, summaryMessage)
533 },
534 OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
535 // handle anthropic signature
536 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
537 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
538 summaryMessage.AppendReasoningSignature(signature.Signature)
539 }
540 }
541 summaryMessage.FinishThinking()
542 return a.messages.Update(genCtx, summaryMessage)
543 },
544 OnTextDelta: func(id, text string) error {
545 summaryMessage.AppendContent(text)
546 return a.messages.Update(genCtx, summaryMessage)
547 },
548 })
549 if err != nil {
550 isCancelErr := errors.Is(err, context.Canceled)
551 if isCancelErr {
552 // User cancelled summarize we need to remove the summary message
553 deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
554 return deleteErr
555 }
556 return err
557 }
558
559 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
560 err = a.messages.Update(genCtx, summaryMessage)
561 if err != nil {
562 return err
563 }
564
565 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
566
567 // just in case get just the last usage
568 usage := resp.Response.Usage
569 currentSession.SummaryMessageID = summaryMessage.ID
570 currentSession.CompletionTokens = usage.OutputTokens
571 currentSession.PromptTokens = 0
572 _, err = a.sessions.Save(genCtx, currentSession)
573 return err
574}
575
576func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
577 if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
578 return fantasy.ProviderOptions{}
579 }
580 return fantasy.ProviderOptions{
581 anthropic.Name: &anthropic.ProviderCacheControlOptions{
582 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
583 },
584 bedrock.Name: &anthropic.ProviderCacheControlOptions{
585 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
586 },
587 }
588}
589
590func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
591 var attachmentParts []message.ContentPart
592 for _, attachment := range call.Attachments {
593 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
594 }
595 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
596 parts = append(parts, attachmentParts...)
597 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
598 Role: message.User,
599 Parts: parts,
600 })
601 if err != nil {
602 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
603 }
604 return msg, nil
605}
606
607func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
608 var history []fantasy.Message
609 for _, m := range msgs {
610 if len(m.Parts) == 0 {
611 continue
612 }
613 // Assistant message without content or tool calls (cancelled before it returned anything)
614 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
615 continue
616 }
617 history = append(history, m.ToAIMessage()...)
618 }
619
620 var files []fantasy.FilePart
621 for _, attachment := range attachments {
622 files = append(files, fantasy.FilePart{
623 Filename: attachment.FileName,
624 Data: attachment.Content,
625 MediaType: attachment.MimeType,
626 })
627 }
628
629 return history, files
630}
631
632func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
633 msgs, err := a.messages.List(ctx, session.ID)
634 if err != nil {
635 return nil, fmt.Errorf("failed to list messages: %w", err)
636 }
637
638 if session.SummaryMessageID != "" {
639 summaryMsgInex := -1
640 for i, msg := range msgs {
641 if msg.ID == session.SummaryMessageID {
642 summaryMsgInex = i
643 break
644 }
645 }
646 if summaryMsgInex != -1 {
647 msgs = msgs[summaryMsgInex:]
648 msgs[0].Role = message.User
649 }
650 }
651 return msgs, nil
652}
653
654func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
655 if prompt == "" {
656 return
657 }
658
659 var maxOutput int64 = 40
660 if a.smallModel.CatwalkCfg.CanReason {
661 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
662 }
663
664 agent := fantasy.NewAgent(a.smallModel.Model,
665 fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
666 fantasy.WithMaxOutputTokens(maxOutput),
667 )
668
669 resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
670 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
671 })
672 if err != nil {
673 slog.Error("error generating title", "err", err)
674 return
675 }
676
677 title := resp.Response.Content.Text()
678
679 title = strings.ReplaceAll(title, "\n", " ")
680
681 // remove thinking tags if present
682 if idx := strings.Index(title, "</think>"); idx > 0 {
683 title = title[idx+len("</think>"):]
684 }
685
686 title = strings.TrimSpace(title)
687 if title == "" {
688 slog.Warn("failed to generate title", "warn", "empty title")
689 return
690 }
691
692 session.Title = title
693 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
694 _, saveErr := a.sessions.Save(ctx, *session)
695 if saveErr != nil {
696 slog.Error("failed to save session title & usage", "error", saveErr)
697 return
698 }
699}
700
701func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
702 modelConfig := model.CatwalkCfg
703 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
704 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
705 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
706 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
707
708 a.eventTokensUsed(session.ID, model, usage, cost)
709
710 session.Cost += cost
711 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
712 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
713}
714
715func (a *sessionAgent) Cancel(sessionID string) {
716 // Cancel regular requests
717 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
718 slog.Info("Request cancellation initiated", "session_id", sessionID)
719 cancel()
720 }
721
722 // Also check for summarize requests
723 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
724 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
725 cancel()
726 }
727
728 if a.QueuedPrompts(sessionID) > 0 {
729 slog.Info("Clearing queued prompts", "session_id", sessionID)
730 a.messageQueue.Del(sessionID)
731 }
732}
733
734func (a *sessionAgent) ClearQueue(sessionID string) {
735 if a.QueuedPrompts(sessionID) > 0 {
736 slog.Info("Clearing queued prompts", "session_id", sessionID)
737 a.messageQueue.Del(sessionID)
738 }
739}
740
741func (a *sessionAgent) CancelAll() {
742 if !a.IsBusy() {
743 return
744 }
745 for key := range a.activeRequests.Seq2() {
746 a.Cancel(key) // key is sessionID
747 }
748
749 timeout := time.After(5 * time.Second)
750 for a.IsBusy() {
751 select {
752 case <-timeout:
753 return
754 default:
755 time.Sleep(200 * time.Millisecond)
756 }
757 }
758}
759
760func (a *sessionAgent) IsBusy() bool {
761 var busy bool
762 for cancelFunc := range a.activeRequests.Seq() {
763 if cancelFunc != nil {
764 busy = true
765 break
766 }
767 }
768 return busy
769}
770
771func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
772 _, busy := a.activeRequests.Get(sessionID)
773 return busy
774}
775
776func (a *sessionAgent) QueuedPrompts(sessionID string) int {
777 l, ok := a.messageQueue.Get(sessionID)
778 if !ok {
779 return 0
780 }
781 return len(l)
782}
783
784func (a *sessionAgent) SetModels(large Model, small Model) {
785 a.largeModel = large
786 a.smallModel = small
787}
788
789func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
790 a.tools = tools
791}
792
793func (a *sessionAgent) Model() Model {
794 return a.largeModel
795}