1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/agent/tools"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/session"
20 "github.com/charmbracelet/fantasy/ai"
21 "github.com/charmbracelet/fantasy/anthropic"
22)
23
24//go:embed templates/title.md
25var titlePrompt []byte
26
27//go:embed templates/summary.md
28var summaryPrompt []byte
29
30type SessionAgentCall struct {
31 SessionID string
32 Prompt string
33 ProviderOptions ai.ProviderOptions
34 Attachments []message.Attachment
35 MaxOutputTokens int64
36 Temperature *float64
37 TopP *float64
38 TopK *int64
39 FrequencyPenalty *float64
40 PresencePenalty *float64
41}
42
43type SessionAgent interface {
44 Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
45 SetModels(large Model, small Model)
46 SetTools(tools []ai.AgentTool)
47 Cancel(sessionID string)
48 CancelAll()
49 IsSessionBusy(sessionID string) bool
50 IsBusy() bool
51 QueuedPrompts(sessionID string) int
52 ClearQueue(sessionID string)
53 Summarize(context.Context, string) error
54 Model() Model
55}
56
57type Model struct {
58 Model ai.LanguageModel
59 CatwalkCfg catwalk.Model
60 ModelCfg config.SelectedModel
61}
62
63type sessionAgent struct {
64 largeModel Model
65 smallModel Model
66 systemPrompt string
67 tools []ai.AgentTool
68 sessions session.Service
69 messages message.Service
70
71 messageQueue *csync.Map[string, []SessionAgentCall]
72 activeRequests *csync.Map[string, context.CancelFunc]
73}
74
75type SessionAgentOption func(*sessionAgent)
76
77func NewSessionAgent(
78 largeModel Model,
79 smallModel Model,
80 systemPrompt string,
81 sessions session.Service,
82 messages message.Service,
83 tools ...ai.AgentTool,
84) SessionAgent {
85 return &sessionAgent{
86 largeModel: largeModel,
87 smallModel: smallModel,
88 systemPrompt: systemPrompt,
89 sessions: sessions,
90 messages: messages,
91 tools: tools,
92 messageQueue: csync.NewMap[string, []SessionAgentCall](),
93 activeRequests: csync.NewMap[string, context.CancelFunc](),
94 }
95}
96
97func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
98 if call.Prompt == "" {
99 return nil, ErrEmptyPrompt
100 }
101 if call.SessionID == "" {
102 return nil, ErrSessionMissing
103 }
104
105 // Queue the message if busy
106 if a.IsSessionBusy(call.SessionID) {
107 existing, ok := a.messageQueue.Get(call.SessionID)
108 if !ok {
109 existing = []SessionAgentCall{}
110 }
111 existing = append(existing, call)
112 a.messageQueue.Set(call.SessionID, existing)
113 return nil, nil
114 }
115
116 if len(a.tools) > 0 {
117 // add anthropic caching to the last tool
118 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
119 }
120
121 agent := ai.NewAgent(
122 a.largeModel.Model,
123 ai.WithSystemPrompt(a.systemPrompt),
124 ai.WithTools(a.tools...),
125 )
126
127 sessionLock := sync.Mutex{}
128 currentSession, err := a.sessions.Get(ctx, call.SessionID)
129 if err != nil {
130 return nil, fmt.Errorf("failed to get session: %w", err)
131 }
132
133 msgs, err := a.getSessionMessages(ctx, currentSession)
134 if err != nil {
135 return nil, fmt.Errorf("failed to get session messages: %w", err)
136 }
137
138 var wg sync.WaitGroup
139 // Generate title if first message
140 if len(msgs) == 0 {
141 wg.Go(func() {
142 sessionLock.Lock()
143 a.generateTitle(ctx, ¤tSession, call.Prompt)
144 sessionLock.Unlock()
145 })
146 }
147
148 // Add the user message to the session
149 _, err = a.createUserMessage(ctx, call)
150 if err != nil {
151 return nil, err
152 }
153
154 // add the session to the context
155 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
156
157 genCtx, cancel := context.WithCancel(ctx)
158 a.activeRequests.Set(call.SessionID, cancel)
159
160 defer cancel()
161 defer a.activeRequests.Del(call.SessionID)
162
163 history, files := a.preparePrompt(msgs, call.Attachments...)
164
165 var currentAssistant *message.Message
166 result, err := agent.Stream(genCtx, ai.AgentStreamCall{
167 Prompt: call.Prompt,
168 Files: files,
169 Messages: history,
170 ProviderOptions: call.ProviderOptions,
171 MaxOutputTokens: &call.MaxOutputTokens,
172 TopP: call.TopP,
173 Temperature: call.Temperature,
174 PresencePenalty: call.PresencePenalty,
175 TopK: call.TopK,
176 FrequencyPenalty: call.FrequencyPenalty,
177 // Before each step create the new assistant message
178 PrepareStep: func(callContext context.Context, options ai.PrepareStepFunctionOptions) (_ context.Context, prepared ai.PrepareStepResult, err error) {
179 var assistantMsg message.Message
180 assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
181 Role: message.Assistant,
182 Parts: []message.ContentPart{},
183 Model: a.largeModel.ModelCfg.Model,
184 Provider: a.largeModel.ModelCfg.Provider,
185 })
186 if err != nil {
187 return callContext, prepared, err
188 }
189
190 callContext = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
191
192 currentAssistant = &assistantMsg
193
194 prepared.Messages = options.Messages
195 // reset all cached items
196 for i := range prepared.Messages {
197 prepared.Messages[i].ProviderOptions = nil
198 }
199
200 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
201 a.messageQueue.Del(call.SessionID)
202 for _, queued := range queuedCalls {
203 userMessage, createErr := a.createUserMessage(genCtx, queued)
204 if createErr != nil {
205 return callContext, prepared, createErr
206 }
207 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
208 }
209
210 lastSystemRoleInx := 0
211 systemMessageUpdated := false
212 for i, msg := range prepared.Messages {
213 // only add cache control to the last message
214 if msg.Role == ai.MessageRoleSystem {
215 lastSystemRoleInx = i
216 } else if !systemMessageUpdated {
217 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
218 systemMessageUpdated = true
219 }
220 // than add cache control to the last 2 messages
221 if i > len(prepared.Messages)-3 {
222 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
223 }
224 }
225 return callContext, prepared, err
226 },
227 OnReasoningDelta: func(id string, text string) error {
228 currentAssistant.AppendReasoningContent(text)
229 return a.messages.Update(genCtx, *currentAssistant)
230 },
231 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
232 // handle anthropic signature
233 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
234 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
235 currentAssistant.AppendReasoningSignature(reasoning.Signature)
236 }
237 }
238 currentAssistant.FinishThinking()
239 return a.messages.Update(genCtx, *currentAssistant)
240 },
241 OnTextDelta: func(id string, text string) error {
242 currentAssistant.AppendContent(text)
243 return a.messages.Update(genCtx, *currentAssistant)
244 },
245 OnToolInputStart: func(id string, toolName string) error {
246 toolCall := message.ToolCall{
247 ID: id,
248 Name: toolName,
249 ProviderExecuted: false,
250 Finished: false,
251 }
252 currentAssistant.AddToolCall(toolCall)
253 return a.messages.Update(genCtx, *currentAssistant)
254 },
255 OnRetry: func(err *ai.APICallError, delay time.Duration) {
256 // TODO: implement
257 },
258 OnToolCall: func(tc ai.ToolCallContent) error {
259 toolCall := message.ToolCall{
260 ID: tc.ToolCallID,
261 Name: tc.ToolName,
262 Input: tc.Input,
263 ProviderExecuted: false,
264 Finished: true,
265 }
266 currentAssistant.AddToolCall(toolCall)
267 return a.messages.Update(genCtx, *currentAssistant)
268 },
269 OnToolResult: func(result ai.ToolResultContent) error {
270 var resultContent string
271 isError := false
272 switch result.Result.GetType() {
273 case ai.ToolResultContentTypeText:
274 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
275 if ok {
276 resultContent = r.Text
277 }
278 case ai.ToolResultContentTypeError:
279 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
280 if ok {
281 isError = true
282 resultContent = r.Error.Error()
283 }
284 case ai.ToolResultContentTypeMedia:
285 // TODO: handle this message type
286 }
287 toolResult := message.ToolResult{
288 ToolCallID: result.ToolCallID,
289 Name: result.ToolName,
290 Content: resultContent,
291 IsError: isError,
292 Metadata: result.ClientMetadata,
293 }
294 a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
295 Role: message.Tool,
296 Parts: []message.ContentPart{
297 toolResult,
298 },
299 })
300 return a.messages.Update(genCtx, *currentAssistant)
301 },
302 OnStepFinish: func(stepResult ai.StepResult) error {
303 finishReason := message.FinishReasonUnknown
304 switch stepResult.FinishReason {
305 case ai.FinishReasonLength:
306 finishReason = message.FinishReasonMaxTokens
307 case ai.FinishReasonStop:
308 finishReason = message.FinishReasonEndTurn
309 case ai.FinishReasonToolCalls:
310 finishReason = message.FinishReasonToolUse
311 }
312 currentAssistant.AddFinish(finishReason, "", "")
313 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
314 sessionLock.Lock()
315 _, sessionErr := a.sessions.Save(genCtx, currentSession)
316 sessionLock.Unlock()
317 if sessionErr != nil {
318 return sessionErr
319 }
320 return a.messages.Update(genCtx, *currentAssistant)
321 },
322 })
323 if err != nil {
324 isCancelErr := errors.Is(err, context.Canceled)
325 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
326 if currentAssistant == nil {
327 return result, err
328 }
329 toolCalls := currentAssistant.ToolCalls()
330 toolResults := currentAssistant.ToolResults()
331 for _, tc := range toolCalls {
332 if !tc.Finished {
333 tc.Finished = true
334 tc.Input = "{}"
335 }
336 currentAssistant.AddToolCall(tc)
337 found := false
338 for _, tr := range toolResults {
339 if tr.ToolCallID == tc.ID {
340 found = true
341 break
342 }
343 }
344 if !found {
345 content := "There was an error while executing the tool"
346 if isCancelErr {
347 content = "Tool execution canceled by user"
348 } else if isPermissionErr {
349 content = "Permission denied"
350 }
351 currentAssistant.AddToolResult(message.ToolResult{
352 ToolCallID: tc.ID,
353 Name: tc.Name,
354 Content: content,
355 IsError: true,
356 })
357 }
358 }
359 if isCancelErr {
360 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
361 } else if isPermissionErr {
362 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
363 } else {
364 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
365 }
366 // INFO: we use the parent context here because the genCtx might have been cancelled
367 updateErr := a.messages.Update(ctx, *currentAssistant)
368 if updateErr != nil {
369 return nil, updateErr
370 }
371 }
372 if err != nil {
373 return nil, err
374 }
375 wg.Wait()
376
377 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
378 if !ok || len(queuedMessages) == 0 {
379 return result, err
380 }
381 // there are queued messages restart the loop
382 firstQueuedMessage := queuedMessages[0]
383 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
384 return a.Run(genCtx, firstQueuedMessage)
385}
386
387func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
388 if a.IsSessionBusy(sessionID) {
389 return ErrSessionBusy
390 }
391
392 currentSession, err := a.sessions.Get(ctx, sessionID)
393 if err != nil {
394 return fmt.Errorf("failed to get session: %w", err)
395 }
396 msgs, err := a.getSessionMessages(ctx, currentSession)
397 if err != nil {
398 return err
399 }
400 if len(msgs) == 0 {
401 // nothing to summarize
402 return nil
403 }
404
405 aiMsgs, _ := a.preparePrompt(msgs)
406
407 genCtx, cancel := context.WithCancel(ctx)
408 a.activeRequests.Set(sessionID, cancel)
409 defer a.activeRequests.Del(sessionID)
410 defer cancel()
411
412 agent := ai.NewAgent(a.largeModel.Model,
413 ai.WithSystemPrompt(string(summaryPrompt)),
414 )
415 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
416 Role: message.Assistant,
417 Model: a.largeModel.Model.Model(),
418 Provider: a.largeModel.Model.Provider(),
419 })
420 if err != nil {
421 return err
422 }
423
424 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
425 Prompt: "Provide a detailed summary of our conversation above.",
426 Messages: aiMsgs,
427 OnReasoningDelta: func(id string, text string) error {
428 summaryMessage.AppendReasoningContent(text)
429 return a.messages.Update(ctx, summaryMessage)
430 },
431 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
432 // handle anthropic signature
433 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
434 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
435 summaryMessage.AppendReasoningSignature(signature.Signature)
436 }
437 }
438 summaryMessage.FinishThinking()
439 return a.messages.Update(ctx, summaryMessage)
440 },
441 OnTextDelta: func(id, text string) error {
442 summaryMessage.AppendContent(text)
443 return a.messages.Update(ctx, summaryMessage)
444 },
445 })
446 if err != nil {
447 return err
448 }
449
450 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
451 err = a.messages.Update(genCtx, summaryMessage)
452 if err != nil {
453 return err
454 }
455
456 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
457
458 // just in case get just the last usage
459 usage := resp.Response.Usage
460 currentSession.SummaryMessageID = summaryMessage.ID
461 currentSession.CompletionTokens = usage.OutputTokens
462 currentSession.PromptTokens = 0
463 _, err = a.sessions.Save(genCtx, currentSession)
464 return err
465}
466
467func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
468 return ai.ProviderOptions{
469 anthropic.Name: &anthropic.ProviderCacheControlOptions{
470 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
471 },
472 }
473}
474
475func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
476 var attachmentParts []message.ContentPart
477 for _, attachment := range call.Attachments {
478 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
479 }
480 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
481 parts = append(parts, attachmentParts...)
482 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
483 Role: message.User,
484 Parts: parts,
485 })
486 if err != nil {
487 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
488 }
489 return msg, nil
490}
491
492func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
493 var history []ai.Message
494 for _, m := range msgs {
495 if len(m.Parts) == 0 {
496 continue
497 }
498 // Assistant message without content or tool calls (cancelled before it returned anything)
499 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
500 continue
501 }
502 history = append(history, m.ToAIMessage()...)
503 }
504
505 var files []ai.FilePart
506 for _, attachment := range attachments {
507 files = append(files, ai.FilePart{
508 Filename: attachment.FileName,
509 Data: attachment.Content,
510 MediaType: attachment.MimeType,
511 })
512 }
513
514 return history, files
515}
516
517func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
518 msgs, err := a.messages.List(ctx, session.ID)
519 if err != nil {
520 return nil, fmt.Errorf("failed to list messages: %w", err)
521 }
522
523 if session.SummaryMessageID != "" {
524 summaryMsgInex := -1
525 for i, msg := range msgs {
526 if msg.ID == session.SummaryMessageID {
527 summaryMsgInex = i
528 break
529 }
530 }
531 if summaryMsgInex != -1 {
532 msgs = msgs[summaryMsgInex:]
533 msgs[0].Role = message.User
534 }
535 }
536 return msgs, nil
537}
538
539func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
540 if prompt == "" {
541 return
542 }
543
544 agent := ai.NewAgent(a.smallModel.Model,
545 ai.WithSystemPrompt(string(titlePrompt)),
546 ai.WithMaxOutputTokens(40),
547 )
548
549 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
550 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", prompt),
551 })
552 if err != nil {
553 slog.Error("error generating title", "err", err)
554 return
555 }
556
557 title := resp.Response.Content.Text()
558
559 title = strings.ReplaceAll(title, "\n", " ")
560
561 // remove thinking tags if present
562 if idx := strings.Index(title, "</think>"); idx > 0 {
563 title = title[idx+len("</think>"):]
564 }
565
566 title = strings.TrimSpace(title)
567 if title == "" {
568 slog.Warn("failed to generate title", "warn", "empty title")
569 return
570 }
571
572 session.Title = title
573 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
574 _, saveErr := a.sessions.Save(ctx, *session)
575 if saveErr != nil {
576 slog.Error("failed to save session title & usage", "error", saveErr)
577 return
578 }
579}
580
581func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
582 modelConfig := model.CatwalkCfg
583 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
584 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
585 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
586 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
587 session.Cost += cost
588 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
589 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
590}
591
592func (a *sessionAgent) Cancel(sessionID string) {
593 // Cancel regular requests
594 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
595 slog.Info("Request cancellation initiated", "session_id", sessionID)
596 cancel()
597 }
598
599 // Also check for summarize requests
600 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
601 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
602 cancel()
603 }
604
605 if a.QueuedPrompts(sessionID) > 0 {
606 slog.Info("Clearing queued prompts", "session_id", sessionID)
607 a.messageQueue.Del(sessionID)
608 }
609}
610
611func (a *sessionAgent) ClearQueue(sessionID string) {
612 if a.QueuedPrompts(sessionID) > 0 {
613 slog.Info("Clearing queued prompts", "session_id", sessionID)
614 a.messageQueue.Del(sessionID)
615 }
616}
617
618func (a *sessionAgent) CancelAll() {
619 if !a.IsBusy() {
620 return
621 }
622 for key := range a.activeRequests.Seq2() {
623 a.Cancel(key) // key is sessionID
624 }
625
626 timeout := time.After(5 * time.Second)
627 for a.IsBusy() {
628 select {
629 case <-timeout:
630 return
631 default:
632 time.Sleep(200 * time.Millisecond)
633 }
634 }
635}
636
637func (a *sessionAgent) IsBusy() bool {
638 var busy bool
639 for cancelFunc := range a.activeRequests.Seq() {
640 if cancelFunc != nil {
641 busy = true
642 break
643 }
644 }
645 return busy
646}
647
648func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
649 _, busy := a.activeRequests.Get(sessionID)
650 return busy
651}
652
653func (a *sessionAgent) QueuedPrompts(sessionID string) int {
654 l, ok := a.messageQueue.Get(sessionID)
655 if !ok {
656 return 0
657 }
658 return len(l)
659}
660
661func (a *sessionAgent) SetModels(large Model, small Model) {
662 a.largeModel = large
663 a.smallModel = small
664}
665
666func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
667 a.tools = tools
668}
669
670func (a *sessionAgent) Model() Model {
671 return a.largeModel
672}