1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	fur "github.com/charmbracelet/crush/internal/fur/provider"
 16	"github.com/charmbracelet/crush/internal/history"
 17	"github.com/charmbracelet/crush/internal/llm/prompt"
 18	"github.com/charmbracelet/crush/internal/llm/provider"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/log"
 21	"github.com/charmbracelet/crush/internal/lsp"
 22	"github.com/charmbracelet/crush/internal/message"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/pubsub"
 25	"github.com/charmbracelet/crush/internal/session"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() fur.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70
 71	tools *csync.LazySlice[tools.BaseTool]
 72
 73	provider   provider.Provider
 74	providerID string
 75
 76	titleProvider       provider.Provider
 77	summarizeProvider   provider.Provider
 78	summarizeProviderID string
 79
 80	activeRequests sync.Map
 81}
 82
 83var agentPromptMap = map[string]prompt.PromptID{
 84	"coder": prompt.PromptCoder,
 85	"task":  prompt.PromptTask,
 86}
 87
 88func NewAgent(
 89	agentCfg config.Agent,
 90	// These services are needed in the tools
 91	permissions permission.Service,
 92	sessions session.Service,
 93	messages message.Service,
 94	history history.Service,
 95	lspClients map[string]*lsp.Client,
 96) (Service, error) {
 97	ctx := context.Background()
 98	cfg := config.Get()
 99
100	var agentTool tools.BaseTool
101	if agentCfg.ID == "coder" {
102		taskAgentCfg := config.Get().Agents["task"]
103		if taskAgentCfg.ID == "" {
104			return nil, fmt.Errorf("task agent not found in config")
105		}
106		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
107		if err != nil {
108			return nil, fmt.Errorf("failed to create task agent: %w", err)
109		}
110
111		agentTool = NewAgentTool(taskAgent, sessions, messages)
112	}
113
114	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
115	if providerCfg == nil {
116		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
117	}
118	model := config.Get().GetModelByType(agentCfg.Model)
119
120	if model == nil {
121		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
122	}
123
124	promptID := agentPromptMap[agentCfg.ID]
125	if promptID == "" {
126		promptID = prompt.PromptDefault
127	}
128	opts := []provider.ProviderClientOption{
129		provider.WithModel(agentCfg.Model),
130		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
131	}
132	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
133	if err != nil {
134		return nil, err
135	}
136
137	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
138	var smallModelProviderCfg *config.ProviderConfig
139	if smallModelCfg.Provider == providerCfg.ID {
140		smallModelProviderCfg = providerCfg
141	} else {
142		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
143
144		if smallModelProviderCfg.ID == "" {
145			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
146		}
147	}
148	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
149	if smallModel.ID == "" {
150		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
151	}
152
153	titleOpts := []provider.ProviderClientOption{
154		provider.WithModel(config.SelectedModelTypeSmall),
155		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
156	}
157	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
158	if err != nil {
159		return nil, err
160	}
161	summarizeOpts := []provider.ProviderClientOption{
162		provider.WithModel(config.SelectedModelTypeSmall),
163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
164	}
165	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
166	if err != nil {
167		return nil, err
168	}
169
170	toolFn := func() []tools.BaseTool {
171		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
172		defer func() {
173			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
174		}()
175
176		cwd := cfg.WorkingDir()
177		allTools := []tools.BaseTool{
178			tools.NewBashTool(permissions, cwd),
179			tools.NewDownloadTool(permissions, cwd),
180			tools.NewEditTool(lspClients, permissions, history, cwd),
181			tools.NewFetchTool(permissions, cwd),
182			tools.NewGlobTool(cwd),
183			tools.NewGrepTool(cwd),
184			tools.NewLsTool(cwd),
185			tools.NewSourcegraphTool(),
186			tools.NewViewTool(lspClients, cwd),
187			tools.NewWriteTool(lspClients, permissions, history, cwd),
188		}
189
190		mcpTools := GetMCPTools(ctx, permissions, cfg)
191		allTools = append(allTools, mcpTools...)
192
193		if len(lspClients) > 0 {
194			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
195		}
196
197		if agentTool != nil {
198			allTools = append(allTools, agentTool)
199		}
200
201		if agentCfg.AllowedTools == nil {
202			return allTools
203		}
204
205		var filteredTools []tools.BaseTool
206		for _, tool := range allTools {
207			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
208				filteredTools = append(filteredTools, tool)
209			}
210		}
211		return filteredTools
212	}
213
214	return &agent{
215		Broker:              pubsub.NewBroker[AgentEvent](),
216		agentCfg:            agentCfg,
217		provider:            agentProvider,
218		providerID:          string(providerCfg.ID),
219		messages:            messages,
220		sessions:            sessions,
221		titleProvider:       titleProvider,
222		summarizeProvider:   summarizeProvider,
223		summarizeProviderID: string(smallModelProviderCfg.ID),
224		activeRequests:      sync.Map{},
225		tools:               csync.NewLazySlice(toolFn),
226	}, nil
227}
228
229func (a *agent) Model() fur.Model {
230	return *config.Get().GetModelByType(a.agentCfg.Model)
231}
232
233func (a *agent) Cancel(sessionID string) {
234	// Cancel regular requests
235	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
236		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
237			slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
238			cancel()
239		}
240	}
241
242	// Also check for summarize requests
243	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
244		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
245			slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
246			cancel()
247		}
248	}
249}
250
251func (a *agent) IsBusy() bool {
252	busy := false
253	a.activeRequests.Range(func(key, value any) bool {
254		if cancelFunc, ok := value.(context.CancelFunc); ok {
255			if cancelFunc != nil {
256				busy = true
257				return false
258			}
259		}
260		return true
261	})
262	return busy
263}
264
265func (a *agent) IsSessionBusy(sessionID string) bool {
266	_, busy := a.activeRequests.Load(sessionID)
267	return busy
268}
269
270func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
271	if content == "" {
272		return nil
273	}
274	if a.titleProvider == nil {
275		return nil
276	}
277	session, err := a.sessions.Get(ctx, sessionID)
278	if err != nil {
279		return err
280	}
281	parts := []message.ContentPart{message.TextContent{
282		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
283	}}
284
285	// Use streaming approach like summarization
286	response := a.titleProvider.StreamResponse(
287		ctx,
288		[]message.Message{
289			{
290				Role:  message.User,
291				Parts: parts,
292			},
293		},
294		make([]tools.BaseTool, 0),
295	)
296
297	var finalResponse *provider.ProviderResponse
298	for r := range response {
299		if r.Error != nil {
300			return r.Error
301		}
302		finalResponse = r.Response
303	}
304
305	if finalResponse == nil {
306		return fmt.Errorf("no response received from title provider")
307	}
308
309	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
310	if title == "" {
311		return nil
312	}
313
314	session.Title = title
315	_, err = a.sessions.Save(ctx, session)
316	return err
317}
318
319func (a *agent) err(err error) AgentEvent {
320	return AgentEvent{
321		Type:  AgentEventTypeError,
322		Error: err,
323	}
324}
325
326func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
327	if !a.Model().SupportsImages && attachments != nil {
328		attachments = nil
329	}
330	events := make(chan AgentEvent)
331	if a.IsSessionBusy(sessionID) {
332		return nil, ErrSessionBusy
333	}
334
335	genCtx, cancel := context.WithCancel(ctx)
336
337	a.activeRequests.Store(sessionID, cancel)
338	go func() {
339		slog.Debug("Request started", "sessionID", sessionID)
340		defer log.RecoverPanic("agent.Run", func() {
341			events <- a.err(fmt.Errorf("panic while running the agent"))
342		})
343		var attachmentParts []message.ContentPart
344		for _, attachment := range attachments {
345			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
346		}
347		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
348		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
349			slog.Error(result.Error.Error())
350		}
351		slog.Debug("Request completed", "sessionID", sessionID)
352		a.activeRequests.Delete(sessionID)
353		cancel()
354		a.Publish(pubsub.CreatedEvent, result)
355		events <- result
356		close(events)
357	}()
358	return events, nil
359}
360
361func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
362	cfg := config.Get()
363	// List existing messages; if none, start title generation asynchronously.
364	msgs, err := a.messages.List(ctx, sessionID)
365	if err != nil {
366		return a.err(fmt.Errorf("failed to list messages: %w", err))
367	}
368	if len(msgs) == 0 {
369		go func() {
370			defer log.RecoverPanic("agent.Run", func() {
371				slog.Error("panic while generating title")
372			})
373			titleErr := a.generateTitle(context.Background(), sessionID, content)
374			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
375				slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
376			}
377		}()
378	}
379	session, err := a.sessions.Get(ctx, sessionID)
380	if err != nil {
381		return a.err(fmt.Errorf("failed to get session: %w", err))
382	}
383	if session.SummaryMessageID != "" {
384		summaryMsgInex := -1
385		for i, msg := range msgs {
386			if msg.ID == session.SummaryMessageID {
387				summaryMsgInex = i
388				break
389			}
390		}
391		if summaryMsgInex != -1 {
392			msgs = msgs[summaryMsgInex:]
393			msgs[0].Role = message.User
394		}
395	}
396
397	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
398	if err != nil {
399		return a.err(fmt.Errorf("failed to create user message: %w", err))
400	}
401	// Append the new user message to the conversation history.
402	msgHistory := append(msgs, userMsg)
403
404	for {
405		// Check for cancellation before each iteration
406		select {
407		case <-ctx.Done():
408			return a.err(ctx.Err())
409		default:
410			// Continue processing
411		}
412		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
413		if err != nil {
414			if errors.Is(err, context.Canceled) {
415				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
416				a.messages.Update(context.Background(), agentMessage)
417				return a.err(ErrRequestCancelled)
418			}
419			return a.err(fmt.Errorf("failed to process events: %w", err))
420		}
421		if cfg.Options.Debug {
422			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
423		}
424		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
425			// We are not done, we need to respond with the tool response
426			msgHistory = append(msgHistory, agentMessage, *toolResults)
427			continue
428		}
429		return AgentEvent{
430			Type:    AgentEventTypeResponse,
431			Message: agentMessage,
432			Done:    true,
433		}
434	}
435}
436
437func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
438	parts := []message.ContentPart{message.TextContent{Text: content}}
439	parts = append(parts, attachmentParts...)
440	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
441		Role:  message.User,
442		Parts: parts,
443	})
444}
445
446func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
447	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
448	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Iter()))
449
450	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
451		Role:     message.Assistant,
452		Parts:    []message.ContentPart{},
453		Model:    a.Model().ID,
454		Provider: a.providerID,
455	})
456	if err != nil {
457		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
458	}
459
460	// Add the session and message ID into the context if needed by tools.
461	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
462
463	// Process each event in the stream.
464	for event := range eventChan {
465		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
466			if errors.Is(processErr, context.Canceled) {
467				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
468			} else {
469				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
470			}
471			return assistantMsg, nil, processErr
472		}
473		if ctx.Err() != nil {
474			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
475			return assistantMsg, nil, ctx.Err()
476		}
477	}
478
479	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
480	toolCalls := assistantMsg.ToolCalls()
481	for i, toolCall := range toolCalls {
482		select {
483		case <-ctx.Done():
484			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
485			// Make all future tool calls cancelled
486			for j := i; j < len(toolCalls); j++ {
487				toolResults[j] = message.ToolResult{
488					ToolCallID: toolCalls[j].ID,
489					Content:    "Tool execution canceled by user",
490					IsError:    true,
491				}
492			}
493			goto out
494		default:
495			// Continue processing
496			var tool tools.BaseTool
497			for availableTool := range a.tools.Iter() {
498				if availableTool.Info().Name == toolCall.Name {
499					tool = availableTool
500					break
501				}
502			}
503
504			// Tool not found
505			if tool == nil {
506				toolResults[i] = message.ToolResult{
507					ToolCallID: toolCall.ID,
508					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
509					IsError:    true,
510				}
511				continue
512			}
513
514			// Run tool in goroutine to allow cancellation
515			type toolExecResult struct {
516				response tools.ToolResponse
517				err      error
518			}
519			resultChan := make(chan toolExecResult, 1)
520
521			go func() {
522				response, err := tool.Run(ctx, tools.ToolCall{
523					ID:    toolCall.ID,
524					Name:  toolCall.Name,
525					Input: toolCall.Input,
526				})
527				resultChan <- toolExecResult{response: response, err: err}
528			}()
529
530			var toolResponse tools.ToolResponse
531			var toolErr error
532
533			select {
534			case <-ctx.Done():
535				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
536				// Mark remaining tool calls as cancelled
537				for j := i; j < len(toolCalls); j++ {
538					toolResults[j] = message.ToolResult{
539						ToolCallID: toolCalls[j].ID,
540						Content:    "Tool execution canceled by user",
541						IsError:    true,
542					}
543				}
544				goto out
545			case result := <-resultChan:
546				toolResponse = result.response
547				toolErr = result.err
548			}
549
550			if toolErr != nil {
551				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
552				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
553					toolResults[i] = message.ToolResult{
554						ToolCallID: toolCall.ID,
555						Content:    "Permission denied",
556						IsError:    true,
557					}
558					for j := i + 1; j < len(toolCalls); j++ {
559						toolResults[j] = message.ToolResult{
560							ToolCallID: toolCalls[j].ID,
561							Content:    "Tool execution canceled by user",
562							IsError:    true,
563						}
564					}
565					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
566					break
567				}
568			}
569			toolResults[i] = message.ToolResult{
570				ToolCallID: toolCall.ID,
571				Content:    toolResponse.Content,
572				Metadata:   toolResponse.Metadata,
573				IsError:    toolResponse.IsError,
574			}
575		}
576	}
577out:
578	if len(toolResults) == 0 {
579		return assistantMsg, nil, nil
580	}
581	parts := make([]message.ContentPart, 0)
582	for _, tr := range toolResults {
583		parts = append(parts, tr)
584	}
585	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
586		Role:     message.Tool,
587		Parts:    parts,
588		Provider: a.providerID,
589	})
590	if err != nil {
591		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
592	}
593
594	return assistantMsg, &msg, err
595}
596
597func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
598	msg.AddFinish(finishReason, message, details)
599	_ = a.messages.Update(ctx, *msg)
600}
601
602func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
603	select {
604	case <-ctx.Done():
605		return ctx.Err()
606	default:
607		// Continue processing.
608	}
609
610	switch event.Type {
611	case provider.EventThinkingDelta:
612		assistantMsg.AppendReasoningContent(event.Thinking)
613		return a.messages.Update(ctx, *assistantMsg)
614	case provider.EventSignatureDelta:
615		assistantMsg.AppendReasoningSignature(event.Signature)
616		return a.messages.Update(ctx, *assistantMsg)
617	case provider.EventContentDelta:
618		assistantMsg.FinishThinking()
619		assistantMsg.AppendContent(event.Content)
620		return a.messages.Update(ctx, *assistantMsg)
621	case provider.EventToolUseStart:
622		assistantMsg.FinishThinking()
623		slog.Info("Tool call started", "toolCall", event.ToolCall)
624		assistantMsg.AddToolCall(*event.ToolCall)
625		return a.messages.Update(ctx, *assistantMsg)
626	case provider.EventToolUseDelta:
627		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
628		return a.messages.Update(ctx, *assistantMsg)
629	case provider.EventToolUseStop:
630		slog.Info("Finished tool call", "toolCall", event.ToolCall)
631		assistantMsg.FinishToolCall(event.ToolCall.ID)
632		return a.messages.Update(ctx, *assistantMsg)
633	case provider.EventError:
634		return event.Error
635	case provider.EventComplete:
636		assistantMsg.FinishThinking()
637		assistantMsg.SetToolCalls(event.Response.ToolCalls)
638		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
639		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
640			return fmt.Errorf("failed to update message: %w", err)
641		}
642		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
643	}
644
645	return nil
646}
647
648func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
649	sess, err := a.sessions.Get(ctx, sessionID)
650	if err != nil {
651		return fmt.Errorf("failed to get session: %w", err)
652	}
653
654	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
655		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
656		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
657		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
658
659	sess.Cost += cost
660	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
661	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
662
663	_, err = a.sessions.Save(ctx, sess)
664	if err != nil {
665		return fmt.Errorf("failed to save session: %w", err)
666	}
667	return nil
668}
669
670func (a *agent) Summarize(ctx context.Context, sessionID string) error {
671	if a.summarizeProvider == nil {
672		return fmt.Errorf("summarize provider not available")
673	}
674
675	// Check if session is busy
676	if a.IsSessionBusy(sessionID) {
677		return ErrSessionBusy
678	}
679
680	// Create a new context with cancellation
681	summarizeCtx, cancel := context.WithCancel(ctx)
682
683	// Store the cancel function in activeRequests to allow cancellation
684	a.activeRequests.Store(sessionID+"-summarize", cancel)
685
686	go func() {
687		defer a.activeRequests.Delete(sessionID + "-summarize")
688		defer cancel()
689		event := AgentEvent{
690			Type:     AgentEventTypeSummarize,
691			Progress: "Starting summarization...",
692		}
693
694		a.Publish(pubsub.CreatedEvent, event)
695		// Get all messages from the session
696		msgs, err := a.messages.List(summarizeCtx, sessionID)
697		if err != nil {
698			event = AgentEvent{
699				Type:  AgentEventTypeError,
700				Error: fmt.Errorf("failed to list messages: %w", err),
701				Done:  true,
702			}
703			a.Publish(pubsub.CreatedEvent, event)
704			return
705		}
706		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
707
708		if len(msgs) == 0 {
709			event = AgentEvent{
710				Type:  AgentEventTypeError,
711				Error: fmt.Errorf("no messages to summarize"),
712				Done:  true,
713			}
714			a.Publish(pubsub.CreatedEvent, event)
715			return
716		}
717
718		event = AgentEvent{
719			Type:     AgentEventTypeSummarize,
720			Progress: "Analyzing conversation...",
721		}
722		a.Publish(pubsub.CreatedEvent, event)
723
724		// Add a system message to guide the summarization
725		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
726
727		// Create a new message with the summarize prompt
728		promptMsg := message.Message{
729			Role:  message.User,
730			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
731		}
732
733		// Append the prompt to the messages
734		msgsWithPrompt := append(msgs, promptMsg)
735
736		event = AgentEvent{
737			Type:     AgentEventTypeSummarize,
738			Progress: "Generating summary...",
739		}
740
741		a.Publish(pubsub.CreatedEvent, event)
742
743		// Send the messages to the summarize provider
744		response := a.summarizeProvider.StreamResponse(
745			summarizeCtx,
746			msgsWithPrompt,
747			make([]tools.BaseTool, 0),
748		)
749		var finalResponse *provider.ProviderResponse
750		for r := range response {
751			if r.Error != nil {
752				event = AgentEvent{
753					Type:  AgentEventTypeError,
754					Error: fmt.Errorf("failed to summarize: %w", err),
755					Done:  true,
756				}
757				a.Publish(pubsub.CreatedEvent, event)
758				return
759			}
760			finalResponse = r.Response
761		}
762
763		summary := strings.TrimSpace(finalResponse.Content)
764		if summary == "" {
765			event = AgentEvent{
766				Type:  AgentEventTypeError,
767				Error: fmt.Errorf("empty summary returned"),
768				Done:  true,
769			}
770			a.Publish(pubsub.CreatedEvent, event)
771			return
772		}
773		event = AgentEvent{
774			Type:     AgentEventTypeSummarize,
775			Progress: "Creating new session...",
776		}
777
778		a.Publish(pubsub.CreatedEvent, event)
779		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
780		if err != nil {
781			event = AgentEvent{
782				Type:  AgentEventTypeError,
783				Error: fmt.Errorf("failed to get session: %w", err),
784				Done:  true,
785			}
786
787			a.Publish(pubsub.CreatedEvent, event)
788			return
789		}
790		// Create a message in the new session with the summary
791		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
792			Role: message.Assistant,
793			Parts: []message.ContentPart{
794				message.TextContent{Text: summary},
795				message.Finish{
796					Reason: message.FinishReasonEndTurn,
797					Time:   time.Now().Unix(),
798				},
799			},
800			Model:    a.summarizeProvider.Model().ID,
801			Provider: a.summarizeProviderID,
802		})
803		if err != nil {
804			event = AgentEvent{
805				Type:  AgentEventTypeError,
806				Error: fmt.Errorf("failed to create summary message: %w", err),
807				Done:  true,
808			}
809
810			a.Publish(pubsub.CreatedEvent, event)
811			return
812		}
813		oldSession.SummaryMessageID = msg.ID
814		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
815		oldSession.PromptTokens = 0
816		model := a.summarizeProvider.Model()
817		usage := finalResponse.Usage
818		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
819			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
820			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
821			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
822		oldSession.Cost += cost
823		_, err = a.sessions.Save(summarizeCtx, oldSession)
824		if err != nil {
825			event = AgentEvent{
826				Type:  AgentEventTypeError,
827				Error: fmt.Errorf("failed to save session: %w", err),
828				Done:  true,
829			}
830			a.Publish(pubsub.CreatedEvent, event)
831		}
832
833		event = AgentEvent{
834			Type:      AgentEventTypeSummarize,
835			SessionID: oldSession.ID,
836			Progress:  "Summary complete",
837			Done:      true,
838		}
839		a.Publish(pubsub.CreatedEvent, event)
840		// Send final success event with the new session ID
841	}()
842
843	return nil
844}
845
846func (a *agent) CancelAll() {
847	if !a.IsBusy() {
848		return
849	}
850	a.activeRequests.Range(func(key, value any) bool {
851		a.Cancel(key.(string)) // key is sessionID
852		return true
853	})
854
855	timeout := time.After(5 * time.Second)
856	for a.IsBusy() {
857		select {
858		case <-timeout:
859			return
860		default:
861			time.Sleep(200 * time.Millisecond)
862		}
863	}
864}
865
866func (a *agent) UpdateModel() error {
867	cfg := config.Get()
868
869	// Get current provider configuration
870	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
871	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
872		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
873	}
874
875	// Check if provider has changed
876	if string(currentProviderCfg.ID) != a.providerID {
877		// Provider changed, need to recreate the main provider
878		model := cfg.GetModelByType(a.agentCfg.Model)
879		if model.ID == "" {
880			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
881		}
882
883		promptID := agentPromptMap[a.agentCfg.ID]
884		if promptID == "" {
885			promptID = prompt.PromptDefault
886		}
887
888		opts := []provider.ProviderClientOption{
889			provider.WithModel(a.agentCfg.Model),
890			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
891		}
892
893		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
894		if err != nil {
895			return fmt.Errorf("failed to create new provider: %w", err)
896		}
897
898		// Update the provider and provider ID
899		a.provider = newProvider
900		a.providerID = string(currentProviderCfg.ID)
901	}
902
903	// Check if small model provider has changed (affects title and summarize providers)
904	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
905	var smallModelProviderCfg config.ProviderConfig
906
907	for _, p := range cfg.Providers.Seq2() {
908		if p.ID == smallModelCfg.Provider {
909			smallModelProviderCfg = p
910			break
911		}
912	}
913
914	if smallModelProviderCfg.ID == "" {
915		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
916	}
917
918	// Check if summarize provider has changed
919	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
920		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
921		if smallModel == nil {
922			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
923		}
924
925		// Recreate title provider
926		titleOpts := []provider.ProviderClientOption{
927			provider.WithModel(config.SelectedModelTypeSmall),
928			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
929			// We want the title to be short, so we limit the max tokens
930			provider.WithMaxTokens(40),
931		}
932		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
933		if err != nil {
934			return fmt.Errorf("failed to create new title provider: %w", err)
935		}
936
937		// Recreate summarize provider
938		summarizeOpts := []provider.ProviderClientOption{
939			provider.WithModel(config.SelectedModelTypeSmall),
940			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
941		}
942		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
943		if err != nil {
944			return fmt.Errorf("failed to create new summarize provider: %w", err)
945		}
946
947		// Update the providers and provider ID
948		a.titleProvider = newTitleProvider
949		a.summarizeProvider = newSummarizeProvider
950		a.summarizeProviderID = string(smallModelProviderCfg.ID)
951	}
952
953	return nil
954}