1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/agent/tools"
15 "github.com/charmbracelet/crush/internal/config"
16 "github.com/charmbracelet/crush/internal/csync"
17 "github.com/charmbracelet/crush/internal/message"
18 "github.com/charmbracelet/crush/internal/permission"
19 "github.com/charmbracelet/crush/internal/session"
20 "github.com/charmbracelet/fantasy/ai"
21 "github.com/charmbracelet/fantasy/anthropic"
22)
23
24//go:embed templates/title.md
25var titlePrompt []byte
26
27//go:embed templates/summary.md
28var summaryPrompt []byte
29
30type SessionAgentCall struct {
31 SessionID string
32 Prompt string
33 ProviderOptions ai.ProviderOptions
34 Attachments []message.Attachment
35 MaxOutputTokens int64
36 Temperature *float64
37 TopP *float64
38 TopK *int64
39 FrequencyPenalty *float64
40 PresencePenalty *float64
41}
42
43type SessionAgent interface {
44 Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
45 SetModels(large Model, small Model)
46 SetTools(tools []ai.AgentTool)
47 Cancel(sessionID string)
48 CancelAll()
49 IsSessionBusy(sessionID string) bool
50 IsBusy() bool
51 QueuedPrompts(sessionID string) int
52 ClearQueue(sessionID string)
53 Summarize(context.Context, string) error
54 Model() Model
55}
56
57type Model struct {
58 Model ai.LanguageModel
59 CatwalkCfg catwalk.Model
60 ModelCfg config.SelectedModel
61}
62
63type sessionAgent struct {
64 largeModel Model
65 smallModel Model
66 systemPrompt string
67 tools []ai.AgentTool
68 sessions session.Service
69 messages message.Service
70
71 messageQueue *csync.Map[string, []SessionAgentCall]
72 activeRequests *csync.Map[string, context.CancelFunc]
73}
74
75type SessionAgentOption func(*sessionAgent)
76
77func NewSessionAgent(
78 largeModel Model,
79 smallModel Model,
80 systemPrompt string,
81 sessions session.Service,
82 messages message.Service,
83 tools ...ai.AgentTool,
84) SessionAgent {
85 return &sessionAgent{
86 largeModel: largeModel,
87 smallModel: smallModel,
88 systemPrompt: systemPrompt,
89 sessions: sessions,
90 messages: messages,
91 tools: tools,
92 messageQueue: csync.NewMap[string, []SessionAgentCall](),
93 activeRequests: csync.NewMap[string, context.CancelFunc](),
94 }
95}
96
97func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
98 if call.Prompt == "" {
99 return nil, ErrEmptyPrompt
100 }
101 if call.SessionID == "" {
102 return nil, ErrSessionMissing
103 }
104
105 // Queue the message if busy
106 if a.IsSessionBusy(call.SessionID) {
107 existing, ok := a.messageQueue.Get(call.SessionID)
108 if !ok {
109 existing = []SessionAgentCall{}
110 }
111 existing = append(existing, call)
112 a.messageQueue.Set(call.SessionID, existing)
113 return nil, nil
114 }
115
116 if len(a.tools) > 0 {
117 // add anthropic caching to the last tool
118 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
119 }
120
121 agent := ai.NewAgent(
122 a.largeModel.Model,
123 ai.WithSystemPrompt(a.systemPrompt),
124 ai.WithTools(a.tools...),
125 )
126
127 currentSession, err := a.sessions.Get(ctx, call.SessionID)
128 if err != nil {
129 return nil, fmt.Errorf("failed to get session: %w", err)
130 }
131
132 msgs, err := a.getSessionMessages(ctx, currentSession)
133 if err != nil {
134 return nil, fmt.Errorf("failed to get session messages: %w", err)
135 }
136
137 var wg sync.WaitGroup
138 // Generate title if first message
139 if len(msgs) == 0 {
140 wg.Go(func() {
141 a.generateTitle(ctx, currentSession, call.Prompt)
142 })
143 }
144
145 // Add the user message to the session
146 _, err = a.createUserMessage(ctx, call)
147 if err != nil {
148 return nil, err
149 }
150
151 // add the session to the context
152 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
153
154 genCtx, cancel := context.WithCancel(ctx)
155 a.activeRequests.Set(call.SessionID, cancel)
156
157 defer cancel()
158 defer a.activeRequests.Del(call.SessionID)
159
160 history, files := a.preparePrompt(msgs, call.Attachments...)
161
162 var currentAssistant *message.Message
163 result, err := agent.Stream(genCtx, ai.AgentStreamCall{
164 Prompt: call.Prompt,
165 Files: files,
166 Messages: history,
167 ProviderOptions: call.ProviderOptions,
168 MaxOutputTokens: &call.MaxOutputTokens,
169 TopP: call.TopP,
170 Temperature: call.Temperature,
171 PresencePenalty: call.PresencePenalty,
172 TopK: call.TopK,
173 FrequencyPenalty: call.FrequencyPenalty,
174 // Before each step create the new assistant message
175 PrepareStep: func(options ai.PrepareStepFunctionOptions) (prepared ai.PrepareStepResult, err error) {
176 var assistantMsg message.Message
177 assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
178 Role: message.Assistant,
179 Parts: []message.ContentPart{},
180 Model: a.largeModel.ModelCfg.Model,
181 Provider: a.largeModel.ModelCfg.Provider,
182 })
183 if err != nil {
184 return prepared, err
185 }
186
187 currentAssistant = &assistantMsg
188
189 prepared.Messages = options.Messages
190 // reset all cached items
191 for i := range prepared.Messages {
192 prepared.Messages[i].ProviderOptions = nil
193 }
194
195 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
196 a.messageQueue.Del(call.SessionID)
197 for _, queued := range queuedCalls {
198 userMessage, createErr := a.createUserMessage(genCtx, queued)
199 if createErr != nil {
200 return prepared, createErr
201 }
202 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
203 }
204
205 lastSystemRoleInx := 0
206 systemMessageUpdated := false
207 for i, msg := range prepared.Messages {
208 // only add cache control to the last message
209 if msg.Role == ai.MessageRoleSystem {
210 lastSystemRoleInx = i
211 } else if !systemMessageUpdated {
212 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
213 systemMessageUpdated = true
214 }
215 // than add cache control to the last 2 messages
216 if i > len(prepared.Messages)-3 {
217 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
218 }
219 }
220 return prepared, err
221 },
222 OnReasoningDelta: func(id string, text string) error {
223 currentAssistant.AppendReasoningContent(text)
224 return a.messages.Update(genCtx, *currentAssistant)
225 },
226 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
227 // handle anthropic signature
228 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
229 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
230 currentAssistant.AppendReasoningSignature(reasoning.Signature)
231 }
232 }
233 currentAssistant.FinishThinking()
234 return a.messages.Update(genCtx, *currentAssistant)
235 },
236 OnTextDelta: func(id string, text string) error {
237 currentAssistant.AppendContent(text)
238 return a.messages.Update(genCtx, *currentAssistant)
239 },
240 OnToolInputStart: func(id string, toolName string) error {
241 toolCall := message.ToolCall{
242 ID: id,
243 Name: toolName,
244 ProviderExecuted: false,
245 Finished: false,
246 }
247 currentAssistant.AddToolCall(toolCall)
248 return a.messages.Update(genCtx, *currentAssistant)
249 },
250 OnRetry: func(err *ai.APICallError, delay time.Duration) {
251 // TODO: implement
252 },
253 OnToolResult: func(result ai.ToolResultContent) error {
254 var resultContent string
255 isError := false
256 switch result.Result.GetType() {
257 case ai.ToolResultContentTypeText:
258 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
259 if ok {
260 resultContent = r.Text
261 }
262 case ai.ToolResultContentTypeError:
263 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
264 if ok {
265 isError = true
266 resultContent = r.Error.Error()
267 }
268 case ai.ToolResultContentTypeMedia:
269 // TODO: handle this message type
270 }
271 toolResult := message.ToolResult{
272 ToolCallID: result.ToolCallID,
273 Name: result.ToolName,
274 Content: resultContent,
275 IsError: isError,
276 Metadata: result.ClientMetadata,
277 }
278 a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
279 Role: message.Tool,
280 Parts: []message.ContentPart{
281 toolResult,
282 },
283 })
284 return a.messages.Update(genCtx, *currentAssistant)
285 },
286 OnStepFinish: func(stepResult ai.StepResult) error {
287 finishReason := message.FinishReasonUnknown
288 switch stepResult.FinishReason {
289 case ai.FinishReasonLength:
290 finishReason = message.FinishReasonMaxTokens
291 case ai.FinishReasonStop:
292 finishReason = message.FinishReasonEndTurn
293 case ai.FinishReasonToolCalls:
294 finishReason = message.FinishReasonToolUse
295 }
296 currentAssistant.AddFinish(finishReason, "", "")
297 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
298 return a.messages.Update(genCtx, *currentAssistant)
299 },
300 })
301 if err != nil {
302 isCancelErr := errors.Is(err, context.Canceled)
303 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
304 if currentAssistant == nil {
305 return result, err
306 }
307 toolCalls := currentAssistant.ToolCalls()
308 toolResults := currentAssistant.ToolResults()
309 for _, tc := range toolCalls {
310 if !tc.Finished {
311 tc.Finished = true
312 tc.Input = "{}"
313 }
314 currentAssistant.AddToolCall(tc)
315 found := false
316 for _, tr := range toolResults {
317 if tr.ToolCallID == tc.ID {
318 found = true
319 break
320 }
321 }
322 if !found {
323 content := "There was an error while executing the tool"
324 if isCancelErr {
325 content = "Tool execution canceled by user"
326 } else if isPermissionErr {
327 content = "Permission denied"
328 }
329 currentAssistant.AddToolResult(message.ToolResult{
330 ToolCallID: tc.ID,
331 Name: tc.Name,
332 Content: content,
333 IsError: true,
334 })
335 }
336 }
337 if isCancelErr {
338 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
339 } else if isPermissionErr {
340 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
341 } else {
342 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
343 }
344 // INFO: we use the parent context here because the genCtx might have been cancelled
345 updateErr := a.messages.Update(ctx, *currentAssistant)
346 if updateErr != nil {
347 return nil, updateErr
348 }
349 }
350 if err != nil {
351 return nil, err
352 }
353 wg.Wait()
354
355 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
356 if !ok || len(queuedMessages) == 0 {
357 return result, err
358 }
359 // there are queued messages restart the loop
360 firstQueuedMessage := queuedMessages[0]
361 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
362 return a.Run(genCtx, firstQueuedMessage)
363}
364
365func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
366 if a.IsSessionBusy(sessionID) {
367 return ErrSessionBusy
368 }
369
370 currentSession, err := a.sessions.Get(ctx, sessionID)
371 if err != nil {
372 return fmt.Errorf("failed to get session: %w", err)
373 }
374 msgs, err := a.getSessionMessages(ctx, currentSession)
375 if err != nil {
376 return err
377 }
378 if len(msgs) == 0 {
379 // nothing to summarize
380 return nil
381 }
382
383 aiMsgs, _ := a.preparePrompt(msgs)
384
385 genCtx, cancel := context.WithCancel(ctx)
386 a.activeRequests.Set(sessionID, cancel)
387 defer a.activeRequests.Del(sessionID)
388 defer cancel()
389
390 agent := ai.NewAgent(a.largeModel.Model,
391 ai.WithSystemPrompt(string(summaryPrompt)),
392 )
393 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
394 Role: message.Assistant,
395 Model: a.largeModel.Model.Model(),
396 Provider: a.largeModel.Model.Provider(),
397 })
398 if err != nil {
399 return err
400 }
401
402 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
403 Prompt: "Provide a detailed summary of our conversation above.",
404 Messages: aiMsgs,
405 OnReasoningDelta: func(id string, text string) error {
406 summaryMessage.AppendReasoningContent(text)
407 return a.messages.Update(ctx, summaryMessage)
408 },
409 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
410 // handle anthropic signature
411 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
412 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
413 summaryMessage.AppendReasoningSignature(signature.Signature)
414 }
415 }
416 summaryMessage.FinishThinking()
417 return a.messages.Update(ctx, summaryMessage)
418 },
419 OnTextDelta: func(id, text string) error {
420 summaryMessage.AppendContent(text)
421 return a.messages.Update(ctx, summaryMessage)
422 },
423 })
424 if err != nil {
425 return err
426 }
427
428 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
429 err = a.messages.Update(genCtx, summaryMessage)
430 if err != nil {
431 return err
432 }
433
434 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
435
436 // just in case get just the last usage
437 usage := resp.Response.Usage
438 currentSession.SummaryMessageID = summaryMessage.ID
439 currentSession.CompletionTokens = usage.OutputTokens
440 currentSession.PromptTokens = 0
441 _, err = a.sessions.Save(genCtx, currentSession)
442 return err
443}
444
445func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
446 return ai.ProviderOptions{
447 anthropic.Name: &anthropic.ProviderCacheControlOptions{
448 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
449 },
450 }
451}
452
453func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
454 var attachmentParts []message.ContentPart
455 for _, attachment := range call.Attachments {
456 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
457 }
458 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
459 parts = append(parts, attachmentParts...)
460 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
461 Role: message.User,
462 Parts: parts,
463 })
464 if err != nil {
465 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
466 }
467 return msg, nil
468}
469
470func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
471 var history []ai.Message
472 for _, m := range msgs {
473 if len(m.Parts) == 0 {
474 continue
475 }
476 // Assistant message without content or tool calls (cancelled before it returned anything)
477 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
478 continue
479 }
480 history = append(history, m.ToAIMessage()...)
481 }
482
483 var files []ai.FilePart
484 for _, attachment := range attachments {
485 files = append(files, ai.FilePart{
486 Filename: attachment.FileName,
487 Data: attachment.Content,
488 MediaType: attachment.MimeType,
489 })
490 }
491
492 return history, files
493}
494
495func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
496 msgs, err := a.messages.List(ctx, session.ID)
497 if err != nil {
498 return nil, fmt.Errorf("failed to list messages: %w", err)
499 }
500
501 if session.SummaryMessageID != "" {
502 summaryMsgInex := -1
503 for i, msg := range msgs {
504 if msg.ID == session.SummaryMessageID {
505 summaryMsgInex = i
506 break
507 }
508 }
509 if summaryMsgInex != -1 {
510 msgs = msgs[summaryMsgInex:]
511 msgs[0].Role = message.User
512 }
513 }
514 return msgs, nil
515}
516
517func (a *sessionAgent) generateTitle(ctx context.Context, session session.Session, prompt string) {
518 if prompt == "" {
519 return
520 }
521
522 agent := ai.NewAgent(a.smallModel.Model,
523 ai.WithSystemPrompt(string(titlePrompt)),
524 ai.WithMaxOutputTokens(40),
525 )
526
527 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
528 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", prompt),
529 })
530 if err != nil {
531 slog.Error("error generating title", "err", err)
532 return
533 }
534
535 title := resp.Response.Content.Text()
536
537 title = strings.ReplaceAll(title, "\n", " ")
538
539 // remove thinking tags if present
540 if idx := strings.Index(title, "</think>"); idx > 0 {
541 title = title[idx+len("</think>"):]
542 }
543
544 title = strings.TrimSpace(title)
545 if title == "" {
546 slog.Warn("failed to generate title", "warn", "empty title")
547 return
548 }
549
550 session.Title = title
551 a.updateSessionUsage(a.smallModel, &session, resp.TotalUsage)
552 _, saveErr := a.sessions.Save(ctx, session)
553 if saveErr != nil {
554 slog.Error("failed to save session title & usage", "error", saveErr)
555 return
556 }
557}
558
559func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
560 modelConfig := model.CatwalkCfg
561 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
562 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
563 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
564 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
565 session.Cost += cost
566 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
567 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
568}
569
570func (a *sessionAgent) Cancel(sessionID string) {
571 // Cancel regular requests
572 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
573 slog.Info("Request cancellation initiated", "session_id", sessionID)
574 cancel()
575 }
576
577 // Also check for summarize requests
578 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
579 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
580 cancel()
581 }
582
583 if a.QueuedPrompts(sessionID) > 0 {
584 slog.Info("Clearing queued prompts", "session_id", sessionID)
585 a.messageQueue.Del(sessionID)
586 }
587}
588
589func (a *sessionAgent) ClearQueue(sessionID string) {
590 if a.QueuedPrompts(sessionID) > 0 {
591 slog.Info("Clearing queued prompts", "session_id", sessionID)
592 a.messageQueue.Del(sessionID)
593 }
594}
595
596func (a *sessionAgent) CancelAll() {
597 if !a.IsBusy() {
598 return
599 }
600 for key := range a.activeRequests.Seq2() {
601 a.Cancel(key) // key is sessionID
602 }
603
604 timeout := time.After(5 * time.Second)
605 for a.IsBusy() {
606 select {
607 case <-timeout:
608 return
609 default:
610 time.Sleep(200 * time.Millisecond)
611 }
612 }
613}
614
615func (a *sessionAgent) IsBusy() bool {
616 var busy bool
617 for cancelFunc := range a.activeRequests.Seq() {
618 if cancelFunc != nil {
619 busy = true
620 break
621 }
622 }
623 return busy
624}
625
626func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
627 _, busy := a.activeRequests.Get(sessionID)
628 return busy
629}
630
631func (a *sessionAgent) QueuedPrompts(sessionID string) int {
632 l, ok := a.messageQueue.Get(sessionID)
633 if !ok {
634 return 0
635 }
636 return len(l)
637}
638
639func (a *sessionAgent) SetModels(large Model, small Model) {
640 a.largeModel = large
641 a.smallModel = small
642}
643
644func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
645 a.tools = tools
646}
647
648func (a *sessionAgent) Model() Model {
649 return a.largeModel
650}