agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"os"
 11	"strconv"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"charm.land/fantasy"
 17	"charm.land/fantasy/providers/anthropic"
 18	"charm.land/fantasy/providers/bedrock"
 19	"charm.land/fantasy/providers/google"
 20	"charm.land/fantasy/providers/openai"
 21	"charm.land/fantasy/providers/openrouter"
 22	"github.com/charmbracelet/catwalk/pkg/catwalk"
 23	"github.com/charmbracelet/crush/internal/agent/tools"
 24	"github.com/charmbracelet/crush/internal/config"
 25	"github.com/charmbracelet/crush/internal/csync"
 26	"github.com/charmbracelet/crush/internal/hooks"
 27	"github.com/charmbracelet/crush/internal/message"
 28	"github.com/charmbracelet/crush/internal/permission"
 29	"github.com/charmbracelet/crush/internal/session"
 30)
 31
 32//go:embed templates/title.md
 33var titlePrompt []byte
 34
 35//go:embed templates/summary.md
 36var summaryPrompt []byte
 37
 38type SessionAgentCall struct {
 39	SessionID        string
 40	Prompt           string
 41	ProviderOptions  fantasy.ProviderOptions
 42	Attachments      []message.Attachment
 43	MaxOutputTokens  int64
 44	Temperature      *float64
 45	TopP             *float64
 46	TopK             *int64
 47	FrequencyPenalty *float64
 48	PresencePenalty  *float64
 49}
 50
 51type SessionAgent interface {
 52	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 53	SetModels(large Model, small Model)
 54	SetTools(tools []fantasy.AgentTool)
 55	Cancel(sessionID string)
 56	CancelAll()
 57	IsSessionBusy(sessionID string) bool
 58	IsBusy() bool
 59	QueuedPrompts(sessionID string) int
 60	ClearQueue(sessionID string)
 61	Summarize(context.Context, string, fantasy.ProviderOptions) error
 62	Model() Model
 63}
 64
 65type Model struct {
 66	Model      fantasy.LanguageModel
 67	CatwalkCfg catwalk.Model
 68	ModelCfg   config.SelectedModel
 69}
 70
 71type sessionAgent struct {
 72	largeModel           Model
 73	smallModel           Model
 74	systemPromptPrefix   string
 75	systemPrompt         string
 76	tools                []fantasy.AgentTool
 77	sessions             session.Service
 78	messages             message.Service
 79	disableAutoSummarize bool
 80	isYolo               bool
 81	hooks                *hooks.Executor
 82
 83	messageQueue   *csync.Map[string, []SessionAgentCall]
 84	activeRequests *csync.Map[string, context.CancelFunc]
 85}
 86
 87type SessionAgentOptions struct {
 88	LargeModel           Model
 89	SmallModel           Model
 90	SystemPromptPrefix   string
 91	SystemPrompt         string
 92	DisableAutoSummarize bool
 93	IsYolo               bool
 94	Sessions             session.Service
 95	Messages             message.Service
 96	Tools                []fantasy.AgentTool
 97	Hooks                *hooks.Executor
 98}
 99
