1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/llm/tools"
 16	"github.com/charmbracelet/crush/internal/message"
 17	"github.com/charmbracelet/crush/internal/permission"
 18	"github.com/charmbracelet/crush/internal/session"
 19	"github.com/charmbracelet/fantasy/ai"
 20	"github.com/charmbracelet/fantasy/anthropic"
 21)
 22
 23//go:embed templates/title.md
 24var titlePrompt []byte
 25
 26//go:embed templates/summary.md
 27var summaryPrompt []byte
 28
 29type SessionAgentCall struct {
 30	SessionID        string
 31	Prompt           string
 32	ProviderOptions  ai.ProviderOptions
 33	Attachments      []message.Attachment
 34	MaxOutputTokens  int64
 35	Temperature      *float64
 36	TopP             *float64
 37	TopK             *int64
 38	FrequencyPenalty *float64
 39	PresencePenalty  *float64
 40}
 41
 42type SessionAgent interface {
 43	Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
 44	SetModels(large Model, small Model)
 45	SetTools(tools []ai.AgentTool)
 46	Cancel(sessionID string)
 47	CancelAll()
 48	IsSessionBusy(sessionID string) bool
 49	IsBusy() bool
 50	QueuedPrompts(sessionID string) int
 51	ClearQueue(sessionID string)
 52	Summarize(context.Context, string) error
 53}
 54
 55type Model struct {
 56	model  ai.LanguageModel
 57	config catwalk.Model
 58}
 59
 60type sessionAgent struct {
 61	largeModel   Model
 62	smallModel   Model
 63	systemPrompt string
 64	tools        []ai.AgentTool
 65	sessions     session.Service
 66	messages     message.Service
 67
 68	messageQueue   *csync.Map[string, []SessionAgentCall]
 69	activeRequests *csync.Map[string, context.CancelFunc]
 70}
 71
 72type SessionAgentOption func(*sessionAgent)
 73
 74func NewSessionAgent(
 75	largeModel Model,
 76	smallModel Model,
 77	systemPrompt string,
 78	sessions session.Service,
 79	messages message.Service,
 80	tools ...ai.AgentTool,
 81) SessionAgent {
 82	return &sessionAgent{
 83		largeModel:     largeModel,
 84		smallModel:     smallModel,
 85		systemPrompt:   systemPrompt,
 86		sessions:       sessions,
 87		messages:       messages,
 88		tools:          tools,
 89		messageQueue:   csync.NewMap[string, []SessionAgentCall](),
 90		activeRequests: csync.NewMap[string, context.CancelFunc](),
 91	}
 92}
 93
 94func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
 95	if call.Prompt == "" {
 96		return nil, ErrEmptyPrompt
 97	}
 98	if call.SessionID == "" {
 99		return nil, ErrSessionMissing
100	}
101
102	// Queue the message if busy
103	if a.IsSessionBusy(call.SessionID) {
104		existing, ok := a.messageQueue.Get(call.SessionID)
105		if !ok {
106			existing = []SessionAgentCall{}
107		}
108		existing = append(existing, call)
109		a.messageQueue.Set(call.SessionID, existing)
110		return nil, nil
111	}
112
113	if len(a.tools) > 0 {
114		// add anthropic caching to the last tool
115		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
116	}
117
118	agent := ai.NewAgent(
119		a.largeModel.model,
120		ai.WithSystemPrompt(a.systemPrompt),
121		ai.WithTools(a.tools...),
122	)
123
124	currentSession, err := a.sessions.Get(ctx, call.SessionID)
125	if err != nil {
126		return nil, fmt.Errorf("failed to get session: %w", err)
127	}
128
129	msgs, err := a.getSessionMessages(ctx, currentSession)
130	if err != nil {
131		return nil, fmt.Errorf("failed to get session messages: %w", err)
132	}
133
134	var wg sync.WaitGroup
135	// Generate title if first message
136	if len(msgs) == 0 {
137		wg.Go(func() {
138			a.generateTitle(ctx, currentSession, call.Prompt)
139		})
140	}
141
142	// Add the user message to the session
143	_, err = a.createUserMessage(ctx, call)
144	if err != nil {
145		return nil, err
146	}
147
148	// add the session to the context
149	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
150
151	genCtx, cancel := context.WithCancel(ctx)
152	a.activeRequests.Set(call.SessionID, cancel)
153
154	defer cancel()
155	defer a.activeRequests.Del(call.SessionID)
156
157	history, files := a.preparePrompt(msgs, call.Attachments...)
158
159	var currentAssistant *message.Message
160	result, err := agent.Stream(genCtx, ai.AgentStreamCall{
161		Prompt:           call.Prompt,
162		Files:            files,
163		Messages:         history,
164		ProviderOptions:  call.ProviderOptions,
165		MaxOutputTokens:  &call.MaxOutputTokens,
166		TopP:             call.TopP,
167		Temperature:      call.Temperature,
168		PresencePenalty:  call.PresencePenalty,
169		TopK:             call.TopK,
170		FrequencyPenalty: call.FrequencyPenalty,
171		// Before each step create the new assistant message
172		PrepareStep: func(options ai.PrepareStepFunctionOptions) (prepared ai.PrepareStepResult, err error) {
173			var assistantMsg message.Message
174			assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
175				Role:     message.Assistant,
176				Parts:    []message.ContentPart{},
177				Model:    a.largeModel.model.Model(),
178				Provider: a.largeModel.model.Provider(),
179			})
180			if err != nil {
181				return prepared, err
182			}
183
184			currentAssistant = &assistantMsg
185
186			prepared.Messages = options.Messages
187
188			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
189			a.messageQueue.Del(call.SessionID)
190			for _, queued := range queuedCalls {
191				userMessage, createErr := a.createUserMessage(genCtx, queued)
192				if createErr != nil {
193					return prepared, createErr
194				}
195				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
196			}
197
198			lastSystemRoleInx := 0
199			systemMessageUpdated := false
200			for i, msg := range prepared.Messages {
201				// only add cache control to the last message
202				if msg.Role == ai.MessageRoleSystem {
203					lastSystemRoleInx = i
204				} else if !systemMessageUpdated {
205					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
206					systemMessageUpdated = true
207				}
208				// than add cache control to the last 2 messages
209				if i > len(msgs)-3 {
210					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
211				}
212			}
213			return prepared, err
214		},
215		OnReasoningDelta: func(id string, text string) error {
216			currentAssistant.AppendReasoningContent(text)
217			return a.messages.Update(genCtx, *currentAssistant)
218		},
219		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
220			// handle anthropic signature
221			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
222				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
223					currentAssistant.AppendReasoningSignature(reasoning.Signature)
224				}
225			}
226			currentAssistant.FinishThinking()
227			return a.messages.Update(genCtx, *currentAssistant)
228		},
229		OnTextDelta: func(id string, text string) error {
230			currentAssistant.AppendContent(text)
231			return a.messages.Update(genCtx, *currentAssistant)
232		},
233		OnToolInputStart: func(id string, toolName string) error {
234			toolCall := message.ToolCall{
235				ID:               id,
236				Name:             toolName,
237				ProviderExecuted: false,
238				Finished:         false,
239			}
240			currentAssistant.AddToolCall(toolCall)
241			return a.messages.Update(genCtx, *currentAssistant)
242		},
243		OnRetry: func(err *ai.APICallError, delay time.Duration) {
244			// TODO: implement
245		},
246		OnToolResult: func(result ai.ToolResultContent) error {
247			var resultContent string
248			isError := false
249			switch result.Result.GetType() {
250			case ai.ToolResultContentTypeText:
251				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
252				if ok {
253					resultContent = r.Text
254				}
255			case ai.ToolResultContentTypeError:
256				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
257				if ok {
258					isError = true
259					resultContent = r.Error.Error()
260				}
261			case ai.ToolResultContentTypeMedia:
262				// TODO: handle this message type
263			}
264			toolResult := message.ToolResult{
265				ToolCallID: result.ToolCallID,
266				Name:       result.ToolName,
267				Content:    resultContent,
268				IsError:    isError,
269				Metadata:   result.ClientMetadata,
270			}
271			a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
272				Role: message.Tool,
273				Parts: []message.ContentPart{
274					toolResult,
275				},
276			})
277			return a.messages.Update(genCtx, *currentAssistant)
278		},
279		OnStepFinish: func(stepResult ai.StepResult) error {
280			finishReason := message.FinishReasonUnknown
281			switch stepResult.FinishReason {
282			case ai.FinishReasonLength:
283				finishReason = message.FinishReasonMaxTokens
284			case ai.FinishReasonStop:
285				finishReason = message.FinishReasonEndTurn
286			case ai.FinishReasonToolCalls:
287				finishReason = message.FinishReasonToolUse
288			}
289			currentAssistant.AddFinish(finishReason, "", "")
290			a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
291			return a.messages.Update(genCtx, *currentAssistant)
292		},
293	})
294	if err != nil {
295		isCancelErr := errors.Is(err, context.Canceled)
296		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
297		if currentAssistant == nil {
298			return result, err
299		}
300		toolCalls := currentAssistant.ToolCalls()
301		toolResults := currentAssistant.ToolResults()
302		for _, tc := range toolCalls {
303			if !tc.Finished {
304				tc.Finished = true
305				tc.Input = "{}"
306			}
307			currentAssistant.AddToolCall(tc)
308			found := false
309			for _, tr := range toolResults {
310				if tr.ToolCallID == tc.ID {
311					found = true
312					break
313				}
314			}
315			if !found {
316				content := "There was an error while executing the tool"
317				if isCancelErr {
318					content = "Tool execution canceled by user"
319				} else if isPermissionErr {
320					content = "Permission denied"
321				}
322				currentAssistant.AddToolResult(message.ToolResult{
323					ToolCallID: tc.ID,
324					Name:       tc.Name,
325					Content:    content,
326					IsError:    true,
327				})
328			}
329		}
330		if isCancelErr {
331			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
332		} else if isPermissionErr {
333			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
334		} else {
335			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
336		}
337		// INFO: we use the parent context here because the genCtx might have been cancelled
338		updateErr := a.messages.Update(ctx, *currentAssistant)
339		if updateErr != nil {
340			return nil, updateErr
341		}
342	}
343	if err != nil {
344		return nil, err
345	}
346	wg.Wait()
347
348	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
349	if !ok || len(queuedMessages) == 0 {
350		return result, err
351	}
352	// there are queued messages restart the loop
353	firstQueuedMessage := queuedMessages[0]
354	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
355	return a.Run(genCtx, firstQueuedMessage)
356}
357
358func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
359	if a.IsSessionBusy(sessionID) {
360		return ErrSessionBusy
361	}
362
363	currentSession, err := a.sessions.Get(ctx, sessionID)
364	if err != nil {
365		return fmt.Errorf("failed to get session: %w", err)
366	}
367	msgs, err := a.getSessionMessages(ctx, currentSession)
368	if err != nil {
369		return err
370	}
371	if len(msgs) == 0 {
372		// nothing to summarize
373		return nil
374	}
375
376	aiMsgs, _ := a.preparePrompt(msgs)
377
378	genCtx, cancel := context.WithCancel(ctx)
379	a.activeRequests.Set(sessionID, cancel)
380	defer a.activeRequests.Del(sessionID)
381	defer cancel()
382
383	agent := ai.NewAgent(a.largeModel.model,
384		ai.WithSystemPrompt(string(summaryPrompt)),
385	)
386	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
387		Role:     message.Assistant,
388		Model:    a.largeModel.model.Model(),
389		Provider: a.largeModel.model.Provider(),
390	})
391	if err != nil {
392		return err
393	}
394
395	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
396		Prompt:   "Provide a detailed summary of our conversation above.",
397		Messages: aiMsgs,
398		OnReasoningDelta: func(id string, text string) error {
399			summaryMessage.AppendReasoningContent(text)
400			return a.messages.Update(ctx, summaryMessage)
401		},
402		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
403			// handle anthropic signature
404			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
405				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
406					summaryMessage.AppendReasoningSignature(signature.Signature)
407				}
408			}
409			summaryMessage.FinishThinking()
410			return a.messages.Update(ctx, summaryMessage)
411		},
412		OnTextDelta: func(id, text string) error {
413			summaryMessage.AppendContent(text)
414			return a.messages.Update(ctx, summaryMessage)
415		},
416	})
417	if err != nil {
418		return err
419	}
420
421	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
422	err = a.messages.Update(genCtx, summaryMessage)
423	if err != nil {
424		return err
425	}
426
427	a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
428
429	// just in case get just the last usage
430	usage := resp.Response.Usage
431	currentSession.SummaryMessageID = summaryMessage.ID
432	currentSession.CompletionTokens = usage.OutputTokens
433	currentSession.PromptTokens = 0
434	_, err = a.sessions.Save(genCtx, currentSession)
435	return err
436}
437
438func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
439	return ai.ProviderOptions{
440		anthropic.Name: &anthropic.ProviderCacheControlOptions{
441			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
442		},
443	}
444}
445
446func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
447	var attachmentParts []message.ContentPart
448	for _, attachment := range call.Attachments {
449		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
450	}
451	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
452	parts = append(parts, attachmentParts...)
453	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
454		Role:  message.User,
455		Parts: parts,
456	})
457	if err != nil {
458		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
459	}
460	return msg, nil
461}
462
463func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
464	var history []ai.Message
465	for _, m := range msgs {
466		if len(m.Parts) == 0 {
467			continue
468		}
469		// Assistant message without content or tool calls (cancelled before it returned anything)
470		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
471			continue
472		}
473		history = append(history, m.ToAIMessage()...)
474	}
475
476	var files []ai.FilePart
477	for _, attachment := range attachments {
478		files = append(files, ai.FilePart{
479			Filename:  attachment.FileName,
480			Data:      attachment.Content,
481			MediaType: attachment.MimeType,
482		})
483	}
484
485	return history, files
486}
487
488func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
489	msgs, err := a.messages.List(ctx, session.ID)
490	if err != nil {
491		return nil, fmt.Errorf("failed to list messages: %w", err)
492	}
493
494	if session.SummaryMessageID != "" {
495		summaryMsgInex := -1
496		for i, msg := range msgs {
497			if msg.ID == session.SummaryMessageID {
498				summaryMsgInex = i
499				break
500			}
501		}
502		if summaryMsgInex != -1 {
503			msgs = msgs[summaryMsgInex:]
504			msgs[0].Role = message.User
505		}
506	}
507	return msgs, nil
508}
509
510func (a *sessionAgent) generateTitle(ctx context.Context, session session.Session, prompt string) {
511	if prompt == "" {
512		return
513	}
514
515	agent := ai.NewAgent(a.smallModel.model,
516		ai.WithSystemPrompt(string(titlePrompt)),
517		ai.WithMaxOutputTokens(40),
518	)
519
520	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
521		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", prompt),
522	})
523	if err != nil {
524		slog.Error("error generating title", "err", err)
525		return
526	}
527
528	title := resp.Response.Content.Text()
529
530	title = strings.ReplaceAll(title, "\n", " ")
531
532	// remove thinking tags if present
533	if idx := strings.Index(title, "</think>"); idx > 0 {
534		title = title[idx+len("</think>"):]
535	}
536
537	title = strings.TrimSpace(title)
538	if title == "" {
539		slog.Warn("failed to generate title", "warn", "empty title")
540		return
541	}
542
543	session.Title = title
544	a.updateSessionUsage(a.smallModel, &session, resp.TotalUsage)
545	_, saveErr := a.sessions.Save(ctx, session)
546	if saveErr != nil {
547		slog.Error("failed to save session title & usage", "error", saveErr)
548		return
549	}
550}
551
552func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
553	modelConfig := model.config
554	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
555		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
556		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
557		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
558	session.Cost += cost
559	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
560	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
561}
562
563func (a *sessionAgent) Cancel(sessionID string) {
564	// Cancel regular requests
565	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
566		slog.Info("Request cancellation initiated", "session_id", sessionID)
567		cancel()
568	}
569
570	// Also check for summarize requests
571	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
572		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
573		cancel()
574	}
575
576	if a.QueuedPrompts(sessionID) > 0 {
577		slog.Info("Clearing queued prompts", "session_id", sessionID)
578		a.messageQueue.Del(sessionID)
579	}
580}
581
582func (a *sessionAgent) ClearQueue(sessionID string) {
583	if a.QueuedPrompts(sessionID) > 0 {
584		slog.Info("Clearing queued prompts", "session_id", sessionID)
585		a.messageQueue.Del(sessionID)
586	}
587}
588
589func (a *sessionAgent) CancelAll() {
590	if !a.IsBusy() {
591		return
592	}
593	for key := range a.activeRequests.Seq2() {
594		a.Cancel(key) // key is sessionID
595	}
596
597	timeout := time.After(5 * time.Second)
598	for a.IsBusy() {
599		select {
600		case <-timeout:
601			return
602		default:
603			time.Sleep(200 * time.Millisecond)
604		}
605	}
606}
607
608func (a *sessionAgent) IsBusy() bool {
609	var busy bool
610	for cancelFunc := range a.activeRequests.Seq() {
611		if cancelFunc != nil {
612			busy = true
613			break
614		}
615	}
616	return busy
617}
618
619func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
620	_, busy := a.activeRequests.Get(sessionID)
621	return busy
622}
623
624func (a *sessionAgent) QueuedPrompts(sessionID string) int {
625	l, ok := a.messageQueue.Get(sessionID)
626	if !ok {
627		return 0
628	}
629	return len(l)
630}
631
632func (a *sessionAgent) SetModels(large Model, small Model) {
633	a.largeModel = large
634	a.smallModel = small
635}
636
637func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
638	a.tools = tools
639}