agent.go

  1package agent
  2
  3import (
  4	"context"
  5	_ "embed"
  6	"errors"
  7	"fmt"
  8	"log/slog"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/agent/tools"
 15	"github.com/charmbracelet/crush/internal/csync"
 16	"github.com/charmbracelet/crush/internal/message"
 17	"github.com/charmbracelet/crush/internal/permission"
 18	"github.com/charmbracelet/crush/internal/session"
 19	"github.com/charmbracelet/fantasy/ai"
 20	"github.com/charmbracelet/fantasy/anthropic"
 21)
 22
 23//go:embed templates/title.md
 24var titlePrompt []byte
 25
 26//go:embed templates/summary.md
 27var summaryPrompt []byte
 28
 29type SessionAgentCall struct {
 30	SessionID        string
 31	Prompt           string
 32	ProviderOptions  ai.ProviderOptions
 33	Attachments      []message.Attachment
 34	MaxOutputTokens  int64
 35	Temperature      *float64
 36	TopP             *float64
 37	TopK             *int64
 38	FrequencyPenalty *float64
 39	PresencePenalty  *float64
 40}
 41
 42type SessionAgent interface {
 43	Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
 44	SetModels(large Model, small Model)
 45	SetTools(tools []ai.AgentTool)
 46	Cancel(sessionID string)
 47	CancelAll()
 48	IsSessionBusy(sessionID string) bool
 49	IsBusy() bool
 50	QueuedPrompts(sessionID string) int
 51	ClearQueue(sessionID string)
 52	Summarize(context.Context, string) error
 53}
 54
 55type Model struct {
 56	model  ai.LanguageModel
 57	config catwalk.Model
 58}
 59
 60type sessionAgent struct {
 61	largeModel   Model
 62	smallModel   Model
 63	systemPrompt string
 64	tools        []ai.AgentTool
 65	sessions     session.Service
 66	messages     message.Service
 67
 68	messageQueue   *csync.Map[string, []SessionAgentCall]
 69	activeRequests *csync.Map[string, context.CancelFunc]
 70}
 71
 72type SessionAgentOption func(*sessionAgent)
 73
 74func NewSessionAgent(
 75	largeModel Model,
 76	smallModel Model,
 77	systemPrompt string,
 78	sessions session.Service,
 79	messages message.Service,
 80	tools ...ai.AgentTool,
 81) SessionAgent {
 82	return &sessionAgent{
 83		largeModel:     largeModel,
 84		smallModel:     smallModel,
 85		systemPrompt:   systemPrompt,
 86		sessions:       sessions,
 87		messages:       messages,
 88		tools:          tools,
 89		messageQueue:   csync.NewMap[string, []SessionAgentCall](),
 90		activeRequests: csync.NewMap[string, context.CancelFunc](),
 91	}
 92}
 93
 94func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
 95	if call.Prompt == "" {
 96		return nil, ErrEmptyPrompt
 97	}
 98	if call.SessionID == "" {
 99		return nil, ErrSessionMissing
100	}
101
102	// Queue the message if busy
103	if a.IsSessionBusy(call.SessionID) {
104		existing, ok := a.messageQueue.Get(call.SessionID)
105		if !ok {
106			existing = []SessionAgentCall{}
107		}
108		existing = append(existing, call)
109		a.messageQueue.Set(call.SessionID, existing)
110		return nil, nil
111	}
112
113	if len(a.tools) > 0 {
114		// add anthropic caching to the last tool
115		a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
116	}
117
118	agent := ai.NewAgent(
119		a.largeModel.model,
120		ai.WithSystemPrompt(a.systemPrompt),
121		ai.WithTools(a.tools...),
122	)
123
124	currentSession, err := a.sessions.Get(ctx, call.SessionID)
125	if err != nil {
126		return nil, fmt.Errorf("failed to get session: %w", err)
127	}
128
129	msgs, err := a.getSessionMessages(ctx, currentSession)
130	if err != nil {
131		return nil, fmt.Errorf("failed to get session messages: %w", err)
132	}
133
134	var wg sync.WaitGroup
135	// Generate title if first message
136	if len(msgs) == 0 {
137		wg.Go(func() {
138			a.generateTitle(ctx, currentSession, call.Prompt)
139		})
140	}
141
142	// Add the user message to the session
143	_, err = a.createUserMessage(ctx, call)
144	if err != nil {
145		return nil, err
146	}
147
148	// add the session to the context
149	ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
150
151	genCtx, cancel := context.WithCancel(ctx)
152	a.activeRequests.Set(call.SessionID, cancel)
153
154	defer cancel()
155	defer a.activeRequests.Del(call.SessionID)
156
157	history, files := a.preparePrompt(msgs, call.Attachments...)
158
159	var currentAssistant *message.Message
160	result, err := agent.Stream(genCtx, ai.AgentStreamCall{
161		Prompt:           call.Prompt,
162		Files:            files,
163		Messages:         history,
164		ProviderOptions:  call.ProviderOptions,
165		MaxOutputTokens:  &call.MaxOutputTokens,
166		TopP:             call.TopP,
167		Temperature:      call.Temperature,
168		PresencePenalty:  call.PresencePenalty,
169		TopK:             call.TopK,
170		FrequencyPenalty: call.FrequencyPenalty,
171		// Before each step create the new assistant message
172		PrepareStep: func(options ai.PrepareStepFunctionOptions) (prepared ai.PrepareStepResult, err error) {
173			var assistantMsg message.Message
174			assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
175				Role:     message.Assistant,
176				Parts:    []message.ContentPart{},
177				Model:    a.largeModel.model.Model(),
178				Provider: a.largeModel.model.Provider(),
179			})
180			if err != nil {
181				return prepared, err
182			}
183
184			currentAssistant = &assistantMsg
185
186			prepared.Messages = options.Messages
187			// reset all cached items
188			for i := range prepared.Messages {
189				prepared.Messages[i].ProviderOptions = nil
190			}
191
192			queuedCalls, _ := a.messageQueue.Get(call.SessionID)
193			a.messageQueue.Del(call.SessionID)
194			for _, queued := range queuedCalls {
195				userMessage, createErr := a.createUserMessage(genCtx, queued)
196				if createErr != nil {
197					return prepared, createErr
198				}
199				prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
200			}
201
202			lastSystemRoleInx := 0
203			systemMessageUpdated := false
204			for i, msg := range prepared.Messages {
205				// only add cache control to the last message
206				if msg.Role == ai.MessageRoleSystem {
207					lastSystemRoleInx = i
208				} else if !systemMessageUpdated {
209					prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
210					systemMessageUpdated = true
211				}
212				// than add cache control to the last 2 messages
213				if i > len(prepared.Messages)-3 {
214					prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
215				}
216			}
217			return prepared, err
218		},
219		OnReasoningDelta: func(id string, text string) error {
220			currentAssistant.AppendReasoningContent(text)
221			return a.messages.Update(genCtx, *currentAssistant)
222		},
223		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
224			// handle anthropic signature
225			if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
226				if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
227					currentAssistant.AppendReasoningSignature(reasoning.Signature)
228				}
229			}
230			currentAssistant.FinishThinking()
231			return a.messages.Update(genCtx, *currentAssistant)
232		},
233		OnTextDelta: func(id string, text string) error {
234			currentAssistant.AppendContent(text)
235			return a.messages.Update(genCtx, *currentAssistant)
236		},
237		OnToolInputStart: func(id string, toolName string) error {
238			toolCall := message.ToolCall{
239				ID:               id,
240				Name:             toolName,
241				ProviderExecuted: false,
242				Finished:         false,
243			}
244			currentAssistant.AddToolCall(toolCall)
245			return a.messages.Update(genCtx, *currentAssistant)
246		},
247		OnRetry: func(err *ai.APICallError, delay time.Duration) {
248			// TODO: implement
249		},
250		OnToolResult: func(result ai.ToolResultContent) error {
251			var resultContent string
252			isError := false
253			switch result.Result.GetType() {
254			case ai.ToolResultContentTypeText:
255				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
256				if ok {
257					resultContent = r.Text
258				}
259			case ai.ToolResultContentTypeError:
260				r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
261				if ok {
262					isError = true
263					resultContent = r.Error.Error()
264				}
265			case ai.ToolResultContentTypeMedia:
266				// TODO: handle this message type
267			}
268			toolResult := message.ToolResult{
269				ToolCallID: result.ToolCallID,
270				Name:       result.ToolName,
271				Content:    resultContent,
272				IsError:    isError,
273				Metadata:   result.ClientMetadata,
274			}
275			a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
276				Role: message.Tool,
277				Parts: []message.ContentPart{
278					toolResult,
279				},
280			})
281			return a.messages.Update(genCtx, *currentAssistant)
282		},
283		OnStepFinish: func(stepResult ai.StepResult) error {
284			finishReason := message.FinishReasonUnknown
285			switch stepResult.FinishReason {
286			case ai.FinishReasonLength:
287				finishReason = message.FinishReasonMaxTokens
288			case ai.FinishReasonStop:
289				finishReason = message.FinishReasonEndTurn
290			case ai.FinishReasonToolCalls:
291				finishReason = message.FinishReasonToolUse
292			}
293			currentAssistant.AddFinish(finishReason, "", "")
294			a.updateSessionUsage(a.largeModel, &currentSession, stepResult.Usage)
295			return a.messages.Update(genCtx, *currentAssistant)
296		},
297	})
298	if err != nil {
299		isCancelErr := errors.Is(err, context.Canceled)
300		isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
301		if currentAssistant == nil {
302			return result, err
303		}
304		toolCalls := currentAssistant.ToolCalls()
305		toolResults := currentAssistant.ToolResults()
306		for _, tc := range toolCalls {
307			if !tc.Finished {
308				tc.Finished = true
309				tc.Input = "{}"
310			}
311			currentAssistant.AddToolCall(tc)
312			found := false
313			for _, tr := range toolResults {
314				if tr.ToolCallID == tc.ID {
315					found = true
316					break
317				}
318			}
319			if !found {
320				content := "There was an error while executing the tool"
321				if isCancelErr {
322					content = "Tool execution canceled by user"
323				} else if isPermissionErr {
324					content = "Permission denied"
325				}
326				currentAssistant.AddToolResult(message.ToolResult{
327					ToolCallID: tc.ID,
328					Name:       tc.Name,
329					Content:    content,
330					IsError:    true,
331				})
332			}
333		}
334		if isCancelErr {
335			currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
336		} else if isPermissionErr {
337			currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
338		} else {
339			currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
340		}
341		// INFO: we use the parent context here because the genCtx might have been cancelled
342		updateErr := a.messages.Update(ctx, *currentAssistant)
343		if updateErr != nil {
344			return nil, updateErr
345		}
346	}
347	if err != nil {
348		return nil, err
349	}
350	wg.Wait()
351
352	queuedMessages, ok := a.messageQueue.Get(call.SessionID)
353	if !ok || len(queuedMessages) == 0 {
354		return result, err
355	}
356	// there are queued messages restart the loop
357	firstQueuedMessage := queuedMessages[0]
358	a.messageQueue.Set(call.SessionID, queuedMessages[1:])
359	return a.Run(genCtx, firstQueuedMessage)
360}
361
362func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
363	if a.IsSessionBusy(sessionID) {
364		return ErrSessionBusy
365	}
366
367	currentSession, err := a.sessions.Get(ctx, sessionID)
368	if err != nil {
369		return fmt.Errorf("failed to get session: %w", err)
370	}
371	msgs, err := a.getSessionMessages(ctx, currentSession)
372	if err != nil {
373		return err
374	}
375	if len(msgs) == 0 {
376		// nothing to summarize
377		return nil
378	}
379
380	aiMsgs, _ := a.preparePrompt(msgs)
381
382	genCtx, cancel := context.WithCancel(ctx)
383	a.activeRequests.Set(sessionID, cancel)
384	defer a.activeRequests.Del(sessionID)
385	defer cancel()
386
387	agent := ai.NewAgent(a.largeModel.model,
388		ai.WithSystemPrompt(string(summaryPrompt)),
389	)
390	summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
391		Role:     message.Assistant,
392		Model:    a.largeModel.model.Model(),
393		Provider: a.largeModel.model.Provider(),
394	})
395	if err != nil {
396		return err
397	}
398
399	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
400		Prompt:   "Provide a detailed summary of our conversation above.",
401		Messages: aiMsgs,
402		OnReasoningDelta: func(id string, text string) error {
403			summaryMessage.AppendReasoningContent(text)
404			return a.messages.Update(ctx, summaryMessage)
405		},
406		OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
407			// handle anthropic signature
408			if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
409				if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
410					summaryMessage.AppendReasoningSignature(signature.Signature)
411				}
412			}
413			summaryMessage.FinishThinking()
414			return a.messages.Update(ctx, summaryMessage)
415		},
416		OnTextDelta: func(id, text string) error {
417			summaryMessage.AppendContent(text)
418			return a.messages.Update(ctx, summaryMessage)
419		},
420	})
421	if err != nil {
422		return err
423	}
424
425	summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
426	err = a.messages.Update(genCtx, summaryMessage)
427	if err != nil {
428		return err
429	}
430
431	a.updateSessionUsage(a.largeModel, &currentSession, resp.TotalUsage)
432
433	// just in case get just the last usage
434	usage := resp.Response.Usage
435	currentSession.SummaryMessageID = summaryMessage.ID
436	currentSession.CompletionTokens = usage.OutputTokens
437	currentSession.PromptTokens = 0
438	_, err = a.sessions.Save(genCtx, currentSession)
439	return err
440}
441
442func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
443	return ai.ProviderOptions{
444		anthropic.Name: &anthropic.ProviderCacheControlOptions{
445			CacheControl: anthropic.CacheControl{Type: "ephemeral"},
446		},
447	}
448}
449
450func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
451	var attachmentParts []message.ContentPart
452	for _, attachment := range call.Attachments {
453		attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
454	}
455	parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
456	parts = append(parts, attachmentParts...)
457	msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
458		Role:  message.User,
459		Parts: parts,
460	})
461	if err != nil {
462		return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
463	}
464	return msg, nil
465}
466
467func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
468	var history []ai.Message
469	for _, m := range msgs {
470		if len(m.Parts) == 0 {
471			continue
472		}
473		// Assistant message without content or tool calls (cancelled before it returned anything)
474		if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
475			continue
476		}
477		history = append(history, m.ToAIMessage()...)
478	}
479
480	var files []ai.FilePart
481	for _, attachment := range attachments {
482		files = append(files, ai.FilePart{
483			Filename:  attachment.FileName,
484			Data:      attachment.Content,
485			MediaType: attachment.MimeType,
486		})
487	}
488
489	return history, files
490}
491
492func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
493	msgs, err := a.messages.List(ctx, session.ID)
494	if err != nil {
495		return nil, fmt.Errorf("failed to list messages: %w", err)
496	}
497
498	if session.SummaryMessageID != "" {
499		summaryMsgInex := -1
500		for i, msg := range msgs {
501			if msg.ID == session.SummaryMessageID {
502				summaryMsgInex = i
503				break
504			}
505		}
506		if summaryMsgInex != -1 {
507			msgs = msgs[summaryMsgInex:]
508			msgs[0].Role = message.User
509		}
510	}
511	return msgs, nil
512}
513
514func (a *sessionAgent) generateTitle(ctx context.Context, session session.Session, prompt string) {
515	if prompt == "" {
516		return
517	}
518
519	agent := ai.NewAgent(a.smallModel.model,
520		ai.WithSystemPrompt(string(titlePrompt)),
521		ai.WithMaxOutputTokens(40),
522	)
523
524	resp, err := agent.Stream(ctx, ai.AgentStreamCall{
525		Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", prompt),
526	})
527	if err != nil {
528		slog.Error("error generating title", "err", err)
529		return
530	}
531
532	title := resp.Response.Content.Text()
533
534	title = strings.ReplaceAll(title, "\n", " ")
535
536	// remove thinking tags if present
537	if idx := strings.Index(title, "</think>"); idx > 0 {
538		title = title[idx+len("</think>"):]
539	}
540
541	title = strings.TrimSpace(title)
542	if title == "" {
543		slog.Warn("failed to generate title", "warn", "empty title")
544		return
545	}
546
547	session.Title = title
548	a.updateSessionUsage(a.smallModel, &session, resp.TotalUsage)
549	_, saveErr := a.sessions.Save(ctx, session)
550	if saveErr != nil {
551		slog.Error("failed to save session title & usage", "error", saveErr)
552		return
553	}
554}
555
556func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
557	modelConfig := model.config
558	cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
559		modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
560		modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
561		modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
562	session.Cost += cost
563	session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
564	session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
565}
566
567func (a *sessionAgent) Cancel(sessionID string) {
568	// Cancel regular requests
569	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
570		slog.Info("Request cancellation initiated", "session_id", sessionID)
571		cancel()
572	}
573
574	// Also check for summarize requests
575	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
576		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
577		cancel()
578	}
579
580	if a.QueuedPrompts(sessionID) > 0 {
581		slog.Info("Clearing queued prompts", "session_id", sessionID)
582		a.messageQueue.Del(sessionID)
583	}
584}
585
586func (a *sessionAgent) ClearQueue(sessionID string) {
587	if a.QueuedPrompts(sessionID) > 0 {
588		slog.Info("Clearing queued prompts", "session_id", sessionID)
589		a.messageQueue.Del(sessionID)
590	}
591}
592
593func (a *sessionAgent) CancelAll() {
594	if !a.IsBusy() {
595		return
596	}
597	for key := range a.activeRequests.Seq2() {
598		a.Cancel(key) // key is sessionID
599	}
600
601	timeout := time.After(5 * time.Second)
602	for a.IsBusy() {
603		select {
604		case <-timeout:
605			return
606		default:
607			time.Sleep(200 * time.Millisecond)
608		}
609	}
610}
611
612func (a *sessionAgent) IsBusy() bool {
613	var busy bool
614	for cancelFunc := range a.activeRequests.Seq() {
615		if cancelFunc != nil {
616			busy = true
617			break
618		}
619	}
620	return busy
621}
622
623func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
624	_, busy := a.activeRequests.Get(sessionID)
625	return busy
626}
627
628func (a *sessionAgent) QueuedPrompts(sessionID string) int {
629	l, ok := a.messageQueue.Get(sessionID)
630	if !ok {
631		return 0
632	}
633	return len(l)
634}
635
636func (a *sessionAgent) SetModels(large Model, small Model) {
637	a.largeModel = large
638	a.smallModel = small
639}
640
641func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
642	a.tools = tools
643}