agent.go

  1// Package agent is the core orchestration layer for Crush AI agents.
  2//
  3// It provides session-based AI agent functionality for managing
  4// conversations, tool execution, and message handling. It coordinates
  5// interactions between language models, messages, sessions, and tools while
  6// handling features like automatic summarization, queuing, and token
  7// management.
  8package agent
  9
 10import (
 11	"context"
 12	_ "embed"
 13	"errors"
 14	"fmt"
 15	"log/slog"
 16	"os"
 17	"strconv"
 18	"strings"
 19	"sync"
 20	"time"
 21
 22	"charm.land/fantasy"
 23	"charm.land/fantasy/providers/anthropic"
 24	"charm.land/fantasy/providers/bedrock"
 25	"charm.land/fantasy/providers/google"
 26	"charm.land/fantasy/providers/openai"
 27	"charm.land/fantasy/providers/openrouter"
 28	"github.com/charmbracelet/catwalk/pkg/catwalk"
 29	"github.com/charmbracelet/crush/internal/agent/tools"
 30	"github.com/charmbracelet/crush/internal/config"
 31	"github.com/charmbracelet/crush/internal/csync"
 32	"github.com/charmbracelet/crush/internal/message"
 33	"github.com/charmbracelet/crush/internal/permission"
 34	"github.com/charmbracelet/crush/internal/session"
 35)
 36
 37//go:embed templates/title.md
 38var titlePrompt []byte
 39
 40//go:embed templates/summary.md
 41var summaryPrompt []byte
 42
 43type SessionAgentCall struct {
 44	SessionID        string
 45	Prompt           string
 46	ProviderOptions  fantasy.ProviderOptions
 47	Attachments      []message.Attachment
 48	MaxOutputTokens  int64
 49	Temperature      *float64
 50	TopP             *float64
 51	TopK             *int64
 52	FrequencyPenalty *float64
 53	PresencePenalty  *float64
 54}
 55
 56type SessionAgent interface {
 57	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 58	SetModels(large Model, small Model)
 59	SetTools(tools []fantasy.AgentTool)
 60	Cancel(sessionID string)
 61	CancelAll()
 62	IsSessionBusy(sessionID string) bool
 63	IsBusy() bool
 64	QueuedPrompts(sessionID string) int
 65	ClearQueue(sessionID string)
 66	Summarize(context.Context, string, fantasy.ProviderOptions) error
 67	Model() Model
 68}
 69
 70type Model struct {
 71	Model      fantasy.LanguageModel
 72	CatwalkCfg catwalk.Model
 73	ModelCfg   config.SelectedModel
 74}
 75
 76type sessionAgent struct {
 77	largeModel           Model
 78	smallModel           Model
 79	systemPromptPrefix   string
 80	systemPrompt         string
 81	tools                []fantasy.AgentTool
 82	sessions             session.Service
 83	messages             message.Service
 84	disableAutoSummarize bool
 85	isYolo               bool
 86
 87	messageQueue   *csync.Map[string, []SessionAgentCall]
 88	activeRequests *csync.Map[string, context.CancelFunc]
 89}
 90
 91type SessionAgentOptions struct {
 92	LargeModel           Model
 93	SmallModel           Model
 94	SystemPromptPrefix   string
 95	SystemPrompt         string
 96	DisableAutoSummarize bool
 97	IsYolo               bool
 98	Sessions             session.Service
 99	Messages             message.Service
100	Tools                []fantasy.AgentTool
101}
102
103func NewSessionAgent(
104	opts SessionAgentOptions,
105) SessionAgent {
106	return &sessionAgent{
107		largeModel:           opts.LargeModel,
108		smallModel:           opts.SmallModel,
109		systemPromptPrefix:   opts.SystemPromptPrefix,
110		systemPrompt:         opts.SystemPrompt,
111		sessions:             opts.Sessions,
112		messages:             opts.Messages,
113		disableAutoSummarize: opts.DisableAutoSummarize,
114		tools:                opts.Tools,
115		isYolo:               opts.IsYolo,
116		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
117		activeRequests:       csync.NewMap[string, context.CancelFunc](),
118	}
119}
120
121func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
122	if call.Prompt == "" {
123		return nil, ErrEmptyPrompt
124	}
125	if call.SessionID == "" {
126		return nil, ErrSessionMissing
127	}
128
129	// Queue the message if busy
130	if a.IsSessionBusy(call.SessionID) {
131		existing, ok := a.messageQueue.Get(call.SessionID)
132		if !ok {
133			existing = []SessionAgentCall{}
134		}
135		existing = append(existing, call)
136		a.messageQueue.Set(call.SessionID, existing)
137		return nil, nil
138	}
139
140	if len(a.tools) > 0 {
141		// Add anthropic caching to the last tool.
142		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
143	}
144
145	agent := fantasy.NewAgent(
146		a.largeModel.Model,
147		fantasy.WithSystemPrompt(a.systemPrompt),
148		fantasy.WithTools(a.tools...),
149	)
150
151	sessionLock := sync.Mutex{}
152	currentSession, err := a.sessions.Get(ctx, call.SessionID)
153	if err != nil {
154		return nil, fmt.Errorf("failed to get session: %w", err)
155	}
156
157	msgs, err := a.getSessionMessages(ctx, currentSession)
158	if err != nil {
159		return nil, fmt.Errorf("failed to get session messages: %w", err)
160	}
161
162	var wg sync.WaitGroup
163	// Generate title if first message.
164	if len(msgs) == 0 {
165		wg.Go(func() {
166			sessionLock.Lock()
167			a.generateTitle(ctx, &currentSession, call.Prompt)
168			sessionLock.Unlock()
169		})
170	}
171
172	// Add the user message to the session.
173	_, err = a.createUserMessage(ctx, call)
174	if err != nil {
175		return nil, err
176	}
177
178	// Add the session to the context.
179	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
180
181	genCtx, cancel := context.WithCancel(ctx)
182	a.activeRequests.Set(call.SessionID, cancel)
183
184	defer cancel()
185	defer a.activeRequests.Del(call.SessionID)
186
187	history, files := a.preparePrompt(msgs, call.Attachments...)
188
189	startTime := time.Now()
190	a.eventPromptSent(call.SessionID)
191
192	var currentAssistant *message.Message
193	var shouldSummarize bool
194	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
195		Prompt:           call.Prompt,
196		Files:            files,
197		Messages:         history,
198		ProviderOptions:  call.ProviderOptions,
199		MaxOutputTokens:  &call.MaxOutputTokens,
200		TopP:             call.TopP,
201		Temperature:      call.Temperature,
202		PresencePenalty:  call.PresencePenalty,
203		TopK:             call.TopK,
204		FrequencyPenalty: call.FrequencyPenalty,
205		// Before each step create a new assistant message.
206		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
207			prepared.Messages = options.Messages
208			// Reset all cached items.
209			for i := range prepared.Messages {
210				prepared.Messages[i].ProviderOptions = nil
211			}
212
213			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
214			a.messageQueue.Del(call.SessionID)
215			for _, queued := range queuedCalls {
216				userMessage, createErr := a.createUserMessage(callContext, queued)
217				if createErr != nil {
218					return callContext, prepared, createErr
219				}
220				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
221			}
222
223			lastSystemRoleInx := 0
224			systemMessageUpdated := false
225			for i, msg := range prepared.Messages {
226				// Only add cache control to the last message.
227				if msg.Role == fantasy.MessageRoleSystem {
228					lastSystemRoleInx = i
229				} else if !systemMessageUpdated {
230					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
231					systemMessageUpdated = true
232				}
233				// Than add cache control to the last 2 messages.
234				if i > len(prepared.Messages)-3 {
235					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
236				}
237			}
238
239			if a.systemPromptPrefix != "" {
240				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
241			}
242
243			var assistantMsg message.Message
244			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
245				Role:     message.Assistant,
246				Parts:    []message.ContentPart{},
247				Model:    a.largeModel.ModelCfg.Model,
248				Provider: a.largeModel.ModelCfg.Provider,
249			})
250			if err != nil {
251				return callContext, prepared, err
252			}
253			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
254			currentAssistant = &assistantMsg
255			return callContext, prepared, err
256		},
257		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
258			currentAssistant.AppendReasoningContent(reasoning.Text)
259			return a.messages.Update(genCtx, *currentAssistant)
260		},
261		OnReasoningDelta: func(id string, text string) error {
262			currentAssistant.AppendReasoningContent(text)
263			return a.messages.Update(genCtx, *currentAssistant)
264		},
265		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
266			// handle anthropic signature
267			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
268				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
269					currentAssistant.AppendReasoningSignature(reasoning.Signature)
270				}
271			}
272			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
273				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
274					currentAssistant.AppendReasoningSignature(reasoning.Signature)
275				}
276			}
277			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
278				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
279					currentAssistant.SetReasoningResponsesData(reasoning)
280				}
281			}
282			currentAssistant.FinishThinking()
283			return a.messages.Update(genCtx, *currentAssistant)
284		},
285		OnTextDelta: func(id string, text string) error {
286			// Strip leading newline from initial text content. This is is
287			// particularly important in non-interactive mode where leading
288			// newlines are very visible.
289			if len(currentAssistant.Parts) == 0 && strings.HasPrefix(text, "\n") {
290				text = strings.TrimPrefix(text, "\n")
291			}
292
293			currentAssistant.AppendContent(text)
294			return a.messages.Update(genCtx, *currentAssistant)
295		},
296		OnToolInputStart: func(id string, toolName string) error {
297			toolCall := message.ToolCall{
298				ID:               id,
299				Name:             toolName,
300				ProviderExecuted: false,
301				Finished:         false,
302			}
303			currentAssistant.AddToolCall(toolCall)
304			return a.messages.Update(genCtx, *currentAssistant)
305		},
306		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
307			// TODO: implement
308		},
309		OnToolCall: func(tc fantasy.ToolCallContent) error {
310			toolCall := message.ToolCall{
311				ID:               tc.ToolCallID,
312				Name:             tc.ToolName,
313				Input:            tc.Input,
314				ProviderExecuted: false,
315				Finished:         true,
316			}
317			currentAssistant.AddToolCall(toolCall)
318			return a.messages.Update(genCtx, *currentAssistant)
319		},
320		OnToolResult: func(result fantasy.ToolResultContent) error {
321			var resultContent string
322			isError := false
323			switch result.Result.GetType() {
324			case fantasy.ToolResultContentTypeText:
325				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
326				if ok {
327					resultContent = r.Text
328				}
329			case fantasy.ToolResultContentTypeError:
330				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
331				if ok {
332					isError = true
333					resultContent = r.Error.Error()
334				}
335			case fantasy.ToolResultContentTypeMedia:
336				// TODO: handle this message type
337			}
338			toolResult := message.ToolResult{
339				ToolCallID: result.ToolCallID,
340				Name:       result.ToolName,
341				Content:    resultContent,
342				IsError:    isError,
343				Metadata:   result.ClientMetadata,
344			}
345			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
346				Role: message.Tool,
347				Parts: []message.ContentPart{
348					toolResult,
349				},
350			})
351			if createMsgErr != nil {
352				return createMsgErr
353			}
354			return nil
355		},
356		OnStepFinish: func(stepResult fantasy.StepResult) error {
357			finishReason := message.FinishReasonUnknown
358			switch stepResult.FinishReason {
359			case fantasy.FinishReasonLength:
360				finishReason = message.FinishReasonMaxTokens
361			case fantasy.FinishReasonStop:
362				finishReason = message.FinishReasonEndTurn
363			case fantasy.FinishReasonToolCalls:
364				finishReason = message.FinishReasonToolUse
365			}
366			currentAssistant.AddFinish(finishReason, "", "")
367			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
368			sessionLock.Lock()
369			_, sessionErr := a.sessions.Save(genCtx, currentSession)
370			sessionLock.Unlock()
371			if sessionErr != nil {
372				return sessionErr
373			}
374			return a.messages.Update(genCtx, *currentAssistant)
375		},
376		StopWhen: []fantasy.StopCondition{
377			func(_ []fantasy.StepResult) bool {
378				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
379				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
380				remaining := cw - tokens
381				var threshold int64
382				if cw > 200_000 {
383					threshold = 20_000
384				} else {
385					threshold = int64(float64(cw) * 0.2)
386				}
387				if (remaining <= threshold) && !a.disableAutoSummarize {
388					shouldSummarize = true
389					return true
390				}
391				return false
392			},
393		},
394	})
395
396	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
397
398	if err != nil {
399		isCancelErr := errors.Is(err, context.Canceled)
400		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
401		if currentAssistant == nil {
402			return result, err
403		}
404		// Ensure we finish thinking on error to close the reasoning state.
405		currentAssistant.FinishThinking()
406		toolCalls := currentAssistant.ToolCalls()
407		// INFO: we use the parent context here because the genCtx has been cancelled.
408		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
409		if createErr != nil {
410			return nil, createErr
411		}
412		for _, tc := range toolCalls {
413			if !tc.Finished {
414				tc.Finished = true
415				tc.Input = "{}"
416				currentAssistant.AddToolCall(tc)
417				updateErr := a.messages.Update(ctx, *currentAssistant)
418				if updateErr != nil {
419					return nil, updateErr
420				}
421			}
422
423			found := false
424			for _, msg := range msgs {
425				if msg.Role == message.Tool {
426					for _, tr := range msg.ToolResults() {
427						if tr.ToolCallID == tc.ID {
428							found = true
429							break
430						}
431					}
432				}
433				if found {
434					break
435				}
436			}
437			if found {
438				continue
439			}
440			content := "There was an error while executing the tool"
441			if isCancelErr {
442				content = "Tool execution canceled by user"
443			} else if isPermissionErr {
444				content = "Permission denied"
445			}
446			toolResult := message.ToolResult{
447				ToolCallID: tc.ID,
448				Name:       tc.Name,
449				Content:    content,
450				IsError:    true,
451			}
452			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
453				Role: message.Tool,
454				Parts: []message.ContentPart{
455					toolResult,
456				},
457			})
458			if createErr != nil {
459				return nil, createErr
460			}
461		}
462		if isCancelErr {
463			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
464		} else if isPermissionErr {
465			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
466		} else {
467			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
468		}
469		// Note: we use the parent context here because the genCtx has been
470		// cancelled.
471		updateErr := a.messages.Update(ctx, *currentAssistant)
472		if updateErr != nil {
473			return nil, updateErr
474		}
475		return nil, err
476	}
477	wg.Wait()
478
479	if shouldSummarize {
480		a.activeRequests.Del(call.SessionID)
481		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
482			return nil, summarizeErr
483		}
484		// If the agent wasn't done...
485		if len(currentAssistant.ToolCalls()) > 0 {
486			existing, ok := a.messageQueue.Get(call.SessionID)
487			if !ok {
488				existing = []SessionAgentCall{}
489			}
490			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
491			existing = append(existing, call)
492			a.messageQueue.Set(call.SessionID, existing)
493		}
494	}
495
496	// Release active request before processing queued messages.
497	a.activeRequests.Del(call.SessionID)
498	cancel()
499
500	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
501	if !ok || len(queuedMessages) == 0 {
502		return result, err
503	}
504	// There are queued messages restart the loop.
505	firstQueuedMessage := queuedMessages[0]
506	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
507	return a.Run(ctx, firstQueuedMessage)
508}
509
510func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
511	if a.IsSessionBusy(sessionID) {
512		return ErrSessionBusy
513	}
514
515	currentSession, err := a.sessions.Get(ctx, sessionID)
516	if err != nil {
517		return fmt.Errorf("failed to get session: %w", err)
518	}
519	msgs, err := a.getSessionMessages(ctx, currentSession)
520	if err != nil {
521		return err
522	}
523	if len(msgs) == 0 {
524		// Nothing to summarize.
525		return nil
526	}
527
528	aiMsgs, _ := a.preparePrompt(msgs)
529
530	genCtx, cancel := context.WithCancel(ctx)
531	a.activeRequests.Set(sessionID, cancel)
532	defer a.activeRequests.Del(sessionID)
533	defer cancel()
534
535	agent := fantasy.NewAgent(a.largeModel.Model,
536		fantasy.WithSystemPrompt(string(summaryPrompt)),
537	)
538	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
539		Role:             message.Assistant,
540		Model:            a.largeModel.Model.Model(),
541		Provider:         a.largeModel.Model.Provider(),
542		IsSummaryMessage: true,
543	})
544	if err != nil {
545		return err
546	}
547
548	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
549		Prompt:          "Provide a detailed summary of our conversation above.",
550		Messages:        aiMsgs,
551		ProviderOptions: opts,
552		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
553			prepared.Messages = options.Messages
554			if a.systemPromptPrefix != "" {
555				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
556			}
557			return callContext, prepared, nil
558		},
559		OnReasoningDelta: func(id string, text string) error {
560			summaryMessage.AppendReasoningContent(text)
561			return a.messages.Update(genCtx, summaryMessage)
562		},
563		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
564			// Handle anthropic signature.
565			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
566				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
567					summaryMessage.AppendReasoningSignature(signature.Signature)
568				}
569			}
570			summaryMessage.FinishThinking()
571			return a.messages.Update(genCtx, summaryMessage)
572		},
573		OnTextDelta: func(id, text string) error {
574			summaryMessage.AppendContent(text)
575			return a.messages.Update(genCtx, summaryMessage)
576		},
577	})
578	if err != nil {
579		isCancelErr := errors.Is(err, context.Canceled)
580		if isCancelErr {
581			// User cancelled summarize we need to remove the summary message.
582			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
583			return deleteErr
584		}
585		return err
586	}
587
588	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
589	err = a.messages.Update(genCtx, summaryMessage)
590	if err != nil {
591		return err
592	}
593
594	var openrouterCost *float64
595	for _, step := range resp.Steps {
596		stepCost := a.openrouterCost(step.ProviderMetadata)
597		if stepCost != nil {
598			newCost := *stepCost
599			if openrouterCost != nil {
600				newCost += *openrouterCost
601			}
602			openrouterCost = &newCost
603		}
604	}
605
606	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
607
608	// Just in case, get just the last usage info.
609	usage := resp.Response.Usage
610	currentSession.SummaryMessageID = summaryMessage.ID
611	currentSession.CompletionTokens = usage.OutputTokens
612	currentSession.PromptTokens = 0
613	_, err = a.sessions.Save(genCtx, currentSession)
614	return err
615}
616
617func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
618	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
619		return fantasy.ProviderOptions{}
620	}
621	return fantasy.ProviderOptions{
622		anthropic.Name: &anthropic.ProviderCacheControlOptions{
623			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
624		},
625		bedrock.Name: &anthropic.ProviderCacheControlOptions{
626			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
627		},
628	}
629}
630
631func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
632	var attachmentParts []message.ContentPart
633	for _, attachment := range call.Attachments {
634		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
635	}
636	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
637	parts = append(parts, attachmentParts...)
638	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
639		Role:  message.User,
640		Parts: parts,
641	})
642	if err != nil {
643		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
644	}
645	return msg, nil
646}
647
648func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
649	var history []fantasy.Message
650	for _, m := range msgs {
651		if len(m.Parts) == 0 {
652			continue
653		}
654		// Assistant message without content or tool calls (cancelled before it
655		// returned anything).
656		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
657			continue
658		}
659		history = append(history, m.ToAIMessage()...)
660	}
661
662	var files []fantasy.FilePart
663	for _, attachment := range attachments {
664		files = append(files, fantasy.FilePart{
665			Filename:  attachment.FileName,
666			Data:      attachment.Content,
667			MediaType: attachment.MimeType,
668		})
669	}
670
671	return history, files
672}
673
674func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
675	msgs, err := a.messages.List(ctx, session.ID)
676	if err != nil {
677		return nil, fmt.Errorf("failed to list messages: %w", err)
678	}
679
680	if session.SummaryMessageID != "" {
681		summaryMsgInex := -1
682		for i, msg := range msgs {
683			if msg.ID == session.SummaryMessageID {
684				summaryMsgInex = i
685				break
686			}
687		}
688		if summaryMsgInex != -1 {
689			msgs = msgs[summaryMsgInex:]
690			msgs[0].Role = message.User
691		}
692	}
693	return msgs, nil
694}
695
696func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
697	if prompt == "" {
698		return
699	}
700
701	var maxOutput int64 = 40
702	if a.smallModel.CatwalkCfg.CanReason {
703		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
704	}
705
706	agent := fantasy.NewAgent(a.smallModel.Model,
707		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
708		fantasy.WithMaxOutputTokens(maxOutput),
709	)
710
711	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
712		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
713		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
714			prepared.Messages = options.Messages
715			if a.systemPromptPrefix != "" {
716				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
717			}
718			return callContext, prepared, nil
719		},
720	})
721	if err != nil {
722		slog.Error("error generating title", "err", err)
723		return
724	}
725
726	title := resp.Response.Content.Text()
727
728	title = strings.ReplaceAll(title, "\n", " ")
729
730	// Remove thinking tags if present.
731	if idx := strings.Index(title, "</think>"); idx > 0 {
732		title = title[idx+len("</think>"):]
733	}
734
735	title = strings.TrimSpace(title)
736	if title == "" {
737		slog.Warn("failed to generate title", "warn", "empty title")
738		return
739	}
740
741	session.Title = title
742
743	var openrouterCost *float64
744	for _, step := range resp.Steps {
745		stepCost := a.openrouterCost(step.ProviderMetadata)
746		if stepCost != nil {
747			newCost := *stepCost
748			if openrouterCost != nil {
749				newCost += *openrouterCost
750			}
751			openrouterCost = &newCost
752		}
753	}
754
755	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
756	_, saveErr := a.sessions.Save(ctx, *session)
757	if saveErr != nil {
758		slog.Error("failed to save session title & usage", "error", saveErr)
759		return
760	}
761}
762
763func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
764	openrouterMetadata, ok := metadata[openrouter.Name]
765	if !ok {
766		return nil
767	}
768
769	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
770	if !ok {
771		return nil
772	}
773	return &opts.Usage.Cost
774}
775
776func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
777	modelConfig := model.CatwalkCfg
778	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
779		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
780		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
781		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
782
783	a.eventTokensUsed(session.ID, model, usage, cost)
784
785	if overrideCost != nil {
786		session.Cost += *overrideCost
787	} else {
788		session.Cost += cost
789	}
790
791	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
792	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
793}
794
795func (a *sessionAgent) Cancel(sessionID string) {
796	// Cancel regular requests.
797	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
798		slog.Info("Request cancellation initiated", "session_id", sessionID)
799		cancel()
800	}
801
802	// Also check for summarize requests.
803	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
804		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
805		cancel()
806	}
807
808	if a.QueuedPrompts(sessionID) > 0 {
809		slog.Info("Clearing queued prompts", "session_id", sessionID)
810		a.messageQueue.Del(sessionID)
811	}
812}
813
814func (a *sessionAgent) ClearQueue(sessionID string) {
815	if a.QueuedPrompts(sessionID) > 0 {
816		slog.Info("Clearing queued prompts", "session_id", sessionID)
817		a.messageQueue.Del(sessionID)
818	}
819}
820
821func (a *sessionAgent) CancelAll() {
822	if !a.IsBusy() {
823		return
824	}
825	for key := range a.activeRequests.Seq2() {
826		a.Cancel(key) // key is sessionID
827	}
828
829	timeout := time.After(5 * time.Second)
830	for a.IsBusy() {
831		select {
832		case <-timeout:
833			return
834		default:
835			time.Sleep(200 * time.Millisecond)
836		}
837	}
838}
839
840func (a *sessionAgent) IsBusy() bool {
841	var busy bool
842	for cancelFunc := range a.activeRequests.Seq() {
843		if cancelFunc != nil {
844			busy = true
845			break
846		}
847	}
848	return busy
849}
850
851func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
852	_, busy := a.activeRequests.Get(sessionID)
853	return busy
854}
855
856func (a *sessionAgent) QueuedPrompts(sessionID string) int {
857	l, ok := a.messageQueue.Get(sessionID)
858	if !ok {
859		return 0
860	}
861	return len(l)
862}
863
864func (a *sessionAgent) SetModels(large Model, small Model) {
865	a.largeModel = large
866	a.smallModel = small
867}
868
869func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
870	a.tools = tools
871}
872
873func (a *sessionAgent) Model() Model {
874	return a.largeModel
875}