agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/agent/tools"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/session"
 20	"github.com/charmbracelet/fantasy/ai"
 21	"github.com/charmbracelet/fantasy/anthropic"
 22)
 23
 24//go:embed templates/title.md
 25var titlePrompt []byte
 26
 27//go:embed templates/summary.md
 28var summaryPrompt []byte
 29
 30type SessionAgentCall struct {
 31	SessionID        string
 32	Prompt           string
 33	ProviderOptions  ai.ProviderOptions
 34	Attachments      []message.Attachment
 35	MaxOutputTokens  int64
 36	Temperature      *float64
 37	TopP             *float64
 38	TopK             *int64
 39	FrequencyPenalty *float64
 40	PresencePenalty  *float64
 41}
 42
 43type SessionAgent interface {
 44	Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
 45	SetModels(large Model, small Model)
 46	SetTools(tools []ai.AgentTool)
 47	Cancel(sessionID string)
 48	CancelAll()
 49	IsSessionBusy(sessionID string) bool
 50	IsBusy() bool
 51	QueuedPrompts(sessionID string) int
 52	ClearQueue(sessionID string)
 53	Summarize(context.Context, string) error
 54	Model() Model
 55}
 56
 57type Model struct {
 58	Model      ai.LanguageModel
 59	CatwalkCfg catwalk.Model
 60	ModelCfg   config.SelectedModel
 61}
 62
 63type sessionAgent struct {
 64	largeModel   Model
 65	smallModel   Model
 66	systemPrompt string
 67	tools        []ai.AgentTool
 68	sessions     session.Service
 69	messages     message.Service
 70
 71	messageQueue   *csync.Map[string, []SessionAgentCall]
 72	activeRequests *csync.Map[string, context.CancelFunc]
 73}
 74
 75type SessionAgentOption func(*sessionAgent)
 76
 77func NewSessionAgent(
 78	largeModel Model,
 79	smallModel Model,
 80	systemPrompt string,
 81	sessions session.Service,
 82	messages message.Service,
 83	tools ...ai.AgentTool,
 84) SessionAgent {
 85	return &sessionAgent{
 86		largeModel:     largeModel,
 87		smallModel:     smallModel,
 88		systemPrompt:   systemPrompt,
 89		sessions:       sessions,
 90		messages:       messages,
 91		tools:          tools,
 92		messageQueue:   csync.NewMap[string, []SessionAgentCall](),
 93		activeRequests: csync.NewMap[string, context.CancelFunc](),
 94	}
 95}
 96
 97func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
 98	if call.Prompt == "" {
 99		return nil, ErrEmptyPrompt
100	}
101	if call.SessionID == "" {
102		return nil, ErrSessionMissing
103	}
104
105	// Queue the message if busy
106	if a.IsSessionBusy(call.SessionID) {
107		existing, ok := a.messageQueue.Get(call.SessionID)
108		if !ok {
109			existing = []SessionAgentCall{}
110		}
111		existing = append(existing, call)
112		a.messageQueue.Set(call.SessionID, existing)
113		return nil, nil
114	}
115
116	if len(a.tools) > 0 {
117		// add anthropic caching to the last tool
118		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
119	}
120
121	agent := ai.NewAgent(
122		a.largeModel.Model,
123		ai.WithSystemPrompt(a.systemPrompt),
124		ai.WithTools(a.tools...),
125	)
126
127	sessionLock := sync.Mutex{}
128	currentSession, err := a.sessions.Get(ctx, call.SessionID)
129	if err != nil {
130		return nil, fmt.Errorf("failed to get session: %w", err)
131	}
132
133	msgs, err := a.getSessionMessages(ctx, currentSession)
134	if err != nil {
135		return nil, fmt.Errorf("failed to get session messages: %w", err)
136	}
137
138	var wg sync.WaitGroup
139	// Generate title if first message
140	if len(msgs) == 0 {
141		wg.Go(func() {
142			sessionLock.Lock()
143			a.generateTitle(ctx, &currentSession, call.Prompt)
144			sessionLock.Unlock()
145		})
146	}
147
148	// Add the user message to the session
149	_, err = a.createUserMessage(ctx, call)
150	if err != nil {
151		return nil, err
152	}
153
154	// add the session to the context
155	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
156
157	genCtx, cancel := context.WithCancel(ctx)
158	a.activeRequests.Set(call.SessionID, cancel)
159
160	defer cancel()
161	defer a.activeRequests.Del(call.SessionID)
162
163	history, files := a.preparePrompt(msgs, call.Attachments...)
164
165	var currentAssistant *message.Message
166	result, err := agent.Stream(genCtx, ai.AgentStreamCall{
167		Prompt:           call.Prompt,
168		Files:            files,
169		Messages:         history,
170		ProviderOptions:  call.ProviderOptions,
171		MaxOutputTokens:  &call.MaxOutputTokens,
172		TopP:             call.TopP,
173		Temperature:      call.Temperature,
174		PresencePenalty:  call.PresencePenalty,
175		TopK:             call.TopK,
176		FrequencyPenalty: call.FrequencyPenalty,
177		// Before each step create the new assistant message
178		PrepareStep: func(callContext context.Context, options ai.PrepareStepFunctionOptions) (_ context.Context, prepared ai.PrepareStepResult, err error) {
179			var assistantMsg message.Message
180			assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
181				Role:     message.Assistant,
182				Parts:    []message.ContentPart{},
183				Model:    a.largeModel.ModelCfg.Model,
184				Provider: a.largeModel.ModelCfg.Provider,
185			})
186			if err != nil {
187				return callContext, prepared, err
188			}
189
190			callContext = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
191
192			currentAssistant = &assistantMsg
193
194			prepared.Messages = options.Messages
195			// reset all cached items
196			for i := range prepared.Messages {
197				prepared.Messages[i].ProviderOptions = nil
198			}
199
200			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
201			a.messageQueue.Del(call.SessionID)
202			for _, queued := range queuedCalls {
203				userMessage, createErr := a.createUserMessage(genCtx, queued)
204				if createErr != nil {
205					return callContext, prepared, createErr
206				}
207				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
208			}
209
210			lastSystemRoleInx := 0
211			systemMessageUpdated := false
212			for i, msg := range prepared.Messages {
213				// only add cache control to the last message
214				if msg.Role == ai.MessageRoleSystem {
215					lastSystemRoleInx = i
216				} else if !systemMessageUpdated {
217					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
218					systemMessageUpdated = true
219				}
220				// than add cache control to the last 2 messages
221				if i > len(prepared.Messages)-3 {
222					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
223				}
224			}
225			return callContext, prepared, err
226		},
227		OnReasoningDelta: func(id string, text string) error {
228			currentAssistant.AppendReasoningContent(text)
229			return a.messages.Update(genCtx, *currentAssistant)
230		},
231		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
232			// handle anthropic signature
233			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
234				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
235					currentAssistant.AppendReasoningSignature(reasoning.Signature)
236				}
237			}
238			currentAssistant.FinishThinking()
239			return a.messages.Update(genCtx, *currentAssistant)
240		},
241		OnTextDelta: func(id string, text string) error {
242			currentAssistant.AppendContent(text)
243			return a.messages.Update(genCtx, *currentAssistant)
244		},
245		OnToolInputStart: func(id string, toolName string) error {
246			toolCall := message.ToolCall{
247				ID:               id,
248				Name:             toolName,
249				ProviderExecuted: false,
250				Finished:         false,
251			}
252			currentAssistant.AddToolCall(toolCall)
253			return a.messages.Update(genCtx, *currentAssistant)
254		},
255		OnRetry: func(err *ai.APICallError, delay time.Duration) {
256			// TODO: implement
257		},
258		OnToolCall: func(tc ai.ToolCallContent) error {
259			toolCall := message.ToolCall{
260				ID:               tc.ToolCallID,
261				Name:             tc.ToolName,
262				Input:            tc.Input,
263				ProviderExecuted: false,
264				Finished:         true,
265			}
266			currentAssistant.AddToolCall(toolCall)
267			return a.messages.Update(genCtx, *currentAssistant)
268		},
269		OnToolResult: func(result ai.ToolResultContent) error {
270			var resultContent string
271			isError := false
272			switch result.Result.GetType() {
273			case ai.ToolResultContentTypeText:
274				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
275				if ok {
276					resultContent = r.Text
277				}
278			case ai.ToolResultContentTypeError:
279				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
280				if ok {
281					isError = true
282					resultContent = r.Error.Error()
283				}
284			case ai.ToolResultContentTypeMedia:
285				// TODO: handle this message type
286			}
287			toolResult := message.ToolResult{
288				ToolCallID: result.ToolCallID,
289				Name:       result.ToolName,
290				Content:    resultContent,
291				IsError:    isError,
292				Metadata:   result.ClientMetadata,
293			}
294			a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
295				Role: message.Tool,
296				Parts: []message.ContentPart{
297					toolResult,
298				},
299			})
300			return a.messages.Update(genCtx, *currentAssistant)
301		},
302		OnStepFinish: func(stepResult ai.StepResult) error {
303			finishReason := message.FinishReasonUnknown
304			switch stepResult.FinishReason {
305			case ai.FinishReasonLength:
306				finishReason = message.FinishReasonMaxTokens
307			case ai.FinishReasonStop:
308				finishReason = message.FinishReasonEndTurn
309			case ai.FinishReasonToolCalls:
310				finishReason = message.FinishReasonToolUse
311			}
312			currentAssistant.AddFinish(finishReason, "", "")
313			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
314			sessionLock.Lock()
315			_, sessionErr := a.sessions.Save(genCtx, currentSession)
316			sessionLock.Unlock()
317			if sessionErr != nil {
318				return sessionErr
319			}
320			return a.messages.Update(genCtx, *currentAssistant)
321		},
322	})
323	if err != nil {
324		isCancelErr := errors.Is(err, context.Canceled)
325		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
326		if currentAssistant == nil {
327			return result, err
328		}
329		toolCalls := currentAssistant.ToolCalls()
330		toolResults := currentAssistant.ToolResults()
331		for _, tc := range toolCalls {
332			if !tc.Finished {
333				tc.Finished = true
334				tc.Input = "{}"
335			}
336			currentAssistant.AddToolCall(tc)
337			found := false
338			for _, tr := range toolResults {
339				if tr.ToolCallID == tc.ID {
340					found = true
341					break
342				}
343			}
344			if !found {
345				content := "There was an error while executing the tool"
346				if isCancelErr {
347					content = "Tool execution canceled by user"
348				} else if isPermissionErr {
349					content = "Permission denied"
350				}
351				currentAssistant.AddToolResult(message.ToolResult{
352					ToolCallID: tc.ID,
353					Name:       tc.Name,
354					Content:    content,
355					IsError:    true,
356				})
357			}
358		}
359		if isCancelErr {
360			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
361		} else if isPermissionErr {
362			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
363		} else {
364			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
365		}
366		// INFO: we use the parent context here because the genCtx might have been cancelled
367		updateErr := a.messages.Update(ctx, *currentAssistant)
368		if updateErr != nil {
369			return nil, updateErr
370		}
371	}
372	if err != nil {
373		return nil, err
374	}
375	wg.Wait()
376
377	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
378	if !ok || len(queuedMessages) == 0 {
379		return result, err
380	}
381	// there are queued messages restart the loop
382	firstQueuedMessage := queuedMessages[0]
383	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
384	return a.Run(genCtx, firstQueuedMessage)
385}
386
387func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
388	if a.IsSessionBusy(sessionID) {
389		return ErrSessionBusy
390	}
391
392	currentSession, err := a.sessions.Get(ctx, sessionID)
393	if err != nil {
394		return fmt.Errorf("failed to get session: %w", err)
395	}
396	msgs, err := a.getSessionMessages(ctx, currentSession)
397	if err != nil {
398		return err
399	}
400	if len(msgs) == 0 {
401		// nothing to summarize
402		return nil
403	}
404
405	aiMsgs, _ := a.preparePrompt(msgs)
406
407	genCtx, cancel := context.WithCancel(ctx)
408	a.activeRequests.Set(sessionID, cancel)
409	defer a.activeRequests.Del(sessionID)
410	defer cancel()
411
412	agent := ai.NewAgent(a.largeModel.Model,
413		ai.WithSystemPrompt(string(summaryPrompt)),
414	)
415	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
416		Role:     message.Assistant,
417		Model:    a.largeModel.Model.Model(),
418		Provider: a.largeModel.Model.Provider(),
419	})
420	if err != nil {
421		return err
422	}
423
424	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
425		Prompt:   "Provide a detailed summary of our conversation above.",
426		Messages: aiMsgs,
427		OnReasoningDelta: func(id string, text string) error {
428			summaryMessage.AppendReasoningContent(text)
429			return a.messages.Update(ctx, summaryMessage)
430		},
431		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
432			// handle anthropic signature
433			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
434				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
435					summaryMessage.AppendReasoningSignature(signature.Signature)
436				}
437			}
438			summaryMessage.FinishThinking()
439			return a.messages.Update(ctx, summaryMessage)
440		},
441		OnTextDelta: func(id, text string) error {
442			summaryMessage.AppendContent(text)
443			return a.messages.Update(ctx, summaryMessage)
444		},
445	})
446	if err != nil {
447		return err
448	}
449
450	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
451	err = a.messages.Update(genCtx, summaryMessage)
452	if err != nil {
453		return err
454	}
455
456	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
457
458	// just in case get just the last usage
459	usage := resp.Response.Usage
460	currentSession.SummaryMessageID = summaryMessage.ID
461	currentSession.CompletionTokens = usage.OutputTokens
462	currentSession.PromptTokens = 0
463	_, err = a.sessions.Save(genCtx, currentSession)
464	return err
465}
466
467func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
468	return ai.ProviderOptions{
469		anthropic.Name: &anthropic.ProviderCacheControlOptions{
470			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
471		},
472	}
473}
474
475func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
476	var attachmentParts []message.ContentPart
477	for _, attachment := range call.Attachments {
478		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
479	}
480	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
481	parts = append(parts, attachmentParts...)
482	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
483		Role:  message.User,
484		Parts: parts,
485	})
486	if err != nil {
487		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
488	}
489	return msg, nil
490}
491
492func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
493	var history []ai.Message
494	for _, m := range msgs {
495		if len(m.Parts) == 0 {
496			continue
497		}
498		// Assistant message without content or tool calls (cancelled before it returned anything)
499		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
500			continue
501		}
502		history = append(history, m.ToAIMessage()...)
503	}
504
505	var files []ai.FilePart
506	for _, attachment := range attachments {
507		files = append(files, ai.FilePart{
508			Filename:  attachment.FileName,
509			Data:      attachment.Content,
510			MediaType: attachment.MimeType,
511		})
512	}
513
514	return history, files
515}
516
517func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
518	msgs, err := a.messages.List(ctx, session.ID)
519	if err != nil {
520		return nil, fmt.Errorf("failed to list messages: %w", err)
521	}
522
523	if session.SummaryMessageID != "" {
524		summaryMsgInex := -1
525		for i, msg := range msgs {
526			if msg.ID == session.SummaryMessageID {
527				summaryMsgInex = i
528				break
529			}
530		}
531		if summaryMsgInex != -1 {
532			msgs = msgs[summaryMsgInex:]
533			msgs[0].Role = message.User
534		}
535	}
536	return msgs, nil
537}
538
539func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
540	if prompt == "" {
541		return
542	}
543
544	agent := ai.NewAgent(a.smallModel.Model,
545		ai.WithSystemPrompt(string(titlePrompt)),
546		ai.WithMaxOutputTokens(40),
547	)
548
549	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
550		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", prompt),
551	})
552	if err != nil {
553		slog.Error("error generating title", "err", err)
554		return
555	}
556
557	title := resp.Response.Content.Text()
558
559	title = strings.ReplaceAll(title, "\n", " ")
560
561	// remove thinking tags if present
562	if idx := strings.Index(title, "</think>"); idx > 0 {
563		title = title[idx+len("</think>"):]
564	}
565
566	title = strings.TrimSpace(title)
567	if title == "" {
568		slog.Warn("failed to generate title", "warn", "empty title")
569		return
570	}
571
572	session.Title = title
573	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
574	_, saveErr := a.sessions.Save(ctx, *session)
575	if saveErr != nil {
576		slog.Error("failed to save session title & usage", "error", saveErr)
577		return
578	}
579}
580
581func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
582	modelConfig := model.CatwalkCfg
583	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
584		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
585		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
586		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
587	session.Cost += cost
588	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
589	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
590}
591
592func (a *sessionAgent) Cancel(sessionID string) {
593	// Cancel regular requests
594	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
595		slog.Info("Request cancellation initiated", "session_id", sessionID)
596		cancel()
597	}
598
599	// Also check for summarize requests
600	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
601		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
602		cancel()
603	}
604
605	if a.QueuedPrompts(sessionID) > 0 {
606		slog.Info("Clearing queued prompts", "session_id", sessionID)
607		a.messageQueue.Del(sessionID)
608	}
609}
610
611func (a *sessionAgent) ClearQueue(sessionID string) {
612	if a.QueuedPrompts(sessionID) > 0 {
613		slog.Info("Clearing queued prompts", "session_id", sessionID)
614		a.messageQueue.Del(sessionID)
615	}
616}
617
618func (a *sessionAgent) CancelAll() {
619	if !a.IsBusy() {
620		return
621	}
622	for key := range a.activeRequests.Seq2() {
623		a.Cancel(key) // key is sessionID
624	}
625
626	timeout := time.After(5 * time.Second)
627	for a.IsBusy() {
628		select {
629		case <-timeout:
630			return
631		default:
632			time.Sleep(200 * time.Millisecond)
633		}
634	}
635}
636
637func (a *sessionAgent) IsBusy() bool {
638	var busy bool
639	for cancelFunc := range a.activeRequests.Seq() {
640		if cancelFunc != nil {
641			busy = true
642			break
643		}
644	}
645	return busy
646}
647
648func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
649	_, busy := a.activeRequests.Get(sessionID)
650	return busy
651}
652
653func (a *sessionAgent) QueuedPrompts(sessionID string) int {
654	l, ok := a.messageQueue.Get(sessionID)
655	if !ok {
656		return 0
657	}
658	return len(l)
659}
660
661func (a *sessionAgent) SetModels(large Model, small Model) {
662	a.largeModel = large
663	a.smallModel = small
664}
665
666func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
667	a.tools = tools
668}
669
670func (a *sessionAgent) Model() Model {
671	return a.largeModel
672}