1package agent
2
3import (
4 "context"
5 _ "embed"
6 "errors"
7 "fmt"
8 "log/slog"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/csync"
14 "github.com/charmbracelet/crush/internal/llm/tools"
15 "github.com/charmbracelet/crush/internal/message"
16 "github.com/charmbracelet/crush/internal/permission"
17 "github.com/charmbracelet/crush/internal/session"
18 "github.com/charmbracelet/fantasy/ai"
19 "github.com/charmbracelet/fantasy/anthropic"
20)
21
22//go:embed templates/title.md
23var titlePrompt []byte
24
25//go:embed templates/summary.md
26var summaryPrompt []byte
27
28type SessionAgentCall struct {
29 SessionID string
30 Prompt string
31 ProviderOptions ai.ProviderOptions
32 Attachments []message.Attachment
33 MaxOutputTokens int64
34 Temperature *float64
35 TopP *float64
36 TopK *int64
37 FrequencyPenalty *float64
38 PresencePenalty *float64
39}
40
41type SessionAgent interface {
42 Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
43 SetModels(large Model, small Model)
44 SetTools(tools []ai.AgentTool)
45 Cancel(sessionID string)
46 CancelAll()
47 IsSessionBusy(sessionID string) bool
48 IsBusy() bool
49 QueuedPrompts(sessionID string) int
50 ClearQueue(sessionID string)
51 Summarize(context.Context, string) error
52}
53
54type Model struct {
55 model ai.LanguageModel
56 config catwalk.Model
57}
58
59type sessionAgent struct {
60 largeModel Model
61 smallModel Model
62 systemPrompt string
63 tools []ai.AgentTool
64 maxOutputTokens int64
65 sessions session.Service
66 messages message.Service
67
68 messageQueue *csync.Map[string, []SessionAgentCall]
69 activeRequests *csync.Map[string, context.CancelFunc]
70}
71
72type SessionAgentOption func(*sessionAgent)
73
74func NewSessionAgent() SessionAgent {
75 return &sessionAgent{}
76}
77
78func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
79 if call.Prompt == "" {
80 return nil, ErrEmptyPrompt
81 }
82 if call.SessionID == "" {
83 return nil, ErrSessionMissing
84 }
85
86 // Queue the message if busy
87 if a.IsSessionBusy(call.SessionID) {
88 existing, ok := a.messageQueue.Get(call.SessionID)
89 if !ok {
90 existing = []SessionAgentCall{}
91 }
92 existing = append(existing, call)
93 a.messageQueue.Set(call.SessionID, existing)
94 return nil, nil
95 }
96
97 if len(a.tools) > 0 {
98 // add anthropic caching to the last tool
99 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
100 }
101
102 agent := ai.NewAgent(
103 a.largeModel.model,
104 ai.WithSystemPrompt(a.systemPrompt),
105 ai.WithTools(a.tools...),
106 ai.WithMaxOutputTokens(a.maxOutputTokens),
107 )
108
109 currentSession, err := a.sessions.Get(ctx, call.SessionID)
110 if err != nil {
111 return nil, fmt.Errorf("failed to get session: %w", err)
112 }
113
114 msgs, err := a.getSessionMessages(ctx, currentSession)
115 if err != nil {
116 return nil, fmt.Errorf("failed to get session messages: %w", err)
117 }
118
119 // Generate title if first message
120 if len(msgs) == 0 {
121 go a.generateTitle(ctx, currentSession, call.Prompt)
122 }
123
124 // Add the user message to the session
125 _, err = a.createUserMessage(ctx, call)
126 if err != nil {
127 return nil, err
128 }
129
130 // add the session to the context
131 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
132
133 genCtx, cancel := context.WithCancel(ctx)
134 a.activeRequests.Set(call.SessionID, cancel)
135
136 defer cancel()
137 defer a.activeRequests.Del(call.SessionID)
138
139 history, files := a.preparePrompt(msgs, call.Attachments...)
140
141 var currentAssistant *message.Message
142 result, err := agent.Stream(genCtx, ai.AgentStreamCall{
143 Prompt: call.Prompt,
144 Files: files,
145 Messages: history,
146 ProviderOptions: call.ProviderOptions,
147 MaxOutputTokens: &call.MaxOutputTokens,
148 TopP: call.TopP,
149 Temperature: call.Temperature,
150 PresencePenalty: call.PresencePenalty,
151 TopK: call.TopK,
152 FrequencyPenalty: call.FrequencyPenalty,
153 // Before each step create the new assistant message
154 PrepareStep: func(options ai.PrepareStepFunctionOptions) (prepared ai.PrepareStepResult, err error) {
155 var assistantMsg message.Message
156 assistantMsg, err = a.messages.Create(genCtx, call.SessionID, message.CreateMessageParams{
157 Role: message.Assistant,
158 Parts: []message.ContentPart{},
159 Model: a.largeModel.model.Model(),
160 Provider: a.largeModel.model.Provider(),
161 })
162 if err != nil {
163 return prepared, err
164 }
165
166 currentAssistant = &assistantMsg
167
168 prepared.Messages = options.Messages
169
170 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
171 a.messageQueue.Del(call.SessionID)
172 for _, queued := range queuedCalls {
173 userMessage, createErr := a.createUserMessage(genCtx, queued)
174 if createErr != nil {
175 return prepared, createErr
176 }
177 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
178 }
179
180 lastSystemRoleInx := 0
181 systemMessageUpdated := false
182 for i, msg := range prepared.Messages {
183 // only add cache control to the last message
184 if msg.Role == ai.MessageRoleSystem {
185 lastSystemRoleInx = i
186 } else if !systemMessageUpdated {
187 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
188 systemMessageUpdated = true
189 }
190 // than add cache control to the last 2 messages
191 if i > len(msgs)-3 {
192 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
193 }
194 }
195 return prepared, err
196 },
197 OnReasoningDelta: func(id string, text string) error {
198 currentAssistant.AppendReasoningContent(text)
199 return a.messages.Update(genCtx, *currentAssistant)
200 },
201 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
202 // handle anthropic signature
203 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
204 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
205 currentAssistant.AppendReasoningSignature(reasoning.Signature)
206 }
207 }
208 currentAssistant.FinishThinking()
209 return a.messages.Update(genCtx, *currentAssistant)
210 },
211 OnTextDelta: func(id string, text string) error {
212 currentAssistant.AppendContent(text)
213 return a.messages.Update(genCtx, *currentAssistant)
214 },
215 OnToolInputStart: func(id string, toolName string) error {
216 toolCall := message.ToolCall{
217 ID: id,
218 Name: toolName,
219 ProviderExecuted: false,
220 Finished: false,
221 }
222 currentAssistant.AddToolCall(toolCall)
223 return a.messages.Update(genCtx, *currentAssistant)
224 },
225 OnRetry: func(err *ai.APICallError, delay time.Duration) {
226 // TODO: implement
227 },
228 OnToolResult: func(result ai.ToolResultContent) error {
229 var resultContent string
230 isError := false
231 switch result.Result.GetType() {
232 case ai.ToolResultContentTypeText:
233 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
234 if ok {
235 resultContent = r.Text
236 }
237 case ai.ToolResultContentTypeError:
238 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
239 if ok {
240 isError = true
241 resultContent = r.Error.Error()
242 }
243 case ai.ToolResultContentTypeMedia:
244 // TODO: handle this message type
245 }
246 toolResult := message.ToolResult{
247 ToolCallID: result.ToolCallID,
248 Name: result.ToolName,
249 Content: resultContent,
250 IsError: isError,
251 Metadata: result.ClientMetadata,
252 }
253 a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
254 Role: message.Tool,
255 Parts: []message.ContentPart{
256 toolResult,
257 },
258 })
259 return a.messages.Update(genCtx, *currentAssistant)
260 },
261 OnStepFinish: func(stepResult ai.StepResult) error {
262 finishReason := message.FinishReasonUnknown
263 switch stepResult.FinishReason {
264 case ai.FinishReasonLength:
265 finishReason = message.FinishReasonMaxTokens
266 case ai.FinishReasonStop:
267 finishReason = message.FinishReasonEndTurn
268 case ai.FinishReasonToolCalls:
269 finishReason = message.FinishReasonToolUse
270 }
271 currentAssistant.AddFinish(finishReason, "", "")
272 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
273 return a.messages.Update(genCtx, *currentAssistant)
274 },
275 })
276 if err != nil {
277 isCancelErr := errors.Is(err, context.Canceled)
278 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
279 if currentAssistant == nil {
280 return result, err
281 }
282 toolCalls := currentAssistant.ToolCalls()
283 toolResults := currentAssistant.ToolResults()
284 for _, tc := range toolCalls {
285 if !tc.Finished {
286 tc.Finished = true
287 tc.Input = "{}"
288 }
289 currentAssistant.AddToolCall(tc)
290 found := false
291 for _, tr := range toolResults {
292 if tr.ToolCallID == tc.ID {
293 found = true
294 break
295 }
296 }
297 if !found {
298 content := "There was an error while executing the tool"
299 if isCancelErr {
300 content = "Tool execution canceled by user"
301 } else if isPermissionErr {
302 content = "Permission denied"
303 }
304 currentAssistant.AddToolResult(message.ToolResult{
305 ToolCallID: tc.ID,
306 Name: tc.Name,
307 Content: content,
308 IsError: true,
309 })
310 }
311 }
312 if isCancelErr {
313 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
314 } else if isPermissionErr {
315 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
316 } else {
317 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
318 }
319 // INFO: we use the parent context here because the genCtx might have been cancelled
320 updateErr := a.messages.Update(ctx, *currentAssistant)
321 if updateErr != nil {
322 return nil, updateErr
323 }
324 }
325 if err != nil {
326 return nil, err
327 }
328
329 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
330 if !ok || len(queuedMessages) == 0 {
331 return result, err
332 }
333 // there are queued messages restart the loop
334 firstQueuedMessage := queuedMessages[0]
335 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
336 return a.Run(genCtx, firstQueuedMessage)
337}
338
339func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
340 if a.IsSessionBusy(sessionID) {
341 return ErrSessionBusy
342 }
343
344 currentSession, err := a.sessions.Get(ctx, sessionID)
345 if err != nil {
346 return fmt.Errorf("failed to get session: %w", err)
347 }
348 msgs, err := a.getSessionMessages(ctx, currentSession)
349 if err != nil {
350 return err
351 }
352 if len(msgs) == 0 {
353 // nothing to summarize
354 return nil
355 }
356
357 aiMsgs, _ := a.preparePrompt(msgs)
358
359 genCtx, cancel := context.WithCancel(ctx)
360 a.activeRequests.Set(sessionID, cancel)
361 defer a.activeRequests.Del(sessionID)
362 defer cancel()
363
364 agent := ai.NewAgent(a.largeModel.model,
365 ai.WithSystemPrompt(string(summaryPrompt)),
366 )
367 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
368 Role: message.Assistant,
369 Model: a.largeModel.model.Model(),
370 Provider: a.largeModel.model.Provider(),
371 })
372 if err != nil {
373 return err
374 }
375
376 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
377 Prompt: "Provide a detailed summary of our conversation above.",
378 Messages: aiMsgs,
379 OnReasoningDelta: func(id string, text string) error {
380 summaryMessage.AppendReasoningContent(text)
381 return a.messages.Update(ctx, summaryMessage)
382 },
383 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
384 // handle anthropic signature
385 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
386 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
387 summaryMessage.AppendReasoningSignature(signature.Signature)
388 }
389 }
390 summaryMessage.FinishThinking()
391 return a.messages.Update(ctx, summaryMessage)
392 },
393 OnTextDelta: func(id, text string) error {
394 summaryMessage.AppendContent(text)
395 return a.messages.Update(ctx, summaryMessage)
396 },
397 })
398 if err != nil {
399 return err
400 }
401
402 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
403 err = a.messages.Update(genCtx, summaryMessage)
404 if err != nil {
405 return err
406 }
407
408 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
409
410 // just in case get just the last usage
411 usage := resp.Response.Usage
412 currentSession.SummaryMessageID = summaryMessage.ID
413 currentSession.CompletionTokens = usage.OutputTokens
414 currentSession.PromptTokens = 0
415 _, err = a.sessions.Save(genCtx, currentSession)
416 return err
417}
418
419func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
420 return ai.ProviderOptions{
421 anthropic.Name: &anthropic.ProviderCacheControlOptions{
422 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
423 },
424 }
425}
426
427func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
428 var attachmentParts []message.ContentPart
429 for _, attachment := range call.Attachments {
430 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
431 }
432 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
433 parts = append(parts, attachmentParts...)
434 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
435 Role: message.User,
436 Parts: parts,
437 })
438 if err != nil {
439 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
440 }
441 return msg, nil
442}
443
444func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
445 var history []ai.Message
446 for _, m := range msgs {
447 if len(m.Parts) == 0 {
448 continue
449 }
450 // Assistant message without content or tool calls (cancelled before it returned anything)
451 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
452 continue
453 }
454 history = append(history, m.ToAIMessage()...)
455 }
456
457 var files []ai.FilePart
458 for _, attachment := range attachments {
459 files = append(files, ai.FilePart{
460 Filename: attachment.FileName,
461 Data: attachment.Content,
462 MediaType: attachment.MimeType,
463 })
464 }
465
466 return history, files
467}
468
469func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
470 msgs, err := a.messages.List(ctx, session.ID)
471 if err != nil {
472 return nil, fmt.Errorf("failed to list messages: %w", err)
473 }
474
475 if session.SummaryMessageID != "" {
476 summaryMsgInex := -1
477 for i, msg := range msgs {
478 if msg.ID == session.SummaryMessageID {
479 summaryMsgInex = i
480 break
481 }
482 }
483 if summaryMsgInex != -1 {
484 msgs = msgs[summaryMsgInex:]
485 msgs[0].Role = message.User
486 }
487 }
488 return msgs, nil
489}
490
491func (a *sessionAgent) generateTitle(ctx context.Context, session session.Session, prompt string) {
492 if prompt == "" {
493 return
494 }
495
496 agent := ai.NewAgent(a.smallModel.model,
497 ai.WithSystemPrompt(string(titlePrompt)),
498 ai.WithMaxOutputTokens(40),
499 )
500
501 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
502 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", prompt),
503 })
504 if err != nil {
505 slog.Error("error generating title", "err", err)
506 return
507 }
508
509 title := resp.Response.Content.Text()
510
511 title = strings.ReplaceAll(title, "\n", " ")
512
513 // remove thinking tags if present
514 if idx := strings.Index(title, "</think>"); idx > 0 {
515 title = title[idx+len("</think>"):]
516 }
517
518 title = strings.TrimSpace(title)
519 if title == "" {
520 slog.Warn("failed to generate title", "warn", "empty title")
521 return
522 }
523
524 session.Title = title
525 a.updateSessionUsage(a.smallModel, &session, resp.TotalUsage)
526 _, saveErr := a.sessions.Save(ctx, session)
527 if saveErr != nil {
528 slog.Error("failed to save session title & usage", "error", saveErr)
529 return
530 }
531}
532
533func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
534 modelConfig := model.config
535 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
536 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
537 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
538 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
539 session.Cost += cost
540 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
541 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
542}
543
544func (a *sessionAgent) Cancel(sessionID string) {
545 // Cancel regular requests
546 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
547 slog.Info("Request cancellation initiated", "session_id", sessionID)
548 cancel()
549 }
550
551 // Also check for summarize requests
552 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
553 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
554 cancel()
555 }
556
557 if a.QueuedPrompts(sessionID) > 0 {
558 slog.Info("Clearing queued prompts", "session_id", sessionID)
559 a.messageQueue.Del(sessionID)
560 }
561}
562
563func (a *sessionAgent) ClearQueue(sessionID string) {
564 if a.QueuedPrompts(sessionID) > 0 {
565 slog.Info("Clearing queued prompts", "session_id", sessionID)
566 a.messageQueue.Del(sessionID)
567 }
568}
569
570func (a *sessionAgent) CancelAll() {
571 if !a.IsBusy() {
572 return
573 }
574 for key := range a.activeRequests.Seq2() {
575 a.Cancel(key) // key is sessionID
576 }
577
578 timeout := time.After(5 * time.Second)
579 for a.IsBusy() {
580 select {
581 case <-timeout:
582 return
583 default:
584 time.Sleep(200 * time.Millisecond)
585 }
586 }
587}
588
589func (a *sessionAgent) IsBusy() bool {
590 var busy bool
591 for cancelFunc := range a.activeRequests.Seq() {
592 if cancelFunc != nil {
593 busy = true
594 break
595 }
596 }
597 return busy
598}
599
600func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
601 _, busy := a.activeRequests.Get(sessionID)
602 return busy
603}
604
605func (a *sessionAgent) QueuedPrompts(sessionID string) int {
606 l, ok := a.messageQueue.Get(sessionID)
607 if !ok {
608 return 0
609 }
610 return len(l)
611}
612
613func (a *sessionAgent) SetModels(large Model, small Model) {
614 a.largeModel = large
615 a.smallModel = small
616}
617
618func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
619 a.tools = tools
620}