agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"encoding/json"
  7	"errors"
  8	"fmt"
  9	"log/slog"
 10	"strings"
 11	"sync"
 12	"time"
 13
 14	"github.com/charmbracelet/catwalk/pkg/catwalk"
 15	"github.com/charmbracelet/crush/internal/agent/tools"
 16	"github.com/charmbracelet/crush/internal/config"
 17	"github.com/charmbracelet/crush/internal/csync"
 18	"github.com/charmbracelet/crush/internal/message"
 19	"github.com/charmbracelet/crush/internal/permission"
 20	"github.com/charmbracelet/crush/internal/session"
 21	"github.com/charmbracelet/fantasy/ai"
 22	"github.com/charmbracelet/fantasy/anthropic"
 23)
 24
 25//go:embed templates/title.md
 26var titlePrompt []byte
 27
 28//go:embed templates/summary.md
 29var summaryPrompt []byte
 30
 31type SessionAgentCall struct {
 32	SessionID        string
 33	Prompt           string
 34	ProviderOptions  ai.ProviderOptions
 35	Attachments      []message.Attachment
 36	MaxOutputTokens  int64
 37	Temperature      *float64
 38	TopP             *float64
 39	TopK             *int64
 40	FrequencyPenalty *float64
 41	PresencePenalty  *float64
 42}
 43
 44type SessionAgent interface {
 45	Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
 46	SetModels(large Model, small Model)
 47	SetTools(tools []ai.AgentTool)
 48	Cancel(sessionID string)
 49	CancelAll()
 50	IsSessionBusy(sessionID string) bool
 51	IsBusy() bool
 52	QueuedPrompts(sessionID string) int
 53	ClearQueue(sessionID string)
 54	Summarize(context.Context, string) error
 55	Model() Model
 56}
 57
 58type Model struct {
 59	Model      ai.LanguageModel
 60	CatwalkCfg catwalk.Model
 61	ModelCfg   config.SelectedModel
 62}
 63
 64type sessionAgent struct {
 65	largeModel   Model
 66	smallModel   Model
 67	systemPrompt string
 68	tools        []ai.AgentTool
 69	sessions     session.Service
 70	messages     message.Service
 71
 72	messageQueue   *csync.Map[string, []SessionAgentCall]
 73	activeRequests *csync.Map[string, context.CancelFunc]
 74}
 75
 76type SessionAgentOption func(*sessionAgent)
 77
 78func NewSessionAgent(
 79	largeModel Model,
 80	smallModel Model,
 81	systemPrompt string,
 82	sessions session.Service,
 83	messages message.Service,
 84	tools ...ai.AgentTool,
 85) SessionAgent {
 86	return &sessionAgent{
 87		largeModel:     largeModel,
 88		smallModel:     smallModel,
 89		systemPrompt:   systemPrompt,
 90		sessions:       sessions,
 91		messages:       messages,
 92		tools:          tools,
 93		messageQueue:   csync.NewMap[string, []SessionAgentCall](),
 94		activeRequests: csync.NewMap[string, context.CancelFunc](),
 95	}
 96}
 97
 98func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
 99	if call.Prompt == "" {
100		return nil, ErrEmptyPrompt
101	}
102	if call.SessionID == "" {
103		return nil, ErrSessionMissing
104	}
105
106	// Queue the message if busy
107	if a.IsSessionBusy(call.SessionID) {
108		existing, ok := a.messageQueue.Get(call.SessionID)
109		if !ok {
110			existing = []SessionAgentCall{}
111		}
112		existing = append(existing, call)
113		a.messageQueue.Set(call.SessionID, existing)
114		return nil, nil
115	}
116
117	if len(a.tools) > 0 {
118		// add anthropic caching to the last tool
119		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
120	}
121
122	agent := ai.NewAgent(
123		a.largeModel.Model,
124		ai.WithSystemPrompt(a.systemPrompt),
125		ai.WithTools(a.tools...),
126	)
127
128	sessionLock := sync.Mutex{}
129	currentSession, err := a.sessions.Get(ctx, call.SessionID)
130	if err != nil {
131		return nil, fmt.Errorf("failed to get session: %w", err)
132	}
133
134	msgs, err := a.getSessionMessages(ctx, currentSession)
135	if err != nil {
136		return nil, fmt.Errorf("failed to get session messages: %w", err)
137	}
138
139	var wg sync.WaitGroup
140	// Generate title if first message
141	if len(msgs) == 0 {
142		wg.Go(func() {
143			sessionLock.Lock()
144			a.generateTitle(ctx, &currentSession, call.Prompt)
145			sessionLock.Unlock()
146		})
147	}
148
149	// Add the user message to the session
150	_, err = a.createUserMessage(ctx, call)
151	if err != nil {
152		return nil, err
153	}
154
155	// add the session to the context
156	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
157
158	genCtx, cancel := context.WithCancel(ctx)
159	a.activeRequests.Set(call.SessionID, cancel)
160
161	defer cancel()
162	defer a.activeRequests.Del(call.SessionID)
163
164	history, files := a.preparePrompt(msgs, call.Attachments...)
165
166	var currentAssistant *message.Message
167	result, err := agent.Stream(genCtx, ai.AgentStreamCall{
168		Prompt:           call.Prompt,
169		Files:            files,
170		Messages:         history,
171		ProviderOptions:  call.ProviderOptions,
172		MaxOutputTokens:  &call.MaxOutputTokens,
173		TopP:             call.TopP,
174		Temperature:      call.Temperature,
175		PresencePenalty:  call.PresencePenalty,
176		TopK:             call.TopK,
177		FrequencyPenalty: call.FrequencyPenalty,
178		// Before each step create the new assistant message
179		PrepareStep: func(callContext context.Context, options ai.PrepareStepFunctionOptions) (_ context.Context, prepared ai.PrepareStepResult, err error) {
180			var assistantMsg message.Message
181			assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
182				Role:     message.Assistant,
183				Parts:    []message.ContentPart{},
184				Model:    a.largeModel.ModelCfg.Model,
185				Provider: a.largeModel.ModelCfg.Provider,
186			})
187			if err != nil {
188				return callContext, prepared, err
189			}
190
191			callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
192
193			currentAssistant = &assistantMsg
194
195			prepared.Messages = options.Messages
196			// reset all cached items
197			for i := range prepared.Messages {
198				prepared.Messages[i].ProviderOptions = nil
199			}
200
201			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
202			a.messageQueue.Del(call.SessionID)
203			for _, queued := range queuedCalls {
204				userMessage, createErr := a.createUserMessage(callContext, queued)
205				if createErr != nil {
206					return callContext, prepared, createErr
207				}
208				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
209			}
210
211			lastSystemRoleInx := 0
212			systemMessageUpdated := false
213			for i, msg := range prepared.Messages {
214				// only add cache control to the last message
215				if msg.Role == ai.MessageRoleSystem {
216					lastSystemRoleInx = i
217				} else if !systemMessageUpdated {
218					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
219					systemMessageUpdated = true
220				}
221				// than add cache control to the last 2 messages
222				if i > len(prepared.Messages)-3 {
223					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
224				}
225			}
226			return callContext, prepared, err
227		},
228		OnReasoningDelta: func(id string, text string) error {
229			currentAssistant.AppendReasoningContent(text)
230			return a.messages.Update(genCtx, *currentAssistant)
231		},
232		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
233			// handle anthropic signature
234			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
235				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
236					currentAssistant.AppendReasoningSignature(reasoning.Signature)
237				}
238			}
239			currentAssistant.FinishThinking()
240			return a.messages.Update(genCtx, *currentAssistant)
241		},
242		OnTextDelta: func(id string, text string) error {
243			currentAssistant.AppendContent(text)
244			return a.messages.Update(genCtx, *currentAssistant)
245		},
246		OnToolInputStart: func(id string, toolName string) error {
247			toolCall := message.ToolCall{
248				ID:               id,
249				Name:             toolName,
250				ProviderExecuted: false,
251				Finished:         false,
252			}
253			currentAssistant.AddToolCall(toolCall)
254			return a.messages.Update(genCtx, *currentAssistant)
255		},
256		OnRetry: func(err *ai.APICallError, delay time.Duration) {
257			// TODO: implement
258		},
259		OnToolCall: func(tc ai.ToolCallContent) error {
260			toolCall := message.ToolCall{
261				ID:               tc.ToolCallID,
262				Name:             tc.ToolName,
263				Input:            tc.Input,
264				ProviderExecuted: false,
265				Finished:         true,
266			}
267			currentAssistant.AddToolCall(toolCall)
268			return a.messages.Update(genCtx, *currentAssistant)
269		},
270		OnToolResult: func(result ai.ToolResultContent) error {
271			var resultContent string
272			isError := false
273			switch result.Result.GetType() {
274			case ai.ToolResultContentTypeText:
275				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
276				if ok {
277					resultContent = r.Text
278				}
279			case ai.ToolResultContentTypeError:
280				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
281				if ok {
282					isError = true
283					resultContent = r.Error.Error()
284				}
285			case ai.ToolResultContentTypeMedia:
286				// TODO: handle this message type
287			}
288			toolResult := message.ToolResult{
289				ToolCallID: result.ToolCallID,
290				Name:       result.ToolName,
291				Content:    resultContent,
292				IsError:    isError,
293				Metadata:   result.ClientMetadata,
294			}
295			_, err := a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
296				Role: message.Tool,
297				Parts: []message.ContentPart{
298					toolResult,
299				},
300			})
301			if err != nil {
302				return err
303			}
304			return a.messages.Update(genCtx, *currentAssistant)
305		},
306		OnStepFinish: func(stepResult ai.StepResult) error {
307			finishReason := message.FinishReasonUnknown
308			switch stepResult.FinishReason {
309			case ai.FinishReasonLength:
310				finishReason = message.FinishReasonMaxTokens
311			case ai.FinishReasonStop:
312				finishReason = message.FinishReasonEndTurn
313			case ai.FinishReasonToolCalls:
314				finishReason = message.FinishReasonToolUse
315			}
316			currentAssistant.AddFinish(finishReason, "", "")
317			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
318			sessionLock.Lock()
319			_, sessionErr := a.sessions.Save(genCtx, currentSession)
320			sessionLock.Unlock()
321			if sessionErr != nil {
322				return sessionErr
323			}
324			return a.messages.Update(genCtx, *currentAssistant)
325		},
326	})
327	if err != nil {
328		isCancelErr := errors.Is(err, context.Canceled)
329		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
330		if currentAssistant == nil {
331			return result, err
332		}
333		toolCalls := currentAssistant.ToolCalls()
334		toolResults := currentAssistant.ToolResults()
335		// INFO: we use the parent context here because the genCtx has been cancelled
336		msgs, createErr := a.messages.List(ctx, currentSession.ID)
337		if createErr != nil {
338			return nil, createErr
339		}
340		for _, tc := range toolCalls {
341			if !tc.Finished {
342				tc.Finished = true
343				tc.Input = "{}"
344			}
345			currentAssistant.AddToolCall(tc)
346
347			found := false
348			for _, msg := range msgs {
349				if msg.Role == message.Tool {
350					for _, tr := range toolResults {
351						if tr.ToolCallID == tc.ID {
352							found = true
353							break
354						}
355					}
356				}
357				if found {
358					break
359				}
360			}
361			if !found {
362				content := "There was an error while executing the tool"
363				if isCancelErr {
364					content = "Tool execution canceled by user"
365				} else if isPermissionErr {
366					content = "Permission denied"
367				}
368				toolResult := message.ToolResult{
369					ToolCallID: tc.ID,
370					Name:       tc.Name,
371					Content:    content,
372					IsError:    true,
373				}
374				_, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
375					Role: message.Tool,
376					Parts: []message.ContentPart{
377						toolResult,
378					},
379				})
380				if createErr != nil {
381					return nil, createErr
382				}
383			}
384		}
385		if isCancelErr {
386			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
387		} else if isPermissionErr {
388			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
389		} else {
390			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
391		}
392		// INFO: we use the parent context here because the genCtx has been cancelled
393		updateErr := a.messages.Update(ctx, *currentAssistant)
394		if updateErr != nil {
395			return nil, updateErr
396		}
397		return nil, err
398	}
399	wg.Wait()
400
401	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
402	if !ok || len(queuedMessages) == 0 {
403		return result, err
404	}
405	// there are queued messages restart the loop
406	firstQueuedMessage := queuedMessages[0]
407	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
408	return a.Run(genCtx, firstQueuedMessage)
409}
410
411func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
412	if a.IsSessionBusy(sessionID) {
413		return ErrSessionBusy
414	}
415
416	currentSession, err := a.sessions.Get(ctx, sessionID)
417	if err != nil {
418		return fmt.Errorf("failed to get session: %w", err)
419	}
420	msgs, err := a.getSessionMessages(ctx, currentSession)
421	if err != nil {
422		return err
423	}
424	if len(msgs) == 0 {
425		// nothing to summarize
426		return nil
427	}
428
429	aiMsgs, _ := a.preparePrompt(msgs)
430
431	genCtx, cancel := context.WithCancel(ctx)
432	a.activeRequests.Set(sessionID, cancel)
433	defer a.activeRequests.Del(sessionID)
434	defer cancel()
435
436	agent := ai.NewAgent(a.largeModel.Model,
437		ai.WithSystemPrompt(string(summaryPrompt)),
438	)
439	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
440		Role:     message.Assistant,
441		Model:    a.largeModel.Model.Model(),
442		Provider: a.largeModel.Model.Provider(),
443	})
444	if err != nil {
445		return err
446	}
447
448	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
449		Prompt:   "Provide a detailed summary of our conversation above.",
450		Messages: aiMsgs,
451		OnReasoningDelta: func(id string, text string) error {
452			summaryMessage.AppendReasoningContent(text)
453			return a.messages.Update(ctx, summaryMessage)
454		},
455		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
456			// handle anthropic signature
457			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
458				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
459					summaryMessage.AppendReasoningSignature(signature.Signature)
460				}
461			}
462			summaryMessage.FinishThinking()
463			return a.messages.Update(ctx, summaryMessage)
464		},
465		OnTextDelta: func(id, text string) error {
466			summaryMessage.AppendContent(text)
467			return a.messages.Update(ctx, summaryMessage)
468		},
469	})
470	if err != nil {
471		return err
472	}
473
474	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
475	err = a.messages.Update(genCtx, summaryMessage)
476	if err != nil {
477		return err
478	}
479
480	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
481
482	// just in case get just the last usage
483	usage := resp.Response.Usage
484	currentSession.SummaryMessageID = summaryMessage.ID
485	currentSession.CompletionTokens = usage.OutputTokens
486	currentSession.PromptTokens = 0
487	_, err = a.sessions.Save(genCtx, currentSession)
488	return err
489}
490
491func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
492	return ai.ProviderOptions{
493		anthropic.Name: &anthropic.ProviderCacheControlOptions{
494			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
495		},
496	}
497}
498
499func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
500	var attachmentParts []message.ContentPart
501	for _, attachment := range call.Attachments {
502		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
503	}
504	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
505	parts = append(parts, attachmentParts...)
506	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
507		Role:  message.User,
508		Parts: parts,
509	})
510	if err != nil {
511		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
512	}
513	return msg, nil
514}
515
516func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
517	var history []ai.Message
518	for _, m := range msgs {
519		if len(m.Parts) == 0 {
520			continue
521		}
522		// Assistant message without content or tool calls (cancelled before it returned anything)
523		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
524			continue
525		}
526		history = append(history, m.ToAIMessage()...)
527	}
528
529	var files []ai.FilePart
530	for _, attachment := range attachments {
531		files = append(files, ai.FilePart{
532			Filename:  attachment.FileName,
533			Data:      attachment.Content,
534			MediaType: attachment.MimeType,
535		})
536	}
537
538	return history, files
539}
540
541func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
542	msgs, err := a.messages.List(ctx, session.ID)
543	if err != nil {
544		return nil, fmt.Errorf("failed to list messages: %w", err)
545	}
546
547	if session.SummaryMessageID != "" {
548		summaryMsgInex := -1
549		for i, msg := range msgs {
550			if msg.ID == session.SummaryMessageID {
551				summaryMsgInex = i
552				break
553			}
554		}
555		if summaryMsgInex != -1 {
556			msgs = msgs[summaryMsgInex:]
557			msgs[0].Role = message.User
558		}
559	}
560	return msgs, nil
561}
562
563func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
564	if prompt == "" {
565		return
566	}
567
568	var maxOutput int64 = 40
569	if a.smallModel.CatwalkCfg.CanReason {
570		maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
571	}
572
573	agent := ai.NewAgent(a.smallModel.Model,
574		ai.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
575		ai.WithMaxOutputTokens(maxOutput),
576	)
577
578	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
579		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
580	})
581	if err != nil {
582		slog.Error("error generating title", "err", err)
583		return
584	}
585
586	data, _ := json.Marshal(resp)
587
588	slog.Info("Title Response")
589	slog.Info(string(data))
590	title := resp.Response.Content.Text()
591
592	title = strings.ReplaceAll(title, "\n", " ")
593	slog.Info(title)
594
595	// remove thinking tags if present
596	if idx := strings.Index(title, "</think>"); idx > 0 {
597		title = title[idx+len("</think>"):]
598	}
599
600	title = strings.TrimSpace(title)
601	if title == "" {
602		slog.Warn("failed to generate title", "warn", "empty title")
603		return
604	}
605
606	session.Title = title
607	a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
608	_, saveErr := a.sessions.Save(ctx, *session)
609	if saveErr != nil {
610		slog.Error("failed to save session title & usage", "error", saveErr)
611		return
612	}
613}
614
615func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
616	modelConfig := model.CatwalkCfg
617	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
618		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
619		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
620		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
621	session.Cost += cost
622	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
623	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
624}
625
626func (a *sessionAgent) Cancel(sessionID string) {
627	// Cancel regular requests
628	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
629		slog.Info("Request cancellation initiated", "session_id", sessionID)
630		cancel()
631	}
632
633	// Also check for summarize requests
634	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
635		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
636		cancel()
637	}
638
639	if a.QueuedPrompts(sessionID) > 0 {
640		slog.Info("Clearing queued prompts", "session_id", sessionID)
641		a.messageQueue.Del(sessionID)
642	}
643}
644
645func (a *sessionAgent) ClearQueue(sessionID string) {
646	if a.QueuedPrompts(sessionID) > 0 {
647		slog.Info("Clearing queued prompts", "session_id", sessionID)
648		a.messageQueue.Del(sessionID)
649	}
650}
651
652func (a *sessionAgent) CancelAll() {
653	if !a.IsBusy() {
654		return
655	}
656	for key := range a.activeRequests.Seq2() {
657		a.Cancel(key) // key is sessionID
658	}
659
660	timeout := time.After(5 * time.Second)
661	for a.IsBusy() {
662		select {
663		case <-timeout:
664			return
665		default:
666			time.Sleep(200 * time.Millisecond)
667		}
668	}
669}
670
671func (a *sessionAgent) IsBusy() bool {
672	var busy bool
673	for cancelFunc := range a.activeRequests.Seq() {
674		if cancelFunc != nil {
675			busy = true
676			break
677		}
678	}
679	return busy
680}
681
682func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
683	_, busy := a.activeRequests.Get(sessionID)
684	return busy
685}
686
687func (a *sessionAgent) QueuedPrompts(sessionID string) int {
688	l, ok := a.messageQueue.Get(sessionID)
689	if !ok {
690		return 0
691	}
692	return len(l)
693}
694
695func (a *sessionAgent) SetModels(large Model, small Model) {
696	a.largeModel = large
697	a.smallModel = small
698}
699
700func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
701	a.tools = tools
702}
703
704func (a *sessionAgent) Model() Model {
705	return a.largeModel
706}