1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "time"
11
12 "github.com/charmbracelet/catwalk/pkg/catwalk"
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() catwalk.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70 mcpTools []McpTool
71
72 tools *csync.LazySlice[tools.BaseTool]
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests *csync.Map[string, context.CancelFunc]
82}
83
84var agentPromptMap = map[string]prompt.PromptID{
85 "coder": prompt.PromptCoder,
86 "task": prompt.PromptTask,
87}
88
89func NewAgent(
90 ctx context.Context,
91 agentCfg config.Agent,
92 // These services are needed in the tools
93 permissions permission.Service,
94 sessions session.Service,
95 messages message.Service,
96 history history.Service,
97 lspClients map[string]*lsp.Client,
98) (Service, error) {
99 cfg := config.Get()
100
101 var agentTool tools.BaseTool
102 if agentCfg.ID == "coder" {
103 taskAgentCfg := config.Get().Agents["task"]
104 if taskAgentCfg.ID == "" {
105 return nil, fmt.Errorf("task agent not found in config")
106 }
107 taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108 if err != nil {
109 return nil, fmt.Errorf("failed to create task agent: %w", err)
110 }
111
112 agentTool = NewAgentTool(taskAgent, sessions, messages)
113 }
114
115 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116 if providerCfg == nil {
117 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118 }
119 model := config.Get().GetModelByType(agentCfg.Model)
120
121 if model == nil {
122 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123 }
124
125 promptID := agentPromptMap[agentCfg.ID]
126 if promptID == "" {
127 promptID = prompt.PromptDefault
128 }
129 opts := []provider.ProviderClientOption{
130 provider.WithModel(agentCfg.Model),
131 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132 }
133 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134 if err != nil {
135 return nil, err
136 }
137
138 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139 var smallModelProviderCfg *config.ProviderConfig
140 if smallModelCfg.Provider == providerCfg.ID {
141 smallModelProviderCfg = providerCfg
142 } else {
143 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145 if smallModelProviderCfg.ID == "" {
146 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147 }
148 }
149 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150 if smallModel.ID == "" {
151 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152 }
153
154 titleOpts := []provider.ProviderClientOption{
155 provider.WithModel(config.SelectedModelTypeSmall),
156 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157 }
158 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159 if err != nil {
160 return nil, err
161 }
162 summarizeOpts := []provider.ProviderClientOption{
163 provider.WithModel(config.SelectedModelTypeSmall),
164 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165 }
166 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167 if err != nil {
168 return nil, err
169 }
170
171 toolFn := func() []tools.BaseTool {
172 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173 defer func() {
174 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175 }()
176
177 cwd := cfg.WorkingDir()
178 allTools := []tools.BaseTool{
179 tools.NewBashTool(permissions, cwd),
180 tools.NewDownloadTool(permissions, cwd),
181 tools.NewEditTool(lspClients, permissions, history, cwd),
182 tools.NewMultiEditTool(lspClients, permissions, history, cwd),
183 tools.NewFetchTool(permissions, cwd),
184 tools.NewGlobTool(cwd),
185 tools.NewGrepTool(cwd),
186 tools.NewLsTool(permissions, cwd),
187 tools.NewSourcegraphTool(),
188 tools.NewViewTool(lspClients, permissions, cwd),
189 tools.NewWriteTool(lspClients, permissions, history, cwd),
190 }
191
192 mcpToolsOnce.Do(func() {
193 mcpTools = doGetMCPTools(ctx, permissions, cfg)
194 })
195 allTools = append(allTools, mcpTools...)
196
197 if len(lspClients) > 0 {
198 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
199 }
200
201 if agentTool != nil {
202 allTools = append(allTools, agentTool)
203 }
204
205 if agentCfg.AllowedTools == nil {
206 return allTools
207 }
208
209 var filteredTools []tools.BaseTool
210 for _, tool := range allTools {
211 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
212 filteredTools = append(filteredTools, tool)
213 }
214 }
215 return filteredTools
216 }
217
218 return &agent{
219 Broker: pubsub.NewBroker[AgentEvent](),
220 agentCfg: agentCfg,
221 provider: agentProvider,
222 providerID: string(providerCfg.ID),
223 messages: messages,
224 sessions: sessions,
225 titleProvider: titleProvider,
226 summarizeProvider: summarizeProvider,
227 summarizeProviderID: string(smallModelProviderCfg.ID),
228 activeRequests: csync.NewMap[string, context.CancelFunc](),
229 tools: csync.NewLazySlice(toolFn),
230 }, nil
231}
232
233func (a *agent) Model() catwalk.Model {
234 return *config.Get().GetModelByType(a.agentCfg.Model)
235}
236
237func (a *agent) Cancel(sessionID string) {
238 // Cancel regular requests
239 if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
240 slog.Info("Request cancellation initiated", "session_id", sessionID)
241 cancel()
242 }
243
244 // Also check for summarize requests
245 if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
246 slog.Info("Summarize cancellation initiated", "session_id", sessionID)
247 cancel()
248 }
249}
250
251func (a *agent) IsBusy() bool {
252 var busy bool
253 for cancelFunc := range a.activeRequests.Seq() {
254 if cancelFunc != nil {
255 busy = true
256 break
257 }
258 }
259 return busy
260}
261
262func (a *agent) IsSessionBusy(sessionID string) bool {
263 _, busy := a.activeRequests.Get(sessionID)
264 return busy
265}
266
267func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
268 if content == "" {
269 return nil
270 }
271 if a.titleProvider == nil {
272 return nil
273 }
274 session, err := a.sessions.Get(ctx, sessionID)
275 if err != nil {
276 return err
277 }
278 parts := []message.ContentPart{message.TextContent{
279 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
280 }}
281
282 // Use streaming approach like summarization
283 response := a.titleProvider.StreamResponse(
284 ctx,
285 []message.Message{
286 {
287 Role: message.User,
288 Parts: parts,
289 },
290 },
291 nil,
292 )
293
294 var finalResponse *provider.ProviderResponse
295 for r := range response {
296 if r.Error != nil {
297 return r.Error
298 }
299 finalResponse = r.Response
300 }
301
302 if finalResponse == nil {
303 return fmt.Errorf("no response received from title provider")
304 }
305
306 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
307 if title == "" {
308 return nil
309 }
310
311 session.Title = title
312 _, err = a.sessions.Save(ctx, session)
313 return err
314}
315
316func (a *agent) err(err error) AgentEvent {
317 return AgentEvent{
318 Type: AgentEventTypeError,
319 Error: err,
320 }
321}
322
323func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
324 if !a.Model().SupportsImages && attachments != nil {
325 attachments = nil
326 }
327 events := make(chan AgentEvent)
328 if a.IsSessionBusy(sessionID) {
329 return nil, ErrSessionBusy
330 }
331
332 genCtx, cancel := context.WithCancel(ctx)
333
334 a.activeRequests.Set(sessionID, cancel)
335 go func() {
336 slog.Debug("Request started", "sessionID", sessionID)
337 defer log.RecoverPanic("agent.Run", func() {
338 events <- a.err(fmt.Errorf("panic while running the agent"))
339 })
340 var attachmentParts []message.ContentPart
341 for _, attachment := range attachments {
342 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
343 }
344 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
345 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
346 slog.Error(result.Error.Error())
347 }
348 slog.Debug("Request completed", "sessionID", sessionID)
349 a.activeRequests.Del(sessionID)
350 cancel()
351 a.Publish(pubsub.CreatedEvent, result)
352 events <- result
353 close(events)
354 }()
355 return events, nil
356}
357
358func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
359 cfg := config.Get()
360 // List existing messages; if none, start title generation asynchronously.
361 msgs, err := a.messages.List(ctx, sessionID)
362 if err != nil {
363 return a.err(fmt.Errorf("failed to list messages: %w", err))
364 }
365
366 // Implement sliding window to limit message history
367 maxMessages := cfg.Options.MaxMessagesInContext
368 if maxMessages > 0 && len(msgs) > maxMessages {
369 // Keep the first message (usually system/context) and the last N-1 messages
370 msgs = append(msgs[:1], msgs[len(msgs)-maxMessages+1:]...)
371 }
372 if len(msgs) == 0 {
373 go func() {
374 defer log.RecoverPanic("agent.Run", func() {
375 slog.Error("panic while generating title")
376 })
377 titleErr := a.generateTitle(context.Background(), sessionID, content)
378 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
379 slog.Error("failed to generate title", "error", titleErr)
380 }
381 }()
382 }
383 session, err := a.sessions.Get(ctx, sessionID)
384 if err != nil {
385 return a.err(fmt.Errorf("failed to get session: %w", err))
386 }
387 if session.SummaryMessageID != "" {
388 summaryMsgInex := -1
389 for i, msg := range msgs {
390 if msg.ID == session.SummaryMessageID {
391 summaryMsgInex = i
392 break
393 }
394 }
395 if summaryMsgInex != -1 {
396 msgs = msgs[summaryMsgInex:]
397 msgs[0].Role = message.User
398 }
399 }
400
401 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
402 if err != nil {
403 return a.err(fmt.Errorf("failed to create user message: %w", err))
404 }
405 // Append the new user message to the conversation history.
406 msgHistory := append(msgs, userMsg)
407
408 for {
409 // Check for cancellation before each iteration
410 select {
411 case <-ctx.Done():
412 return a.err(ctx.Err())
413 default:
414 // Continue processing
415 }
416 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
417 if err != nil {
418 if errors.Is(err, context.Canceled) {
419 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
420 a.messages.Update(context.Background(), agentMessage)
421 return a.err(ErrRequestCancelled)
422 }
423 return a.err(fmt.Errorf("failed to process events: %w", err))
424 }
425 if cfg.Options.Debug {
426 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
427 }
428 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
429 // We are not done, we need to respond with the tool response
430 msgHistory = append(msgHistory, agentMessage, *toolResults)
431 continue
432 }
433 if agentMessage.FinishReason() == "" {
434 // Kujtim: could not track down where this is happening but this means its cancelled
435 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
436 _ = a.messages.Update(context.Background(), agentMessage)
437 return a.err(ErrRequestCancelled)
438 }
439 return AgentEvent{
440 Type: AgentEventTypeResponse,
441 Message: agentMessage,
442 Done: true,
443 }
444 }
445}
446
447func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
448 parts := []message.ContentPart{message.TextContent{Text: content}}
449 parts = append(parts, attachmentParts...)
450 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
451 Role: message.User,
452 Parts: parts,
453 })
454}
455
456func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
457 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
458 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
459
460 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
461 Role: message.Assistant,
462 Parts: []message.ContentPart{},
463 Model: a.Model().ID,
464 Provider: a.providerID,
465 })
466 if err != nil {
467 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
468 }
469
470 // Add the session and message ID into the context if needed by tools.
471 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
472
473 // Process each event in the stream.
474 for event := range eventChan {
475 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
476 if errors.Is(processErr, context.Canceled) {
477 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
478 } else {
479 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
480 }
481 return assistantMsg, nil, processErr
482 }
483 if ctx.Err() != nil {
484 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
485 return assistantMsg, nil, ctx.Err()
486 }
487 }
488
489 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
490 toolCalls := assistantMsg.ToolCalls()
491 for i, toolCall := range toolCalls {
492 select {
493 case <-ctx.Done():
494 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
495 // Make all future tool calls cancelled
496 for j := i; j < len(toolCalls); j++ {
497 toolResults[j] = message.ToolResult{
498 ToolCallID: toolCalls[j].ID,
499 Content: "Tool execution canceled by user",
500 IsError: true,
501 }
502 }
503 goto out
504 default:
505 // Continue processing
506 var tool tools.BaseTool
507 for availableTool := range a.tools.Seq() {
508 if availableTool.Info().Name == toolCall.Name {
509 tool = availableTool
510 break
511 }
512 }
513
514 // Tool not found
515 if tool == nil {
516 toolResults[i] = message.ToolResult{
517 ToolCallID: toolCall.ID,
518 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
519 IsError: true,
520 }
521 continue
522 }
523
524 // Run tool in goroutine to allow cancellation
525 type toolExecResult struct {
526 response tools.ToolResponse
527 err error
528 }
529 resultChan := make(chan toolExecResult, 1)
530
531 go func() {
532 response, err := tool.Run(ctx, tools.ToolCall{
533 ID: toolCall.ID,
534 Name: toolCall.Name,
535 Input: toolCall.Input,
536 })
537 resultChan <- toolExecResult{response: response, err: err}
538 }()
539
540 var toolResponse tools.ToolResponse
541 var toolErr error
542
543 select {
544 case <-ctx.Done():
545 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
546 // Mark remaining tool calls as cancelled
547 for j := i; j < len(toolCalls); j++ {
548 toolResults[j] = message.ToolResult{
549 ToolCallID: toolCalls[j].ID,
550 Content: "Tool execution canceled by user",
551 IsError: true,
552 }
553 }
554 goto out
555 case result := <-resultChan:
556 toolResponse = result.response
557 toolErr = result.err
558 }
559
560 if toolErr != nil {
561 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
562 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
563 toolResults[i] = message.ToolResult{
564 ToolCallID: toolCall.ID,
565 Content: "Permission denied",
566 IsError: true,
567 }
568 for j := i + 1; j < len(toolCalls); j++ {
569 toolResults[j] = message.ToolResult{
570 ToolCallID: toolCalls[j].ID,
571 Content: "Tool execution canceled by user",
572 IsError: true,
573 }
574 }
575 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
576 break
577 }
578 }
579 toolResults[i] = message.ToolResult{
580 ToolCallID: toolCall.ID,
581 Content: toolResponse.Content,
582 Metadata: toolResponse.Metadata,
583 IsError: toolResponse.IsError,
584 }
585 }
586 }
587out:
588 if len(toolResults) == 0 {
589 return assistantMsg, nil, nil
590 }
591 parts := make([]message.ContentPart, 0)
592 for _, tr := range toolResults {
593 parts = append(parts, tr)
594 }
595 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
596 Role: message.Tool,
597 Parts: parts,
598 Provider: a.providerID,
599 })
600 if err != nil {
601 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
602 }
603
604 return assistantMsg, &msg, err
605}
606
607func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
608 msg.AddFinish(finishReason, message, details)
609 _ = a.messages.Update(ctx, *msg)
610}
611
612func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
613 select {
614 case <-ctx.Done():
615 return ctx.Err()
616 default:
617 // Continue processing.
618 }
619
620 switch event.Type {
621 case provider.EventThinkingDelta:
622 assistantMsg.AppendReasoningContent(event.Thinking)
623 return a.messages.Update(ctx, *assistantMsg)
624 case provider.EventSignatureDelta:
625 assistantMsg.AppendReasoningSignature(event.Signature)
626 return a.messages.Update(ctx, *assistantMsg)
627 case provider.EventContentDelta:
628 assistantMsg.FinishThinking()
629 assistantMsg.AppendContent(event.Content)
630 return a.messages.Update(ctx, *assistantMsg)
631 case provider.EventToolUseStart:
632 assistantMsg.FinishThinking()
633 slog.Info("Tool call started", "toolCall", event.ToolCall)
634 assistantMsg.AddToolCall(*event.ToolCall)
635 return a.messages.Update(ctx, *assistantMsg)
636 case provider.EventToolUseDelta:
637 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
638 return a.messages.Update(ctx, *assistantMsg)
639 case provider.EventToolUseStop:
640 slog.Info("Finished tool call", "toolCall", event.ToolCall)
641 assistantMsg.FinishToolCall(event.ToolCall.ID)
642 return a.messages.Update(ctx, *assistantMsg)
643 case provider.EventError:
644 return event.Error
645 case provider.EventComplete:
646 assistantMsg.FinishThinking()
647 assistantMsg.SetToolCalls(event.Response.ToolCalls)
648 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
649 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
650 return fmt.Errorf("failed to update message: %w", err)
651 }
652 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
653 }
654
655 return nil
656}
657
658func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
659 sess, err := a.sessions.Get(ctx, sessionID)
660 if err != nil {
661 return fmt.Errorf("failed to get session: %w", err)
662 }
663
664 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
665 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
666 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
667 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
668
669 sess.Cost += cost
670 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
671 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
672
673 _, err = a.sessions.Save(ctx, sess)
674 if err != nil {
675 return fmt.Errorf("failed to save session: %w", err)
676 }
677 return nil
678}
679
680func (a *agent) Summarize(ctx context.Context, sessionID string) error {
681 if a.summarizeProvider == nil {
682 return fmt.Errorf("summarize provider not available")
683 }
684
685 // Check if session is busy
686 if a.IsSessionBusy(sessionID) {
687 return ErrSessionBusy
688 }
689
690 // Create a new context with cancellation
691 summarizeCtx, cancel := context.WithCancel(ctx)
692
693 // Store the cancel function in activeRequests to allow cancellation
694 a.activeRequests.Set(sessionID+"-summarize", cancel)
695
696 go func() {
697 defer a.activeRequests.Del(sessionID + "-summarize")
698 defer cancel()
699 event := AgentEvent{
700 Type: AgentEventTypeSummarize,
701 Progress: "Starting summarization...",
702 }
703
704 a.Publish(pubsub.CreatedEvent, event)
705 // Get all messages from the session
706 msgs, err := a.messages.List(summarizeCtx, sessionID)
707 if err != nil {
708 event = AgentEvent{
709 Type: AgentEventTypeError,
710 Error: fmt.Errorf("failed to list messages: %w", err),
711 Done: true,
712 }
713 a.Publish(pubsub.CreatedEvent, event)
714 return
715 }
716 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
717
718 if len(msgs) == 0 {
719 event = AgentEvent{
720 Type: AgentEventTypeError,
721 Error: fmt.Errorf("no messages to summarize"),
722 Done: true,
723 }
724 a.Publish(pubsub.CreatedEvent, event)
725 return
726 }
727
728 event = AgentEvent{
729 Type: AgentEventTypeSummarize,
730 Progress: "Analyzing conversation...",
731 }
732 a.Publish(pubsub.CreatedEvent, event)
733
734 // Add a system message to guide the summarization
735 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
736
737 // Create a new message with the summarize prompt
738 promptMsg := message.Message{
739 Role: message.User,
740 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
741 }
742
743 // Append the prompt to the messages
744 msgsWithPrompt := append(msgs, promptMsg)
745
746 event = AgentEvent{
747 Type: AgentEventTypeSummarize,
748 Progress: "Generating summary...",
749 }
750
751 a.Publish(pubsub.CreatedEvent, event)
752
753 // Send the messages to the summarize provider
754 response := a.summarizeProvider.StreamResponse(
755 summarizeCtx,
756 msgsWithPrompt,
757 nil,
758 )
759 var finalResponse *provider.ProviderResponse
760 for r := range response {
761 if r.Error != nil {
762 event = AgentEvent{
763 Type: AgentEventTypeError,
764 Error: fmt.Errorf("failed to summarize: %w", err),
765 Done: true,
766 }
767 a.Publish(pubsub.CreatedEvent, event)
768 return
769 }
770 finalResponse = r.Response
771 }
772
773 summary := strings.TrimSpace(finalResponse.Content)
774 if summary == "" {
775 event = AgentEvent{
776 Type: AgentEventTypeError,
777 Error: fmt.Errorf("empty summary returned"),
778 Done: true,
779 }
780 a.Publish(pubsub.CreatedEvent, event)
781 return
782 }
783 shell := shell.GetPersistentShell(config.Get().WorkingDir())
784 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
785 event = AgentEvent{
786 Type: AgentEventTypeSummarize,
787 Progress: "Creating new session...",
788 }
789
790 a.Publish(pubsub.CreatedEvent, event)
791 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
792 if err != nil {
793 event = AgentEvent{
794 Type: AgentEventTypeError,
795 Error: fmt.Errorf("failed to get session: %w", err),
796 Done: true,
797 }
798
799 a.Publish(pubsub.CreatedEvent, event)
800 return
801 }
802 // Create a message in the new session with the summary
803 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
804 Role: message.Assistant,
805 Parts: []message.ContentPart{
806 message.TextContent{Text: summary},
807 message.Finish{
808 Reason: message.FinishReasonEndTurn,
809 Time: time.Now().Unix(),
810 },
811 },
812 Model: a.summarizeProvider.Model().ID,
813 Provider: a.summarizeProviderID,
814 })
815 if err != nil {
816 event = AgentEvent{
817 Type: AgentEventTypeError,
818 Error: fmt.Errorf("failed to create summary message: %w", err),
819 Done: true,
820 }
821
822 a.Publish(pubsub.CreatedEvent, event)
823 return
824 }
825 oldSession.SummaryMessageID = msg.ID
826 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
827 oldSession.PromptTokens = 0
828 model := a.summarizeProvider.Model()
829 usage := finalResponse.Usage
830 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
831 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
832 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
833 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
834 oldSession.Cost += cost
835 _, err = a.sessions.Save(summarizeCtx, oldSession)
836 if err != nil {
837 event = AgentEvent{
838 Type: AgentEventTypeError,
839 Error: fmt.Errorf("failed to save session: %w", err),
840 Done: true,
841 }
842 a.Publish(pubsub.CreatedEvent, event)
843 }
844
845 event = AgentEvent{
846 Type: AgentEventTypeSummarize,
847 SessionID: oldSession.ID,
848 Progress: "Summary complete",
849 Done: true,
850 }
851 a.Publish(pubsub.CreatedEvent, event)
852 // Send final success event with the new session ID
853 }()
854
855 return nil
856}
857
858func (a *agent) CancelAll() {
859 if !a.IsBusy() {
860 return
861 }
862 for key := range a.activeRequests.Seq2() {
863 a.Cancel(key) // key is sessionID
864 }
865
866 timeout := time.After(5 * time.Second)
867 for a.IsBusy() {
868 select {
869 case <-timeout:
870 return
871 default:
872 time.Sleep(200 * time.Millisecond)
873 }
874 }
875}
876
877func (a *agent) UpdateModel() error {
878 cfg := config.Get()
879
880 // Get current provider configuration
881 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
882 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
883 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
884 }
885
886 // Check if provider has changed
887 if string(currentProviderCfg.ID) != a.providerID {
888 // Provider changed, need to recreate the main provider
889 model := cfg.GetModelByType(a.agentCfg.Model)
890 if model.ID == "" {
891 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
892 }
893
894 promptID := agentPromptMap[a.agentCfg.ID]
895 if promptID == "" {
896 promptID = prompt.PromptDefault
897 }
898
899 opts := []provider.ProviderClientOption{
900 provider.WithModel(a.agentCfg.Model),
901 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
902 }
903
904 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
905 if err != nil {
906 return fmt.Errorf("failed to create new provider: %w", err)
907 }
908
909 // Update the provider and provider ID
910 a.provider = newProvider
911 a.providerID = string(currentProviderCfg.ID)
912 }
913
914 // Check if small model provider has changed (affects title and summarize providers)
915 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
916 var smallModelProviderCfg config.ProviderConfig
917
918 for p := range cfg.Providers.Seq() {
919 if p.ID == smallModelCfg.Provider {
920 smallModelProviderCfg = p
921 break
922 }
923 }
924
925 if smallModelProviderCfg.ID == "" {
926 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
927 }
928
929 // Check if summarize provider has changed
930 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
931 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
932 if smallModel == nil {
933 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
934 }
935
936 // Recreate title provider
937 titleOpts := []provider.ProviderClientOption{
938 provider.WithModel(config.SelectedModelTypeSmall),
939 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
940 // We want the title to be short, so we limit the max tokens
941 provider.WithMaxTokens(40),
942 }
943 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
944 if err != nil {
945 return fmt.Errorf("failed to create new title provider: %w", err)
946 }
947
948 // Recreate summarize provider
949 summarizeOpts := []provider.ProviderClientOption{
950 provider.WithModel(config.SelectedModelTypeSmall),
951 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
952 }
953 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
954 if err != nil {
955 return fmt.Errorf("failed to create new summarize provider: %w", err)
956 }
957
958 // Update the providers and provider ID
959 a.titleProvider = newTitleProvider
960 a.summarizeProvider = newSummarizeProvider
961 a.summarizeProviderID = string(smallModelProviderCfg.ID)
962 }
963
964 return nil
965}