1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	fur "github.com/charmbracelet/crush/internal/fur/provider"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25)
 26
 27// Common errors
 28var (
 29	ErrRequestCancelled = errors.New("request canceled by user")
 30	ErrSessionBusy      = errors.New("session is currently processing another request")
 31)
 32
 33type AgentEventType string
 34
 35const (
 36	AgentEventTypeError     AgentEventType = "error"
 37	AgentEventTypeResponse  AgentEventType = "response"
 38	AgentEventTypeSummarize AgentEventType = "summarize"
 39)
 40
 41type AgentEvent struct {
 42	Type    AgentEventType
 43	Message message.Message
 44	Error   error
 45
 46	// When summarizing
 47	SessionID string
 48	Progress  string
 49	Done      bool
 50}
 51
 52type Service interface {
 53	pubsub.Suscriber[AgentEvent]
 54	Model() fur.Model
 55	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 56	Cancel(sessionID string)
 57	CancelAll()
 58	IsSessionBusy(sessionID string) bool
 59	IsBusy() bool
 60	Summarize(ctx context.Context, sessionID string) error
 61	UpdateModel() error
 62}
 63
 64type agent struct {
 65	*pubsub.Broker[AgentEvent]
 66	agentCfg config.Agent
 67	sessions session.Service
 68	messages message.Service
 69
 70	tools      []tools.BaseTool
 71	provider   provider.Provider
 72	providerID string
 73
 74	titleProvider       provider.Provider
 75	summarizeProvider   provider.Provider
 76	summarizeProviderID string
 77
 78	activeRequests sync.Map
 79}
 80
 81var agentPromptMap = map[string]prompt.PromptID{
 82	"coder": prompt.PromptCoder,
 83	"task":  prompt.PromptTask,
 84}
 85
 86func NewAgent(
 87	agentCfg config.Agent,
 88	// These services are needed in the tools
 89	permissions permission.Service,
 90	sessions session.Service,
 91	messages message.Service,
 92	history history.Service,
 93	lspClients map[string]*lsp.Client,
 94) (Service, error) {
 95	ctx := context.Background()
 96	cfg := config.Get()
 97	otherTools := GetMCPTools(ctx, permissions, cfg)
 98	if len(lspClients) > 0 {
 99		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100	}
101
102	cwd := cfg.WorkingDir()
103	allTools := []tools.BaseTool{
104		tools.NewBashTool(permissions, cwd),
105		tools.NewEditTool(lspClients, permissions, history, cwd),
106		tools.NewFetchTool(permissions, cwd),
107		tools.NewGlobTool(cwd),
108		tools.NewGrepTool(cwd),
109		tools.NewLsTool(cwd),
110		tools.NewSourcegraphTool(),
111		tools.NewViewTool(lspClients, cwd),
112		tools.NewWriteTool(lspClients, permissions, history, cwd),
113	}
114
115	if agentCfg.ID == "coder" {
116		taskAgentCfg := config.Get().Agents["task"]
117		if taskAgentCfg.ID == "" {
118			return nil, fmt.Errorf("task agent not found in config")
119		}
120		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
121		if err != nil {
122			return nil, fmt.Errorf("failed to create task agent: %w", err)
123		}
124
125		allTools = append(
126			allTools,
127			NewAgentTool(
128				taskAgent,
129				sessions,
130				messages,
131			),
132		)
133	}
134
135	allTools = append(allTools, otherTools...)
136	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
137	if providerCfg == nil {
138		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
139	}
140	model := config.Get().GetModelByType(agentCfg.Model)
141
142	if model == nil {
143		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
144	}
145
146	promptID := agentPromptMap[agentCfg.ID]
147	if promptID == "" {
148		promptID = prompt.PromptDefault
149	}
150	opts := []provider.ProviderClientOption{
151		provider.WithModel(agentCfg.Model),
152		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
153	}
154	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
155	if err != nil {
156		return nil, err
157	}
158
159	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
160	var smallModelProviderCfg *config.ProviderConfig
161	if smallModelCfg.Provider == providerCfg.ID {
162		smallModelProviderCfg = providerCfg
163	} else {
164		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
165
166		if smallModelProviderCfg.ID == "" {
167			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
168		}
169	}
170	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
171	if smallModel.ID == "" {
172		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
173	}
174
175	titleOpts := []provider.ProviderClientOption{
176		provider.WithModel(config.SelectedModelTypeSmall),
177		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
178	}
179	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
180	if err != nil {
181		return nil, err
182	}
183	summarizeOpts := []provider.ProviderClientOption{
184		provider.WithModel(config.SelectedModelTypeSmall),
185		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
186	}
187	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
188	if err != nil {
189		return nil, err
190	}
191
192	agentTools := []tools.BaseTool{}
193	if agentCfg.AllowedTools == nil {
194		agentTools = allTools
195	} else {
196		for _, tool := range allTools {
197			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
198				agentTools = append(agentTools, tool)
199			}
200		}
201	}
202
203	agent := &agent{
204		Broker:              pubsub.NewBroker[AgentEvent](),
205		agentCfg:            agentCfg,
206		provider:            agentProvider,
207		providerID:          string(providerCfg.ID),
208		messages:            messages,
209		sessions:            sessions,
210		tools:               agentTools,
211		titleProvider:       titleProvider,
212		summarizeProvider:   summarizeProvider,
213		summarizeProviderID: string(smallModelProviderCfg.ID),
214		activeRequests:      sync.Map{},
215	}
216
217	return agent, nil
218}
219
220func (a *agent) Model() fur.Model {
221	return *config.Get().GetModelByType(a.agentCfg.Model)
222}
223
224func (a *agent) Cancel(sessionID string) {
225	// Cancel regular requests
226	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
227		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
228			slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
229			cancel()
230		}
231	}
232
233	// Also check for summarize requests
234	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
235		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236			slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
237			cancel()
238		}
239	}
240}
241
242func (a *agent) IsBusy() bool {
243	busy := false
244	a.activeRequests.Range(func(key, value any) bool {
245		if cancelFunc, ok := value.(context.CancelFunc); ok {
246			if cancelFunc != nil {
247				busy = true
248				return false
249			}
250		}
251		return true
252	})
253	return busy
254}
255
256func (a *agent) IsSessionBusy(sessionID string) bool {
257	_, busy := a.activeRequests.Load(sessionID)
258	return busy
259}
260
261func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
262	if content == "" {
263		return nil
264	}
265	if a.titleProvider == nil {
266		return nil
267	}
268	session, err := a.sessions.Get(ctx, sessionID)
269	if err != nil {
270		return err
271	}
272	parts := []message.ContentPart{message.TextContent{
273		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
274	}}
275
276	// Use streaming approach like summarization
277	response := a.titleProvider.StreamResponse(
278		ctx,
279		[]message.Message{
280			{
281				Role:  message.User,
282				Parts: parts,
283			},
284		},
285		make([]tools.BaseTool, 0),
286	)
287
288	var finalResponse *provider.ProviderResponse
289	for r := range response {
290		if r.Error != nil {
291			return r.Error
292		}
293		finalResponse = r.Response
294	}
295
296	if finalResponse == nil {
297		return fmt.Errorf("no response received from title provider")
298	}
299
300	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
301	if title == "" {
302		return nil
303	}
304
305	session.Title = title
306	_, err = a.sessions.Save(ctx, session)
307	return err
308}
309
310func (a *agent) err(err error) AgentEvent {
311	return AgentEvent{
312		Type:  AgentEventTypeError,
313		Error: err,
314	}
315}
316
317func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
318	if !a.Model().SupportsImages && attachments != nil {
319		attachments = nil
320	}
321	events := make(chan AgentEvent)
322	if a.IsSessionBusy(sessionID) {
323		return nil, ErrSessionBusy
324	}
325
326	genCtx, cancel := context.WithCancel(ctx)
327
328	a.activeRequests.Store(sessionID, cancel)
329	go func() {
330		slog.Debug("Request started", "sessionID", sessionID)
331		defer log.RecoverPanic("agent.Run", func() {
332			events <- a.err(fmt.Errorf("panic while running the agent"))
333		})
334		var attachmentParts []message.ContentPart
335		for _, attachment := range attachments {
336			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
337		}
338		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
339		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
340			slog.Error(result.Error.Error())
341		}
342		slog.Debug("Request completed", "sessionID", sessionID)
343		a.activeRequests.Delete(sessionID)
344		cancel()
345		a.Publish(pubsub.CreatedEvent, result)
346		events <- result
347		close(events)
348	}()
349	return events, nil
350}
351
352func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
353	cfg := config.Get()
354	// List existing messages; if none, start title generation asynchronously.
355	msgs, err := a.messages.List(ctx, sessionID)
356	if err != nil {
357		return a.err(fmt.Errorf("failed to list messages: %w", err))
358	}
359	if len(msgs) == 0 {
360		go func() {
361			defer log.RecoverPanic("agent.Run", func() {
362				slog.Error("panic while generating title")
363			})
364			titleErr := a.generateTitle(context.Background(), sessionID, content)
365			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
366				slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
367			}
368		}()
369	}
370	session, err := a.sessions.Get(ctx, sessionID)
371	if err != nil {
372		return a.err(fmt.Errorf("failed to get session: %w", err))
373	}
374	if session.SummaryMessageID != "" {
375		summaryMsgInex := -1
376		for i, msg := range msgs {
377			if msg.ID == session.SummaryMessageID {
378				summaryMsgInex = i
379				break
380			}
381		}
382		if summaryMsgInex != -1 {
383			msgs = msgs[summaryMsgInex:]
384			msgs[0].Role = message.User
385		}
386	}
387
388	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
389	if err != nil {
390		return a.err(fmt.Errorf("failed to create user message: %w", err))
391	}
392	// Append the new user message to the conversation history.
393	msgHistory := append(msgs, userMsg)
394
395	for {
396		// Check for cancellation before each iteration
397		select {
398		case <-ctx.Done():
399			return a.err(ctx.Err())
400		default:
401			// Continue processing
402		}
403		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
404		if err != nil {
405			if errors.Is(err, context.Canceled) {
406				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
407				a.messages.Update(context.Background(), agentMessage)
408				return a.err(ErrRequestCancelled)
409			}
410			return a.err(fmt.Errorf("failed to process events: %w", err))
411		}
412		if cfg.Options.Debug {
413			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
414		}
415		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
416			// We are not done, we need to respond with the tool response
417			msgHistory = append(msgHistory, agentMessage, *toolResults)
418			continue
419		}
420		return AgentEvent{
421			Type:    AgentEventTypeResponse,
422			Message: agentMessage,
423			Done:    true,
424		}
425	}
426}
427
428func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
429	parts := []message.ContentPart{message.TextContent{Text: content}}
430	parts = append(parts, attachmentParts...)
431	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
432		Role:  message.User,
433		Parts: parts,
434	})
435}
436
437func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
438	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
439	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
440
441	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442		Role:     message.Assistant,
443		Parts:    []message.ContentPart{},
444		Model:    a.Model().ID,
445		Provider: a.providerID,
446	})
447	if err != nil {
448		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
449	}
450
451	// Add the session and message ID into the context if needed by tools.
452	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
453
454	// Process each event in the stream.
455	for event := range eventChan {
456		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
457			if errors.Is(processErr, context.Canceled) {
458				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
459			} else {
460				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
461			}
462			return assistantMsg, nil, processErr
463		}
464		if ctx.Err() != nil {
465			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
466			return assistantMsg, nil, ctx.Err()
467		}
468	}
469
470	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
471	toolCalls := assistantMsg.ToolCalls()
472	for i, toolCall := range toolCalls {
473		select {
474		case <-ctx.Done():
475			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476			// Make all future tool calls cancelled
477			for j := i; j < len(toolCalls); j++ {
478				toolResults[j] = message.ToolResult{
479					ToolCallID: toolCalls[j].ID,
480					Content:    "Tool execution canceled by user",
481					IsError:    true,
482				}
483			}
484			goto out
485		default:
486			// Continue processing
487			var tool tools.BaseTool
488			for _, availableTool := range a.tools {
489				if availableTool.Info().Name == toolCall.Name {
490					tool = availableTool
491					break
492				}
493			}
494
495			// Tool not found
496			if tool == nil {
497				toolResults[i] = message.ToolResult{
498					ToolCallID: toolCall.ID,
499					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
500					IsError:    true,
501				}
502				continue
503			}
504
505			// Run tool in goroutine to allow cancellation
506			type toolExecResult struct {
507				response tools.ToolResponse
508				err      error
509			}
510			resultChan := make(chan toolExecResult, 1)
511
512			go func() {
513				response, err := tool.Run(ctx, tools.ToolCall{
514					ID:    toolCall.ID,
515					Name:  toolCall.Name,
516					Input: toolCall.Input,
517				})
518				resultChan <- toolExecResult{response: response, err: err}
519			}()
520
521			var toolResponse tools.ToolResponse
522			var toolErr error
523
524			select {
525			case <-ctx.Done():
526				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
527				// Mark remaining tool calls as cancelled
528				for j := i; j < len(toolCalls); j++ {
529					toolResults[j] = message.ToolResult{
530						ToolCallID: toolCalls[j].ID,
531						Content:    "Tool execution canceled by user",
532						IsError:    true,
533					}
534				}
535				goto out
536			case result := <-resultChan:
537				toolResponse = result.response
538				toolErr = result.err
539			}
540
541			if toolErr != nil {
542				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
543				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
544					toolResults[i] = message.ToolResult{
545						ToolCallID: toolCall.ID,
546						Content:    "Permission denied",
547						IsError:    true,
548					}
549					for j := i + 1; j < len(toolCalls); j++ {
550						toolResults[j] = message.ToolResult{
551							ToolCallID: toolCalls[j].ID,
552							Content:    "Tool execution canceled by user",
553							IsError:    true,
554						}
555					}
556					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
557					break
558				}
559			}
560			toolResults[i] = message.ToolResult{
561				ToolCallID: toolCall.ID,
562				Content:    toolResponse.Content,
563				Metadata:   toolResponse.Metadata,
564				IsError:    toolResponse.IsError,
565			}
566		}
567	}
568out:
569	if len(toolResults) == 0 {
570		return assistantMsg, nil, nil
571	}
572	parts := make([]message.ContentPart, 0)
573	for _, tr := range toolResults {
574		parts = append(parts, tr)
575	}
576	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
577		Role:     message.Tool,
578		Parts:    parts,
579		Provider: a.providerID,
580	})
581	if err != nil {
582		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
583	}
584
585	return assistantMsg, &msg, err
586}
587
588func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
589	msg.AddFinish(finishReason, message, details)
590	_ = a.messages.Update(ctx, *msg)
591}
592
593func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
594	select {
595	case <-ctx.Done():
596		return ctx.Err()
597	default:
598		// Continue processing.
599	}
600
601	switch event.Type {
602	case provider.EventThinkingDelta:
603		assistantMsg.AppendReasoningContent(event.Thinking)
604		return a.messages.Update(ctx, *assistantMsg)
605	case provider.EventSignatureDelta:
606		assistantMsg.AppendReasoningSignature(event.Signature)
607		return a.messages.Update(ctx, *assistantMsg)
608	case provider.EventContentDelta:
609		assistantMsg.FinishThinking()
610		assistantMsg.AppendContent(event.Content)
611		return a.messages.Update(ctx, *assistantMsg)
612	case provider.EventToolUseStart:
613		assistantMsg.FinishThinking()
614		slog.Info("Tool call started", "toolCall", event.ToolCall)
615		assistantMsg.AddToolCall(*event.ToolCall)
616		return a.messages.Update(ctx, *assistantMsg)
617	case provider.EventToolUseDelta:
618		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
619		return a.messages.Update(ctx, *assistantMsg)
620	case provider.EventToolUseStop:
621		slog.Info("Finished tool call", "toolCall", event.ToolCall)
622		assistantMsg.FinishToolCall(event.ToolCall.ID)
623		return a.messages.Update(ctx, *assistantMsg)
624	case provider.EventError:
625		return event.Error
626	case provider.EventComplete:
627		assistantMsg.FinishThinking()
628		assistantMsg.SetToolCalls(event.Response.ToolCalls)
629		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
630		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
631			return fmt.Errorf("failed to update message: %w", err)
632		}
633		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
634	}
635
636	return nil
637}
638
639func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
640	sess, err := a.sessions.Get(ctx, sessionID)
641	if err != nil {
642		return fmt.Errorf("failed to get session: %w", err)
643	}
644
645	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
646		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
647		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
648		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
649
650	sess.Cost += cost
651	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
652	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
653
654	_, err = a.sessions.Save(ctx, sess)
655	if err != nil {
656		return fmt.Errorf("failed to save session: %w", err)
657	}
658	return nil
659}
660
661func (a *agent) Summarize(ctx context.Context, sessionID string) error {
662	if a.summarizeProvider == nil {
663		return fmt.Errorf("summarize provider not available")
664	}
665
666	// Check if session is busy
667	if a.IsSessionBusy(sessionID) {
668		return ErrSessionBusy
669	}
670
671	// Create a new context with cancellation
672	summarizeCtx, cancel := context.WithCancel(ctx)
673
674	// Store the cancel function in activeRequests to allow cancellation
675	a.activeRequests.Store(sessionID+"-summarize", cancel)
676
677	go func() {
678		defer a.activeRequests.Delete(sessionID + "-summarize")
679		defer cancel()
680		event := AgentEvent{
681			Type:     AgentEventTypeSummarize,
682			Progress: "Starting summarization...",
683		}
684
685		a.Publish(pubsub.CreatedEvent, event)
686		// Get all messages from the session
687		msgs, err := a.messages.List(summarizeCtx, sessionID)
688		if err != nil {
689			event = AgentEvent{
690				Type:  AgentEventTypeError,
691				Error: fmt.Errorf("failed to list messages: %w", err),
692				Done:  true,
693			}
694			a.Publish(pubsub.CreatedEvent, event)
695			return
696		}
697		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
698
699		if len(msgs) == 0 {
700			event = AgentEvent{
701				Type:  AgentEventTypeError,
702				Error: fmt.Errorf("no messages to summarize"),
703				Done:  true,
704			}
705			a.Publish(pubsub.CreatedEvent, event)
706			return
707		}
708
709		event = AgentEvent{
710			Type:     AgentEventTypeSummarize,
711			Progress: "Analyzing conversation...",
712		}
713		a.Publish(pubsub.CreatedEvent, event)
714
715		// Add a system message to guide the summarization
716		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
717
718		// Create a new message with the summarize prompt
719		promptMsg := message.Message{
720			Role:  message.User,
721			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
722		}
723
724		// Append the prompt to the messages
725		msgsWithPrompt := append(msgs, promptMsg)
726
727		event = AgentEvent{
728			Type:     AgentEventTypeSummarize,
729			Progress: "Generating summary...",
730		}
731
732		a.Publish(pubsub.CreatedEvent, event)
733
734		// Send the messages to the summarize provider
735		response := a.summarizeProvider.StreamResponse(
736			summarizeCtx,
737			msgsWithPrompt,
738			make([]tools.BaseTool, 0),
739		)
740		var finalResponse *provider.ProviderResponse
741		for r := range response {
742			if r.Error != nil {
743				event = AgentEvent{
744					Type:  AgentEventTypeError,
745					Error: fmt.Errorf("failed to summarize: %w", err),
746					Done:  true,
747				}
748				a.Publish(pubsub.CreatedEvent, event)
749				return
750			}
751			finalResponse = r.Response
752		}
753
754		summary := strings.TrimSpace(finalResponse.Content)
755		if summary == "" {
756			event = AgentEvent{
757				Type:  AgentEventTypeError,
758				Error: fmt.Errorf("empty summary returned"),
759				Done:  true,
760			}
761			a.Publish(pubsub.CreatedEvent, event)
762			return
763		}
764		event = AgentEvent{
765			Type:     AgentEventTypeSummarize,
766			Progress: "Creating new session...",
767		}
768
769		a.Publish(pubsub.CreatedEvent, event)
770		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
771		if err != nil {
772			event = AgentEvent{
773				Type:  AgentEventTypeError,
774				Error: fmt.Errorf("failed to get session: %w", err),
775				Done:  true,
776			}
777
778			a.Publish(pubsub.CreatedEvent, event)
779			return
780		}
781		// Create a message in the new session with the summary
782		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
783			Role: message.Assistant,
784			Parts: []message.ContentPart{
785				message.TextContent{Text: summary},
786				message.Finish{
787					Reason: message.FinishReasonEndTurn,
788					Time:   time.Now().Unix(),
789				},
790			},
791			Model:    a.summarizeProvider.Model().ID,
792			Provider: a.summarizeProviderID,
793		})
794		if err != nil {
795			event = AgentEvent{
796				Type:  AgentEventTypeError,
797				Error: fmt.Errorf("failed to create summary message: %w", err),
798				Done:  true,
799			}
800
801			a.Publish(pubsub.CreatedEvent, event)
802			return
803		}
804		oldSession.SummaryMessageID = msg.ID
805		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
806		oldSession.PromptTokens = 0
807		model := a.summarizeProvider.Model()
808		usage := finalResponse.Usage
809		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
810			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
811			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
812			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
813		oldSession.Cost += cost
814		_, err = a.sessions.Save(summarizeCtx, oldSession)
815		if err != nil {
816			event = AgentEvent{
817				Type:  AgentEventTypeError,
818				Error: fmt.Errorf("failed to save session: %w", err),
819				Done:  true,
820			}
821			a.Publish(pubsub.CreatedEvent, event)
822		}
823
824		event = AgentEvent{
825			Type:      AgentEventTypeSummarize,
826			SessionID: oldSession.ID,
827			Progress:  "Summary complete",
828			Done:      true,
829		}
830		a.Publish(pubsub.CreatedEvent, event)
831		// Send final success event with the new session ID
832	}()
833
834	return nil
835}
836
837func (a *agent) CancelAll() {
838	if !a.IsBusy() {
839		return
840	}
841	a.activeRequests.Range(func(key, value any) bool {
842		a.Cancel(key.(string)) // key is sessionID
843		return true
844	})
845
846	timeout := time.After(5 * time.Second)
847	for a.IsBusy() {
848		select {
849		case <-timeout:
850			return
851		default:
852			time.Sleep(200 * time.Millisecond)
853		}
854	}
855}
856
857func (a *agent) UpdateModel() error {
858	cfg := config.Get()
859
860	// Get current provider configuration
861	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
862	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
863		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
864	}
865
866	// Check if provider has changed
867	if string(currentProviderCfg.ID) != a.providerID {
868		// Provider changed, need to recreate the main provider
869		model := cfg.GetModelByType(a.agentCfg.Model)
870		if model.ID == "" {
871			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
872		}
873
874		promptID := agentPromptMap[a.agentCfg.ID]
875		if promptID == "" {
876			promptID = prompt.PromptDefault
877		}
878
879		opts := []provider.ProviderClientOption{
880			provider.WithModel(a.agentCfg.Model),
881			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
882		}
883
884		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
885		if err != nil {
886			return fmt.Errorf("failed to create new provider: %w", err)
887		}
888
889		// Update the provider and provider ID
890		a.provider = newProvider
891		a.providerID = string(currentProviderCfg.ID)
892	}
893
894	// Check if small model provider has changed (affects title and summarize providers)
895	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
896	var smallModelProviderCfg config.ProviderConfig
897
898	for _, p := range cfg.Providers {
899		if p.ID == smallModelCfg.Provider {
900			smallModelProviderCfg = p
901			break
902		}
903	}
904
905	if smallModelProviderCfg.ID == "" {
906		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
907	}
908
909	// Check if summarize provider has changed
910	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
911		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
912		if smallModel == nil {
913			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
914		}
915
916		// Recreate title provider
917		titleOpts := []provider.ProviderClientOption{
918			provider.WithModel(config.SelectedModelTypeSmall),
919			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
920			// We want the title to be short, so we limit the max tokens
921			provider.WithMaxTokens(40),
922		}
923		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
924		if err != nil {
925			return fmt.Errorf("failed to create new title provider: %w", err)
926		}
927
928		// Recreate summarize provider
929		summarizeOpts := []provider.ProviderClientOption{
930			provider.WithModel(config.SelectedModelTypeSmall),
931			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
932		}
933		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
934		if err != nil {
935			return fmt.Errorf("failed to create new summarize provider: %w", err)
936		}
937
938		// Update the providers and provider ID
939		a.titleProvider = newTitleProvider
940		a.summarizeProvider = newSummarizeProvider
941		a.summarizeProviderID = string(smallModelProviderCfg.ID)
942	}
943
944	return nil
945}