agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/agent/tools"
 15	"github.com/charmbracelet/crush/internal/config"
 16	"github.com/charmbracelet/crush/internal/csync"
 17	"github.com/charmbracelet/crush/internal/message"
 18	"github.com/charmbracelet/crush/internal/permission"
 19	"github.com/charmbracelet/crush/internal/session"
 20	"github.com/charmbracelet/fantasy/ai"
 21	"github.com/charmbracelet/fantasy/anthropic"
 22)
 23
 24//go:embed templates/title.md
 25var titlePrompt []byte
 26
 27//go:embed templates/summary.md
 28var summaryPrompt []byte
 29
 30type SessionAgentCall struct {
 31	SessionID        string
 32	Prompt           string
 33	ProviderOptions  ai.ProviderOptions
 34	Attachments      []message.Attachment
 35	MaxOutputTokens  int64
 36	Temperature      *float64
 37	TopP             *float64
 38	TopK             *int64
 39	FrequencyPenalty *float64
 40	PresencePenalty  *float64
 41}
 42
 43type SessionAgent interface {
 44	Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
 45	SetModels(large Model, small Model)
 46	SetTools(tools []ai.AgentTool)
 47	Cancel(sessionID string)
 48	CancelAll()
 49	IsSessionBusy(sessionID string) bool
 50	IsBusy() bool
 51	QueuedPrompts(sessionID string) int
 52	ClearQueue(sessionID string)
 53	Summarize(context.Context, string) error
 54	Model() Model
 55}
 56
 57type Model struct {
 58	Model      ai.LanguageModel
 59	CatwalkCfg catwalk.Model
 60	ModelCfg   config.SelectedModel
 61}
 62
 63type sessionAgent struct {
 64	largeModel   Model
 65	smallModel   Model
 66	systemPrompt string
 67	tools        []ai.AgentTool
 68	sessions     session.Service
 69	messages     message.Service
 70
 71	messageQueue   *csync.Map[string, []SessionAgentCall]
 72	activeRequests *csync.Map[string, context.CancelFunc]
 73}
 74
 75type SessionAgentOption func(*sessionAgent)
 76
 77func NewSessionAgent(
 78	largeModel Model,
 79	smallModel Model,
 80	systemPrompt string,
 81	sessions session.Service,
 82	messages message.Service,
 83	tools ...ai.AgentTool,
 84) SessionAgent {
 85	return &sessionAgent{
 86		largeModel:     largeModel,
 87		smallModel:     smallModel,
 88		systemPrompt:   systemPrompt,
 89		sessions:       sessions,
 90		messages:       messages,
 91		tools:          tools,
 92		messageQueue:   csync.NewMap[string, []SessionAgentCall](),
 93		activeRequests: csync.NewMap[string, context.CancelFunc](),
 94	}
 95}
 96
 97func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
 98	if call.Prompt == "" {
 99		return nil, ErrEmptyPrompt
100	}
101	if call.SessionID == "" {
102		return nil, ErrSessionMissing
103	}
104
105	// Queue the message if busy
106	if a.IsSessionBusy(call.SessionID) {
107		existing, ok := a.messageQueue.Get(call.SessionID)
108		if !ok {
109			existing = []SessionAgentCall{}
110		}
111		existing = append(existing, call)
112		a.messageQueue.Set(call.SessionID, existing)
113		return nil, nil
114	}
115
116	if len(a.tools) > 0 {
117		// add anthropic caching to the last tool
118		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
119	}
120
121	agent := ai.NewAgent(
122		a.largeModel.Model,
123		ai.WithSystemPrompt(a.systemPrompt),
124		ai.WithTools(a.tools...),
125	)
126
127	currentSession, err := a.sessions.Get(ctx, call.SessionID)
128	if err != nil {
129		return nil, fmt.Errorf("failed to get session: %w", err)
130	}
131
132	msgs, err := a.getSessionMessages(ctx, currentSession)
133	if err != nil {
134		return nil, fmt.Errorf("failed to get session messages: %w", err)
135	}
136
137	var wg sync.WaitGroup
138	// Generate title if first message
139	if len(msgs) == 0 {
140		wg.Go(func() {
141			a.generateTitle(ctx, currentSession, call.Prompt)
142		})
143	}
144
145	// Add the user message to the session
146	_, err = a.createUserMessage(ctx, call)
147	if err != nil {
148		return nil, err
149	}
150
151	// add the session to the context
152	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
153
154	genCtx, cancel := context.WithCancel(ctx)
155	a.activeRequests.Set(call.SessionID, cancel)
156
157	defer cancel()
158	defer a.activeRequests.Del(call.SessionID)
159
160	history, files := a.preparePrompt(msgs, call.Attachments...)
161
162	var currentAssistant *message.Message
163	result, err := agent.Stream(genCtx, ai.AgentStreamCall{
164		Prompt:           call.Prompt,
165		Files:            files,
166		Messages:         history,
167		ProviderOptions:  call.ProviderOptions,
168		MaxOutputTokens:  &call.MaxOutputTokens,
169		TopP:             call.TopP,
170		Temperature:      call.Temperature,
171		PresencePenalty:  call.PresencePenalty,
172		TopK:             call.TopK,
173		FrequencyPenalty: call.FrequencyPenalty,
174		// Before each step create the new assistant message
175		PrepareStep: func(options ai.PrepareStepFunctionOptions) (prepared ai.PrepareStepResult, err error) {
176			var assistantMsg message.Message
177			assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
178				Role:     message.Assistant,
179				Parts:    []message.ContentPart{},
180				Model:    a.largeModel.ModelCfg.Model,
181				Provider: a.largeModel.ModelCfg.Provider,
182			})
183			if err != nil {
184				return prepared, err
185			}
186
187			currentAssistant = &assistantMsg
188
189			prepared.Messages = options.Messages
190			// reset all cached items
191			for i := range prepared.Messages {
192				prepared.Messages[i].ProviderOptions = nil
193			}
194
195			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
196			a.messageQueue.Del(call.SessionID)
197			for _, queued := range queuedCalls {
198				userMessage, createErr := a.createUserMessage(genCtx, queued)
199				if createErr != nil {
200					return prepared, createErr
201				}
202				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
203			}
204
205			lastSystemRoleInx := 0
206			systemMessageUpdated := false
207			for i, msg := range prepared.Messages {
208				// only add cache control to the last message
209				if msg.Role == ai.MessageRoleSystem {
210					lastSystemRoleInx = i
211				} else if !systemMessageUpdated {
212					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
213					systemMessageUpdated = true
214				}
215				// than add cache control to the last 2 messages
216				if i > len(prepared.Messages)-3 {
217					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
218				}
219			}
220			return prepared, err
221		},
222		OnReasoningDelta: func(id string, text string) error {
223			currentAssistant.AppendReasoningContent(text)
224			return a.messages.Update(genCtx, *currentAssistant)
225		},
226		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
227			// handle anthropic signature
228			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
229				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
230					currentAssistant.AppendReasoningSignature(reasoning.Signature)
231				}
232			}
233			currentAssistant.FinishThinking()
234			return a.messages.Update(genCtx, *currentAssistant)
235		},
236		OnTextDelta: func(id string, text string) error {
237			currentAssistant.AppendContent(text)
238			return a.messages.Update(genCtx, *currentAssistant)
239		},
240		OnToolInputStart: func(id string, toolName string) error {
241			toolCall := message.ToolCall{
242				ID:               id,
243				Name:             toolName,
244				ProviderExecuted: false,
245				Finished:         false,
246			}
247			currentAssistant.AddToolCall(toolCall)
248			return a.messages.Update(genCtx, *currentAssistant)
249		},
250		OnRetry: func(err *ai.APICallError, delay time.Duration) {
251			// TODO: implement
252		},
253		OnToolCall: func(tc ai.ToolCallContent) error {
254			toolCall := message.ToolCall{
255				ID:               tc.ToolCallID,
256				Name:             tc.ToolName,
257				Input:            tc.Input,
258				ProviderExecuted: false,
259				Finished:         true,
260			}
261			currentAssistant.AddToolCall(toolCall)
262			return a.messages.Update(genCtx, *currentAssistant)
263		},
264		OnToolResult: func(result ai.ToolResultContent) error {
265			var resultContent string
266			isError := false
267			switch result.Result.GetType() {
268			case ai.ToolResultContentTypeText:
269				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
270				if ok {
271					resultContent = r.Text
272				}
273			case ai.ToolResultContentTypeError:
274				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
275				if ok {
276					isError = true
277					resultContent = r.Error.Error()
278				}
279			case ai.ToolResultContentTypeMedia:
280				// TODO: handle this message type
281			}
282			toolResult := message.ToolResult{
283				ToolCallID: result.ToolCallID,
284				Name:       result.ToolName,
285				Content:    resultContent,
286				IsError:    isError,
287				Metadata:   result.ClientMetadata,
288			}
289			a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
290				Role: message.Tool,
291				Parts: []message.ContentPart{
292					toolResult,
293				},
294			})
295			return a.messages.Update(genCtx, *currentAssistant)
296		},
297		OnStepFinish: func(stepResult ai.StepResult) error {
298			finishReason := message.FinishReasonUnknown
299			switch stepResult.FinishReason {
300			case ai.FinishReasonLength:
301				finishReason = message.FinishReasonMaxTokens
302			case ai.FinishReasonStop:
303				finishReason = message.FinishReasonEndTurn
304			case ai.FinishReasonToolCalls:
305				finishReason = message.FinishReasonToolUse
306			}
307			currentAssistant.AddFinish(finishReason, "", "")
308			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
309			return a.messages.Update(genCtx, *currentAssistant)
310		},
311	})
312	if err != nil {
313		isCancelErr := errors.Is(err, context.Canceled)
314		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
315		if currentAssistant == nil {
316			return result, err
317		}
318		toolCalls := currentAssistant.ToolCalls()
319		toolResults := currentAssistant.ToolResults()
320		for _, tc := range toolCalls {
321			if !tc.Finished {
322				tc.Finished = true
323				tc.Input = "{}"
324			}
325			currentAssistant.AddToolCall(tc)
326			found := false
327			for _, tr := range toolResults {
328				if tr.ToolCallID == tc.ID {
329					found = true
330					break
331				}
332			}
333			if !found {
334				content := "There was an error while executing the tool"
335				if isCancelErr {
336					content = "Tool execution canceled by user"
337				} else if isPermissionErr {
338					content = "Permission denied"
339				}
340				currentAssistant.AddToolResult(message.ToolResult{
341					ToolCallID: tc.ID,
342					Name:       tc.Name,
343					Content:    content,
344					IsError:    true,
345				})
346			}
347		}
348		if isCancelErr {
349			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
350		} else if isPermissionErr {
351			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
352		} else {
353			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
354		}
355		// INFO: we use the parent context here because the genCtx might have been cancelled
356		updateErr := a.messages.Update(ctx, *currentAssistant)
357		if updateErr != nil {
358			return nil, updateErr
359		}
360	}
361	if err != nil {
362		return nil, err
363	}
364	wg.Wait()
365
366	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
367	if !ok || len(queuedMessages) == 0 {
368		return result, err
369	}
370	// there are queued messages restart the loop
371	firstQueuedMessage := queuedMessages[0]
372	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
373	return a.Run(genCtx, firstQueuedMessage)
374}
375
376func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
377	if a.IsSessionBusy(sessionID) {
378		return ErrSessionBusy
379	}
380
381	currentSession, err := a.sessions.Get(ctx, sessionID)
382	if err != nil {
383		return fmt.Errorf("failed to get session: %w", err)
384	}
385	msgs, err := a.getSessionMessages(ctx, currentSession)
386	if err != nil {
387		return err
388	}
389	if len(msgs) == 0 {
390		// nothing to summarize
391		return nil
392	}
393
394	aiMsgs, _ := a.preparePrompt(msgs)
395
396	genCtx, cancel := context.WithCancel(ctx)
397	a.activeRequests.Set(sessionID, cancel)
398	defer a.activeRequests.Del(sessionID)
399	defer cancel()
400
401	agent := ai.NewAgent(a.largeModel.Model,
402		ai.WithSystemPrompt(string(summaryPrompt)),
403	)
404	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
405		Role:     message.Assistant,
406		Model:    a.largeModel.Model.Model(),
407		Provider: a.largeModel.Model.Provider(),
408	})
409	if err != nil {
410		return err
411	}
412
413	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
414		Prompt:   "Provide a detailed summary of our conversation above.",
415		Messages: aiMsgs,
416		OnReasoningDelta: func(id string, text string) error {
417			summaryMessage.AppendReasoningContent(text)
418			return a.messages.Update(ctx, summaryMessage)
419		},
420		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
421			// handle anthropic signature
422			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
423				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
424					summaryMessage.AppendReasoningSignature(signature.Signature)
425				}
426			}
427			summaryMessage.FinishThinking()
428			return a.messages.Update(ctx, summaryMessage)
429		},
430		OnTextDelta: func(id, text string) error {
431			summaryMessage.AppendContent(text)
432			return a.messages.Update(ctx, summaryMessage)
433		},
434	})
435	if err != nil {
436		return err
437	}
438
439	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
440	err = a.messages.Update(genCtx, summaryMessage)
441	if err != nil {
442		return err
443	}
444
445	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
446
447	// just in case get just the last usage
448	usage := resp.Response.Usage
449	currentSession.SummaryMessageID = summaryMessage.ID
450	currentSession.CompletionTokens = usage.OutputTokens
451	currentSession.PromptTokens = 0
452	_, err = a.sessions.Save(genCtx, currentSession)
453	return err
454}
455
456func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
457	return ai.ProviderOptions{
458		anthropic.Name: &anthropic.ProviderCacheControlOptions{
459			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
460		},
461	}
462}
463
464func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
465	var attachmentParts []message.ContentPart
466	for _, attachment := range call.Attachments {
467		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
468	}
469	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
470	parts = append(parts, attachmentParts...)
471	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
472		Role:  message.User,
473		Parts: parts,
474	})
475	if err != nil {
476		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
477	}
478	return msg, nil
479}
480
481func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
482	var history []ai.Message
483	for _, m := range msgs {
484		if len(m.Parts) == 0 {
485			continue
486		}
487		// Assistant message without content or tool calls (cancelled before it returned anything)
488		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
489			continue
490		}
491		history = append(history, m.ToAIMessage()...)
492	}
493
494	var files []ai.FilePart
495	for _, attachment := range attachments {
496		files = append(files, ai.FilePart{
497			Filename:  attachment.FileName,
498			Data:      attachment.Content,
499			MediaType: attachment.MimeType,
500		})
501	}
502
503	return history, files
504}
505
506func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
507	msgs, err := a.messages.List(ctx, session.ID)
508	if err != nil {
509		return nil, fmt.Errorf("failed to list messages: %w", err)
510	}
511
512	if session.SummaryMessageID != "" {
513		summaryMsgInex := -1
514		for i, msg := range msgs {
515			if msg.ID == session.SummaryMessageID {
516				summaryMsgInex = i
517				break
518			}
519		}
520		if summaryMsgInex != -1 {
521			msgs = msgs[summaryMsgInex:]
522			msgs[0].Role = message.User
523		}
524	}
525	return msgs, nil
526}
527
528func (a *sessionAgent) generateTitle(ctx context.Context, session session.Session, prompt string) {
529	if prompt == "" {
530		return
531	}
532
533	agent := ai.NewAgent(a.smallModel.Model,
534		ai.WithSystemPrompt(string(titlePrompt)),
535		ai.WithMaxOutputTokens(40),
536	)
537
538	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
539		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", prompt),
540	})
541	if err != nil {
542		slog.Error("error generating title", "err", err)
543		return
544	}
545
546	title := resp.Response.Content.Text()
547
548	title = strings.ReplaceAll(title, "\n", " ")
549
550	// remove thinking tags if present
551	if idx := strings.Index(title, "</think>"); idx > 0 {
552		title = title[idx+len("</think>"):]
553	}
554
555	title = strings.TrimSpace(title)
556	if title == "" {
557		slog.Warn("failed to generate title", "warn", "empty title")
558		return
559	}
560
561	session.Title = title
562	a.updateSessionUsage(a.smallModel, &session, resp.TotalUsage)
563	_, saveErr := a.sessions.Save(ctx, session)
564	if saveErr != nil {
565		slog.Error("failed to save session title & usage", "error", saveErr)
566		return
567	}
568}
569
570func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
571	modelConfig := model.CatwalkCfg
572	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
573		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
574		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
575		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
576	session.Cost += cost
577	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
578	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
579}
580
581func (a *sessionAgent) Cancel(sessionID string) {
582	// Cancel regular requests
583	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
584		slog.Info("Request cancellation initiated", "session_id", sessionID)
585		cancel()
586	}
587
588	// Also check for summarize requests
589	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
590		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
591		cancel()
592	}
593
594	if a.QueuedPrompts(sessionID) > 0 {
595		slog.Info("Clearing queued prompts", "session_id", sessionID)
596		a.messageQueue.Del(sessionID)
597	}
598}
599
600func (a *sessionAgent) ClearQueue(sessionID string) {
601	if a.QueuedPrompts(sessionID) > 0 {
602		slog.Info("Clearing queued prompts", "session_id", sessionID)
603		a.messageQueue.Del(sessionID)
604	}
605}
606
607func (a *sessionAgent) CancelAll() {
608	if !a.IsBusy() {
609		return
610	}
611	for key := range a.activeRequests.Seq2() {
612		a.Cancel(key) // key is sessionID
613	}
614
615	timeout := time.After(5 * time.Second)
616	for a.IsBusy() {
617		select {
618		case <-timeout:
619			return
620		default:
621			time.Sleep(200 * time.Millisecond)
622		}
623	}
624}
625
626func (a *sessionAgent) IsBusy() bool {
627	var busy bool
628	for cancelFunc := range a.activeRequests.Seq() {
629		if cancelFunc != nil {
630			busy = true
631			break
632		}
633	}
634	return busy
635}
636
637func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
638	_, busy := a.activeRequests.Get(sessionID)
639	return busy
640}
641
642func (a *sessionAgent) QueuedPrompts(sessionID string) int {
643	l, ok := a.messageQueue.Get(sessionID)
644	if !ok {
645		return 0
646	}
647	return len(l)
648}
649
650func (a *sessionAgent) SetModels(large Model, small Model) {
651	a.largeModel = large
652	a.smallModel = small
653}
654
655func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
656	a.tools = tools
657}
658
659func (a *sessionAgent) Model() Model {
660	return a.largeModel
661}