agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"strconv"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"charm.land/fantasy/providers/anthropic"
 17	"charm.land/fantasy/providers/bedrock"
 18	"charm.land/fantasy/providers/google"
 19	"charm.land/fantasy/providers/openai"
 20	"github.com/charmbracelet/catwalk/pkg/catwalk"
 21	"github.com/charmbracelet/crush/internal/agent/tools"
 22	"github.com/charmbracelet/crush/internal/config"
 23	"github.com/charmbracelet/crush/internal/csync"
 24	"github.com/charmbracelet/crush/internal/message"
 25	"github.com/charmbracelet/crush/internal/permission"
 26	"github.com/charmbracelet/crush/internal/session"
 27)
 28
 29//go:embed templates/title.md
 30var titlePrompt []byte
 31
 32//go:embed templates/summary.md
 33var summaryPrompt []byte
 34
 35type SessionAgentCall struct {
 36	SessionID        string
 37	Prompt           string
 38	ProviderOptions  fantasy.ProviderOptions
 39	Attachments      []message.Attachment
 40	MaxOutputTokens  int64
 41	Temperature      *float64
 42	TopP             *float64
 43	TopK             *int64
 44	FrequencyPenalty *float64
 45	PresencePenalty  *float64
 46}
 47
 48type SessionAgent interface {
 49	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 50	SetModels(large Model, small Model)
 51	SetTools(tools []fantasy.AgentTool)
 52	Cancel(sessionID string)
 53	CancelAll()
 54	IsSessionBusy(sessionID string) bool
 55	IsBusy() bool
 56	QueuedPrompts(sessionID string) int
 57	ClearQueue(sessionID string)
 58	Summarize(context.Context, string, fantasy.ProviderOptions) error
 59	Model() Model
 60}
 61
 62type Model struct {
 63	Model      fantasy.LanguageModel
 64	CatwalkCfg catwalk.Model
 65	ModelCfg   config.SelectedModel
 66}
 67
 68type sessionAgent struct {
 69	largeModel           Model
 70	smallModel           Model
 71	systemPromptPrefix   string
 72	systemPrompt         string
 73	tools                []fantasy.AgentTool
 74	sessions             session.Service
 75	messages             message.Service
 76	disableAutoSummarize bool
 77	isYolo               bool
 78
 79	messageQueue   *csync.Map[string, []SessionAgentCall]
 80	activeRequests *csync.Map[string, context.CancelFunc]
 81}
 82
 83type SessionAgentOptions struct {
 84	LargeModel           Model
 85	SmallModel           Model
 86	SystemPromptPrefix   string
 87	SystemPrompt         string
 88	DisableAutoSummarize bool
 89	IsYolo               bool
 90	Sessions             session.Service
 91	Messages             message.Service
 92	Tools                []fantasy.AgentTool
 93}
 94
 95func NewSessionAgent(
 96	opts SessionAgentOptions,
 97) SessionAgent {
 98	return &sessionAgent{
 99		largeModel:           opts.LargeModel,
100		smallModel:           opts.SmallModel,
101		systemPromptPrefix:   opts.SystemPromptPrefix,
102		systemPrompt:         opts.SystemPrompt,
103		sessions:             opts.Sessions,
104		messages:             opts.Messages,
105		disableAutoSummarize: opts.DisableAutoSummarize,
106		tools:                opts.Tools,
107		isYolo:               opts.IsYolo,
108		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
109		activeRequests:       csync.NewMap[string, context.CancelFunc](),
110	}
111}
112
113func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
114	if call.Prompt == "" {
115		return nil, ErrEmptyPrompt
116	}
117	if call.SessionID == "" {
118		return nil, ErrSessionMissing
119	}
120
121	// Queue the message if busy
122	if a.IsSessionBusy(call.SessionID) {
123		existing, ok := a.messageQueue.Get(call.SessionID)
124		if !ok {
125			existing = []SessionAgentCall{}
126		}
127		existing = append(existing, call)
128		a.messageQueue.Set(call.SessionID, existing)
129		return nil, nil
130	}
131
132	if len(a.tools) > 0 {
133		// add anthropic caching to the last tool
134		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
135	}
136
137	agent := fantasy.NewAgent(
138		a.largeModel.Model,
139		fantasy.WithSystemPrompt(a.systemPrompt),
140		fantasy.WithTools(a.tools...),
141	)
142
143	sessionLock := sync.Mutex{}
144	currentSession, err := a.sessions.Get(ctx, call.SessionID)
145	if err != nil {
146		return nil, fmt.Errorf("failed to get session: %w", err)
147	}
148
149	msgs, err := a.getSessionMessages(ctx, currentSession)
150	if err != nil {
151		return nil, fmt.Errorf("failed to get session messages: %w", err)
152	}
153
154	var wg sync.WaitGroup
155	// Generate title if first message
156	if len(msgs) == 0 {
157		wg.Go(func() {
158			sessionLock.Lock()
159			a.generateTitle(ctx, &currentSession, call.Prompt)
160			sessionLock.Unlock()
161		})
162	}
163
164	// Add the user message to the session
165	_, err = a.createUserMessage(ctx, call)
166	if err != nil {
167		return nil, err
168	}
169
170	// add the session to the context
171	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
172
173	genCtx, cancel := context.WithCancel(ctx)
174	a.activeRequests.Set(call.SessionID, cancel)
175
176	defer cancel()
177	defer a.activeRequests.Del(call.SessionID)
178
179	history, files := a.preparePrompt(msgs, call.Attachments...)
180
181	startTime := time.Now()
182	a.eventPromptSent(call.SessionID)
183
184	var currentAssistant *message.Message
185	var shouldSummarize bool
186	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
187		Prompt:           call.Prompt,
188		Files:            files,
189		Messages:         history,
190		ProviderOptions:  call.ProviderOptions,
191		MaxOutputTokens:  &call.MaxOutputTokens,
192		TopP:             call.TopP,
193		Temperature:      call.Temperature,
194		PresencePenalty:  call.PresencePenalty,
195		TopK:             call.TopK,
196		FrequencyPenalty: call.FrequencyPenalty,
197		// Before each step create the new assistant message
198		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
199			prepared.Messages = options.Messages
200			// reset all cached items
201			for i := range prepared.Messages {
202				prepared.Messages[i].ProviderOptions = nil
203			}
204
205			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
206			a.messageQueue.Del(call.SessionID)
207			for _, queued := range queuedCalls {
208				userMessage, createErr := a.createUserMessage(callContext, queued)
209				if createErr != nil {
210					return callContext, prepared, createErr
211				}
212				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
213			}
214
215			lastSystemRoleInx := 0
216			systemMessageUpdated := false
217			for i, msg := range prepared.Messages {
218				// only add cache control to the last message
219				if msg.Role == fantasy.MessageRoleSystem {
220					lastSystemRoleInx = i
221				} else if !systemMessageUpdated {
222					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
223					systemMessageUpdated = true
224				}
225				// than add cache control to the last 2 messages
226				if i > len(prepared.Messages)-3 {
227					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
228				}
229			}
230
231			if a.systemPromptPrefix != "" {
232				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
233			}
234
235			var assistantMsg message.Message
236			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
237				Role:     message.Assistant,
238				Parts:    []message.ContentPart{},
239				Model:    a.largeModel.ModelCfg.Model,
240				Provider: a.largeModel.ModelCfg.Provider,
241			})
242			if err != nil {
243				return callContext, prepared, err
244			}
245			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
246			currentAssistant = &assistantMsg
247			return callContext, prepared, err
248		},
249		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
250			currentAssistant.AppendReasoningContent(reasoning.Text)
251			return a.messages.Update(genCtx, *currentAssistant)
252		},
253		OnReasoningDelta: func(id string, text string) error {
254			currentAssistant.AppendReasoningContent(text)
255			return a.messages.Update(genCtx, *currentAssistant)
256		},
257		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
258			// handle anthropic signature
259			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
260				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
261					currentAssistant.AppendReasoningSignature(reasoning.Signature)
262				}
263			}
264			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
265				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
266					currentAssistant.AppendReasoningSignature(reasoning.Signature)
267				}
268			}
269			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
270				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
271					currentAssistant.SetReasoningResponsesData(reasoning)
272				}
273			}
274			currentAssistant.FinishThinking()
275			return a.messages.Update(genCtx, *currentAssistant)
276		},
277		OnTextDelta: func(id string, text string) error {
278			currentAssistant.AppendContent(text)
279			return a.messages.Update(genCtx, *currentAssistant)
280		},
281		OnToolInputStart: func(id string, toolName string) error {
282			toolCall := message.ToolCall{
283				ID:               id,
284				Name:             toolName,
285				ProviderExecuted: false,
286				Finished:         false,
287			}
288			currentAssistant.AddToolCall(toolCall)
289			return a.messages.Update(genCtx, *currentAssistant)
290		},
291		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
292			// TODO: implement
293		},
294		OnToolCall: func(tc fantasy.ToolCallContent) error {
295			toolCall := message.ToolCall{
296				ID:               tc.ToolCallID,
297				Name:             tc.ToolName,
298				Input:            tc.Input,
299				ProviderExecuted: false,
300				Finished:         true,
301			}
302			currentAssistant.AddToolCall(toolCall)
303			return a.messages.Update(genCtx, *currentAssistant)
304		},
305		OnToolResult: func(result fantasy.ToolResultContent) error {
306			var resultContent string
307			isError := false
308			switch result.Result.GetType() {
309			case fantasy.ToolResultContentTypeText:
310				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
311				if ok {
312					resultContent = r.Text
313				}
314			case fantasy.ToolResultContentTypeError:
315				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
316				if ok {
317					isError = true
318					resultContent = r.Error.Error()
319				}
320			case fantasy.ToolResultContentTypeMedia:
321				// TODO: handle this message type
322			}
323			toolResult := message.ToolResult{
324				ToolCallID: result.ToolCallID,
325				Name:       result.ToolName,
326				Content:    resultContent,
327				IsError:    isError,
328				Metadata:   result.ClientMetadata,
329			}
330			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
331				Role: message.Tool,
332				Parts: []message.ContentPart{
333					toolResult,
334				},
335			})
336			if createMsgErr != nil {
337				return createMsgErr
338			}
339			return nil
340		},
341		OnStepFinish: func(stepResult fantasy.StepResult) error {
342			finishReason := message.FinishReasonUnknown
343			switch stepResult.FinishReason {
344			case fantasy.FinishReasonLength:
345				finishReason = message.FinishReasonMaxTokens
346			case fantasy.FinishReasonStop:
347				finishReason = message.FinishReasonEndTurn
348			case fantasy.FinishReasonToolCalls:
349				finishReason = message.FinishReasonToolUse
350			}
351			currentAssistant.AddFinish(finishReason, "", "")
352			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
353			sessionLock.Lock()
354			_, sessionErr := a.sessions.Save(genCtx, currentSession)
355			sessionLock.Unlock()
356			if sessionErr != nil {
357				return sessionErr
358			}
359			return a.messages.Update(genCtx, *currentAssistant)
360		},
361		StopWhen: []fantasy.StopCondition{
362			func(_ []fantasy.StepResult) bool {
363				contextWindow := a.largeModel.CatwalkCfg.ContextWindow
364				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
365				percentage := (float64(tokens) / float64(contextWindow)) * 100
366				if (percentage > 80) && !a.disableAutoSummarize {
367					shouldSummarize = true
368					return true
369				}
370				return false
371			},
372		},
373	})
374
375	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
376
377	if err != nil {
378		isCancelErr := errors.Is(err, context.Canceled)
379		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
380		if currentAssistant == nil {
381			return result, err
382		}
383		// Ensure we finish thinking on error to close the reasoning state
384		currentAssistant.FinishThinking()
385		toolCalls := currentAssistant.ToolCalls()
386		// INFO: we use the parent context here because the genCtx has been cancelled
387		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
388		if createErr != nil {
389			return nil, createErr
390		}
391		for _, tc := range toolCalls {
392			if !tc.Finished {
393				tc.Finished = true
394				tc.Input = "{}"
395				currentAssistant.AddToolCall(tc)
396				updateErr := a.messages.Update(ctx, *currentAssistant)
397				if updateErr != nil {
398					return nil, updateErr
399				}
400			}
401
402			found := false
403			for _, msg := range msgs {
404				if msg.Role == message.Tool {
405					for _, tr := range msg.ToolResults() {
406						if tr.ToolCallID == tc.ID {
407							found = true
408							break
409						}
410					}
411				}
412				if found {
413					break
414				}
415			}
416			if found {
417				continue
418			}
419			content := "There was an error while executing the tool"
420			if isCancelErr {
421				content = "Tool execution canceled by user"
422			} else if isPermissionErr {
423				content = "Permission denied"
424			}
425			toolResult := message.ToolResult{
426				ToolCallID: tc.ID,
427				Name:       tc.Name,
428				Content:    content,
429				IsError:    true,
430			}
431			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
432				Role: message.Tool,
433				Parts: []message.ContentPart{
434					toolResult,
435				},
436			})
437			if createErr != nil {
438				return nil, createErr
439			}
440		}
441		if isCancelErr {
442			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
443		} else if isPermissionErr {
444			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
445		} else {
446			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
447		}
448		// INFO: we use the parent context here because the genCtx has been cancelled
449		updateErr := a.messages.Update(ctx, *currentAssistant)
450		if updateErr != nil {
451			return nil, updateErr
452		}
453		return nil, err
454	}
455	wg.Wait()
456
457	if shouldSummarize {
458		a.activeRequests.Del(call.SessionID)
459		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
460			return nil, summarizeErr
461		}
462		// if the agent was not done...
463		if len(currentAssistant.ToolCalls()) > 0 {
464			existing, ok := a.messageQueue.Get(call.SessionID)
465			if !ok {
466				existing = []SessionAgentCall{}
467			}
468			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
469			existing = append(existing, call)
470			a.messageQueue.Set(call.SessionID, existing)
471		}
472	}
473
474	// release active request before processing queued messages
475	a.activeRequests.Del(call.SessionID)
476	cancel()
477
478	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
479	if !ok || len(queuedMessages) == 0 {
480		return result, err
481	}
482	// there are queued messages restart the loop
483	firstQueuedMessage := queuedMessages[0]
484	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
485	return a.Run(ctx, firstQueuedMessage)
486}
487
488func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
489	if a.IsSessionBusy(sessionID) {
490		return ErrSessionBusy
491	}
492
493	currentSession, err := a.sessions.Get(ctx, sessionID)
494	if err != nil {
495		return fmt.Errorf("failed to get session: %w", err)
496	}
497	msgs, err := a.getSessionMessages(ctx, currentSession)
498	if err != nil {
499		return err
500	}
501	if len(msgs) == 0 {
502		// nothing to summarize
503		return nil
504	}
505
506	aiMsgs, _ := a.preparePrompt(msgs)
507
508	genCtx, cancel := context.WithCancel(ctx)
509	a.activeRequests.Set(sessionID, cancel)
510	defer a.activeRequests.Del(sessionID)
511	defer cancel()
512
513	agent := fantasy.NewAgent(a.largeModel.Model,
514		fantasy.WithSystemPrompt(string(summaryPrompt)),
515	)
516	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
517		Role:             message.Assistant,
518		Model:            a.largeModel.Model.Model(),
519		Provider:         a.largeModel.Model.Provider(),
520		IsSummaryMessage: true,
521	})
522	if err != nil {
523		return err
524	}
525
526	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
527		Prompt:          "Provide a detailed summary of our conversation above.",
528		Messages:        aiMsgs,
529		ProviderOptions: opts,
530		OnReasoningDelta: func(id string, text string) error {
531			summaryMessage.AppendReasoningContent(text)
532			return a.messages.Update(genCtx, summaryMessage)
533		},
534		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
535			// handle anthropic signature
536			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
537				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
538					summaryMessage.AppendReasoningSignature(signature.Signature)
539				}
540			}
541			summaryMessage.FinishThinking()
542			return a.messages.Update(genCtx, summaryMessage)
543		},
544		OnTextDelta: func(id, text string) error {
545			summaryMessage.AppendContent(text)
546			return a.messages.Update(genCtx, summaryMessage)
547		},
548	})
549	if err != nil {
550		isCancelErr := errors.Is(err, context.Canceled)
551		if isCancelErr {
552			// User cancelled summarize we need to remove the summary message
553			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
554			return deleteErr
555		}
556		return err
557	}
558
559	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
560	err = a.messages.Update(genCtx, summaryMessage)
561	if err != nil {
562		return err
563	}
564
565	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
566
567	// just in case get just the last usage
568	usage := resp.Response.Usage
569	currentSession.SummaryMessageID = summaryMessage.ID
570	currentSession.CompletionTokens = usage.OutputTokens
571	currentSession.PromptTokens = 0
572	_, err = a.sessions.Save(genCtx, currentSession)
573	return err
574}
575
576func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
577	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
578		return fantasy.ProviderOptions{}
579	}
580	return fantasy.ProviderOptions{
581		anthropic.Name: &anthropic.ProviderCacheControlOptions{
582			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
583		},
584		bedrock.Name: &anthropic.ProviderCacheControlOptions{
585			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
586		},
587	}
588}
589
590func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
591	var attachmentParts []message.ContentPart
592	for _, attachment := range call.Attachments {
593		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
594	}
595	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
596	parts = append(parts, attachmentParts...)
597	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
598		Role:  message.User,
599		Parts: parts,
600	})
601	if err != nil {
602		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
603	}
604	return msg, nil
605}
606
607func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
608	var history []fantasy.Message
609	for _, m := range msgs {
610		if len(m.Parts) == 0 {
611			continue
612		}
613		// Assistant message without content or tool calls (cancelled before it returned anything)
614		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
615			continue
616		}
617		history = append(history, m.ToAIMessage()...)
618	}
619
620	var files []fantasy.FilePart
621	for _, attachment := range attachments {
622		files = append(files, fantasy.FilePart{
623			Filename:  attachment.FileName,
624			Data:      attachment.Content,
625			MediaType: attachment.MimeType,
626		})
627	}
628
629	return history, files
630}
631
632func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
633	msgs, err := a.messages.List(ctx, session.ID)
634	if err != nil {
635		return nil, fmt.Errorf("failed to list messages: %w", err)
636	}
637
638	if session.SummaryMessageID != "" {
639		summaryMsgInex := -1
640		for i, msg := range msgs {
641			if msg.ID == session.SummaryMessageID {
642				summaryMsgInex = i
643				break
644			}
645		}
646		if summaryMsgInex != -1 {
647			msgs = msgs[summaryMsgInex:]
648			msgs[0].Role = message.User
649		}
650	}
651	return msgs, nil
652}
653
654func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
655	if prompt == "" {
656		return
657	}
658
659	var maxOutput int64 = 40
660	if a.smallModel.CatwalkCfg.CanReason {
661		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
662	}
663
664	agent := fantasy.NewAgent(a.smallModel.Model,
665		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
666		fantasy.WithMaxOutputTokens(maxOutput),
667	)
668
669	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
670		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
671	})
672	if err != nil {
673		slog.Error("error generating title", "err", err)
674		return
675	}
676
677	title := resp.Response.Content.Text()
678
679	title = strings.ReplaceAll(title, "\n", " ")
680
681	// remove thinking tags if present
682	if idx := strings.Index(title, "</think>"); idx > 0 {
683		title = title[idx+len("</think>"):]
684	}
685
686	title = strings.TrimSpace(title)
687	if title == "" {
688		slog.Warn("failed to generate title", "warn", "empty title")
689		return
690	}
691
692	session.Title = title
693	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
694	_, saveErr := a.sessions.Save(ctx, *session)
695	if saveErr != nil {
696		slog.Error("failed to save session title & usage", "error", saveErr)
697		return
698	}
699}
700
701func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage) {
702	modelConfig := model.CatwalkCfg
703	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
704		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
705		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
706		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
707
708	a.eventTokensUsed(session.ID, model, usage, cost)
709
710	session.Cost += cost
711	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
712	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
713}
714
715func (a *sessionAgent) Cancel(sessionID string) {
716	// Cancel regular requests
717	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
718		slog.Info("Request cancellation initiated", "session_id", sessionID)
719		cancel()
720	}
721
722	// Also check for summarize requests
723	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
724		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
725		cancel()
726	}
727
728	if a.QueuedPrompts(sessionID) > 0 {
729		slog.Info("Clearing queued prompts", "session_id", sessionID)
730		a.messageQueue.Del(sessionID)
731	}
732}
733
734func (a *sessionAgent) ClearQueue(sessionID string) {
735	if a.QueuedPrompts(sessionID) > 0 {
736		slog.Info("Clearing queued prompts", "session_id", sessionID)
737		a.messageQueue.Del(sessionID)
738	}
739}
740
741func (a *sessionAgent) CancelAll() {
742	if !a.IsBusy() {
743		return
744	}
745	for key := range a.activeRequests.Seq2() {
746		a.Cancel(key) // key is sessionID
747	}
748
749	timeout := time.After(5 * time.Second)
750	for a.IsBusy() {
751		select {
752		case <-timeout:
753			return
754		default:
755			time.Sleep(200 * time.Millisecond)
756		}
757	}
758}
759
760func (a *sessionAgent) IsBusy() bool {
761	var busy bool
762	for cancelFunc := range a.activeRequests.Seq() {
763		if cancelFunc != nil {
764			busy = true
765			break
766		}
767	}
768	return busy
769}
770
771func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
772	_, busy := a.activeRequests.Get(sessionID)
773	return busy
774}
775
776func (a *sessionAgent) QueuedPrompts(sessionID string) int {
777	l, ok := a.messageQueue.Get(sessionID)
778	if !ok {
779		return 0
780	}
781	return len(l)
782}
783
784func (a *sessionAgent) SetModels(large Model, small Model) {
785	a.largeModel = large
786	a.smallModel = small
787}
788
789func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
790	a.tools = tools
791}
792
793func (a *sessionAgent) Model() Model {
794	return a.largeModel
795}