1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/agent/tools"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/session"
20 "github.com/charmbracelet/fantasy/ai"
21 "github.com/charmbracelet/fantasy/anthropic"
22)
23
24//go:embed templates/title.md
25var titlePrompt []byte
26
27//go:embed templates/summary.md
28var summaryPrompt []byte
29
30type SessionAgentCall struct {
31 SessionID string
32 Prompt string
33 ProviderOptions ai.ProviderOptions
34 Attachments []message.Attachment
35 MaxOutputTokens int64
36 Temperature *float64
37 TopP *float64
38 TopK *int64
39 FrequencyPenalty *float64
40 PresencePenalty *float64
41}
42
43type SessionAgent interface {
44 Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
45 SetModels(large Model, small Model)
46 SetTools(tools []ai.AgentTool)
47 Cancel(sessionID string)
48 CancelAll()
49 IsSessionBusy(sessionID string) bool
50 IsBusy() bool
51 QueuedPrompts(sessionID string) int
52 ClearQueue(sessionID string)
53 Summarize(context.Context, string) error
54 Model() Model
55}
56
57type Model struct {
58 Model ai.LanguageModel
59 CatwalkCfg catwalk.Model
60 ModelCfg config.SelectedModel
61}
62
63type sessionAgent struct {
64 largeModel Model
65 smallModel Model
66 systemPrompt string
67 tools []ai.AgentTool
68 sessions session.Service
69 messages message.Service
70
71 messageQueue *csync.Map[string, []SessionAgentCall]
72 activeRequests *csync.Map[string, context.CancelFunc]
73}
74
75type SessionAgentOption func(*sessionAgent)
76
77func NewSessionAgent(
78 largeModel Model,
79 smallModel Model,
80 systemPrompt string,
81 sessions session.Service,
82 messages message.Service,
83 tools ...ai.AgentTool,
84) SessionAgent {
85 return &sessionAgent{
86 largeModel: largeModel,
87 smallModel: smallModel,
88 systemPrompt: systemPrompt,
89 sessions: sessions,
90 messages: messages,
91 tools: tools,
92 messageQueue: csync.NewMap[string, []SessionAgentCall](),
93 activeRequests: csync.NewMap[string, context.CancelFunc](),
94 }
95}
96
97func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
98 if call.Prompt == "" {
99 return nil, ErrEmptyPrompt
100 }
101 if call.SessionID == "" {
102 return nil, ErrSessionMissing
103 }
104
105 // Queue the message if busy
106 if a.IsSessionBusy(call.SessionID) {
107 existing, ok := a.messageQueue.Get(call.SessionID)
108 if !ok {
109 existing = []SessionAgentCall{}
110 }
111 existing = append(existing, call)
112 a.messageQueue.Set(call.SessionID, existing)
113 return nil, nil
114 }
115
116 if len(a.tools) > 0 {
117 // add anthropic caching to the last tool
118 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
119 }
120
121 agent := ai.NewAgent(
122 a.largeModel.Model,
123 ai.WithSystemPrompt(a.systemPrompt),
124 ai.WithTools(a.tools...),
125 )
126
127 currentSession, err := a.sessions.Get(ctx, call.SessionID)
128 if err != nil {
129 return nil, fmt.Errorf("failed to get session: %w", err)
130 }
131
132 msgs, err := a.getSessionMessages(ctx, currentSession)
133 if err != nil {
134 return nil, fmt.Errorf("failed to get session messages: %w", err)
135 }
136
137 var wg sync.WaitGroup
138 // Generate title if first message
139 if len(msgs) == 0 {
140 wg.Go(func() {
141 a.generateTitle(ctx, currentSession, call.Prompt)
142 })
143 }
144
145 // Add the user message to the session
146 _, err = a.createUserMessage(ctx, call)
147 if err != nil {
148 return nil, err
149 }
150
151 // add the session to the context
152 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
153
154 genCtx, cancel := context.WithCancel(ctx)
155 a.activeRequests.Set(call.SessionID, cancel)
156
157 defer cancel()
158 defer a.activeRequests.Del(call.SessionID)
159
160 history, files := a.preparePrompt(msgs, call.Attachments...)
161
162 var currentAssistant *message.Message
163 result, err := agent.Stream(genCtx, ai.AgentStreamCall{
164 Prompt: call.Prompt,
165 Files: files,
166 Messages: history,
167 ProviderOptions: call.ProviderOptions,
168 MaxOutputTokens: &call.MaxOutputTokens,
169 TopP: call.TopP,
170 Temperature: call.Temperature,
171 PresencePenalty: call.PresencePenalty,
172 TopK: call.TopK,
173 FrequencyPenalty: call.FrequencyPenalty,
174 // Before each step create the new assistant message
175 PrepareStep: func(options ai.PrepareStepFunctionOptions) (prepared ai.PrepareStepResult, err error) {
176 var assistantMsg message.Message
177 assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
178 Role: message.Assistant,
179 Parts: []message.ContentPart{},
180 Model: a.largeModel.ModelCfg.Model,
181 Provider: a.largeModel.ModelCfg.Provider,
182 })
183 if err != nil {
184 return prepared, err
185 }
186
187 currentAssistant = &assistantMsg
188
189 prepared.Messages = options.Messages
190 // reset all cached items
191 for i := range prepared.Messages {
192 prepared.Messages[i].ProviderOptions = nil
193 }
194
195 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
196 a.messageQueue.Del(call.SessionID)
197 for _, queued := range queuedCalls {
198 userMessage, createErr := a.createUserMessage(genCtx, queued)
199 if createErr != nil {
200 return prepared, createErr
201 }
202 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
203 }
204
205 lastSystemRoleInx := 0
206 systemMessageUpdated := false
207 for i, msg := range prepared.Messages {
208 // only add cache control to the last message
209 if msg.Role == ai.MessageRoleSystem {
210 lastSystemRoleInx = i
211 } else if !systemMessageUpdated {
212 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
213 systemMessageUpdated = true
214 }
215 // than add cache control to the last 2 messages
216 if i > len(prepared.Messages)-3 {
217 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
218 }
219 }
220 return prepared, err
221 },
222 OnReasoningDelta: func(id string, text string) error {
223 currentAssistant.AppendReasoningContent(text)
224 return a.messages.Update(genCtx, *currentAssistant)
225 },
226 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
227 // handle anthropic signature
228 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
229 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
230 currentAssistant.AppendReasoningSignature(reasoning.Signature)
231 }
232 }
233 currentAssistant.FinishThinking()
234 return a.messages.Update(genCtx, *currentAssistant)
235 },
236 OnTextDelta: func(id string, text string) error {
237 currentAssistant.AppendContent(text)
238 return a.messages.Update(genCtx, *currentAssistant)
239 },
240 OnToolInputStart: func(id string, toolName string) error {
241 toolCall := message.ToolCall{
242 ID: id,
243 Name: toolName,
244 ProviderExecuted: false,
245 Finished: false,
246 }
247 currentAssistant.AddToolCall(toolCall)
248 return a.messages.Update(genCtx, *currentAssistant)
249 },
250 OnRetry: func(err *ai.APICallError, delay time.Duration) {
251 // TODO: implement
252 },
253 OnToolCall: func(tc ai.ToolCallContent) error {
254 toolCall := message.ToolCall{
255 ID: tc.ToolCallID,
256 Name: tc.ToolName,
257 Input: tc.Input,
258 ProviderExecuted: false,
259 Finished: true,
260 }
261 currentAssistant.AddToolCall(toolCall)
262 return a.messages.Update(genCtx, *currentAssistant)
263 },
264 OnToolResult: func(result ai.ToolResultContent) error {
265 var resultContent string
266 isError := false
267 switch result.Result.GetType() {
268 case ai.ToolResultContentTypeText:
269 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
270 if ok {
271 resultContent = r.Text
272 }
273 case ai.ToolResultContentTypeError:
274 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
275 if ok {
276 isError = true
277 resultContent = r.Error.Error()
278 }
279 case ai.ToolResultContentTypeMedia:
280 // TODO: handle this message type
281 }
282 toolResult := message.ToolResult{
283 ToolCallID: result.ToolCallID,
284 Name: result.ToolName,
285 Content: resultContent,
286 IsError: isError,
287 Metadata: result.ClientMetadata,
288 }
289 a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
290 Role: message.Tool,
291 Parts: []message.ContentPart{
292 toolResult,
293 },
294 })
295 return a.messages.Update(genCtx, *currentAssistant)
296 },
297 OnStepFinish: func(stepResult ai.StepResult) error {
298 finishReason := message.FinishReasonUnknown
299 switch stepResult.FinishReason {
300 case ai.FinishReasonLength:
301 finishReason = message.FinishReasonMaxTokens
302 case ai.FinishReasonStop:
303 finishReason = message.FinishReasonEndTurn
304 case ai.FinishReasonToolCalls:
305 finishReason = message.FinishReasonToolUse
306 }
307 currentAssistant.AddFinish(finishReason, "", "")
308 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
309 return a.messages.Update(genCtx, *currentAssistant)
310 },
311 })
312 if err != nil {
313 isCancelErr := errors.Is(err, context.Canceled)
314 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
315 if currentAssistant == nil {
316 return result, err
317 }
318 toolCalls := currentAssistant.ToolCalls()
319 toolResults := currentAssistant.ToolResults()
320 for _, tc := range toolCalls {
321 if !tc.Finished {
322 tc.Finished = true
323 tc.Input = "{}"
324 }
325 currentAssistant.AddToolCall(tc)
326 found := false
327 for _, tr := range toolResults {
328 if tr.ToolCallID == tc.ID {
329 found = true
330 break
331 }
332 }
333 if !found {
334 content := "There was an error while executing the tool"
335 if isCancelErr {
336 content = "Tool execution canceled by user"
337 } else if isPermissionErr {
338 content = "Permission denied"
339 }
340 currentAssistant.AddToolResult(message.ToolResult{
341 ToolCallID: tc.ID,
342 Name: tc.Name,
343 Content: content,
344 IsError: true,
345 })
346 }
347 }
348 if isCancelErr {
349 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
350 } else if isPermissionErr {
351 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
352 } else {
353 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
354 }
355 // INFO: we use the parent context here because the genCtx might have been cancelled
356 updateErr := a.messages.Update(ctx, *currentAssistant)
357 if updateErr != nil {
358 return nil, updateErr
359 }
360 }
361 if err != nil {
362 return nil, err
363 }
364 wg.Wait()
365
366 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
367 if !ok || len(queuedMessages) == 0 {
368 return result, err
369 }
370 // there are queued messages restart the loop
371 firstQueuedMessage := queuedMessages[0]
372 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
373 return a.Run(genCtx, firstQueuedMessage)
374}
375
376func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
377 if a.IsSessionBusy(sessionID) {
378 return ErrSessionBusy
379 }
380
381 currentSession, err := a.sessions.Get(ctx, sessionID)
382 if err != nil {
383 return fmt.Errorf("failed to get session: %w", err)
384 }
385 msgs, err := a.getSessionMessages(ctx, currentSession)
386 if err != nil {
387 return err
388 }
389 if len(msgs) == 0 {
390 // nothing to summarize
391 return nil
392 }
393
394 aiMsgs, _ := a.preparePrompt(msgs)
395
396 genCtx, cancel := context.WithCancel(ctx)
397 a.activeRequests.Set(sessionID, cancel)
398 defer a.activeRequests.Del(sessionID)
399 defer cancel()
400
401 agent := ai.NewAgent(a.largeModel.Model,
402 ai.WithSystemPrompt(string(summaryPrompt)),
403 )
404 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
405 Role: message.Assistant,
406 Model: a.largeModel.Model.Model(),
407 Provider: a.largeModel.Model.Provider(),
408 })
409 if err != nil {
410 return err
411 }
412
413 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
414 Prompt: "Provide a detailed summary of our conversation above.",
415 Messages: aiMsgs,
416 OnReasoningDelta: func(id string, text string) error {
417 summaryMessage.AppendReasoningContent(text)
418 return a.messages.Update(ctx, summaryMessage)
419 },
420 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
421 // handle anthropic signature
422 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
423 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
424 summaryMessage.AppendReasoningSignature(signature.Signature)
425 }
426 }
427 summaryMessage.FinishThinking()
428 return a.messages.Update(ctx, summaryMessage)
429 },
430 OnTextDelta: func(id, text string) error {
431 summaryMessage.AppendContent(text)
432 return a.messages.Update(ctx, summaryMessage)
433 },
434 })
435 if err != nil {
436 return err
437 }
438
439 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
440 err = a.messages.Update(genCtx, summaryMessage)
441 if err != nil {
442 return err
443 }
444
445 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
446
447 // just in case get just the last usage
448 usage := resp.Response.Usage
449 currentSession.SummaryMessageID = summaryMessage.ID
450 currentSession.CompletionTokens = usage.OutputTokens
451 currentSession.PromptTokens = 0
452 _, err = a.sessions.Save(genCtx, currentSession)
453 return err
454}
455
456func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
457 return ai.ProviderOptions{
458 anthropic.Name: &anthropic.ProviderCacheControlOptions{
459 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
460 },
461 }
462}
463
464func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
465 var attachmentParts []message.ContentPart
466 for _, attachment := range call.Attachments {
467 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
468 }
469 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
470 parts = append(parts, attachmentParts...)
471 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
472 Role: message.User,
473 Parts: parts,
474 })
475 if err != nil {
476 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
477 }
478 return msg, nil
479}
480
481func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
482 var history []ai.Message
483 for _, m := range msgs {
484 if len(m.Parts) == 0 {
485 continue
486 }
487 // Assistant message without content or tool calls (cancelled before it returned anything)
488 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
489 continue
490 }
491 history = append(history, m.ToAIMessage()...)
492 }
493
494 var files []ai.FilePart
495 for _, attachment := range attachments {
496 files = append(files, ai.FilePart{
497 Filename: attachment.FileName,
498 Data: attachment.Content,
499 MediaType: attachment.MimeType,
500 })
501 }
502
503 return history, files
504}
505
506func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
507 msgs, err := a.messages.List(ctx, session.ID)
508 if err != nil {
509 return nil, fmt.Errorf("failed to list messages: %w", err)
510 }
511
512 if session.SummaryMessageID != "" {
513 summaryMsgInex := -1
514 for i, msg := range msgs {
515 if msg.ID == session.SummaryMessageID {
516 summaryMsgInex = i
517 break
518 }
519 }
520 if summaryMsgInex != -1 {
521 msgs = msgs[summaryMsgInex:]
522 msgs[0].Role = message.User
523 }
524 }
525 return msgs, nil
526}
527
528func (a *sessionAgent) generateTitle(ctx context.Context, session session.Session, prompt string) {
529 if prompt == "" {
530 return
531 }
532
533 agent := ai.NewAgent(a.smallModel.Model,
534 ai.WithSystemPrompt(string(titlePrompt)),
535 ai.WithMaxOutputTokens(40),
536 )
537
538 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
539 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", prompt),
540 })
541 if err != nil {
542 slog.Error("error generating title", "err", err)
543 return
544 }
545
546 title := resp.Response.Content.Text()
547
548 title = strings.ReplaceAll(title, "\n", " ")
549
550 // remove thinking tags if present
551 if idx := strings.Index(title, "</think>"); idx > 0 {
552 title = title[idx+len("</think>"):]
553 }
554
555 title = strings.TrimSpace(title)
556 if title == "" {
557 slog.Warn("failed to generate title", "warn", "empty title")
558 return
559 }
560
561 session.Title = title
562 a.updateSessionUsage(a.smallModel, &session, resp.TotalUsage)
563 _, saveErr := a.sessions.Save(ctx, session)
564 if saveErr != nil {
565 slog.Error("failed to save session title & usage", "error", saveErr)
566 return
567 }
568}
569
570func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
571 modelConfig := model.CatwalkCfg
572 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
573 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
574 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
575 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
576 session.Cost += cost
577 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
578 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
579}
580
581func (a *sessionAgent) Cancel(sessionID string) {
582 // Cancel regular requests
583 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
584 slog.Info("Request cancellation initiated", "session_id", sessionID)
585 cancel()
586 }
587
588 // Also check for summarize requests
589 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
590 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
591 cancel()
592 }
593
594 if a.QueuedPrompts(sessionID) > 0 {
595 slog.Info("Clearing queued prompts", "session_id", sessionID)
596 a.messageQueue.Del(sessionID)
597 }
598}
599
600func (a *sessionAgent) ClearQueue(sessionID string) {
601 if a.QueuedPrompts(sessionID) > 0 {
602 slog.Info("Clearing queued prompts", "session_id", sessionID)
603 a.messageQueue.Del(sessionID)
604 }
605}
606
607func (a *sessionAgent) CancelAll() {
608 if !a.IsBusy() {
609 return
610 }
611 for key := range a.activeRequests.Seq2() {
612 a.Cancel(key) // key is sessionID
613 }
614
615 timeout := time.After(5 * time.Second)
616 for a.IsBusy() {
617 select {
618 case <-timeout:
619 return
620 default:
621 time.Sleep(200 * time.Millisecond)
622 }
623 }
624}
625
626func (a *sessionAgent) IsBusy() bool {
627 var busy bool
628 for cancelFunc := range a.activeRequests.Seq() {
629 if cancelFunc != nil {
630 busy = true
631 break
632 }
633 }
634 return busy
635}
636
637func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
638 _, busy := a.activeRequests.Get(sessionID)
639 return busy
640}
641
642func (a *sessionAgent) QueuedPrompts(sessionID string) int {
643 l, ok := a.messageQueue.Get(sessionID)
644 if !ok {
645 return 0
646 }
647 return len(l)
648}
649
650func (a *sessionAgent) SetModels(large Model, small Model) {
651 a.largeModel = large
652 a.smallModel = small
653}
654
655func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
656 a.tools = tools
657}
658
659func (a *sessionAgent) Model() Model {
660 return a.largeModel
661}