1package agent
2
3import (
4 "context"
5 _ "embed"
6 "encoding/json"
7 "errors"
8 "fmt"
9 "log/slog"
10 "strings"
11 "sync"
12 "time"
13
14 "github.com/charmbracelet/catwalk/pkg/catwalk"
15 "github.com/charmbracelet/crush/internal/agent/tools"
16 "github.com/charmbracelet/crush/internal/config"
17 "github.com/charmbracelet/crush/internal/csync"
18 "github.com/charmbracelet/crush/internal/message"
19 "github.com/charmbracelet/crush/internal/permission"
20 "github.com/charmbracelet/crush/internal/session"
21 "github.com/charmbracelet/fantasy/ai"
22 "github.com/charmbracelet/fantasy/anthropic"
23)
24
25//go:embed templates/title.md
26var titlePrompt []byte
27
28//go:embed templates/summary.md
29var summaryPrompt []byte
30
31type SessionAgentCall struct {
32 SessionID string
33 Prompt string
34 ProviderOptions ai.ProviderOptions
35 Attachments []message.Attachment
36 MaxOutputTokens int64
37 Temperature *float64
38 TopP *float64
39 TopK *int64
40 FrequencyPenalty *float64
41 PresencePenalty *float64
42}
43
44type SessionAgent interface {
45 Run(context.Context, SessionAgentCall) (*ai.AgentResult, error)
46 SetModels(large Model, small Model)
47 SetTools(tools []ai.AgentTool)
48 Cancel(sessionID string)
49 CancelAll()
50 IsSessionBusy(sessionID string) bool
51 IsBusy() bool
52 QueuedPrompts(sessionID string) int
53 ClearQueue(sessionID string)
54 Summarize(context.Context, string) error
55 Model() Model
56}
57
58type Model struct {
59 Model ai.LanguageModel
60 CatwalkCfg catwalk.Model
61 ModelCfg config.SelectedModel
62}
63
64type sessionAgent struct {
65 largeModel Model
66 smallModel Model
67 systemPrompt string
68 tools []ai.AgentTool
69 sessions session.Service
70 messages message.Service
71
72 messageQueue *csync.Map[string, []SessionAgentCall]
73 activeRequests *csync.Map[string, context.CancelFunc]
74}
75
76type SessionAgentOption func(*sessionAgent)
77
78func NewSessionAgent(
79 largeModel Model,
80 smallModel Model,
81 systemPrompt string,
82 sessions session.Service,
83 messages message.Service,
84 tools ...ai.AgentTool,
85) SessionAgent {
86 return &sessionAgent{
87 largeModel: largeModel,
88 smallModel: smallModel,
89 systemPrompt: systemPrompt,
90 sessions: sessions,
91 messages: messages,
92 tools: tools,
93 messageQueue: csync.NewMap[string, []SessionAgentCall](),
94 activeRequests: csync.NewMap[string, context.CancelFunc](),
95 }
96}
97
98func (a *sessionAgent) Run(ctx context.Context, call SessionAgentCall) (*ai.AgentResult, error) {
99 if call.Prompt == "" {
100 return nil, ErrEmptyPrompt
101 }
102 if call.SessionID == "" {
103 return nil, ErrSessionMissing
104 }
105
106 // Queue the message if busy
107 if a.IsSessionBusy(call.SessionID) {
108 existing, ok := a.messageQueue.Get(call.SessionID)
109 if !ok {
110 existing = []SessionAgentCall{}
111 }
112 existing = append(existing, call)
113 a.messageQueue.Set(call.SessionID, existing)
114 return nil, nil
115 }
116
117 if len(a.tools) > 0 {
118 // add anthropic caching to the last tool
119 a.tools[len(a.tools)-1].SetProviderOptions(a.getCacheControlOptions())
120 }
121
122 agent := ai.NewAgent(
123 a.largeModel.Model,
124 ai.WithSystemPrompt(a.systemPrompt),
125 ai.WithTools(a.tools...),
126 )
127
128 sessionLock := sync.Mutex{}
129 currentSession, err := a.sessions.Get(ctx, call.SessionID)
130 if err != nil {
131 return nil, fmt.Errorf("failed to get session: %w", err)
132 }
133
134 msgs, err := a.getSessionMessages(ctx, currentSession)
135 if err != nil {
136 return nil, fmt.Errorf("failed to get session messages: %w", err)
137 }
138
139 var wg sync.WaitGroup
140 // Generate title if first message
141 if len(msgs) == 0 {
142 wg.Go(func() {
143 sessionLock.Lock()
144 a.generateTitle(ctx, ¤tSession, call.Prompt)
145 sessionLock.Unlock()
146 })
147 }
148
149 // Add the user message to the session
150 _, err = a.createUserMessage(ctx, call)
151 if err != nil {
152 return nil, err
153 }
154
155 // add the session to the context
156 ctx = context.WithValue(ctx, tools.SessionIDContextKey, call.SessionID)
157
158 genCtx, cancel := context.WithCancel(ctx)
159 a.activeRequests.Set(call.SessionID, cancel)
160
161 defer cancel()
162 defer a.activeRequests.Del(call.SessionID)
163
164 history, files := a.preparePrompt(msgs, call.Attachments...)
165
166 var currentAssistant *message.Message
167 result, err := agent.Stream(genCtx, ai.AgentStreamCall{
168 Prompt: call.Prompt,
169 Files: files,
170 Messages: history,
171 ProviderOptions: call.ProviderOptions,
172 MaxOutputTokens: &call.MaxOutputTokens,
173 TopP: call.TopP,
174 Temperature: call.Temperature,
175 PresencePenalty: call.PresencePenalty,
176 TopK: call.TopK,
177 FrequencyPenalty: call.FrequencyPenalty,
178 // Before each step create the new assistant message
179 PrepareStep: func(callContext context.Context, options ai.PrepareStepFunctionOptions) (_ context.Context, prepared ai.PrepareStepResult, err error) {
180 var assistantMsg message.Message
181 assistantMsg, err = a.messages.Create(callContext, call.SessionID, message.CreateMessageParams{
182 Role: message.Assistant,
183 Parts: []message.ContentPart{},
184 Model: a.largeModel.ModelCfg.Model,
185 Provider: a.largeModel.ModelCfg.Provider,
186 })
187 if err != nil {
188 return callContext, prepared, err
189 }
190
191 callContext = context.WithValue(callContext, tools.MessageIDContextKey, assistantMsg.ID)
192
193 currentAssistant = &assistantMsg
194
195 prepared.Messages = options.Messages
196 // reset all cached items
197 for i := range prepared.Messages {
198 prepared.Messages[i].ProviderOptions = nil
199 }
200
201 queuedCalls, _ := a.messageQueue.Get(call.SessionID)
202 a.messageQueue.Del(call.SessionID)
203 for _, queued := range queuedCalls {
204 userMessage, createErr := a.createUserMessage(callContext, queued)
205 if createErr != nil {
206 return callContext, prepared, createErr
207 }
208 prepared.Messages = append(prepared.Messages, userMessage.ToAIMessage()...)
209 }
210
211 lastSystemRoleInx := 0
212 systemMessageUpdated := false
213 for i, msg := range prepared.Messages {
214 // only add cache control to the last message
215 if msg.Role == ai.MessageRoleSystem {
216 lastSystemRoleInx = i
217 } else if !systemMessageUpdated {
218 prepared.Messages[lastSystemRoleInx].ProviderOptions = a.getCacheControlOptions()
219 systemMessageUpdated = true
220 }
221 // than add cache control to the last 2 messages
222 if i > len(prepared.Messages)-3 {
223 prepared.Messages[i].ProviderOptions = a.getCacheControlOptions()
224 }
225 }
226 return callContext, prepared, err
227 },
228 OnReasoningDelta: func(id string, text string) error {
229 currentAssistant.AppendReasoningContent(text)
230 return a.messages.Update(genCtx, *currentAssistant)
231 },
232 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
233 // handle anthropic signature
234 if anthropicData, ok := reasoning.ProviderMetadata[anthropic.Name]; ok {
235 if reasoning, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok {
236 currentAssistant.AppendReasoningSignature(reasoning.Signature)
237 }
238 }
239 currentAssistant.FinishThinking()
240 return a.messages.Update(genCtx, *currentAssistant)
241 },
242 OnTextDelta: func(id string, text string) error {
243 currentAssistant.AppendContent(text)
244 return a.messages.Update(genCtx, *currentAssistant)
245 },
246 OnToolInputStart: func(id string, toolName string) error {
247 toolCall := message.ToolCall{
248 ID: id,
249 Name: toolName,
250 ProviderExecuted: false,
251 Finished: false,
252 }
253 currentAssistant.AddToolCall(toolCall)
254 return a.messages.Update(genCtx, *currentAssistant)
255 },
256 OnRetry: func(err *ai.APICallError, delay time.Duration) {
257 // TODO: implement
258 },
259 OnToolCall: func(tc ai.ToolCallContent) error {
260 toolCall := message.ToolCall{
261 ID: tc.ToolCallID,
262 Name: tc.ToolName,
263 Input: tc.Input,
264 ProviderExecuted: false,
265 Finished: true,
266 }
267 currentAssistant.AddToolCall(toolCall)
268 return a.messages.Update(genCtx, *currentAssistant)
269 },
270 OnToolResult: func(result ai.ToolResultContent) error {
271 var resultContent string
272 isError := false
273 switch result.Result.GetType() {
274 case ai.ToolResultContentTypeText:
275 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentText](result.Result)
276 if ok {
277 resultContent = r.Text
278 }
279 case ai.ToolResultContentTypeError:
280 r, ok := ai.AsToolResultOutputType[ai.ToolResultOutputContentError](result.Result)
281 if ok {
282 isError = true
283 resultContent = r.Error.Error()
284 }
285 case ai.ToolResultContentTypeMedia:
286 // TODO: handle this message type
287 }
288 toolResult := message.ToolResult{
289 ToolCallID: result.ToolCallID,
290 Name: result.ToolName,
291 Content: resultContent,
292 IsError: isError,
293 Metadata: result.ClientMetadata,
294 }
295 _, err := a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
296 Role: message.Tool,
297 Parts: []message.ContentPart{
298 toolResult,
299 },
300 })
301 if err != nil {
302 return err
303 }
304 return a.messages.Update(genCtx, *currentAssistant)
305 },
306 OnStepFinish: func(stepResult ai.StepResult) error {
307 finishReason := message.FinishReasonUnknown
308 switch stepResult.FinishReason {
309 case ai.FinishReasonLength:
310 finishReason = message.FinishReasonMaxTokens
311 case ai.FinishReasonStop:
312 finishReason = message.FinishReasonEndTurn
313 case ai.FinishReasonToolCalls:
314 finishReason = message.FinishReasonToolUse
315 }
316 currentAssistant.AddFinish(finishReason, "", "")
317 a.updateSessionUsage(a.largeModel, ¤tSession, stepResult.Usage)
318 sessionLock.Lock()
319 _, sessionErr := a.sessions.Save(genCtx, currentSession)
320 sessionLock.Unlock()
321 if sessionErr != nil {
322 return sessionErr
323 }
324 return a.messages.Update(genCtx, *currentAssistant)
325 },
326 })
327 if err != nil {
328 isCancelErr := errors.Is(err, context.Canceled)
329 isPermissionErr := errors.Is(err, permission.ErrorPermissionDenied)
330 if currentAssistant == nil {
331 return result, err
332 }
333 toolCalls := currentAssistant.ToolCalls()
334 toolResults := currentAssistant.ToolResults()
335 // INFO: we use the parent context here because the genCtx has been cancelled
336 msgs, createErr := a.messages.List(ctx, currentSession.ID)
337 if createErr != nil {
338 return nil, createErr
339 }
340 for _, tc := range toolCalls {
341 if !tc.Finished {
342 tc.Finished = true
343 tc.Input = "{}"
344 }
345 currentAssistant.AddToolCall(tc)
346
347 found := false
348 for _, msg := range msgs {
349 if msg.Role == message.Tool {
350 for _, tr := range toolResults {
351 if tr.ToolCallID == tc.ID {
352 found = true
353 break
354 }
355 }
356 }
357 if found {
358 break
359 }
360 }
361 if !found {
362 content := "There was an error while executing the tool"
363 if isCancelErr {
364 content = "Tool execution canceled by user"
365 } else if isPermissionErr {
366 content = "Permission denied"
367 }
368 toolResult := message.ToolResult{
369 ToolCallID: tc.ID,
370 Name: tc.Name,
371 Content: content,
372 IsError: true,
373 }
374 _, createErr = a.messages.Create(context.Background(), currentAssistant.SessionID, message.CreateMessageParams{
375 Role: message.Tool,
376 Parts: []message.ContentPart{
377 toolResult,
378 },
379 })
380 if createErr != nil {
381 return nil, createErr
382 }
383 }
384 }
385 if isCancelErr {
386 currentAssistant.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
387 } else if isPermissionErr {
388 currentAssistant.AddFinish(message.FinishReasonPermissionDenied, "Permission denied", "")
389 } else {
390 currentAssistant.AddFinish(message.FinishReasonError, "API Error", err.Error())
391 }
392 // INFO: we use the parent context here because the genCtx has been cancelled
393 updateErr := a.messages.Update(ctx, *currentAssistant)
394 if updateErr != nil {
395 return nil, updateErr
396 }
397 return nil, err
398 }
399 wg.Wait()
400
401 queuedMessages, ok := a.messageQueue.Get(call.SessionID)
402 if !ok || len(queuedMessages) == 0 {
403 return result, err
404 }
405 // there are queued messages restart the loop
406 firstQueuedMessage := queuedMessages[0]
407 a.messageQueue.Set(call.SessionID, queuedMessages[1:])
408 return a.Run(genCtx, firstQueuedMessage)
409}
410
411func (a *sessionAgent) Summarize(ctx context.Context, sessionID string) error {
412 if a.IsSessionBusy(sessionID) {
413 return ErrSessionBusy
414 }
415
416 currentSession, err := a.sessions.Get(ctx, sessionID)
417 if err != nil {
418 return fmt.Errorf("failed to get session: %w", err)
419 }
420 msgs, err := a.getSessionMessages(ctx, currentSession)
421 if err != nil {
422 return err
423 }
424 if len(msgs) == 0 {
425 // nothing to summarize
426 return nil
427 }
428
429 aiMsgs, _ := a.preparePrompt(msgs)
430
431 genCtx, cancel := context.WithCancel(ctx)
432 a.activeRequests.Set(sessionID, cancel)
433 defer a.activeRequests.Del(sessionID)
434 defer cancel()
435
436 agent := ai.NewAgent(a.largeModel.Model,
437 ai.WithSystemPrompt(string(summaryPrompt)),
438 )
439 summaryMessage, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
440 Role: message.Assistant,
441 Model: a.largeModel.Model.Model(),
442 Provider: a.largeModel.Model.Provider(),
443 })
444 if err != nil {
445 return err
446 }
447
448 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
449 Prompt: "Provide a detailed summary of our conversation above.",
450 Messages: aiMsgs,
451 OnReasoningDelta: func(id string, text string) error {
452 summaryMessage.AppendReasoningContent(text)
453 return a.messages.Update(ctx, summaryMessage)
454 },
455 OnReasoningEnd: func(id string, reasoning ai.ReasoningContent) error {
456 // handle anthropic signature
457 if anthropicData, ok := reasoning.ProviderMetadata["anthropic"]; ok {
458 if signature, ok := anthropicData.(*anthropic.ReasoningOptionMetadata); ok && signature.Signature != "" {
459 summaryMessage.AppendReasoningSignature(signature.Signature)
460 }
461 }
462 summaryMessage.FinishThinking()
463 return a.messages.Update(ctx, summaryMessage)
464 },
465 OnTextDelta: func(id, text string) error {
466 summaryMessage.AppendContent(text)
467 return a.messages.Update(ctx, summaryMessage)
468 },
469 })
470 if err != nil {
471 return err
472 }
473
474 summaryMessage.AddFinish(message.FinishReasonEndTurn, "", "")
475 err = a.messages.Update(genCtx, summaryMessage)
476 if err != nil {
477 return err
478 }
479
480 a.updateSessionUsage(a.largeModel, ¤tSession, resp.TotalUsage)
481
482 // just in case get just the last usage
483 usage := resp.Response.Usage
484 currentSession.SummaryMessageID = summaryMessage.ID
485 currentSession.CompletionTokens = usage.OutputTokens
486 currentSession.PromptTokens = 0
487 _, err = a.sessions.Save(genCtx, currentSession)
488 return err
489}
490
491func (a *sessionAgent) getCacheControlOptions() ai.ProviderOptions {
492 return ai.ProviderOptions{
493 anthropic.Name: &anthropic.ProviderCacheControlOptions{
494 CacheControl: anthropic.CacheControl{Type: "ephemeral"},
495 },
496 }
497}
498
499func (a *sessionAgent) createUserMessage(ctx context.Context, call SessionAgentCall) (message.Message, error) {
500 var attachmentParts []message.ContentPart
501 for _, attachment := range call.Attachments {
502 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
503 }
504 parts := []message.ContentPart{message.TextContent{Text: call.Prompt}}
505 parts = append(parts, attachmentParts...)
506 msg, err := a.messages.Create(ctx, call.SessionID, message.CreateMessageParams{
507 Role: message.User,
508 Parts: parts,
509 })
510 if err != nil {
511 return message.Message{}, fmt.Errorf("failed to create user message: %w", err)
512 }
513 return msg, nil
514}
515
516func (a *sessionAgent) preparePrompt(msgs []message.Message, attachments ...message.Attachment) ([]ai.Message, []ai.FilePart) {
517 var history []ai.Message
518 for _, m := range msgs {
519 if len(m.Parts) == 0 {
520 continue
521 }
522 // Assistant message without content or tool calls (cancelled before it returned anything)
523 if m.Role == message.Assistant && len(m.ToolCalls()) == 0 && m.Content().Text == "" && m.ReasoningContent().String() == "" {
524 continue
525 }
526 history = append(history, m.ToAIMessage()...)
527 }
528
529 var files []ai.FilePart
530 for _, attachment := range attachments {
531 files = append(files, ai.FilePart{
532 Filename: attachment.FileName,
533 Data: attachment.Content,
534 MediaType: attachment.MimeType,
535 })
536 }
537
538 return history, files
539}
540
541func (a *sessionAgent) getSessionMessages(ctx context.Context, session session.Session) ([]message.Message, error) {
542 msgs, err := a.messages.List(ctx, session.ID)
543 if err != nil {
544 return nil, fmt.Errorf("failed to list messages: %w", err)
545 }
546
547 if session.SummaryMessageID != "" {
548 summaryMsgInex := -1
549 for i, msg := range msgs {
550 if msg.ID == session.SummaryMessageID {
551 summaryMsgInex = i
552 break
553 }
554 }
555 if summaryMsgInex != -1 {
556 msgs = msgs[summaryMsgInex:]
557 msgs[0].Role = message.User
558 }
559 }
560 return msgs, nil
561}
562
563func (a *sessionAgent) generateTitle(ctx context.Context, session *session.Session, prompt string) {
564 if prompt == "" {
565 return
566 }
567
568 var maxOutput int64 = 40
569 if a.smallModel.CatwalkCfg.CanReason {
570 maxOutput = a.smallModel.CatwalkCfg.DefaultMaxTokens
571 }
572
573 agent := ai.NewAgent(a.smallModel.Model,
574 ai.WithSystemPrompt(string(titlePrompt)+"\n /no_think"),
575 ai.WithMaxOutputTokens(maxOutput),
576 )
577
578 resp, err := agent.Stream(ctx, ai.AgentStreamCall{
579 Prompt: fmt.Sprintf("Generate a concise title for the following content:\n\n%s\n <think>\n\n</think>", prompt),
580 })
581 if err != nil {
582 slog.Error("error generating title", "err", err)
583 return
584 }
585
586 data, _ := json.Marshal(resp)
587
588 slog.Info("Title Response")
589 slog.Info(string(data))
590 title := resp.Response.Content.Text()
591
592 title = strings.ReplaceAll(title, "\n", " ")
593 slog.Info(title)
594
595 // remove thinking tags if present
596 if idx := strings.Index(title, "</think>"); idx > 0 {
597 title = title[idx+len("</think>"):]
598 }
599
600 title = strings.TrimSpace(title)
601 if title == "" {
602 slog.Warn("failed to generate title", "warn", "empty title")
603 return
604 }
605
606 session.Title = title
607 a.updateSessionUsage(a.smallModel, session, resp.TotalUsage)
608 _, saveErr := a.sessions.Save(ctx, *session)
609 if saveErr != nil {
610 slog.Error("failed to save session title & usage", "error", saveErr)
611 return
612 }
613}
614
615func (a *sessionAgent) updateSessionUsage(model Model, session *session.Session, usage ai.Usage) {
616 modelConfig := model.CatwalkCfg
617 cost := modelConfig.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
618 modelConfig.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
619 modelConfig.CostPer1MIn/1e6*float64(usage.InputTokens) +
620 modelConfig.CostPer1MOut/1e6*float64(usage.OutputTokens)
621 session.Cost += cost
622 session.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
623 session.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
624}
625
626func (a *sessionAgent) Cancel(sessionID string) {
627 // Cancel regular requests
628 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
629 slog.Info("Request cancellation initiated", "session_id", sessionID)
630 cancel()
631 }
632
633 // Also check for summarize requests
634 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
635 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
636 cancel()
637 }
638
639 if a.QueuedPrompts(sessionID) > 0 {
640 slog.Info("Clearing queued prompts", "session_id", sessionID)
641 a.messageQueue.Del(sessionID)
642 }
643}
644
645func (a *sessionAgent) ClearQueue(sessionID string) {
646 if a.QueuedPrompts(sessionID) > 0 {
647 slog.Info("Clearing queued prompts", "session_id", sessionID)
648 a.messageQueue.Del(sessionID)
649 }
650}
651
652func (a *sessionAgent) CancelAll() {
653 if !a.IsBusy() {
654 return
655 }
656 for key := range a.activeRequests.Seq2() {
657 a.Cancel(key) // key is sessionID
658 }
659
660 timeout := time.After(5 * time.Second)
661 for a.IsBusy() {
662 select {
663 case <-timeout:
664 return
665 default:
666 time.Sleep(200 * time.Millisecond)
667 }
668 }
669}
670
671func (a *sessionAgent) IsBusy() bool {
672 var busy bool
673 for cancelFunc := range a.activeRequests.Seq() {
674 if cancelFunc != nil {
675 busy = true
676 break
677 }
678 }
679 return busy
680}
681
682func (a *sessionAgent) IsSessionBusy(sessionID string) bool {
683 _, busy := a.activeRequests.Get(sessionID)
684 return busy
685}
686
687func (a *sessionAgent) QueuedPrompts(sessionID string) int {
688 l, ok := a.messageQueue.Get(sessionID)
689 if !ok {
690 return 0
691 }
692 return len(l)
693}
694
695func (a *sessionAgent) SetModels(large Model, small Model) {
696 a.largeModel = large
697 a.smallModel = small
698}
699
700func (a *sessionAgent) SetTools(tools []ai.AgentTool) {
701 a.tools = tools
702}
703
704func (a *sessionAgent) Model() Model {
705 return a.largeModel
706}