agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"os"
 11	"strconv"
 12	"strings"
 13	"sync"
 14	"time"
 15
 16	"charm.land/fantasy"
 17	"charm.land/fantasy/providers/anthropic"
 18	"charm.land/fantasy/providers/bedrock"
 19	"charm.land/fantasy/providers/google"
 20	"charm.land/fantasy/providers/openai"
 21	"charm.land/fantasy/providers/openrouter"
 22	"github.com/charmbracelet/catwalk/pkg/catwalk"
 23	"github.com/charmbracelet/crush/internal/agent/tools"
 24	"github.com/charmbracelet/crush/internal/config"
 25	"github.com/charmbracelet/crush/internal/csync"
 26	"github.com/charmbracelet/crush/internal/hooks"
 27	"github.com/charmbracelet/crush/internal/message"
 28	"github.com/charmbracelet/crush/internal/permission"
 29	"github.com/charmbracelet/crush/internal/session"
 30)
 31
 32//go:embed templates/title.md
 33var titlePrompt []byte
 34
 35//go:embed templates/summary.md
 36var summaryPrompt []byte
 37
 38type SessionAgentCall struct {
 39	SessionID        string
 40	Prompt           string
 41	ProviderOptions  fantasy.ProviderOptions
 42	Attachments      []message.Attachment
 43	MaxOutputTokens  int64
 44	Temperature      *float64
 45	TopP             *float64
 46	TopK             *int64
 47	FrequencyPenalty *float64
 48	PresencePenalty  *float64
 49}
 50
 51type SessionAgent interface {
 52	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 53	SetModels(large Model, small Model)
 54	SetTools(tools []fantasy.AgentTool)
 55	Cancel(sessionID string)
 56	CancelAll()
 57	IsSessionBusy(sessionID string) bool
 58	IsBusy() bool
 59	QueuedPrompts(sessionID string) int
 60	ClearQueue(sessionID string)
 61	Summarize(context.Context, string, fantasy.ProviderOptions) error
 62	Model() Model
 63}
 64
 65type Model struct {
 66	Model      fantasy.LanguageModel
 67	CatwalkCfg catwalk.Model
 68	ModelCfg   config.SelectedModel
 69}
 70
 71type sessionAgent struct {
 72	largeModel           Model
 73	smallModel           Model
 74	systemPromptPrefix   string
 75	systemPrompt         string
 76	tools                []fantasy.AgentTool
 77	sessions             session.Service
 78	messages             message.Service
 79	disableAutoSummarize bool
 80	isYolo               bool
 81	hooks                *hooks.Executor
 82
 83	messageQueue   *csync.Map[string, []SessionAgentCall]
 84	activeRequests *csync.Map[string, context.CancelFunc]
 85}
 86
 87type SessionAgentOptions struct {
 88	LargeModel           Model
 89	SmallModel           Model
 90	SystemPromptPrefix   string
 91	SystemPrompt         string
 92	DisableAutoSummarize bool
 93	IsYolo               bool
 94	Sessions             session.Service
 95	Messages             message.Service
 96	Tools                []fantasy.AgentTool
 97	Hooks                *hooks.Executor
 98}
 99
