1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70 mcpTools []McpTool
71
72 tools *csync.LazySlice[tools.BaseTool]
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests *csync.Map[string, context.CancelFunc]
82}
83
84var agentPromptMap = map[string]prompt.PromptID{
85 "coder": prompt.PromptCoder,
86 "task": prompt.PromptTask,
87}
88
89func NewAgent(
90 ctx context.Context,
91 agentCfg config.Agent,
92 // These services are needed in the tools
93 permissions permission.Service,
94 sessions session.Service,
95 messages message.Service,
96 history history.Service,
97 lspClients map[string]*lsp.Client,
98) (Service, error) {
99 cfg := config.Get()
100
101 var agentTool tools.BaseTool
102 if agentCfg.ID == "coder" {
103 taskAgentCfg := config.Get().Agents["task"]
104 if taskAgentCfg.ID == "" {
105 return nil, fmt.Errorf("task agent not found in config")
106 }
107 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108 if err != nil {
109 return nil, fmt.Errorf("failed to create task agent: %w", err)
110 }
111
112 agentTool = NewAgentTool(taskAgent, sessions, messages)
113 }
114
115 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116 if providerCfg == nil {
117 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118 }
119 model := config.Get().GetModelByType(agentCfg.Model)
120
121 if model == nil {
122 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123 }
124
125 promptID := agentPromptMap[agentCfg.ID]
126 if promptID == "" {
127 promptID = prompt.PromptDefault
128 }
129 opts := []provider.ProviderClientOption{
130 provider.WithModel(agentCfg.Model),
131 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132 }
133 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134 if err != nil {
135 return nil, err
136 }
137
138 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139 var smallModelProviderCfg *config.ProviderConfig
140 if smallModelCfg.Provider == providerCfg.ID {
141 smallModelProviderCfg = providerCfg
142 } else {
143 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145 if smallModelProviderCfg.ID == "" {
146 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147 }
148 }
149 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150 if smallModel.ID == "" {
151 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152 }
153
154 titleOpts := []provider.ProviderClientOption{
155 provider.WithModel(config.SelectedModelTypeSmall),
156 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157 }
158 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159 if err != nil {
160 return nil, err
161 }
162 summarizeOpts := []provider.ProviderClientOption{
163 provider.WithModel(config.SelectedModelTypeSmall),
164 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165 }
166 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167 if err != nil {
168 return nil, err
169 }
170
171 toolFn := func() []tools.BaseTool {
172 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173 defer func() {
174 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175 }()
176
177 cwd := cfg.WorkingDir()
178 allTools := []tools.BaseTool{
179 tools.NewBashTool(permissions, cwd),
180 tools.NewDownloadTool(permissions, cwd),
181 tools.NewEditTool(lspClients, permissions, history, cwd),
182 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
183 tools.NewFetchTool(permissions, cwd),
184 tools.NewGlobTool(cwd),
185 tools.NewGrepTool(cwd),
186 tools.NewLsTool(permissions, cwd),
187 tools.NewSourcegraphTool(),
188 tools.NewViewTool(lspClients, permissions, cwd),
189 tools.NewWriteTool(lspClients, permissions, history, cwd),
190 }
191
192 mcpToolsOnce.Do(func() {
193 mcpTools = doGetMCPTools(ctx, permissions, cfg)
194 })
195 allTools = append(allTools, mcpTools...)
196
197 if len(lspClients) > 0 {
198 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
199 }
200
201 if agentTool != nil {
202 allTools = append(allTools, agentTool)
203 }
204
205 if agentCfg.AllowedTools == nil {
206 return allTools
207 }
208
209 var filteredTools []tools.BaseTool
210 for _, tool := range allTools {
211 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
212 filteredTools = append(filteredTools, tool)
213 }
214 }
215 return filteredTools
216 }
217
218 return &agent{
219 Broker: pubsub.NewBroker[AgentEvent](),
220 agentCfg: agentCfg,
221 provider: agentProvider,
222 providerID: string(providerCfg.ID),
223 messages: messages,
224 sessions: sessions,
225 titleProvider: titleProvider,
226 summarizeProvider: summarizeProvider,
227 summarizeProviderID: string(smallModelProviderCfg.ID),
228 activeRequests: csync.NewMap[string, context.CancelFunc](),
229 tools: csync.NewLazySlice(toolFn),
230 }, nil
231}
232
233func (a *agent) Model() catwalk.Model {
234 return *config.Get().GetModelByType(a.agentCfg.Model)
235}
236
237func (a *agent) Cancel(sessionID string) {
238 // Cancel regular requests
239 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
240 slog.Info("Request cancellation initiated", "session_id", sessionID)
241 cancel()
242 }
243
244 // Also check for summarize requests
245 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
246 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249}
250
251func (a *agent) IsBusy() bool {
252 var busy bool
253 for cancelFunc := range a.activeRequests.Seq() {
254 if cancelFunc != nil {
255 busy = true
256 break
257 }
258 }
259 return busy
260}
261
262func (a *agent) IsSessionBusy(sessionID string) bool {
263 _, busy := a.activeRequests.Get(sessionID)
264 return busy
265}
266
267func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
268 if content == "" {
269 return nil
270 }
271 if a.titleProvider == nil {
272 return nil
273 }
274 session, err := a.sessions.Get(ctx, sessionID)
275 if err != nil {
276 return err
277 }
278 parts := []message.ContentPart{message.TextContent{
279 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
280 }}
281
282 // Use streaming approach like summarization
283 response := a.titleProvider.StreamResponse(
284 ctx,
285 []message.Message{
286 {
287 Role: message.User,
288 Parts: parts,
289 },
290 },
291 nil,
292 )
293
294 var finalResponse *provider.ProviderResponse
295 for r := range response {
296 if r.Error != nil {
297 return r.Error
298 }
299 finalResponse = r.Response
300 }
301
302 if finalResponse == nil {
303 return fmt.Errorf("no response received from title provider")
304 }
305
306 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
307 if title == "" {
308 return nil
309 }
310
311 session.Title = title
312 _, err = a.sessions.Save(ctx, session)
313 return err
314}
315
316func (a *agent) err(err error) AgentEvent {
317 return AgentEvent{
318 Type: AgentEventTypeError,
319 Error: err,
320 }
321}
322
323func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
324 if !a.Model().SupportsImages && attachments != nil {
325 attachments = nil
326 }
327 events := make(chan AgentEvent)
328 if a.IsSessionBusy(sessionID) {
329 return nil, ErrSessionBusy
330 }
331
332 genCtx, cancel := context.WithCancel(ctx)
333
334 a.activeRequests.Set(sessionID, cancel)
335 go func() {
336 slog.Debug("Request started", "sessionID", sessionID)
337 defer log.RecoverPanic("agent.Run", func() {
338 events <- a.err(fmt.Errorf("panic while running the agent"))
339 })
340 var attachmentParts []message.ContentPart
341 for _, attachment := range attachments {
342 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
343 }
344 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
345 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
346 slog.Error(result.Error.Error())
347 }
348 slog.Debug("Request completed", "sessionID", sessionID)
349 a.activeRequests.Del(sessionID)
350 cancel()
351 a.Publish(pubsub.CreatedEvent, result)
352 events <- result
353 close(events)
354 }()
355 return events, nil
356}
357
358func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
359 cfg := config.Get()
360 // List existing messages; if none, start title generation asynchronously.
361 msgs, err := a.messages.List(ctx, sessionID)
362 if err != nil {
363 return a.err(fmt.Errorf("failed to list messages: %w", err))
364 }
365 if len(msgs) == 0 {
366 go func() {
367 defer log.RecoverPanic("agent.Run", func() {
368 slog.Error("panic while generating title")
369 })
370 titleErr := a.generateTitle(context.Background(), sessionID, content)
371 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
372 slog.Error("failed to generate title", "error", titleErr)
373 }
374 }()
375 }
376 session, err := a.sessions.Get(ctx, sessionID)
377 if err != nil {
378 return a.err(fmt.Errorf("failed to get session: %w", err))
379 }
380 if session.SummaryMessageID != "" {
381 summaryMsgInex := -1
382 for i, msg := range msgs {
383 if msg.ID == session.SummaryMessageID {
384 summaryMsgInex = i
385 break
386 }
387 }
388 if summaryMsgInex != -1 {
389 msgs = msgs[summaryMsgInex:]
390 msgs[0].Role = message.User
391 }
392 }
393
394 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
395 if err != nil {
396 return a.err(fmt.Errorf("failed to create user message: %w", err))
397 }
398 // Append the new user message to the conversation history.
399 msgHistory := append(msgs, userMsg)
400
401 for {
402 // Check for cancellation before each iteration
403 select {
404 case <-ctx.Done():
405 return a.err(ctx.Err())
406 default:
407 // Continue processing
408 }
409 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
410 if err != nil {
411 if errors.Is(err, context.Canceled) {
412 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
413 a.messages.Update(context.Background(), agentMessage)
414 return a.err(ErrRequestCancelled)
415 }
416 return a.err(fmt.Errorf("failed to process events: %w", err))
417 }
418 if cfg.Options.Debug {
419 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
420 }
421 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
422 // We are not done, we need to respond with the tool response
423 msgHistory = append(msgHistory, agentMessage, *toolResults)
424 continue
425 }
426 if agentMessage.FinishReason() == "" {
427 // Kujtim: could not track down where this is happening but this means its cancelled
428 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
429 _ = a.messages.Update(context.Background(), agentMessage)
430 return a.err(ErrRequestCancelled)
431 }
432 return AgentEvent{
433 Type: AgentEventTypeResponse,
434 Message: agentMessage,
435 Done: true,
436 }
437 }
438}
439
440func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
441 parts := []message.ContentPart{message.TextContent{Text: content}}
442 parts = append(parts, attachmentParts...)
443 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444 Role: message.User,
445 Parts: parts,
446 })
447}
448
449func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
450 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
451 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
452
453 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
454 Role: message.Assistant,
455 Parts: []message.ContentPart{},
456 Model: a.Model().ID,
457 Provider: a.providerID,
458 })
459 if err != nil {
460 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
461 }
462
463 // Add the session and message ID into the context if needed by tools.
464 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
465
466 // Process each event in the stream.
467 for event := range eventChan {
468 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
469 if errors.Is(processErr, context.Canceled) {
470 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
471 } else {
472 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
473 }
474 return assistantMsg, nil, processErr
475 }
476 if ctx.Err() != nil {
477 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
478 return assistantMsg, nil, ctx.Err()
479 }
480 }
481
482 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
483 toolCalls := assistantMsg.ToolCalls()
484 for i, toolCall := range toolCalls {
485 select {
486 case <-ctx.Done():
487 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
488 // Make all future tool calls cancelled
489 for j := i; j < len(toolCalls); j++ {
490 toolResults[j] = message.ToolResult{
491 ToolCallID: toolCalls[j].ID,
492 Content: "Tool execution canceled by user",
493 IsError: true,
494 }
495 }
496 goto out
497 default:
498 // Continue processing
499 var tool tools.BaseTool
500 for availableTool := range a.tools.Seq() {
501 if availableTool.Info().Name == toolCall.Name {
502 tool = availableTool
503 break
504 }
505 }
506
507 // Tool not found
508 if tool == nil {
509 toolResults[i] = message.ToolResult{
510 ToolCallID: toolCall.ID,
511 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
512 IsError: true,
513 }
514 continue
515 }
516
517 // Run tool in goroutine to allow cancellation
518 type toolExecResult struct {
519 response tools.ToolResponse
520 err error
521 }
522 resultChan := make(chan toolExecResult, 1)
523
524 go func() {
525 response, err := tool.Run(ctx, tools.ToolCall{
526 ID: toolCall.ID,
527 Name: toolCall.Name,
528 Input: toolCall.Input,
529 })
530 resultChan <- toolExecResult{response: response, err: err}
531 }()
532
533 var toolResponse tools.ToolResponse
534 var toolErr error
535
536 select {
537 case <-ctx.Done():
538 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
539 // Mark remaining tool calls as cancelled
540 for j := i; j < len(toolCalls); j++ {
541 toolResults[j] = message.ToolResult{
542 ToolCallID: toolCalls[j].ID,
543 Content: "Tool execution canceled by user",
544 IsError: true,
545 }
546 }
547 goto out
548 case result := <-resultChan:
549 toolResponse = result.response
550 toolErr = result.err
551 }
552
553 if toolErr != nil {
554 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
555 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
556 toolResults[i] = message.ToolResult{
557 ToolCallID: toolCall.ID,
558 Content: "Permission denied",
559 IsError: true,
560 }
561 for j := i + 1; j < len(toolCalls); j++ {
562 toolResults[j] = message.ToolResult{
563 ToolCallID: toolCalls[j].ID,
564 Content: "Tool execution canceled by user",
565 IsError: true,
566 }
567 }
568 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
569 break
570 }
571 }
572 toolResults[i] = message.ToolResult{
573 ToolCallID: toolCall.ID,
574 Content: toolResponse.Content,
575 Metadata: toolResponse.Metadata,
576 IsError: toolResponse.IsError,
577 }
578 }
579 }
580out:
581 if len(toolResults) == 0 {
582 return assistantMsg, nil, nil
583 }
584 parts := make([]message.ContentPart, 0)
585 for _, tr := range toolResults {
586 parts = append(parts, tr)
587 }
588 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
589 Role: message.Tool,
590 Parts: parts,
591 Provider: a.providerID,
592 })
593 if err != nil {
594 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
595 }
596
597 return assistantMsg, &msg, err
598}
599
600func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
601 msg.AddFinish(finishReason, message, details)
602 _ = a.messages.Update(ctx, *msg)
603}
604
605func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
606 select {
607 case <-ctx.Done():
608 return ctx.Err()
609 default:
610 // Continue processing.
611 }
612
613 switch event.Type {
614 case provider.EventThinkingDelta:
615 assistantMsg.AppendReasoningContent(event.Thinking)
616 return a.messages.Update(ctx, *assistantMsg)
617 case provider.EventSignatureDelta:
618 assistantMsg.AppendReasoningSignature(event.Signature)
619 return a.messages.Update(ctx, *assistantMsg)
620 case provider.EventContentDelta:
621 assistantMsg.FinishThinking()
622 assistantMsg.AppendContent(event.Content)
623 return a.messages.Update(ctx, *assistantMsg)
624 case provider.EventToolUseStart:
625 assistantMsg.FinishThinking()
626 slog.Info("Tool call started", "toolCall", event.ToolCall)
627 assistantMsg.AddToolCall(*event.ToolCall)
628 return a.messages.Update(ctx, *assistantMsg)
629 case provider.EventToolUseDelta:
630 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
631 return a.messages.Update(ctx, *assistantMsg)
632 case provider.EventToolUseStop:
633 slog.Info("Finished tool call", "toolCall", event.ToolCall)
634 assistantMsg.FinishToolCall(event.ToolCall.ID)
635 return a.messages.Update(ctx, *assistantMsg)
636 case provider.EventError:
637 return event.Error
638 case provider.EventComplete:
639 assistantMsg.FinishThinking()
640 assistantMsg.SetToolCalls(event.Response.ToolCalls)
641 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
642 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
643 return fmt.Errorf("failed to update message: %w", err)
644 }
645 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
646 }
647
648 return nil
649}
650
651func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
652 sess, err := a.sessions.Get(ctx, sessionID)
653 if err != nil {
654 return fmt.Errorf("failed to get session: %w", err)
655 }
656
657 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
658 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
659 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
660 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
661
662 sess.Cost += cost
663 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
664 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
665
666 _, err = a.sessions.Save(ctx, sess)
667 if err != nil {
668 return fmt.Errorf("failed to save session: %w", err)
669 }
670 return nil
671}
672
673func (a *agent) Summarize(ctx context.Context, sessionID string) error {
674 if a.summarizeProvider == nil {
675 return fmt.Errorf("summarize provider not available")
676 }
677
678 // Check if session is busy
679 if a.IsSessionBusy(sessionID) {
680 return ErrSessionBusy
681 }
682
683 // Create a new context with cancellation
684 summarizeCtx, cancel := context.WithCancel(ctx)
685
686 // Store the cancel function in activeRequests to allow cancellation
687 a.activeRequests.Set(sessionID+"-summarize", cancel)
688
689 go func() {
690 defer a.activeRequests.Del(sessionID + "-summarize")
691 defer cancel()
692 event := AgentEvent{
693 Type: AgentEventTypeSummarize,
694 Progress: "Starting summarization...",
695 }
696
697 a.Publish(pubsub.CreatedEvent, event)
698 // Get all messages from the session
699 msgs, err := a.messages.List(summarizeCtx, sessionID)
700 if err != nil {
701 event = AgentEvent{
702 Type: AgentEventTypeError,
703 Error: fmt.Errorf("failed to list messages: %w", err),
704 Done: true,
705 }
706 a.Publish(pubsub.CreatedEvent, event)
707 return
708 }
709 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
710
711 if len(msgs) == 0 {
712 event = AgentEvent{
713 Type: AgentEventTypeError,
714 Error: fmt.Errorf("no messages to summarize"),
715 Done: true,
716 }
717 a.Publish(pubsub.CreatedEvent, event)
718 return
719 }
720
721 event = AgentEvent{
722 Type: AgentEventTypeSummarize,
723 Progress: "Analyzing conversation...",
724 }
725 a.Publish(pubsub.CreatedEvent, event)
726
727 // Add a system message to guide the summarization
728 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
729
730 // Create a new message with the summarize prompt
731 promptMsg := message.Message{
732 Role: message.User,
733 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
734 }
735
736 // Append the prompt to the messages
737 msgsWithPrompt := append(msgs, promptMsg)
738
739 event = AgentEvent{
740 Type: AgentEventTypeSummarize,
741 Progress: "Generating summary...",
742 }
743
744 a.Publish(pubsub.CreatedEvent, event)
745
746 // Send the messages to the summarize provider
747 response := a.summarizeProvider.StreamResponse(
748 summarizeCtx,
749 msgsWithPrompt,
750 nil,
751 )
752 var finalResponse *provider.ProviderResponse
753 for r := range response {
754 if r.Error != nil {
755 event = AgentEvent{
756 Type: AgentEventTypeError,
757 Error: fmt.Errorf("failed to summarize: %w", err),
758 Done: true,
759 }
760 a.Publish(pubsub.CreatedEvent, event)
761 return
762 }
763 finalResponse = r.Response
764 }
765
766 summary := strings.TrimSpace(finalResponse.Content)
767 if summary == "" {
768 event = AgentEvent{
769 Type: AgentEventTypeError,
770 Error: fmt.Errorf("empty summary returned"),
771 Done: true,
772 }
773 a.Publish(pubsub.CreatedEvent, event)
774 return
775 }
776 shell := shell.GetPersistentShell(config.Get().WorkingDir())
777 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
778 event = AgentEvent{
779 Type: AgentEventTypeSummarize,
780 Progress: "Creating new session...",
781 }
782
783 a.Publish(pubsub.CreatedEvent, event)
784 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
785 if err != nil {
786 event = AgentEvent{
787 Type: AgentEventTypeError,
788 Error: fmt.Errorf("failed to get session: %w", err),
789 Done: true,
790 }
791
792 a.Publish(pubsub.CreatedEvent, event)
793 return
794 }
795 // Create a message in the new session with the summary
796 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
797 Role: message.Assistant,
798 Parts: []message.ContentPart{
799 message.TextContent{Text: summary},
800 message.Finish{
801 Reason: message.FinishReasonEndTurn,
802 Time: time.Now().Unix(),
803 },
804 },
805 Model: a.summarizeProvider.Model().ID,
806 Provider: a.summarizeProviderID,
807 })
808 if err != nil {
809 event = AgentEvent{
810 Type: AgentEventTypeError,
811 Error: fmt.Errorf("failed to create summary message: %w", err),
812 Done: true,
813 }
814
815 a.Publish(pubsub.CreatedEvent, event)
816 return
817 }
818 oldSession.SummaryMessageID = msg.ID
819 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
820 oldSession.PromptTokens = 0
821 model := a.summarizeProvider.Model()
822 usage := finalResponse.Usage
823 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
824 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
825 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
826 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
827 oldSession.Cost += cost
828 _, err = a.sessions.Save(summarizeCtx, oldSession)
829 if err != nil {
830 event = AgentEvent{
831 Type: AgentEventTypeError,
832 Error: fmt.Errorf("failed to save session: %w", err),
833 Done: true,
834 }
835 a.Publish(pubsub.CreatedEvent, event)
836 }
837
838 event = AgentEvent{
839 Type: AgentEventTypeSummarize,
840 SessionID: oldSession.ID,
841 Progress: "Summary complete",
842 Done: true,
843 }
844 a.Publish(pubsub.CreatedEvent, event)
845 // Send final success event with the new session ID
846 }()
847
848 return nil
849}
850
851func (a *agent) CancelAll() {
852 if !a.IsBusy() {
853 return
854 }
855 for key := range a.activeRequests.Seq2() {
856 a.Cancel(key) // key is sessionID
857 }
858
859 timeout := time.After(5 * time.Second)
860 for a.IsBusy() {
861 select {
862 case <-timeout:
863 return
864 default:
865 time.Sleep(200 * time.Millisecond)
866 }
867 }
868}
869
870func (a *agent) UpdateModel() error {
871 cfg := config.Get()
872
873 // Get current provider configuration
874 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
875 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
876 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
877 }
878
879 // Check if provider has changed
880 if string(currentProviderCfg.ID) != a.providerID {
881 // Provider changed, need to recreate the main provider
882 model := cfg.GetModelByType(a.agentCfg.Model)
883 if model.ID == "" {
884 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
885 }
886
887 promptID := agentPromptMap[a.agentCfg.ID]
888 if promptID == "" {
889 promptID = prompt.PromptDefault
890 }
891
892 opts := []provider.ProviderClientOption{
893 provider.WithModel(a.agentCfg.Model),
894 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
895 }
896
897 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
898 if err != nil {
899 return fmt.Errorf("failed to create new provider: %w", err)
900 }
901
902 // Update the provider and provider ID
903 a.provider = newProvider
904 a.providerID = string(currentProviderCfg.ID)
905 }
906
907 // Check if small model provider has changed (affects title and summarize providers)
908 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
909 var smallModelProviderCfg config.ProviderConfig
910
911 for p := range cfg.Providers.Seq() {
912 if p.ID == smallModelCfg.Provider {
913 smallModelProviderCfg = p
914 break
915 }
916 }
917
918 if smallModelProviderCfg.ID == "" {
919 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
920 }
921
922 // Check if summarize provider has changed
923 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
924 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
925 if smallModel == nil {
926 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
927 }
928
929 // Recreate title provider
930 titleOpts := []provider.ProviderClientOption{
931 provider.WithModel(config.SelectedModelTypeSmall),
932 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
933 // We want the title to be short, so we limit the max tokens
934 provider.WithMaxTokens(40),
935 }
936 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
937 if err != nil {
938 return fmt.Errorf("failed to create new title provider: %w", err)
939 }
940
941 // Recreate summarize provider
942 summarizeOpts := []provider.ProviderClientOption{
943 provider.WithModel(config.SelectedModelTypeSmall),
944 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
945 }
946 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
947 if err != nil {
948 return fmt.Errorf("failed to create new summarize provider: %w", err)
949 }
950
951 // Update the providers and provider ID
952 a.titleProvider = newTitleProvider
953 a.summarizeProvider = newSummarizeProvider
954 a.summarizeProviderID = string(smallModelProviderCfg.ID)
955 }
956
957 return nil
958}