1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/agent/tools"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/message"
17 "github.com/charmbracelet/crush/internal/permission"
18 "github.com/charmbracelet/crush/internal/session"
19 "github.com/charmbracelet/fantasy/ai"
20 "github.com/charmbracelet/fantasy/anthropic"
21)
22
23//go:embed templates/title.md
24var titlePrompt []byte
25
26//go:embed templates/summary.md
27var summaryPrompt []byte
28
29type SessionAgentCall struct {
30 SessionID string
31 Prompt string
32 ProviderOptions ai.ProviderOptions
33 Attachments []message.Attachment
34 MaxOutputTokens int64
35 Temperature *float64
36 TopP *float64
37 TopK *int64
38 FrequencyPenalty *float64
39 PresencePenalty *float64
40}
41
42type SessionAgent interface {
43 Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
44 SetModels(large Model, small Model)
45 SetTools(tools []ai.AgentTool)
46 Cancel(sessionID string)
47 CancelAll()
48 IsSessionBusy(sessionID string) bool
49 IsBusy() bool
50 QueuedPrompts(sessionID string) int
51 ClearQueue(sessionID string)
52 Summarize(context.Context, string) error
53}
54
55type Model struct {
56 model ai.LanguageModel
57 config catwalk.Model
58}
59
60type sessionAgent struct {
61 largeModel Model
62 smallModel Model
63 systemPrompt string
64 tools []ai.AgentTool
65 sessions session.Service
66 messages message.Service
67
68 messageQueue *csync.Map[string, []SessionAgentCall]
69 activeRequests *csync.Map[string, context.CancelFunc]
70}
71
72type SessionAgentOption func(*sessionAgent)
73
74func NewSessionAgent(
75 largeModel Model,
76 smallModel Model,
77 systemPrompt string,
78 sessions session.Service,
79 messages message.Service,
80 tools ...ai.AgentTool,
81) SessionAgent {
82 return &sessionAgent{
83 largeModel: largeModel,
84 smallModel: smallModel,
85 systemPrompt: systemPrompt,
86 sessions: sessions,
87 messages: messages,
88 tools: tools,
89 messageQueue: csync.NewMap[string, []SessionAgentCall](),
90 activeRequests: csync.NewMap[string, context.CancelFunc](),
91 }
92}
93
94func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
95 if call.Prompt == "" {
96 return nil, ErrEmptyPrompt
97 }
98 if call.SessionID == "" {
99 return nil, ErrSessionMissing
100 }
101
102 // Queue the message if busy
103 if a.IsSessionBusy(call.SessionID) {
104 existing, ok := a.messageQueue.Get(call.SessionID)
105 if !ok {
106 existing = []SessionAgentCall{}
107 }
108 existing = append(existing, call)
109 a.messageQueue.Set(call.SessionID, existing)
110 return nil, nil
111 }
112
113 if len(a.tools) > 0 {
114 // add anthropic caching to the last tool
115 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
116 }
117
118 agent := ai.NewAgent(
119 a.largeModel.model,
120 ai.WithSystemPrompt(a.systemPrompt),
121 ai.WithTools(a.tools...),
122 )
123
124 currentSession, err := a.sessions.Get(ctx, call.SessionID)
125 if err != nil {
126 return nil, fmt.Errorf("failed to get session: %w", err)
127 }
128
129 msgs, err := a.getSessionMessages(ctx, currentSession)
130 if err != nil {
131 return nil, fmt.Errorf("failed to get session messages: %w", err)
132 }
133
134 var wg sync.WaitGroup
135 // Generate title if first message
136 if len(msgs) == 0 {
137 wg.Go(func() {
138 a.generateTitle(ctx, currentSession, call.Prompt)
139 })
140 }
141
142 // Add the user message to the session
143 _, err = a.createUserMessage(ctx, call)
144 if err != nil {
145 return nil, err
146 }
147
148 // add the session to the context
149 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
150
151 genCtx, cancel := context.WithCancel(ctx)
152 a.activeRequests.Set(call.SessionID, cancel)
153
154 defer cancel()
155 defer a.activeRequests.Del(call.SessionID)
156
157 history, files := a.preparePrompt(msgs, call.Attachments...)
158
159 var currentAssistant *message.Message
160 result, err := agent.Stream(genCtx, ai.AgentStreamCall{
161 Prompt: call.Prompt,
162 Files: files,
163 Messages: history,
164 ProviderOptions: call.ProviderOptions,
165 MaxOutputTokens: &call.MaxOutputTokens,
166 TopP: call.TopP,
167 Temperature: call.Temperature,
168 PresencePenalty: call.PresencePenalty,
169 TopK: call.TopK,
170 FrequencyPenalty: call.FrequencyPenalty,
171 // Before each step create the new assistant message
172 PrepareStep: func(options ai.PrepareStepFunctionOptions) (prepared ai.PrepareStepResult, err error) {
173 var assistantMsg message.Message
174 assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
175 Role: message.Assistant,
176 Parts: []message.ContentPart{},
177 Model: a.largeModel.model.Model(),
178 Provider: a.largeModel.model.Provider(),
179 })
180 if err != nil {
181 return prepared, err
182 }
183
184 currentAssistant = &assistantMsg
185
186 prepared.Messages = options.Messages
187 // reset all cached items
188 for i := range prepared.Messages {
189 prepared.Messages[i].ProviderOptions = nil
190 }
191
192 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
193 a.messageQueue.Del(call.SessionID)
194 for _, queued := range queuedCalls {
195 userMessage, createErr := a.createUserMessage(genCtx, queued)
196 if createErr != nil {
197 return prepared, createErr
198 }
199 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
200 }
201
202 lastSystemRoleInx := 0
203 systemMessageUpdated := false
204 for i, msg := range prepared.Messages {
205 // only add cache control to the last message
206 if msg.Role == ai.MessageRoleSystem {
207 lastSystemRoleInx = i
208 } else if !systemMessageUpdated {
209 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
210 systemMessageUpdated = true
211 }
212 // than add cache control to the last 2 messages
213 if i > len(prepared.Messages)-3 {
214 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
215 }
216 }
217 return prepared, err
218 },
219 OnReasoningDelta: func(id string, text string) error {
220 currentAssistant.AppendReasoningContent(text)
221 return a.messages.Update(genCtx, *currentAssistant)
222 },
223 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
224 // handle anthropic signature
225 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
226 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
227 currentAssistant.AppendReasoningSignature(reasoning.Signature)
228 }
229 }
230 currentAssistant.FinishThinking()
231 return a.messages.Update(genCtx, *currentAssistant)
232 },
233 OnTextDelta: func(id string, text string) error {
234 currentAssistant.AppendContent(text)
235 return a.messages.Update(genCtx, *currentAssistant)
236 },
237 OnToolInputStart: func(id string, toolName string) error {
238 toolCall := message.ToolCall{
239 ID: id,
240 Name: toolName,
241 ProviderExecuted: false,
242 Finished: false,
243 }
244 currentAssistant.AddToolCall(toolCall)
245 return a.messages.Update(genCtx, *currentAssistant)
246 },
247 OnRetry: func(err *ai.APICallError, delay time.Duration) {
248 // TODO: implement
249 },
250 OnToolResult: func(result ai.ToolResultContent) error {
251 var resultContent string
252 isError := false
253 switch result.Result.GetType() {
254 case ai.ToolResultContentTypeText:
255 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
256 if ok {
257 resultContent = r.Text
258 }
259 case ai.ToolResultContentTypeError:
260 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
261 if ok {
262 isError = true
263 resultContent = r.Error.Error()
264 }
265 case ai.ToolResultContentTypeMedia:
266 // TODO: handle this message type
267 }
268 toolResult := message.ToolResult{
269 ToolCallID: result.ToolCallID,
270 Name: result.ToolName,
271 Content: resultContent,
272 IsError: isError,
273 Metadata: result.ClientMetadata,
274 }
275 a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
276 Role: message.Tool,
277 Parts: []message.ContentPart{
278 toolResult,
279 },
280 })
281 return a.messages.Update(genCtx, *currentAssistant)
282 },
283 OnStepFinish: func(stepResult ai.StepResult) error {
284 finishReason := message.FinishReasonUnknown
285 switch stepResult.FinishReason {
286 case ai.FinishReasonLength:
287 finishReason = message.FinishReasonMaxTokens
288 case ai.FinishReasonStop:
289 finishReason = message.FinishReasonEndTurn
290 case ai.FinishReasonToolCalls:
291 finishReason = message.FinishReasonToolUse
292 }
293 currentAssistant.AddFinish(finishReason, "", "")
294 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
295 return a.messages.Update(genCtx, *currentAssistant)
296 },
297 })
298 if err != nil {
299 isCancelErr := errors.Is(err, context.Canceled)
300 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
301 if currentAssistant == nil {
302 return result, err
303 }
304 toolCalls := currentAssistant.ToolCalls()
305 toolResults := currentAssistant.ToolResults()
306 for _, tc := range toolCalls {
307 if !tc.Finished {
308 tc.Finished = true
309 tc.Input = "{}"
310 }
311 currentAssistant.AddToolCall(tc)
312 found := false
313 for _, tr := range toolResults {
314 if tr.ToolCallID == tc.ID {
315 found = true
316 break
317 }
318 }
319 if !found {
320 content := "There was an error while executing the tool"
321 if isCancelErr {
322 content = "Tool execution canceled by user"
323 } else if isPermissionErr {
324 content = "Permission denied"
325 }
326 currentAssistant.AddToolResult(message.ToolResult{
327 ToolCallID: tc.ID,
328 Name: tc.Name,
329 Content: content,
330 IsError: true,
331 })
332 }
333 }
334 if isCancelErr {
335 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
336 } else if isPermissionErr {
337 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
338 } else {
339 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
340 }
341 // INFO: we use the parent context here because the genCtx might have been cancelled
342 updateErr := a.messages.Update(ctx, *currentAssistant)
343 if updateErr != nil {
344 return nil, updateErr
345 }
346 }
347 if err != nil {
348 return nil, err
349 }
350 wg.Wait()
351
352 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
353 if !ok || len(queuedMessages) == 0 {
354 return result, err
355 }
356 // there are queued messages restart the loop
357 firstQueuedMessage := queuedMessages[0]
358 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
359 return a.Run(genCtx, firstQueuedMessage)
360}
361
362func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
363 if a.IsSessionBusy(sessionID) {
364 return ErrSessionBusy
365 }
366
367 currentSession, err := a.sessions.Get(ctx, sessionID)
368 if err != nil {
369 return fmt.Errorf("failed to get session: %w", err)
370 }
371 msgs, err := a.getSessionMessages(ctx, currentSession)
372 if err != nil {
373 return err
374 }
375 if len(msgs) == 0 {
376 // nothing to summarize
377 return nil
378 }
379
380 aiMsgs, _ := a.preparePrompt(msgs)
381
382 genCtx, cancel := context.WithCancel(ctx)
383 a.activeRequests.Set(sessionID, cancel)
384 defer a.activeRequests.Del(sessionID)
385 defer cancel()
386
387 agent := ai.NewAgent(a.largeModel.model,
388 ai.WithSystemPrompt(string(summaryPrompt)),
389 )
390 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
391 Role: message.Assistant,
392 Model: a.largeModel.model.Model(),
393 Provider: a.largeModel.model.Provider(),
394 })
395 if err != nil {
396 return err
397 }
398
399 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
400 Prompt: "Provide a detailed summary of our conversation above.",
401 Messages: aiMsgs,
402 OnReasoningDelta: func(id string, text string) error {
403 summaryMessage.AppendReasoningContent(text)
404 return a.messages.Update(ctx, summaryMessage)
405 },
406 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
407 // handle anthropic signature
408 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
409 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
410 summaryMessage.AppendReasoningSignature(signature.Signature)
411 }
412 }
413 summaryMessage.FinishThinking()
414 return a.messages.Update(ctx, summaryMessage)
415 },
416 OnTextDelta: func(id, text string) error {
417 summaryMessage.AppendContent(text)
418 return a.messages.Update(ctx, summaryMessage)
419 },
420 })
421 if err != nil {
422 return err
423 }
424
425 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
426 err = a.messages.Update(genCtx, summaryMessage)
427 if err != nil {
428 return err
429 }
430
431 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
432
433 // just in case get just the last usage
434 usage := resp.Response.Usage
435 currentSession.SummaryMessageID = summaryMessage.ID
436 currentSession.CompletionTokens = usage.OutputTokens
437 currentSession.PromptTokens = 0
438 _, err = a.sessions.Save(genCtx, currentSession)
439 return err
440}
441
442func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
443 return ai.ProviderOptions{
444 anthropic.Name: &anthropic.ProviderCacheControlOptions{
445 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
446 },
447 }
448}
449
450func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
451 var attachmentParts []message.ContentPart
452 for _, attachment := range call.Attachments {
453 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
454 }
455 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
456 parts = append(parts, attachmentParts...)
457 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
458 Role: message.User,
459 Parts: parts,
460 })
461 if err != nil {
462 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
463 }
464 return msg, nil
465}
466
467func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
468 var history []ai.Message
469 for _, m := range msgs {
470 if len(m.Parts) == 0 {
471 continue
472 }
473 // Assistant message without content or tool calls (cancelled before it returned anything)
474 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
475 continue
476 }
477 history = append(history, m.ToAIMessage()...)
478 }
479
480 var files []ai.FilePart
481 for _, attachment := range attachments {
482 files = append(files, ai.FilePart{
483 Filename: attachment.FileName,
484 Data: attachment.Content,
485 MediaType: attachment.MimeType,
486 })
487 }
488
489 return history, files
490}
491
492func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
493 msgs, err := a.messages.List(ctx, session.ID)
494 if err != nil {
495 return nil, fmt.Errorf("failed to list messages: %w", err)
496 }
497
498 if session.SummaryMessageID != "" {
499 summaryMsgInex := -1
500 for i, msg := range msgs {
501 if msg.ID == session.SummaryMessageID {
502 summaryMsgInex = i
503 break
504 }
505 }
506 if summaryMsgInex != -1 {
507 msgs = msgs[summaryMsgInex:]
508 msgs[0].Role = message.User
509 }
510 }
511 return msgs, nil
512}
513
514func (a *sessionAgent) generateTitle(ctx context.Context, session session.Session, prompt string) {
515 if prompt == "" {
516 return
517 }
518
519 agent := ai.NewAgent(a.smallModel.model,
520 ai.WithSystemPrompt(string(titlePrompt)),
521 ai.WithMaxOutputTokens(40),
522 )
523
524 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
525 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", prompt),
526 })
527 if err != nil {
528 slog.Error("error generating title", "err", err)
529 return
530 }
531
532 title := resp.Response.Content.Text()
533
534 title = strings.ReplaceAll(title, "\n", " ")
535
536 // remove thinking tags if present
537 if idx := strings.Index(title, "</think>"); idx > 0 {
538 title = title[idx+len("</think>"):]
539 }
540
541 title = strings.TrimSpace(title)
542 if title == "" {
543 slog.Warn("failed to generate title", "warn", "empty title")
544 return
545 }
546
547 session.Title = title
548 a.updateSessionUsage(a.smallModel, &session, resp.TotalUsage)
549 _, saveErr := a.sessions.Save(ctx, session)
550 if saveErr != nil {
551 slog.Error("failed to save session title & usage", "error", saveErr)
552 return
553 }
554}
555
556func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
557 modelConfig := model.config
558 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
559 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
560 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
561 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
562 session.Cost += cost
563 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
564 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
565}
566
567func (a *sessionAgent) Cancel(sessionID string) {
568 // Cancel regular requests
569 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
570 slog.Info("Request cancellation initiated", "session_id", sessionID)
571 cancel()
572 }
573
574 // Also check for summarize requests
575 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
576 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
577 cancel()
578 }
579
580 if a.QueuedPrompts(sessionID) > 0 {
581 slog.Info("Clearing queued prompts", "session_id", sessionID)
582 a.messageQueue.Del(sessionID)
583 }
584}
585
586func (a *sessionAgent) ClearQueue(sessionID string) {
587 if a.QueuedPrompts(sessionID) > 0 {
588 slog.Info("Clearing queued prompts", "session_id", sessionID)
589 a.messageQueue.Del(sessionID)
590 }
591}
592
593func (a *sessionAgent) CancelAll() {
594 if !a.IsBusy() {
595 return
596 }
597 for key := range a.activeRequests.Seq2() {
598 a.Cancel(key) // key is sessionID
599 }
600
601 timeout := time.After(5 * time.Second)
602 for a.IsBusy() {
603 select {
604 case <-timeout:
605 return
606 default:
607 time.Sleep(200 * time.Millisecond)
608 }
609 }
610}
611
612func (a *sessionAgent) IsBusy() bool {
613 var busy bool
614 for cancelFunc := range a.activeRequests.Seq() {
615 if cancelFunc != nil {
616 busy = true
617 break
618 }
619 }
620 return busy
621}
622
623func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
624 _, busy := a.activeRequests.Get(sessionID)
625 return busy
626}
627
628func (a *sessionAgent) QueuedPrompts(sessionID string) int {
629 l, ok := a.messageQueue.Get(sessionID)
630 if !ok {
631 return 0
632 }
633 return len(l)
634}
635
636func (a *sessionAgent) SetModels(large Model, small Model) {
637 a.largeModel = large
638 a.smallModel = small
639}
640
641func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
642 a.tools = tools
643}