agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/csync"
 14	"github.com/charmbracelet/crush/internal/llm/tools"
 15	"github.com/charmbracelet/crush/internal/message"
 16	"github.com/charmbracelet/crush/internal/permission"
 17	"github.com/charmbracelet/crush/internal/session"
 18	"github.com/charmbracelet/fantasy/ai"
 19	"github.com/charmbracelet/fantasy/anthropic"
 20)
 21
 22//go:embed templates/title.md
 23var titlePrompt []byte
 24
 25//go:embed templates/summary.md
 26var summaryPrompt []byte
 27
 28type SessionAgentCall struct {
 29	SessionID        string
 30	Prompt           string
 31	ProviderOptions  ai.ProviderOptions
 32	Attachments      []message.Attachment
 33	MaxOutputTokens  int64
 34	Temperature      *float64
 35	TopP             *float64
 36	TopK             *int64
 37	FrequencyPenalty *float64
 38	PresencePenalty  *float64
 39}
 40
 41type SessionAgent interface {
 42	Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
 43	SetModels(large Model, small Model)
 44	SetTools(tools []ai.AgentTool)
 45	Cancel(sessionID string)
 46	CancelAll()
 47	IsSessionBusy(sessionID string) bool
 48	IsBusy() bool
 49	QueuedPrompts(sessionID string) int
 50	ClearQueue(sessionID string)
 51	Summarize(context.Context, string) error
 52}
 53
 54type Model struct {
 55	model  ai.LanguageModel
 56	config catwalk.Model
 57}
 58
 59type sessionAgent struct {
 60	largeModel      Model
 61	smallModel      Model
 62	systemPrompt    string
 63	tools           []ai.AgentTool
 64	maxOutputTokens int64
 65	sessions        session.Service
 66	messages        message.Service
 67
 68	messageQueue   *csync.Map[string, []SessionAgentCall]
 69	activeRequests *csync.Map[string, context.CancelFunc]
 70}
 71
 72type SessionAgentOption func(*sessionAgent)
 73
 74func NewSessionAgent() SessionAgent {
 75	return &sessionAgent{}
 76}
 77
 78func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
 79	if call.Prompt == "" {
 80		return nil, ErrEmptyPrompt
 81	}
 82	if call.SessionID == "" {
 83		return nil, ErrSessionMissing
 84	}
 85
 86	// Queue the message if busy
 87	if a.IsSessionBusy(call.SessionID) {
 88		existing, ok := a.messageQueue.Get(call.SessionID)
 89		if !ok {
 90			existing = []SessionAgentCall{}
 91		}
 92		existing = append(existing, call)
 93		a.messageQueue.Set(call.SessionID, existing)
 94		return nil, nil
 95	}
 96
 97	if len(a.tools) > 0 {
 98		// add anthropic caching to the last tool
 99		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
100	}
101
102	agent := ai.NewAgent(
103		a.largeModel.model,
104		ai.WithSystemPrompt(a.systemPrompt),
105		ai.WithTools(a.tools...),
106		ai.WithMaxOutputTokens(a.maxOutputTokens),
107	)
108
109	currentSession, err := a.sessions.Get(ctx, call.SessionID)
110	if err != nil {
111		return nil, fmt.Errorf("failed to get session: %w", err)
112	}
113
114	msgs, err := a.getSessionMessages(ctx, currentSession)
115	if err != nil {
116		return nil, fmt.Errorf("failed to get session messages: %w", err)
117	}
118
119	// Generate title if first message
120	if len(msgs) == 0 {
121		go a.generateTitle(ctx, currentSession, call.Prompt)
122	}
123
124	// Add the user message to the session
125	_, err = a.createUserMessage(ctx, call)
126	if err != nil {
127		return nil, err
128	}
129
130	// add the session to the context
131	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
132
133	genCtx, cancel := context.WithCancel(ctx)
134	a.activeRequests.Set(call.SessionID, cancel)
135
136	defer cancel()
137	defer a.activeRequests.Del(call.SessionID)
138
139	history, files := a.preparePrompt(msgs, call.Attachments...)
140
141	var currentAssistant *message.Message
142	result, err := agent.Stream(genCtx, ai.AgentStreamCall{
143		Prompt:           call.Prompt,
144		Files:            files,
145		Messages:         history,
146		ProviderOptions:  call.ProviderOptions,
147		MaxOutputTokens:  &call.MaxOutputTokens,
148		TopP:             call.TopP,
149		Temperature:      call.Temperature,
150		PresencePenalty:  call.PresencePenalty,
151		TopK:             call.TopK,
152		FrequencyPenalty: call.FrequencyPenalty,
153		// Before each step create the new assistant message
154		PrepareStep: func(options ai.PrepareStepFunctionOptions) (prepared ai.PrepareStepResult, err error) {
155			var assistantMsg message.Message
156			assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
157				Role:     message.Assistant,
158				Parts:    []message.ContentPart{},
159				Model:    a.largeModel.model.Model(),
160				Provider: a.largeModel.model.Provider(),
161			})
162			if err != nil {
163				return prepared, err
164			}
165
166			currentAssistant = &assistantMsg
167
168			prepared.Messages = options.Messages
169
170			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
171			a.messageQueue.Del(call.SessionID)
172			for _, queued := range queuedCalls {
173				userMessage, createErr := a.createUserMessage(genCtx, queued)
174				if createErr != nil {
175					return prepared, createErr
176				}
177				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
178			}
179
180			lastSystemRoleInx := 0
181			systemMessageUpdated := false
182			for i, msg := range prepared.Messages {
183				// only add cache control to the last message
184				if msg.Role == ai.MessageRoleSystem {
185					lastSystemRoleInx = i
186				} else if !systemMessageUpdated {
187					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
188					systemMessageUpdated = true
189				}
190				// than add cache control to the last 2 messages
191				if i > len(msgs)-3 {
192					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
193				}
194			}
195			return prepared, err
196		},
197		OnReasoningDelta: func(id string, text string) error {
198			currentAssistant.AppendReasoningContent(text)
199			return a.messages.Update(genCtx, *currentAssistant)
200		},
201		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
202			// handle anthropic signature
203			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
204				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
205					currentAssistant.AppendReasoningSignature(reasoning.Signature)
206				}
207			}
208			currentAssistant.FinishThinking()
209			return a.messages.Update(genCtx, *currentAssistant)
210		},
211		OnTextDelta: func(id string, text string) error {
212			currentAssistant.AppendContent(text)
213			return a.messages.Update(genCtx, *currentAssistant)
214		},
215		OnToolInputStart: func(id string, toolName string) error {
216			toolCall := message.ToolCall{
217				ID:               id,
218				Name:             toolName,
219				ProviderExecuted: false,
220				Finished:         false,
221			}
222			currentAssistant.AddToolCall(toolCall)
223			return a.messages.Update(genCtx, *currentAssistant)
224		},
225		OnRetry: func(err *ai.APICallError, delay time.Duration) {
226			// TODO: implement
227		},
228		OnToolResult: func(result ai.ToolResultContent) error {
229			var resultContent string
230			isError := false
231			switch result.Result.GetType() {
232			case ai.ToolResultContentTypeText:
233				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
234				if ok {
235					resultContent = r.Text
236				}
237			case ai.ToolResultContentTypeError:
238				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
239				if ok {
240					isError = true
241					resultContent = r.Error.Error()
242				}
243			case ai.ToolResultContentTypeMedia:
244				// TODO: handle this message type
245			}
246			toolResult := message.ToolResult{
247				ToolCallID: result.ToolCallID,
248				Name:       result.ToolName,
249				Content:    resultContent,
250				IsError:    isError,
251				Metadata:   result.ClientMetadata,
252			}
253			a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
254				Role: message.Tool,
255				Parts: []message.ContentPart{
256					toolResult,
257				},
258			})
259			return a.messages.Update(genCtx, *currentAssistant)
260		},
261		OnStepFinish: func(stepResult ai.StepResult) error {
262			finishReason := message.FinishReasonUnknown
263			switch stepResult.FinishReason {
264			case ai.FinishReasonLength:
265				finishReason = message.FinishReasonMaxTokens
266			case ai.FinishReasonStop:
267				finishReason = message.FinishReasonEndTurn
268			case ai.FinishReasonToolCalls:
269				finishReason = message.FinishReasonToolUse
270			}
271			currentAssistant.AddFinish(finishReason, "", "")
272			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
273			return a.messages.Update(genCtx, *currentAssistant)
274		},
275	})
276	if err != nil {
277		isCancelErr := errors.Is(err, context.Canceled)
278		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
279		if currentAssistant == nil {
280			return result, err
281		}
282		toolCalls := currentAssistant.ToolCalls()
283		toolResults := currentAssistant.ToolResults()
284		for _, tc := range toolCalls {
285			if !tc.Finished {
286				tc.Finished = true
287				tc.Input = "{}"
288			}
289			currentAssistant.AddToolCall(tc)
290			found := false
291			for _, tr := range toolResults {
292				if tr.ToolCallID == tc.ID {
293					found = true
294					break
295				}
296			}
297			if !found {
298				content := "There was an error while executing the tool"
299				if isCancelErr {
300					content = "Tool execution canceled by user"
301				} else if isPermissionErr {
302					content = "Permission denied"
303				}
304				currentAssistant.AddToolResult(message.ToolResult{
305					ToolCallID: tc.ID,
306					Name:       tc.Name,
307					Content:    content,
308					IsError:    true,
309				})
310			}
311		}
312		if isCancelErr {
313			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
314		} else if isPermissionErr {
315			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
316		} else {
317			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
318		}
319		// INFO: we use the parent context here because the genCtx might have been cancelled
320		updateErr := a.messages.Update(ctx, *currentAssistant)
321		if updateErr != nil {
322			return nil, updateErr
323		}
324	}
325	if err != nil {
326		return nil, err
327	}
328
329	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
330	if !ok || len(queuedMessages) == 0 {
331		return result, err
332	}
333	// there are queued messages restart the loop
334	firstQueuedMessage := queuedMessages[0]
335	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
336	return a.Run(genCtx, firstQueuedMessage)
337}
338
339func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
340	if a.IsSessionBusy(sessionID) {
341		return ErrSessionBusy
342	}
343
344	currentSession, err := a.sessions.Get(ctx, sessionID)
345	if err != nil {
346		return fmt.Errorf("failed to get session: %w", err)
347	}
348	msgs, err := a.getSessionMessages(ctx, currentSession)
349	if err != nil {
350		return err
351	}
352	if len(msgs) == 0 {
353		// nothing to summarize
354		return nil
355	}
356
357	aiMsgs, _ := a.preparePrompt(msgs)
358
359	genCtx, cancel := context.WithCancel(ctx)
360	a.activeRequests.Set(sessionID, cancel)
361	defer a.activeRequests.Del(sessionID)
362	defer cancel()
363
364	agent := ai.NewAgent(a.largeModel.model,
365		ai.WithSystemPrompt(string(summaryPrompt)),
366	)
367	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
368		Role:     message.Assistant,
369		Model:    a.largeModel.model.Model(),
370		Provider: a.largeModel.model.Provider(),
371	})
372	if err != nil {
373		return err
374	}
375
376	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
377		Prompt:   "Provide a detailed summary of our conversation above.",
378		Messages: aiMsgs,
379		OnReasoningDelta: func(id string, text string) error {
380			summaryMessage.AppendReasoningContent(text)
381			return a.messages.Update(ctx, summaryMessage)
382		},
383		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
384			// handle anthropic signature
385			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
386				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
387					summaryMessage.AppendReasoningSignature(signature.Signature)
388				}
389			}
390			summaryMessage.FinishThinking()
391			return a.messages.Update(ctx, summaryMessage)
392		},
393		OnTextDelta: func(id, text string) error {
394			summaryMessage.AppendContent(text)
395			return a.messages.Update(ctx, summaryMessage)
396		},
397	})
398	if err != nil {
399		return err
400	}
401
402	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
403	err = a.messages.Update(genCtx, summaryMessage)
404	if err != nil {
405		return err
406	}
407
408	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
409
410	// just in case get just the last usage
411	usage := resp.Response.Usage
412	currentSession.SummaryMessageID = summaryMessage.ID
413	currentSession.CompletionTokens = usage.OutputTokens
414	currentSession.PromptTokens = 0
415	_, err = a.sessions.Save(genCtx, currentSession)
416	return err
417}
418
419func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
420	return ai.ProviderOptions{
421		anthropic.Name: &anthropic.ProviderCacheControlOptions{
422			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
423		},
424	}
425}
426
427func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
428	var attachmentParts []message.ContentPart
429	for _, attachment := range call.Attachments {
430		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
431	}
432	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
433	parts = append(parts, attachmentParts...)
434	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
435		Role:  message.User,
436		Parts: parts,
437	})
438	if err != nil {
439		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
440	}
441	return msg, nil
442}
443
444func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
445	var history []ai.Message
446	for _, m := range msgs {
447		if len(m.Parts) == 0 {
448			continue
449		}
450		// Assistant message without content or tool calls (cancelled before it returned anything)
451		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
452			continue
453		}
454		history = append(history, m.ToAIMessage()...)
455	}
456
457	var files []ai.FilePart
458	for _, attachment := range attachments {
459		files = append(files, ai.FilePart{
460			Filename:  attachment.FileName,
461			Data:      attachment.Content,
462			MediaType: attachment.MimeType,
463		})
464	}
465
466	return history, files
467}
468
469func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
470	msgs, err := a.messages.List(ctx, session.ID)
471	if err != nil {
472		return nil, fmt.Errorf("failed to list messages: %w", err)
473	}
474
475	if session.SummaryMessageID != "" {
476		summaryMsgInex := -1
477		for i, msg := range msgs {
478			if msg.ID == session.SummaryMessageID {
479				summaryMsgInex = i
480				break
481			}
482		}
483		if summaryMsgInex != -1 {
484			msgs = msgs[summaryMsgInex:]
485			msgs[0].Role = message.User
486		}
487	}
488	return msgs, nil
489}
490
491func (a *sessionAgent) generateTitle(ctx context.Context, session session.Session, prompt string) {
492	if prompt == "" {
493		return
494	}
495
496	agent := ai.NewAgent(a.smallModel.model,
497		ai.WithSystemPrompt(string(titlePrompt)),
498		ai.WithMaxOutputTokens(40),
499	)
500
501	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
502		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", prompt),
503	})
504	if err != nil {
505		slog.Error("error generating title", "err", err)
506		return
507	}
508
509	title := resp.Response.Content.Text()
510
511	title = strings.ReplaceAll(title, "\n", " ")
512
513	// remove thinking tags if present
514	if idx := strings.Index(title, "</think>"); idx > 0 {
515		title = title[idx+len("</think>"):]
516	}
517
518	title = strings.TrimSpace(title)
519	if title == "" {
520		slog.Warn("failed to generate title", "warn", "empty title")
521		return
522	}
523
524	session.Title = title
525	a.updateSessionUsage(a.smallModel, &session, resp.TotalUsage)
526	_, saveErr := a.sessions.Save(ctx, session)
527	if saveErr != nil {
528		slog.Error("failed to save session title & usage", "error", saveErr)
529		return
530	}
531}
532
533func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
534	modelConfig := model.config
535	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
536		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
537		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
538		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
539	session.Cost += cost
540	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
541	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
542}
543
544func (a *sessionAgent) Cancel(sessionID string) {
545	// Cancel regular requests
546	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
547		slog.Info("Request cancellation initiated", "session_id", sessionID)
548		cancel()
549	}
550
551	// Also check for summarize requests
552	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
553		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
554		cancel()
555	}
556
557	if a.QueuedPrompts(sessionID) > 0 {
558		slog.Info("Clearing queued prompts", "session_id", sessionID)
559		a.messageQueue.Del(sessionID)
560	}
561}
562
563func (a *sessionAgent) ClearQueue(sessionID string) {
564	if a.QueuedPrompts(sessionID) > 0 {
565		slog.Info("Clearing queued prompts", "session_id", sessionID)
566		a.messageQueue.Del(sessionID)
567	}
568}
569
570func (a *sessionAgent) CancelAll() {
571	if !a.IsBusy() {
572		return
573	}
574	for key := range a.activeRequests.Seq2() {
575		a.Cancel(key) // key is sessionID
576	}
577
578	timeout := time.After(5 * time.Second)
579	for a.IsBusy() {
580		select {
581		case <-timeout:
582			return
583		default:
584			time.Sleep(200 * time.Millisecond)
585		}
586	}
587}
588
589func (a *sessionAgent) IsBusy() bool {
590	var busy bool
591	for cancelFunc := range a.activeRequests.Seq() {
592		if cancelFunc != nil {
593			busy = true
594			break
595		}
596	}
597	return busy
598}
599
600func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
601	_, busy := a.activeRequests.Get(sessionID)
602	return busy
603}
604
605func (a *sessionAgent) QueuedPrompts(sessionID string) int {
606	l, ok := a.messageQueue.Get(sessionID)
607	if !ok {
608		return 0
609	}
610	return len(l)
611}
612
613func (a *sessionAgent) SetModels(large Model, small Model) {
614	a.largeModel = large
615	a.smallModel = small
616}
617
618func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
619	a.tools = tools
620}