1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70 mcpTools []McpTool
71
72 tools *csync.LazySlice[tools.BaseTool]
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests *csync.Map[string, context.CancelFunc]
82}
83
84var agentPromptMap = map[string]prompt.PromptID{
85 "coder": prompt.PromptCoder,
86 "task": prompt.PromptTask,
87}
88
89func NewAgent(
90 ctx context.Context,
91 agentCfg config.Agent,
92 // These services are needed in the tools
93 permissions permission.Service,
94 sessions session.Service,
95 messages message.Service,
96 history history.Service,
97 lspClients map[string]*lsp.Client,
98) (Service, error) {
99 cfg := config.Get()
100
101 var agentTool tools.BaseTool
102 if agentCfg.ID == "coder" {
103 taskAgentCfg := config.Get().Agents["task"]
104 if taskAgentCfg.ID == "" {
105 return nil, fmt.Errorf("task agent not found in config")
106 }
107 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108 if err != nil {
109 return nil, fmt.Errorf("failed to create task agent: %w", err)
110 }
111
112 agentTool = NewAgentTool(taskAgent, sessions, messages)
113 }
114
115 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116 if providerCfg == nil {
117 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118 }
119 model := config.Get().GetModelByType(agentCfg.Model)
120
121 if model == nil {
122 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123 }
124
125 promptID := agentPromptMap[agentCfg.ID]
126 if promptID == "" {
127 promptID = prompt.PromptDefault
128 }
129 opts := []provider.ProviderClientOption{
130 provider.WithModel(agentCfg.Model),
131 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132 }
133 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134 if err != nil {
135 return nil, err
136 }
137
138 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139 var smallModelProviderCfg *config.ProviderConfig
140 if smallModelCfg.Provider == providerCfg.ID {
141 smallModelProviderCfg = providerCfg
142 } else {
143 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145 if smallModelProviderCfg.ID == "" {
146 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147 }
148 }
149 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150 if smallModel.ID == "" {
151 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152 }
153
154 titleOpts := []provider.ProviderClientOption{
155 provider.WithModel(config.SelectedModelTypeSmall),
156 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157 }
158 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159 if err != nil {
160 return nil, err
161 }
162 summarizeOpts := []provider.ProviderClientOption{
163 provider.WithModel(config.SelectedModelTypeSmall),
164 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165 }
166 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167 if err != nil {
168 return nil, err
169 }
170
171 toolFn := func() []tools.BaseTool {
172 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173 defer func() {
174 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175 }()
176
177 cwd := cfg.WorkingDir()
178 allTools := []tools.BaseTool{
179 tools.NewBashTool(permissions, cwd),
180 tools.NewDownloadTool(permissions, cwd),
181 tools.NewEditTool(lspClients, permissions, history, cwd),
182 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
183 tools.NewFetchTool(permissions, cwd),
184 tools.NewGlobTool(cwd),
185 tools.NewGrepTool(cwd),
186 tools.NewLsTool(permissions, cwd),
187 tools.NewSourcegraphTool(),
188 tools.NewViewTool(lspClients, permissions, cwd),
189 tools.NewWriteTool(lspClients, permissions, history, cwd),
190 }
191
192 mcpToolsOnce.Do(func() {
193 mcpTools = doGetMCPTools(ctx, permissions, cfg)
194 })
195 allTools = append(allTools, mcpTools...)
196
197 if len(lspClients) > 0 {
198 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
199 }
200
201 if agentTool != nil {
202 allTools = append(allTools, agentTool)
203 }
204
205 if agentCfg.AllowedTools == nil {
206 return allTools
207 }
208
209 var filteredTools []tools.BaseTool
210 for _, tool := range allTools {
211 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
212 filteredTools = append(filteredTools, tool)
213 }
214 }
215 return filteredTools
216 }
217
218 return &agent{
219 Broker: pubsub.NewBroker[AgentEvent](),
220 agentCfg: agentCfg,
221 provider: agentProvider,
222 providerID: string(providerCfg.ID),
223 messages: messages,
224 sessions: sessions,
225 titleProvider: titleProvider,
226 summarizeProvider: summarizeProvider,
227 summarizeProviderID: string(smallModelProviderCfg.ID),
228 activeRequests: csync.NewMap[string, context.CancelFunc](),
229 tools: csync.NewLazySlice(toolFn),
230 }, nil
231}
232
233func (a *agent) Model() catwalk.Model {
234 return *config.Get().GetModelByType(a.agentCfg.Model)
235}
236
237func (a *agent) Cancel(sessionID string) {
238 // Cancel regular requests
239 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
240 slog.Info("Request cancellation initiated", "session_id", sessionID)
241 cancel()
242 }
243
244 // Also check for summarize requests
245 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
246 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249}
250
251func (a *agent) IsBusy() bool {
252 var busy bool
253 for cancelFunc := range a.activeRequests.Seq() {
254 if cancelFunc != nil {
255 busy = true
256 break
257 }
258 }
259 return busy
260}
261
262func (a *agent) IsSessionBusy(sessionID string) bool {
263 _, busy := a.activeRequests.Get(sessionID)
264 return busy
265}
266
267func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
268 if content == "" {
269 return nil
270 }
271 if a.titleProvider == nil {
272 return nil
273 }
274 session, err := a.sessions.Get(ctx, sessionID)
275 if err != nil {
276 return err
277 }
278 parts := []message.ContentPart{message.TextContent{
279 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
280 }}
281
282 // Use streaming approach like summarization
283 response := a.titleProvider.StreamResponse(
284 ctx,
285 []message.Message{
286 {
287 Role: message.User,
288 Parts: parts,
289 },
290 },
291 nil,
292 )
293
294 var finalResponse *provider.ProviderResponse
295 for r := range response {
296 if r.Error != nil {
297 return r.Error
298 }
299 finalResponse = r.Response
300 }
301
302 if finalResponse == nil {
303 return fmt.Errorf("no response received from title provider")
304 }
305
306 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
307 if title == "" {
308 return nil
309 }
310
311 session.Title = title
312 _, err = a.sessions.Save(ctx, session)
313 return err
314}
315
316func (a *agent) err(err error) AgentEvent {
317 return AgentEvent{
318 Type: AgentEventTypeError,
319 Error: err,
320 }
321}
322
323func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
324 if !a.Model().SupportsImages && attachments != nil {
325 attachments = nil
326 }
327 events := make(chan AgentEvent)
328 if a.IsSessionBusy(sessionID) {
329 return nil, ErrSessionBusy
330 }
331
332 genCtx, cancel := context.WithCancel(ctx)
333
334 a.activeRequests.Set(sessionID, cancel)
335 go func() {
336 slog.Debug("Request started", "sessionID", sessionID)
337 defer log.RecoverPanic("agent.Run", func() {
338 events <- a.err(fmt.Errorf("panic while running the agent"))
339 })
340 var attachmentParts []message.ContentPart
341 for _, attachment := range attachments {
342 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
343 }
344 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
345 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
346 slog.Error(result.Error.Error())
347 }
348 slog.Debug("Request completed", "sessionID", sessionID)
349 a.activeRequests.Del(sessionID)
350 cancel()
351 a.Publish(pubsub.CreatedEvent, result)
352 events <- result
353 close(events)
354 }()
355 return events, nil
356}
357
358func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
359 cfg := config.Get()
360 // List existing messages; if none, start title generation asynchronously.
361 msgs, err := a.messages.List(ctx, sessionID)
362 if err != nil {
363 return a.err(fmt.Errorf("failed to list messages: %w", err))
364 }
365 if len(msgs) == 0 {
366 go func() {
367 defer log.RecoverPanic("agent.Run", func() {
368 slog.Error("panic while generating title")
369 })
370 titleErr := a.generateTitle(context.Background(), sessionID, content)
371 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
372 slog.Error("failed to generate title", "error", titleErr)
373 }
374 }()
375 }
376 session, err := a.sessions.Get(ctx, sessionID)
377 if err != nil {
378 return a.err(fmt.Errorf("failed to get session: %w", err))
379 }
380 if session.SummaryMessageID != "" {
381 summaryMsgInex := -1
382 for i, msg := range msgs {
383 if msg.ID == session.SummaryMessageID {
384 summaryMsgInex = i
385 break
386 }
387 }
388 if summaryMsgInex != -1 {
389 msgs = msgs[summaryMsgInex:]
390 msgs[0].Role = message.User
391 }
392 }
393
394 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
395 if err != nil {
396 return a.err(fmt.Errorf("failed to create user message: %w", err))
397 }
398 // Append the new user message to the conversation history.
399 msgHistory := append(msgs, userMsg)
400
401 for {
402 // Check for cancellation before each iteration
403 select {
404 case <-ctx.Done():
405 return a.err(ctx.Err())
406 default:
407 // Continue processing
408 }
409 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
410 if err != nil {
411 if errors.Is(err, context.Canceled) {
412 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
413 a.messages.Update(context.Background(), agentMessage)
414 return a.err(ErrRequestCancelled)
415 }
416 return a.err(fmt.Errorf("failed to process events: %w", err))
417 }
418 if cfg.Options.Debug {
419 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
420 }
421 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
422 // We are not done, we need to respond with the tool response
423 msgHistory = append(msgHistory, agentMessage, *toolResults)
424 continue
425 }
426 if agentMessage.FinishReason() == "" {
427 // Kujtim: could not track down where this is happening but this means its cancelled
428 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
429 _ = a.messages.Update(context.Background(), agentMessage)
430 return a.err(ErrRequestCancelled)
431 }
432 return AgentEvent{
433 Type: AgentEventTypeResponse,
434 Message: agentMessage,
435 Done: true,
436 }
437 }
438}
439
440func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
441 parts := []message.ContentPart{message.TextContent{Text: content}}
442 parts = append(parts, attachmentParts...)
443 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444 Role: message.User,
445 Parts: parts,
446 })
447}
448
449func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
450 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
451 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
452
453 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
454 Role: message.Assistant,
455 Parts: []message.ContentPart{},
456 Model: a.Model().ID,
457 Provider: a.providerID,
458 })
459 if err != nil {
460 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
461 }
462
463 // Add the session and message ID into the context if needed by tools.
464 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
465
466 // Process each event in the stream.
467 for event := range eventChan {
468 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
469 if errors.Is(processErr, context.Canceled) {
470 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
471 } else {
472 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
473 }
474 return assistantMsg, nil, processErr
475 }
476 if ctx.Err() != nil {
477 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
478 return assistantMsg, nil, ctx.Err()
479 }
480 }
481
482 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
483 toolCalls := assistantMsg.ToolCalls()
484 for i, toolCall := range toolCalls {
485 select {
486 case <-ctx.Done():
487 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
488 // Make all future tool calls cancelled
489 for j := i; j < len(toolCalls); j++ {
490 toolResults[j] = message.ToolResult{
491 ToolCallID: toolCalls[j].ID,
492 Content: "Tool execution canceled by user",
493 IsError: true,
494 }
495 }
496 goto out
497 default:
498 // Continue processing
499 var tool tools.BaseTool
500 for availableTool := range a.tools.Seq() {
501 if availableTool.Info().Name == toolCall.Name {
502 tool = availableTool
503 break
504 }
505 }
506
507 // Tool not found
508 if tool == nil {
509 toolResults[i] = message.ToolResult{
510 ToolCallID: toolCall.ID,
511 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
512 IsError: true,
513 }
514 continue
515 }
516
517 // Run tool in goroutine to allow cancellation
518 type toolExecResult struct {
519 response tools.ToolResponse
520 err error
521 }
522 resultChan := make(chan toolExecResult, 1)
523
524 go func() {
525 response, err := tool.Run(ctx, tools.ToolCall{
526 ID: toolCall.ID,
527 Name: toolCall.Name,
528 Input: toolCall.Input,
529 })
530 resultChan <- toolExecResult{response: response, err: err}
531 }()
532
533 var toolResponse tools.ToolResponse
534 var toolErr error
535
536 select {
537 case <-ctx.Done():
538 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
539 // Mark remaining tool calls as cancelled
540 for j := i; j < len(toolCalls); j++ {
541 toolResults[j] = message.ToolResult{
542 ToolCallID: toolCalls[j].ID,
543 Content: "Tool execution canceled by user",
544 IsError: true,
545 }
546 }
547 goto out
548 case result := <-resultChan:
549 toolResponse = result.response
550 toolErr = result.err
551 }
552
553 if toolErr != nil {
554 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
555 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
556 toolResults[i] = message.ToolResult{
557 ToolCallID: toolCall.ID,
558 Content: "Permission denied",
559 IsError: true,
560 }
561 for j := i + 1; j < len(toolCalls); j++ {
562 toolResults[j] = message.ToolResult{
563 ToolCallID: toolCalls[j].ID,
564 Content: "Tool execution canceled by user",
565 IsError: true,
566 }
567 }
568 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
569 break
570 }
571 }
572 toolResults[i] = message.ToolResult{
573 ToolCallID: toolCall.ID,
574 Content: toolResponse.Content,
575 Metadata: toolResponse.Metadata,
576 IsError: toolResponse.IsError,
577 }
578 }
579 }
580out:
581 if len(toolResults) == 0 {
582 return assistantMsg, nil, nil
583 }
584 parts := make([]message.ContentPart, 0)
585 for _, tr := range toolResults {
586 parts = append(parts, tr)
587 }
588 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
589 Role: message.Tool,
590 Parts: parts,
591 Provider: a.providerID,
592 })
593 if err != nil {
594 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
595 }
596
597 return assistantMsg, &msg, err
598}
599
600func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
601 msg.AddFinish(finishReason, message, details)
602 _ = a.messages.Update(ctx, *msg)
603}
604
605func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
606 select {
607 case <-ctx.Done():
608 return ctx.Err()
609 default:
610 // Continue processing.
611 }
612
613 switch event.Type {
614 case provider.EventThinkingDelta:
615 assistantMsg.AppendReasoningContent(event.Thinking)
616 return a.messages.Update(ctx, *assistantMsg)
617 case provider.EventSignatureDelta:
618 assistantMsg.AppendReasoningSignature(event.Signature)
619 return a.messages.Update(ctx, *assistantMsg)
620 case provider.EventContentDelta:
621 assistantMsg.FinishThinking()
622 assistantMsg.AppendContent(event.Content)
623 return a.messages.Update(ctx, *assistantMsg)
624 case provider.EventToolUseStart:
625 assistantMsg.FinishThinking()
626 slog.Info("Tool call started", "toolCall", event.ToolCall)
627 assistantMsg.AddToolCall(*event.ToolCall)
628 return a.messages.Update(ctx, *assistantMsg)
629 case provider.EventToolUseDelta:
630 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
631 return a.messages.Update(ctx, *assistantMsg)
632 case provider.EventToolUseStop:
633 slog.Info("Finished tool call", "toolCall", event.ToolCall)
634 assistantMsg.FinishToolCall(event.ToolCall.ID)
635 return a.messages.Update(ctx, *assistantMsg)
636 case provider.EventError:
637 assistantMsg.SetRetrying(false)
638 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
639 return fmt.Errorf("failed to update message: %w", err)
640 }
641 return event.Error
642 case provider.EventRetry:
643 errMsg := ""
644 if event.Error != nil {
645 errMsg = event.Error.Error()
646 }
647 assistantMsg.SetRetrying(false)
648 assistantMsg.AddRetry(errMsg, event.Retry)
649 return a.messages.Update(ctx, *assistantMsg)
650 case provider.EventRetrying:
651 assistantMsg.SetRetrying(true)
652 return a.messages.Update(ctx, *assistantMsg)
653
654 case provider.EventComplete:
655 assistantMsg.FinishThinking()
656 assistantMsg.SetToolCalls(event.Response.ToolCalls)
657 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
658 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
659 return fmt.Errorf("failed to update message: %w", err)
660 }
661 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
662 }
663
664 return nil
665}
666
667func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
668 sess, err := a.sessions.Get(ctx, sessionID)
669 if err != nil {
670 return fmt.Errorf("failed to get session: %w", err)
671 }
672
673 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
674 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
675 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
676 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
677
678 sess.Cost += cost
679 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
680 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
681
682 _, err = a.sessions.Save(ctx, sess)
683 if err != nil {
684 return fmt.Errorf("failed to save session: %w", err)
685 }
686 return nil
687}
688
689func (a *agent) Summarize(ctx context.Context, sessionID string) error {
690 if a.summarizeProvider == nil {
691 return fmt.Errorf("summarize provider not available")
692 }
693
694 // Check if session is busy
695 if a.IsSessionBusy(sessionID) {
696 return ErrSessionBusy
697 }
698
699 // Create a new context with cancellation
700 summarizeCtx, cancel := context.WithCancel(ctx)
701
702 // Store the cancel function in activeRequests to allow cancellation
703 a.activeRequests.Set(sessionID+"-summarize", cancel)
704
705 go func() {
706 defer a.activeRequests.Del(sessionID + "-summarize")
707 defer cancel()
708 event := AgentEvent{
709 Type: AgentEventTypeSummarize,
710 Progress: "Starting summarization...",
711 }
712
713 a.Publish(pubsub.CreatedEvent, event)
714 // Get all messages from the session
715 msgs, err := a.messages.List(summarizeCtx, sessionID)
716 if err != nil {
717 event = AgentEvent{
718 Type: AgentEventTypeError,
719 Error: fmt.Errorf("failed to list messages: %w", err),
720 Done: true,
721 }
722 a.Publish(pubsub.CreatedEvent, event)
723 return
724 }
725 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
726
727 if len(msgs) == 0 {
728 event = AgentEvent{
729 Type: AgentEventTypeError,
730 Error: fmt.Errorf("no messages to summarize"),
731 Done: true,
732 }
733 a.Publish(pubsub.CreatedEvent, event)
734 return
735 }
736
737 event = AgentEvent{
738 Type: AgentEventTypeSummarize,
739 Progress: "Analyzing conversation...",
740 }
741 a.Publish(pubsub.CreatedEvent, event)
742
743 // Add a system message to guide the summarization
744 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
745
746 // Create a new message with the summarize prompt
747 promptMsg := message.Message{
748 Role: message.User,
749 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
750 }
751
752 // Append the prompt to the messages
753 msgsWithPrompt := append(msgs, promptMsg)
754
755 event = AgentEvent{
756 Type: AgentEventTypeSummarize,
757 Progress: "Generating summary...",
758 }
759
760 a.Publish(pubsub.CreatedEvent, event)
761
762 // Send the messages to the summarize provider
763 response := a.summarizeProvider.StreamResponse(
764 summarizeCtx,
765 msgsWithPrompt,
766 nil,
767 )
768 var finalResponse *provider.ProviderResponse
769 for r := range response {
770 if r.Error != nil {
771 event = AgentEvent{
772 Type: AgentEventTypeError,
773 Error: fmt.Errorf("failed to summarize: %w", err),
774 Done: true,
775 }
776 a.Publish(pubsub.CreatedEvent, event)
777 return
778 }
779 finalResponse = r.Response
780 }
781
782 summary := strings.TrimSpace(finalResponse.Content)
783 if summary == "" {
784 event = AgentEvent{
785 Type: AgentEventTypeError,
786 Error: fmt.Errorf("empty summary returned"),
787 Done: true,
788 }
789 a.Publish(pubsub.CreatedEvent, event)
790 return
791 }
792 shell := shell.GetPersistentShell(config.Get().WorkingDir())
793 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
794 event = AgentEvent{
795 Type: AgentEventTypeSummarize,
796 Progress: "Creating new session...",
797 }
798
799 a.Publish(pubsub.CreatedEvent, event)
800 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
801 if err != nil {
802 event = AgentEvent{
803 Type: AgentEventTypeError,
804 Error: fmt.Errorf("failed to get session: %w", err),
805 Done: true,
806 }
807
808 a.Publish(pubsub.CreatedEvent, event)
809 return
810 }
811 // Create a message in the new session with the summary
812 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
813 Role: message.Assistant,
814 Parts: []message.ContentPart{
815 message.TextContent{Text: summary},
816 message.Finish{
817 Reason: message.FinishReasonEndTurn,
818 Time: time.Now().Unix(),
819 },
820 },
821 Model: a.summarizeProvider.Model().ID,
822 Provider: a.summarizeProviderID,
823 })
824 if err != nil {
825 event = AgentEvent{
826 Type: AgentEventTypeError,
827 Error: fmt.Errorf("failed to create summary message: %w", err),
828 Done: true,
829 }
830
831 a.Publish(pubsub.CreatedEvent, event)
832 return
833 }
834 oldSession.SummaryMessageID = msg.ID
835 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
836 oldSession.PromptTokens = 0
837 model := a.summarizeProvider.Model()
838 usage := finalResponse.Usage
839 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
840 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
841 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
842 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
843 oldSession.Cost += cost
844 _, err = a.sessions.Save(summarizeCtx, oldSession)
845 if err != nil {
846 event = AgentEvent{
847 Type: AgentEventTypeError,
848 Error: fmt.Errorf("failed to save session: %w", err),
849 Done: true,
850 }
851 a.Publish(pubsub.CreatedEvent, event)
852 }
853
854 event = AgentEvent{
855 Type: AgentEventTypeSummarize,
856 SessionID: oldSession.ID,
857 Progress: "Summary complete",
858 Done: true,
859 }
860 a.Publish(pubsub.CreatedEvent, event)
861 // Send final success event with the new session ID
862 }()
863
864 return nil
865}
866
867func (a *agent) CancelAll() {
868 if !a.IsBusy() {
869 return
870 }
871 for key := range a.activeRequests.Seq2() {
872 a.Cancel(key) // key is sessionID
873 }
874
875 timeout := time.After(5 * time.Second)
876 for a.IsBusy() {
877 select {
878 case <-timeout:
879 return
880 default:
881 time.Sleep(200 * time.Millisecond)
882 }
883 }
884}
885
886func (a *agent) UpdateModel() error {
887 cfg := config.Get()
888
889 // Get current provider configuration
890 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
891 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
892 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
893 }
894
895 // Check if provider has changed
896 if string(currentProviderCfg.ID) != a.providerID {
897 // Provider changed, need to recreate the main provider
898 model := cfg.GetModelByType(a.agentCfg.Model)
899 if model.ID == "" {
900 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
901 }
902
903 promptID := agentPromptMap[a.agentCfg.ID]
904 if promptID == "" {
905 promptID = prompt.PromptDefault
906 }
907
908 opts := []provider.ProviderClientOption{
909 provider.WithModel(a.agentCfg.Model),
910 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
911 }
912
913 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
914 if err != nil {
915 return fmt.Errorf("failed to create new provider: %w", err)
916 }
917
918 // Update the provider and provider ID
919 a.provider = newProvider
920 a.providerID = string(currentProviderCfg.ID)
921 }
922
923 // Check if small model provider has changed (affects title and summarize providers)
924 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
925 var smallModelProviderCfg config.ProviderConfig
926
927 for p := range cfg.Providers.Seq() {
928 if p.ID == smallModelCfg.Provider {
929 smallModelProviderCfg = p
930 break
931 }
932 }
933
934 if smallModelProviderCfg.ID == "" {
935 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
936 }
937
938 // Check if summarize provider has changed
939 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
940 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
941 if smallModel == nil {
942 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
943 }
944
945 // Recreate title provider
946 titleOpts := []provider.ProviderClientOption{
947 provider.WithModel(config.SelectedModelTypeSmall),
948 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
949 // We want the title to be short, so we limit the max tokens
950 provider.WithMaxTokens(40),
951 }
952 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
953 if err != nil {
954 return fmt.Errorf("failed to create new title provider: %w", err)
955 }
956
957 // Recreate summarize provider
958 summarizeOpts := []provider.ProviderClientOption{
959 provider.WithModel(config.SelectedModelTypeSmall),
960 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
961 }
962 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
963 if err != nil {
964 return fmt.Errorf("failed to create new summarize provider: %w", err)
965 }
966
967 // Update the providers and provider ID
968 a.titleProvider = newTitleProvider
969 a.summarizeProvider = newSummarizeProvider
970 a.summarizeProviderID = string(smallModelProviderCfg.ID)
971 }
972
973 return nil
974}