100func NewSessionAgent(
101	opts SessionAgentOptions,
102) SessionAgent {
103	return &sessionAgent{
104		largeModel:           opts.LargeModel,
105		smallModel:           opts.SmallModel,
106		systemPromptPrefix:   opts.SystemPromptPrefix,
107		systemPrompt:         opts.SystemPrompt,
108		sessions:             opts.Sessions,
109		messages:             opts.Messages,
110		disableAutoSummarize: opts.DisableAutoSummarize,
111		tools:                opts.Tools,
112		isYolo:               opts.IsYolo,
113		hooks:                opts.Hooks,
114		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
115		activeRequests:       csync.NewMap[string, context.CancelFunc](),
116	}
117}
118
119func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
120	if call.Prompt == "" {
121		return nil, ErrEmptyPrompt
122	}
123	if call.SessionID == "" {
124		return nil, ErrSessionMissing
125	}
126
127	// Queue the message if busy
128	if a.IsSessionBusy(call.SessionID) {
129		existing, ok := a.messageQueue.Get(call.SessionID)
130		if !ok {
131			existing = []SessionAgentCall{}
132		}
133		existing = append(existing, call)
134		a.messageQueue.Set(call.SessionID, existing)
135		return nil, nil
136	}
137
138	if len(a.tools) > 0 {
139		// add anthropic caching to the last tool
140		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
141	}
142
143	agent := fantasy.NewAgent(
144		a.largeModel.Model,
145		fantasy.WithSystemPrompt(a.systemPrompt),
146		fantasy.WithTools(a.tools...),
147	)
148
149	sessionLock := sync.Mutex{}
150	currentSession, err := a.sessions.Get(ctx, call.SessionID)
151	if err != nil {
152		return nil, fmt.Errorf("failed to get session: %w", err)
153	}
154
155	msgs, err := a.getSessionMessages(ctx, currentSession)
156	if err != nil {
157		return nil, fmt.Errorf("failed to get session messages: %w", err)
158	}
159
160	var wg sync.WaitGroup
161	// Generate title if first message
162	if len(msgs) == 0 {
163		wg.Go(func() {
164			sessionLock.Lock()
165			a.generateTitle(ctx, &currentSession, call.Prompt)
166			sessionLock.Unlock()
167		})
168	}
169
170	// Add the user message to the session
171	_, err = a.createUserMessage(ctx, call)
172	if err != nil {
173		return nil, err
174	}
175
176	// Execute UserPromptSubmit hook
177	if a.hooks != nil {
178		_ = a.hooks.Execute(ctx, hooks.HookContext{
179			EventType:  config.UserPromptSubmit,
180			SessionID:  call.SessionID,
181			UserPrompt: call.Prompt,
182			Provider:   a.largeModel.ModelCfg.Provider,
183			Model:      a.largeModel.ModelCfg.Model,
184		})
185	}
186
187	// add the session to the context
188	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
189
190	genCtx, cancel := context.WithCancel(ctx)
191	a.activeRequests.Set(call.SessionID, cancel)
192
193	defer cancel()
194	defer a.activeRequests.Del(call.SessionID)
195
196	history, files := a.preparePrompt(msgs, call.Attachments...)
197
198	startTime := time.Now()
199	a.eventPromptSent(call.SessionID)
200
201	var currentAssistant *message.Message
202	var shouldSummarize bool
203	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
204		Prompt:           call.Prompt,
205		Files:            files,
206		Messages:         history,
207		ProviderOptions:  call.ProviderOptions,
208		MaxOutputTokens:  &call.MaxOutputTokens,
209		TopP:             call.TopP,
210		Temperature:      call.Temperature,
211		PresencePenalty:  call.PresencePenalty,
212		TopK:             call.TopK,
213		FrequencyPenalty: call.FrequencyPenalty,
214		// Before each step create the new assistant message
215		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
216			prepared.Messages = options.Messages
217			// reset all cached items
218			for i := range prepared.Messages {
219				prepared.Messages[i].ProviderOptions = nil
220			}
221
222			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
223			a.messageQueue.Del(call.SessionID)
224			for _, queued := range queuedCalls {
225				userMessage, createErr := a.createUserMessage(callContext, queued)
226				if createErr != nil {
227					return callContext, prepared, createErr
228				}
229				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
230			}
231
232			lastSystemRoleInx := 0
233			systemMessageUpdated := false
234			for i, msg := range prepared.Messages {
235				// only add cache control to the last message
236				if msg.Role == fantasy.MessageRoleSystem {
237					lastSystemRoleInx = i
238				} else if !systemMessageUpdated {
239					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
240					systemMessageUpdated = true
241				}
242				// than add cache control to the last 2 messages
243				if i > len(prepared.Messages)-3 {
244					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
245				}
246			}
247
248			if a.systemPromptPrefix != "" {
249				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
250			}
251
252			var assistantMsg message.Message
253			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
254				Role:     message.Assistant,
255				Parts:    []message.ContentPart{},
256				Model:    a.largeModel.ModelCfg.Model,
257				Provider: a.largeModel.ModelCfg.Provider,
258			})
259			if err != nil {
260				return callContext, prepared, err
261			}
262			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
263			currentAssistant = &assistantMsg
264			return callContext, prepared, err
265		},
266		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
267			currentAssistant.AppendReasoningContent(reasoning.Text)
268			return a.messages.Update(genCtx, *currentAssistant)
269		},
270		OnReasoningDelta: func(id string, text string) error {
271			currentAssistant.AppendReasoningContent(text)
272			return a.messages.Update(genCtx, *currentAssistant)
273		},
274		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
275			// handle anthropic signature
276			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
277				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
278					currentAssistant.AppendReasoningSignature(reasoning.Signature)
279				}
280			}
281			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
282				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
283					currentAssistant.AppendReasoningSignature(reasoning.Signature)
284				}
285			}
286			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
287				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
288					currentAssistant.SetReasoningResponsesData(reasoning)
289				}
290			}
291			currentAssistant.FinishThinking()
292			return a.messages.Update(genCtx, *currentAssistant)
293		},
294		OnTextDelta: func(id string, text string) error {
295			currentAssistant.AppendContent(text)
296			return a.messages.Update(genCtx, *currentAssistant)
297		},
298		OnToolInputStart: func(id string, toolName string) error {
299			toolCall := message.ToolCall{
300				ID:               id,
301				Name:             toolName,
302				ProviderExecuted: false,
303				Finished:         false,
304			}
305			currentAssistant.AddToolCall(toolCall)
306			return a.messages.Update(genCtx, *currentAssistant)
307		},
308		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
309			// TODO: implement
310		},
311		OnToolCall: func(tc fantasy.ToolCallContent) error {
312			// Execute PreToolUse hook - blocks tool execution on error
313			if a.hooks != nil {
314				toolInput := make(map[string]any)
315				if err := json.Unmarshal([]byte(tc.Input), &toolInput); err == nil {
316					if err := a.hooks.Execute(genCtx, hooks.HookContext{
317						EventType: config.PreToolUse,
318						SessionID: call.SessionID,
319						ToolName:  tc.ToolName,
320						ToolInput: toolInput,
321						MessageID: currentAssistant.ID,
322						Provider:  a.largeModel.ModelCfg.Provider,
323						Model:     a.largeModel.ModelCfg.Model,
324					}); err != nil {
325						return fmt.Errorf("PreToolUse hook blocked tool execution: %w", err)
326					}
327				}
328			}
329
330			toolCall := message.ToolCall{
331				ID:               tc.ToolCallID,
332				Name:             tc.ToolName,
333				Input:            tc.Input,
334				ProviderExecuted: false,
335				Finished:         true,
336			}
337			currentAssistant.AddToolCall(toolCall)
338			return a.messages.Update(genCtx, *currentAssistant)
339		},
340		OnToolResult: func(result fantasy.ToolResultContent) error {
341			var resultContent string
342			isError := false
343			switch result.Result.GetType() {
344			case fantasy.ToolResultContentTypeText:
345				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
346				if ok {
347					resultContent = r.Text
348				}
349			case fantasy.ToolResultContentTypeError:
350				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
351				if ok {
352					isError = true
353					resultContent = r.Error.Error()
354				}
355			case fantasy.ToolResultContentTypeMedia:
356				// TODO: handle this message type
357			}
358
359			// Execute PostToolUse hook
360			if a.hooks != nil {
361				toolInput := make(map[string]any)
362				// Try to get tool input from the assistant message
363				toolCalls := currentAssistant.ToolCalls()
364				for _, tc := range toolCalls {
365					if tc.ID == result.ToolCallID {
366						_ = json.Unmarshal([]byte(tc.Input), &toolInput)
367						break
368					}
369				}
370
371				_ = a.hooks.Execute(genCtx, hooks.HookContext{
372					EventType:  config.PostToolUse,
373					SessionID:  call.SessionID,
374					ToolName:   result.ToolName,
375					ToolInput:  toolInput,
376					ToolResult: resultContent,
377					ToolError:  isError,
378					MessageID:  currentAssistant.ID,
379					Provider:   a.largeModel.ModelCfg.Provider,
380					Model:      a.largeModel.ModelCfg.Model,
381				})
382			}
383
384			toolResult := message.ToolResult{
385				ToolCallID: result.ToolCallID,
386				Name:       result.ToolName,
387				Content:    resultContent,
388				IsError:    isError,
389				Metadata:   result.ClientMetadata,
390			}
391			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
392				Role: message.Tool,
393				Parts: []message.ContentPart{
394					toolResult,
395				},
396			})
397			if createMsgErr != nil {
398				return createMsgErr
399			}
400			return nil
401		},
402		OnStepFinish: func(stepResult fantasy.StepResult) error {
403			finishReason := message.FinishReasonUnknown
404			switch stepResult.FinishReason {
405			case fantasy.FinishReasonLength:
406				finishReason = message.FinishReasonMaxTokens
407			case fantasy.FinishReasonStop:
408				finishReason = message.FinishReasonEndTurn
409			case fantasy.FinishReasonToolCalls:
410				finishReason = message.FinishReasonToolUse
411			}
412			currentAssistant.AddFinish(finishReason, "", "")
413			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
414			sessionLock.Lock()
415			_, sessionErr := a.sessions.Save(genCtx, currentSession)
416			sessionLock.Unlock()
417			if sessionErr != nil {
418				return sessionErr
419			}
420			return a.messages.Update(genCtx, *currentAssistant)
421		},
422		StopWhen: []fantasy.StopCondition{
423			func(_ []fantasy.StepResult) bool {
424				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
425				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
426				remaining := cw - tokens
427				var threshold int64
428				if cw > 200_000 {
429					threshold = 20_000
430				} else {
431					threshold = int64(float64(cw) * 0.2)
432				}
433				if (remaining <= threshold) && !a.disableAutoSummarize {
434					shouldSummarize = true
435					return true
436				}
437				return false
438			},
439		},
440	})
441
442	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
443
444	if err != nil {
445		isCancelErr := errors.Is(err, context.Canceled)
446		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
447		if currentAssistant == nil {
448			return result, err
449		}
450		// Ensure we finish thinking on error to close the reasoning state
451		currentAssistant.FinishThinking()
452		toolCalls := currentAssistant.ToolCalls()
453		// INFO: we use the parent context here because the genCtx has been cancelled
454		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
455		if createErr != nil {
456			return nil, createErr
457		}
458		for _, tc := range toolCalls {
459			if !tc.Finished {
460				tc.Finished = true
461				tc.Input = "{}"
462				currentAssistant.AddToolCall(tc)
463				updateErr := a.messages.Update(ctx, *currentAssistant)
464				if updateErr != nil {
465					return nil, updateErr
466				}
467			}
468
469			found := false
470			for _, msg := range msgs {
471				if msg.Role == message.Tool {
472					for _, tr := range msg.ToolResults() {
473						if tr.ToolCallID == tc.ID {
474							found = true
475							break
476						}
477					}
478				}
479				if found {
480					break
481				}
482			}
483			if found {
484				continue
485			}
486			content := "There was an error while executing the tool"
487			if isCancelErr {
488				content = "Tool execution canceled by user"
489			} else if isPermissionErr {
490				content = "Permission denied"
491			}
492			toolResult := message.ToolResult{
493				ToolCallID: tc.ID,
494				Name:       tc.Name,
495				Content:    content,
496				IsError:    true,
497			}
498			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
499				Role: message.Tool,
500				Parts: []message.ContentPart{
501					toolResult,
502				},
503			})
504			if createErr != nil {
505				return nil, createErr
506			}
507		}
508		if isCancelErr {
509			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
510		} else if isPermissionErr {
511			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
512		} else {
513			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
514		}
515		// INFO: we use the parent context here because the genCtx has been cancelled
516		updateErr := a.messages.Update(ctx, *currentAssistant)
517		if updateErr != nil {
518			return nil, updateErr
519		}
520		return nil, err
521	}
522	wg.Wait()
523
524	// Execute Stop hook
525	if a.hooks != nil && result != nil {
526		var totalTokens, inputTokens int64
527		for _, step := range result.Steps {
528			totalTokens += step.Usage.TotalTokens
529			inputTokens += step.Usage.InputTokens
530		}
531
532		_ = a.hooks.Execute(ctx, hooks.HookContext{
533			EventType:   config.Stop,
534			SessionID:   call.SessionID,
535			MessageID:   currentAssistant.ID,
536			Provider:    a.largeModel.ModelCfg.Provider,
537			Model:       a.largeModel.ModelCfg.Model,
538			TokensUsed:  totalTokens,
539			TokensInput: inputTokens,
540		})
541	}
542
543	if shouldSummarize {
544		a.activeRequests.Del(call.SessionID)
545		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
546			return nil, summarizeErr
547		}
548		// if the agent was not done...
549		if len(currentAssistant.ToolCalls()) > 0 {
550			existing, ok := a.messageQueue.Get(call.SessionID)
551			if !ok {
552				existing = []SessionAgentCall{}
553			}
554			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
555			existing = append(existing, call)
556			a.messageQueue.Set(call.SessionID, existing)
557		}
558	}
559
560	// release active request before processing queued messages
561	a.activeRequests.Del(call.SessionID)
562	cancel()
563
564	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
565	if !ok || len(queuedMessages) == 0 {
566		return result, err
567	}
568	// there are queued messages restart the loop
569	firstQueuedMessage := queuedMessages[0]
570	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
571	return a.Run(ctx, firstQueuedMessage)
572}
573
574func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
575	if a.IsSessionBusy(sessionID) {
576		return ErrSessionBusy
577	}
578
579	currentSession, err := a.sessions.Get(ctx, sessionID)
580	if err != nil {
581		return fmt.Errorf("failed to get session: %w", err)
582	}
583	msgs, err := a.getSessionMessages(ctx, currentSession)
584	if err != nil {
585		return err
586	}
587	if len(msgs) == 0 {
588		// nothing to summarize
589		return nil
590	}
591
592	aiMsgs, _ := a.preparePrompt(msgs)
593
594	genCtx, cancel := context.WithCancel(ctx)
595	a.activeRequests.Set(sessionID, cancel)
596	defer a.activeRequests.Del(sessionID)
597	defer cancel()
598
599	agent := fantasy.NewAgent(a.largeModel.Model,
600		fantasy.WithSystemPrompt(string(summaryPrompt)),
601	)
602	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
603		Role:             message.Assistant,
604		Model:            a.largeModel.Model.Model(),
605		Provider:         a.largeModel.Model.Provider(),
606		IsSummaryMessage: true,
607	})
608	if err != nil {
609		return err
610	}
611
612	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
613		Prompt:          "Provide a detailed summary of our conversation above.",
614		Messages:        aiMsgs,
615		ProviderOptions: opts,
616		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
617			prepared.Messages = options.Messages
618			if a.systemPromptPrefix != "" {
619				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
620			}
621			return callContext, prepared, nil
622		},
623		OnReasoningDelta: func(id string, text string) error {
624			summaryMessage.AppendReasoningContent(text)
625			return a.messages.Update(genCtx, summaryMessage)
626		},
627		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
628			// handle anthropic signature
629			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
630				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
631					summaryMessage.AppendReasoningSignature(signature.Signature)
632				}
633			}
634			summaryMessage.FinishThinking()
635			return a.messages.Update(genCtx, summaryMessage)
636		},
637		OnTextDelta: func(id, text string) error {
638			summaryMessage.AppendContent(text)
639			return a.messages.Update(genCtx, summaryMessage)
640		},
641	})
642	if err != nil {
643		isCancelErr := errors.Is(err, context.Canceled)
644		if isCancelErr {
645			// User cancelled summarize we need to remove the summary message
646			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
647			return deleteErr
648		}
649		return err
650	}
651
652	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
653	err = a.messages.Update(genCtx, summaryMessage)
654	if err != nil {
655		return err
656	}
657
658	var openrouterCost *float64
659	for _, step := range resp.Steps {
660		stepCost := a.openrouterCost(step.ProviderMetadata)
661		if stepCost != nil {
662			newCost := *stepCost
663			if openrouterCost != nil {
664				newCost += *openrouterCost
665			}
666			openrouterCost = &newCost
667		}
668	}
669
670	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
671
672	// just in case get just the last usage
673	usage := resp.Response.Usage
674	currentSession.SummaryMessageID = summaryMessage.ID
675	currentSession.CompletionTokens = usage.OutputTokens
676	currentSession.PromptTokens = 0
677	_, err = a.sessions.Save(genCtx, currentSession)
678	return err
679}
680
681func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
682	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
683		return fantasy.ProviderOptions{}
684	}
685	return fantasy.ProviderOptions{
686		anthropic.Name: &anthropic.ProviderCacheControlOptions{
687			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
688		},
689		bedrock.Name: &anthropic.ProviderCacheControlOptions{
690			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
691		},
692	}
693}
694
695func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
696	var attachmentParts []message.ContentPart
697	for _, attachment := range call.Attachments {
698		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
699	}
700	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
701	parts = append(parts, attachmentParts...)
702	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
703		Role:  message.User,
704		Parts: parts,
705	})
706	if err != nil {
707		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
708	}
709	return msg, nil
710}
711
712func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
713	var history []fantasy.Message
714	for _, m := range msgs {
715		if len(m.Parts) == 0 {
716			continue
717		}
718		// Assistant message without content or tool calls (cancelled before it returned anything)
719		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
720			continue
721		}
722		history = append(history, m.ToAIMessage()...)
723	}
724
725	var files []fantasy.FilePart
726	for _, attachment := range attachments {
727		files = append(files, fantasy.FilePart{
728			Filename:  attachment.FileName,
729			Data:      attachment.Content,
730			MediaType: attachment.MimeType,
731		})
732	}
733
734	return history, files
735}
736
737func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
738	msgs, err := a.messages.List(ctx, session.ID)
739	if err != nil {
740		return nil, fmt.Errorf("failed to list messages: %w", err)
741	}
742
743	if session.SummaryMessageID != "" {
744		summaryMsgInex := -1
745		for i, msg := range msgs {
746			if msg.ID == session.SummaryMessageID {
747				summaryMsgInex = i
748				break
749			}
750		}
751		if summaryMsgInex != -1 {
752			msgs = msgs[summaryMsgInex:]
753			msgs[0].Role = message.User
754		}
755	}
756	return msgs, nil
757}
758
759func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
760	if prompt == "" {
761		return
762	}
763
764	var maxOutput int64 = 40
765	if a.smallModel.CatwalkCfg.CanReason {
766		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
767	}
768
769	agent := fantasy.NewAgent(a.smallModel.Model,
770		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
771		fantasy.WithMaxOutputTokens(maxOutput),
772	)
773
774	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
775		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
776		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
777			prepared.Messages = options.Messages
778			if a.systemPromptPrefix != "" {
779				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
780			}
781			return callContext, prepared, nil
782		},
783	})
784	if err != nil {
785		slog.Error("error generating title", "err", err)
786		return
787	}
788
789	title := resp.Response.Content.Text()
790
791	title = strings.ReplaceAll(title, "\n", " ")
792
793	// remove thinking tags if present
794	if idx := strings.Index(title, "</think>"); idx > 0 {
795		title = title[idx+len("</think>"):]
796	}
797
798	title = strings.TrimSpace(title)
799	if title == "" {
800		slog.Warn("failed to generate title", "warn", "empty title")
801		return
802	}
803
804	session.Title = title
805
806	var openrouterCost *float64
807	for _, step := range resp.Steps {
808		stepCost := a.openrouterCost(step.ProviderMetadata)
809		if stepCost != nil {
810			newCost := *stepCost
811			if openrouterCost != nil {
812				newCost += *openrouterCost
813			}
814			openrouterCost = &newCost
815		}
816	}
817
818	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
819	_, saveErr := a.sessions.Save(ctx, *session)
820	if saveErr != nil {
821		slog.Error("failed to save session title & usage", "error", saveErr)
822		return
823	}
824}
825
826func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
827	openrouterMetadata, ok := metadata[openrouter.Name]
828	if !ok {
829		return nil
830	}
831
832	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
833	if !ok {
834		return nil
835	}
836	return &opts.Usage.Cost
837}
838
839func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
840	modelConfig := model.CatwalkCfg
841	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
842		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
843		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
844		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
845
846	a.eventTokensUsed(session.ID, model, usage, cost)
847
848	if overrideCost != nil {
849		session.Cost += *overrideCost
850	} else {
851		session.Cost += cost
852	}
853
854	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
855	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
856}
857
858func (a *sessionAgent) Cancel(sessionID string) {
859	// Cancel regular requests
860	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
861		slog.Info("Request cancellation initiated", "session_id", sessionID)
862		cancel()
863	}
864
865	// Also check for summarize requests
866	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
867		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
868		cancel()
869	}
870
871	if a.QueuedPrompts(sessionID) > 0 {
872		slog.Info("Clearing queued prompts", "session_id", sessionID)
873		a.messageQueue.Del(sessionID)
874	}
875}
876
877func (a *sessionAgent) ClearQueue(sessionID string) {
878	if a.QueuedPrompts(sessionID) > 0 {
879		slog.Info("Clearing queued prompts", "session_id", sessionID)
880		a.messageQueue.Del(sessionID)
881	}
882}
883
884func (a *sessionAgent) CancelAll() {
885	if !a.IsBusy() {
886		return
887	}
888	for key := range a.activeRequests.Seq2() {
889		a.Cancel(key) // key is sessionID
890	}
891
892	timeout := time.After(5 * time.Second)
893	for a.IsBusy() {
894		select {
895		case <-timeout:
896			return
897		default:
898			time.Sleep(200 * time.Millisecond)
899		}
900	}
901}
902
903func (a *sessionAgent) IsBusy() bool {
904	var busy bool
905	for cancelFunc := range a.activeRequests.Seq() {
906		if cancelFunc != nil {
907			busy = true
908			break
909		}
910	}
911	return busy
912}
913
914func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
915	_, busy := a.activeRequests.Get(sessionID)
916	return busy
917}
918
919func (a *sessionAgent) QueuedPrompts(sessionID string) int {
920	l, ok := a.messageQueue.Get(sessionID)
921	if !ok {
922		return 0
923	}
924	return len(l)
925}
926
927func (a *sessionAgent) SetModels(large Model, small Model) {
928	a.largeModel = large
929	a.smallModel = small
930}
931
932func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
933	a.tools = tools
934}
935
936func (a *sessionAgent) Model() Model {
937	return a.largeModel
938}