agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"sync/atomic"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/config"
 15	fur "github.com/charmbracelet/crush/internal/fur/provider"
 16	"github.com/charmbracelet/crush/internal/history"
 17	"github.com/charmbracelet/crush/internal/llm/prompt"
 18	"github.com/charmbracelet/crush/internal/llm/provider"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/log"
 21	"github.com/charmbracelet/crush/internal/lsp"
 22	"github.com/charmbracelet/crush/internal/message"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/pubsub"
 25	"github.com/charmbracelet/crush/internal/session"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() fur.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70
 71	toolsDone atomic.Bool
 72	tools     []tools.BaseTool
 73
 74	provider   provider.Provider
 75	providerID string
 76
 77	titleProvider       provider.Provider
 78	summarizeProvider   provider.Provider
 79	summarizeProviderID string
 80
 81	activeRequests sync.Map
 82}
 83
 84var agentPromptMap = map[string]prompt.PromptID{
 85	"coder": prompt.PromptCoder,
 86	"task":  prompt.PromptTask,
 87}
 88
 89func NewAgent(
 90	agentCfg config.Agent,
 91	// These services are needed in the tools
 92	permissions permission.Service,
 93	sessions session.Service,
 94	messages message.Service,
 95	history history.Service,
 96	lspClients map[string]*lsp.Client,
 97) (Service, error) {
 98	ctx := context.Background()
 99	cfg := config.Get()
100
101	var agentTool tools.BaseTool
102	if agentCfg.ID == "coder" {
103		taskAgentCfg := config.Get().Agents["task"]
104		if taskAgentCfg.ID == "" {
105			return nil, fmt.Errorf("task agent not found in config")
106		}
107		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
108		if err != nil {
109			return nil, fmt.Errorf("failed to create task agent: %w", err)
110		}
111
112		agentTool = NewAgentTool(
113			taskAgent,
114			sessions,
115			messages,
116		)
117	}
118
119	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
120	if providerCfg == nil {
121		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
122	}
123	model := config.Get().GetModelByType(agentCfg.Model)
124
125	if model == nil {
126		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
127	}
128
129	promptID := agentPromptMap[agentCfg.ID]
130	if promptID == "" {
131		promptID = prompt.PromptDefault
132	}
133	opts := []provider.ProviderClientOption{
134		provider.WithModel(agentCfg.Model),
135		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
136	}
137	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
138	if err != nil {
139		return nil, err
140	}
141
142	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
143	var smallModelProviderCfg *config.ProviderConfig
144	if smallModelCfg.Provider == providerCfg.ID {
145		smallModelProviderCfg = providerCfg
146	} else {
147		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
148
149		if smallModelProviderCfg.ID == "" {
150			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
151		}
152	}
153	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
154	if smallModel.ID == "" {
155		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
156	}
157
158	titleOpts := []provider.ProviderClientOption{
159		provider.WithModel(config.SelectedModelTypeSmall),
160		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
161	}
162	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
163	if err != nil {
164		return nil, err
165	}
166	summarizeOpts := []provider.ProviderClientOption{
167		provider.WithModel(config.SelectedModelTypeSmall),
168		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
169	}
170	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
171	if err != nil {
172		return nil, err
173	}
174
175	agent := &agent{
176		Broker:              pubsub.NewBroker[AgentEvent](),
177		agentCfg:            agentCfg,
178		provider:            agentProvider,
179		providerID:          string(providerCfg.ID),
180		messages:            messages,
181		sessions:            sessions,
182		titleProvider:       titleProvider,
183		summarizeProvider:   summarizeProvider,
184		summarizeProviderID: string(smallModelProviderCfg.ID),
185		activeRequests:      sync.Map{},
186	}
187
188	go func() {
189		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
190		defer func() {
191			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
192			agent.toolsDone.Store(true)
193		}()
194
195		cwd := cfg.WorkingDir()
196		allTools := []tools.BaseTool{
197			tools.NewBashTool(permissions, cwd),
198			tools.NewDownloadTool(permissions, cwd),
199			tools.NewEditTool(lspClients, permissions, history, cwd),
200			tools.NewFetchTool(permissions, cwd),
201			tools.NewGlobTool(cwd),
202			tools.NewGrepTool(cwd),
203			tools.NewLsTool(cwd),
204			tools.NewSourcegraphTool(),
205			tools.NewViewTool(lspClients, cwd),
206			tools.NewWriteTool(lspClients, permissions, history, cwd),
207		}
208
209		mcpTools := GetMCPTools(ctx, permissions, cfg)
210		allTools = append(allTools, mcpTools...)
211
212		if len(lspClients) > 0 {
213			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
214		}
215
216		if agentTool != nil {
217			allTools = append(allTools, agentTool)
218		}
219
220		if agentCfg.AllowedTools == nil {
221			agent.tools = allTools
222			return
223		}
224
225		var filteredTools []tools.BaseTool
226		for _, tool := range allTools {
227			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
228				filteredTools = append(filteredTools, tool)
229			}
230		}
231		agent.tools = filteredTools
232	}()
233
234	return agent, nil
235}
236
237func (a *agent) Model() fur.Model {
238	return *config.Get().GetModelByType(a.agentCfg.Model)
239}
240
241func (a *agent) Cancel(sessionID string) {
242	// Cancel regular requests
243	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
244		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
245			slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
246			cancel()
247		}
248	}
249
250	// Also check for summarize requests
251	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
252		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
253			slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
254			cancel()
255		}
256	}
257}
258
259func (a *agent) IsBusy() bool {
260	busy := false
261	a.activeRequests.Range(func(key, value any) bool {
262		if cancelFunc, ok := value.(context.CancelFunc); ok {
263			if cancelFunc != nil {
264				busy = true
265				return false
266			}
267		}
268		return true
269	})
270	return busy
271}
272
273func (a *agent) IsSessionBusy(sessionID string) bool {
274	_, busy := a.activeRequests.Load(sessionID)
275	return busy
276}
277
278func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
279	if content == "" {
280		return nil
281	}
282	if a.titleProvider == nil {
283		return nil
284	}
285	session, err := a.sessions.Get(ctx, sessionID)
286	if err != nil {
287		return err
288	}
289	parts := []message.ContentPart{message.TextContent{
290		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
291	}}
292
293	// Use streaming approach like summarization
294	response := a.titleProvider.StreamResponse(
295		ctx,
296		[]message.Message{
297			{
298				Role:  message.User,
299				Parts: parts,
300			},
301		},
302		make([]tools.BaseTool, 0),
303	)
304
305	var finalResponse *provider.ProviderResponse
306	for r := range response {
307		if r.Error != nil {
308			return r.Error
309		}
310		finalResponse = r.Response
311	}
312
313	if finalResponse == nil {
314		return fmt.Errorf("no response received from title provider")
315	}
316
317	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
318	if title == "" {
319		return nil
320	}
321
322	session.Title = title
323	_, err = a.sessions.Save(ctx, session)
324	return err
325}
326
327func (a *agent) err(err error) AgentEvent {
328	return AgentEvent{
329		Type:  AgentEventTypeError,
330		Error: err,
331	}
332}
333
334func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
335	if !a.Model().SupportsImages && attachments != nil {
336		attachments = nil
337	}
338	events := make(chan AgentEvent)
339	if a.IsSessionBusy(sessionID) {
340		return nil, ErrSessionBusy
341	}
342
343	genCtx, cancel := context.WithCancel(ctx)
344
345	a.activeRequests.Store(sessionID, cancel)
346	go func() {
347		slog.Debug("Request started", "sessionID", sessionID)
348		defer log.RecoverPanic("agent.Run", func() {
349			events <- a.err(fmt.Errorf("panic while running the agent"))
350		})
351		var attachmentParts []message.ContentPart
352		for _, attachment := range attachments {
353			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
354		}
355		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
356		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
357			slog.Error(result.Error.Error())
358		}
359		slog.Debug("Request completed", "sessionID", sessionID)
360		a.activeRequests.Delete(sessionID)
361		cancel()
362		a.Publish(pubsub.CreatedEvent, result)
363		events <- result
364		close(events)
365	}()
366	return events, nil
367}
368
369func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
370	cfg := config.Get()
371	// List existing messages; if none, start title generation asynchronously.
372	msgs, err := a.messages.List(ctx, sessionID)
373	if err != nil {
374		return a.err(fmt.Errorf("failed to list messages: %w", err))
375	}
376	if len(msgs) == 0 {
377		go func() {
378			defer log.RecoverPanic("agent.Run", func() {
379				slog.Error("panic while generating title")
380			})
381			titleErr := a.generateTitle(context.Background(), sessionID, content)
382			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
383				slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
384			}
385		}()
386	}
387	session, err := a.sessions.Get(ctx, sessionID)
388	if err != nil {
389		return a.err(fmt.Errorf("failed to get session: %w", err))
390	}
391	if session.SummaryMessageID != "" {
392		summaryMsgInex := -1
393		for i, msg := range msgs {
394			if msg.ID == session.SummaryMessageID {
395				summaryMsgInex = i
396				break
397			}
398		}
399		if summaryMsgInex != -1 {
400			msgs = msgs[summaryMsgInex:]
401			msgs[0].Role = message.User
402		}
403	}
404
405	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
406	if err != nil {
407		return a.err(fmt.Errorf("failed to create user message: %w", err))
408	}
409	// Append the new user message to the conversation history.
410	msgHistory := append(msgs, userMsg)
411
412	for {
413		// Check for cancellation before each iteration
414		select {
415		case <-ctx.Done():
416			return a.err(ctx.Err())
417		default:
418			// Continue processing
419		}
420		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
421		if err != nil {
422			if errors.Is(err, context.Canceled) {
423				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
424				a.messages.Update(context.Background(), agentMessage)
425				return a.err(ErrRequestCancelled)
426			}
427			return a.err(fmt.Errorf("failed to process events: %w", err))
428		}
429		if cfg.Options.Debug {
430			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
431		}
432		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
433			// We are not done, we need to respond with the tool response
434			msgHistory = append(msgHistory, agentMessage, *toolResults)
435			continue
436		}
437		return AgentEvent{
438			Type:    AgentEventTypeResponse,
439			Message: agentMessage,
440			Done:    true,
441		}
442	}
443}
444
445func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
446	parts := []message.ContentPart{message.TextContent{Text: content}}
447	parts = append(parts, attachmentParts...)
448	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
449		Role:  message.User,
450		Parts: parts,
451	})
452}
453
454func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
455	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
456	if !a.toolsDone.Load() {
457		return message.Message{}, nil, fmt.Errorf("tools not initialized yet")
458	}
459	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
460
461	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
462		Role:     message.Assistant,
463		Parts:    []message.ContentPart{},
464		Model:    a.Model().ID,
465		Provider: a.providerID,
466	})
467	if err != nil {
468		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
469	}
470
471	// Add the session and message ID into the context if needed by tools.
472	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
473
474	// Process each event in the stream.
475	for event := range eventChan {
476		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
477			if errors.Is(processErr, context.Canceled) {
478				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
479			} else {
480				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
481			}
482			return assistantMsg, nil, processErr
483		}
484		if ctx.Err() != nil {
485			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
486			return assistantMsg, nil, ctx.Err()
487		}
488	}
489
490	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
491	toolCalls := assistantMsg.ToolCalls()
492	for i, toolCall := range toolCalls {
493		select {
494		case <-ctx.Done():
495			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
496			// Make all future tool calls cancelled
497			for j := i; j < len(toolCalls); j++ {
498				toolResults[j] = message.ToolResult{
499					ToolCallID: toolCalls[j].ID,
500					Content:    "Tool execution canceled by user",
501					IsError:    true,
502				}
503			}
504			goto out
505		default:
506			// Continue processing
507			var tool tools.BaseTool
508			for _, availableTool := range a.tools {
509				if availableTool.Info().Name == toolCall.Name {
510					tool = availableTool
511					break
512				}
513			}
514
515			// Tool not found
516			if tool == nil {
517				toolResults[i] = message.ToolResult{
518					ToolCallID: toolCall.ID,
519					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
520					IsError:    true,
521				}
522				continue
523			}
524
525			// Run tool in goroutine to allow cancellation
526			type toolExecResult struct {
527				response tools.ToolResponse
528				err      error
529			}
530			resultChan := make(chan toolExecResult, 1)
531
532			go func() {
533				response, err := tool.Run(ctx, tools.ToolCall{
534					ID:    toolCall.ID,
535					Name:  toolCall.Name,
536					Input: toolCall.Input,
537				})
538				resultChan <- toolExecResult{response: response, err: err}
539			}()
540
541			var toolResponse tools.ToolResponse
542			var toolErr error
543
544			select {
545			case <-ctx.Done():
546				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
547				// Mark remaining tool calls as cancelled
548				for j := i; j < len(toolCalls); j++ {
549					toolResults[j] = message.ToolResult{
550						ToolCallID: toolCalls[j].ID,
551						Content:    "Tool execution canceled by user",
552						IsError:    true,
553					}
554				}
555				goto out
556			case result := <-resultChan:
557				toolResponse = result.response
558				toolErr = result.err
559			}
560
561			if toolErr != nil {
562				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
563				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
564					toolResults[i] = message.ToolResult{
565						ToolCallID: toolCall.ID,
566						Content:    "Permission denied",
567						IsError:    true,
568					}
569					for j := i + 1; j < len(toolCalls); j++ {
570						toolResults[j] = message.ToolResult{
571							ToolCallID: toolCalls[j].ID,
572							Content:    "Tool execution canceled by user",
573							IsError:    true,
574						}
575					}
576					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
577					break
578				}
579			}
580			toolResults[i] = message.ToolResult{
581				ToolCallID: toolCall.ID,
582				Content:    toolResponse.Content,
583				Metadata:   toolResponse.Metadata,
584				IsError:    toolResponse.IsError,
585			}
586		}
587	}
588out:
589	if len(toolResults) == 0 {
590		return assistantMsg, nil, nil
591	}
592	parts := make([]message.ContentPart, 0)
593	for _, tr := range toolResults {
594		parts = append(parts, tr)
595	}
596	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
597		Role:     message.Tool,
598		Parts:    parts,
599		Provider: a.providerID,
600	})
601	if err != nil {
602		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
603	}
604
605	return assistantMsg, &msg, err
606}
607
608func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
609	msg.AddFinish(finishReason, message, details)
610	_ = a.messages.Update(ctx, *msg)
611}
612
613func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
614	select {
615	case <-ctx.Done():
616		return ctx.Err()
617	default:
618		// Continue processing.
619	}
620
621	switch event.Type {
622	case provider.EventThinkingDelta:
623		assistantMsg.AppendReasoningContent(event.Thinking)
624		return a.messages.Update(ctx, *assistantMsg)
625	case provider.EventSignatureDelta:
626		assistantMsg.AppendReasoningSignature(event.Signature)
627		return a.messages.Update(ctx, *assistantMsg)
628	case provider.EventContentDelta:
629		assistantMsg.FinishThinking()
630		assistantMsg.AppendContent(event.Content)
631		return a.messages.Update(ctx, *assistantMsg)
632	case provider.EventToolUseStart:
633		assistantMsg.FinishThinking()
634		slog.Info("Tool call started", "toolCall", event.ToolCall)
635		assistantMsg.AddToolCall(*event.ToolCall)
636		return a.messages.Update(ctx, *assistantMsg)
637	case provider.EventToolUseDelta:
638		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
639		return a.messages.Update(ctx, *assistantMsg)
640	case provider.EventToolUseStop:
641		slog.Info("Finished tool call", "toolCall", event.ToolCall)
642		assistantMsg.FinishToolCall(event.ToolCall.ID)
643		return a.messages.Update(ctx, *assistantMsg)
644	case provider.EventError:
645		return event.Error
646	case provider.EventComplete:
647		assistantMsg.FinishThinking()
648		assistantMsg.SetToolCalls(event.Response.ToolCalls)
649		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
650		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
651			return fmt.Errorf("failed to update message: %w", err)
652		}
653		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
654	}
655
656	return nil
657}
658
659func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
660	sess, err := a.sessions.Get(ctx, sessionID)
661	if err != nil {
662		return fmt.Errorf("failed to get session: %w", err)
663	}
664
665	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
666		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
667		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
668		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
669
670	sess.Cost += cost
671	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
672	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
673
674	_, err = a.sessions.Save(ctx, sess)
675	if err != nil {
676		return fmt.Errorf("failed to save session: %w", err)
677	}
678	return nil
679}
680
681func (a *agent) Summarize(ctx context.Context, sessionID string) error {
682	if a.summarizeProvider == nil {
683		return fmt.Errorf("summarize provider not available")
684	}
685
686	// Check if session is busy
687	if a.IsSessionBusy(sessionID) {
688		return ErrSessionBusy
689	}
690
691	// Create a new context with cancellation
692	summarizeCtx, cancel := context.WithCancel(ctx)
693
694	// Store the cancel function in activeRequests to allow cancellation
695	a.activeRequests.Store(sessionID+"-summarize", cancel)
696
697	go func() {
698		defer a.activeRequests.Delete(sessionID + "-summarize")
699		defer cancel()
700		event := AgentEvent{
701			Type:     AgentEventTypeSummarize,
702			Progress: "Starting summarization...",
703		}
704
705		a.Publish(pubsub.CreatedEvent, event)
706		// Get all messages from the session
707		msgs, err := a.messages.List(summarizeCtx, sessionID)
708		if err != nil {
709			event = AgentEvent{
710				Type:  AgentEventTypeError,
711				Error: fmt.Errorf("failed to list messages: %w", err),
712				Done:  true,
713			}
714			a.Publish(pubsub.CreatedEvent, event)
715			return
716		}
717		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
718
719		if len(msgs) == 0 {
720			event = AgentEvent{
721				Type:  AgentEventTypeError,
722				Error: fmt.Errorf("no messages to summarize"),
723				Done:  true,
724			}
725			a.Publish(pubsub.CreatedEvent, event)
726			return
727		}
728
729		event = AgentEvent{
730			Type:     AgentEventTypeSummarize,
731			Progress: "Analyzing conversation...",
732		}
733		a.Publish(pubsub.CreatedEvent, event)
734
735		// Add a system message to guide the summarization
736		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
737
738		// Create a new message with the summarize prompt
739		promptMsg := message.Message{
740			Role:  message.User,
741			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
742		}
743
744		// Append the prompt to the messages
745		msgsWithPrompt := append(msgs, promptMsg)
746
747		event = AgentEvent{
748			Type:     AgentEventTypeSummarize,
749			Progress: "Generating summary...",
750		}
751
752		a.Publish(pubsub.CreatedEvent, event)
753
754		// Send the messages to the summarize provider
755		response := a.summarizeProvider.StreamResponse(
756			summarizeCtx,
757			msgsWithPrompt,
758			make([]tools.BaseTool, 0),
759		)
760		var finalResponse *provider.ProviderResponse
761		for r := range response {
762			if r.Error != nil {
763				event = AgentEvent{
764					Type:  AgentEventTypeError,
765					Error: fmt.Errorf("failed to summarize: %w", err),
766					Done:  true,
767				}
768				a.Publish(pubsub.CreatedEvent, event)
769				return
770			}
771			finalResponse = r.Response
772		}
773
774		summary := strings.TrimSpace(finalResponse.Content)
775		if summary == "" {
776			event = AgentEvent{
777				Type:  AgentEventTypeError,
778				Error: fmt.Errorf("empty summary returned"),
779				Done:  true,
780			}
781			a.Publish(pubsub.CreatedEvent, event)
782			return
783		}
784		event = AgentEvent{
785			Type:     AgentEventTypeSummarize,
786			Progress: "Creating new session...",
787		}
788
789		a.Publish(pubsub.CreatedEvent, event)
790		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
791		if err != nil {
792			event = AgentEvent{
793				Type:  AgentEventTypeError,
794				Error: fmt.Errorf("failed to get session: %w", err),
795				Done:  true,
796			}
797
798			a.Publish(pubsub.CreatedEvent, event)
799			return
800		}
801		// Create a message in the new session with the summary
802		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
803			Role: message.Assistant,
804			Parts: []message.ContentPart{
805				message.TextContent{Text: summary},
806				message.Finish{
807					Reason: message.FinishReasonEndTurn,
808					Time:   time.Now().Unix(),
809				},
810			},
811			Model:    a.summarizeProvider.Model().ID,
812			Provider: a.summarizeProviderID,
813		})
814		if err != nil {
815			event = AgentEvent{
816				Type:  AgentEventTypeError,
817				Error: fmt.Errorf("failed to create summary message: %w", err),
818				Done:  true,
819			}
820
821			a.Publish(pubsub.CreatedEvent, event)
822			return
823		}
824		oldSession.SummaryMessageID = msg.ID
825		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
826		oldSession.PromptTokens = 0
827		model := a.summarizeProvider.Model()
828		usage := finalResponse.Usage
829		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
830			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
831			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
832			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
833		oldSession.Cost += cost
834		_, err = a.sessions.Save(summarizeCtx, oldSession)
835		if err != nil {
836			event = AgentEvent{
837				Type:  AgentEventTypeError,
838				Error: fmt.Errorf("failed to save session: %w", err),
839				Done:  true,
840			}
841			a.Publish(pubsub.CreatedEvent, event)
842		}
843
844		event = AgentEvent{
845			Type:      AgentEventTypeSummarize,
846			SessionID: oldSession.ID,
847			Progress:  "Summary complete",
848			Done:      true,
849		}
850		a.Publish(pubsub.CreatedEvent, event)
851		// Send final success event with the new session ID
852	}()
853
854	return nil
855}
856
857func (a *agent) CancelAll() {
858	if !a.IsBusy() {
859		return
860	}
861	a.activeRequests.Range(func(key, value any) bool {
862		a.Cancel(key.(string)) // key is sessionID
863		return true
864	})
865
866	timeout := time.After(5 * time.Second)
867	for a.IsBusy() {
868		select {
869		case <-timeout:
870			return
871		default:
872			time.Sleep(200 * time.Millisecond)
873		}
874	}
875}
876
877func (a *agent) UpdateModel() error {
878	cfg := config.Get()
879
880	// Get current provider configuration
881	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
882	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
883		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
884	}
885
886	// Check if provider has changed
887	if string(currentProviderCfg.ID) != a.providerID {
888		// Provider changed, need to recreate the main provider
889		model := cfg.GetModelByType(a.agentCfg.Model)
890		if model.ID == "" {
891			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
892		}
893
894		promptID := agentPromptMap[a.agentCfg.ID]
895		if promptID == "" {
896			promptID = prompt.PromptDefault
897		}
898
899		opts := []provider.ProviderClientOption{
900			provider.WithModel(a.agentCfg.Model),
901			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
902		}
903
904		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
905		if err != nil {
906			return fmt.Errorf("failed to create new provider: %w", err)
907		}
908
909		// Update the provider and provider ID
910		a.provider = newProvider
911		a.providerID = string(currentProviderCfg.ID)
912	}
913
914	// Check if small model provider has changed (affects title and summarize providers)
915	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
916	var smallModelProviderCfg config.ProviderConfig
917
918	for _, p := range cfg.Providers {
919		if p.ID == smallModelCfg.Provider {
920			smallModelProviderCfg = p
921			break
922		}
923	}
924
925	if smallModelProviderCfg.ID == "" {
926		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
927	}
928
929	// Check if summarize provider has changed
930	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
931		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
932		if smallModel == nil {
933			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
934		}
935
936		// Recreate title provider
937		titleOpts := []provider.ProviderClientOption{
938			provider.WithModel(config.SelectedModelTypeSmall),
939			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
940			// We want the title to be short, so we limit the max tokens
941			provider.WithMaxTokens(40),
942		}
943		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
944		if err != nil {
945			return fmt.Errorf("failed to create new title provider: %w", err)
946		}
947
948		// Recreate summarize provider
949		summarizeOpts := []provider.ProviderClientOption{
950			provider.WithModel(config.SelectedModelTypeSmall),
951			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
952		}
953		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
954		if err != nil {
955			return fmt.Errorf("failed to create new summarize provider: %w", err)
956		}
957
958		// Update the providers and provider ID
959		a.titleProvider = newTitleProvider
960		a.summarizeProvider = newSummarizeProvider
961		a.summarizeProviderID = string(smallModelProviderCfg.ID)
962	}
963
964	return nil
965}