agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"os"
 10	"strconv"
 11	"strings"
 12	"sync"
 13	"time"
 14
 15	"charm.land/fantasy"
 16	"charm.land/fantasy/providers/anthropic"
 17	"charm.land/fantasy/providers/bedrock"
 18	"charm.land/fantasy/providers/google"
 19	"charm.land/fantasy/providers/openai"
 20	"charm.land/fantasy/providers/openrouter"
 21	"github.com/charmbracelet/catwalk/pkg/catwalk"
 22	"git.secluded.site/crush/internal/agent/tools"
 23	"git.secluded.site/crush/internal/config"
 24	"git.secluded.site/crush/internal/csync"
 25	"git.secluded.site/crush/internal/message"
 26	"git.secluded.site/crush/internal/notification"
 27	"git.secluded.site/crush/internal/permission"
 28	"git.secluded.site/crush/internal/session"
 29)
 30
 31//go:embed templates/title.md
 32var titlePrompt []byte
 33
 34//go:embed templates/summary.md
 35var summaryPrompt []byte
 36
 37type SessionAgentCall struct {
 38	SessionID        string
 39	Prompt           string
 40	ProviderOptions  fantasy.ProviderOptions
 41	Attachments      []message.Attachment
 42	MaxOutputTokens  int64
 43	Temperature      *float64
 44	TopP             *float64
 45	TopK             *int64
 46	FrequencyPenalty *float64
 47	PresencePenalty  *float64
 48}
 49
 50type SessionAgent interface {
 51	Run(context.Context, SessionAgentCall) (*fantasy.AgentResult, error)
 52	SetModels(large Model, small Model)
 53	SetTools(tools []fantasy.AgentTool)
 54	Cancel(sessionID string)
 55	CancelAll()
 56	IsSessionBusy(sessionID string) bool
 57	IsBusy() bool
 58	QueuedPrompts(sessionID string) int
 59	ClearQueue(sessionID string)
 60	Summarize(context.Context, string, fantasy.ProviderOptions) error
 61	Model() Model
 62}
 63
 64const completionNotificationDelay = 5 * time.Second
 65
 66type Model struct {
 67	Model      fantasy.LanguageModel
 68	CatwalkCfg catwalk.Model
 69	ModelCfg   config.SelectedModel
 70}
 71
 72type sessionAgent struct {
 73	largeModel           Model
 74	smallModel           Model
 75	systemPromptPrefix   string
 76	systemPrompt         string
 77	tools                []fantasy.AgentTool
 78	sessions             session.Service
 79	messages             message.Service
 80	disableAutoSummarize bool
 81	isYolo               bool
 82
 83	messageQueue      *csync.Map[string, []SessionAgentCall]
 84	activeRequests    *csync.Map[string, context.CancelFunc]
 85	notifier          *notification.Notifier
 86	notifyCtx         context.Context
 87	completionCancels *csync.Map[string, context.CancelFunc]
 88}
 89
 90type SessionAgentOptions struct {
 91	LargeModel           Model
 92	SmallModel           Model
 93	SystemPromptPrefix   string
 94	SystemPrompt         string
 95	DisableAutoSummarize bool
 96	IsYolo               bool
 97	Sessions             session.Service
 98	Messages             message.Service
 99	Tools                []fantasy.AgentTool
100	Notifier             *notification.Notifier
101	NotificationCtx      context.Context
102}
103
104func NewSessionAgent(
105	opts SessionAgentOptions,
106) SessionAgent {
107	notifyCtx := opts.NotificationCtx
108	if notifyCtx == nil {
109		notifyCtx = context.Background()
110	}
111
112	return &sessionAgent{
113		largeModel:           opts.LargeModel,
114		smallModel:           opts.SmallModel,
115		systemPromptPrefix:   opts.SystemPromptPrefix,
116		systemPrompt:         opts.SystemPrompt,
117		sessions:             opts.Sessions,
118		messages:             opts.Messages,
119		disableAutoSummarize: opts.DisableAutoSummarize,
120		tools:                opts.Tools,
121		isYolo:               opts.IsYolo,
122		messageQueue:         csync.NewMap[string, []SessionAgentCall](),
123		activeRequests:       csync.NewMap[string, context.CancelFunc](),
124		notifier:             opts.Notifier,
125		notifyCtx:            notifyCtx,
126		completionCancels:    csync.NewMap[string, context.CancelFunc](),
127	}
128}
129
130func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*fantasy.AgentResult, error) {
131	if call.Prompt == "" {
132		return nil, ErrEmptyPrompt
133	}
134	if call.SessionID == "" {
135		return nil, ErrSessionMissing
136	}
137
138	a.cancelCompletionNotification(call.SessionID)
139
140	// Queue the message if busy
141	if a.IsSessionBusy(call.SessionID) {
142		existing, ok := a.messageQueue.Get(call.SessionID)
143		if !ok {
144			existing = []SessionAgentCall{}
145		}
146		existing = append(existing, call)
147		a.messageQueue.Set(call.SessionID, existing)
148		return nil, nil
149	}
150
151	if len(a.tools) > 0 {
152		// add anthropic caching to the last tool
153		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
154	}
155
156	agent := fantasy.NewAgent(
157		a.largeModel.Model,
158		fantasy.WithSystemPrompt(a.systemPrompt),
159		fantasy.WithTools(a.tools...),
160	)
161
162	sessionLock := sync.Mutex{}
163	currentSession, err := a.sessions.Get(ctx, call.SessionID)
164	if err != nil {
165		return nil, fmt.Errorf("failed to get session: %w", err)
166	}
167
168	msgs, err := a.getSessionMessages(ctx, currentSession)
169	if err != nil {
170		return nil, fmt.Errorf("failed to get session messages: %w", err)
171	}
172
173	var wg sync.WaitGroup
174	// Generate title if first message
175	if len(msgs) == 0 {
176		wg.Go(func() {
177			sessionLock.Lock()
178			a.generateTitle(ctx, &currentSession, call.Prompt)
179			sessionLock.Unlock()
180		})
181	}
182
183	// Add the user message to the session
184	_, err = a.createUserMessage(ctx, call)
185	if err != nil {
186		return nil, err
187	}
188
189	// add the session to the context
190	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
191
192	genCtx, cancel := context.WithCancel(ctx)
193	a.activeRequests.Set(call.SessionID, cancel)
194
195	defer cancel()
196	defer a.activeRequests.Del(call.SessionID)
197
198	history, files := a.preparePrompt(msgs, call.Attachments...)
199
200	startTime := time.Now()
201	a.eventPromptSent(call.SessionID)
202
203	var currentAssistant *message.Message
204	var shouldSummarize bool
205	result, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
206		Prompt:           call.Prompt,
207		Files:            files,
208		Messages:         history,
209		ProviderOptions:  call.ProviderOptions,
210		MaxOutputTokens:  &call.MaxOutputTokens,
211		TopP:             call.TopP,
212		Temperature:      call.Temperature,
213		PresencePenalty:  call.PresencePenalty,
214		TopK:             call.TopK,
215		FrequencyPenalty: call.FrequencyPenalty,
216		// Before each step create the new assistant message
217		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
218			prepared.Messages = options.Messages
219			// reset all cached items
220			for i := range prepared.Messages {
221				prepared.Messages[i].ProviderOptions = nil
222			}
223
224			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
225			a.messageQueue.Del(call.SessionID)
226			for _, queued := range queuedCalls {
227				userMessage, createErr := a.createUserMessage(callContext, queued)
228				if createErr != nil {
229					return callContext, prepared, createErr
230				}
231				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
232			}
233
234			lastSystemRoleInx := 0
235			systemMessageUpdated := false
236			for i, msg := range prepared.Messages {
237				// only add cache control to the last message
238				if msg.Role == fantasy.MessageRoleSystem {
239					lastSystemRoleInx = i
240				} else if !systemMessageUpdated {
241					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
242					systemMessageUpdated = true
243				}
244				// than add cache control to the last 2 messages
245				if i > len(prepared.Messages)-3 {
246					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
247				}
248			}
249
250			if a.systemPromptPrefix != "" {
251				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
252			}
253
254			var assistantMsg message.Message
255			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
256				Role:     message.Assistant,
257				Parts:    []message.ContentPart{},
258				Model:    a.largeModel.ModelCfg.Model,
259				Provider: a.largeModel.ModelCfg.Provider,
260			})
261			if err != nil {
262				return callContext, prepared, err
263			}
264			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
265			currentAssistant = &assistantMsg
266			return callContext, prepared, err
267		},
268		OnReasoningStart: func(id string, reasoning fantasy.ReasoningContent) error {
269			currentAssistant.AppendReasoningContent(reasoning.Text)
270			return a.messages.Update(genCtx, *currentAssistant)
271		},
272		OnReasoningDelta: func(id string, text string) error {
273			currentAssistant.AppendReasoningContent(text)
274			return a.messages.Update(genCtx, *currentAssistant)
275		},
276		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
277			// handle anthropic signature
278			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
279				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
280					currentAssistant.AppendReasoningSignature(reasoning.Signature)
281				}
282			}
283			if googleData, ok := reasoning.ProviderMetadata[google.Name]; ok {
284				if reasoning, ok := googleData.(*google.ReasoningMetadata); ok {
285					currentAssistant.AppendReasoningSignature(reasoning.Signature)
286				}
287			}
288			if openaiData, ok := reasoning.ProviderMetadata[openai.Name]; ok {
289				if reasoning, ok := openaiData.(*openai.ResponsesReasoningMetadata); ok {
290					currentAssistant.SetReasoningResponsesData(reasoning)
291				}
292			}
293			currentAssistant.FinishThinking()
294			return a.messages.Update(genCtx, *currentAssistant)
295		},
296		OnTextDelta: func(id string, text string) error {
297			currentAssistant.AppendContent(text)
298			return a.messages.Update(genCtx, *currentAssistant)
299		},
300		OnToolInputStart: func(id string, toolName string) error {
301			toolCall := message.ToolCall{
302				ID:               id,
303				Name:             toolName,
304				ProviderExecuted: false,
305				Finished:         false,
306			}
307			currentAssistant.AddToolCall(toolCall)
308			return a.messages.Update(genCtx, *currentAssistant)
309		},
310		OnRetry: func(err *fantasy.APICallError, delay time.Duration) {
311			// TODO: implement
312		},
313		OnToolCall: func(tc fantasy.ToolCallContent) error {
314			toolCall := message.ToolCall{
315				ID:               tc.ToolCallID,
316				Name:             tc.ToolName,
317				Input:            tc.Input,
318				ProviderExecuted: false,
319				Finished:         true,
320			}
321			currentAssistant.AddToolCall(toolCall)
322			return a.messages.Update(genCtx, *currentAssistant)
323		},
324		OnToolResult: func(result fantasy.ToolResultContent) error {
325			var resultContent string
326			isError := false
327			switch result.Result.GetType() {
328			case fantasy.ToolResultContentTypeText:
329				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentText](result.Result)
330				if ok {
331					resultContent = r.Text
332				}
333			case fantasy.ToolResultContentTypeError:
334				r, ok := fantasy.AsToolResultOutputType[fantasy.ToolResultOutputContentError](result.Result)
335				if ok {
336					isError = true
337					resultContent = r.Error.Error()
338				}
339			case fantasy.ToolResultContentTypeMedia:
340				// TODO: handle this message type
341			}
342			toolResult := message.ToolResult{
343				ToolCallID: result.ToolCallID,
344				Name:       result.ToolName,
345				Content:    resultContent,
346				IsError:    isError,
347				Metadata:   result.ClientMetadata,
348			}
349			_, createMsgErr := a.messages.Create(genCtx, currentAssistant.SessionID, message.CreateMessageParams{
350				Role: message.Tool,
351				Parts: []message.ContentPart{
352					toolResult,
353				},
354			})
355			if createMsgErr != nil {
356				return createMsgErr
357			}
358			return nil
359		},
360		OnStepFinish: func(stepResult fantasy.StepResult) error {
361			finishReason := message.FinishReasonUnknown
362			switch stepResult.FinishReason {
363			case fantasy.FinishReasonLength:
364				finishReason = message.FinishReasonMaxTokens
365			case fantasy.FinishReasonStop:
366				finishReason = message.FinishReasonEndTurn
367			case fantasy.FinishReasonToolCalls:
368				finishReason = message.FinishReasonToolUse
369			}
370			currentAssistant.AddFinish(finishReason, "", "")
371			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage, a.openrouterCost(stepResult.ProviderMetadata))
372			sessionLock.Lock()
373			_, sessionErr := a.sessions.Save(genCtx, currentSession)
374			sessionLock.Unlock()
375			if sessionErr != nil {
376				return sessionErr
377			}
378			if finishReason == message.FinishReasonEndTurn {
379				a.scheduleCompletionNotification(call.SessionID, currentSession.Title)
380			}
381			return a.messages.Update(genCtx, *currentAssistant)
382		},
383		StopWhen: []fantasy.StopCondition{
384			func(_ []fantasy.StepResult) bool {
385				cw := int64(a.largeModel.CatwalkCfg.ContextWindow)
386				tokens := currentSession.CompletionTokens + currentSession.PromptTokens
387				remaining := cw - tokens
388				var threshold int64
389				if cw > 200_000 {
390					threshold = 20_000
391				} else {
392					threshold = int64(float64(cw) * 0.2)
393				}
394				if (remaining <= threshold) && !a.disableAutoSummarize {
395					shouldSummarize = true
396					return true
397				}
398				return false
399			},
400		},
401	})
402
403	a.eventPromptResponded(call.SessionID, time.Since(startTime).Truncate(time.Second))
404
405	if err != nil {
406		isCancelErr := errors.Is(err, context.Canceled)
407		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
408		if currentAssistant == nil {
409			return result, err
410		}
411		// Ensure we finish thinking on error to close the reasoning state
412		currentAssistant.FinishThinking()
413		toolCalls := currentAssistant.ToolCalls()
414		// INFO: we use the parent context here because the genCtx has been cancelled
415		msgs, createErr := a.messages.List(ctx, currentAssistant.SessionID)
416		if createErr != nil {
417			return nil, createErr
418		}
419		for _, tc := range toolCalls {
420			if !tc.Finished {
421				tc.Finished = true
422				tc.Input = "{}"
423				currentAssistant.AddToolCall(tc)
424				updateErr := a.messages.Update(ctx, *currentAssistant)
425				if updateErr != nil {
426					return nil, updateErr
427				}
428			}
429
430			found := false
431			for _, msg := range msgs {
432				if msg.Role == message.Tool {
433					for _, tr := range msg.ToolResults() {
434						if tr.ToolCallID == tc.ID {
435							found = true
436							break
437						}
438					}
439				}
440				if found {
441					break
442				}
443			}
444			if found {
445				continue
446			}
447			content := "There was an error while executing the tool"
448			if isCancelErr {
449				content = "Tool execution canceled by user"
450			} else if isPermissionErr {
451				content = "Permission denied"
452			}
453			toolResult := message.ToolResult{
454				ToolCallID: tc.ID,
455				Name:       tc.Name,
456				Content:    content,
457				IsError:    true,
458			}
459			_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
460				Role: message.Tool,
461				Parts: []message.ContentPart{
462					toolResult,
463				},
464			})
465			if createErr != nil {
466				return nil, createErr
467			}
468		}
469		if isCancelErr {
470			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
471		} else if isPermissionErr {
472			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
473		} else {
474			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
475		}
476		// INFO: we use the parent context here because the genCtx has been cancelled
477		updateErr := a.messages.Update(ctx, *currentAssistant)
478		if updateErr != nil {
479			return nil, updateErr
480		}
481		return nil, err
482	}
483	wg.Wait()
484
485	if shouldSummarize {
486		a.cancelCompletionNotification(call.SessionID)
487		a.activeRequests.Del(call.SessionID)
488		if summarizeErr := a.Summarize(genCtx, call.SessionID, call.ProviderOptions); summarizeErr != nil {
489			return nil, summarizeErr
490		}
491		// if the agent was not done...
492		if len(currentAssistant.ToolCalls()) > 0 {
493			existing, ok := a.messageQueue.Get(call.SessionID)
494			if !ok {
495				existing = []SessionAgentCall{}
496			}
497			call.Prompt = fmt.Sprintf("The previous session was interrupted because it got too long, the initial user request was: `%s`", call.Prompt)
498			existing = append(existing, call)
499			a.messageQueue.Set(call.SessionID, existing)
500		}
501	}
502
503	// release active request before processing queued messages
504	a.activeRequests.Del(call.SessionID)
505	cancel()
506
507	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
508	if !ok || len(queuedMessages) == 0 {
509		return result, err
510	}
511	// there are queued messages restart the loop
512	firstQueuedMessage := queuedMessages[0]
513	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
514	return a.Run(ctx, firstQueuedMessage)
515}
516
517func (a *sessionAgent) scheduleCompletionNotification(sessionID, sessionTitle string) {
518	if a.notifier == nil {
519		return
520	}
521
522	if sessionTitle == "" {
523		sessionTitle = sessionID
524	}
525
526	if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
527		cancel()
528	}
529
530	title := "💘 Crush is waiting"
531	message := fmt.Sprintf("Agent's turn completed in session \"%s\"", sessionTitle)
532	cancel := a.notifier.NotifyTaskComplete(a.notifyCtx, title, message, completionNotificationDelay)
533	if cancel == nil {
534		cancel = func() {}
535	}
536	a.completionCancels.Set(sessionID, cancel)
537}
538
539func (a *sessionAgent) cancelCompletionNotification(sessionID string) {
540	if a.notifier == nil {
541		return
542	}
543
544	if cancel, ok := a.completionCancels.Take(sessionID); ok && cancel != nil {
545		cancel()
546	}
547}
548
549func (a *sessionAgent) Summarize(ctx context.Context, sessionID string, opts fantasy.ProviderOptions) error {
550	if a.IsSessionBusy(sessionID) {
551		return ErrSessionBusy
552	}
553
554	currentSession, err := a.sessions.Get(ctx, sessionID)
555	if err != nil {
556		return fmt.Errorf("failed to get session: %w", err)
557	}
558	msgs, err := a.getSessionMessages(ctx, currentSession)
559	if err != nil {
560		return err
561	}
562	if len(msgs) == 0 {
563		// nothing to summarize
564		return nil
565	}
566
567	aiMsgs, _ := a.preparePrompt(msgs)
568
569	genCtx, cancel := context.WithCancel(ctx)
570	a.activeRequests.Set(sessionID, cancel)
571	defer a.activeRequests.Del(sessionID)
572	defer cancel()
573
574	agent := fantasy.NewAgent(a.largeModel.Model,
575		fantasy.WithSystemPrompt(string(summaryPrompt)),
576	)
577	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
578		Role:             message.Assistant,
579		Model:            a.largeModel.Model.Model(),
580		Provider:         a.largeModel.Model.Provider(),
581		IsSummaryMessage: true,
582	})
583	if err != nil {
584		return err
585	}
586
587	resp, err := agent.Stream(genCtx, fantasy.AgentStreamCall{
588		Prompt:          "Provide a detailed summary of our conversation above.",
589		Messages:        aiMsgs,
590		ProviderOptions: opts,
591		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
592			prepared.Messages = options.Messages
593			if a.systemPromptPrefix != "" {
594				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
595			}
596			return callContext, prepared, nil
597		},
598		OnReasoningDelta: func(id string, text string) error {
599			summaryMessage.AppendReasoningContent(text)
600			return a.messages.Update(genCtx, summaryMessage)
601		},
602		OnReasoningEnd: func(id string, reasoning fantasy.ReasoningContent) error {
603			// handle anthropic signature
604			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
605				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
606					summaryMessage.AppendReasoningSignature(signature.Signature)
607				}
608			}
609			summaryMessage.FinishThinking()
610			return a.messages.Update(genCtx, summaryMessage)
611		},
612		OnTextDelta: func(id, text string) error {
613			summaryMessage.AppendContent(text)
614			return a.messages.Update(genCtx, summaryMessage)
615		},
616	})
617	if err != nil {
618		isCancelErr := errors.Is(err, context.Canceled)
619		if isCancelErr {
620			// User cancelled summarize we need to remove the summary message
621			deleteErr := a.messages.Delete(ctx, summaryMessage.ID)
622			return deleteErr
623		}
624		return err
625	}
626
627	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
628	err = a.messages.Update(genCtx, summaryMessage)
629	if err != nil {
630		return err
631	}
632
633	var openrouterCost *float64
634	for _, step := range resp.Steps {
635		stepCost := a.openrouterCost(step.ProviderMetadata)
636		if stepCost != nil {
637			newCost := *stepCost
638			if openrouterCost != nil {
639				newCost += *openrouterCost
640			}
641			openrouterCost = &newCost
642		}
643	}
644
645	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage, openrouterCost)
646
647	// just in case get just the last usage
648	usage := resp.Response.Usage
649	currentSession.SummaryMessageID = summaryMessage.ID
650	currentSession.CompletionTokens = usage.OutputTokens
651	currentSession.PromptTokens = 0
652	_, err = a.sessions.Save(genCtx, currentSession)
653	return err
654}
655
656func (a *sessionAgent) getCacheControlOptions() fantasy.ProviderOptions {
657	if t, _ := strconv.ParseBool(os.Getenv("CRUSH_DISABLE_ANTHROPIC_CACHE")); t {
658		return fantasy.ProviderOptions{}
659	}
660	return fantasy.ProviderOptions{
661		anthropic.Name: &anthropic.ProviderCacheControlOptions{
662			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
663		},
664		bedrock.Name: &anthropic.ProviderCacheControlOptions{
665			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
666		},
667	}
668}
669
670func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
671	var attachmentParts []message.ContentPart
672	for _, attachment := range call.Attachments {
673		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
674	}
675	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
676	parts = append(parts, attachmentParts...)
677	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
678		Role:  message.User,
679		Parts: parts,
680	})
681	if err != nil {
682		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
683	}
684	return msg, nil
685}
686
687func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]fantasy.Message, []fantasy.FilePart) {
688	var history []fantasy.Message
689	for _, m := range msgs {
690		if len(m.Parts) == 0 {
691			continue
692		}
693		// Assistant message without content or tool calls (cancelled before it returned anything)
694		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
695			continue
696		}
697		history = append(history, m.ToAIMessage()...)
698	}
699
700	var files []fantasy.FilePart
701	for _, attachment := range attachments {
702		files = append(files, fantasy.FilePart{
703			Filename:  attachment.FileName,
704			Data:      attachment.Content,
705			MediaType: attachment.MimeType,
706		})
707	}
708
709	return history, files
710}
711
712func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
713	msgs, err := a.messages.List(ctx, session.ID)
714	if err != nil {
715		return nil, fmt.Errorf("failed to list messages: %w", err)
716	}
717
718	if session.SummaryMessageID != "" {
719		summaryMsgInex := -1
720		for i, msg := range msgs {
721			if msg.ID == session.SummaryMessageID {
722				summaryMsgInex = i
723				break
724			}
725		}
726		if summaryMsgInex != -1 {
727			msgs = msgs[summaryMsgInex:]
728			msgs[0].Role = message.User
729		}
730	}
731	return msgs, nil
732}
733
734func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
735	if prompt == "" {
736		return
737	}
738
739	var maxOutput int64 = 40
740	if a.smallModel.CatwalkCfg.CanReason {
741		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
742	}
743
744	agent := fantasy.NewAgent(a.smallModel.Model,
745		fantasy.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
746		fantasy.WithMaxOutputTokens(maxOutput),
747	)
748
749	resp, err := agent.Stream(ctx, fantasy.AgentStreamCall{
750		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
751		PrepareStep: func(callContext context.Context, options fantasy.PrepareStepFunctionOptions) (_ context.Context, prepared fantasy.PrepareStepResult, err error) {
752			prepared.Messages = options.Messages
753			if a.systemPromptPrefix != "" {
754				prepared.Messages = append([]fantasy.Message{fantasy.NewSystemMessage(a.systemPromptPrefix)}, prepared.Messages...)
755			}
756			return callContext, prepared, nil
757		},
758	})
759	if err != nil {
760		slog.Error("error generating title", "err", err)
761		return
762	}
763
764	title := resp.Response.Content.Text()
765
766	title = strings.ReplaceAll(title, "\n", " ")
767
768	// remove thinking tags if present
769	if idx := strings.Index(title, "</think>"); idx > 0 {
770		title = title[idx+len("</think>"):]
771	}
772
773	title = strings.TrimSpace(title)
774	if title == "" {
775		slog.Warn("failed to generate title", "warn", "empty title")
776		return
777	}
778
779	session.Title = title
780
781	var openrouterCost *float64
782	for _, step := range resp.Steps {
783		stepCost := a.openrouterCost(step.ProviderMetadata)
784		if stepCost != nil {
785			newCost := *stepCost
786			if openrouterCost != nil {
787				newCost += *openrouterCost
788			}
789			openrouterCost = &newCost
790		}
791	}
792
793	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage, openrouterCost)
794	_, saveErr := a.sessions.Save(ctx, *session)
795	if saveErr != nil {
796		slog.Error("failed to save session title & usage", "error", saveErr)
797		return
798	}
799}
800
801func (a *sessionAgent) openrouterCost(metadata fantasy.ProviderMetadata) *float64 {
802	openrouterMetadata, ok := metadata[openrouter.Name]
803	if !ok {
804		return nil
805	}
806
807	opts, ok := openrouterMetadata.(*openrouter.ProviderMetadata)
808	if !ok {
809		return nil
810	}
811	return &opts.Usage.Cost
812}
813
814func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage fantasy.Usage, overrideCost *float64) {
815	modelConfig := model.CatwalkCfg
816	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
817		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
818		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
819		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
820
821	a.eventTokensUsed(session.ID, model, usage, cost)
822
823	if overrideCost != nil {
824		session.Cost += *overrideCost
825	} else {
826		session.Cost += cost
827	}
828
829	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
830	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
831}
832
833func (a *sessionAgent) Cancel(sessionID string) {
834	// Cancel regular requests
835	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
836		slog.Info("Request cancellation initiated", "session_id", sessionID)
837		cancel()
838	}
839
840	// Also check for summarize requests
841	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
842		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
843		cancel()
844	}
845
846	if a.QueuedPrompts(sessionID) > 0 {
847		slog.Info("Clearing queued prompts", "session_id", sessionID)
848		a.messageQueue.Del(sessionID)
849	}
850}
851
852func (a *sessionAgent) ClearQueue(sessionID string) {
853	if a.QueuedPrompts(sessionID) > 0 {
854		slog.Info("Clearing queued prompts", "session_id", sessionID)
855		a.messageQueue.Del(sessionID)
856	}
857}
858
859func (a *sessionAgent) CancelAll() {
860	if !a.IsBusy() {
861		// still ensure notifications are cancelled even when not busy
862		for cancel := range a.completionCancels.Seq() {
863			if cancel != nil {
864				cancel()
865			}
866		}
867		a.completionCancels.Reset(make(map[string]context.CancelFunc))
868		return
869	}
870	for key := range a.activeRequests.Seq2() {
871		a.Cancel(key) // key is sessionID
872	}
873
874	timeout := time.After(5 * time.Second)
875	for a.IsBusy() {
876		select {
877		case <-timeout:
878			return
879		default:
880			time.Sleep(200 * time.Millisecond)
881		}
882	}
883
884	for cancel := range a.completionCancels.Seq() {
885		if cancel != nil {
886			cancel()
887		}
888	}
889	a.completionCancels.Reset(make(map[string]context.CancelFunc))
890}
891
892func (a *sessionAgent) IsBusy() bool {
893	var busy bool
894	for cancelFunc := range a.activeRequests.Seq() {
895		if cancelFunc != nil {
896			busy = true
897			break
898		}
899	}
900	return busy
901}
902
903func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
904	_, busy := a.activeRequests.Get(sessionID)
905	return busy
906}
907
908func (a *sessionAgent) QueuedPrompts(sessionID string) int {
909	l, ok := a.messageQueue.Get(sessionID)
910	if !ok {
911		return 0
912	}
913	return len(l)
914}
915
916func (a *sessionAgent) SetModels(large Model, small Model) {
917	a.largeModel = large
918	a.smallModel = small
919}
920
921func (a *sessionAgent) SetTools(tools []fantasy.AgentTool) {
922	a.tools = tools
923}
924
925func (a *sessionAgent) Model() Model {
926	return a.largeModel
927}