1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/catwalk/pkg/catwalk"
14 "github.com/charmbracelet/crush/internal/config"
15 "github.com/charmbracelet/crush/internal/csync"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26 "github.com/charmbracelet/crush/internal/shell"
27)
28
29// Common errors
30var (
31 ErrRequestCancelled = errors.New("request canceled by user")
32 ErrSessionBusy = errors.New("session is currently processing another request")
33)
34
35type AgentEventType string
36
37const (
38 AgentEventTypeError AgentEventType = "error"
39 AgentEventTypeResponse AgentEventType = "response"
40 AgentEventTypeSummarize AgentEventType = "summarize"
41)
42
43type AgentEvent struct {
44 Type AgentEventType
45 Message message.Message
46 Error error
47
48 // When summarizing
49 SessionID string
50 Progress string
51 Done bool
52}
53
54type Service interface {
55 pubsub.Suscriber[AgentEvent]
56 Model() catwalk.Model
57 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
58 Cancel(sessionID string)
59 CancelAll()
60 IsSessionBusy(sessionID string) bool
61 IsBusy() bool
62 Summarize(ctx context.Context, sessionID string) error
63 UpdateModel() error
64}
65
66type agent struct {
67 *pubsub.Broker[AgentEvent]
68 agentCfg config.Agent
69 sessions session.Service
70 messages message.Service
71
72 tools *csync.LazySlice[tools.BaseTool]
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests sync.Map
82}
83
84var agentPromptMap = map[string]prompt.PromptID{
85 "coder": prompt.PromptCoder,
86 "task": prompt.PromptTask,
87}
88
89func NewAgent(
90 agentCfg config.Agent,
91 // These services are needed in the tools
92 permissions permission.Service,
93 sessions session.Service,
94 messages message.Service,
95 history history.Service,
96 lspClients map[string]*lsp.Client,
97) (Service, error) {
98 ctx := context.Background()
99 cfg := config.Get()
100
101 var agentTool tools.BaseTool
102 if agentCfg.ID == "coder" {
103 taskAgentCfg := config.Get().Agents["task"]
104 if taskAgentCfg.ID == "" {
105 return nil, fmt.Errorf("task agent not found in config")
106 }
107 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
108 if err != nil {
109 return nil, fmt.Errorf("failed to create task agent: %w", err)
110 }
111
112 agentTool = NewAgentTool(taskAgent, sessions, messages)
113 }
114
115 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116 if providerCfg == nil {
117 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118 }
119 model := config.Get().GetModelByType(agentCfg.Model)
120
121 if model == nil {
122 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123 }
124
125 promptID := agentPromptMap[agentCfg.ID]
126 if promptID == "" {
127 promptID = prompt.PromptDefault
128 }
129 opts := []provider.ProviderClientOption{
130 provider.WithModel(agentCfg.Model),
131 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132 }
133 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134 if err != nil {
135 return nil, err
136 }
137
138 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139 var smallModelProviderCfg *config.ProviderConfig
140 if smallModelCfg.Provider == providerCfg.ID {
141 smallModelProviderCfg = providerCfg
142 } else {
143 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145 if smallModelProviderCfg.ID == "" {
146 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147 }
148 }
149 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150 if smallModel.ID == "" {
151 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152 }
153
154 titleOpts := []provider.ProviderClientOption{
155 provider.WithModel(config.SelectedModelTypeSmall),
156 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157 }
158 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159 if err != nil {
160 return nil, err
161 }
162 summarizeOpts := []provider.ProviderClientOption{
163 provider.WithModel(config.SelectedModelTypeSmall),
164 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165 }
166 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167 if err != nil {
168 return nil, err
169 }
170
171 toolFn := func() []tools.BaseTool {
172 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173 defer func() {
174 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175 }()
176
177 cwd := cfg.WorkingDir()
178 allTools := []tools.BaseTool{
179 tools.NewBashTool(permissions, cwd),
180 tools.NewDownloadTool(permissions, cwd),
181 tools.NewEditTool(lspClients, permissions, history, cwd),
182 tools.NewFetchTool(permissions, cwd),
183 tools.NewGlobTool(cwd),
184 tools.NewGrepTool(cwd),
185 tools.NewLsTool(cwd),
186 tools.NewSourcegraphTool(),
187 tools.NewViewTool(lspClients, cwd),
188 tools.NewWriteTool(lspClients, permissions, history, cwd),
189 }
190
191 mcpTools := GetMCPTools(ctx, permissions, cfg)
192 allTools = append(allTools, mcpTools...)
193
194 if len(lspClients) > 0 {
195 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
196 }
197
198 if agentTool != nil {
199 allTools = append(allTools, agentTool)
200 }
201
202 if agentCfg.AllowedTools == nil {
203 return allTools
204 }
205
206 var filteredTools []tools.BaseTool
207 for _, tool := range allTools {
208 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
209 filteredTools = append(filteredTools, tool)
210 }
211 }
212 return filteredTools
213 }
214
215 return &agent{
216 Broker: pubsub.NewBroker[AgentEvent](),
217 agentCfg: agentCfg,
218 provider: agentProvider,
219 providerID: string(providerCfg.ID),
220 messages: messages,
221 sessions: sessions,
222 titleProvider: titleProvider,
223 summarizeProvider: summarizeProvider,
224 summarizeProviderID: string(smallModelProviderCfg.ID),
225 activeRequests: sync.Map{},
226 tools: csync.NewLazySlice(toolFn),
227 }, nil
228}
229
230func (a *agent) Model() catwalk.Model {
231 return *config.Get().GetModelByType(a.agentCfg.Model)
232}
233
234func (a *agent) Cancel(sessionID string) {
235 // Cancel regular requests
236 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
237 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
238 slog.Info("Request cancellation initiated", "session_id", sessionID)
239 cancel()
240 }
241 }
242
243 // Also check for summarize requests
244 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
245 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
246 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249 }
250}
251
252func (a *agent) IsBusy() bool {
253 busy := false
254 a.activeRequests.Range(func(key, value any) bool {
255 if cancelFunc, ok := value.(context.CancelFunc); ok {
256 if cancelFunc != nil {
257 busy = true
258 return false
259 }
260 }
261 return true
262 })
263 return busy
264}
265
266func (a *agent) IsSessionBusy(sessionID string) bool {
267 _, busy := a.activeRequests.Load(sessionID)
268 return busy
269}
270
271func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
272 if content == "" {
273 return nil
274 }
275 if a.titleProvider == nil {
276 return nil
277 }
278 session, err := a.sessions.Get(ctx, sessionID)
279 if err != nil {
280 return err
281 }
282 parts := []message.ContentPart{message.TextContent{
283 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
284 }}
285
286 // Use streaming approach like summarization
287 response := a.titleProvider.StreamResponse(
288 ctx,
289 []message.Message{
290 {
291 Role: message.User,
292 Parts: parts,
293 },
294 },
295 nil,
296 )
297
298 var finalResponse *provider.ProviderResponse
299 for r := range response {
300 if r.Error != nil {
301 return r.Error
302 }
303 finalResponse = r.Response
304 }
305
306 if finalResponse == nil {
307 return fmt.Errorf("no response received from title provider")
308 }
309
310 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
311 if title == "" {
312 return nil
313 }
314
315 session.Title = title
316 _, err = a.sessions.Save(ctx, session)
317 return err
318}
319
320func (a *agent) err(err error) AgentEvent {
321 return AgentEvent{
322 Type: AgentEventTypeError,
323 Error: err,
324 }
325}
326
327func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
328 if !a.Model().SupportsImages && attachments != nil {
329 attachments = nil
330 }
331 events := make(chan AgentEvent)
332 if a.IsSessionBusy(sessionID) {
333 return nil, ErrSessionBusy
334 }
335
336 genCtx, cancel := context.WithCancel(ctx)
337
338 a.activeRequests.Store(sessionID, cancel)
339 go func() {
340 slog.Debug("Request started", "sessionID", sessionID)
341 defer log.RecoverPanic("agent.Run", func() {
342 events <- a.err(fmt.Errorf("panic while running the agent"))
343 })
344 var attachmentParts []message.ContentPart
345 for _, attachment := range attachments {
346 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
347 }
348 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
349 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
350 slog.Error(result.Error.Error())
351 }
352 slog.Debug("Request completed", "sessionID", sessionID)
353 a.activeRequests.Delete(sessionID)
354 cancel()
355 a.Publish(pubsub.CreatedEvent, result)
356 events <- result
357 close(events)
358 }()
359 return events, nil
360}
361
362func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
363 cfg := config.Get()
364 // List existing messages; if none, start title generation asynchronously.
365 msgs, err := a.messages.List(ctx, sessionID)
366 if err != nil {
367 return a.err(fmt.Errorf("failed to list messages: %w", err))
368 }
369 if len(msgs) == 0 {
370 go func() {
371 defer log.RecoverPanic("agent.Run", func() {
372 slog.Error("panic while generating title")
373 })
374 titleErr := a.generateTitle(context.Background(), sessionID, content)
375 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
376 slog.Error("failed to generate title", "error", titleErr)
377 }
378 }()
379 }
380 session, err := a.sessions.Get(ctx, sessionID)
381 if err != nil {
382 return a.err(fmt.Errorf("failed to get session: %w", err))
383 }
384 if session.SummaryMessageID != "" {
385 summaryMsgInex := -1
386 for i, msg := range msgs {
387 if msg.ID == session.SummaryMessageID {
388 summaryMsgInex = i
389 break
390 }
391 }
392 if summaryMsgInex != -1 {
393 msgs = msgs[summaryMsgInex:]
394 msgs[0].Role = message.User
395 }
396 }
397
398 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
399 if err != nil {
400 return a.err(fmt.Errorf("failed to create user message: %w", err))
401 }
402 // Append the new user message to the conversation history.
403 msgHistory := append(msgs, userMsg)
404
405 for {
406 // Check for cancellation before each iteration
407 select {
408 case <-ctx.Done():
409 return a.err(ctx.Err())
410 default:
411 // Continue processing
412 }
413 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
414 if err != nil {
415 if errors.Is(err, context.Canceled) {
416 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
417 a.messages.Update(context.Background(), agentMessage)
418 return a.err(ErrRequestCancelled)
419 }
420 return a.err(fmt.Errorf("failed to process events: %w", err))
421 }
422 if cfg.Options.Debug {
423 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
424 }
425 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
426 // We are not done, we need to respond with the tool response
427 msgHistory = append(msgHistory, agentMessage, *toolResults)
428 continue
429 }
430 return AgentEvent{
431 Type: AgentEventTypeResponse,
432 Message: agentMessage,
433 Done: true,
434 }
435 }
436}
437
438func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
439 parts := []message.ContentPart{message.TextContent{Text: content}}
440 parts = append(parts, attachmentParts...)
441 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442 Role: message.User,
443 Parts: parts,
444 })
445}
446
447func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
448 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
449 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
450
451 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
452 Role: message.Assistant,
453 Parts: []message.ContentPart{},
454 Model: a.Model().ID,
455 Provider: a.providerID,
456 })
457 if err != nil {
458 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
459 }
460
461 // Add the session and message ID into the context if needed by tools.
462 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
463
464 // Process each event in the stream.
465 for event := range eventChan {
466 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
467 if errors.Is(processErr, context.Canceled) {
468 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
469 } else {
470 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
471 }
472 return assistantMsg, nil, processErr
473 }
474 if ctx.Err() != nil {
475 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476 return assistantMsg, nil, ctx.Err()
477 }
478 }
479
480 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
481 toolCalls := assistantMsg.ToolCalls()
482 for i, toolCall := range toolCalls {
483 select {
484 case <-ctx.Done():
485 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
486 // Make all future tool calls cancelled
487 for j := i; j < len(toolCalls); j++ {
488 toolResults[j] = message.ToolResult{
489 ToolCallID: toolCalls[j].ID,
490 Content: "Tool execution canceled by user",
491 IsError: true,
492 }
493 }
494 goto out
495 default:
496 // Continue processing
497 var tool tools.BaseTool
498 for availableTool := range a.tools.Seq() {
499 if availableTool.Info().Name == toolCall.Name {
500 tool = availableTool
501 break
502 }
503 }
504
505 // Tool not found
506 if tool == nil {
507 toolResults[i] = message.ToolResult{
508 ToolCallID: toolCall.ID,
509 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
510 IsError: true,
511 }
512 continue
513 }
514
515 // Run tool in goroutine to allow cancellation
516 type toolExecResult struct {
517 response tools.ToolResponse
518 err error
519 }
520 resultChan := make(chan toolExecResult, 1)
521
522 go func() {
523 response, err := tool.Run(ctx, tools.ToolCall{
524 ID: toolCall.ID,
525 Name: toolCall.Name,
526 Input: toolCall.Input,
527 })
528 resultChan <- toolExecResult{response: response, err: err}
529 }()
530
531 var toolResponse tools.ToolResponse
532 var toolErr error
533
534 select {
535 case <-ctx.Done():
536 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
537 // Mark remaining tool calls as cancelled
538 for j := i; j < len(toolCalls); j++ {
539 toolResults[j] = message.ToolResult{
540 ToolCallID: toolCalls[j].ID,
541 Content: "Tool execution canceled by user",
542 IsError: true,
543 }
544 }
545 goto out
546 case result := <-resultChan:
547 toolResponse = result.response
548 toolErr = result.err
549 }
550
551 if toolErr != nil {
552 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
553 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
554 toolResults[i] = message.ToolResult{
555 ToolCallID: toolCall.ID,
556 Content: "Permission denied",
557 IsError: true,
558 }
559 for j := i + 1; j < len(toolCalls); j++ {
560 toolResults[j] = message.ToolResult{
561 ToolCallID: toolCalls[j].ID,
562 Content: "Tool execution canceled by user",
563 IsError: true,
564 }
565 }
566 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
567 break
568 }
569 }
570 toolResults[i] = message.ToolResult{
571 ToolCallID: toolCall.ID,
572 Content: toolResponse.Content,
573 Metadata: toolResponse.Metadata,
574 IsError: toolResponse.IsError,
575 }
576 }
577 }
578out:
579 if len(toolResults) == 0 {
580 return assistantMsg, nil, nil
581 }
582 parts := make([]message.ContentPart, 0)
583 for _, tr := range toolResults {
584 parts = append(parts, tr)
585 }
586 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
587 Role: message.Tool,
588 Parts: parts,
589 Provider: a.providerID,
590 })
591 if err != nil {
592 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
593 }
594
595 return assistantMsg, &msg, err
596}
597
598func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
599 msg.AddFinish(finishReason, message, details)
600 _ = a.messages.Update(ctx, *msg)
601}
602
603func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
604 select {
605 case <-ctx.Done():
606 return ctx.Err()
607 default:
608 // Continue processing.
609 }
610
611 switch event.Type {
612 case provider.EventThinkingDelta:
613 assistantMsg.AppendReasoningContent(event.Thinking)
614 return a.messages.Update(ctx, *assistantMsg)
615 case provider.EventSignatureDelta:
616 assistantMsg.AppendReasoningSignature(event.Signature)
617 return a.messages.Update(ctx, *assistantMsg)
618 case provider.EventContentDelta:
619 assistantMsg.FinishThinking()
620 assistantMsg.AppendContent(event.Content)
621 return a.messages.Update(ctx, *assistantMsg)
622 case provider.EventToolUseStart:
623 assistantMsg.FinishThinking()
624 slog.Info("Tool call started", "toolCall", event.ToolCall)
625 assistantMsg.AddToolCall(*event.ToolCall)
626 return a.messages.Update(ctx, *assistantMsg)
627 case provider.EventToolUseDelta:
628 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
629 return a.messages.Update(ctx, *assistantMsg)
630 case provider.EventToolUseStop:
631 slog.Info("Finished tool call", "toolCall", event.ToolCall)
632 assistantMsg.FinishToolCall(event.ToolCall.ID)
633 return a.messages.Update(ctx, *assistantMsg)
634 case provider.EventError:
635 return event.Error
636 case provider.EventComplete:
637 assistantMsg.FinishThinking()
638 assistantMsg.SetToolCalls(event.Response.ToolCalls)
639 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
640 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
641 return fmt.Errorf("failed to update message: %w", err)
642 }
643 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
644 }
645
646 return nil
647}
648
649func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
650 sess, err := a.sessions.Get(ctx, sessionID)
651 if err != nil {
652 return fmt.Errorf("failed to get session: %w", err)
653 }
654
655 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
656 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
657 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
658 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
659
660 sess.Cost += cost
661 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
662 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
663
664 _, err = a.sessions.Save(ctx, sess)
665 if err != nil {
666 return fmt.Errorf("failed to save session: %w", err)
667 }
668 return nil
669}
670
671func (a *agent) Summarize(ctx context.Context, sessionID string) error {
672 if a.summarizeProvider == nil {
673 return fmt.Errorf("summarize provider not available")
674 }
675
676 // Check if session is busy
677 if a.IsSessionBusy(sessionID) {
678 return ErrSessionBusy
679 }
680
681 // Create a new context with cancellation
682 summarizeCtx, cancel := context.WithCancel(ctx)
683
684 // Store the cancel function in activeRequests to allow cancellation
685 a.activeRequests.Store(sessionID+"-summarize", cancel)
686
687 go func() {
688 defer a.activeRequests.Delete(sessionID + "-summarize")
689 defer cancel()
690 event := AgentEvent{
691 Type: AgentEventTypeSummarize,
692 Progress: "Starting summarization...",
693 }
694
695 a.Publish(pubsub.CreatedEvent, event)
696 // Get all messages from the session
697 msgs, err := a.messages.List(summarizeCtx, sessionID)
698 if err != nil {
699 event = AgentEvent{
700 Type: AgentEventTypeError,
701 Error: fmt.Errorf("failed to list messages: %w", err),
702 Done: true,
703 }
704 a.Publish(pubsub.CreatedEvent, event)
705 return
706 }
707 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
708
709 if len(msgs) == 0 {
710 event = AgentEvent{
711 Type: AgentEventTypeError,
712 Error: fmt.Errorf("no messages to summarize"),
713 Done: true,
714 }
715 a.Publish(pubsub.CreatedEvent, event)
716 return
717 }
718
719 event = AgentEvent{
720 Type: AgentEventTypeSummarize,
721 Progress: "Analyzing conversation...",
722 }
723 a.Publish(pubsub.CreatedEvent, event)
724
725 // Add a system message to guide the summarization
726 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
727
728 // Create a new message with the summarize prompt
729 promptMsg := message.Message{
730 Role: message.User,
731 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
732 }
733
734 // Append the prompt to the messages
735 msgsWithPrompt := append(msgs, promptMsg)
736
737 event = AgentEvent{
738 Type: AgentEventTypeSummarize,
739 Progress: "Generating summary...",
740 }
741
742 a.Publish(pubsub.CreatedEvent, event)
743
744 // Send the messages to the summarize provider
745 response := a.summarizeProvider.StreamResponse(
746 summarizeCtx,
747 msgsWithPrompt,
748 nil,
749 )
750 var finalResponse *provider.ProviderResponse
751 for r := range response {
752 if r.Error != nil {
753 event = AgentEvent{
754 Type: AgentEventTypeError,
755 Error: fmt.Errorf("failed to summarize: %w", err),
756 Done: true,
757 }
758 a.Publish(pubsub.CreatedEvent, event)
759 return
760 }
761 finalResponse = r.Response
762 }
763
764 summary := strings.TrimSpace(finalResponse.Content)
765 if summary == "" {
766 event = AgentEvent{
767 Type: AgentEventTypeError,
768 Error: fmt.Errorf("empty summary returned"),
769 Done: true,
770 }
771 a.Publish(pubsub.CreatedEvent, event)
772 return
773 }
774 shell := shell.GetPersistentShell(config.Get().WorkingDir())
775 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
776 event = AgentEvent{
777 Type: AgentEventTypeSummarize,
778 Progress: "Creating new session...",
779 }
780
781 a.Publish(pubsub.CreatedEvent, event)
782 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
783 if err != nil {
784 event = AgentEvent{
785 Type: AgentEventTypeError,
786 Error: fmt.Errorf("failed to get session: %w", err),
787 Done: true,
788 }
789
790 a.Publish(pubsub.CreatedEvent, event)
791 return
792 }
793 // Create a message in the new session with the summary
794 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
795 Role: message.Assistant,
796 Parts: []message.ContentPart{
797 message.TextContent{Text: summary},
798 message.Finish{
799 Reason: message.FinishReasonEndTurn,
800 Time: time.Now().Unix(),
801 },
802 },
803 Model: a.summarizeProvider.Model().ID,
804 Provider: a.summarizeProviderID,
805 })
806 if err != nil {
807 event = AgentEvent{
808 Type: AgentEventTypeError,
809 Error: fmt.Errorf("failed to create summary message: %w", err),
810 Done: true,
811 }
812
813 a.Publish(pubsub.CreatedEvent, event)
814 return
815 }
816 oldSession.SummaryMessageID = msg.ID
817 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
818 oldSession.PromptTokens = 0
819 model := a.summarizeProvider.Model()
820 usage := finalResponse.Usage
821 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
822 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
823 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
824 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
825 oldSession.Cost += cost
826 _, err = a.sessions.Save(summarizeCtx, oldSession)
827 if err != nil {
828 event = AgentEvent{
829 Type: AgentEventTypeError,
830 Error: fmt.Errorf("failed to save session: %w", err),
831 Done: true,
832 }
833 a.Publish(pubsub.CreatedEvent, event)
834 }
835
836 event = AgentEvent{
837 Type: AgentEventTypeSummarize,
838 SessionID: oldSession.ID,
839 Progress: "Summary complete",
840 Done: true,
841 }
842 a.Publish(pubsub.CreatedEvent, event)
843 // Send final success event with the new session ID
844 }()
845
846 return nil
847}
848
849func (a *agent) CancelAll() {
850 if !a.IsBusy() {
851 return
852 }
853 a.activeRequests.Range(func(key, value any) bool {
854 a.Cancel(key.(string)) // key is sessionID
855 return true
856 })
857
858 timeout := time.After(5 * time.Second)
859 for a.IsBusy() {
860 select {
861 case <-timeout:
862 return
863 default:
864 time.Sleep(200 * time.Millisecond)
865 }
866 }
867}
868
869func (a *agent) UpdateModel() error {
870 cfg := config.Get()
871
872 // Get current provider configuration
873 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
874 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
875 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
876 }
877
878 // Check if provider has changed
879 if string(currentProviderCfg.ID) != a.providerID {
880 // Provider changed, need to recreate the main provider
881 model := cfg.GetModelByType(a.agentCfg.Model)
882 if model.ID == "" {
883 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
884 }
885
886 promptID := agentPromptMap[a.agentCfg.ID]
887 if promptID == "" {
888 promptID = prompt.PromptDefault
889 }
890
891 opts := []provider.ProviderClientOption{
892 provider.WithModel(a.agentCfg.Model),
893 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
894 }
895
896 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
897 if err != nil {
898 return fmt.Errorf("failed to create new provider: %w", err)
899 }
900
901 // Update the provider and provider ID
902 a.provider = newProvider
903 a.providerID = string(currentProviderCfg.ID)
904 }
905
906 // Check if small model provider has changed (affects title and summarize providers)
907 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
908 var smallModelProviderCfg config.ProviderConfig
909
910 for _, p := range cfg.Providers.Seq2() {
911 if p.ID == smallModelCfg.Provider {
912 smallModelProviderCfg = p
913 break
914 }
915 }
916
917 if smallModelProviderCfg.ID == "" {
918 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
919 }
920
921 // Check if summarize provider has changed
922 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
923 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
924 if smallModel == nil {
925 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
926 }
927
928 // Recreate title provider
929 titleOpts := []provider.ProviderClientOption{
930 provider.WithModel(config.SelectedModelTypeSmall),
931 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
932 // We want the title to be short, so we limit the max tokens
933 provider.WithMaxTokens(40),
934 }
935 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
936 if err != nil {
937 return fmt.Errorf("failed to create new title provider: %w", err)
938 }
939
940 // Recreate summarize provider
941 summarizeOpts := []provider.ProviderClientOption{
942 provider.WithModel(config.SelectedModelTypeSmall),
943 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
944 }
945 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
946 if err != nil {
947 return fmt.Errorf("failed to create new summarize provider: %w", err)
948 }
949
950 // Update the providers and provider ID
951 a.titleProvider = newTitleProvider
952 a.summarizeProvider = newSummarizeProvider
953 a.summarizeProviderID = string(smallModelProviderCfg.ID)
954 }
955
956 return nil
957}