agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"strconv"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"charm.land/fantasy/providers/anthropic"
 17	"charm.land/fantasy/providers/bedrock"
 18	"charm.land/fantasy/providers/google"
 19	"charm.land/fantasy/providers/openai"
 20	"charm.land/fantasy/providers/openrouter"
 21	"github.com/charmbracelet/catwalk/pkg/catwalk"
 22	"github.com/charmbracelet/crush/internal/agent/tools"
 23	"github.com/charmbracelet/crush/internal/config"
 24	"github.com/charmbracelet/crush/internal/csync"
 25	"github.com/charmbracelet/crush/internal/message"
 26	"github.com/charmbracelet/crush/internal/permission"
 27	"github.com/charmbracelet/crush/internal/session"
 28)
 29
 30//go:embed templates/title.md
 31var titlePrompt []byte
 32
 33//go:embed templates/summary.md
 34var summaryPrompt []byte
 35
 36type SessionAgentCall struct {
 37	SessionID        string
 38	Prompt           string
 39	ProviderOptions  fantasy.ProviderOptions
 40	Attachments      []message.Attachment
 41	MaxOutputTokens  int64
 42	Temperature      *float64
 43	TopP             *float64
 44	TopK             *int64
 45	FrequencyPenalty *float64
 46	PresencePenalty  *float64
 47}
 48
 49type SessionAgent interface {
 50	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 51	SetModels(large Model, small Model)
 52	SetTools(tools []fantasy.AgentTool)
 53	Cancel(sessionID string)
 54	CancelAll()
 55	IsSessionBusy(sessionID string) bool
 56	IsBusy() bool
 57	QueuedPrompts(sessionID string) int
 58	ClearQueue(sessionID string)
 59	Summarize(context.Context, string, fantasy.ProviderOptions) error
 60	Model() Model
 61}
 62
 63type Model struct {
 64	Model      fantasy.LanguageModel
 65	CatwalkCfg catwalk.Model
 66	ModelCfg   config.SelectedModel
 67}
 68
 69type sessionAgent struct {
 70	largeModel           Model
 71	smallModel           Model
 72	systemPromptPrefix   string
 73	systemPrompt         string
 74	tools                []fantasy.AgentTool
 75	sessions             session.Service
 76	messages             message.Service
 77	disableAutoSummarize bool
 78	isYolo               bool
 79
 80	messageQueue   *csync.Map[string, []SessionAgentCall]
 81	activeRequests *csync.Map[string, context.CancelFunc]
 82}
 83
 84type SessionAgentOptions struct {
 85	LargeModel           Model
 86	SmallModel           Model
 87	SystemPromptPrefix   string
 88	SystemPrompt         string
 89	DisableAutoSummarize bool
 90	IsYolo               bool
 91	Sessions             session.Service
 92	Messages             message.Service
 93	Tools                []fantasy.AgentTool
 94}
 95
 96func NewSessionAgent(
 97	opts SessionAgentOptions,
 98) SessionAgent {
 99	return &sessionAgent{
100		largeModel:           opts.LargeModel,
101		smallModel:           opts.SmallModel,
102		systemPromptPrefix:   opts.SystemPromptPrefix,
103		systemPrompt:         opts.SystemPrompt,
104		sessions:             opts.Sessions,
105		messages:             opts.Messages,
106		disableAutoSummarize: opts.DisableAutoSummarize,
107		tools:                opts.Tools,
108		isYolo:               opts.IsYolo,
109		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
110		activeRequests:       csync.NewMap[string, context.CancelFunc](),
111	}
112}
113
114func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
115	if call.Prompt == "" {
116		return nil, ErrEmptyPrompt
117	}
118	if call.SessionID == "" {
119		return nil, ErrSessionMissing
120	}
121
122	// Queue the message if busy
123	if a.IsSessionBusy(call.SessionID) {
124		existing, ok := a.messageQueue.Get(call.SessionID)
125		if !ok {
126			existing = []SessionAgentCall{}
127		}
128		existing = append(existing, call)
129		a.messageQueue.Set(call.SessionID, existing)
130		return nil, nil
131	}
132
133	if len(a.tools) > 0 {
134		// Add anthropic caching to the last tool.
135		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
136	}
137
138	agent := fantasy.NewAgent(
139		a.largeModel.Model,
140		fantasy.WithSystemPrompt(a.systemPrompt),
141		fantasy.WithTools(a.tools...),
142	)
143
144	sessionLock := sync.Mutex{}
145	currentSession, err := a.sessions.Get(ctx, call.SessionID)
146	if err != nil {
147		return nil, fmt.Errorf("failed to get session: %w", err)
148	}
149
150	msgs, err := a.getSessionMessages(ctx, currentSession)
151	if err != nil {
152		return nil, fmt.Errorf("failed to get session messages: %w", err)
153	}
154
155	var wg sync.WaitGroup
156	// Generate title if first message.
157	if len(msgs) == 0 {
158		wg.Go(func() {
159			sessionLock.Lock()
160			a.generateTitle(ctx, &currentSession, call.Prompt)
161			sessionLock.Unlock()
162		})
163	}
164
165	// Add the user message to the session.
166	_, err = a.createUserMessage(ctx, call)
167	if err != nil {
168		return nil, err
169	}
170
171	// Add the session to the context.
172	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
173
174	genCtx, cancel := context.WithCancel(ctx)
175	a.activeRequests.Set(call.SessionID, cancel)
176
177	defer cancel()
178	defer a.activeRequests.Del(call.SessionID)
179
180	history, files := a.preparePrompt(msgs, call.Attachments...)
181
182	startTime := time.Now()
183	a.eventPromptSent(call.SessionID)
184
185	var currentAssistant *message.Message
186	var shouldSummarize bool
187	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
188		Prompt:           call.Prompt,
189		Files:            files,
190		Messages:         history,
191		ProviderOptions:  call.ProviderOptions,
192		MaxOutputTokens:  &call.MaxOutputTokens,
193		TopP:             call.TopP,
194		Temperature:      call.Temperature,
195		PresencePenalty:  call.PresencePenalty,
196		TopK:             call.TopK,
197		FrequencyPenalty: call.FrequencyPenalty,
198		// Before each step create a new assistant message.
199		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
200			prepared.Messages = options.Messages
201			// Reset all cached items.
202			for i := range prepared.Messages {
203				prepared.Messages[i].ProviderOptions = nil
204			}
205
206			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
207			a.messageQueue.Del(call.SessionID)
208			for _, queued := range queuedCalls {
209				userMessage, createErr := a.createUserMessage(callContext, queued)
210				if createErr != nil {
211					return callContext, prepared, createErr
212				}
213				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
214			}
215
216			lastSystemRoleInx := 0
217			systemMessageUpdated := false
218			for i, msg := range prepared.Messages {
219				// Only add cache control to the last message.
220				if msg.Role == fantasy.MessageRoleSystem {
221					lastSystemRoleInx = i
222				} else if !systemMessageUpdated {
223					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
224					systemMessageUpdated = true
225				}
226				// Than add cache control to the last 2 messages.
227				if i > len(prepared.Messages)-3 {
228					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
229				}
230			}
231
232			if a.systemPromptPrefix != "" {
233				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
234			}
235
236			var assistantMsg message.Message
237			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
238				Role:     message.Assistant,
239				Parts:    []message.ContentPart{},
240				Model:    a.largeModel.ModelCfg.Model,
241				Provider: a.largeModel.ModelCfg.Provider,
242			})
243			if err != nil {
244				return callContext, prepared, err
245			}
246			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
247			currentAssistant = &assistantMsg
248			return callContext, prepared, err
249		},
250		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
251			currentAssistant.AppendReasoningContent(reasoning.Text)
252			return a.messages.Update(genCtx, *currentAssistant)
253		},
254		OnReasoningDelta: func(id string, text string) error {
255			currentAssistant.AppendReasoningContent(text)
256			return a.messages.Update(genCtx, *currentAssistant)
257		},
258		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
259			// handle anthropic signature
260			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
261				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
262					currentAssistant.AppendReasoningSignature(reasoning.Signature)
263				}
264			}
265			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
266				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
267					currentAssistant.AppendReasoningSignature(reasoning.Signature)
268				}
269			}
270			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
271				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
272					currentAssistant.SetReasoningResponsesData(reasoning)
273				}
274			}
275			currentAssistant.FinishThinking()
276			return a.messages.Update(genCtx, *currentAssistant)
277		},
278		OnTextDelta: func(id string, text string) error {
279			// Strip leading newline from initial text content. This is is
280			// particularly important in non-interactive mode where leading
281			// newlines are very visible.
282			if len(currentAssistant.Parts) == 0 && strings.HasPrefix(text, "\n") {
283				text = strings.TrimPrefix(text, "\n")
284			}
285
286			currentAssistant.AppendContent(text)
287			return a.messages.Update(genCtx, *currentAssistant)
288		},
289		OnToolInputStart: func(id string, toolName string) error {
290			toolCall := message.ToolCall{
291				ID:               id,
292				Name:             toolName,
293				ProviderExecuted: false,
294				Finished:         false,
295			}
296			currentAssistant.AddToolCall(toolCall)
297			return a.messages.Update(genCtx, *currentAssistant)
298		},
299		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
300			// TODO: implement
301		},
302		OnToolCall: func(tc fantasy.ToolCallContent) error {
303			toolCall := message.ToolCall{
304				ID:               tc.ToolCallID,
305				Name:             tc.ToolName,
306				Input:            tc.Input,
307				ProviderExecuted: false,
308				Finished:         true,
309			}
310			currentAssistant.AddToolCall(toolCall)
311			return a.messages.Update(genCtx, *currentAssistant)
312		},
313		OnToolResult: func(result fantasy.ToolResultContent) error {
314			var resultContent string
315			isError := false
316			switch result.Result.GetType() {
317			case fantasy.ToolResultContentTypeText:
318				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
319				if ok {
320					resultContent = r.Text
321				}
322			case fantasy.ToolResultContentTypeError:
323				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
324				if ok {
325					isError = true
326					resultContent = r.Error.Error()
327				}
328			case fantasy.ToolResultContentTypeMedia:
329				// TODO: handle this message type
330			}
331			toolResult := message.ToolResult{
332				ToolCallID: result.ToolCallID,
333				Name:       result.ToolName,
334				Content:    resultContent,
335				IsError:    isError,
336				Metadata:   result.ClientMetadata,
337			}
338			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
339				Role: message.Tool,
340				Parts: []message.ContentPart{
341					toolResult,
342				},
343			})
344			if createMsgErr != nil {
345				return createMsgErr
346			}
347			return nil
348		},
349		OnStepFinish: func(stepResult fantasy.StepResult) error {
350			finishReason := message.FinishReasonUnknown
351			switch stepResult.FinishReason {
352			case fantasy.FinishReasonLength:
353				finishReason = message.FinishReasonMaxTokens
354			case fantasy.FinishReasonStop:
355				finishReason = message.FinishReasonEndTurn
356			case fantasy.FinishReasonToolCalls:
357				finishReason = message.FinishReasonToolUse
358			}
359			currentAssistant.AddFinish(finishReason, "", "")
360			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
361			sessionLock.Lock()
362			_, sessionErr := a.sessions.Save(genCtx, currentSession)
363			sessionLock.Unlock()
364			if sessionErr != nil {
365				return sessionErr
366			}
367			return a.messages.Update(genCtx, *currentAssistant)
368		},
369		StopWhen: []fantasy.StopCondition{
370			func(_ []fantasy.StepResult) bool {
371				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
372				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
373				remaining := cw - tokens
374				var threshold int64
375				if cw > 200_000 {
376					threshold = 20_000
377				} else {
378					threshold = int64(float64(cw) * 0.2)
379				}
380				if (remaining <= threshold) && !a.disableAutoSummarize {
381					shouldSummarize = true
382					return true
383				}
384				return false
385			},
386		},
387	})
388
389	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
390
391	if err != nil {
392		isCancelErr := errors.Is(err, context.Canceled)
393		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
394		if currentAssistant == nil {
395			return result, err
396		}
397		// Ensure we finish thinking on error to close the reasoning state.
398		currentAssistant.FinishThinking()
399		toolCalls := currentAssistant.ToolCalls()
400		// INFO: we use the parent context here because the genCtx has been cancelled.
401		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
402		if createErr != nil {
403			return nil, createErr
404		}
405		for _, tc := range toolCalls {
406			if !tc.Finished {
407				tc.Finished = true
408				tc.Input = "{}"
409				currentAssistant.AddToolCall(tc)
410				updateErr := a.messages.Update(ctx, *currentAssistant)
411				if updateErr != nil {
412					return nil, updateErr
413				}
414			}
415
416			found := false
417			for _, msg := range msgs {
418				if msg.Role == message.Tool {
419					for _, tr := range msg.ToolResults() {
420						if tr.ToolCallID == tc.ID {
421							found = true
422							break
423						}
424					}
425				}
426				if found {
427					break
428				}
429			}
430			if found {
431				continue
432			}
433			content := "There was an error while executing the tool"
434			if isCancelErr {
435				content = "Tool execution canceled by user"
436			} else if isPermissionErr {
437				content = "User denied permission"
438			}
439			toolResult := message.ToolResult{
440				ToolCallID: tc.ID,
441				Name:       tc.Name,
442				Content:    content,
443				IsError:    true,
444			}
445			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
446				Role: message.Tool,
447				Parts: []message.ContentPart{
448					toolResult,
449				},
450			})
451			if createErr != nil {
452				return nil, createErr
453			}
454		}
455		if isCancelErr {
456			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
457		} else if isPermissionErr {
458			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
459		} else {
460			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
461		}
462		// Note: we use the parent context here because the genCtx has been
463		// cancelled.
464		updateErr := a.messages.Update(ctx, *currentAssistant)
465		if updateErr != nil {
466			return nil, updateErr
467		}
468		return nil, err
469	}
470	wg.Wait()
471
472	if shouldSummarize {
473		a.activeRequests.Del(call.SessionID)
474		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
475			return nil, summarizeErr
476		}
477		// If the agent wasn't done...
478		if len(currentAssistant.ToolCalls()) > 0 {
479			existing, ok := a.messageQueue.Get(call.SessionID)
480			if !ok {
481				existing = []SessionAgentCall{}
482			}
483			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
484			existing = append(existing, call)
485			a.messageQueue.Set(call.SessionID, existing)
486		}
487	}
488
489	// Release active request before processing queued messages.
490	a.activeRequests.Del(call.SessionID)
491	cancel()
492
493	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
494	if !ok || len(queuedMessages) == 0 {
495		return result, err
496	}
497	// There are queued messages restart the loop.
498	firstQueuedMessage := queuedMessages[0]
499	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
500	return a.Run(ctx, firstQueuedMessage)
501}
502
503func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
504	if a.IsSessionBusy(sessionID) {
505		return ErrSessionBusy
506	}
507
508	currentSession, err := a.sessions.Get(ctx, sessionID)
509	if err != nil {
510		return fmt.Errorf("failed to get session: %w", err)
511	}
512	msgs, err := a.getSessionMessages(ctx, currentSession)
513	if err != nil {
514		return err
515	}
516	if len(msgs) == 0 {
517		// Nothing to summarize.
518		return nil
519	}
520
521	aiMsgs, _ := a.preparePrompt(msgs)
522
523	genCtx, cancel := context.WithCancel(ctx)
524	a.activeRequests.Set(sessionID, cancel)
525	defer a.activeRequests.Del(sessionID)
526	defer cancel()
527
528	agent := fantasy.NewAgent(a.largeModel.Model,
529		fantasy.WithSystemPrompt(string(summaryPrompt)),
530	)
531	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
532		Role:             message.Assistant,
533		Model:            a.largeModel.Model.Model(),
534		Provider:         a.largeModel.Model.Provider(),
535		IsSummaryMessage: true,
536	})
537	if err != nil {
538		return err
539	}
540
541	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
542		Prompt:          "Provide a detailed summary of our conversation above.",
543		Messages:        aiMsgs,
544		ProviderOptions: opts,
545		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
546			prepared.Messages = options.Messages
547			if a.systemPromptPrefix != "" {
548				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
549			}
550			return callContext, prepared, nil
551		},
552		OnReasoningDelta: func(id string, text string) error {
553			summaryMessage.AppendReasoningContent(text)
554			return a.messages.Update(genCtx, summaryMessage)
555		},
556		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
557			// Handle anthropic signature.
558			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
559				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
560					summaryMessage.AppendReasoningSignature(signature.Signature)
561				}
562			}
563			summaryMessage.FinishThinking()
564			return a.messages.Update(genCtx, summaryMessage)
565		},
566		OnTextDelta: func(id, text string) error {
567			summaryMessage.AppendContent(text)
568			return a.messages.Update(genCtx, summaryMessage)
569		},
570	})
571	if err != nil {
572		isCancelErr := errors.Is(err, context.Canceled)
573		if isCancelErr {
574			// User cancelled summarize we need to remove the summary message.
575			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
576			return deleteErr
577		}
578		return err
579	}
580
581	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
582	err = a.messages.Update(genCtx, summaryMessage)
583	if err != nil {
584		return err
585	}
586
587	var openrouterCost *float64
588	for _, step := range resp.Steps {
589		stepCost := a.openrouterCost(step.ProviderMetadata)
590		if stepCost != nil {
591			newCost := *stepCost
592			if openrouterCost != nil {
593				newCost += *openrouterCost
594			}
595			openrouterCost = &newCost
596		}
597	}
598
599	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
600
601	// Just in case, get just the last usage info.
602	usage := resp.Response.Usage
603	currentSession.SummaryMessageID = summaryMessage.ID
604	currentSession.CompletionTokens = usage.OutputTokens
605	currentSession.PromptTokens = 0
606	_, err = a.sessions.Save(genCtx, currentSession)
607	return err
608}
609
610func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
611	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
612		return fantasy.ProviderOptions{}
613	}
614	return fantasy.ProviderOptions{
615		anthropic.Name: &anthropic.ProviderCacheControlOptions{
616			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
617		},
618		bedrock.Name: &anthropic.ProviderCacheControlOptions{
619			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
620		},
621	}
622}
623
624func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
625	var attachmentParts []message.ContentPart
626	for _, attachment := range call.Attachments {
627		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
628	}
629	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
630	parts = append(parts, attachmentParts...)
631	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
632		Role:  message.User,
633		Parts: parts,
634	})
635	if err != nil {
636		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
637	}
638	return msg, nil
639}
640
641func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
642	var history []fantasy.Message
643	for _, m := range msgs {
644		if len(m.Parts) == 0 {
645			continue
646		}
647		// Assistant message without content or tool calls (cancelled before it
648		// returned anything).
649		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
650			continue
651		}
652		history = append(history, m.ToAIMessage()...)
653	}
654
655	var files []fantasy.FilePart
656	for _, attachment := range attachments {
657		files = append(files, fantasy.FilePart{
658			Filename:  attachment.FileName,
659			Data:      attachment.Content,
660			MediaType: attachment.MimeType,
661		})
662	}
663
664	return history, files
665}
666
667func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
668	msgs, err := a.messages.List(ctx, session.ID)
669	if err != nil {
670		return nil, fmt.Errorf("failed to list messages: %w", err)
671	}
672
673	if session.SummaryMessageID != "" {
674		summaryMsgInex := -1
675		for i, msg := range msgs {
676			if msg.ID == session.SummaryMessageID {
677				summaryMsgInex = i
678				break
679			}
680		}
681		if summaryMsgInex != -1 {
682			msgs = msgs[summaryMsgInex:]
683			msgs[0].Role = message.User
684		}
685	}
686	return msgs, nil
687}
688
689func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
690	if prompt == "" {
691		return
692	}
693
694	var maxOutput int64 = 40
695	if a.smallModel.CatwalkCfg.CanReason {
696		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
697	}
698
699	agent := fantasy.NewAgent(a.smallModel.Model,
700		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
701		fantasy.WithMaxOutputTokens(maxOutput),
702	)
703
704	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
705		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
706		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
707			prepared.Messages = options.Messages
708			if a.systemPromptPrefix != "" {
709				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
710			}
711			return callContext, prepared, nil
712		},
713	})
714	if err != nil {
715		slog.Error("error generating title", "err", err)
716		return
717	}
718
719	title := resp.Response.Content.Text()
720
721	title = strings.ReplaceAll(title, "\n", " ")
722
723	// Remove thinking tags if present.
724	if idx := strings.Index(title, "</think>"); idx > 0 {
725		title = title[idx+len("</think>"):]
726	}
727
728	title = strings.TrimSpace(title)
729	if title == "" {
730		slog.Warn("failed to generate title", "warn", "empty title")
731		return
732	}
733
734	session.Title = title
735
736	var openrouterCost *float64
737	for _, step := range resp.Steps {
738		stepCost := a.openrouterCost(step.ProviderMetadata)
739		if stepCost != nil {
740			newCost := *stepCost
741			if openrouterCost != nil {
742				newCost += *openrouterCost
743			}
744			openrouterCost = &newCost
745		}
746	}
747
748	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
749	_, saveErr := a.sessions.Save(ctx, *session)
750	if saveErr != nil {
751		slog.Error("failed to save session title & usage", "error", saveErr)
752		return
753	}
754}
755
756func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
757	openrouterMetadata, ok := metadata[openrouter.Name]
758	if !ok {
759		return nil
760	}
761
762	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
763	if !ok {
764		return nil
765	}
766	return &opts.Usage.Cost
767}
768
769func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
770	modelConfig := model.CatwalkCfg
771	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
772		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
773		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
774		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
775
776	a.eventTokensUsed(session.ID, model, usage, cost)
777
778	if overrideCost != nil {
779		session.Cost += *overrideCost
780	} else {
781		session.Cost += cost
782	}
783
784	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
785	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
786}
787
788func (a *sessionAgent) Cancel(sessionID string) {
789	// Cancel regular requests.
790	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
791		slog.Info("Request cancellation initiated", "session_id", sessionID)
792		cancel()
793	}
794
795	// Also check for summarize requests.
796	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
797		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
798		cancel()
799	}
800
801	if a.QueuedPrompts(sessionID) > 0 {
802		slog.Info("Clearing queued prompts", "session_id", sessionID)
803		a.messageQueue.Del(sessionID)
804	}
805}
806
807func (a *sessionAgent) ClearQueue(sessionID string) {
808	if a.QueuedPrompts(sessionID) > 0 {
809		slog.Info("Clearing queued prompts", "session_id", sessionID)
810		a.messageQueue.Del(sessionID)
811	}
812}
813
814func (a *sessionAgent) CancelAll() {
815	if !a.IsBusy() {
816		return
817	}
818	for key := range a.activeRequests.Seq2() {
819		a.Cancel(key) // key is sessionID
820	}
821
822	timeout := time.After(5 * time.Second)
823	for a.IsBusy() {
824		select {
825		case <-timeout:
826			return
827		default:
828			time.Sleep(200 * time.Millisecond)
829		}
830	}
831}
832
833func (a *sessionAgent) IsBusy() bool {
834	var busy bool
835	for cancelFunc := range a.activeRequests.Seq() {
836		if cancelFunc != nil {
837			busy = true
838			break
839		}
840	}
841	return busy
842}
843
844func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
845	_, busy := a.activeRequests.Get(sessionID)
846	return busy
847}
848
849func (a *sessionAgent) QueuedPrompts(sessionID string) int {
850	l, ok := a.messageQueue.Get(sessionID)
851	if !ok {
852		return 0
853	}
854	return len(l)
855}
856
857func (a *sessionAgent) SetModels(large Model, small Model) {
858	a.largeModel = large
859	a.smallModel = small
860}
861
862func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
863	a.tools = tools
864}
865
866func (a *sessionAgent) Model() Model {
867	return a.largeModel
868}