1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70
71 tools *csync.LazySlice[tools.BaseTool]
72
73 provider provider.Provider
74 providerID string
75
76 titleProvider provider.Provider
77 summarizeProvider provider.Provider
78 summarizeProviderID string
79
80 activeRequests *csync.Map[string, context.CancelFunc]
81}
82
83var agentPromptMap = map[string]prompt.PromptID{
84 "coder": prompt.PromptCoder,
85 "task": prompt.PromptTask,
86}
87
88func NewAgent(
89 agentCfg config.Agent,
90 // These services are needed in the tools
91 permissions permission.Service,
92 sessions session.Service,
93 messages message.Service,
94 history history.Service,
95 lspClients map[string]*lsp.Client,
96) (Service, error) {
97 ctx := context.Background()
98 cfg := config.Get()
99
100 var agentTool tools.BaseTool
101 if agentCfg.ID == "coder" {
102 taskAgentCfg := config.Get().Agents["task"]
103 if taskAgentCfg.ID == "" {
104 return nil, fmt.Errorf("task agent not found in config")
105 }
106 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
107 if err != nil {
108 return nil, fmt.Errorf("failed to create task agent: %w", err)
109 }
110
111 agentTool = NewAgentTool(taskAgent, sessions, messages)
112 }
113
114 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
115 if providerCfg == nil {
116 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
117 }
118 model := config.Get().GetModelByType(agentCfg.Model)
119
120 if model == nil {
121 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
122 }
123
124 promptID := agentPromptMap[agentCfg.ID]
125 if promptID == "" {
126 promptID = prompt.PromptDefault
127 }
128 opts := []provider.ProviderClientOption{
129 provider.WithModel(agentCfg.Model),
130 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
131 }
132 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
133 if err != nil {
134 return nil, err
135 }
136
137 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
138 var smallModelProviderCfg *config.ProviderConfig
139 if smallModelCfg.Provider == providerCfg.ID {
140 smallModelProviderCfg = providerCfg
141 } else {
142 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
143
144 if smallModelProviderCfg.ID == "" {
145 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
146 }
147 }
148 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
149 if smallModel.ID == "" {
150 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
151 }
152
153 titleOpts := []provider.ProviderClientOption{
154 provider.WithModel(config.SelectedModelTypeSmall),
155 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
156 }
157 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
158 if err != nil {
159 return nil, err
160 }
161 summarizeOpts := []provider.ProviderClientOption{
162 provider.WithModel(config.SelectedModelTypeSmall),
163 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
164 }
165 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
166 if err != nil {
167 return nil, err
168 }
169
170 toolFn := func() []tools.BaseTool {
171 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
172 defer func() {
173 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
174 }()
175
176 cwd := cfg.WorkingDir()
177 allTools := []tools.BaseTool{
178 tools.NewBashTool(permissions, cwd),
179 tools.NewDownloadTool(permissions, cwd),
180 tools.NewEditTool(lspClients, permissions, history, cwd),
181 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
182 tools.NewFetchTool(permissions, cwd),
183 tools.NewGlobTool(cwd),
184 tools.NewGrepTool(cwd),
185 tools.NewLsTool(cwd),
186 tools.NewSourcegraphTool(),
187 tools.NewViewTool(lspClients, cwd),
188 tools.NewWriteTool(lspClients, permissions, history, cwd),
189 }
190
191 mcpTools := GetMCPTools(ctx, permissions, cfg)
192 allTools = append(allTools, mcpTools...)
193
194 if len(lspClients) > 0 {
195 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
196 }
197
198 if agentTool != nil {
199 allTools = append(allTools, agentTool)
200 }
201
202 if agentCfg.AllowedTools == nil {
203 return allTools
204 }
205
206 var filteredTools []tools.BaseTool
207 for _, tool := range allTools {
208 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
209 filteredTools = append(filteredTools, tool)
210 }
211 }
212 return filteredTools
213 }
214
215 return &agent{
216 Broker: pubsub.NewBroker[AgentEvent](),
217 agentCfg: agentCfg,
218 provider: agentProvider,
219 providerID: string(providerCfg.ID),
220 messages: messages,
221 sessions: sessions,
222 titleProvider: titleProvider,
223 summarizeProvider: summarizeProvider,
224 summarizeProviderID: string(smallModelProviderCfg.ID),
225 activeRequests: csync.NewMap[string, context.CancelFunc](),
226 tools: csync.NewLazySlice(toolFn),
227 }, nil
228}
229
230func (a *agent) Model() catwalk.Model {
231 return *config.Get().GetModelByType(a.agentCfg.Model)
232}
233
234func (a *agent) Cancel(sessionID string) {
235 // Cancel regular requests
236 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
237 slog.Info("Request cancellation initiated", "session_id", sessionID)
238 cancel()
239 }
240
241 // Also check for summarize requests
242 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
243 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
244 cancel()
245 }
246}
247
248func (a *agent) IsBusy() bool {
249 var busy bool
250 for cancelFunc := range a.activeRequests.Seq() {
251 if cancelFunc != nil {
252 busy = true
253 break
254 }
255 }
256 return busy
257}
258
259func (a *agent) IsSessionBusy(sessionID string) bool {
260 _, busy := a.activeRequests.Get(sessionID)
261 return busy
262}
263
264func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
265 if content == "" {
266 return nil
267 }
268 if a.titleProvider == nil {
269 return nil
270 }
271 session, err := a.sessions.Get(ctx, sessionID)
272 if err != nil {
273 return err
274 }
275 parts := []message.ContentPart{message.TextContent{
276 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
277 }}
278
279 // Use streaming approach like summarization
280 response := a.titleProvider.StreamResponse(
281 ctx,
282 []message.Message{
283 {
284 Role: message.User,
285 Parts: parts,
286 },
287 },
288 nil,
289 )
290
291 var finalResponse *provider.ProviderResponse
292 for r := range response {
293 if r.Error != nil {
294 return r.Error
295 }
296 finalResponse = r.Response
297 }
298
299 if finalResponse == nil {
300 return fmt.Errorf("no response received from title provider")
301 }
302
303 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
304 if title == "" {
305 return nil
306 }
307
308 session.Title = title
309 _, err = a.sessions.Save(ctx, session)
310 return err
311}
312
313func (a *agent) err(err error) AgentEvent {
314 return AgentEvent{
315 Type: AgentEventTypeError,
316 Error: err,
317 }
318}
319
320func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
321 if !a.Model().SupportsImages && attachments != nil {
322 attachments = nil
323 }
324 events := make(chan AgentEvent)
325 if a.IsSessionBusy(sessionID) {
326 return nil, ErrSessionBusy
327 }
328
329 genCtx, cancel := context.WithCancel(ctx)
330
331 a.activeRequests.Set(sessionID, cancel)
332 go func() {
333 slog.Debug("Request started", "sessionID", sessionID)
334 defer log.RecoverPanic("agent.Run", func() {
335 events <- a.err(fmt.Errorf("panic while running the agent"))
336 })
337 var attachmentParts []message.ContentPart
338 for _, attachment := range attachments {
339 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
340 }
341 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
342 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
343 slog.Error(result.Error.Error())
344 }
345 slog.Debug("Request completed", "sessionID", sessionID)
346 a.activeRequests.Del(sessionID)
347 cancel()
348 a.Publish(pubsub.CreatedEvent, result)
349 events <- result
350 close(events)
351 }()
352 return events, nil
353}
354
355func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
356 cfg := config.Get()
357 // List existing messages; if none, start title generation asynchronously.
358 msgs, err := a.messages.List(ctx, sessionID)
359 if err != nil {
360 return a.err(fmt.Errorf("failed to list messages: %w", err))
361 }
362 if len(msgs) == 0 {
363 go func() {
364 defer log.RecoverPanic("agent.Run", func() {
365 slog.Error("panic while generating title")
366 })
367 titleErr := a.generateTitle(context.Background(), sessionID, content)
368 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
369 slog.Error("failed to generate title", "error", titleErr)
370 }
371 }()
372 }
373 session, err := a.sessions.Get(ctx, sessionID)
374 if err != nil {
375 return a.err(fmt.Errorf("failed to get session: %w", err))
376 }
377 if session.SummaryMessageID != "" {
378 summaryMsgInex := -1
379 for i, msg := range msgs {
380 if msg.ID == session.SummaryMessageID {
381 summaryMsgInex = i
382 break
383 }
384 }
385 if summaryMsgInex != -1 {
386 msgs = msgs[summaryMsgInex:]
387 msgs[0].Role = message.User
388 }
389 }
390
391 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
392 if err != nil {
393 return a.err(fmt.Errorf("failed to create user message: %w", err))
394 }
395 // Append the new user message to the conversation history.
396 msgHistory := append(msgs, userMsg)
397
398 for {
399 // Check for cancellation before each iteration
400 select {
401 case <-ctx.Done():
402 return a.err(ctx.Err())
403 default:
404 // Continue processing
405 }
406 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
407 if err != nil {
408 if errors.Is(err, context.Canceled) {
409 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
410 a.messages.Update(context.Background(), agentMessage)
411 return a.err(ErrRequestCancelled)
412 }
413 return a.err(fmt.Errorf("failed to process events: %w", err))
414 }
415 if cfg.Options.Debug {
416 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
417 }
418 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
419 // We are not done, we need to respond with the tool response
420 msgHistory = append(msgHistory, agentMessage, *toolResults)
421 continue
422 }
423 return AgentEvent{
424 Type: AgentEventTypeResponse,
425 Message: agentMessage,
426 Done: true,
427 }
428 }
429}
430
431func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
432 parts := []message.ContentPart{message.TextContent{Text: content}}
433 parts = append(parts, attachmentParts...)
434 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
435 Role: message.User,
436 Parts: parts,
437 })
438}
439
440func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
441 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
442 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
443
444 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
445 Role: message.Assistant,
446 Parts: []message.ContentPart{},
447 Model: a.Model().ID,
448 Provider: a.providerID,
449 })
450 if err != nil {
451 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
452 }
453
454 // Add the session and message ID into the context if needed by tools.
455 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
456
457 // Process each event in the stream.
458 for event := range eventChan {
459 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
460 if errors.Is(processErr, context.Canceled) {
461 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
462 } else {
463 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
464 }
465 return assistantMsg, nil, processErr
466 }
467 if ctx.Err() != nil {
468 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
469 return assistantMsg, nil, ctx.Err()
470 }
471 }
472
473 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
474 toolCalls := assistantMsg.ToolCalls()
475 for i, toolCall := range toolCalls {
476 select {
477 case <-ctx.Done():
478 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
479 // Make all future tool calls cancelled
480 for j := i; j < len(toolCalls); j++ {
481 toolResults[j] = message.ToolResult{
482 ToolCallID: toolCalls[j].ID,
483 Content: "Tool execution canceled by user",
484 IsError: true,
485 }
486 }
487 goto out
488 default:
489 // Continue processing
490 var tool tools.BaseTool
491 for availableTool := range a.tools.Seq() {
492 if availableTool.Info().Name == toolCall.Name {
493 tool = availableTool
494 break
495 }
496 }
497
498 // Tool not found
499 if tool == nil {
500 toolResults[i] = message.ToolResult{
501 ToolCallID: toolCall.ID,
502 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
503 IsError: true,
504 }
505 continue
506 }
507
508 // Run tool in goroutine to allow cancellation
509 type toolExecResult struct {
510 response tools.ToolResponse
511 err error
512 }
513 resultChan := make(chan toolExecResult, 1)
514
515 go func() {
516 response, err := tool.Run(ctx, tools.ToolCall{
517 ID: toolCall.ID,
518 Name: toolCall.Name,
519 Input: toolCall.Input,
520 })
521 resultChan <- toolExecResult{response: response, err: err}
522 }()
523
524 var toolResponse tools.ToolResponse
525 var toolErr error
526
527 select {
528 case <-ctx.Done():
529 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
530 // Mark remaining tool calls as cancelled
531 for j := i; j < len(toolCalls); j++ {
532 toolResults[j] = message.ToolResult{
533 ToolCallID: toolCalls[j].ID,
534 Content: "Tool execution canceled by user",
535 IsError: true,
536 }
537 }
538 goto out
539 case result := <-resultChan:
540 toolResponse = result.response
541 toolErr = result.err
542 }
543
544 if toolErr != nil {
545 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
546 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
547 toolResults[i] = message.ToolResult{
548 ToolCallID: toolCall.ID,
549 Content: "Permission denied",
550 IsError: true,
551 }
552 for j := i + 1; j < len(toolCalls); j++ {
553 toolResults[j] = message.ToolResult{
554 ToolCallID: toolCalls[j].ID,
555 Content: "Tool execution canceled by user",
556 IsError: true,
557 }
558 }
559 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
560 break
561 }
562 }
563 toolResults[i] = message.ToolResult{
564 ToolCallID: toolCall.ID,
565 Content: toolResponse.Content,
566 Metadata: toolResponse.Metadata,
567 IsError: toolResponse.IsError,
568 }
569 }
570 }
571out:
572 if len(toolResults) == 0 {
573 return assistantMsg, nil, nil
574 }
575 parts := make([]message.ContentPart, 0)
576 for _, tr := range toolResults {
577 parts = append(parts, tr)
578 }
579 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
580 Role: message.Tool,
581 Parts: parts,
582 Provider: a.providerID,
583 })
584 if err != nil {
585 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
586 }
587
588 return assistantMsg, &msg, err
589}
590
591func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
592 msg.AddFinish(finishReason, message, details)
593 _ = a.messages.Update(ctx, *msg)
594}
595
596func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
597 select {
598 case <-ctx.Done():
599 return ctx.Err()
600 default:
601 // Continue processing.
602 }
603
604 switch event.Type {
605 case provider.EventThinkingDelta:
606 assistantMsg.AppendReasoningContent(event.Thinking)
607 return a.messages.Update(ctx, *assistantMsg)
608 case provider.EventSignatureDelta:
609 assistantMsg.AppendReasoningSignature(event.Signature)
610 return a.messages.Update(ctx, *assistantMsg)
611 case provider.EventContentDelta:
612 assistantMsg.FinishThinking()
613 assistantMsg.AppendContent(event.Content)
614 return a.messages.Update(ctx, *assistantMsg)
615 case provider.EventToolUseStart:
616 assistantMsg.FinishThinking()
617 slog.Info("Tool call started", "toolCall", event.ToolCall)
618 assistantMsg.AddToolCall(*event.ToolCall)
619 return a.messages.Update(ctx, *assistantMsg)
620 case provider.EventToolUseDelta:
621 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
622 return a.messages.Update(ctx, *assistantMsg)
623 case provider.EventToolUseStop:
624 slog.Info("Finished tool call", "toolCall", event.ToolCall)
625 assistantMsg.FinishToolCall(event.ToolCall.ID)
626 return a.messages.Update(ctx, *assistantMsg)
627 case provider.EventError:
628 return event.Error
629 case provider.EventComplete:
630 assistantMsg.FinishThinking()
631 assistantMsg.SetToolCalls(event.Response.ToolCalls)
632 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
633 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
634 return fmt.Errorf("failed to update message: %w", err)
635 }
636 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
637 }
638
639 return nil
640}
641
642func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
643 sess, err := a.sessions.Get(ctx, sessionID)
644 if err != nil {
645 return fmt.Errorf("failed to get session: %w", err)
646 }
647
648 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
649 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
650 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
651 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
652
653 sess.Cost += cost
654 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
655 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
656
657 _, err = a.sessions.Save(ctx, sess)
658 if err != nil {
659 return fmt.Errorf("failed to save session: %w", err)
660 }
661 return nil
662}
663
664func (a *agent) Summarize(ctx context.Context, sessionID string) error {
665 if a.summarizeProvider == nil {
666 return fmt.Errorf("summarize provider not available")
667 }
668
669 // Check if session is busy
670 if a.IsSessionBusy(sessionID) {
671 return ErrSessionBusy
672 }
673
674 // Create a new context with cancellation
675 summarizeCtx, cancel := context.WithCancel(ctx)
676
677 // Store the cancel function in activeRequests to allow cancellation
678 a.activeRequests.Set(sessionID+"-summarize", cancel)
679
680 go func() {
681 defer a.activeRequests.Del(sessionID + "-summarize")
682 defer cancel()
683 event := AgentEvent{
684 Type: AgentEventTypeSummarize,
685 Progress: "Starting summarization...",
686 }
687
688 a.Publish(pubsub.CreatedEvent, event)
689 // Get all messages from the session
690 msgs, err := a.messages.List(summarizeCtx, sessionID)
691 if err != nil {
692 event = AgentEvent{
693 Type: AgentEventTypeError,
694 Error: fmt.Errorf("failed to list messages: %w", err),
695 Done: true,
696 }
697 a.Publish(pubsub.CreatedEvent, event)
698 return
699 }
700 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
701
702 if len(msgs) == 0 {
703 event = AgentEvent{
704 Type: AgentEventTypeError,
705 Error: fmt.Errorf("no messages to summarize"),
706 Done: true,
707 }
708 a.Publish(pubsub.CreatedEvent, event)
709 return
710 }
711
712 event = AgentEvent{
713 Type: AgentEventTypeSummarize,
714 Progress: "Analyzing conversation...",
715 }
716 a.Publish(pubsub.CreatedEvent, event)
717
718 // Add a system message to guide the summarization
719 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
720
721 // Create a new message with the summarize prompt
722 promptMsg := message.Message{
723 Role: message.User,
724 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
725 }
726
727 // Append the prompt to the messages
728 msgsWithPrompt := append(msgs, promptMsg)
729
730 event = AgentEvent{
731 Type: AgentEventTypeSummarize,
732 Progress: "Generating summary...",
733 }
734
735 a.Publish(pubsub.CreatedEvent, event)
736
737 // Send the messages to the summarize provider
738 response := a.summarizeProvider.StreamResponse(
739 summarizeCtx,
740 msgsWithPrompt,
741 nil,
742 )
743 var finalResponse *provider.ProviderResponse
744 for r := range response {
745 if r.Error != nil {
746 event = AgentEvent{
747 Type: AgentEventTypeError,
748 Error: fmt.Errorf("failed to summarize: %w", err),
749 Done: true,
750 }
751 a.Publish(pubsub.CreatedEvent, event)
752 return
753 }
754 finalResponse = r.Response
755 }
756
757 summary := strings.TrimSpace(finalResponse.Content)
758 if summary == "" {
759 event = AgentEvent{
760 Type: AgentEventTypeError,
761 Error: fmt.Errorf("empty summary returned"),
762 Done: true,
763 }
764 a.Publish(pubsub.CreatedEvent, event)
765 return
766 }
767 shell := shell.GetPersistentShell(config.Get().WorkingDir())
768 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
769 event = AgentEvent{
770 Type: AgentEventTypeSummarize,
771 Progress: "Creating new session...",
772 }
773
774 a.Publish(pubsub.CreatedEvent, event)
775 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
776 if err != nil {
777 event = AgentEvent{
778 Type: AgentEventTypeError,
779 Error: fmt.Errorf("failed to get session: %w", err),
780 Done: true,
781 }
782
783 a.Publish(pubsub.CreatedEvent, event)
784 return
785 }
786 // Create a message in the new session with the summary
787 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
788 Role: message.Assistant,
789 Parts: []message.ContentPart{
790 message.TextContent{Text: summary},
791 message.Finish{
792 Reason: message.FinishReasonEndTurn,
793 Time: time.Now().Unix(),
794 },
795 },
796 Model: a.summarizeProvider.Model().ID,
797 Provider: a.summarizeProviderID,
798 })
799 if err != nil {
800 event = AgentEvent{
801 Type: AgentEventTypeError,
802 Error: fmt.Errorf("failed to create summary message: %w", err),
803 Done: true,
804 }
805
806 a.Publish(pubsub.CreatedEvent, event)
807 return
808 }
809 oldSession.SummaryMessageID = msg.ID
810 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
811 oldSession.PromptTokens = 0
812 model := a.summarizeProvider.Model()
813 usage := finalResponse.Usage
814 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
815 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
816 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
817 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
818 oldSession.Cost += cost
819 _, err = a.sessions.Save(summarizeCtx, oldSession)
820 if err != nil {
821 event = AgentEvent{
822 Type: AgentEventTypeError,
823 Error: fmt.Errorf("failed to save session: %w", err),
824 Done: true,
825 }
826 a.Publish(pubsub.CreatedEvent, event)
827 }
828
829 event = AgentEvent{
830 Type: AgentEventTypeSummarize,
831 SessionID: oldSession.ID,
832 Progress: "Summary complete",
833 Done: true,
834 }
835 a.Publish(pubsub.CreatedEvent, event)
836 // Send final success event with the new session ID
837 }()
838
839 return nil
840}
841
842func (a *agent) CancelAll() {
843 if !a.IsBusy() {
844 return
845 }
846 for key := range a.activeRequests.Seq2() {
847 a.Cancel(key) // key is sessionID
848 }
849
850 timeout := time.After(5 * time.Second)
851 for a.IsBusy() {
852 select {
853 case <-timeout:
854 return
855 default:
856 time.Sleep(200 * time.Millisecond)
857 }
858 }
859}
860
861func (a *agent) UpdateModel() error {
862 cfg := config.Get()
863
864 // Get current provider configuration
865 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
866 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
867 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
868 }
869
870 // Check if provider has changed
871 if string(currentProviderCfg.ID) != a.providerID {
872 // Provider changed, need to recreate the main provider
873 model := cfg.GetModelByType(a.agentCfg.Model)
874 if model.ID == "" {
875 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
876 }
877
878 promptID := agentPromptMap[a.agentCfg.ID]
879 if promptID == "" {
880 promptID = prompt.PromptDefault
881 }
882
883 opts := []provider.ProviderClientOption{
884 provider.WithModel(a.agentCfg.Model),
885 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
886 }
887
888 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
889 if err != nil {
890 return fmt.Errorf("failed to create new provider: %w", err)
891 }
892
893 // Update the provider and provider ID
894 a.provider = newProvider
895 a.providerID = string(currentProviderCfg.ID)
896 }
897
898 // Check if small model provider has changed (affects title and summarize providers)
899 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
900 var smallModelProviderCfg config.ProviderConfig
901
902 for p := range cfg.Providers.Seq() {
903 if p.ID == smallModelCfg.Provider {
904 smallModelProviderCfg = p
905 break
906 }
907 }
908
909 if smallModelProviderCfg.ID == "" {
910 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
911 }
912
913 // Check if summarize provider has changed
914 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
915 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
916 if smallModel == nil {
917 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
918 }
919
920 // Recreate title provider
921 titleOpts := []provider.ProviderClientOption{
922 provider.WithModel(config.SelectedModelTypeSmall),
923 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
924 // We want the title to be short, so we limit the max tokens
925 provider.WithMaxTokens(40),
926 }
927 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
928 if err != nil {
929 return fmt.Errorf("failed to create new title provider: %w", err)
930 }
931
932 // Recreate summarize provider
933 summarizeOpts := []provider.ProviderClientOption{
934 provider.WithModel(config.SelectedModelTypeSmall),
935 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
936 }
937 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
938 if err != nil {
939 return fmt.Errorf("failed to create new summarize provider: %w", err)
940 }
941
942 // Update the providers and provider ID
943 a.titleProvider = newTitleProvider
944 a.summarizeProvider = newSummarizeProvider
945 a.summarizeProviderID = string(smallModelProviderCfg.ID)
946 }
947
948 return nil
949}