agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/crush/internal/shell"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() catwalk.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70	mcpTools []McpTool
 71
 72	tools *csync.LazySlice[tools.BaseTool]
 73
 74	provider   provider.Provider
 75	providerID string
 76
 77	titleProvider       provider.Provider
 78	summarizeProvider   provider.Provider
 79	summarizeProviderID string
 80
 81	activeRequests *csync.Map[string, context.CancelFunc]
 82}
 83
 84var agentPromptMap = map[string]prompt.PromptID{
 85	"coder": prompt.PromptCoder,
 86	"task":  prompt.PromptTask,
 87}
 88
 89func NewAgent(
 90	ctx context.Context,
 91	agentCfg config.Agent,
 92	// These services are needed in the tools
 93	permissions permission.Service,
 94	sessions session.Service,
 95	messages message.Service,
 96	history history.Service,
 97	lspClients map[string]*lsp.Client,
 98) (Service, error) {
 99	cfg := config.Get()
100
101	var agentTool tools.BaseTool
102	if agentCfg.ID == "coder" {
103		taskAgentCfg := config.Get().Agents["task"]
104		if taskAgentCfg.ID == "" {
105			return nil, fmt.Errorf("task agent not found in config")
106		}
107		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108		if err != nil {
109			return nil, fmt.Errorf("failed to create task agent: %w", err)
110		}
111
112		agentTool = NewAgentTool(taskAgent, sessions, messages)
113	}
114
115	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116	if providerCfg == nil {
117		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118	}
119	model := config.Get().GetModelByType(agentCfg.Model)
120
121	if model == nil {
122		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123	}
124
125	promptID := agentPromptMap[agentCfg.ID]
126	if promptID == "" {
127		promptID = prompt.PromptDefault
128	}
129	opts := []provider.ProviderClientOption{
130		provider.WithModel(agentCfg.Model),
131		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132	}
133	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134	if err != nil {
135		return nil, err
136	}
137
138	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139	var smallModelProviderCfg *config.ProviderConfig
140	if smallModelCfg.Provider == providerCfg.ID {
141		smallModelProviderCfg = providerCfg
142	} else {
143		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145		if smallModelProviderCfg.ID == "" {
146			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147		}
148	}
149	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150	if smallModel.ID == "" {
151		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152	}
153
154	titleOpts := []provider.ProviderClientOption{
155		provider.WithModel(config.SelectedModelTypeSmall),
156		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157	}
158	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159	if err != nil {
160		return nil, err
161	}
162	summarizeOpts := []provider.ProviderClientOption{
163		provider.WithModel(config.SelectedModelTypeSmall),
164		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165	}
166	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167	if err != nil {
168		return nil, err
169	}
170
171	toolFn := func() []tools.BaseTool {
172		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173		defer func() {
174			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175		}()
176
177		cwd := cfg.WorkingDir()
178		allTools := []tools.BaseTool{
179			tools.NewBashTool(permissions, cwd),
180			tools.NewDownloadTool(permissions, cwd),
181			tools.NewEditTool(lspClients, permissions, history, cwd),
182			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
183			tools.NewFetchTool(permissions, cwd),
184			tools.NewGlobTool(cwd),
185			tools.NewGrepTool(cwd),
186			tools.NewLsTool(permissions, cwd),
187			tools.NewSourcegraphTool(),
188			tools.NewViewTool(lspClients, permissions, cwd),
189			tools.NewWriteTool(lspClients, permissions, history, cwd),
190		}
191
192		mcpTools := GetMCPTools(ctx, permissions, cfg)
193		allTools = append(allTools, mcpTools...)
194
195		if len(lspClients) > 0 {
196			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
197		}
198
199		if agentTool != nil {
200			allTools = append(allTools, agentTool)
201		}
202
203		if agentCfg.AllowedTools == nil {
204			return allTools
205		}
206
207		var filteredTools []tools.BaseTool
208		for _, tool := range allTools {
209			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
210				filteredTools = append(filteredTools, tool)
211			}
212		}
213		return filteredTools
214	}
215
216	return &agent{
217		Broker:              pubsub.NewBroker[AgentEvent](),
218		agentCfg:            agentCfg,
219		provider:            agentProvider,
220		providerID:          string(providerCfg.ID),
221		messages:            messages,
222		sessions:            sessions,
223		titleProvider:       titleProvider,
224		summarizeProvider:   summarizeProvider,
225		summarizeProviderID: string(smallModelProviderCfg.ID),
226		activeRequests:      csync.NewMap[string, context.CancelFunc](),
227		tools:               csync.NewLazySlice(toolFn),
228	}, nil
229}
230
231func (a *agent) Model() catwalk.Model {
232	return *config.Get().GetModelByType(a.agentCfg.Model)
233}
234
235func (a *agent) Cancel(sessionID string) {
236	// Cancel regular requests
237	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
238		slog.Info("Request cancellation initiated", "session_id", sessionID)
239		cancel()
240	}
241
242	// Also check for summarize requests
243	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
244		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
245		cancel()
246	}
247}
248
249func (a *agent) IsBusy() bool {
250	var busy bool
251	for cancelFunc := range a.activeRequests.Seq() {
252		if cancelFunc != nil {
253			busy = true
254			break
255		}
256	}
257	return busy
258}
259
260func (a *agent) IsSessionBusy(sessionID string) bool {
261	_, busy := a.activeRequests.Get(sessionID)
262	return busy
263}
264
265func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
266	if content == "" {
267		return nil
268	}
269	if a.titleProvider == nil {
270		return nil
271	}
272	session, err := a.sessions.Get(ctx, sessionID)
273	if err != nil {
274		return err
275	}
276	parts := []message.ContentPart{message.TextContent{
277		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
278	}}
279
280	// Use streaming approach like summarization
281	response := a.titleProvider.StreamResponse(
282		ctx,
283		[]message.Message{
284			{
285				Role:  message.User,
286				Parts: parts,
287			},
288		},
289		nil,
290	)
291
292	var finalResponse *provider.ProviderResponse
293	for r := range response {
294		if r.Error != nil {
295			return r.Error
296		}
297		finalResponse = r.Response
298	}
299
300	if finalResponse == nil {
301		return fmt.Errorf("no response received from title provider")
302	}
303
304	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
305	if title == "" {
306		return nil
307	}
308
309	session.Title = title
310	_, err = a.sessions.Save(ctx, session)
311	return err
312}
313
314func (a *agent) err(err error) AgentEvent {
315	return AgentEvent{
316		Type:  AgentEventTypeError,
317		Error: err,
318	}
319}
320
321func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
322	if !a.Model().SupportsImages && attachments != nil {
323		attachments = nil
324	}
325	events := make(chan AgentEvent)
326	if a.IsSessionBusy(sessionID) {
327		return nil, ErrSessionBusy
328	}
329
330	genCtx, cancel := context.WithCancel(ctx)
331
332	a.activeRequests.Set(sessionID, cancel)
333	go func() {
334		slog.Debug("Request started", "sessionID", sessionID)
335		defer log.RecoverPanic("agent.Run", func() {
336			events <- a.err(fmt.Errorf("panic while running the agent"))
337		})
338		var attachmentParts []message.ContentPart
339		for _, attachment := range attachments {
340			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
341		}
342		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
343		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
344			slog.Error(result.Error.Error())
345		}
346		slog.Debug("Request completed", "sessionID", sessionID)
347		a.activeRequests.Del(sessionID)
348		cancel()
349		a.Publish(pubsub.CreatedEvent, result)
350		events <- result
351		close(events)
352	}()
353	return events, nil
354}
355
356func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
357	cfg := config.Get()
358	// List existing messages; if none, start title generation asynchronously.
359	msgs, err := a.messages.List(ctx, sessionID)
360	if err != nil {
361		return a.err(fmt.Errorf("failed to list messages: %w", err))
362	}
363	if len(msgs) == 0 {
364		go func() {
365			defer log.RecoverPanic("agent.Run", func() {
366				slog.Error("panic while generating title")
367			})
368			titleErr := a.generateTitle(context.Background(), sessionID, content)
369			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
370				slog.Error("failed to generate title", "error", titleErr)
371			}
372		}()
373	}
374	session, err := a.sessions.Get(ctx, sessionID)
375	if err != nil {
376		return a.err(fmt.Errorf("failed to get session: %w", err))
377	}
378	if session.SummaryMessageID != "" {
379		summaryMsgInex := -1
380		for i, msg := range msgs {
381			if msg.ID == session.SummaryMessageID {
382				summaryMsgInex = i
383				break
384			}
385		}
386		if summaryMsgInex != -1 {
387			msgs = msgs[summaryMsgInex:]
388			msgs[0].Role = message.User
389		}
390	}
391
392	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
393	if err != nil {
394		return a.err(fmt.Errorf("failed to create user message: %w", err))
395	}
396	// Append the new user message to the conversation history.
397	msgHistory := append(msgs, userMsg)
398
399	for {
400		// Check for cancellation before each iteration
401		select {
402		case <-ctx.Done():
403			return a.err(ctx.Err())
404		default:
405			// Continue processing
406		}
407		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
408		if err != nil {
409			if errors.Is(err, context.Canceled) {
410				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
411				a.messages.Update(context.Background(), agentMessage)
412				return a.err(ErrRequestCancelled)
413			}
414			return a.err(fmt.Errorf("failed to process events: %w", err))
415		}
416		if cfg.Options.Debug {
417			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
418		}
419		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
420			// We are not done, we need to respond with the tool response
421			msgHistory = append(msgHistory, agentMessage, *toolResults)
422			continue
423		}
424		if agentMessage.FinishReason() == "" {
425			// Kujtim: could not track down where this is happening but this means its cancelled
426			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
427			_ = a.messages.Update(context.Background(), agentMessage)
428			return a.err(ErrRequestCancelled)
429		}
430		return AgentEvent{
431			Type:    AgentEventTypeResponse,
432			Message: agentMessage,
433			Done:    true,
434		}
435	}
436}
437
438func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
439	parts := []message.ContentPart{message.TextContent{Text: content}}
440	parts = append(parts, attachmentParts...)
441	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442		Role:  message.User,
443		Parts: parts,
444	})
445}
446
447func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
448	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
449	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
450
451	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
452		Role:     message.Assistant,
453		Parts:    []message.ContentPart{},
454		Model:    a.Model().ID,
455		Provider: a.providerID,
456	})
457	if err != nil {
458		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
459	}
460
461	// Add the session and message ID into the context if needed by tools.
462	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
463
464	// Process each event in the stream.
465	for event := range eventChan {
466		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
467			if errors.Is(processErr, context.Canceled) {
468				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
469			} else {
470				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
471			}
472			return assistantMsg, nil, processErr
473		}
474		if ctx.Err() != nil {
475			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476			return assistantMsg, nil, ctx.Err()
477		}
478	}
479
480	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
481	toolCalls := assistantMsg.ToolCalls()
482	for i, toolCall := range toolCalls {
483		select {
484		case <-ctx.Done():
485			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
486			// Make all future tool calls cancelled
487			for j := i; j < len(toolCalls); j++ {
488				toolResults[j] = message.ToolResult{
489					ToolCallID: toolCalls[j].ID,
490					Content:    "Tool execution canceled by user",
491					IsError:    true,
492				}
493			}
494			goto out
495		default:
496			// Continue processing
497			var tool tools.BaseTool
498			for availableTool := range a.tools.Seq() {
499				if availableTool.Info().Name == toolCall.Name {
500					tool = availableTool
501					break
502				}
503			}
504
505			// Tool not found
506			if tool == nil {
507				toolResults[i] = message.ToolResult{
508					ToolCallID: toolCall.ID,
509					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
510					IsError:    true,
511				}
512				continue
513			}
514
515			// Run tool in goroutine to allow cancellation
516			type toolExecResult struct {
517				response tools.ToolResponse
518				err      error
519			}
520			resultChan := make(chan toolExecResult, 1)
521
522			go func() {
523				response, err := tool.Run(ctx, tools.ToolCall{
524					ID:    toolCall.ID,
525					Name:  toolCall.Name,
526					Input: toolCall.Input,
527				})
528				resultChan <- toolExecResult{response: response, err: err}
529			}()
530
531			var toolResponse tools.ToolResponse
532			var toolErr error
533
534			select {
535			case <-ctx.Done():
536				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
537				// Mark remaining tool calls as cancelled
538				for j := i; j < len(toolCalls); j++ {
539					toolResults[j] = message.ToolResult{
540						ToolCallID: toolCalls[j].ID,
541						Content:    "Tool execution canceled by user",
542						IsError:    true,
543					}
544				}
545				goto out
546			case result := <-resultChan:
547				toolResponse = result.response
548				toolErr = result.err
549			}
550
551			if toolErr != nil {
552				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
553				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
554					toolResults[i] = message.ToolResult{
555						ToolCallID: toolCall.ID,
556						Content:    "Permission denied",
557						IsError:    true,
558					}
559					for j := i + 1; j < len(toolCalls); j++ {
560						toolResults[j] = message.ToolResult{
561							ToolCallID: toolCalls[j].ID,
562							Content:    "Tool execution canceled by user",
563							IsError:    true,
564						}
565					}
566					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
567					break
568				}
569			}
570			toolResults[i] = message.ToolResult{
571				ToolCallID: toolCall.ID,
572				Content:    toolResponse.Content,
573				Metadata:   toolResponse.Metadata,
574				IsError:    toolResponse.IsError,
575			}
576		}
577	}
578out:
579	if len(toolResults) == 0 {
580		return assistantMsg, nil, nil
581	}
582	parts := make([]message.ContentPart, 0)
583	for _, tr := range toolResults {
584		parts = append(parts, tr)
585	}
586	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
587		Role:     message.Tool,
588		Parts:    parts,
589		Provider: a.providerID,
590	})
591	if err != nil {
592		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
593	}
594
595	return assistantMsg, &msg, err
596}
597
598func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
599	msg.AddFinish(finishReason, message, details)
600	_ = a.messages.Update(ctx, *msg)
601}
602
603func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
604	select {
605	case <-ctx.Done():
606		return ctx.Err()
607	default:
608		// Continue processing.
609	}
610
611	switch event.Type {
612	case provider.EventThinkingDelta:
613		assistantMsg.AppendReasoningContent(event.Thinking)
614		return a.messages.Update(ctx, *assistantMsg)
615	case provider.EventSignatureDelta:
616		assistantMsg.AppendReasoningSignature(event.Signature)
617		return a.messages.Update(ctx, *assistantMsg)
618	case provider.EventContentDelta:
619		assistantMsg.FinishThinking()
620		assistantMsg.AppendContent(event.Content)
621		return a.messages.Update(ctx, *assistantMsg)
622	case provider.EventToolUseStart:
623		assistantMsg.FinishThinking()
624		slog.Info("Tool call started", "toolCall", event.ToolCall)
625		assistantMsg.AddToolCall(*event.ToolCall)
626		return a.messages.Update(ctx, *assistantMsg)
627	case provider.EventToolUseDelta:
628		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
629		return a.messages.Update(ctx, *assistantMsg)
630	case provider.EventToolUseStop:
631		slog.Info("Finished tool call", "toolCall", event.ToolCall)
632		assistantMsg.FinishToolCall(event.ToolCall.ID)
633		return a.messages.Update(ctx, *assistantMsg)
634	case provider.EventError:
635		return event.Error
636	case provider.EventComplete:
637		assistantMsg.FinishThinking()
638		assistantMsg.SetToolCalls(event.Response.ToolCalls)
639		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
640		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
641			return fmt.Errorf("failed to update message: %w", err)
642		}
643		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
644	}
645
646	return nil
647}
648
649func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
650	sess, err := a.sessions.Get(ctx, sessionID)
651	if err != nil {
652		return fmt.Errorf("failed to get session: %w", err)
653	}
654
655	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
656		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
657		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
658		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
659
660	sess.Cost += cost
661	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
662	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
663
664	_, err = a.sessions.Save(ctx, sess)
665	if err != nil {
666		return fmt.Errorf("failed to save session: %w", err)
667	}
668	return nil
669}
670
671func (a *agent) Summarize(ctx context.Context, sessionID string) error {
672	if a.summarizeProvider == nil {
673		return fmt.Errorf("summarize provider not available")
674	}
675
676	// Check if session is busy
677	if a.IsSessionBusy(sessionID) {
678		return ErrSessionBusy
679	}
680
681	// Create a new context with cancellation
682	summarizeCtx, cancel := context.WithCancel(ctx)
683
684	// Store the cancel function in activeRequests to allow cancellation
685	a.activeRequests.Set(sessionID+"-summarize", cancel)
686
687	go func() {
688		defer a.activeRequests.Del(sessionID + "-summarize")
689		defer cancel()
690		event := AgentEvent{
691			Type:     AgentEventTypeSummarize,
692			Progress: "Starting summarization...",
693		}
694
695		a.Publish(pubsub.CreatedEvent, event)
696		// Get all messages from the session
697		msgs, err := a.messages.List(summarizeCtx, sessionID)
698		if err != nil {
699			event = AgentEvent{
700				Type:  AgentEventTypeError,
701				Error: fmt.Errorf("failed to list messages: %w", err),
702				Done:  true,
703			}
704			a.Publish(pubsub.CreatedEvent, event)
705			return
706		}
707		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
708
709		if len(msgs) == 0 {
710			event = AgentEvent{
711				Type:  AgentEventTypeError,
712				Error: fmt.Errorf("no messages to summarize"),
713				Done:  true,
714			}
715			a.Publish(pubsub.CreatedEvent, event)
716			return
717		}
718
719		event = AgentEvent{
720			Type:     AgentEventTypeSummarize,
721			Progress: "Analyzing conversation...",
722		}
723		a.Publish(pubsub.CreatedEvent, event)
724
725		// Add a system message to guide the summarization
726		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
727
728		// Create a new message with the summarize prompt
729		promptMsg := message.Message{
730			Role:  message.User,
731			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
732		}
733
734		// Append the prompt to the messages
735		msgsWithPrompt := append(msgs, promptMsg)
736
737		event = AgentEvent{
738			Type:     AgentEventTypeSummarize,
739			Progress: "Generating summary...",
740		}
741
742		a.Publish(pubsub.CreatedEvent, event)
743
744		// Send the messages to the summarize provider
745		response := a.summarizeProvider.StreamResponse(
746			summarizeCtx,
747			msgsWithPrompt,
748			nil,
749		)
750		var finalResponse *provider.ProviderResponse
751		for r := range response {
752			if r.Error != nil {
753				event = AgentEvent{
754					Type:  AgentEventTypeError,
755					Error: fmt.Errorf("failed to summarize: %w", err),
756					Done:  true,
757				}
758				a.Publish(pubsub.CreatedEvent, event)
759				return
760			}
761			finalResponse = r.Response
762		}
763
764		summary := strings.TrimSpace(finalResponse.Content)
765		if summary == "" {
766			event = AgentEvent{
767				Type:  AgentEventTypeError,
768				Error: fmt.Errorf("empty summary returned"),
769				Done:  true,
770			}
771			a.Publish(pubsub.CreatedEvent, event)
772			return
773		}
774		shell := shell.GetPersistentShell(config.Get().WorkingDir())
775		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
776		event = AgentEvent{
777			Type:     AgentEventTypeSummarize,
778			Progress: "Creating new session...",
779		}
780
781		a.Publish(pubsub.CreatedEvent, event)
782		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
783		if err != nil {
784			event = AgentEvent{
785				Type:  AgentEventTypeError,
786				Error: fmt.Errorf("failed to get session: %w", err),
787				Done:  true,
788			}
789
790			a.Publish(pubsub.CreatedEvent, event)
791			return
792		}
793		// Create a message in the new session with the summary
794		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
795			Role: message.Assistant,
796			Parts: []message.ContentPart{
797				message.TextContent{Text: summary},
798				message.Finish{
799					Reason: message.FinishReasonEndTurn,
800					Time:   time.Now().Unix(),
801				},
802			},
803			Model:    a.summarizeProvider.Model().ID,
804			Provider: a.summarizeProviderID,
805		})
806		if err != nil {
807			event = AgentEvent{
808				Type:  AgentEventTypeError,
809				Error: fmt.Errorf("failed to create summary message: %w", err),
810				Done:  true,
811			}
812
813			a.Publish(pubsub.CreatedEvent, event)
814			return
815		}
816		oldSession.SummaryMessageID = msg.ID
817		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
818		oldSession.PromptTokens = 0
819		model := a.summarizeProvider.Model()
820		usage := finalResponse.Usage
821		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
822			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
823			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
824			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
825		oldSession.Cost += cost
826		_, err = a.sessions.Save(summarizeCtx, oldSession)
827		if err != nil {
828			event = AgentEvent{
829				Type:  AgentEventTypeError,
830				Error: fmt.Errorf("failed to save session: %w", err),
831				Done:  true,
832			}
833			a.Publish(pubsub.CreatedEvent, event)
834		}
835
836		event = AgentEvent{
837			Type:      AgentEventTypeSummarize,
838			SessionID: oldSession.ID,
839			Progress:  "Summary complete",
840			Done:      true,
841		}
842		a.Publish(pubsub.CreatedEvent, event)
843		// Send final success event with the new session ID
844	}()
845
846	return nil
847}
848
849func (a *agent) CancelAll() {
850	if !a.IsBusy() {
851		return
852	}
853	for key := range a.activeRequests.Seq2() {
854		a.Cancel(key) // key is sessionID
855	}
856
857	timeout := time.After(5 * time.Second)
858	for a.IsBusy() {
859		select {
860		case <-timeout:
861			return
862		default:
863			time.Sleep(200 * time.Millisecond)
864		}
865	}
866}
867
868func (a *agent) UpdateModel() error {
869	cfg := config.Get()
870
871	// Get current provider configuration
872	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
873	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
874		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
875	}
876
877	// Check if provider has changed
878	if string(currentProviderCfg.ID) != a.providerID {
879		// Provider changed, need to recreate the main provider
880		model := cfg.GetModelByType(a.agentCfg.Model)
881		if model.ID == "" {
882			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
883		}
884
885		promptID := agentPromptMap[a.agentCfg.ID]
886		if promptID == "" {
887			promptID = prompt.PromptDefault
888		}
889
890		opts := []provider.ProviderClientOption{
891			provider.WithModel(a.agentCfg.Model),
892			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
893		}
894
895		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
896		if err != nil {
897			return fmt.Errorf("failed to create new provider: %w", err)
898		}
899
900		// Update the provider and provider ID
901		a.provider = newProvider
902		a.providerID = string(currentProviderCfg.ID)
903	}
904
905	// Check if small model provider has changed (affects title and summarize providers)
906	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
907	var smallModelProviderCfg config.ProviderConfig
908
909	for p := range cfg.Providers.Seq() {
910		if p.ID == smallModelCfg.Provider {
911			smallModelProviderCfg = p
912			break
913		}
914	}
915
916	if smallModelProviderCfg.ID == "" {
917		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
918	}
919
920	// Check if summarize provider has changed
921	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
922		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
923		if smallModel == nil {
924			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
925		}
926
927		// Recreate title provider
928		titleOpts := []provider.ProviderClientOption{
929			provider.WithModel(config.SelectedModelTypeSmall),
930			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
931			// We want the title to be short, so we limit the max tokens
932			provider.WithMaxTokens(40),
933		}
934		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
935		if err != nil {
936			return fmt.Errorf("failed to create new title provider: %w", err)
937		}
938
939		// Recreate summarize provider
940		summarizeOpts := []provider.ProviderClientOption{
941			provider.WithModel(config.SelectedModelTypeSmall),
942			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
943		}
944		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
945		if err != nil {
946			return fmt.Errorf("failed to create new summarize provider: %w", err)
947		}
948
949		// Update the providers and provider ID
950		a.titleProvider = newTitleProvider
951		a.summarizeProvider = newSummarizeProvider
952		a.summarizeProviderID = string(smallModelProviderCfg.ID)
953	}
954
955	return nil
956}