agent.go

  1// Package agent is the core orchestration layer for Crush AI agents.
  2//
  3// It provides session-based AI agent functionality for managing
  4// conversations, tool execution, and message handling. It coordinates
  5// interactions between language models, messages, sessions, and tools while
  6// handling features like automatic summarization, queuing, and token
  7// management.
  8package agent
  9
 10import (
 11	"cmp"
 12	"context"
 13	_ "embed"
 14	"errors"
 15	"fmt"
 16	"log/slog"
 17	"os"
 18	"strconv"
 19	"strings"
 20	"sync"
 21	"time"
 22
 23	"charm.land/fantasy"
 24	"charm.land/fantasy/providers/anthropic"
 25	"charm.land/fantasy/providers/bedrock"
 26	"charm.land/fantasy/providers/google"
 27	"charm.land/fantasy/providers/openai"
 28	"charm.land/fantasy/providers/openrouter"
 29	"github.com/charmbracelet/catwalk/pkg/catwalk"
 30	"github.com/charmbracelet/crush/internal/agent/tools"
 31	"github.com/charmbracelet/crush/internal/config"
 32	"github.com/charmbracelet/crush/internal/csync"
 33	"github.com/charmbracelet/crush/internal/hooks"
 34	"github.com/charmbracelet/crush/internal/message"
 35	"github.com/charmbracelet/crush/internal/permission"
 36	"github.com/charmbracelet/crush/internal/session"
 37	"github.com/charmbracelet/crush/internal/stringext"
 38)
 39
 40//go:embed templates/title.md
 41var titlePrompt []byte
 42
 43//go:embed templates/summary.md
 44var summaryPrompt []byte
 45
 46type SessionAgentCall struct {
 47	SessionID        string
 48	Prompt           string
 49	ProviderOptions  fantasy.ProviderOptions
 50	Attachments      []message.Attachment
 51	MaxOutputTokens  int64
 52	Temperature      *float64
 53	TopP             *float64
 54	TopK             *int64
 55	FrequencyPenalty *float64
 56	PresencePenalty  *float64
 57}
 58
 59type SessionAgent interface {
 60	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 61	SetModels(large Model, small Model)
 62	SetTools(tools []fantasy.AgentTool)
 63	Cancel(sessionID string)
 64	CancelAll()
 65	IsSessionBusy(sessionID string) bool
 66	IsBusy() bool
 67	QueuedPrompts(sessionID string) int
 68	ClearQueue(sessionID string)
 69	Summarize(context.Context, string, fantasy.ProviderOptions) error
 70	Model() Model
 71}
 72
 73type Model struct {
 74	Model      fantasy.LanguageModel
 75	CatwalkCfg catwalk.Model
 76	ModelCfg   config.SelectedModel
 77}
 78
 79type sessionAgent struct {
 80	largeModel           Model
 81	smallModel           Model
 82	systemPromptPrefix   string
 83	systemPrompt         string
 84	tools                []fantasy.AgentTool
 85	sessions             session.Service
 86	messages             message.Service
 87	hooks                hooks.Service
 88	disableAutoSummarize bool
 89	isYolo               bool
 90
 91	messageQueue   *csync.Map[string, []SessionAgentCall]
 92	activeRequests *csync.Map[string, context.CancelFunc]
 93}
 94
 95type SessionAgentOptions struct {
 96	LargeModel           Model
 97	SmallModel           Model
 98	SystemPromptPrefix   string
 99	SystemPrompt         string
100	DisableAutoSummarize bool
101	IsYolo               bool
102	Sessions             session.Service
103	Messages             message.Service
104	Hooks                hooks.Service
105	Tools                []fantasy.AgentTool
106}
107
108func NewSessionAgent(
109	opts SessionAgentOptions,
110) SessionAgent {
111	return &sessionAgent{
112		largeModel:           opts.LargeModel,
113		smallModel:           opts.SmallModel,
114		systemPromptPrefix:   opts.SystemPromptPrefix,
115		systemPrompt:         opts.SystemPrompt,
116		sessions:             opts.Sessions,
117		messages:             opts.Messages,
118		disableAutoSummarize: opts.DisableAutoSummarize,
119		hooks:                opts.Hooks,
120		tools:                opts.Tools,
121		isYolo:               opts.IsYolo,
122		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
123		activeRequests:       csync.NewMap[string, context.CancelFunc](),
124	}
125}
126
127func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
128	if call.Prompt == "" {
129		return nil, ErrEmptyPrompt
130	}
131	if call.SessionID == "" {
132		return nil, ErrSessionMissing
133	}
134
135	// Queue the message if busy
136	if a.IsSessionBusy(call.SessionID) {
137		existing, ok := a.messageQueue.Get(call.SessionID)
138		if !ok {
139			existing = []SessionAgentCall{}
140		}
141		existing = append(existing, call)
142		a.messageQueue.Set(call.SessionID, existing)
143		return nil, nil
144	}
145
146	if len(a.tools) > 0 {
147		// Add Anthropic caching to the last tool.
148		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
149	}
150
151	agent := fantasy.NewAgent(
152		a.largeModel.Model,
153		fantasy.WithSystemPrompt(a.systemPrompt),
154		fantasy.WithTools(a.tools...),
155	)
156
157	sessionLock := sync.Mutex{}
158	currentSession, err := a.sessions.Get(ctx, call.SessionID)
159	if err != nil {
160		return nil, fmt.Errorf("failed to get session: %w", err)
161	}
162
163	msgs, err := a.getSessionMessages(ctx, currentSession)
164	if err != nil {
165		return nil, fmt.Errorf("failed to get session messages: %w", err)
166	}
167
168	var wg sync.WaitGroup
169	// Generate title if first message.
170	if len(msgs) == 0 {
171		wg.Go(func() {
172			sessionLock.Lock()
173			a.generateTitle(ctx, &currentSession, call.Prompt)
174			sessionLock.Unlock()
175		})
176	}
177
178	// Add the user message to the session.
179	userMsg, err := a.createUserMessage(ctx, call)
180	if err != nil {
181		return nil, err
182	}
183
184	// Add the session to the context.
185	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
186
187	genCtx, cancel := context.WithCancel(ctx)
188	a.activeRequests.Set(call.SessionID, cancel)
189	defer cancel()
190	defer a.activeRequests.Del(call.SessionID)
191
192	// create the agent message asap to show loading
193	var currentAssistant *message.Message
194	assistantMessage, err := a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
195		Role:     message.Assistant,
196		Parts:    []message.ContentPart{},
197		Model:    a.largeModel.ModelCfg.Model,
198		Provider: a.largeModel.ModelCfg.Provider,
199	})
200	if err != nil {
201		return nil, err
202	}
203	currentAssistant = &assistantMessage
204
205	// run hooks after assistant message is created
206	// this way we show loading for the user
207	hooks, err := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt)
208	if err != nil {
209		return nil, err
210	}
211	userMsg.AddHookOutputs(hooks...)
212	if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil {
213		return nil, updateErr
214	}
215
216	for _, hook := range hooks {
217		// execution stopped
218		if hook.Stop {
219			deleteErr := a.messages.Delete(genCtx, assistantMessage.ID)
220			if deleteErr != nil {
221				return nil, deleteErr
222			}
223			return nil, nil
224		}
225	}
226
227	history, files := a.preparePrompt(msgs, call.Attachments...)
228
229	startTime := time.Now()
230	a.eventPromptSent(call.SessionID)
231
232	var shouldSummarize bool
233	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
234		Prompt:           userMsg.ContentWithHooksContext(),
235		Files:            files,
236		Messages:         history,
237		ProviderOptions:  call.ProviderOptions,
238		MaxOutputTokens:  &call.MaxOutputTokens,
239		TopP:             call.TopP,
240		Temperature:      call.Temperature,
241		PresencePenalty:  call.PresencePenalty,
242		TopK:             call.TopK,
243		FrequencyPenalty: call.FrequencyPenalty,
244		// Before each step create a new assistant message.
245		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
246			// only add new assistant message when its not the first step
247			if options.StepNumber != 0 {
248				var assistantMsg message.Message
249				assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
250					Role:     message.Assistant,
251					Parts:    []message.ContentPart{},
252					Model:    a.largeModel.ModelCfg.Model,
253					Provider: a.largeModel.ModelCfg.Provider,
254				})
255				currentAssistant = &assistantMsg
256				// create the message first so we show loading asap
257				if err != nil {
258					return callContext, prepared, err
259				}
260			}
261
262			prepared.Messages = options.Messages
263			// Reset all cached items.
264			for i := range prepared.Messages {
265				prepared.Messages[i].ProviderOptions = nil
266			}
267
268			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
269			a.messageQueue.Del(call.SessionID)
270			for _, queued := range queuedCalls {
271				userMessage, createErr := a.createUserMessage(callContext, queued)
272				if createErr != nil {
273					return callContext, prepared, createErr
274				}
275
276				// run hooks after assistant message is created
277				// this way we show loading for the user
278				hooks, hookErr := a.executeUserPromptSubmitHook(genCtx, call.SessionID, call.Prompt)
279				if hookErr != nil {
280					return callContext, prepared, hookErr
281				}
282				userMsg.AddHookOutputs(hooks...)
283				for _, hook := range hooks {
284					// execution stopped
285					if hook.Stop {
286						deleteErr := a.messages.Delete(genCtx, assistantMessage.ID)
287						if deleteErr != nil {
288							return callContext, prepared, deleteErr
289						}
290						return callContext, prepared, ErrHookCancellation
291					}
292				}
293				if updateErr := a.messages.Update(genCtx, userMsg); updateErr != nil {
294					return callContext, prepared, updateErr
295				}
296				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
297			}
298
299			lastSystemRoleInx := 0
300			systemMessageUpdated := false
301			for i, msg := range prepared.Messages {
302				// Only add cache control to the last message.
303				if msg.Role == fantasy.MessageRoleSystem {
304					lastSystemRoleInx = i
305				} else if !systemMessageUpdated {
306					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
307					systemMessageUpdated = true
308				}
309				// Than add cache control to the last 2 messages.
310				if i > len(prepared.Messages)-3 {
311					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
312				}
313			}
314
315			if a.systemPromptPrefix != "" {
316				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
317			}
318
319			callContext = context.WithValue(callContext, tools.MessageIDContextKey, currentAssistant.ID)
320			return callContext, prepared, err
321		},
322		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
323			currentAssistant.AppendReasoningContent(reasoning.Text)
324			return a.messages.Update(genCtx, *currentAssistant)
325		},
326		OnReasoningDelta: func(id string, text string) error {
327			currentAssistant.AppendReasoningContent(text)
328			return a.messages.Update(genCtx, *currentAssistant)
329		},
330		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
331			// handle anthropic signature
332			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
333				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
334					currentAssistant.AppendReasoningSignature(reasoning.Signature)
335				}
336			}
337			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
338				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
339					currentAssistant.AppendReasoningSignature(reasoning.Signature)
340				}
341			}
342			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
343				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
344					currentAssistant.SetReasoningResponsesData(reasoning)
345				}
346			}
347			currentAssistant.FinishThinking()
348			return a.messages.Update(genCtx, *currentAssistant)
349		},
350		OnTextDelta: func(id string, text string) error {
351			// Strip leading newline from initial text content. This is is
352			// particularly important in non-interactive mode where leading
353			// newlines are very visible.
354			if len(currentAssistant.Parts) == 0 {
355				text = strings.TrimPrefix(text, "\n")
356			}
357
358			currentAssistant.AppendContent(text)
359			return a.messages.Update(genCtx, *currentAssistant)
360		},
361		OnToolInputStart: func(id string, toolName string) error {
362			toolCall := message.ToolCall{
363				ID:               id,
364				Name:             toolName,
365				ProviderExecuted: false,
366				Finished:         false,
367			}
368			currentAssistant.AddToolCall(toolCall)
369			return a.messages.Update(genCtx, *currentAssistant)
370		},
371		OnRetry: func(err *fantasy.ProviderError, delay time.Duration) {
372			// TODO: implement
373		},
374		OnToolCall: func(tc fantasy.ToolCallContent) error {
375			toolCall := message.ToolCall{
376				ID:               tc.ToolCallID,
377				Name:             tc.ToolName,
378				Input:            tc.Input,
379				ProviderExecuted: false,
380				Finished:         true,
381			}
382			currentAssistant.AddToolCall(toolCall)
383			return a.messages.Update(genCtx, *currentAssistant)
384		},
385		OnToolResult: func(result fantasy.ToolResultContent) error {
386			var resultContent string
387			isError := false
388			switch result.Result.GetType() {
389			case fantasy.ToolResultContentTypeText:
390				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
391				if ok {
392					resultContent = r.Text
393				}
394			case fantasy.ToolResultContentTypeError:
395				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
396				if ok {
397					isError = true
398					resultContent = r.Error.Error()
399				}
400			case fantasy.ToolResultContentTypeMedia:
401				// TODO: handle this message type
402			}
403			toolResult := message.ToolResult{
404				ToolCallID: result.ToolCallID,
405				Name:       result.ToolName,
406				Content:    resultContent,
407				IsError:    isError,
408				Metadata:   result.ClientMetadata,
409			}
410			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
411				Role: message.Tool,
412				Parts: []message.ContentPart{
413					toolResult,
414				},
415			})
416			if createMsgErr != nil {
417				return createMsgErr
418			}
419			return nil
420		},
421		OnStepFinish: func(stepResult fantasy.StepResult) error {
422			finishReason := message.FinishReasonUnknown
423			switch stepResult.FinishReason {
424			case fantasy.FinishReasonLength:
425				finishReason = message.FinishReasonMaxTokens
426			case fantasy.FinishReasonStop:
427				finishReason = message.FinishReasonEndTurn
428			case fantasy.FinishReasonToolCalls:
429				finishReason = message.FinishReasonToolUse
430			}
431			currentAssistant.AddFinish(finishReason, "", "")
432			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
433			sessionLock.Lock()
434			_, sessionErr := a.sessions.Save(genCtx, currentSession)
435			sessionLock.Unlock()
436			if sessionErr != nil {
437				return sessionErr
438			}
439			return a.messages.Update(genCtx, *currentAssistant)
440		},
441		StopWhen: []fantasy.StopCondition{
442			func(_ []fantasy.StepResult) bool {
443				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
444				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
445				remaining := cw - tokens
446				var threshold int64
447				if cw > 200_000 {
448					threshold = 20_000
449				} else {
450					threshold = int64(float64(cw) * 0.2)
451				}
452				if (remaining <= threshold) && !a.disableAutoSummarize {
453					shouldSummarize = true
454					return true
455				}
456				return false
457			},
458		},
459	})
460
461	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
462
463	if err != nil {
464		isCancelErr := errors.Is(err, context.Canceled)
465		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
466		isHookCancelled := errors.Is(err, ErrHookCancellation)
467		if currentAssistant == nil {
468			return result, err
469		}
470		// Ensure we finish thinking on error to close the reasoning state.
471		currentAssistant.FinishThinking()
472		toolCalls := currentAssistant.ToolCalls()
473		// INFO: we use the parent context here because the genCtx has been cancelled.
474		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
475		if createErr != nil {
476			return nil, createErr
477		}
478		for _, tc := range toolCalls {
479			if !tc.Finished {
480				tc.Finished = true
481				tc.Input = "{}"
482				currentAssistant.AddToolCall(tc)
483				updateErr := a.messages.Update(ctx, *currentAssistant)
484				if updateErr != nil {
485					return nil, updateErr
486				}
487			}
488
489			found := false
490			for _, msg := range msgs {
491				if msg.Role == message.Tool {
492					for _, tr := range msg.ToolResults() {
493						if tr.ToolCallID == tc.ID {
494							found = true
495							break
496						}
497					}
498				}
499				if found {
500					break
501				}
502			}
503			if found {
504				continue
505			}
506			content := "There was an error while executing the tool"
507			if isCancelErr {
508				content = "Tool execution canceled by user"
509			} else if isPermissionErr {
510				content = "User denied permission"
511			}
512			toolResult := message.ToolResult{
513				ToolCallID: tc.ID,
514				Name:       tc.Name,
515				Content:    content,
516				IsError:    true,
517			}
518			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
519				Role: message.Tool,
520				Parts: []message.ContentPart{
521					toolResult,
522				},
523			})
524			if createErr != nil {
525				return nil, createErr
526			}
527		}
528		var fantasyErr *fantasy.Error
529		var providerErr *fantasy.ProviderError
530		const defaultTitle = "Provider Error"
531		if isCancelErr {
532			currentAssistant.AddFinish(message.FinishReasonCanceled, "User canceled request", "")
533		} else if isPermissionErr {
534			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "User denied permission", "")
535		} else if isHookCancelled {
536			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Hook canceled request", "")
537		} else if errors.As(err, &providerErr) {
538			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(providerErr.Title), defaultTitle), providerErr.Message)
539		} else if errors.As(err, &fantasyErr) {
540			currentAssistant.AddFinish(message.FinishReasonError, cmp.Or(stringext.Capitalize(fantasyErr.Title), defaultTitle), fantasyErr.Message)
541		} else {
542			currentAssistant.AddFinish(message.FinishReasonError, defaultTitle, err.Error())
543		}
544		// Note: we use the parent context here because the genCtx has been
545		// cancelled.
546		updateErr := a.messages.Update(ctx, *currentAssistant)
547		if updateErr != nil {
548			return nil, updateErr
549		}
550		return nil, err
551	}
552	wg.Wait()
553
554	if shouldSummarize {
555		a.activeRequests.Del(call.SessionID)
556		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
557			return nil, summarizeErr
558		}
559		// If the agent wasn't done...
560		if len(currentAssistant.ToolCalls()) > 0 {
561			existing, ok := a.messageQueue.Get(call.SessionID)
562			if !ok {
563				existing = []SessionAgentCall{}
564			}
565			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
566			existing = append(existing, call)
567			a.messageQueue.Set(call.SessionID, existing)
568		}
569	}
570
571	// Release active request before processing queued messages.
572	a.activeRequests.Del(call.SessionID)
573	cancel()
574
575	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
576	if !ok || len(queuedMessages) == 0 {
577		return result, err
578	}
579	// There are queued messages restart the loop.
580	firstQueuedMessage := queuedMessages[0]
581	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
582	return a.Run(ctx, firstQueuedMessage)
583}
584
585func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
586	if a.IsSessionBusy(sessionID) {
587		return ErrSessionBusy
588	}
589
590	currentSession, err := a.sessions.Get(ctx, sessionID)
591	if err != nil {
592		return fmt.Errorf("failed to get session: %w", err)
593	}
594	msgs, err := a.getSessionMessages(ctx, currentSession)
595	if err != nil {
596		return err
597	}
598	if len(msgs) == 0 {
599		// Nothing to summarize.
600		return nil
601	}
602
603	aiMsgs, _ := a.preparePrompt(msgs)
604
605	genCtx, cancel := context.WithCancel(ctx)
606	a.activeRequests.Set(sessionID, cancel)
607	defer a.activeRequests.Del(sessionID)
608	defer cancel()
609
610	agent := fantasy.NewAgent(a.largeModel.Model,
611		fantasy.WithSystemPrompt(string(summaryPrompt)),
612	)
613	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
614		Role:             message.Assistant,
615		Model:            a.largeModel.Model.Model(),
616		Provider:         a.largeModel.Model.Provider(),
617		IsSummaryMessage: true,
618	})
619	if err != nil {
620		return err
621	}
622
623	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
624		Prompt:          "Provide a detailed summary of our conversation above.",
625		Messages:        aiMsgs,
626		ProviderOptions: opts,
627		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
628			prepared.Messages = options.Messages
629			if a.systemPromptPrefix != "" {
630				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
631			}
632			return callContext, prepared, nil
633		},
634		OnReasoningDelta: func(id string, text string) error {
635			summaryMessage.AppendReasoningContent(text)
636			return a.messages.Update(genCtx, summaryMessage)
637		},
638		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
639			// Handle anthropic signature.
640			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
641				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
642					summaryMessage.AppendReasoningSignature(signature.Signature)
643				}
644			}
645			summaryMessage.FinishThinking()
646			return a.messages.Update(genCtx, summaryMessage)
647		},
648		OnTextDelta: func(id, text string) error {
649			summaryMessage.AppendContent(text)
650			return a.messages.Update(genCtx, summaryMessage)
651		},
652	})
653	if err != nil {
654		isCancelErr := errors.Is(err, context.Canceled)
655		if isCancelErr {
656			// User cancelled summarize we need to remove the summary message.
657			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
658			return deleteErr
659		}
660		return err
661	}
662
663	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
664	err = a.messages.Update(genCtx, summaryMessage)
665	if err != nil {
666		return err
667	}
668
669	var openrouterCost *float64
670	for _, step := range resp.Steps {
671		stepCost := a.openrouterCost(step.ProviderMetadata)
672		if stepCost != nil {
673			newCost := *stepCost
674			if openrouterCost != nil {
675				newCost += *openrouterCost
676			}
677			openrouterCost = &newCost
678		}
679	}
680
681	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
682
683	// Just in case, get just the last usage info.
684	usage := resp.Response.Usage
685	currentSession.SummaryMessageID = summaryMessage.ID
686	currentSession.CompletionTokens = usage.OutputTokens
687	currentSession.PromptTokens = 0
688	_, err = a.sessions.Save(genCtx, currentSession)
689	return err
690}
691
692func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
693	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
694		return fantasy.ProviderOptions{}
695	}
696	return fantasy.ProviderOptions{
697		anthropic.Name: &anthropic.ProviderCacheControlOptions{
698			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
699		},
700		bedrock.Name: &anthropic.ProviderCacheControlOptions{
701			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
702		},
703	}
704}
705
706func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
707	var attachmentParts []message.ContentPart
708	for _, attachment := range call.Attachments {
709		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
710	}
711	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
712	parts = append(parts, attachmentParts...)
713	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
714		Role:  message.User,
715		Parts: parts,
716	})
717	if err != nil {
718		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
719	}
720	return msg, nil
721}
722
723func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
724	var history []fantasy.Message
725	for _, m := range msgs {
726		if len(m.Parts) == 0 {
727			continue
728		}
729		// Assistant message without content or tool calls (cancelled before it
730		// returned anything).
731		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
732			continue
733		}
734		history = append(history, m.ToAIMessage()...)
735	}
736
737	var files []fantasy.FilePart
738	for _, attachment := range attachments {
739		files = append(files, fantasy.FilePart{
740			Filename:  attachment.FileName,
741			Data:      attachment.Content,
742			MediaType: attachment.MimeType,
743		})
744	}
745
746	return history, files
747}
748
749func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
750	msgs, err := a.messages.List(ctx, session.ID)
751	if err != nil {
752		return nil, fmt.Errorf("failed to list messages: %w", err)
753	}
754
755	if session.SummaryMessageID != "" {
756		summaryMsgInex := -1
757		for i, msg := range msgs {
758			if msg.ID == session.SummaryMessageID {
759				summaryMsgInex = i
760				break
761			}
762		}
763		if summaryMsgInex != -1 {
764			msgs = msgs[summaryMsgInex:]
765			msgs[0].Role = message.User
766		}
767	}
768	return msgs, nil
769}
770
771func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
772	if prompt == "" {
773		return
774	}
775
776	var maxOutput int64 = 40
777	if a.smallModel.CatwalkCfg.CanReason {
778		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
779	}
780
781	agent := fantasy.NewAgent(a.smallModel.Model,
782		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
783		fantasy.WithMaxOutputTokens(maxOutput),
784	)
785
786	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
787		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
788		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
789			prepared.Messages = options.Messages
790			if a.systemPromptPrefix != "" {
791				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
792			}
793			return callContext, prepared, nil
794		},
795	})
796	if err != nil {
797		slog.Error("error generating title", "err", err)
798		return
799	}
800
801	title := resp.Response.Content.Text()
802
803	title = strings.ReplaceAll(title, "\n", " ")
804
805	// Remove thinking tags if present.
806	if idx := strings.Index(title, "</think>"); idx > 0 {
807		title = title[idx+len("</think>"):]
808	}
809
810	title = strings.TrimSpace(title)
811	if title == "" {
812		slog.Warn("failed to generate title", "warn", "empty title")
813		return
814	}
815
816	session.Title = title
817
818	var openrouterCost *float64
819	for _, step := range resp.Steps {
820		stepCost := a.openrouterCost(step.ProviderMetadata)
821		if stepCost != nil {
822			newCost := *stepCost
823			if openrouterCost != nil {
824				newCost += *openrouterCost
825			}
826			openrouterCost = &newCost
827		}
828	}
829
830	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
831	_, saveErr := a.sessions.Save(ctx, *session)
832	if saveErr != nil {
833		slog.Error("failed to save session title & usage", "error", saveErr)
834		return
835	}
836}
837
838func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
839	openrouterMetadata, ok := metadata[openrouter.Name]
840	if !ok {
841		return nil
842	}
843
844	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
845	if !ok {
846		return nil
847	}
848	return &opts.Usage.Cost
849}
850
851func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
852	modelConfig := model.CatwalkCfg
853	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
854		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
855		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
856		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
857
858	a.eventTokensUsed(session.ID, model, usage, cost)
859
860	if overrideCost != nil {
861		session.Cost += *overrideCost
862	} else {
863		session.Cost += cost
864	}
865
866	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
867	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
868}
869
870func (a *sessionAgent) Cancel(sessionID string) {
871	// Cancel regular requests.
872	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
873		slog.Info("Request cancellation initiated", "session_id", sessionID)
874		cancel()
875	}
876
877	// Also check for summarize requests.
878	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
879		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
880		cancel()
881	}
882
883	if a.QueuedPrompts(sessionID) > 0 {
884		slog.Info("Clearing queued prompts", "session_id", sessionID)
885		a.messageQueue.Del(sessionID)
886	}
887}
888
889func (a *sessionAgent) ClearQueue(sessionID string) {
890	if a.QueuedPrompts(sessionID) > 0 {
891		slog.Info("Clearing queued prompts", "session_id", sessionID)
892		a.messageQueue.Del(sessionID)
893	}
894}
895
896func (a *sessionAgent) CancelAll() {
897	if !a.IsBusy() {
898		return
899	}
900	for key := range a.activeRequests.Seq2() {
901		a.Cancel(key) // key is sessionID
902	}
903
904	timeout := time.After(5 * time.Second)
905	for a.IsBusy() {
906		select {
907		case <-timeout:
908			return
909		default:
910			time.Sleep(200 * time.Millisecond)
911		}
912	}
913}
914
915func (a *sessionAgent) IsBusy() bool {
916	var busy bool
917	for cancelFunc := range a.activeRequests.Seq() {
918		if cancelFunc != nil {
919			busy = true
920			break
921		}
922	}
923	return busy
924}
925
926func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
927	_, busy := a.activeRequests.Get(sessionID)
928	return busy
929}
930
931func (a *sessionAgent) QueuedPrompts(sessionID string) int {
932	l, ok := a.messageQueue.Get(sessionID)
933	if !ok {
934		return 0
935	}
936	return len(l)
937}
938
939func (a *sessionAgent) SetModels(large Model, small Model) {
940	a.largeModel = large
941	a.smallModel = small
942}
943
944func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
945	a.tools = tools
946}
947
948func (a *sessionAgent) Model() Model {
949	return a.largeModel
950}
951
952func (a *sessionAgent) executeUserPromptSubmitHook(
953	ctx context.Context,
954	sessionID string,
955	prompt string,
956) ([]message.HookOutput, error) {
957	if a.hooks == nil {
958		return nil, nil
959	}
960	return a.hooks.Execute(ctx, hooks.HookContext{
961		EventType:  config.UserPromptSubmit,
962		SessionID:  sessionID,
963		UserPrompt: prompt,
964		Provider:   a.largeModel.ModelCfg.Provider,
965		Model:      a.largeModel.ModelCfg.Model,
966	})
967}