agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"strconv"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"charm.land/fantasy/providers/anthropic"
 17	"charm.land/fantasy/providers/bedrock"
 18	"charm.land/fantasy/providers/google"
 19	"charm.land/fantasy/providers/openai"
 20	"github.com/charmbracelet/catwalk/pkg/catwalk"
 21	"github.com/charmbracelet/crush/internal/agent/tools"
 22	"github.com/charmbracelet/crush/internal/config"
 23	"github.com/charmbracelet/crush/internal/csync"
 24	"github.com/charmbracelet/crush/internal/message"
 25	"github.com/charmbracelet/crush/internal/permission"
 26	"github.com/charmbracelet/crush/internal/session"
 27)
 28
 29//go:embed templates/title.md
 30var titlePrompt []byte
 31
 32//go:embed templates/summary.md
 33var summaryPrompt []byte
 34
 35type SessionAgentCall struct {
 36	SessionID        string
 37	Prompt           string
 38	ProviderOptions  fantasy.ProviderOptions
 39	Attachments      []message.Attachment
 40	MaxOutputTokens  int64
 41	Temperature      *float64
 42	TopP             *float64
 43	TopK             *int64
 44	FrequencyPenalty *float64
 45	PresencePenalty  *float64
 46}
 47
 48type SessionAgent interface {
 49	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 50	SetModels(large Model, small Model)
 51	SetTools(tools []fantasy.AgentTool)
 52	Cancel(sessionID string)
 53	CancelAll()
 54	IsSessionBusy(sessionID string) bool
 55	IsBusy() bool
 56	QueuedPrompts(sessionID string) int
 57	ClearQueue(sessionID string)
 58	Summarize(context.Context, string, fantasy.ProviderOptions) error
 59	Model() Model
 60}
 61
 62type Model struct {
 63	Model      fantasy.LanguageModel
 64	CatwalkCfg catwalk.Model
 65	ModelCfg   config.SelectedModel
 66}
 67
 68type sessionAgent struct {
 69	largeModel           Model
 70	smallModel           Model
 71	systemPrompt         string
 72	tools                []fantasy.AgentTool
 73	sessions             session.Service
 74	messages             message.Service
 75	disableAutoSummarize bool
 76	isYolo               bool
 77
 78	messageQueue   *csync.Map[string, []SessionAgentCall]
 79	activeRequests *csync.Map[string, context.CancelFunc]
 80}
 81
 82type SessionAgentOptions struct {
 83	LargeModel           Model
 84	SmallModel           Model
 85	SystemPrompt         string
 86	DisableAutoSummarize bool
 87	IsYolo               bool
 88	Sessions             session.Service
 89	Messages             message.Service
 90	Tools                []fantasy.AgentTool
 91}
 92
 93func NewSessionAgent(
 94	opts SessionAgentOptions,
 95) SessionAgent {
 96	return &sessionAgent{
 97		largeModel:           opts.LargeModel,
 98		smallModel:           opts.SmallModel,
 99		systemPrompt:         opts.SystemPrompt,
100		sessions:             opts.Sessions,
101		messages:             opts.Messages,
102		disableAutoSummarize: opts.DisableAutoSummarize,
103		tools:                opts.Tools,
104		isYolo:               opts.IsYolo,
105		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
106		activeRequests:       csync.NewMap[string, context.CancelFunc](),
107	}
108}
109
110func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
111	if call.Prompt == "" {
112		return nil, ErrEmptyPrompt
113	}
114	if call.SessionID == "" {
115		return nil, ErrSessionMissing
116	}
117
118	// Queue the message if busy
119	if a.IsSessionBusy(call.SessionID) {
120		existing, ok := a.messageQueue.Get(call.SessionID)
121		if !ok {
122			existing = []SessionAgentCall{}
123		}
124		existing = append(existing, call)
125		a.messageQueue.Set(call.SessionID, existing)
126		return nil, nil
127	}
128
129	if len(a.tools) > 0 {
130		// add anthropic caching to the last tool
131		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
132	}
133
134	agent := fantasy.NewAgent(
135		a.largeModel.Model,
136		fantasy.WithSystemPrompt(a.systemPrompt),
137		fantasy.WithTools(a.tools...),
138	)
139
140	sessionLock := sync.Mutex{}
141	currentSession, err := a.sessions.Get(ctx, call.SessionID)
142	if err != nil {
143		return nil, fmt.Errorf("failed to get session: %w", err)
144	}
145
146	msgs, err := a.getSessionMessages(ctx, currentSession)
147	if err != nil {
148		return nil, fmt.Errorf("failed to get session messages: %w", err)
149	}
150
151	var wg sync.WaitGroup
152	// Generate title if first message
153	if len(msgs) == 0 {
154		wg.Go(func() {
155			sessionLock.Lock()
156			a.generateTitle(ctx, &currentSession, call.Prompt)
157			sessionLock.Unlock()
158		})
159	}
160
161	// Add the user message to the session
162	_, err = a.createUserMessage(ctx, call)
163	if err != nil {
164		return nil, err
165	}
166
167	// add the session to the context
168	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
169
170	genCtx, cancel := context.WithCancel(ctx)
171	a.activeRequests.Set(call.SessionID, cancel)
172
173	defer cancel()
174	defer a.activeRequests.Del(call.SessionID)
175
176	history, files := a.preparePrompt(msgs, call.Attachments...)
177
178	startTime := time.Now()
179	a.eventPromptSent(call.SessionID)
180
181	var currentAssistant *message.Message
182	var shouldSummarize bool
183	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
184		Prompt:           call.Prompt,
185		Files:            files,
186		Messages:         history,
187		ProviderOptions:  call.ProviderOptions,
188		MaxOutputTokens:  &call.MaxOutputTokens,
189		TopP:             call.TopP,
190		Temperature:      call.Temperature,
191		PresencePenalty:  call.PresencePenalty,
192		TopK:             call.TopK,
193		FrequencyPenalty: call.FrequencyPenalty,
194		// Before each step create the new assistant message
195		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
196			prepared.Messages = options.Messages
197			// reset all cached items
198			for i := range prepared.Messages {
199				prepared.Messages[i].ProviderOptions = nil
200			}
201
202			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
203			a.messageQueue.Del(call.SessionID)
204			for _, queued := range queuedCalls {
205				userMessage, createErr := a.createUserMessage(callContext, queued)
206				if createErr != nil {
207					return callContext, prepared, createErr
208				}
209				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
210			}
211
212			lastSystemRoleInx := 0
213			systemMessageUpdated := false
214			for i, msg := range prepared.Messages {
215				// only add cache control to the last message
216				if msg.Role == fantasy.MessageRoleSystem {
217					lastSystemRoleInx = i
218				} else if !systemMessageUpdated {
219					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
220					systemMessageUpdated = true
221				}
222				// than add cache control to the last 2 messages
223				if i > len(prepared.Messages)-3 {
224					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
225				}
226			}
227
228			var assistantMsg message.Message
229			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
230				Role:     message.Assistant,
231				Parts:    []message.ContentPart{},
232				Model:    a.largeModel.ModelCfg.Model,
233				Provider: a.largeModel.ModelCfg.Provider,
234			})
235			if err != nil {
236				return callContext, prepared, err
237			}
238			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
239			currentAssistant = &assistantMsg
240			return callContext, prepared, err
241		},
242		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
243			currentAssistant.AppendReasoningContent(reasoning.Text)
244			return a.messages.Update(genCtx, *currentAssistant)
245		},
246		OnReasoningDelta: func(id string, text string) error {
247			currentAssistant.AppendReasoningContent(text)
248			return a.messages.Update(genCtx, *currentAssistant)
249		},
250		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
251			// handle anthropic signature
252			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
253				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
254					currentAssistant.AppendReasoningSignature(reasoning.Signature)
255				}
256			}
257			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
258				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
259					currentAssistant.AppendReasoningSignature(reasoning.Signature)
260				}
261			}
262			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
263				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
264					currentAssistant.SetReasoningResponsesData(reasoning)
265				}
266			}
267			currentAssistant.FinishThinking()
268			return a.messages.Update(genCtx, *currentAssistant)
269		},
270		OnTextDelta: func(id string, text string) error {
271			currentAssistant.AppendContent(text)
272			return a.messages.Update(genCtx, *currentAssistant)
273		},
274		OnToolInputStart: func(id string, toolName string) error {
275			toolCall := message.ToolCall{
276				ID:               id,
277				Name:             toolName,
278				ProviderExecuted: false,
279				Finished:         false,
280			}
281			currentAssistant.AddToolCall(toolCall)
282			return a.messages.Update(genCtx, *currentAssistant)
283		},
284		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
285			// TODO: implement
286		},
287		OnToolCall: func(tc fantasy.ToolCallContent) error {
288			toolCall := message.ToolCall{
289				ID:               tc.ToolCallID,
290				Name:             tc.ToolName,
291				Input:            tc.Input,
292				ProviderExecuted: false,
293				Finished:         true,
294			}
295			currentAssistant.AddToolCall(toolCall)
296			return a.messages.Update(genCtx, *currentAssistant)
297		},
298		OnToolResult: func(result fantasy.ToolResultContent) error {
299			var resultContent string
300			isError := false
301			switch result.Result.GetType() {
302			case fantasy.ToolResultContentTypeText:
303				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
304				if ok {
305					resultContent = r.Text
306				}
307			case fantasy.ToolResultContentTypeError:
308				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
309				if ok {
310					isError = true
311					resultContent = r.Error.Error()
312				}
313			case fantasy.ToolResultContentTypeMedia:
314				// TODO: handle this message type
315			}
316			toolResult := message.ToolResult{
317				ToolCallID: result.ToolCallID,
318				Name:       result.ToolName,
319				Content:    resultContent,
320				IsError:    isError,
321				Metadata:   result.ClientMetadata,
322			}
323			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
324				Role: message.Tool,
325				Parts: []message.ContentPart{
326					toolResult,
327				},
328			})
329			if createMsgErr != nil {
330				return createMsgErr
331			}
332			return nil
333		},
334		OnStepFinish: func(stepResult fantasy.StepResult) error {
335			finishReason := message.FinishReasonUnknown
336			switch stepResult.FinishReason {
337			case fantasy.FinishReasonLength:
338				finishReason = message.FinishReasonMaxTokens
339			case fantasy.FinishReasonStop:
340				finishReason = message.FinishReasonEndTurn
341			case fantasy.FinishReasonToolCalls:
342				finishReason = message.FinishReasonToolUse
343			}
344			currentAssistant.AddFinish(finishReason, "", "")
345			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
346			sessionLock.Lock()
347			_, sessionErr := a.sessions.Save(genCtx, currentSession)
348			sessionLock.Unlock()
349			if sessionErr != nil {
350				return sessionErr
351			}
352			return a.messages.Update(genCtx, *currentAssistant)
353		},
354		StopWhen: []fantasy.StopCondition{
355			func(_ []fantasy.StepResult) bool {
356				contextWindow := a.largeModel.CatwalkCfg.ContextWindow
357				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
358				percentage := (float64(tokens) / float64(contextWindow)) * 100
359				if (percentage > 80) && !a.disableAutoSummarize {
360					shouldSummarize = true
361					return true
362				}
363				return false
364			},
365		},
366	})
367
368	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
369
370	if err != nil {
371		isCancelErr := errors.Is(err, context.Canceled)
372		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
373		if currentAssistant == nil {
374			return result, err
375		}
376		// Ensure we finish thinking on error to close the reasoning state
377		currentAssistant.FinishThinking()
378		toolCalls := currentAssistant.ToolCalls()
379		// INFO: we use the parent context here because the genCtx has been cancelled
380		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
381		if createErr != nil {
382			return nil, createErr
383		}
384		for _, tc := range toolCalls {
385			if !tc.Finished {
386				tc.Finished = true
387				tc.Input = "{}"
388				currentAssistant.AddToolCall(tc)
389				updateErr := a.messages.Update(ctx, *currentAssistant)
390				if updateErr != nil {
391					return nil, updateErr
392				}
393			}
394
395			found := false
396			for _, msg := range msgs {
397				if msg.Role == message.Tool {
398					for _, tr := range msg.ToolResults() {
399						if tr.ToolCallID == tc.ID {
400							found = true
401							break
402						}
403					}
404				}
405				if found {
406					break
407				}
408			}
409			if found {
410				continue
411			}
412			content := "There was an error while executing the tool"
413			if isCancelErr {
414				content = "Tool execution canceled by user"
415			} else if isPermissionErr {
416				content = "Permission denied"
417			}
418			toolResult := message.ToolResult{
419				ToolCallID: tc.ID,
420				Name:       tc.Name,
421				Content:    content,
422				IsError:    true,
423			}
424			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
425				Role: message.Tool,
426				Parts: []message.ContentPart{
427					toolResult,
428				},
429			})
430			if createErr != nil {
431				return nil, createErr
432			}
433		}
434		if isCancelErr {
435			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
436		} else if isPermissionErr {
437			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
438		} else {
439			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
440		}
441		// INFO: we use the parent context here because the genCtx has been cancelled
442		updateErr := a.messages.Update(ctx, *currentAssistant)
443		if updateErr != nil {
444			return nil, updateErr
445		}
446		return nil, err
447	}
448	wg.Wait()
449
450	if shouldSummarize {
451		a.activeRequests.Del(call.SessionID)
452		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
453			return nil, summarizeErr
454		}
455		// if the agent was not done...
456		if len(currentAssistant.ToolCalls()) > 0 {
457			existing, ok := a.messageQueue.Get(call.SessionID)
458			if !ok {
459				existing = []SessionAgentCall{}
460			}
461			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
462			existing = append(existing, call)
463			a.messageQueue.Set(call.SessionID, existing)
464		}
465	}
466
467	// release active request before processing queued messages
468	a.activeRequests.Del(call.SessionID)
469	cancel()
470
471	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
472	if !ok || len(queuedMessages) == 0 {
473		return result, err
474	}
475	// there are queued messages restart the loop
476	firstQueuedMessage := queuedMessages[0]
477	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
478	return a.Run(ctx, firstQueuedMessage)
479}
480
481func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
482	if a.IsSessionBusy(sessionID) {
483		return ErrSessionBusy
484	}
485
486	currentSession, err := a.sessions.Get(ctx, sessionID)
487	if err != nil {
488		return fmt.Errorf("failed to get session: %w", err)
489	}
490	msgs, err := a.getSessionMessages(ctx, currentSession)
491	if err != nil {
492		return err
493	}
494	if len(msgs) == 0 {
495		// nothing to summarize
496		return nil
497	}
498
499	aiMsgs, _ := a.preparePrompt(msgs)
500
501	genCtx, cancel := context.WithCancel(ctx)
502	a.activeRequests.Set(sessionID, cancel)
503	defer a.activeRequests.Del(sessionID)
504	defer cancel()
505
506	agent := fantasy.NewAgent(a.largeModel.Model,
507		fantasy.WithSystemPrompt(string(summaryPrompt)),
508	)
509	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
510		Role:             message.Assistant,
511		Model:            a.largeModel.Model.Model(),
512		Provider:         a.largeModel.Model.Provider(),
513		IsSummaryMessage: true,
514	})
515	if err != nil {
516		return err
517	}
518
519	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
520		Prompt:          "Provide a detailed summary of our conversation above.",
521		Messages:        aiMsgs,
522		ProviderOptions: opts,
523		OnReasoningDelta: func(id string, text string) error {
524			summaryMessage.AppendReasoningContent(text)
525			return a.messages.Update(genCtx, summaryMessage)
526		},
527		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
528			// handle anthropic signature
529			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
530				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
531					summaryMessage.AppendReasoningSignature(signature.Signature)
532				}
533			}
534			summaryMessage.FinishThinking()
535			return a.messages.Update(genCtx, summaryMessage)
536		},
537		OnTextDelta: func(id, text string) error {
538			summaryMessage.AppendContent(text)
539			return a.messages.Update(genCtx, summaryMessage)
540		},
541	})
542	if err != nil {
543		isCancelErr := errors.Is(err, context.Canceled)
544		if isCancelErr {
545			// User cancelled summarize we need to remove the summary message
546			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
547			return deleteErr
548		}
549		return err
550	}
551
552	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
553	err = a.messages.Update(genCtx, summaryMessage)
554	if err != nil {
555		return err
556	}
557
558	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
559
560	// just in case get just the last usage
561	usage := resp.Response.Usage
562	currentSession.SummaryMessageID = summaryMessage.ID
563	currentSession.CompletionTokens = usage.OutputTokens
564	currentSession.PromptTokens = 0
565	_, err = a.sessions.Save(genCtx, currentSession)
566	return err
567}
568
569func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
570	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
571		return fantasy.ProviderOptions{}
572	}
573	return fantasy.ProviderOptions{
574		anthropic.Name: &anthropic.ProviderCacheControlOptions{
575			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
576		},
577		bedrock.Name: &anthropic.ProviderCacheControlOptions{
578			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
579		},
580	}
581}
582
583func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
584	var attachmentParts []message.ContentPart
585	for _, attachment := range call.Attachments {
586		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
587	}
588	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
589	parts = append(parts, attachmentParts...)
590	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
591		Role:  message.User,
592		Parts: parts,
593	})
594	if err != nil {
595		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
596	}
597	return msg, nil
598}
599
600func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
601	var history []fantasy.Message
602	for _, m := range msgs {
603		if len(m.Parts) == 0 {
604			continue
605		}
606		// Assistant message without content or tool calls (cancelled before it returned anything)
607		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
608			continue
609		}
610		history = append(history, m.ToAIMessage()...)
611	}
612
613	var files []fantasy.FilePart
614	for _, attachment := range attachments {
615		files = append(files, fantasy.FilePart{
616			Filename:  attachment.FileName,
617			Data:      attachment.Content,
618			MediaType: attachment.MimeType,
619		})
620	}
621
622	return history, files
623}
624
625func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
626	msgs, err := a.messages.List(ctx, session.ID)
627	if err != nil {
628		return nil, fmt.Errorf("failed to list messages: %w", err)
629	}
630
631	if session.SummaryMessageID != "" {
632		summaryMsgInex := -1
633		for i, msg := range msgs {
634			if msg.ID == session.SummaryMessageID {
635				summaryMsgInex = i
636				break
637			}
638		}
639		if summaryMsgInex != -1 {
640			msgs = msgs[summaryMsgInex:]
641			msgs[0].Role = message.User
642		}
643	}
644	return msgs, nil
645}
646
647func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
648	if prompt == "" {
649		return
650	}
651
652	var maxOutput int64 = 40
653	if a.smallModel.CatwalkCfg.CanReason {
654		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
655	}
656
657	agent := fantasy.NewAgent(a.smallModel.Model,
658		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
659		fantasy.WithMaxOutputTokens(maxOutput),
660	)
661
662	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
663		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
664	})
665	if err != nil {
666		slog.Error("error generating title", "err", err)
667		return
668	}
669
670	title := resp.Response.Content.Text()
671
672	title = strings.ReplaceAll(title, "\n", " ")
673
674	// remove thinking tags if present
675	if idx := strings.Index(title, "</think>"); idx > 0 {
676		title = title[idx+len("</think>"):]
677	}
678
679	title = strings.TrimSpace(title)
680	if title == "" {
681		slog.Warn("failed to generate title", "warn", "empty title")
682		return
683	}
684
685	session.Title = title
686	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
687	_, saveErr := a.sessions.Save(ctx, *session)
688	if saveErr != nil {
689		slog.Error("failed to save session title & usage", "error", saveErr)
690		return
691	}
692}
693
694func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
695	modelConfig := model.CatwalkCfg
696	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
697		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
698		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
699		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
700
701	a.eventTokensUsed(session.ID, model, usage, cost)
702
703	session.Cost += cost
704	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
705	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
706}
707
708func (a *sessionAgent) Cancel(sessionID string) {
709	// Cancel regular requests
710	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
711		slog.Info("Request cancellation initiated", "session_id", sessionID)
712		cancel()
713	}
714
715	// Also check for summarize requests
716	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
717		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
718		cancel()
719	}
720
721	if a.QueuedPrompts(sessionID) > 0 {
722		slog.Info("Clearing queued prompts", "session_id", sessionID)
723		a.messageQueue.Del(sessionID)
724	}
725}
726
727func (a *sessionAgent) ClearQueue(sessionID string) {
728	if a.QueuedPrompts(sessionID) > 0 {
729		slog.Info("Clearing queued prompts", "session_id", sessionID)
730		a.messageQueue.Del(sessionID)
731	}
732}
733
734func (a *sessionAgent) CancelAll() {
735	if !a.IsBusy() {
736		return
737	}
738	for key := range a.activeRequests.Seq2() {
739		a.Cancel(key) // key is sessionID
740	}
741
742	timeout := time.After(5 * time.Second)
743	for a.IsBusy() {
744		select {
745		case <-timeout:
746			return
747		default:
748			time.Sleep(200 * time.Millisecond)
749		}
750	}
751}
752
753func (a *sessionAgent) IsBusy() bool {
754	var busy bool
755	for cancelFunc := range a.activeRequests.Seq() {
756		if cancelFunc != nil {
757			busy = true
758			break
759		}
760	}
761	return busy
762}
763
764func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
765	_, busy := a.activeRequests.Get(sessionID)
766	return busy
767}
768
769func (a *sessionAgent) QueuedPrompts(sessionID string) int {
770	l, ok := a.messageQueue.Get(sessionID)
771	if !ok {
772		return 0
773	}
774	return len(l)
775}
776
777func (a *sessionAgent) SetModels(large Model, small Model) {
778	a.largeModel = large
779	a.smallModel = small
780}
781
782func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
783	a.tools = tools
784}
785
786func (a *sessionAgent) Model() Model {
787	return a.largeModel
788}