100func NewSessionAgent(
101	opts SessionAgentOptions,
102) SessionAgent {
103	return &sessionAgent{
104		largeModel:           opts.LargeModel,
105		smallModel:           opts.SmallModel,
106		systemPromptPrefix:   opts.SystemPromptPrefix,
107		systemPrompt:         opts.SystemPrompt,
108		sessions:             opts.Sessions,
109		messages:             opts.Messages,
110		disableAutoSummarize: opts.DisableAutoSummarize,
111		tools:                opts.Tools,
112		isYolo:               opts.IsYolo,
113		hooks:                opts.Hooks,
114		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
115		activeRequests:       csync.NewMap[string, context.CancelFunc](),
116	}
117}
118
119func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
120	if call.Prompt == "" {
121		return nil, ErrEmptyPrompt
122	}
123	if call.SessionID == "" {
124		return nil, ErrSessionMissing
125	}
126
127	// Queue the message if busy
128	if a.IsSessionBusy(call.SessionID) {
129		existing, ok := a.messageQueue.Get(call.SessionID)
130		if !ok {
131			existing = []SessionAgentCall{}
132		}
133		existing = append(existing, call)
134		a.messageQueue.Set(call.SessionID, existing)
135		return nil, nil
136	}
137
138	if len(a.tools) > 0 {
139		// add anthropic caching to the last tool
140		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
141	}
142
143	agent := fantasy.NewAgent(
144		a.largeModel.Model,
145		fantasy.WithSystemPrompt(a.systemPrompt),
146		fantasy.WithTools(a.tools...),
147	)
148
149	sessionLock := sync.Mutex{}
150	currentSession, err := a.sessions.Get(ctx, call.SessionID)
151	if err != nil {
152		return nil, fmt.Errorf("failed to get session: %w", err)
153	}
154
155	msgs, err := a.getSessionMessages(ctx, currentSession)
156	if err != nil {
157		return nil, fmt.Errorf("failed to get session messages: %w", err)
158	}
159
160	var wg sync.WaitGroup
161	// Generate title if first message
162	if len(msgs) == 0 {
163		wg.Go(func() {
164			sessionLock.Lock()
165			a.generateTitle(ctx, &currentSession, call.Prompt)
166			sessionLock.Unlock()
167		})
168	}
169
170	// Add the user message to the session
171	_, err = a.createUserMessage(ctx, call)
172	if err != nil {
173		return nil, err
174	}
175
176	// Execute UserPromptSubmit hook
177	if a.hooks != nil {
178		if err := a.hooks.Execute(ctx, hooks.HookContext{
179			EventType:  config.UserPromptSubmit,
180			SessionID:  call.SessionID,
181			UserPrompt: call.Prompt,
182			Provider:   a.largeModel.ModelCfg.Provider,
183			Model:      a.largeModel.ModelCfg.Model,
184		}); err != nil {
185			slog.Debug("user_prompt_submit hook execution failed", "error", err)
186		}
187	}
188
189	// add the session to the context
190	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
191
192	genCtx, cancel := context.WithCancel(ctx)
193	a.activeRequests.Set(call.SessionID, cancel)
194
195	defer cancel()
196	defer a.activeRequests.Del(call.SessionID)
197
198	history, files := a.preparePrompt(msgs, call.Attachments...)
199
200	startTime := time.Now()
201	a.eventPromptSent(call.SessionID)
202
203	var currentAssistant *message.Message
204	var shouldSummarize bool
205	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
206		Prompt:           call.Prompt,
207		Files:            files,
208		Messages:         history,
209		ProviderOptions:  call.ProviderOptions,
210		MaxOutputTokens:  &call.MaxOutputTokens,
211		TopP:             call.TopP,
212		Temperature:      call.Temperature,
213		PresencePenalty:  call.PresencePenalty,
214		TopK:             call.TopK,
215		FrequencyPenalty: call.FrequencyPenalty,
216		// Before each step create the new assistant message
217		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
218			prepared.Messages = options.Messages
219			// reset all cached items
220			for i := range prepared.Messages {
221				prepared.Messages[i].ProviderOptions = nil
222			}
223
224			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
225			a.messageQueue.Del(call.SessionID)
226			for _, queued := range queuedCalls {
227				userMessage, createErr := a.createUserMessage(callContext, queued)
228				if createErr != nil {
229					return callContext, prepared, createErr
230				}
231				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
232			}
233
234			lastSystemRoleInx := 0
235			systemMessageUpdated := false
236			for i, msg := range prepared.Messages {
237				// only add cache control to the last message
238				if msg.Role == fantasy.MessageRoleSystem {
239					lastSystemRoleInx = i
240				} else if !systemMessageUpdated {
241					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
242					systemMessageUpdated = true
243				}
244				// than add cache control to the last 2 messages
245				if i > len(prepared.Messages)-3 {
246					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
247				}
248			}
249
250			if a.systemPromptPrefix != "" {
251				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
252			}
253
254			var assistantMsg message.Message
255			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
256				Role:     message.Assistant,
257				Parts:    []message.ContentPart{},
258				Model:    a.largeModel.ModelCfg.Model,
259				Provider: a.largeModel.ModelCfg.Provider,
260			})
261			if err != nil {
262				return callContext, prepared, err
263			}
264			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
265			currentAssistant = &assistantMsg
266			return callContext, prepared, err
267		},
268		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
269			currentAssistant.AppendReasoningContent(reasoning.Text)
270			return a.messages.Update(genCtx, *currentAssistant)
271		},
272		OnReasoningDelta: func(id string, text string) error {
273			currentAssistant.AppendReasoningContent(text)
274			return a.messages.Update(genCtx, *currentAssistant)
275		},
276		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
277			// handle anthropic signature
278			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
279				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
280					currentAssistant.AppendReasoningSignature(reasoning.Signature)
281				}
282			}
283			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
284				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
285					currentAssistant.AppendReasoningSignature(reasoning.Signature)
286				}
287			}
288			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
289				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
290					currentAssistant.SetReasoningResponsesData(reasoning)
291				}
292			}
293			currentAssistant.FinishThinking()
294			return a.messages.Update(genCtx, *currentAssistant)
295		},
296		OnTextDelta: func(id string, text string) error {
297			currentAssistant.AppendContent(text)
298			return a.messages.Update(genCtx, *currentAssistant)
299		},
300		OnToolInputStart: func(id string, toolName string) error {
301			toolCall := message.ToolCall{
302				ID:               id,
303				Name:             toolName,
304				ProviderExecuted: false,
305				Finished:         false,
306			}
307			currentAssistant.AddToolCall(toolCall)
308			return a.messages.Update(genCtx, *currentAssistant)
309		},
310		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
311			// TODO: implement
312		},
313		OnToolCall: func(tc fantasy.ToolCallContent) error {
314			// Execute PreToolUse hook - blocks tool execution on error
315			if a.hooks != nil {
316				toolInput := make(map[string]any)
317				if err := json.Unmarshal([]byte(tc.Input), &toolInput); err != nil {
318					slog.Warn("Failed to unmarshal tool input for PreToolUse hook", "error", err, "tool", tc.ToolName)
319				}
320				if err := a.hooks.Execute(genCtx, hooks.HookContext{
321					EventType: config.PreToolUse,
322					SessionID: call.SessionID,
323					ToolName:  tc.ToolName,
324					ToolInput: toolInput,
325					MessageID: currentAssistant.ID,
326					Provider:  a.largeModel.ModelCfg.Provider,
327					Model:     a.largeModel.ModelCfg.Model,
328				}); err != nil {
329					return fmt.Errorf("PreToolUse hook blocked tool execution: %w", err)
330				}
331			}
332
333			toolCall := message.ToolCall{
334				ID:               tc.ToolCallID,
335				Name:             tc.ToolName,
336				Input:            tc.Input,
337				ProviderExecuted: false,
338				Finished:         true,
339			}
340			currentAssistant.AddToolCall(toolCall)
341			return a.messages.Update(genCtx, *currentAssistant)
342		},
343		OnToolResult: func(result fantasy.ToolResultContent) error {
344			var resultContent string
345			isError := false
346			switch result.Result.GetType() {
347			case fantasy.ToolResultContentTypeText:
348				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
349				if ok {
350					resultContent = r.Text
351				}
352			case fantasy.ToolResultContentTypeError:
353				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
354				if ok {
355					isError = true
356					resultContent = r.Error.Error()
357				}
358			case fantasy.ToolResultContentTypeMedia:
359				// TODO: handle this message type
360			}
361
362			// Execute PostToolUse hook
363			if a.hooks != nil {
364				toolInput := make(map[string]any)
365				// Try to get tool input from the assistant message
366				toolCalls := currentAssistant.ToolCalls()
367				for _, tc := range toolCalls {
368					if tc.ID == result.ToolCallID {
369						if err := json.Unmarshal([]byte(tc.Input), &toolInput); err != nil {
370							slog.Debug("Failed to unmarshal tool input for PostToolUse hook", "error", err, "tool", result.ToolName)
371						}
372						break
373					}
374				}
375
376				if err := a.hooks.Execute(genCtx, hooks.HookContext{
377					EventType:  config.PostToolUse,
378					SessionID:  call.SessionID,
379					ToolName:   result.ToolName,
380					ToolInput:  toolInput,
381					ToolResult: resultContent,
382					ToolError:  isError,
383					MessageID:  currentAssistant.ID,
384					Provider:   a.largeModel.ModelCfg.Provider,
385					Model:      a.largeModel.ModelCfg.Model,
386				}); err != nil {
387					slog.Debug("post_tool_use hook execution failed", "error", err)
388				}
389			}
390
391			toolResult := message.ToolResult{
392				ToolCallID: result.ToolCallID,
393				Name:       result.ToolName,
394				Content:    resultContent,
395				IsError:    isError,
396				Metadata:   result.ClientMetadata,
397			}
398			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
399				Role: message.Tool,
400				Parts: []message.ContentPart{
401					toolResult,
402				},
403			})
404			if createMsgErr != nil {
405				return createMsgErr
406			}
407			return nil
408		},
409		OnStepFinish: func(stepResult fantasy.StepResult) error {
410			finishReason := message.FinishReasonUnknown
411			switch stepResult.FinishReason {
412			case fantasy.FinishReasonLength:
413				finishReason = message.FinishReasonMaxTokens
414			case fantasy.FinishReasonStop:
415				finishReason = message.FinishReasonEndTurn
416			case fantasy.FinishReasonToolCalls:
417				finishReason = message.FinishReasonToolUse
418			}
419			currentAssistant.AddFinish(finishReason, "", "")
420			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
421			sessionLock.Lock()
422			_, sessionErr := a.sessions.Save(genCtx, currentSession)
423			sessionLock.Unlock()
424			if sessionErr != nil {
425				return sessionErr
426			}
427			return a.messages.Update(genCtx, *currentAssistant)
428		},
429		StopWhen: []fantasy.StopCondition{
430			func(_ []fantasy.StepResult) bool {
431				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
432				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
433				remaining := cw - tokens
434				var threshold int64
435				if cw > 200_000 {
436					threshold = 20_000
437				} else {
438					threshold = int64(float64(cw) * 0.2)
439				}
440				if (remaining <= threshold) && !a.disableAutoSummarize {
441					shouldSummarize = true
442					return true
443				}
444				return false
445			},
446		},
447	})
448
449	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
450
451	if err != nil {
452		isCancelErr := errors.Is(err, context.Canceled)
453		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
454		if currentAssistant == nil {
455			return result, err
456		}
457		// Ensure we finish thinking on error to close the reasoning state
458		currentAssistant.FinishThinking()
459		toolCalls := currentAssistant.ToolCalls()
460		// INFO: we use the parent context here because the genCtx has been cancelled
461		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
462		if createErr != nil {
463			return nil, createErr
464		}
465		for _, tc := range toolCalls {
466			if !tc.Finished {
467				tc.Finished = true
468				tc.Input = "{}"
469				currentAssistant.AddToolCall(tc)
470				updateErr := a.messages.Update(ctx, *currentAssistant)
471				if updateErr != nil {
472					return nil, updateErr
473				}
474			}
475
476			found := false
477			for _, msg := range msgs {
478				if msg.Role == message.Tool {
479					for _, tr := range msg.ToolResults() {
480						if tr.ToolCallID == tc.ID {
481							found = true
482							break
483						}
484					}
485				}
486				if found {
487					break
488				}
489			}
490			if found {
491				continue
492			}
493			content := "There was an error while executing the tool"
494			if isCancelErr {
495				content = "Tool execution canceled by user"
496			} else if isPermissionErr {
497				content = "Permission denied"
498			}
499			toolResult := message.ToolResult{
500				ToolCallID: tc.ID,
501				Name:       tc.Name,
502				Content:    content,
503				IsError:    true,
504			}
505			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
506				Role: message.Tool,
507				Parts: []message.ContentPart{
508					toolResult,
509				},
510			})
511			if createErr != nil {
512				return nil, createErr
513			}
514		}
515		if isCancelErr {
516			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
517		} else if isPermissionErr {
518			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
519		} else {
520			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
521		}
522		// INFO: we use the parent context here because the genCtx has been cancelled
523		updateErr := a.messages.Update(ctx, *currentAssistant)
524		if updateErr != nil {
525			return nil, updateErr
526		}
527		return nil, err
528	}
529	wg.Wait()
530
531	// Execute Stop hook
532	if a.hooks != nil && result != nil {
533		var totalTokens, inputTokens int64
534		for _, step := range result.Steps {
535			totalTokens += step.Usage.TotalTokens
536			inputTokens += step.Usage.InputTokens
537		}
538
539		if err := a.hooks.Execute(ctx, hooks.HookContext{
540			EventType:   config.Stop,
541			SessionID:   call.SessionID,
542			MessageID:   currentAssistant.ID,
543			Provider:    a.largeModel.ModelCfg.Provider,
544			Model:       a.largeModel.ModelCfg.Model,
545			TokensUsed:  totalTokens,
546			TokensInput: inputTokens,
547		}); err != nil {
548			slog.Debug("stop hook execution failed", "error", err)
549		}
550	}
551
552	if shouldSummarize {
553		a.activeRequests.Del(call.SessionID)
554		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
555			return nil, summarizeErr
556		}
557		// if the agent was not done...
558		if len(currentAssistant.ToolCalls()) > 0 {
559			existing, ok := a.messageQueue.Get(call.SessionID)
560			if !ok {
561				existing = []SessionAgentCall{}
562			}
563			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
564			existing = append(existing, call)
565			a.messageQueue.Set(call.SessionID, existing)
566		}
567	}
568
569	// release active request before processing queued messages
570	a.activeRequests.Del(call.SessionID)
571	cancel()
572
573	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
574	if !ok || len(queuedMessages) == 0 {
575		return result, err
576	}
577	// there are queued messages restart the loop
578	firstQueuedMessage := queuedMessages[0]
579	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
580	return a.Run(ctx, firstQueuedMessage)
581}
582
583func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
584	if a.IsSessionBusy(sessionID) {
585		return ErrSessionBusy
586	}
587
588	currentSession, err := a.sessions.Get(ctx, sessionID)
589	if err != nil {
590		return fmt.Errorf("failed to get session: %w", err)
591	}
592	msgs, err := a.getSessionMessages(ctx, currentSession)
593	if err != nil {
594		return err
595	}
596	if len(msgs) == 0 {
597		// nothing to summarize
598		return nil
599	}
600
601	// Execute PreCompact hook
602	if a.hooks != nil {
603		if err := a.hooks.Execute(ctx, hooks.HookContext{
604			EventType: config.PreCompact,
605			SessionID: sessionID,
606			Provider:  a.largeModel.ModelCfg.Provider,
607			Model:     a.largeModel.ModelCfg.Model,
608		}); err != nil {
609			slog.Debug("pre_compact hook execution failed", "error", err)
610		}
611	}
612
613	aiMsgs, _ := a.preparePrompt(msgs)
614
615	genCtx, cancel := context.WithCancel(ctx)
616	a.activeRequests.Set(sessionID, cancel)
617	defer a.activeRequests.Del(sessionID)
618	defer cancel()
619
620	agent := fantasy.NewAgent(a.largeModel.Model,
621		fantasy.WithSystemPrompt(string(summaryPrompt)),
622	)
623	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
624		Role:             message.Assistant,
625		Model:            a.largeModel.Model.Model(),
626		Provider:         a.largeModel.Model.Provider(),
627		IsSummaryMessage: true,
628	})
629	if err != nil {
630		return err
631	}
632
633	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
634		Prompt:          "Provide a detailed summary of our conversation above.",
635		Messages:        aiMsgs,
636		ProviderOptions: opts,
637		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
638			prepared.Messages = options.Messages
639			if a.systemPromptPrefix != "" {
640				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
641			}
642			return callContext, prepared, nil
643		},
644		OnReasoningDelta: func(id string, text string) error {
645			summaryMessage.AppendReasoningContent(text)
646			return a.messages.Update(genCtx, summaryMessage)
647		},
648		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
649			// handle anthropic signature
650			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
651				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
652					summaryMessage.AppendReasoningSignature(signature.Signature)
653				}
654			}
655			summaryMessage.FinishThinking()
656			return a.messages.Update(genCtx, summaryMessage)
657		},
658		OnTextDelta: func(id, text string) error {
659			summaryMessage.AppendContent(text)
660			return a.messages.Update(genCtx, summaryMessage)
661		},
662	})
663	if err != nil {
664		isCancelErr := errors.Is(err, context.Canceled)
665		if isCancelErr {
666			// User cancelled summarize we need to remove the summary message
667			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
668			return deleteErr
669		}
670		return err
671	}
672
673	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
674	err = a.messages.Update(genCtx, summaryMessage)
675	if err != nil {
676		return err
677	}
678
679	var openrouterCost *float64
680	for _, step := range resp.Steps {
681		stepCost := a.openrouterCost(step.ProviderMetadata)
682		if stepCost != nil {
683			newCost := *stepCost
684			if openrouterCost != nil {
685				newCost += *openrouterCost
686			}
687			openrouterCost = &newCost
688		}
689	}
690
691	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
692
693	// just in case get just the last usage
694	usage := resp.Response.Usage
695	currentSession.SummaryMessageID = summaryMessage.ID
696	currentSession.CompletionTokens = usage.OutputTokens
697	currentSession.PromptTokens = 0
698	_, err = a.sessions.Save(genCtx, currentSession)
699	return err
700}
701
702func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
703	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
704		return fantasy.ProviderOptions{}
705	}
706	return fantasy.ProviderOptions{
707		anthropic.Name: &anthropic.ProviderCacheControlOptions{
708			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
709		},
710		bedrock.Name: &anthropic.ProviderCacheControlOptions{
711			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
712		},
713	}
714}
715
716func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
717	var attachmentParts []message.ContentPart
718	for _, attachment := range call.Attachments {
719		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
720	}
721	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
722	parts = append(parts, attachmentParts...)
723	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
724		Role:  message.User,
725		Parts: parts,
726	})
727	if err != nil {
728		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
729	}
730	return msg, nil
731}
732
733func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
734	var history []fantasy.Message
735	for _, m := range msgs {
736		if len(m.Parts) == 0 {
737			continue
738		}
739		// Assistant message without content or tool calls (cancelled before it returned anything)
740		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
741			continue
742		}
743		history = append(history, m.ToAIMessage()...)
744	}
745
746	var files []fantasy.FilePart
747	for _, attachment := range attachments {
748		files = append(files, fantasy.FilePart{
749			Filename:  attachment.FileName,
750			Data:      attachment.Content,
751			MediaType: attachment.MimeType,
752		})
753	}
754
755	return history, files
756}
757
758func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
759	msgs, err := a.messages.List(ctx, session.ID)
760	if err != nil {
761		return nil, fmt.Errorf("failed to list messages: %w", err)
762	}
763
764	if session.SummaryMessageID != "" {
765		summaryMsgInex := -1
766		for i, msg := range msgs {
767			if msg.ID == session.SummaryMessageID {
768				summaryMsgInex = i
769				break
770			}
771		}
772		if summaryMsgInex != -1 {
773			msgs = msgs[summaryMsgInex:]
774			msgs[0].Role = message.User
775		}
776	}
777	return msgs, nil
778}
779
780func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
781	if prompt == "" {
782		return
783	}
784
785	var maxOutput int64 = 40
786	if a.smallModel.CatwalkCfg.CanReason {
787		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
788	}
789
790	agent := fantasy.NewAgent(a.smallModel.Model,
791		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
792		fantasy.WithMaxOutputTokens(maxOutput),
793	)
794
795	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
796		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
797		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
798			prepared.Messages = options.Messages
799			if a.systemPromptPrefix != "" {
800				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
801			}
802			return callContext, prepared, nil
803		},
804	})
805	if err != nil {
806		slog.Error("error generating title", "err", err)
807		return
808	}
809
810	title := resp.Response.Content.Text()
811
812	title = strings.ReplaceAll(title, "\n", " ")
813
814	// remove thinking tags if present
815	if idx := strings.Index(title, "</think>"); idx > 0 {
816		title = title[idx+len("</think>"):]
817	}
818
819	title = strings.TrimSpace(title)
820	if title == "" {
821		slog.Warn("failed to generate title", "warn", "empty title")
822		return
823	}
824
825	session.Title = title
826
827	var openrouterCost *float64
828	for _, step := range resp.Steps {
829		stepCost := a.openrouterCost(step.ProviderMetadata)
830		if stepCost != nil {
831			newCost := *stepCost
832			if openrouterCost != nil {
833				newCost += *openrouterCost
834			}
835			openrouterCost = &newCost
836		}
837	}
838
839	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
840	_, saveErr := a.sessions.Save(ctx, *session)
841	if saveErr != nil {
842		slog.Error("failed to save session title & usage", "error", saveErr)
843		return
844	}
845}
846
847func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
848	openrouterMetadata, ok := metadata[openrouter.Name]
849	if !ok {
850		return nil
851	}
852
853	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
854	if !ok {
855		return nil
856	}
857	return &opts.Usage.Cost
858}
859
860func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
861	modelConfig := model.CatwalkCfg
862	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
863		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
864		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
865		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
866
867	a.eventTokensUsed(session.ID, model, usage, cost)
868
869	if overrideCost != nil {
870		session.Cost += *overrideCost
871	} else {
872		session.Cost += cost
873	}
874
875	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
876	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
877}
878
879func (a *sessionAgent) Cancel(sessionID string) {
880	// Cancel regular requests
881	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
882		slog.Info("Request cancellation initiated", "session_id", sessionID)
883		cancel()
884	}
885
886	// Also check for summarize requests
887	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
888		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
889		cancel()
890	}
891
892	if a.QueuedPrompts(sessionID) > 0 {
893		slog.Info("Clearing queued prompts", "session_id", sessionID)
894		a.messageQueue.Del(sessionID)
895	}
896}
897
898func (a *sessionAgent) ClearQueue(sessionID string) {
899	if a.QueuedPrompts(sessionID) > 0 {
900		slog.Info("Clearing queued prompts", "session_id", sessionID)
901		a.messageQueue.Del(sessionID)
902	}
903}
904
905func (a *sessionAgent) CancelAll() {
906	if !a.IsBusy() {
907		return
908	}
909	for key := range a.activeRequests.Seq2() {
910		a.Cancel(key) // key is sessionID
911	}
912
913	timeout := time.After(5 * time.Second)
914	for a.IsBusy() {
915		select {
916		case <-timeout:
917			return
918		default:
919			time.Sleep(200 * time.Millisecond)
920		}
921	}
922}
923
924func (a *sessionAgent) IsBusy() bool {
925	var busy bool
926	for cancelFunc := range a.activeRequests.Seq() {
927		if cancelFunc != nil {
928			busy = true
929			break
930		}
931	}
932	return busy
933}
934
935func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
936	_, busy := a.activeRequests.Get(sessionID)
937	return busy
938}
939
940func (a *sessionAgent) QueuedPrompts(sessionID string) int {
941	l, ok := a.messageQueue.Get(sessionID)
942	if !ok {
943		return 0
944	}
945	return len(l)
946}
947
948func (a *sessionAgent) SetModels(large Model, small Model) {
949	a.largeModel = large
950	a.smallModel = small
951}
952
953func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
954	a.tools = tools
955}
956
957func (a *sessionAgent) Model() Model {
958	return a.largeModel
959}