1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/crush/internal/shell"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() catwalk.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70	mcpTools []McpTool
 71
 72	tools *csync.LazySlice[tools.BaseTool]
 73
 74	provider   provider.Provider
 75	providerID string
 76
 77	titleProvider       provider.Provider
 78	summarizeProvider   provider.Provider
 79	summarizeProviderID string
 80
 81	activeRequests *csync.Map[string, context.CancelFunc]
 82}
 83
 84var agentPromptMap = map[string]prompt.PromptID{
 85	"coder": prompt.PromptCoder,
 86	"task":  prompt.PromptTask,
 87}
 88
 89func NewAgent(
 90	ctx context.Context,
 91	agentCfg config.Agent,
 92	// These services are needed in the tools
 93	permissions permission.Service,
 94	sessions session.Service,
 95	messages message.Service,
 96	history history.Service,
 97	lspClients map[string]*lsp.Client,
 98) (Service, error) {
 99	cfg := config.Get()
100
101	var agentTool tools.BaseTool
102	if agentCfg.ID == "coder" {
103		taskAgentCfg := config.Get().Agents["task"]
104		if taskAgentCfg.ID == "" {
105			return nil, fmt.Errorf("task agent not found in config")
106		}
107		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108		if err != nil {
109			return nil, fmt.Errorf("failed to create task agent: %w", err)
110		}
111
112		agentTool = NewAgentTool(taskAgent, sessions, messages)
113	}
114
115	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116	if providerCfg == nil {
117		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118	}
119	model := config.Get().GetModelByType(agentCfg.Model)
120
121	if model == nil {
122		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123	}
124
125	promptID := agentPromptMap[agentCfg.ID]
126	if promptID == "" {
127		promptID = prompt.PromptDefault
128	}
129	opts := []provider.ProviderClientOption{
130		provider.WithModel(agentCfg.Model),
131		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132	}
133	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134	if err != nil {
135		return nil, err
136	}
137
138	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139	var smallModelProviderCfg *config.ProviderConfig
140	if smallModelCfg.Provider == providerCfg.ID {
141		smallModelProviderCfg = providerCfg
142	} else {
143		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145		if smallModelProviderCfg.ID == "" {
146			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147		}
148	}
149	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150	if smallModel.ID == "" {
151		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152	}
153
154	titleOpts := []provider.ProviderClientOption{
155		provider.WithModel(config.SelectedModelTypeSmall),
156		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157	}
158	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159	if err != nil {
160		return nil, err
161	}
162
163	summarizeOpts := []provider.ProviderClientOption{
164		provider.WithModel(config.SelectedModelTypeLarge),
165		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
166	}
167	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
168	if err != nil {
169		return nil, err
170	}
171
172	toolFn := func() []tools.BaseTool {
173		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
174		defer func() {
175			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
176		}()
177
178		cwd := cfg.WorkingDir()
179		allTools := []tools.BaseTool{
180			tools.NewBashTool(permissions, cwd),
181			tools.NewDownloadTool(permissions, cwd),
182			tools.NewEditTool(lspClients, permissions, history, cwd),
183			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
184			tools.NewFetchTool(permissions, cwd),
185			tools.NewGlobTool(cwd),
186			tools.NewGrepTool(cwd),
187			tools.NewLsTool(permissions, cwd),
188			tools.NewSourcegraphTool(),
189			tools.NewViewTool(lspClients, permissions, cwd),
190			tools.NewWriteTool(lspClients, permissions, history, cwd),
191		}
192
193		mcpToolsOnce.Do(func() {
194			mcpTools = doGetMCPTools(ctx, permissions, cfg)
195		})
196		allTools = append(allTools, mcpTools...)
197
198		if len(lspClients) > 0 {
199			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
200		}
201
202		if agentTool != nil {
203			allTools = append(allTools, agentTool)
204		}
205
206		if agentCfg.AllowedTools == nil {
207			return allTools
208		}
209
210		var filteredTools []tools.BaseTool
211		for _, tool := range allTools {
212			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
213				filteredTools = append(filteredTools, tool)
214			}
215		}
216		return filteredTools
217	}
218
219	return &agent{
220		Broker:              pubsub.NewBroker[AgentEvent](),
221		agentCfg:            agentCfg,
222		provider:            agentProvider,
223		providerID:          string(providerCfg.ID),
224		messages:            messages,
225		sessions:            sessions,
226		titleProvider:       titleProvider,
227		summarizeProvider:   summarizeProvider,
228		summarizeProviderID: string(providerCfg.ID),
229		activeRequests:      csync.NewMap[string, context.CancelFunc](),
230		tools:               csync.NewLazySlice(toolFn),
231	}, nil
232}
233
234func (a *agent) Model() catwalk.Model {
235	return *config.Get().GetModelByType(a.agentCfg.Model)
236}
237
238func (a *agent) Cancel(sessionID string) {
239	// Cancel regular requests
240	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
241		slog.Info("Request cancellation initiated", "session_id", sessionID)
242		cancel()
243	}
244
245	// Also check for summarize requests
246	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
247		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
248		cancel()
249	}
250}
251
252func (a *agent) IsBusy() bool {
253	var busy bool
254	for cancelFunc := range a.activeRequests.Seq() {
255		if cancelFunc != nil {
256			busy = true
257			break
258		}
259	}
260	return busy
261}
262
263func (a *agent) IsSessionBusy(sessionID string) bool {
264	_, busy := a.activeRequests.Get(sessionID)
265	return busy
266}
267
268func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
269	if content == "" {
270		return nil
271	}
272	if a.titleProvider == nil {
273		return nil
274	}
275	session, err := a.sessions.Get(ctx, sessionID)
276	if err != nil {
277		return err
278	}
279	parts := []message.ContentPart{message.TextContent{
280		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
281	}}
282
283	// Use streaming approach like summarization
284	response := a.titleProvider.StreamResponse(
285		ctx,
286		[]message.Message{
287			{
288				Role:  message.User,
289				Parts: parts,
290			},
291		},
292		nil,
293	)
294
295	var finalResponse *provider.ProviderResponse
296	for r := range response {
297		if r.Error != nil {
298			return r.Error
299		}
300		finalResponse = r.Response
301	}
302
303	if finalResponse == nil {
304		return fmt.Errorf("no response received from title provider")
305	}
306
307	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
308	if title == "" {
309		return nil
310	}
311
312	session.Title = title
313	_, err = a.sessions.Save(ctx, session)
314	return err
315}
316
317func (a *agent) err(err error) AgentEvent {
318	return AgentEvent{
319		Type:  AgentEventTypeError,
320		Error: err,
321	}
322}
323
324func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
325	if !a.Model().SupportsImages && attachments != nil {
326		attachments = nil
327	}
328	events := make(chan AgentEvent)
329	if a.IsSessionBusy(sessionID) {
330		return nil, ErrSessionBusy
331	}
332
333	genCtx, cancel := context.WithCancel(ctx)
334
335	a.activeRequests.Set(sessionID, cancel)
336	go func() {
337		slog.Debug("Request started", "sessionID", sessionID)
338		defer log.RecoverPanic("agent.Run", func() {
339			events <- a.err(fmt.Errorf("panic while running the agent"))
340		})
341		var attachmentParts []message.ContentPart
342		for _, attachment := range attachments {
343			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
344		}
345		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
346		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
347			slog.Error(result.Error.Error())
348		}
349		slog.Debug("Request completed", "sessionID", sessionID)
350		a.activeRequests.Del(sessionID)
351		cancel()
352		a.Publish(pubsub.CreatedEvent, result)
353		events <- result
354		close(events)
355	}()
356	return events, nil
357}
358
359func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
360	cfg := config.Get()
361	// List existing messages; if none, start title generation asynchronously.
362	msgs, err := a.messages.List(ctx, sessionID)
363	if err != nil {
364		return a.err(fmt.Errorf("failed to list messages: %w", err))
365	}
366	if len(msgs) == 0 {
367		go func() {
368			defer log.RecoverPanic("agent.Run", func() {
369				slog.Error("panic while generating title")
370			})
371			titleErr := a.generateTitle(context.Background(), sessionID, content)
372			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
373				slog.Error("failed to generate title", "error", titleErr)
374			}
375		}()
376	}
377	session, err := a.sessions.Get(ctx, sessionID)
378	if err != nil {
379		return a.err(fmt.Errorf("failed to get session: %w", err))
380	}
381	if session.SummaryMessageID != "" {
382		summaryMsgInex := -1
383		for i, msg := range msgs {
384			if msg.ID == session.SummaryMessageID {
385				summaryMsgInex = i
386				break
387			}
388		}
389		if summaryMsgInex != -1 {
390			msgs = msgs[summaryMsgInex:]
391			msgs[0].Role = message.User
392		}
393	}
394
395	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
396	if err != nil {
397		return a.err(fmt.Errorf("failed to create user message: %w", err))
398	}
399	// Append the new user message to the conversation history.
400	msgHistory := append(msgs, userMsg)
401
402	for {
403		// Check for cancellation before each iteration
404		select {
405		case <-ctx.Done():
406			return a.err(ctx.Err())
407		default:
408			// Continue processing
409		}
410		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
411		if err != nil {
412			if errors.Is(err, context.Canceled) {
413				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
414				a.messages.Update(context.Background(), agentMessage)
415				return a.err(ErrRequestCancelled)
416			}
417			return a.err(fmt.Errorf("failed to process events: %w", err))
418		}
419		if cfg.Options.Debug {
420			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
421		}
422		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
423			// We are not done, we need to respond with the tool response
424			msgHistory = append(msgHistory, agentMessage, *toolResults)
425			continue
426		}
427		if agentMessage.FinishReason() == "" {
428			// Kujtim: could not track down where this is happening but this means its cancelled
429			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
430			_ = a.messages.Update(context.Background(), agentMessage)
431			return a.err(ErrRequestCancelled)
432		}
433		return AgentEvent{
434			Type:    AgentEventTypeResponse,
435			Message: agentMessage,
436			Done:    true,
437		}
438	}
439}
440
441func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
442	parts := []message.ContentPart{message.TextContent{Text: content}}
443	parts = append(parts, attachmentParts...)
444	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
445		Role:  message.User,
446		Parts: parts,
447	})
448}
449
450func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
451	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
452	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
453
454	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
455		Role:     message.Assistant,
456		Parts:    []message.ContentPart{},
457		Model:    a.Model().ID,
458		Provider: a.providerID,
459	})
460	if err != nil {
461		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
462	}
463
464	// Add the session and message ID into the context if needed by tools.
465	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
466
467	// Process each event in the stream.
468	for event := range eventChan {
469		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
470			if errors.Is(processErr, context.Canceled) {
471				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
472			} else {
473				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
474			}
475			return assistantMsg, nil, processErr
476		}
477		if ctx.Err() != nil {
478			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
479			return assistantMsg, nil, ctx.Err()
480		}
481	}
482
483	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
484	toolCalls := assistantMsg.ToolCalls()
485	for i, toolCall := range toolCalls {
486		select {
487		case <-ctx.Done():
488			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
489			// Make all future tool calls cancelled
490			for j := i; j < len(toolCalls); j++ {
491				toolResults[j] = message.ToolResult{
492					ToolCallID: toolCalls[j].ID,
493					Content:    "Tool execution canceled by user",
494					IsError:    true,
495				}
496			}
497			goto out
498		default:
499			// Continue processing
500			var tool tools.BaseTool
501			for availableTool := range a.tools.Seq() {
502				if availableTool.Info().Name == toolCall.Name {
503					tool = availableTool
504					break
505				}
506			}
507
508			// Tool not found
509			if tool == nil {
510				toolResults[i] = message.ToolResult{
511					ToolCallID: toolCall.ID,
512					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
513					IsError:    true,
514				}
515				continue
516			}
517
518			// Run tool in goroutine to allow cancellation
519			type toolExecResult struct {
520				response tools.ToolResponse
521				err      error
522			}
523			resultChan := make(chan toolExecResult, 1)
524
525			go func() {
526				response, err := tool.Run(ctx, tools.ToolCall{
527					ID:    toolCall.ID,
528					Name:  toolCall.Name,
529					Input: toolCall.Input,
530				})
531				resultChan <- toolExecResult{response: response, err: err}
532			}()
533
534			var toolResponse tools.ToolResponse
535			var toolErr error
536
537			select {
538			case <-ctx.Done():
539				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
540				// Mark remaining tool calls as cancelled
541				for j := i; j < len(toolCalls); j++ {
542					toolResults[j] = message.ToolResult{
543						ToolCallID: toolCalls[j].ID,
544						Content:    "Tool execution canceled by user",
545						IsError:    true,
546					}
547				}
548				goto out
549			case result := <-resultChan:
550				toolResponse = result.response
551				toolErr = result.err
552			}
553
554			if toolErr != nil {
555				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
556				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
557					toolResults[i] = message.ToolResult{
558						ToolCallID: toolCall.ID,
559						Content:    "Permission denied",
560						IsError:    true,
561					}
562					for j := i + 1; j < len(toolCalls); j++ {
563						toolResults[j] = message.ToolResult{
564							ToolCallID: toolCalls[j].ID,
565							Content:    "Tool execution canceled by user",
566							IsError:    true,
567						}
568					}
569					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
570					break
571				}
572			}
573			toolResults[i] = message.ToolResult{
574				ToolCallID: toolCall.ID,
575				Content:    toolResponse.Content,
576				Metadata:   toolResponse.Metadata,
577				IsError:    toolResponse.IsError,
578			}
579		}
580	}
581out:
582	if len(toolResults) == 0 {
583		return assistantMsg, nil, nil
584	}
585	parts := make([]message.ContentPart, 0)
586	for _, tr := range toolResults {
587		parts = append(parts, tr)
588	}
589	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
590		Role:     message.Tool,
591		Parts:    parts,
592		Provider: a.providerID,
593	})
594	if err != nil {
595		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
596	}
597
598	return assistantMsg, &msg, err
599}
600
601func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
602	msg.AddFinish(finishReason, message, details)
603	_ = a.messages.Update(ctx, *msg)
604}
605
606func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
607	select {
608	case <-ctx.Done():
609		return ctx.Err()
610	default:
611		// Continue processing.
612	}
613
614	switch event.Type {
615	case provider.EventThinkingDelta:
616		assistantMsg.AppendReasoningContent(event.Thinking)
617		return a.messages.Update(ctx, *assistantMsg)
618	case provider.EventSignatureDelta:
619		assistantMsg.AppendReasoningSignature(event.Signature)
620		return a.messages.Update(ctx, *assistantMsg)
621	case provider.EventContentDelta:
622		assistantMsg.FinishThinking()
623		assistantMsg.AppendContent(event.Content)
624		return a.messages.Update(ctx, *assistantMsg)
625	case provider.EventToolUseStart:
626		assistantMsg.FinishThinking()
627		slog.Info("Tool call started", "toolCall", event.ToolCall)
628		assistantMsg.AddToolCall(*event.ToolCall)
629		return a.messages.Update(ctx, *assistantMsg)
630	case provider.EventToolUseDelta:
631		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
632		return a.messages.Update(ctx, *assistantMsg)
633	case provider.EventToolUseStop:
634		slog.Info("Finished tool call", "toolCall", event.ToolCall)
635		assistantMsg.FinishToolCall(event.ToolCall.ID)
636		return a.messages.Update(ctx, *assistantMsg)
637	case provider.EventError:
638		return event.Error
639	case provider.EventComplete:
640		assistantMsg.FinishThinking()
641		assistantMsg.SetToolCalls(event.Response.ToolCalls)
642		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
643		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
644			return fmt.Errorf("failed to update message: %w", err)
645		}
646		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
647	}
648
649	return nil
650}
651
652func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
653	sess, err := a.sessions.Get(ctx, sessionID)
654	if err != nil {
655		return fmt.Errorf("failed to get session: %w", err)
656	}
657
658	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
659		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
660		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
661		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
662
663	sess.Cost += cost
664	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
665	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
666
667	_, err = a.sessions.Save(ctx, sess)
668	if err != nil {
669		return fmt.Errorf("failed to save session: %w", err)
670	}
671	return nil
672}
673
674func (a *agent) Summarize(ctx context.Context, sessionID string) error {
675	if a.summarizeProvider == nil {
676		return fmt.Errorf("summarize provider not available")
677	}
678
679	// Check if session is busy
680	if a.IsSessionBusy(sessionID) {
681		return ErrSessionBusy
682	}
683
684	// Create a new context with cancellation
685	summarizeCtx, cancel := context.WithCancel(ctx)
686
687	// Store the cancel function in activeRequests to allow cancellation
688	a.activeRequests.Set(sessionID+"-summarize", cancel)
689
690	go func() {
691		defer a.activeRequests.Del(sessionID + "-summarize")
692		defer cancel()
693		event := AgentEvent{
694			Type:     AgentEventTypeSummarize,
695			Progress: "Starting summarization...",
696		}
697
698		a.Publish(pubsub.CreatedEvent, event)
699		// Get all messages from the session
700		msgs, err := a.messages.List(summarizeCtx, sessionID)
701		if err != nil {
702			event = AgentEvent{
703				Type:  AgentEventTypeError,
704				Error: fmt.Errorf("failed to list messages: %w", err),
705				Done:  true,
706			}
707			a.Publish(pubsub.CreatedEvent, event)
708			return
709		}
710		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
711
712		if len(msgs) == 0 {
713			event = AgentEvent{
714				Type:  AgentEventTypeError,
715				Error: fmt.Errorf("no messages to summarize"),
716				Done:  true,
717			}
718			a.Publish(pubsub.CreatedEvent, event)
719			return
720		}
721
722		event = AgentEvent{
723			Type:     AgentEventTypeSummarize,
724			Progress: "Analyzing conversation...",
725		}
726		a.Publish(pubsub.CreatedEvent, event)
727
728		// Add a system message to guide the summarization
729		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
730
731		// Create a new message with the summarize prompt
732		promptMsg := message.Message{
733			Role:  message.User,
734			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
735		}
736
737		// Append the prompt to the messages
738		msgsWithPrompt := append(msgs, promptMsg)
739
740		event = AgentEvent{
741			Type:     AgentEventTypeSummarize,
742			Progress: "Generating summary...",
743		}
744
745		a.Publish(pubsub.CreatedEvent, event)
746
747		// Send the messages to the summarize provider
748		response := a.summarizeProvider.StreamResponse(
749			summarizeCtx,
750			msgsWithPrompt,
751			nil,
752		)
753		var finalResponse *provider.ProviderResponse
754		for r := range response {
755			if r.Error != nil {
756				event = AgentEvent{
757					Type:  AgentEventTypeError,
758					Error: fmt.Errorf("failed to summarize: %w", err),
759					Done:  true,
760				}
761				a.Publish(pubsub.CreatedEvent, event)
762				return
763			}
764			finalResponse = r.Response
765		}
766
767		summary := strings.TrimSpace(finalResponse.Content)
768		if summary == "" {
769			event = AgentEvent{
770				Type:  AgentEventTypeError,
771				Error: fmt.Errorf("empty summary returned"),
772				Done:  true,
773			}
774			a.Publish(pubsub.CreatedEvent, event)
775			return
776		}
777		shell := shell.GetPersistentShell(config.Get().WorkingDir())
778		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
779		event = AgentEvent{
780			Type:     AgentEventTypeSummarize,
781			Progress: "Creating new session...",
782		}
783
784		a.Publish(pubsub.CreatedEvent, event)
785		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
786		if err != nil {
787			event = AgentEvent{
788				Type:  AgentEventTypeError,
789				Error: fmt.Errorf("failed to get session: %w", err),
790				Done:  true,
791			}
792
793			a.Publish(pubsub.CreatedEvent, event)
794			return
795		}
796		// Create a message in the new session with the summary
797		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
798			Role: message.Assistant,
799			Parts: []message.ContentPart{
800				message.TextContent{Text: summary},
801				message.Finish{
802					Reason: message.FinishReasonEndTurn,
803					Time:   time.Now().Unix(),
804				},
805			},
806			Model:    a.summarizeProvider.Model().ID,
807			Provider: a.summarizeProviderID,
808		})
809		if err != nil {
810			event = AgentEvent{
811				Type:  AgentEventTypeError,
812				Error: fmt.Errorf("failed to create summary message: %w", err),
813				Done:  true,
814			}
815
816			a.Publish(pubsub.CreatedEvent, event)
817			return
818		}
819		oldSession.SummaryMessageID = msg.ID
820		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
821		oldSession.PromptTokens = 0
822		model := a.summarizeProvider.Model()
823		usage := finalResponse.Usage
824		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
825			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
826			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
827			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
828		oldSession.Cost += cost
829		_, err = a.sessions.Save(summarizeCtx, oldSession)
830		if err != nil {
831			event = AgentEvent{
832				Type:  AgentEventTypeError,
833				Error: fmt.Errorf("failed to save session: %w", err),
834				Done:  true,
835			}
836			a.Publish(pubsub.CreatedEvent, event)
837		}
838
839		event = AgentEvent{
840			Type:      AgentEventTypeSummarize,
841			SessionID: oldSession.ID,
842			Progress:  "Summary complete",
843			Done:      true,
844		}
845		a.Publish(pubsub.CreatedEvent, event)
846		// Send final success event with the new session ID
847	}()
848
849	return nil
850}
851
852func (a *agent) CancelAll() {
853	if !a.IsBusy() {
854		return
855	}
856	for key := range a.activeRequests.Seq2() {
857		a.Cancel(key) // key is sessionID
858	}
859
860	timeout := time.After(5 * time.Second)
861	for a.IsBusy() {
862		select {
863		case <-timeout:
864			return
865		default:
866			time.Sleep(200 * time.Millisecond)
867		}
868	}
869}
870
871func (a *agent) UpdateModel() error {
872	cfg := config.Get()
873
874	// Get current provider configuration
875	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
876	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
877		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
878	}
879
880	// Check if provider has changed
881	if string(currentProviderCfg.ID) != a.providerID {
882		// Provider changed, need to recreate the main provider
883		model := cfg.GetModelByType(a.agentCfg.Model)
884		if model.ID == "" {
885			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
886		}
887
888		promptID := agentPromptMap[a.agentCfg.ID]
889		if promptID == "" {
890			promptID = prompt.PromptDefault
891		}
892
893		opts := []provider.ProviderClientOption{
894			provider.WithModel(a.agentCfg.Model),
895			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
896		}
897
898		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
899		if err != nil {
900			return fmt.Errorf("failed to create new provider: %w", err)
901		}
902
903		// Update the provider and provider ID
904		a.provider = newProvider
905		a.providerID = string(currentProviderCfg.ID)
906	}
907
908	// Check if providers have changed for title (small) and summarize (large)
909	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
910	var smallModelProviderCfg config.ProviderConfig
911	for p := range cfg.Providers.Seq() {
912		if p.ID == smallModelCfg.Provider {
913			smallModelProviderCfg = p
914			break
915		}
916	}
917	if smallModelProviderCfg.ID == "" {
918		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
919	}
920
921	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
922	var largeModelProviderCfg config.ProviderConfig
923	for p := range cfg.Providers.Seq() {
924		if p.ID == largeModelCfg.Provider {
925			largeModelProviderCfg = p
926			break
927		}
928	}
929	if largeModelProviderCfg.ID == "" {
930		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
931	}
932
933	// Recreate title provider
934	titleOpts := []provider.ProviderClientOption{
935		provider.WithModel(config.SelectedModelTypeSmall),
936		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
937		provider.WithMaxTokens(40),
938	}
939	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
940	if err != nil {
941		return fmt.Errorf("failed to create new title provider: %w", err)
942	}
943	a.titleProvider = newTitleProvider
944
945	// Recreate summarize provider if provider changed (now large model)
946	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
947		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
948		if largeModel == nil {
949			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
950		}
951		summarizeOpts := []provider.ProviderClientOption{
952			provider.WithModel(config.SelectedModelTypeLarge),
953			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
954		}
955		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
956		if err != nil {
957			return fmt.Errorf("failed to create new summarize provider: %w", err)
958		}
959		a.summarizeProvider = newSummarizeProvider
960		a.summarizeProviderID = string(largeModelProviderCfg.ID)
961	}
962
963	return nil
964}