agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"sync/atomic"
 12	"time"
 13
 14	"github.com/charmbracelet/crush/internal/config"
 15	fur "github.com/charmbracelet/crush/internal/fur/provider"
 16	"github.com/charmbracelet/crush/internal/history"
 17	"github.com/charmbracelet/crush/internal/llm/prompt"
 18	"github.com/charmbracelet/crush/internal/llm/provider"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/log"
 21	"github.com/charmbracelet/crush/internal/lsp"
 22	"github.com/charmbracelet/crush/internal/message"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/pubsub"
 25	"github.com/charmbracelet/crush/internal/session"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() fur.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70
 71	toolsDone atomic.Bool
 72	tools     []tools.BaseTool
 73
 74	provider   provider.Provider
 75	providerID string
 76
 77	titleProvider       provider.Provider
 78	summarizeProvider   provider.Provider
 79	summarizeProviderID string
 80
 81	activeRequests sync.Map
 82}
 83
 84var agentPromptMap = map[string]prompt.PromptID{
 85	"coder": prompt.PromptCoder,
 86	"task":  prompt.PromptTask,
 87}
 88
 89func NewAgent(
 90	agentCfg config.Agent,
 91	// These services are needed in the tools
 92	permissions permission.Service,
 93	sessions session.Service,
 94	messages message.Service,
 95	history history.Service,
 96	lspClients map[string]*lsp.Client,
 97) (Service, error) {
 98	ctx := context.Background()
 99	cfg := config.Get()
100
101	var agentTool tools.BaseTool
102	if agentCfg.ID == "coder" {
103		taskAgentCfg := config.Get().Agents["task"]
104		if taskAgentCfg.ID == "" {
105			return nil, fmt.Errorf("task agent not found in config")
106		}
107		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
108		if err != nil {
109			return nil, fmt.Errorf("failed to create task agent: %w", err)
110		}
111
112		agentTool = NewAgentTool(taskAgent, sessions, messages)
113	}
114
115	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116	if providerCfg == nil {
117		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118	}
119	model := config.Get().GetModelByType(agentCfg.Model)
120
121	if model == nil {
122		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123	}
124
125	promptID := agentPromptMap[agentCfg.ID]
126	if promptID == "" {
127		promptID = prompt.PromptDefault
128	}
129	opts := []provider.ProviderClientOption{
130		provider.WithModel(agentCfg.Model),
131		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132	}
133	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134	if err != nil {
135		return nil, err
136	}
137
138	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139	var smallModelProviderCfg *config.ProviderConfig
140	if smallModelCfg.Provider == providerCfg.ID {
141		smallModelProviderCfg = providerCfg
142	} else {
143		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145		if smallModelProviderCfg.ID == "" {
146			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147		}
148	}
149	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150	if smallModel.ID == "" {
151		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152	}
153
154	titleOpts := []provider.ProviderClientOption{
155		provider.WithModel(config.SelectedModelTypeSmall),
156		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157	}
158	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159	if err != nil {
160		return nil, err
161	}
162	summarizeOpts := []provider.ProviderClientOption{
163		provider.WithModel(config.SelectedModelTypeSmall),
164		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165	}
166	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167	if err != nil {
168		return nil, err
169	}
170
171	agent := &agent{
172		Broker:              pubsub.NewBroker[AgentEvent](),
173		agentCfg:            agentCfg,
174		provider:            agentProvider,
175		providerID:          string(providerCfg.ID),
176		messages:            messages,
177		sessions:            sessions,
178		titleProvider:       titleProvider,
179		summarizeProvider:   summarizeProvider,
180		summarizeProviderID: string(smallModelProviderCfg.ID),
181		activeRequests:      sync.Map{},
182	}
183
184	go func() {
185		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
186		defer func() {
187			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
188			agent.toolsDone.Store(true)
189		}()
190
191		cwd := cfg.WorkingDir()
192		allTools := []tools.BaseTool{
193			tools.NewBashTool(permissions, cwd),
194			tools.NewDownloadTool(permissions, cwd),
195			tools.NewEditTool(lspClients, permissions, history, cwd),
196			tools.NewFetchTool(permissions, cwd),
197			tools.NewGlobTool(cwd),
198			tools.NewGrepTool(cwd),
199			tools.NewLsTool(cwd),
200			tools.NewSourcegraphTool(),
201			tools.NewViewTool(lspClients, cwd),
202			tools.NewWriteTool(lspClients, permissions, history, cwd),
203		}
204
205		mcpTools := GetMCPTools(ctx, permissions, cfg)
206		allTools = append(allTools, mcpTools...)
207
208		if len(lspClients) > 0 {
209			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
210		}
211
212		if agentTool != nil {
213			allTools = append(allTools, agentTool)
214		}
215
216		if agentCfg.AllowedTools == nil {
217			agent.tools = allTools
218			return
219		}
220
221		var filteredTools []tools.BaseTool
222		for _, tool := range allTools {
223			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
224				filteredTools = append(filteredTools, tool)
225			}
226		}
227		agent.tools = filteredTools
228	}()
229
230	return agent, nil
231}
232
233func (a *agent) Model() fur.Model {
234	return *config.Get().GetModelByType(a.agentCfg.Model)
235}
236
237func (a *agent) Cancel(sessionID string) {
238	// Cancel regular requests
239	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
240		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
241			slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
242			cancel()
243		}
244	}
245
246	// Also check for summarize requests
247	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
248		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
249			slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
250			cancel()
251		}
252	}
253}
254
255func (a *agent) IsBusy() bool {
256	busy := false
257	a.activeRequests.Range(func(key, value any) bool {
258		if cancelFunc, ok := value.(context.CancelFunc); ok {
259			if cancelFunc != nil {
260				busy = true
261				return false
262			}
263		}
264		return true
265	})
266	return busy
267}
268
269func (a *agent) IsSessionBusy(sessionID string) bool {
270	_, busy := a.activeRequests.Load(sessionID)
271	return busy
272}
273
274func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
275	if content == "" {
276		return nil
277	}
278	if a.titleProvider == nil {
279		return nil
280	}
281	session, err := a.sessions.Get(ctx, sessionID)
282	if err != nil {
283		return err
284	}
285	parts := []message.ContentPart{message.TextContent{
286		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
287	}}
288
289	// Use streaming approach like summarization
290	response := a.titleProvider.StreamResponse(
291		ctx,
292		[]message.Message{
293			{
294				Role:  message.User,
295				Parts: parts,
296			},
297		},
298		make([]tools.BaseTool, 0),
299	)
300
301	var finalResponse *provider.ProviderResponse
302	for r := range response {
303		if r.Error != nil {
304			return r.Error
305		}
306		finalResponse = r.Response
307	}
308
309	if finalResponse == nil {
310		return fmt.Errorf("no response received from title provider")
311	}
312
313	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
314	if title == "" {
315		return nil
316	}
317
318	session.Title = title
319	_, err = a.sessions.Save(ctx, session)
320	return err
321}
322
323func (a *agent) err(err error) AgentEvent {
324	return AgentEvent{
325		Type:  AgentEventTypeError,
326		Error: err,
327	}
328}
329
330func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
331	if !a.Model().SupportsImages && attachments != nil {
332		attachments = nil
333	}
334	events := make(chan AgentEvent)
335	if a.IsSessionBusy(sessionID) {
336		return nil, ErrSessionBusy
337	}
338
339	genCtx, cancel := context.WithCancel(ctx)
340
341	a.activeRequests.Store(sessionID, cancel)
342	go func() {
343		slog.Debug("Request started", "sessionID", sessionID)
344		defer log.RecoverPanic("agent.Run", func() {
345			events <- a.err(fmt.Errorf("panic while running the agent"))
346		})
347		var attachmentParts []message.ContentPart
348		for _, attachment := range attachments {
349			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
350		}
351		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
352		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
353			slog.Error(result.Error.Error())
354		}
355		slog.Debug("Request completed", "sessionID", sessionID)
356		a.activeRequests.Delete(sessionID)
357		cancel()
358		a.Publish(pubsub.CreatedEvent, result)
359		events <- result
360		close(events)
361	}()
362	return events, nil
363}
364
365func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
366	cfg := config.Get()
367	// List existing messages; if none, start title generation asynchronously.
368	msgs, err := a.messages.List(ctx, sessionID)
369	if err != nil {
370		return a.err(fmt.Errorf("failed to list messages: %w", err))
371	}
372	if len(msgs) == 0 {
373		go func() {
374			defer log.RecoverPanic("agent.Run", func() {
375				slog.Error("panic while generating title")
376			})
377			titleErr := a.generateTitle(context.Background(), sessionID, content)
378			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
379				slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
380			}
381		}()
382	}
383	session, err := a.sessions.Get(ctx, sessionID)
384	if err != nil {
385		return a.err(fmt.Errorf("failed to get session: %w", err))
386	}
387	if session.SummaryMessageID != "" {
388		summaryMsgInex := -1
389		for i, msg := range msgs {
390			if msg.ID == session.SummaryMessageID {
391				summaryMsgInex = i
392				break
393			}
394		}
395		if summaryMsgInex != -1 {
396			msgs = msgs[summaryMsgInex:]
397			msgs[0].Role = message.User
398		}
399	}
400
401	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
402	if err != nil {
403		return a.err(fmt.Errorf("failed to create user message: %w", err))
404	}
405	// Append the new user message to the conversation history.
406	msgHistory := append(msgs, userMsg)
407
408	for {
409		// Check for cancellation before each iteration
410		select {
411		case <-ctx.Done():
412			return a.err(ctx.Err())
413		default:
414			// Continue processing
415		}
416		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
417		if err != nil {
418			if errors.Is(err, context.Canceled) {
419				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
420				a.messages.Update(context.Background(), agentMessage)
421				return a.err(ErrRequestCancelled)
422			}
423			return a.err(fmt.Errorf("failed to process events: %w", err))
424		}
425		if cfg.Options.Debug {
426			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
427		}
428		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
429			// We are not done, we need to respond with the tool response
430			msgHistory = append(msgHistory, agentMessage, *toolResults)
431			continue
432		}
433		return AgentEvent{
434			Type:    AgentEventTypeResponse,
435			Message: agentMessage,
436			Done:    true,
437		}
438	}
439}
440
441func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
442	parts := []message.ContentPart{message.TextContent{Text: content}}
443	parts = append(parts, attachmentParts...)
444	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
445		Role:  message.User,
446		Parts: parts,
447	})
448}
449
450func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
451	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
452	if !a.toolsDone.Load() {
453		return message.Message{}, nil, fmt.Errorf("Agent is still initializing, please wait a moment and try again")
454	}
455	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
456
457	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
458		Role:     message.Assistant,
459		Parts:    []message.ContentPart{},
460		Model:    a.Model().ID,
461		Provider: a.providerID,
462	})
463	if err != nil {
464		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
465	}
466
467	// Add the session and message ID into the context if needed by tools.
468	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
469
470	// Process each event in the stream.
471	for event := range eventChan {
472		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
473			if errors.Is(processErr, context.Canceled) {
474				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
475			} else {
476				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
477			}
478			return assistantMsg, nil, processErr
479		}
480		if ctx.Err() != nil {
481			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
482			return assistantMsg, nil, ctx.Err()
483		}
484	}
485
486	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
487	toolCalls := assistantMsg.ToolCalls()
488	for i, toolCall := range toolCalls {
489		select {
490		case <-ctx.Done():
491			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
492			// Make all future tool calls cancelled
493			for j := i; j < len(toolCalls); j++ {
494				toolResults[j] = message.ToolResult{
495					ToolCallID: toolCalls[j].ID,
496					Content:    "Tool execution canceled by user",
497					IsError:    true,
498				}
499			}
500			goto out
501		default:
502			// Continue processing
503			var tool tools.BaseTool
504			for _, availableTool := range a.tools {
505				if availableTool.Info().Name == toolCall.Name {
506					tool = availableTool
507					break
508				}
509			}
510
511			// Tool not found
512			if tool == nil {
513				toolResults[i] = message.ToolResult{
514					ToolCallID: toolCall.ID,
515					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
516					IsError:    true,
517				}
518				continue
519			}
520
521			// Run tool in goroutine to allow cancellation
522			type toolExecResult struct {
523				response tools.ToolResponse
524				err      error
525			}
526			resultChan := make(chan toolExecResult, 1)
527
528			go func() {
529				response, err := tool.Run(ctx, tools.ToolCall{
530					ID:    toolCall.ID,
531					Name:  toolCall.Name,
532					Input: toolCall.Input,
533				})
534				resultChan <- toolExecResult{response: response, err: err}
535			}()
536
537			var toolResponse tools.ToolResponse
538			var toolErr error
539
540			select {
541			case <-ctx.Done():
542				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
543				// Mark remaining tool calls as cancelled
544				for j := i; j < len(toolCalls); j++ {
545					toolResults[j] = message.ToolResult{
546						ToolCallID: toolCalls[j].ID,
547						Content:    "Tool execution canceled by user",
548						IsError:    true,
549					}
550				}
551				goto out
552			case result := <-resultChan:
553				toolResponse = result.response
554				toolErr = result.err
555			}
556
557			if toolErr != nil {
558				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
559				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
560					toolResults[i] = message.ToolResult{
561						ToolCallID: toolCall.ID,
562						Content:    "Permission denied",
563						IsError:    true,
564					}
565					for j := i + 1; j < len(toolCalls); j++ {
566						toolResults[j] = message.ToolResult{
567							ToolCallID: toolCalls[j].ID,
568							Content:    "Tool execution canceled by user",
569							IsError:    true,
570						}
571					}
572					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
573					break
574				}
575			}
576			toolResults[i] = message.ToolResult{
577				ToolCallID: toolCall.ID,
578				Content:    toolResponse.Content,
579				Metadata:   toolResponse.Metadata,
580				IsError:    toolResponse.IsError,
581			}
582		}
583	}
584out:
585	if len(toolResults) == 0 {
586		return assistantMsg, nil, nil
587	}
588	parts := make([]message.ContentPart, 0)
589	for _, tr := range toolResults {
590		parts = append(parts, tr)
591	}
592	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
593		Role:     message.Tool,
594		Parts:    parts,
595		Provider: a.providerID,
596	})
597	if err != nil {
598		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
599	}
600
601	return assistantMsg, &msg, err
602}
603
604func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
605	msg.AddFinish(finishReason, message, details)
606	_ = a.messages.Update(ctx, *msg)
607}
608
609func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
610	select {
611	case <-ctx.Done():
612		return ctx.Err()
613	default:
614		// Continue processing.
615	}
616
617	switch event.Type {
618	case provider.EventThinkingDelta:
619		assistantMsg.AppendReasoningContent(event.Thinking)
620		return a.messages.Update(ctx, *assistantMsg)
621	case provider.EventSignatureDelta:
622		assistantMsg.AppendReasoningSignature(event.Signature)
623		return a.messages.Update(ctx, *assistantMsg)
624	case provider.EventContentDelta:
625		assistantMsg.FinishThinking()
626		assistantMsg.AppendContent(event.Content)
627		return a.messages.Update(ctx, *assistantMsg)
628	case provider.EventToolUseStart:
629		assistantMsg.FinishThinking()
630		slog.Info("Tool call started", "toolCall", event.ToolCall)
631		assistantMsg.AddToolCall(*event.ToolCall)
632		return a.messages.Update(ctx, *assistantMsg)
633	case provider.EventToolUseDelta:
634		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
635		return a.messages.Update(ctx, *assistantMsg)
636	case provider.EventToolUseStop:
637		slog.Info("Finished tool call", "toolCall", event.ToolCall)
638		assistantMsg.FinishToolCall(event.ToolCall.ID)
639		return a.messages.Update(ctx, *assistantMsg)
640	case provider.EventError:
641		return event.Error
642	case provider.EventComplete:
643		assistantMsg.FinishThinking()
644		assistantMsg.SetToolCalls(event.Response.ToolCalls)
645		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
646		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
647			return fmt.Errorf("failed to update message: %w", err)
648		}
649		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
650	}
651
652	return nil
653}
654
655func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
656	sess, err := a.sessions.Get(ctx, sessionID)
657	if err != nil {
658		return fmt.Errorf("failed to get session: %w", err)
659	}
660
661	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
662		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
663		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
664		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
665
666	sess.Cost += cost
667	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
668	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
669
670	_, err = a.sessions.Save(ctx, sess)
671	if err != nil {
672		return fmt.Errorf("failed to save session: %w", err)
673	}
674	return nil
675}
676
677func (a *agent) Summarize(ctx context.Context, sessionID string) error {
678	if a.summarizeProvider == nil {
679		return fmt.Errorf("summarize provider not available")
680	}
681
682	// Check if session is busy
683	if a.IsSessionBusy(sessionID) {
684		return ErrSessionBusy
685	}
686
687	// Create a new context with cancellation
688	summarizeCtx, cancel := context.WithCancel(ctx)
689
690	// Store the cancel function in activeRequests to allow cancellation
691	a.activeRequests.Store(sessionID+"-summarize", cancel)
692
693	go func() {
694		defer a.activeRequests.Delete(sessionID + "-summarize")
695		defer cancel()
696		event := AgentEvent{
697			Type:     AgentEventTypeSummarize,
698			Progress: "Starting summarization...",
699		}
700
701		a.Publish(pubsub.CreatedEvent, event)
702		// Get all messages from the session
703		msgs, err := a.messages.List(summarizeCtx, sessionID)
704		if err != nil {
705			event = AgentEvent{
706				Type:  AgentEventTypeError,
707				Error: fmt.Errorf("failed to list messages: %w", err),
708				Done:  true,
709			}
710			a.Publish(pubsub.CreatedEvent, event)
711			return
712		}
713		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
714
715		if len(msgs) == 0 {
716			event = AgentEvent{
717				Type:  AgentEventTypeError,
718				Error: fmt.Errorf("no messages to summarize"),
719				Done:  true,
720			}
721			a.Publish(pubsub.CreatedEvent, event)
722			return
723		}
724
725		event = AgentEvent{
726			Type:     AgentEventTypeSummarize,
727			Progress: "Analyzing conversation...",
728		}
729		a.Publish(pubsub.CreatedEvent, event)
730
731		// Add a system message to guide the summarization
732		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
733
734		// Create a new message with the summarize prompt
735		promptMsg := message.Message{
736			Role:  message.User,
737			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
738		}
739
740		// Append the prompt to the messages
741		msgsWithPrompt := append(msgs, promptMsg)
742
743		event = AgentEvent{
744			Type:     AgentEventTypeSummarize,
745			Progress: "Generating summary...",
746		}
747
748		a.Publish(pubsub.CreatedEvent, event)
749
750		// Send the messages to the summarize provider
751		response := a.summarizeProvider.StreamResponse(
752			summarizeCtx,
753			msgsWithPrompt,
754			make([]tools.BaseTool, 0),
755		)
756		var finalResponse *provider.ProviderResponse
757		for r := range response {
758			if r.Error != nil {
759				event = AgentEvent{
760					Type:  AgentEventTypeError,
761					Error: fmt.Errorf("failed to summarize: %w", err),
762					Done:  true,
763				}
764				a.Publish(pubsub.CreatedEvent, event)
765				return
766			}
767			finalResponse = r.Response
768		}
769
770		summary := strings.TrimSpace(finalResponse.Content)
771		if summary == "" {
772			event = AgentEvent{
773				Type:  AgentEventTypeError,
774				Error: fmt.Errorf("empty summary returned"),
775				Done:  true,
776			}
777			a.Publish(pubsub.CreatedEvent, event)
778			return
779		}
780		event = AgentEvent{
781			Type:     AgentEventTypeSummarize,
782			Progress: "Creating new session...",
783		}
784
785		a.Publish(pubsub.CreatedEvent, event)
786		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
787		if err != nil {
788			event = AgentEvent{
789				Type:  AgentEventTypeError,
790				Error: fmt.Errorf("failed to get session: %w", err),
791				Done:  true,
792			}
793
794			a.Publish(pubsub.CreatedEvent, event)
795			return
796		}
797		// Create a message in the new session with the summary
798		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
799			Role: message.Assistant,
800			Parts: []message.ContentPart{
801				message.TextContent{Text: summary},
802				message.Finish{
803					Reason: message.FinishReasonEndTurn,
804					Time:   time.Now().Unix(),
805				},
806			},
807			Model:    a.summarizeProvider.Model().ID,
808			Provider: a.summarizeProviderID,
809		})
810		if err != nil {
811			event = AgentEvent{
812				Type:  AgentEventTypeError,
813				Error: fmt.Errorf("failed to create summary message: %w", err),
814				Done:  true,
815			}
816
817			a.Publish(pubsub.CreatedEvent, event)
818			return
819		}
820		oldSession.SummaryMessageID = msg.ID
821		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
822		oldSession.PromptTokens = 0
823		model := a.summarizeProvider.Model()
824		usage := finalResponse.Usage
825		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
826			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
827			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
828			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
829		oldSession.Cost += cost
830		_, err = a.sessions.Save(summarizeCtx, oldSession)
831		if err != nil {
832			event = AgentEvent{
833				Type:  AgentEventTypeError,
834				Error: fmt.Errorf("failed to save session: %w", err),
835				Done:  true,
836			}
837			a.Publish(pubsub.CreatedEvent, event)
838		}
839
840		event = AgentEvent{
841			Type:      AgentEventTypeSummarize,
842			SessionID: oldSession.ID,
843			Progress:  "Summary complete",
844			Done:      true,
845		}
846		a.Publish(pubsub.CreatedEvent, event)
847		// Send final success event with the new session ID
848	}()
849
850	return nil
851}
852
853func (a *agent) CancelAll() {
854	if !a.IsBusy() {
855		return
856	}
857	a.activeRequests.Range(func(key, value any) bool {
858		a.Cancel(key.(string)) // key is sessionID
859		return true
860	})
861
862	timeout := time.After(5 * time.Second)
863	for a.IsBusy() {
864		select {
865		case <-timeout:
866			return
867		default:
868			time.Sleep(200 * time.Millisecond)
869		}
870	}
871}
872
873func (a *agent) UpdateModel() error {
874	cfg := config.Get()
875
876	// Get current provider configuration
877	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
878	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
879		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
880	}
881
882	// Check if provider has changed
883	if string(currentProviderCfg.ID) != a.providerID {
884		// Provider changed, need to recreate the main provider
885		model := cfg.GetModelByType(a.agentCfg.Model)
886		if model.ID == "" {
887			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
888		}
889
890		promptID := agentPromptMap[a.agentCfg.ID]
891		if promptID == "" {
892			promptID = prompt.PromptDefault
893		}
894
895		opts := []provider.ProviderClientOption{
896			provider.WithModel(a.agentCfg.Model),
897			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
898		}
899
900		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
901		if err != nil {
902			return fmt.Errorf("failed to create new provider: %w", err)
903		}
904
905		// Update the provider and provider ID
906		a.provider = newProvider
907		a.providerID = string(currentProviderCfg.ID)
908	}
909
910	// Check if small model provider has changed (affects title and summarize providers)
911	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
912	var smallModelProviderCfg config.ProviderConfig
913
914	for _, p := range cfg.Providers {
915		if p.ID == smallModelCfg.Provider {
916			smallModelProviderCfg = p
917			break
918		}
919	}
920
921	if smallModelProviderCfg.ID == "" {
922		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
923	}
924
925	// Check if summarize provider has changed
926	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
927		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
928		if smallModel == nil {
929			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
930		}
931
932		// Recreate title provider
933		titleOpts := []provider.ProviderClientOption{
934			provider.WithModel(config.SelectedModelTypeSmall),
935			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
936			// We want the title to be short, so we limit the max tokens
937			provider.WithMaxTokens(40),
938		}
939		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
940		if err != nil {
941			return fmt.Errorf("failed to create new title provider: %w", err)
942		}
943
944		// Recreate summarize provider
945		summarizeOpts := []provider.ProviderClientOption{
946			provider.WithModel(config.SelectedModelTypeSmall),
947			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
948		}
949		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
950		if err != nil {
951			return fmt.Errorf("failed to create new summarize provider: %w", err)
952		}
953
954		// Update the providers and provider ID
955		a.titleProvider = newTitleProvider
956		a.summarizeProvider = newSummarizeProvider
957		a.summarizeProviderID = string(smallModelProviderCfg.ID)
958	}
959
960	return nil
961}