1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/crush/internal/shell"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() catwalk.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70	mcpTools []McpTool
 71
 72	tools *csync.LazySlice[tools.BaseTool]
 73
 74	provider   provider.Provider
 75	providerID string
 76
 77	titleProvider       provider.Provider
 78	summarizeProvider   provider.Provider
 79	summarizeProviderID string
 80
 81	activeRequests *csync.Map[string, context.CancelFunc]
 82}
 83
 84var agentPromptMap = map[string]prompt.PromptID{
 85	"coder": prompt.PromptCoder,
 86	"task":  prompt.PromptTask,
 87}
 88
 89func NewAgent(
 90	ctx context.Context,
 91	agentCfg config.Agent,
 92	// These services are needed in the tools
 93	permissions permission.Service,
 94	sessions session.Service,
 95	messages message.Service,
 96	history history.Service,
 97	lspClients map[string]*lsp.Client,
 98) (Service, error) {
 99	cfg := config.Get()
100
101	var agentTool tools.BaseTool
102	if agentCfg.ID == "coder" {
103		taskAgentCfg := config.Get().Agents["task"]
104		if taskAgentCfg.ID == "" {
105			return nil, fmt.Errorf("task agent not found in config")
106		}
107		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108		if err != nil {
109			return nil, fmt.Errorf("failed to create task agent: %w", err)
110		}
111
112		agentTool = NewAgentTool(taskAgent, sessions, messages)
113	}
114
115	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116	if providerCfg == nil {
117		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118	}
119	model := config.Get().GetModelByType(agentCfg.Model)
120
121	if model == nil {
122		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123	}
124
125	promptID := agentPromptMap[agentCfg.ID]
126	if promptID == "" {
127		promptID = prompt.PromptDefault
128	}
129	opts := []provider.ProviderClientOption{
130		provider.WithModel(agentCfg.Model),
131		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132	}
133	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134	if err != nil {
135		return nil, err
136	}
137
138	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139	var smallModelProviderCfg *config.ProviderConfig
140	if smallModelCfg.Provider == providerCfg.ID {
141		smallModelProviderCfg = providerCfg
142	} else {
143		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145		if smallModelProviderCfg.ID == "" {
146			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147		}
148	}
149	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150	if smallModel.ID == "" {
151		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152	}
153
154	titleOpts := []provider.ProviderClientOption{
155		provider.WithModel(config.SelectedModelTypeSmall),
156		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157	}
158	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159	if err != nil {
160		return nil, err
161	}
162	summarizeOpts := []provider.ProviderClientOption{
163		provider.WithModel(config.SelectedModelTypeSmall),
164		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165	}
166	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167	if err != nil {
168		return nil, err
169	}
170
171	toolFn := func() []tools.BaseTool {
172		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173		defer func() {
174			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175		}()
176
177		cwd := cfg.WorkingDir()
178		allTools := []tools.BaseTool{
179			tools.NewBashTool(permissions, cwd),
180			tools.NewDownloadTool(permissions, cwd),
181			tools.NewEditTool(lspClients, permissions, history, cwd),
182			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
183			tools.NewFetchTool(permissions, cwd),
184			tools.NewGlobTool(cwd),
185			tools.NewGrepTool(cwd),
186			tools.NewLsTool(permissions, cwd),
187			tools.NewSourcegraphTool(),
188			tools.NewViewTool(lspClients, permissions, cwd),
189			tools.NewWriteTool(lspClients, permissions, history, cwd),
190		}
191
192		mcpToolsOnce.Do(func() {
193			mcpTools = doGetMCPTools(ctx, permissions, cfg)
194		})
195		allTools = append(allTools, mcpTools...)
196
197		if len(lspClients) > 0 {
198			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
199		}
200
201		if agentTool != nil {
202			allTools = append(allTools, agentTool)
203		}
204
205		if agentCfg.AllowedTools == nil {
206			return allTools
207		}
208
209		var filteredTools []tools.BaseTool
210		for _, tool := range allTools {
211			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
212				filteredTools = append(filteredTools, tool)
213			}
214		}
215		return filteredTools
216	}
217
218	return &agent{
219		Broker:              pubsub.NewBroker[AgentEvent](),
220		agentCfg:            agentCfg,
221		provider:            agentProvider,
222		providerID:          string(providerCfg.ID),
223		messages:            messages,
224		sessions:            sessions,
225		titleProvider:       titleProvider,
226		summarizeProvider:   summarizeProvider,
227		summarizeProviderID: string(smallModelProviderCfg.ID),
228		activeRequests:      csync.NewMap[string, context.CancelFunc](),
229		tools:               csync.NewLazySlice(toolFn),
230	}, nil
231}
232
233func (a *agent) Model() catwalk.Model {
234	return *config.Get().GetModelByType(a.agentCfg.Model)
235}
236
237func (a *agent) Cancel(sessionID string) {
238	// Cancel regular requests
239	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
240		slog.Info("Request cancellation initiated", "session_id", sessionID)
241		cancel()
242	}
243
244	// Also check for summarize requests
245	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
246		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
247		cancel()
248	}
249}
250
251func (a *agent) IsBusy() bool {
252	var busy bool
253	for cancelFunc := range a.activeRequests.Seq() {
254		if cancelFunc != nil {
255			busy = true
256			break
257		}
258	}
259	return busy
260}
261
262func (a *agent) IsSessionBusy(sessionID string) bool {
263	_, busy := a.activeRequests.Get(sessionID)
264	return busy
265}
266
267func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
268	if content == "" {
269		return nil
270	}
271	if a.titleProvider == nil {
272		return nil
273	}
274	session, err := a.sessions.Get(ctx, sessionID)
275	if err != nil {
276		return err
277	}
278	parts := []message.ContentPart{message.TextContent{
279		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
280	}}
281
282	// Use streaming approach like summarization
283	response := a.titleProvider.StreamResponse(
284		ctx,
285		[]message.Message{
286			{
287				Role:  message.User,
288				Parts: parts,
289			},
290		},
291		nil,
292	)
293
294	var finalResponse *provider.ProviderResponse
295	for r := range response {
296		if r.Error != nil {
297			return r.Error
298		}
299		finalResponse = r.Response
300	}
301
302	if finalResponse == nil {
303		return fmt.Errorf("no response received from title provider")
304	}
305
306	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
307	if title == "" {
308		return nil
309	}
310
311	session.Title = title
312	_, err = a.sessions.Save(ctx, session)
313	return err
314}
315
316func (a *agent) err(err error) AgentEvent {
317	return AgentEvent{
318		Type:  AgentEventTypeError,
319		Error: err,
320	}
321}
322
323func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
324	if !a.Model().SupportsImages && attachments != nil {
325		attachments = nil
326	}
327	events := make(chan AgentEvent)
328	if a.IsSessionBusy(sessionID) {
329		return nil, ErrSessionBusy
330	}
331
332	genCtx, cancel := context.WithCancel(ctx)
333
334	a.activeRequests.Set(sessionID, cancel)
335	go func() {
336		slog.Debug("Request started", "sessionID", sessionID)
337		defer log.RecoverPanic("agent.Run", func() {
338			events <- a.err(fmt.Errorf("panic while running the agent"))
339		})
340		var attachmentParts []message.ContentPart
341		for _, attachment := range attachments {
342			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
343		}
344		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
345		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
346			slog.Error(result.Error.Error())
347		}
348		slog.Debug("Request completed", "sessionID", sessionID)
349		a.activeRequests.Del(sessionID)
350		cancel()
351		a.Publish(pubsub.CreatedEvent, result)
352		events <- result
353		close(events)
354	}()
355	return events, nil
356}
357
358func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
359	cfg := config.Get()
360	// List existing messages; if none, start title generation asynchronously.
361	msgs, err := a.messages.List(ctx, sessionID)
362	if err != nil {
363		return a.err(fmt.Errorf("failed to list messages: %w", err))
364	}
365	if len(msgs) == 0 {
366		go func() {
367			defer log.RecoverPanic("agent.Run", func() {
368				slog.Error("panic while generating title")
369			})
370			titleErr := a.generateTitle(context.Background(), sessionID, content)
371			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
372				slog.Error("failed to generate title", "error", titleErr)
373			}
374		}()
375	}
376	session, err := a.sessions.Get(ctx, sessionID)
377	if err != nil {
378		return a.err(fmt.Errorf("failed to get session: %w", err))
379	}
380	if session.SummaryMessageID != "" {
381		summaryMsgInex := -1
382		for i, msg := range msgs {
383			if msg.ID == session.SummaryMessageID {
384				summaryMsgInex = i
385				break
386			}
387		}
388		if summaryMsgInex != -1 {
389			msgs = msgs[summaryMsgInex:]
390			msgs[0].Role = message.User
391		}
392	}
393
394	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
395	if err != nil {
396		return a.err(fmt.Errorf("failed to create user message: %w", err))
397	}
398	// Append the new user message to the conversation history.
399	msgHistory := append(msgs, userMsg)
400
401	for {
402		// Check for cancellation before each iteration
403		select {
404		case <-ctx.Done():
405			return a.err(ctx.Err())
406		default:
407			// Continue processing
408		}
409		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
410		if err != nil {
411			if errors.Is(err, context.Canceled) {
412				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
413				a.messages.Update(context.Background(), agentMessage)
414				return a.err(ErrRequestCancelled)
415			}
416			return a.err(fmt.Errorf("failed to process events: %w", err))
417		}
418		if cfg.Options.Debug {
419			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
420		}
421		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
422			// We are not done, we need to respond with the tool response
423			msgHistory = append(msgHistory, agentMessage, *toolResults)
424			continue
425		}
426		if agentMessage.FinishReason() == "" {
427			// Kujtim: could not track down where this is happening but this means its cancelled
428			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
429			_ = a.messages.Update(context.Background(), agentMessage)
430			return a.err(ErrRequestCancelled)
431		}
432		return AgentEvent{
433			Type:    AgentEventTypeResponse,
434			Message: agentMessage,
435			Done:    true,
436		}
437	}
438}
439
440func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
441	parts := []message.ContentPart{message.TextContent{Text: content}}
442	parts = append(parts, attachmentParts...)
443	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444		Role:  message.User,
445		Parts: parts,
446	})
447}
448
449func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
450	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
451	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
452
453	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
454		Role:     message.Assistant,
455		Parts:    []message.ContentPart{},
456		Model:    a.Model().ID,
457		Provider: a.providerID,
458	})
459	if err != nil {
460		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
461	}
462
463	// Add the session and message ID into the context if needed by tools.
464	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
465
466	// Process each event in the stream.
467	for event := range eventChan {
468		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
469			if errors.Is(processErr, context.Canceled) {
470				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
471			} else {
472				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
473			}
474			return assistantMsg, nil, processErr
475		}
476		if ctx.Err() != nil {
477			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
478			return assistantMsg, nil, ctx.Err()
479		}
480	}
481
482	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
483	toolCalls := assistantMsg.ToolCalls()
484	for i, toolCall := range toolCalls {
485		select {
486		case <-ctx.Done():
487			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
488			// Make all future tool calls cancelled
489			for j := i; j < len(toolCalls); j++ {
490				toolResults[j] = message.ToolResult{
491					ToolCallID: toolCalls[j].ID,
492					Content:    "Tool execution canceled by user",
493					IsError:    true,
494				}
495			}
496			goto out
497		default:
498			// Continue processing
499			var tool tools.BaseTool
500			for availableTool := range a.tools.Seq() {
501				if availableTool.Info().Name == toolCall.Name {
502					tool = availableTool
503					break
504				}
505			}
506
507			// Tool not found
508			if tool == nil {
509				toolResults[i] = message.ToolResult{
510					ToolCallID: toolCall.ID,
511					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
512					IsError:    true,
513				}
514				continue
515			}
516
517			// Run tool in goroutine to allow cancellation
518			type toolExecResult struct {
519				response tools.ToolResponse
520				err      error
521			}
522			resultChan := make(chan toolExecResult, 1)
523
524			go func() {
525				response, err := tool.Run(ctx, tools.ToolCall{
526					ID:    toolCall.ID,
527					Name:  toolCall.Name,
528					Input: toolCall.Input,
529				})
530				resultChan <- toolExecResult{response: response, err: err}
531			}()
532
533			var toolResponse tools.ToolResponse
534			var toolErr error
535
536			select {
537			case <-ctx.Done():
538				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
539				// Mark remaining tool calls as cancelled
540				for j := i; j < len(toolCalls); j++ {
541					toolResults[j] = message.ToolResult{
542						ToolCallID: toolCalls[j].ID,
543						Content:    "Tool execution canceled by user",
544						IsError:    true,
545					}
546				}
547				goto out
548			case result := <-resultChan:
549				toolResponse = result.response
550				toolErr = result.err
551			}
552
553			if toolErr != nil {
554				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
555				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
556					toolResults[i] = message.ToolResult{
557						ToolCallID: toolCall.ID,
558						Content:    "Permission denied",
559						IsError:    true,
560					}
561					for j := i + 1; j < len(toolCalls); j++ {
562						toolResults[j] = message.ToolResult{
563							ToolCallID: toolCalls[j].ID,
564							Content:    "Tool execution canceled by user",
565							IsError:    true,
566						}
567					}
568					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
569					break
570				}
571			}
572			toolResults[i] = message.ToolResult{
573				ToolCallID: toolCall.ID,
574				Content:    toolResponse.Content,
575				Metadata:   toolResponse.Metadata,
576				IsError:    toolResponse.IsError,
577			}
578		}
579	}
580out:
581	if len(toolResults) == 0 {
582		return assistantMsg, nil, nil
583	}
584	parts := make([]message.ContentPart, 0)
585	for _, tr := range toolResults {
586		parts = append(parts, tr)
587	}
588	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
589		Role:     message.Tool,
590		Parts:    parts,
591		Provider: a.providerID,
592	})
593	if err != nil {
594		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
595	}
596
597	return assistantMsg, &msg, err
598}
599
600func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
601	msg.AddFinish(finishReason, message, details)
602	_ = a.messages.Update(ctx, *msg)
603}
604
605func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
606	select {
607	case <-ctx.Done():
608		return ctx.Err()
609	default:
610		// Continue processing.
611	}
612
613	switch event.Type {
614	case provider.EventThinkingDelta:
615		assistantMsg.AppendReasoningContent(event.Thinking)
616		return a.messages.Update(ctx, *assistantMsg)
617	case provider.EventSignatureDelta:
618		assistantMsg.AppendReasoningSignature(event.Signature)
619		return a.messages.Update(ctx, *assistantMsg)
620	case provider.EventContentDelta:
621		assistantMsg.FinishThinking()
622		assistantMsg.AppendContent(event.Content)
623		return a.messages.Update(ctx, *assistantMsg)
624	case provider.EventToolUseStart:
625		assistantMsg.FinishThinking()
626		slog.Info("Tool call started", "toolCall", event.ToolCall)
627		assistantMsg.AddToolCall(*event.ToolCall)
628		return a.messages.Update(ctx, *assistantMsg)
629	case provider.EventToolUseDelta:
630		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
631		return a.messages.Update(ctx, *assistantMsg)
632	case provider.EventToolUseStop:
633		slog.Info("Finished tool call", "toolCall", event.ToolCall)
634		assistantMsg.FinishToolCall(event.ToolCall.ID)
635		return a.messages.Update(ctx, *assistantMsg)
636	case provider.EventError:
637		return event.Error
638	case provider.EventComplete:
639		assistantMsg.FinishThinking()
640		assistantMsg.SetToolCalls(event.Response.ToolCalls)
641		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
642		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
643			return fmt.Errorf("failed to update message: %w", err)
644		}
645		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
646	}
647
648	return nil
649}
650
651func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
652	sess, err := a.sessions.Get(ctx, sessionID)
653	if err != nil {
654		return fmt.Errorf("failed to get session: %w", err)
655	}
656
657	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
658		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
659		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
660		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
661
662	sess.Cost += cost
663	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
664	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
665
666	_, err = a.sessions.Save(ctx, sess)
667	if err != nil {
668		return fmt.Errorf("failed to save session: %w", err)
669	}
670	return nil
671}
672
673func (a *agent) Summarize(ctx context.Context, sessionID string) error {
674	if a.summarizeProvider == nil {
675		return fmt.Errorf("summarize provider not available")
676	}
677
678	// Check if session is busy
679	if a.IsSessionBusy(sessionID) {
680		return ErrSessionBusy
681	}
682
683	// Create a new context with cancellation
684	summarizeCtx, cancel := context.WithCancel(ctx)
685
686	// Store the cancel function in activeRequests to allow cancellation
687	a.activeRequests.Set(sessionID+"-summarize", cancel)
688
689	go func() {
690		defer a.activeRequests.Del(sessionID + "-summarize")
691		defer cancel()
692		event := AgentEvent{
693			Type:     AgentEventTypeSummarize,
694			Progress: "Starting summarization...",
695		}
696
697		a.Publish(pubsub.CreatedEvent, event)
698		// Get all messages from the session
699		msgs, err := a.messages.List(summarizeCtx, sessionID)
700		if err != nil {
701			event = AgentEvent{
702				Type:  AgentEventTypeError,
703				Error: fmt.Errorf("failed to list messages: %w", err),
704				Done:  true,
705			}
706			a.Publish(pubsub.CreatedEvent, event)
707			return
708		}
709		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
710
711		if len(msgs) == 0 {
712			event = AgentEvent{
713				Type:  AgentEventTypeError,
714				Error: fmt.Errorf("no messages to summarize"),
715				Done:  true,
716			}
717			a.Publish(pubsub.CreatedEvent, event)
718			return
719		}
720
721		event = AgentEvent{
722			Type:     AgentEventTypeSummarize,
723			Progress: "Analyzing conversation...",
724		}
725		a.Publish(pubsub.CreatedEvent, event)
726
727		// Add a system message to guide the summarization
728		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
729
730		// Create a new message with the summarize prompt
731		promptMsg := message.Message{
732			Role:  message.User,
733			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
734		}
735
736		// Append the prompt to the messages
737		msgsWithPrompt := append(msgs, promptMsg)
738
739		event = AgentEvent{
740			Type:     AgentEventTypeSummarize,
741			Progress: "Generating summary...",
742		}
743
744		a.Publish(pubsub.CreatedEvent, event)
745
746		// Send the messages to the summarize provider
747		response := a.summarizeProvider.StreamResponse(
748			summarizeCtx,
749			msgsWithPrompt,
750			nil,
751		)
752		var finalResponse *provider.ProviderResponse
753		for r := range response {
754			if r.Error != nil {
755				event = AgentEvent{
756					Type:  AgentEventTypeError,
757					Error: fmt.Errorf("failed to summarize: %w", err),
758					Done:  true,
759				}
760				a.Publish(pubsub.CreatedEvent, event)
761				return
762			}
763			finalResponse = r.Response
764		}
765
766		summary := strings.TrimSpace(finalResponse.Content)
767		if summary == "" {
768			event = AgentEvent{
769				Type:  AgentEventTypeError,
770				Error: fmt.Errorf("empty summary returned"),
771				Done:  true,
772			}
773			a.Publish(pubsub.CreatedEvent, event)
774			return
775		}
776		shell := shell.GetPersistentShell(config.Get().WorkingDir())
777		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
778		event = AgentEvent{
779			Type:     AgentEventTypeSummarize,
780			Progress: "Creating new session...",
781		}
782
783		a.Publish(pubsub.CreatedEvent, event)
784		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
785		if err != nil {
786			event = AgentEvent{
787				Type:  AgentEventTypeError,
788				Error: fmt.Errorf("failed to get session: %w", err),
789				Done:  true,
790			}
791
792			a.Publish(pubsub.CreatedEvent, event)
793			return
794		}
795		// Create a message in the new session with the summary
796		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
797			Role: message.Assistant,
798			Parts: []message.ContentPart{
799				message.TextContent{Text: summary},
800				message.Finish{
801					Reason: message.FinishReasonEndTurn,
802					Time:   time.Now().Unix(),
803				},
804			},
805			Model:    a.summarizeProvider.Model().ID,
806			Provider: a.summarizeProviderID,
807		})
808		if err != nil {
809			event = AgentEvent{
810				Type:  AgentEventTypeError,
811				Error: fmt.Errorf("failed to create summary message: %w", err),
812				Done:  true,
813			}
814
815			a.Publish(pubsub.CreatedEvent, event)
816			return
817		}
818		oldSession.SummaryMessageID = msg.ID
819		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
820		oldSession.PromptTokens = 0
821		model := a.summarizeProvider.Model()
822		usage := finalResponse.Usage
823		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
824			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
825			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
826			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
827		oldSession.Cost += cost
828		_, err = a.sessions.Save(summarizeCtx, oldSession)
829		if err != nil {
830			event = AgentEvent{
831				Type:  AgentEventTypeError,
832				Error: fmt.Errorf("failed to save session: %w", err),
833				Done:  true,
834			}
835			a.Publish(pubsub.CreatedEvent, event)
836		}
837
838		event = AgentEvent{
839			Type:      AgentEventTypeSummarize,
840			SessionID: oldSession.ID,
841			Progress:  "Summary complete",
842			Done:      true,
843		}
844		a.Publish(pubsub.CreatedEvent, event)
845		// Send final success event with the new session ID
846	}()
847
848	return nil
849}
850
851func (a *agent) CancelAll() {
852	if !a.IsBusy() {
853		return
854	}
855	for key := range a.activeRequests.Seq2() {
856		a.Cancel(key) // key is sessionID
857	}
858
859	timeout := time.After(5 * time.Second)
860	for a.IsBusy() {
861		select {
862		case <-timeout:
863			return
864		default:
865			time.Sleep(200 * time.Millisecond)
866		}
867	}
868}
869
870func (a *agent) UpdateModel() error {
871	cfg := config.Get()
872
873	// Get current provider configuration
874	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
875	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
876		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
877	}
878
879	// Check if provider has changed
880	if string(currentProviderCfg.ID) != a.providerID {
881		// Provider changed, need to recreate the main provider
882		model := cfg.GetModelByType(a.agentCfg.Model)
883		if model.ID == "" {
884			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
885		}
886
887		promptID := agentPromptMap[a.agentCfg.ID]
888		if promptID == "" {
889			promptID = prompt.PromptDefault
890		}
891
892		opts := []provider.ProviderClientOption{
893			provider.WithModel(a.agentCfg.Model),
894			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
895		}
896
897		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
898		if err != nil {
899			return fmt.Errorf("failed to create new provider: %w", err)
900		}
901
902		// Update the provider and provider ID
903		a.provider = newProvider
904		a.providerID = string(currentProviderCfg.ID)
905	}
906
907	// Check if small model provider has changed (affects title and summarize providers)
908	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
909	var smallModelProviderCfg config.ProviderConfig
910
911	for p := range cfg.Providers.Seq() {
912		if p.ID == smallModelCfg.Provider {
913			smallModelProviderCfg = p
914			break
915		}
916	}
917
918	if smallModelProviderCfg.ID == "" {
919		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
920	}
921
922	// Check if summarize provider has changed
923	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
924		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
925		if smallModel == nil {
926			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
927		}
928
929		// Recreate title provider
930		titleOpts := []provider.ProviderClientOption{
931			provider.WithModel(config.SelectedModelTypeSmall),
932			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
933			// We want the title to be short, so we limit the max tokens
934			provider.WithMaxTokens(40),
935		}
936		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
937		if err != nil {
938			return fmt.Errorf("failed to create new title provider: %w", err)
939		}
940
941		// Recreate summarize provider
942		summarizeOpts := []provider.ProviderClientOption{
943			provider.WithModel(config.SelectedModelTypeSmall),
944			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
945		}
946		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
947		if err != nil {
948			return fmt.Errorf("failed to create new summarize provider: %w", err)
949		}
950
951		// Update the providers and provider ID
952		a.titleProvider = newTitleProvider
953		a.summarizeProvider = newSummarizeProvider
954		a.summarizeProviderID = string(smallModelProviderCfg.ID)
955	}
956
957	return nil
958}