1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70 mcpTools []McpTool
71
72 tools *csync.LazySlice[tools.BaseTool]
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests *csync.Map[string, context.CancelFunc]
82}
83
84var agentPromptMap = map[string]prompt.PromptID{
85 "coder": prompt.PromptCoder,
86 "task": prompt.PromptTask,
87}
88
89func NewAgent(
90 ctx context.Context,
91 agentCfg config.Agent,
92 // These services are needed in the tools
93 permissions permission.Service,
94 sessions session.Service,
95 messages message.Service,
96 history history.Service,
97 lspClients map[string]*lsp.Client,
98) (Service, error) {
99 cfg := config.Get()
100
101 var agentTool tools.BaseTool
102 if agentCfg.ID == "coder" {
103 taskAgentCfg := config.Get().Agents["task"]
104 if taskAgentCfg.ID == "" {
105 return nil, fmt.Errorf("task agent not found in config")
106 }
107 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108 if err != nil {
109 return nil, fmt.Errorf("failed to create task agent: %w", err)
110 }
111
112 agentTool = NewAgentTool(taskAgent, sessions, messages)
113 }
114
115 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116 if providerCfg == nil {
117 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118 }
119 model := config.Get().GetModelByType(agentCfg.Model)
120
121 if model == nil {
122 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123 }
124
125 promptID := agentPromptMap[agentCfg.ID]
126 if promptID == "" {
127 promptID = prompt.PromptDefault
128 }
129 opts := []provider.ProviderClientOption{
130 provider.WithModel(agentCfg.Model),
131 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132 }
133 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134 if err != nil {
135 return nil, err
136 }
137
138 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139 var smallModelProviderCfg *config.ProviderConfig
140 if smallModelCfg.Provider == providerCfg.ID {
141 smallModelProviderCfg = providerCfg
142 } else {
143 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145 if smallModelProviderCfg.ID == "" {
146 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147 }
148 }
149 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150 if smallModel.ID == "" {
151 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152 }
153
154 titleOpts := []provider.ProviderClientOption{
155 provider.WithModel(config.SelectedModelTypeSmall),
156 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157 }
158 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159 if err != nil {
160 return nil, err
161 }
162 summarizeOpts := []provider.ProviderClientOption{
163 provider.WithModel(config.SelectedModelTypeSmall),
164 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165 }
166 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167 if err != nil {
168 return nil, err
169 }
170
171 toolFn := func() []tools.BaseTool {
172 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173 defer func() {
174 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175 }()
176
177 cwd := cfg.WorkingDir()
178 allTools := []tools.BaseTool{
179 tools.NewBashTool(permissions, cwd),
180 tools.NewDownloadTool(permissions, cwd),
181 tools.NewEditTool(lspClients, permissions, history, cwd),
182 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
183 tools.NewFetchTool(permissions, cwd),
184 tools.NewGlobTool(cwd),
185 tools.NewGrepTool(cwd),
186 tools.NewLsTool(permissions, cwd),
187 tools.NewSourcegraphTool(),
188 tools.NewViewTool(lspClients, permissions, cwd),
189 tools.NewWriteTool(lspClients, permissions, history, cwd),
190 }
191
192 mcpToolsOnce.Do(func() {
193 mcpTools = doGetMCPTools(ctx, permissions, cfg)
194 })
195 allTools = append(allTools, mcpTools...)
196
197 if len(lspClients) > 0 {
198 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
199 }
200
201 if agentTool != nil {
202 allTools = append(allTools, agentTool)
203 }
204
205 if agentCfg.AllowedTools == nil {
206 return allTools
207 }
208
209 var filteredTools []tools.BaseTool
210 for _, tool := range allTools {
211 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
212 filteredTools = append(filteredTools, tool)
213 }
214 }
215 return filteredTools
216 }
217
218 return &agent{
219 Broker: pubsub.NewBroker[AgentEvent](),
220 agentCfg: agentCfg,
221 provider: agentProvider,
222 providerID: string(providerCfg.ID),
223 messages: messages,
224 sessions: sessions,
225 titleProvider: titleProvider,
226 summarizeProvider: summarizeProvider,
227 summarizeProviderID: string(smallModelProviderCfg.ID),
228 activeRequests: csync.NewMap[string, context.CancelFunc](),
229 tools: csync.NewLazySlice(toolFn),
230 }, nil
231}
232
233func (a *agent) Model() catwalk.Model {
234 return *config.Get().GetModelByType(a.agentCfg.Model)
235}
236
237func (a *agent) Cancel(sessionID string) {
238 // Cancel regular requests
239 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
240 slog.Info("Request cancellation initiated", "session_id", sessionID)
241 cancel()
242 }
243
244 // Also check for summarize requests
245 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
246 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249}
250
251func (a *agent) IsBusy() bool {
252 var busy bool
253 for cancelFunc := range a.activeRequests.Seq() {
254 if cancelFunc != nil {
255 busy = true
256 break
257 }
258 }
259 return busy
260}
261
262func (a *agent) IsSessionBusy(sessionID string) bool {
263 _, busy := a.activeRequests.Get(sessionID)
264 return busy
265}
266
267func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
268 if content == "" {
269 return nil
270 }
271 if a.titleProvider == nil {
272 return nil
273 }
274 session, err := a.sessions.Get(ctx, sessionID)
275 if err != nil {
276 return err
277 }
278 parts := []message.ContentPart{message.TextContent{
279 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
280 }}
281
282 // Use streaming approach like summarization
283 response := a.titleProvider.StreamResponse(
284 ctx,
285 []message.Message{
286 {
287 Role: message.User,
288 Parts: parts,
289 },
290 },
291 nil,
292 )
293
294 var finalResponse *provider.ProviderResponse
295 for r := range response {
296 if r.Error != nil {
297 return r.Error
298 }
299 finalResponse = r.Response
300 }
301
302 if finalResponse == nil {
303 return fmt.Errorf("no response received from title provider")
304 }
305
306 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
307 if title == "" {
308 return nil
309 }
310
311 session.Title = title
312 _, err = a.sessions.Save(ctx, session)
313 return err
314}
315
316func (a *agent) err(err error) AgentEvent {
317 return AgentEvent{
318 Type: AgentEventTypeError,
319 Error: err,
320 }
321}
322
323func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
324 if !a.Model().SupportsImages && attachments != nil {
325 attachments = nil
326 }
327 events := make(chan AgentEvent)
328 if a.IsSessionBusy(sessionID) {
329 return nil, ErrSessionBusy
330 }
331
332 genCtx, cancel := context.WithCancel(ctx)
333
334 a.activeRequests.Set(sessionID, cancel)
335 go func() {
336 slog.Debug("Request started", "sessionID", sessionID)
337 defer log.RecoverPanic("agent.Run", func() {
338 events <- a.err(fmt.Errorf("panic while running the agent"))
339 })
340 var attachmentParts []message.ContentPart
341 for _, attachment := range attachments {
342 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
343 }
344 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
345 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
346 slog.Error(result.Error.Error())
347 }
348 slog.Debug("Request completed", "sessionID", sessionID)
349 a.activeRequests.Del(sessionID)
350 cancel()
351 a.Publish(pubsub.CreatedEvent, result)
352 events <- result
353 close(events)
354 }()
355 return events, nil
356}
357
358func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
359 cfg := config.Get()
360 // List existing messages; if none, start title generation asynchronously.
361 msgs, err := a.messages.List(ctx, sessionID)
362 if err != nil {
363 return a.err(fmt.Errorf("failed to list messages: %w", err))
364 }
365 if len(msgs) == 0 {
366 go func() {
367 defer log.RecoverPanic("agent.Run", func() {
368 slog.Error("panic while generating title")
369 })
370 titleErr := a.generateTitle(context.Background(), sessionID, content)
371 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
372 slog.Error("failed to generate title", "error", titleErr)
373 }
374 }()
375 }
376 session, err := a.sessions.Get(ctx, sessionID)
377 if err != nil {
378 return a.err(fmt.Errorf("failed to get session: %w", err))
379 }
380 if session.SummaryMessageID != "" {
381 summaryMsgInex := -1
382 for i, msg := range msgs {
383 if msg.ID == session.SummaryMessageID {
384 summaryMsgInex = i
385 break
386 }
387 }
388 if summaryMsgInex != -1 {
389 msgs = msgs[summaryMsgInex:]
390 msgs[0].Role = message.User
391 }
392 }
393
394 // clean messages
395 // if there is a tool call that has no tool response here we need to mark it as cancelled this could have happened by a crash or something similar
396 resultsMap := make(map[string]bool, 0) // toolCallId=>true
397 for _, msg := range msgs {
398 if msg.Role == message.Tool {
399 results := msg.ToolResults()
400 for _, result := range results {
401 resultsMap[result.ToolCallID] = true
402 }
403 }
404 }
405 for _, msg := range msgs {
406 if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
407 for _, tc := range msg.ToolCalls() {
408 if _, ok := resultsMap[tc.ID]; !ok {
409 a.finishMessage(context.Background(), &msg, message.FinishReasonCanceled, "Request cancelled", "")
410 goto next
411 }
412 }
413 next:
414 }
415 }
416
417 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
418 if err != nil {
419 return a.err(fmt.Errorf("failed to create user message: %w", err))
420 }
421 // Append the new user message to the conversation history.
422 msgHistory := append(msgs, userMsg)
423
424 for {
425 // Check for cancellation before each iteration
426 select {
427 case <-ctx.Done():
428 return a.err(ctx.Err())
429 default:
430 // Continue processing
431 }
432 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
433 if err != nil {
434 if errors.Is(err, context.Canceled) {
435 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
436 a.messages.Update(context.Background(), agentMessage)
437 return a.err(ErrRequestCancelled)
438 }
439 return a.err(fmt.Errorf("failed to process events: %w", err))
440 }
441 if cfg.Options.Debug {
442 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
443 }
444 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
445 // We are not done, we need to respond with the tool response
446 msgHistory = append(msgHistory, agentMessage, *toolResults)
447 continue
448 }
449 if agentMessage.FinishReason() == "" {
450 // Kujtim: could not track down where this is happening but this means its cancelled
451 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
452 _ = a.messages.Update(context.Background(), agentMessage)
453 return a.err(ErrRequestCancelled)
454 }
455 return AgentEvent{
456 Type: AgentEventTypeResponse,
457 Message: agentMessage,
458 Done: true,
459 }
460 }
461}
462
463func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
464 parts := []message.ContentPart{message.TextContent{Text: content}}
465 parts = append(parts, attachmentParts...)
466 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
467 Role: message.User,
468 Parts: parts,
469 })
470}
471
472func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
473 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
474 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
475
476 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
477 Role: message.Assistant,
478 Parts: []message.ContentPart{},
479 Model: a.Model().ID,
480 Provider: a.providerID,
481 })
482 if err != nil {
483 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
484 }
485
486 // Add the session and message ID into the context if needed by tools.
487 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
488
489 // Process each event in the stream.
490 for event := range eventChan {
491 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
492 if errors.Is(processErr, context.Canceled) {
493 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
494 } else {
495 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
496 }
497 return assistantMsg, nil, processErr
498 }
499 if ctx.Err() != nil {
500 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
501 return assistantMsg, nil, ctx.Err()
502 }
503 }
504
505 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
506 toolCalls := assistantMsg.ToolCalls()
507 for i, toolCall := range toolCalls {
508 select {
509 case <-ctx.Done():
510 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
511 // Make all future tool calls cancelled
512 for j := i; j < len(toolCalls); j++ {
513 toolResults[j] = message.ToolResult{
514 ToolCallID: toolCalls[j].ID,
515 Content: "Tool execution canceled by user",
516 IsError: true,
517 }
518 }
519 goto out
520 default:
521 // Continue processing
522 var tool tools.BaseTool
523 for availableTool := range a.tools.Seq() {
524 if availableTool.Info().Name == toolCall.Name {
525 tool = availableTool
526 break
527 }
528 }
529
530 // Tool not found
531 if tool == nil {
532 toolResults[i] = message.ToolResult{
533 ToolCallID: toolCall.ID,
534 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
535 IsError: true,
536 }
537 continue
538 }
539
540 // Run tool in goroutine to allow cancellation
541 type toolExecResult struct {
542 response tools.ToolResponse
543 err error
544 }
545 resultChan := make(chan toolExecResult, 1)
546
547 go func() {
548 response, err := tool.Run(ctx, tools.ToolCall{
549 ID: toolCall.ID,
550 Name: toolCall.Name,
551 Input: toolCall.Input,
552 })
553 resultChan <- toolExecResult{response: response, err: err}
554 }()
555
556 var toolResponse tools.ToolResponse
557 var toolErr error
558
559 select {
560 case <-ctx.Done():
561 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
562 // Mark remaining tool calls as cancelled
563 for j := i; j < len(toolCalls); j++ {
564 toolResults[j] = message.ToolResult{
565 ToolCallID: toolCalls[j].ID,
566 Content: "Tool execution canceled by user",
567 IsError: true,
568 }
569 }
570 goto out
571 case result := <-resultChan:
572 toolResponse = result.response
573 toolErr = result.err
574 }
575
576 if toolErr != nil {
577 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
578 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
579 toolResults[i] = message.ToolResult{
580 ToolCallID: toolCall.ID,
581 Content: "Permission denied",
582 IsError: true,
583 }
584 for j := i + 1; j < len(toolCalls); j++ {
585 toolResults[j] = message.ToolResult{
586 ToolCallID: toolCalls[j].ID,
587 Content: "Tool execution canceled by user",
588 IsError: true,
589 }
590 }
591 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
592 break
593 }
594 }
595 toolResults[i] = message.ToolResult{
596 ToolCallID: toolCall.ID,
597 Content: toolResponse.Content,
598 Metadata: toolResponse.Metadata,
599 IsError: toolResponse.IsError,
600 }
601 }
602 }
603out:
604 if len(toolResults) == 0 {
605 return assistantMsg, nil, nil
606 }
607 parts := make([]message.ContentPart, 0)
608 for _, tr := range toolResults {
609 parts = append(parts, tr)
610 }
611 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
612 Role: message.Tool,
613 Parts: parts,
614 Provider: a.providerID,
615 })
616 if err != nil {
617 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
618 }
619
620 return assistantMsg, &msg, err
621}
622
623func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
624 msg.AddFinish(finishReason, message, details)
625 _ = a.messages.Update(ctx, *msg)
626}
627
628func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
629 select {
630 case <-ctx.Done():
631 return ctx.Err()
632 default:
633 // Continue processing.
634 }
635
636 switch event.Type {
637 case provider.EventThinkingDelta:
638 assistantMsg.AppendReasoningContent(event.Thinking)
639 return a.messages.Update(ctx, *assistantMsg)
640 case provider.EventSignatureDelta:
641 assistantMsg.AppendReasoningSignature(event.Signature)
642 return a.messages.Update(ctx, *assistantMsg)
643 case provider.EventContentDelta:
644 assistantMsg.FinishThinking()
645 assistantMsg.AppendContent(event.Content)
646 return a.messages.Update(ctx, *assistantMsg)
647 case provider.EventToolUseStart:
648 assistantMsg.FinishThinking()
649 slog.Info("Tool call started", "toolCall", event.ToolCall)
650 assistantMsg.AddToolCall(*event.ToolCall)
651 return a.messages.Update(ctx, *assistantMsg)
652 case provider.EventToolUseDelta:
653 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
654 return a.messages.Update(ctx, *assistantMsg)
655 case provider.EventToolUseStop:
656 slog.Info("Finished tool call", "toolCall", event.ToolCall)
657 assistantMsg.FinishToolCall(event.ToolCall.ID)
658 return a.messages.Update(ctx, *assistantMsg)
659 case provider.EventError:
660 return event.Error
661 case provider.EventComplete:
662 assistantMsg.FinishThinking()
663 assistantMsg.SetToolCalls(event.Response.ToolCalls)
664 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
665 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
666 return fmt.Errorf("failed to update message: %w", err)
667 }
668 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
669 }
670
671 return nil
672}
673
674func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
675 sess, err := a.sessions.Get(ctx, sessionID)
676 if err != nil {
677 return fmt.Errorf("failed to get session: %w", err)
678 }
679
680 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
681 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
682 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
683 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
684
685 sess.Cost += cost
686 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
687 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
688
689 _, err = a.sessions.Save(ctx, sess)
690 if err != nil {
691 return fmt.Errorf("failed to save session: %w", err)
692 }
693 return nil
694}
695
696func (a *agent) Summarize(ctx context.Context, sessionID string) error {
697 if a.summarizeProvider == nil {
698 return fmt.Errorf("summarize provider not available")
699 }
700
701 // Check if session is busy
702 if a.IsSessionBusy(sessionID) {
703 return ErrSessionBusy
704 }
705
706 // Create a new context with cancellation
707 summarizeCtx, cancel := context.WithCancel(ctx)
708
709 // Store the cancel function in activeRequests to allow cancellation
710 a.activeRequests.Set(sessionID+"-summarize", cancel)
711
712 go func() {
713 defer a.activeRequests.Del(sessionID + "-summarize")
714 defer cancel()
715 event := AgentEvent{
716 Type: AgentEventTypeSummarize,
717 Progress: "Starting summarization...",
718 }
719
720 a.Publish(pubsub.CreatedEvent, event)
721 // Get all messages from the session
722 msgs, err := a.messages.List(summarizeCtx, sessionID)
723 if err != nil {
724 event = AgentEvent{
725 Type: AgentEventTypeError,
726 Error: fmt.Errorf("failed to list messages: %w", err),
727 Done: true,
728 }
729 a.Publish(pubsub.CreatedEvent, event)
730 return
731 }
732 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
733
734 if len(msgs) == 0 {
735 event = AgentEvent{
736 Type: AgentEventTypeError,
737 Error: fmt.Errorf("no messages to summarize"),
738 Done: true,
739 }
740 a.Publish(pubsub.CreatedEvent, event)
741 return
742 }
743
744 event = AgentEvent{
745 Type: AgentEventTypeSummarize,
746 Progress: "Analyzing conversation...",
747 }
748 a.Publish(pubsub.CreatedEvent, event)
749
750 // Add a system message to guide the summarization
751 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
752
753 // Create a new message with the summarize prompt
754 promptMsg := message.Message{
755 Role: message.User,
756 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
757 }
758
759 // Append the prompt to the messages
760 msgsWithPrompt := append(msgs, promptMsg)
761
762 event = AgentEvent{
763 Type: AgentEventTypeSummarize,
764 Progress: "Generating summary...",
765 }
766
767 a.Publish(pubsub.CreatedEvent, event)
768
769 // Send the messages to the summarize provider
770 response := a.summarizeProvider.StreamResponse(
771 summarizeCtx,
772 msgsWithPrompt,
773 nil,
774 )
775 var finalResponse *provider.ProviderResponse
776 for r := range response {
777 if r.Error != nil {
778 event = AgentEvent{
779 Type: AgentEventTypeError,
780 Error: fmt.Errorf("failed to summarize: %w", err),
781 Done: true,
782 }
783 a.Publish(pubsub.CreatedEvent, event)
784 return
785 }
786 finalResponse = r.Response
787 }
788
789 summary := strings.TrimSpace(finalResponse.Content)
790 if summary == "" {
791 event = AgentEvent{
792 Type: AgentEventTypeError,
793 Error: fmt.Errorf("empty summary returned"),
794 Done: true,
795 }
796 a.Publish(pubsub.CreatedEvent, event)
797 return
798 }
799 shell := shell.GetPersistentShell(config.Get().WorkingDir())
800 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
801 event = AgentEvent{
802 Type: AgentEventTypeSummarize,
803 Progress: "Creating new session...",
804 }
805
806 a.Publish(pubsub.CreatedEvent, event)
807 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
808 if err != nil {
809 event = AgentEvent{
810 Type: AgentEventTypeError,
811 Error: fmt.Errorf("failed to get session: %w", err),
812 Done: true,
813 }
814
815 a.Publish(pubsub.CreatedEvent, event)
816 return
817 }
818 // Create a message in the new session with the summary
819 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
820 Role: message.Assistant,
821 Parts: []message.ContentPart{
822 message.TextContent{Text: summary},
823 message.Finish{
824 Reason: message.FinishReasonEndTurn,
825 Time: time.Now().Unix(),
826 },
827 },
828 Model: a.summarizeProvider.Model().ID,
829 Provider: a.summarizeProviderID,
830 })
831 if err != nil {
832 event = AgentEvent{
833 Type: AgentEventTypeError,
834 Error: fmt.Errorf("failed to create summary message: %w", err),
835 Done: true,
836 }
837
838 a.Publish(pubsub.CreatedEvent, event)
839 return
840 }
841 oldSession.SummaryMessageID = msg.ID
842 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
843 oldSession.PromptTokens = 0
844 model := a.summarizeProvider.Model()
845 usage := finalResponse.Usage
846 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
847 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
848 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
849 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
850 oldSession.Cost += cost
851 _, err = a.sessions.Save(summarizeCtx, oldSession)
852 if err != nil {
853 event = AgentEvent{
854 Type: AgentEventTypeError,
855 Error: fmt.Errorf("failed to save session: %w", err),
856 Done: true,
857 }
858 a.Publish(pubsub.CreatedEvent, event)
859 }
860
861 event = AgentEvent{
862 Type: AgentEventTypeSummarize,
863 SessionID: oldSession.ID,
864 Progress: "Summary complete",
865 Done: true,
866 }
867 a.Publish(pubsub.CreatedEvent, event)
868 // Send final success event with the new session ID
869 }()
870
871 return nil
872}
873
874func (a *agent) CancelAll() {
875 if !a.IsBusy() {
876 return
877 }
878 for key := range a.activeRequests.Seq2() {
879 a.Cancel(key) // key is sessionID
880 }
881
882 timeout := time.After(5 * time.Second)
883 for a.IsBusy() {
884 select {
885 case <-timeout:
886 return
887 default:
888 time.Sleep(200 * time.Millisecond)
889 }
890 }
891}
892
893func (a *agent) UpdateModel() error {
894 cfg := config.Get()
895
896 // Get current provider configuration
897 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
898 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
899 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
900 }
901
902 // Check if provider has changed
903 if string(currentProviderCfg.ID) != a.providerID {
904 // Provider changed, need to recreate the main provider
905 model := cfg.GetModelByType(a.agentCfg.Model)
906 if model.ID == "" {
907 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
908 }
909
910 promptID := agentPromptMap[a.agentCfg.ID]
911 if promptID == "" {
912 promptID = prompt.PromptDefault
913 }
914
915 opts := []provider.ProviderClientOption{
916 provider.WithModel(a.agentCfg.Model),
917 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
918 }
919
920 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
921 if err != nil {
922 return fmt.Errorf("failed to create new provider: %w", err)
923 }
924
925 // Update the provider and provider ID
926 a.provider = newProvider
927 a.providerID = string(currentProviderCfg.ID)
928 }
929
930 // Check if small model provider has changed (affects title and summarize providers)
931 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
932 var smallModelProviderCfg config.ProviderConfig
933
934 for p := range cfg.Providers.Seq() {
935 if p.ID == smallModelCfg.Provider {
936 smallModelProviderCfg = p
937 break
938 }
939 }
940
941 if smallModelProviderCfg.ID == "" {
942 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
943 }
944
945 // Check if summarize provider has changed
946 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
947 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
948 if smallModel == nil {
949 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
950 }
951
952 // Recreate title provider
953 titleOpts := []provider.ProviderClientOption{
954 provider.WithModel(config.SelectedModelTypeSmall),
955 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
956 // We want the title to be short, so we limit the max tokens
957 provider.WithMaxTokens(40),
958 }
959 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
960 if err != nil {
961 return fmt.Errorf("failed to create new title provider: %w", err)
962 }
963
964 // Recreate summarize provider
965 summarizeOpts := []provider.ProviderClientOption{
966 provider.WithModel(config.SelectedModelTypeSmall),
967 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
968 }
969 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
970 if err != nil {
971 return fmt.Errorf("failed to create new summarize provider: %w", err)
972 }
973
974 // Update the providers and provider ID
975 a.titleProvider = newTitleProvider
976 a.summarizeProvider = newSummarizeProvider
977 a.summarizeProviderID = string(smallModelProviderCfg.ID)
978 }
979
980 return nil
981}