agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"sync/atomic"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/config"
 15	fur "github.com/charmbracelet/crush/internal/fur/provider"
 16	"github.com/charmbracelet/crush/internal/history"
 17	"github.com/charmbracelet/crush/internal/llm/prompt"
 18	"github.com/charmbracelet/crush/internal/llm/provider"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/log"
 21	"github.com/charmbracelet/crush/internal/lsp"
 22	"github.com/charmbracelet/crush/internal/message"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/pubsub"
 25	"github.com/charmbracelet/crush/internal/session"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() fur.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70
 71	toolsDone atomic.Bool
 72	tools     []tools.BaseTool
 73
 74	provider   provider.Provider
 75	providerID string
 76
 77	titleProvider       provider.Provider
 78	summarizeProvider   provider.Provider
 79	summarizeProviderID string
 80
 81	activeRequests sync.Map
 82}
 83
 84var agentPromptMap = map[string]prompt.PromptID{
 85	"coder": prompt.PromptCoder,
 86	"task":  prompt.PromptTask,
 87}
 88
 89func NewAgent(
 90	agentCfg config.Agent,
 91	// These services are needed in the tools
 92	permissions permission.Service,
 93	sessions session.Service,
 94	messages message.Service,
 95	history history.Service,
 96	lspClients map[string]*lsp.Client,
 97) (Service, error) {
 98	ctx := context.Background()
 99	cfg := config.Get()
100
101	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
102	if providerCfg == nil {
103		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
104	}
105	model := config.Get().GetModelByType(agentCfg.Model)
106
107	if model == nil {
108		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
109	}
110
111	promptID := agentPromptMap[agentCfg.ID]
112	if promptID == "" {
113		promptID = prompt.PromptDefault
114	}
115	opts := []provider.ProviderClientOption{
116		provider.WithModel(agentCfg.Model),
117		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
118	}
119	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
120	if err != nil {
121		return nil, err
122	}
123
124	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
125	var smallModelProviderCfg *config.ProviderConfig
126	if smallModelCfg.Provider == providerCfg.ID {
127		smallModelProviderCfg = providerCfg
128	} else {
129		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
130
131		if smallModelProviderCfg.ID == "" {
132			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
133		}
134	}
135	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
136	if smallModel.ID == "" {
137		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
138	}
139
140	titleOpts := []provider.ProviderClientOption{
141		provider.WithModel(config.SelectedModelTypeSmall),
142		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
143	}
144	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
145	if err != nil {
146		return nil, err
147	}
148	summarizeOpts := []provider.ProviderClientOption{
149		provider.WithModel(config.SelectedModelTypeSmall),
150		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
151	}
152	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
153	if err != nil {
154		return nil, err
155	}
156
157	var agentTool tools.BaseTool
158	if agentCfg.ID == "coder" {
159		taskAgentCfg := config.Get().Agents["task"]
160		if taskAgentCfg.ID == "" {
161			return nil, fmt.Errorf("task agent not found in config")
162		}
163		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
164		if err != nil {
165			return nil, fmt.Errorf("failed to create task agent: %w", err)
166		}
167
168		agentTool = NewAgentTool(
169			taskAgent,
170			sessions,
171			messages,
172		)
173	}
174
175	agent := &agent{
176		Broker:              pubsub.NewBroker[AgentEvent](),
177		agentCfg:            agentCfg,
178		provider:            agentProvider,
179		providerID:          string(providerCfg.ID),
180		messages:            messages,
181		sessions:            sessions,
182		titleProvider:       titleProvider,
183		summarizeProvider:   summarizeProvider,
184		summarizeProviderID: string(smallModelProviderCfg.ID),
185		activeRequests:      sync.Map{},
186	}
187
188	go func() {
189		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
190
191		cwd := cfg.WorkingDir()
192		allTools := []tools.BaseTool{
193			tools.NewBashTool(permissions, cwd),
194			tools.NewDownloadTool(permissions, cwd),
195			tools.NewEditTool(lspClients, permissions, history, cwd),
196			tools.NewFetchTool(permissions, cwd),
197			tools.NewGlobTool(cwd),
198			tools.NewGrepTool(cwd),
199			tools.NewLsTool(cwd),
200			tools.NewSourcegraphTool(),
201			tools.NewViewTool(lspClients, cwd),
202			tools.NewWriteTool(lspClients, permissions, history, cwd),
203		}
204
205		mcpTools := GetMCPTools(ctx, permissions, cfg)
206		if len(lspClients) > 0 {
207			mcpTools = append(mcpTools, tools.NewDiagnosticsTool(lspClients))
208		}
209		allTools = append(allTools, mcpTools...)
210
211		if agentTool != nil {
212			allTools = append(allTools, agentTool)
213		}
214
215		agentTools := []tools.BaseTool{}
216		if agentCfg.AllowedTools == nil {
217			agentTools = allTools
218		} else {
219			for _, tool := range allTools {
220				if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
221					agentTools = append(agentTools, tool)
222				}
223			}
224		}
225
226		slog.Info("Initialized agent tools", "agent", agentCfg.ID)
227		agent.tools = agentTools
228		agent.toolsDone.Store(true)
229	}()
230
231	return agent, nil
232}
233
234func (a *agent) Model() fur.Model {
235	return *config.Get().GetModelByType(a.agentCfg.Model)
236}
237
238func (a *agent) Cancel(sessionID string) {
239	// Cancel regular requests
240	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
241		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
242			slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
243			cancel()
244		}
245	}
246
247	// Also check for summarize requests
248	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
249		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
250			slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
251			cancel()
252		}
253	}
254}
255
256func (a *agent) IsBusy() bool {
257	busy := false
258	a.activeRequests.Range(func(key, value any) bool {
259		if cancelFunc, ok := value.(context.CancelFunc); ok {
260			if cancelFunc != nil {
261				busy = true
262				return false
263			}
264		}
265		return true
266	})
267	return busy
268}
269
270func (a *agent) IsSessionBusy(sessionID string) bool {
271	_, busy := a.activeRequests.Load(sessionID)
272	return busy
273}
274
275func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
276	if content == "" {
277		return nil
278	}
279	if a.titleProvider == nil {
280		return nil
281	}
282	session, err := a.sessions.Get(ctx, sessionID)
283	if err != nil {
284		return err
285	}
286	parts := []message.ContentPart{message.TextContent{
287		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
288	}}
289
290	// Use streaming approach like summarization
291	response := a.titleProvider.StreamResponse(
292		ctx,
293		[]message.Message{
294			{
295				Role:  message.User,
296				Parts: parts,
297			},
298		},
299		make([]tools.BaseTool, 0),
300	)
301
302	var finalResponse *provider.ProviderResponse
303	for r := range response {
304		if r.Error != nil {
305			return r.Error
306		}
307		finalResponse = r.Response
308	}
309
310	if finalResponse == nil {
311		return fmt.Errorf("no response received from title provider")
312	}
313
314	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
315	if title == "" {
316		return nil
317	}
318
319	session.Title = title
320	_, err = a.sessions.Save(ctx, session)
321	return err
322}
323
324func (a *agent) err(err error) AgentEvent {
325	return AgentEvent{
326		Type:  AgentEventTypeError,
327		Error: err,
328	}
329}
330
331func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
332	if !a.Model().SupportsImages && attachments != nil {
333		attachments = nil
334	}
335	events := make(chan AgentEvent)
336	if a.IsSessionBusy(sessionID) {
337		return nil, ErrSessionBusy
338	}
339
340	genCtx, cancel := context.WithCancel(ctx)
341
342	a.activeRequests.Store(sessionID, cancel)
343	go func() {
344		slog.Debug("Request started", "sessionID", sessionID)
345		defer log.RecoverPanic("agent.Run", func() {
346			events <- a.err(fmt.Errorf("panic while running the agent"))
347		})
348		var attachmentParts []message.ContentPart
349		for _, attachment := range attachments {
350			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
351		}
352		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
353		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
354			slog.Error(result.Error.Error())
355		}
356		slog.Debug("Request completed", "sessionID", sessionID)
357		a.activeRequests.Delete(sessionID)
358		cancel()
359		a.Publish(pubsub.CreatedEvent, result)
360		events <- result
361		close(events)
362	}()
363	return events, nil
364}
365
366func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
367	cfg := config.Get()
368	// List existing messages; if none, start title generation asynchronously.
369	msgs, err := a.messages.List(ctx, sessionID)
370	if err != nil {
371		return a.err(fmt.Errorf("failed to list messages: %w", err))
372	}
373	if len(msgs) == 0 {
374		go func() {
375			defer log.RecoverPanic("agent.Run", func() {
376				slog.Error("panic while generating title")
377			})
378			titleErr := a.generateTitle(context.Background(), sessionID, content)
379			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
380				slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
381			}
382		}()
383	}
384	session, err := a.sessions.Get(ctx, sessionID)
385	if err != nil {
386		return a.err(fmt.Errorf("failed to get session: %w", err))
387	}
388	if session.SummaryMessageID != "" {
389		summaryMsgInex := -1
390		for i, msg := range msgs {
391			if msg.ID == session.SummaryMessageID {
392				summaryMsgInex = i
393				break
394			}
395		}
396		if summaryMsgInex != -1 {
397			msgs = msgs[summaryMsgInex:]
398			msgs[0].Role = message.User
399		}
400	}
401
402	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
403	if err != nil {
404		return a.err(fmt.Errorf("failed to create user message: %w", err))
405	}
406	// Append the new user message to the conversation history.
407	msgHistory := append(msgs, userMsg)
408
409	for {
410		// Check for cancellation before each iteration
411		select {
412		case <-ctx.Done():
413			return a.err(ctx.Err())
414		default:
415			// Continue processing
416		}
417		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
418		if err != nil {
419			if errors.Is(err, context.Canceled) {
420				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
421				a.messages.Update(context.Background(), agentMessage)
422				return a.err(ErrRequestCancelled)
423			}
424			return a.err(fmt.Errorf("failed to process events: %w", err))
425		}
426		if cfg.Options.Debug {
427			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
428		}
429		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
430			// We are not done, we need to respond with the tool response
431			msgHistory = append(msgHistory, agentMessage, *toolResults)
432			continue
433		}
434		return AgentEvent{
435			Type:    AgentEventTypeResponse,
436			Message: agentMessage,
437			Done:    true,
438		}
439	}
440}
441
442func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
443	parts := []message.ContentPart{message.TextContent{Text: content}}
444	parts = append(parts, attachmentParts...)
445	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
446		Role:  message.User,
447		Parts: parts,
448	})
449}
450
451func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
452	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
453	if !a.toolsDone.Load() {
454		return message.Message{}, nil, fmt.Errorf("tools not initialized yet")
455	}
456	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
457
458	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
459		Role:     message.Assistant,
460		Parts:    []message.ContentPart{},
461		Model:    a.Model().ID,
462		Provider: a.providerID,
463	})
464	if err != nil {
465		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
466	}
467
468	// Add the session and message ID into the context if needed by tools.
469	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
470
471	// Process each event in the stream.
472	for event := range eventChan {
473		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
474			if errors.Is(processErr, context.Canceled) {
475				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476			} else {
477				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
478			}
479			return assistantMsg, nil, processErr
480		}
481		if ctx.Err() != nil {
482			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
483			return assistantMsg, nil, ctx.Err()
484		}
485	}
486
487	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
488	toolCalls := assistantMsg.ToolCalls()
489	for i, toolCall := range toolCalls {
490		select {
491		case <-ctx.Done():
492			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
493			// Make all future tool calls cancelled
494			for j := i; j < len(toolCalls); j++ {
495				toolResults[j] = message.ToolResult{
496					ToolCallID: toolCalls[j].ID,
497					Content:    "Tool execution canceled by user",
498					IsError:    true,
499				}
500			}
501			goto out
502		default:
503			// Continue processing
504			var tool tools.BaseTool
505			for _, availableTool := range a.tools {
506				if availableTool.Info().Name == toolCall.Name {
507					tool = availableTool
508					break
509				}
510			}
511
512			// Tool not found
513			if tool == nil {
514				toolResults[i] = message.ToolResult{
515					ToolCallID: toolCall.ID,
516					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
517					IsError:    true,
518				}
519				continue
520			}
521
522			// Run tool in goroutine to allow cancellation
523			type toolExecResult struct {
524				response tools.ToolResponse
525				err      error
526			}
527			resultChan := make(chan toolExecResult, 1)
528
529			go func() {
530				response, err := tool.Run(ctx, tools.ToolCall{
531					ID:    toolCall.ID,
532					Name:  toolCall.Name,
533					Input: toolCall.Input,
534				})
535				resultChan <- toolExecResult{response: response, err: err}
536			}()
537
538			var toolResponse tools.ToolResponse
539			var toolErr error
540
541			select {
542			case <-ctx.Done():
543				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
544				// Mark remaining tool calls as cancelled
545				for j := i; j < len(toolCalls); j++ {
546					toolResults[j] = message.ToolResult{
547						ToolCallID: toolCalls[j].ID,
548						Content:    "Tool execution canceled by user",
549						IsError:    true,
550					}
551				}
552				goto out
553			case result := <-resultChan:
554				toolResponse = result.response
555				toolErr = result.err
556			}
557
558			if toolErr != nil {
559				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
560				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
561					toolResults[i] = message.ToolResult{
562						ToolCallID: toolCall.ID,
563						Content:    "Permission denied",
564						IsError:    true,
565					}
566					for j := i + 1; j < len(toolCalls); j++ {
567						toolResults[j] = message.ToolResult{
568							ToolCallID: toolCalls[j].ID,
569							Content:    "Tool execution canceled by user",
570							IsError:    true,
571						}
572					}
573					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
574					break
575				}
576			}
577			toolResults[i] = message.ToolResult{
578				ToolCallID: toolCall.ID,
579				Content:    toolResponse.Content,
580				Metadata:   toolResponse.Metadata,
581				IsError:    toolResponse.IsError,
582			}
583		}
584	}
585out:
586	if len(toolResults) == 0 {
587		return assistantMsg, nil, nil
588	}
589	parts := make([]message.ContentPart, 0)
590	for _, tr := range toolResults {
591		parts = append(parts, tr)
592	}
593	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
594		Role:     message.Tool,
595		Parts:    parts,
596		Provider: a.providerID,
597	})
598	if err != nil {
599		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
600	}
601
602	return assistantMsg, &msg, err
603}
604
605func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
606	msg.AddFinish(finishReason, message, details)
607	_ = a.messages.Update(ctx, *msg)
608}
609
610func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
611	select {
612	case <-ctx.Done():
613		return ctx.Err()
614	default:
615		// Continue processing.
616	}
617
618	switch event.Type {
619	case provider.EventThinkingDelta:
620		assistantMsg.AppendReasoningContent(event.Thinking)
621		return a.messages.Update(ctx, *assistantMsg)
622	case provider.EventSignatureDelta:
623		assistantMsg.AppendReasoningSignature(event.Signature)
624		return a.messages.Update(ctx, *assistantMsg)
625	case provider.EventContentDelta:
626		assistantMsg.FinishThinking()
627		assistantMsg.AppendContent(event.Content)
628		return a.messages.Update(ctx, *assistantMsg)
629	case provider.EventToolUseStart:
630		assistantMsg.FinishThinking()
631		slog.Info("Tool call started", "toolCall", event.ToolCall)
632		assistantMsg.AddToolCall(*event.ToolCall)
633		return a.messages.Update(ctx, *assistantMsg)
634	case provider.EventToolUseDelta:
635		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
636		return a.messages.Update(ctx, *assistantMsg)
637	case provider.EventToolUseStop:
638		slog.Info("Finished tool call", "toolCall", event.ToolCall)
639		assistantMsg.FinishToolCall(event.ToolCall.ID)
640		return a.messages.Update(ctx, *assistantMsg)
641	case provider.EventError:
642		return event.Error
643	case provider.EventComplete:
644		assistantMsg.FinishThinking()
645		assistantMsg.SetToolCalls(event.Response.ToolCalls)
646		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
647		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
648			return fmt.Errorf("failed to update message: %w", err)
649		}
650		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
651	}
652
653	return nil
654}
655
656func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
657	sess, err := a.sessions.Get(ctx, sessionID)
658	if err != nil {
659		return fmt.Errorf("failed to get session: %w", err)
660	}
661
662	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
663		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
664		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
665		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
666
667	sess.Cost += cost
668	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
669	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
670
671	_, err = a.sessions.Save(ctx, sess)
672	if err != nil {
673		return fmt.Errorf("failed to save session: %w", err)
674	}
675	return nil
676}
677
678func (a *agent) Summarize(ctx context.Context, sessionID string) error {
679	if a.summarizeProvider == nil {
680		return fmt.Errorf("summarize provider not available")
681	}
682
683	// Check if session is busy
684	if a.IsSessionBusy(sessionID) {
685		return ErrSessionBusy
686	}
687
688	// Create a new context with cancellation
689	summarizeCtx, cancel := context.WithCancel(ctx)
690
691	// Store the cancel function in activeRequests to allow cancellation
692	a.activeRequests.Store(sessionID+"-summarize", cancel)
693
694	go func() {
695		defer a.activeRequests.Delete(sessionID + "-summarize")
696		defer cancel()
697		event := AgentEvent{
698			Type:     AgentEventTypeSummarize,
699			Progress: "Starting summarization...",
700		}
701
702		a.Publish(pubsub.CreatedEvent, event)
703		// Get all messages from the session
704		msgs, err := a.messages.List(summarizeCtx, sessionID)
705		if err != nil {
706			event = AgentEvent{
707				Type:  AgentEventTypeError,
708				Error: fmt.Errorf("failed to list messages: %w", err),
709				Done:  true,
710			}
711			a.Publish(pubsub.CreatedEvent, event)
712			return
713		}
714		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
715
716		if len(msgs) == 0 {
717			event = AgentEvent{
718				Type:  AgentEventTypeError,
719				Error: fmt.Errorf("no messages to summarize"),
720				Done:  true,
721			}
722			a.Publish(pubsub.CreatedEvent, event)
723			return
724		}
725
726		event = AgentEvent{
727			Type:     AgentEventTypeSummarize,
728			Progress: "Analyzing conversation...",
729		}
730		a.Publish(pubsub.CreatedEvent, event)
731
732		// Add a system message to guide the summarization
733		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
734
735		// Create a new message with the summarize prompt
736		promptMsg := message.Message{
737			Role:  message.User,
738			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
739		}
740
741		// Append the prompt to the messages
742		msgsWithPrompt := append(msgs, promptMsg)
743
744		event = AgentEvent{
745			Type:     AgentEventTypeSummarize,
746			Progress: "Generating summary...",
747		}
748
749		a.Publish(pubsub.CreatedEvent, event)
750
751		// Send the messages to the summarize provider
752		response := a.summarizeProvider.StreamResponse(
753			summarizeCtx,
754			msgsWithPrompt,
755			make([]tools.BaseTool, 0),
756		)
757		var finalResponse *provider.ProviderResponse
758		for r := range response {
759			if r.Error != nil {
760				event = AgentEvent{
761					Type:  AgentEventTypeError,
762					Error: fmt.Errorf("failed to summarize: %w", err),
763					Done:  true,
764				}
765				a.Publish(pubsub.CreatedEvent, event)
766				return
767			}
768			finalResponse = r.Response
769		}
770
771		summary := strings.TrimSpace(finalResponse.Content)
772		if summary == "" {
773			event = AgentEvent{
774				Type:  AgentEventTypeError,
775				Error: fmt.Errorf("empty summary returned"),
776				Done:  true,
777			}
778			a.Publish(pubsub.CreatedEvent, event)
779			return
780		}
781		event = AgentEvent{
782			Type:     AgentEventTypeSummarize,
783			Progress: "Creating new session...",
784		}
785
786		a.Publish(pubsub.CreatedEvent, event)
787		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
788		if err != nil {
789			event = AgentEvent{
790				Type:  AgentEventTypeError,
791				Error: fmt.Errorf("failed to get session: %w", err),
792				Done:  true,
793			}
794
795			a.Publish(pubsub.CreatedEvent, event)
796			return
797		}
798		// Create a message in the new session with the summary
799		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
800			Role: message.Assistant,
801			Parts: []message.ContentPart{
802				message.TextContent{Text: summary},
803				message.Finish{
804					Reason: message.FinishReasonEndTurn,
805					Time:   time.Now().Unix(),
806				},
807			},
808			Model:    a.summarizeProvider.Model().ID,
809			Provider: a.summarizeProviderID,
810		})
811		if err != nil {
812			event = AgentEvent{
813				Type:  AgentEventTypeError,
814				Error: fmt.Errorf("failed to create summary message: %w", err),
815				Done:  true,
816			}
817
818			a.Publish(pubsub.CreatedEvent, event)
819			return
820		}
821		oldSession.SummaryMessageID = msg.ID
822		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
823		oldSession.PromptTokens = 0
824		model := a.summarizeProvider.Model()
825		usage := finalResponse.Usage
826		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
827			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
828			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
829			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
830		oldSession.Cost += cost
831		_, err = a.sessions.Save(summarizeCtx, oldSession)
832		if err != nil {
833			event = AgentEvent{
834				Type:  AgentEventTypeError,
835				Error: fmt.Errorf("failed to save session: %w", err),
836				Done:  true,
837			}
838			a.Publish(pubsub.CreatedEvent, event)
839		}
840
841		event = AgentEvent{
842			Type:      AgentEventTypeSummarize,
843			SessionID: oldSession.ID,
844			Progress:  "Summary complete",
845			Done:      true,
846		}
847		a.Publish(pubsub.CreatedEvent, event)
848		// Send final success event with the new session ID
849	}()
850
851	return nil
852}
853
854func (a *agent) CancelAll() {
855	if !a.IsBusy() {
856		return
857	}
858	a.activeRequests.Range(func(key, value any) bool {
859		a.Cancel(key.(string)) // key is sessionID
860		return true
861	})
862
863	timeout := time.After(5 * time.Second)
864	for a.IsBusy() {
865		select {
866		case <-timeout:
867			return
868		default:
869			time.Sleep(200 * time.Millisecond)
870		}
871	}
872}
873
874func (a *agent) UpdateModel() error {
875	cfg := config.Get()
876
877	// Get current provider configuration
878	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
879	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
880		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
881	}
882
883	// Check if provider has changed
884	if string(currentProviderCfg.ID) != a.providerID {
885		// Provider changed, need to recreate the main provider
886		model := cfg.GetModelByType(a.agentCfg.Model)
887		if model.ID == "" {
888			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
889		}
890
891		promptID := agentPromptMap[a.agentCfg.ID]
892		if promptID == "" {
893			promptID = prompt.PromptDefault
894		}
895
896		opts := []provider.ProviderClientOption{
897			provider.WithModel(a.agentCfg.Model),
898			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
899		}
900
901		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
902		if err != nil {
903			return fmt.Errorf("failed to create new provider: %w", err)
904		}
905
906		// Update the provider and provider ID
907		a.provider = newProvider
908		a.providerID = string(currentProviderCfg.ID)
909	}
910
911	// Check if small model provider has changed (affects title and summarize providers)
912	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
913	var smallModelProviderCfg config.ProviderConfig
914
915	for _, p := range cfg.Providers {
916		if p.ID == smallModelCfg.Provider {
917			smallModelProviderCfg = p
918			break
919		}
920	}
921
922	if smallModelProviderCfg.ID == "" {
923		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
924	}
925
926	// Check if summarize provider has changed
927	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
928		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
929		if smallModel == nil {
930			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
931		}
932
933		// Recreate title provider
934		titleOpts := []provider.ProviderClientOption{
935			provider.WithModel(config.SelectedModelTypeSmall),
936			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
937			// We want the title to be short, so we limit the max tokens
938			provider.WithMaxTokens(40),
939		}
940		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
941		if err != nil {
942			return fmt.Errorf("failed to create new title provider: %w", err)
943		}
944
945		// Recreate summarize provider
946		summarizeOpts := []provider.ProviderClientOption{
947			provider.WithModel(config.SelectedModelTypeSmall),
948			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
949		}
950		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
951		if err != nil {
952			return fmt.Errorf("failed to create new summarize provider: %w", err)
953		}
954
955		// Update the providers and provider ID
956		a.titleProvider = newTitleProvider
957		a.summarizeProvider = newSummarizeProvider
958		a.summarizeProviderID = string(smallModelProviderCfg.ID)
959	}
960
961	return nil
962}