1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/csync"
 16	"github.com/charmbracelet/crush/internal/history"
 17	"github.com/charmbracelet/crush/internal/llm/prompt"
 18	"github.com/charmbracelet/crush/internal/llm/provider"
 19	"github.com/charmbracelet/crush/internal/llm/tools"
 20	"github.com/charmbracelet/crush/internal/log"
 21	"github.com/charmbracelet/crush/internal/lsp"
 22	"github.com/charmbracelet/crush/internal/message"
 23	"github.com/charmbracelet/crush/internal/permission"
 24	"github.com/charmbracelet/crush/internal/pubsub"
 25	"github.com/charmbracelet/crush/internal/session"
 26	"github.com/charmbracelet/crush/internal/shell"
 27)
 28
 29// Common errors
 30var (
 31	ErrRequestCancelled = errors.New("request canceled by user")
 32	ErrSessionBusy      = errors.New("session is currently processing another request")
 33)
 34
 35type AgentEventType string
 36
 37const (
 38	AgentEventTypeError     AgentEventType = "error"
 39	AgentEventTypeResponse  AgentEventType = "response"
 40	AgentEventTypeSummarize AgentEventType = "summarize"
 41)
 42
 43type AgentEvent struct {
 44	Type    AgentEventType
 45	Message message.Message
 46	Error   error
 47
 48	// When summarizing
 49	SessionID string
 50	Progress  string
 51	Done      bool
 52}
 53
 54type Service interface {
 55	pubsub.Suscriber[AgentEvent]
 56	Model() catwalk.Model
 57	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 58	Cancel(sessionID string)
 59	CancelAll()
 60	IsSessionBusy(sessionID string) bool
 61	IsBusy() bool
 62	Summarize(ctx context.Context, sessionID string) error
 63	UpdateModel() error
 64}
 65
 66type agent struct {
 67	*pubsub.Broker[AgentEvent]
 68	agentCfg config.Agent
 69	sessions session.Service
 70	messages message.Service
 71
 72	tools *csync.LazySlice[tools.BaseTool]
 73
 74	provider   provider.Provider
 75	providerID string
 76
 77	titleProvider       provider.Provider
 78	summarizeProvider   provider.Provider
 79	summarizeProviderID string
 80
 81	activeRequests sync.Map
 82}
 83
 84var agentPromptMap = map[string]prompt.PromptID{
 85	"coder": prompt.PromptCoder,
 86	"task":  prompt.PromptTask,
 87}
 88
 89func NewAgent(
 90	agentCfg config.Agent,
 91	// These services are needed in the tools
 92	permissions permission.Service,
 93	sessions session.Service,
 94	messages message.Service,
 95	history history.Service,
 96	lspClients map[string]*lsp.Client,
 97) (Service, error) {
 98	ctx := context.Background()
 99	cfg := config.Get()
100
101	var agentTool tools.BaseTool
102	if agentCfg.ID == "coder" {
103		taskAgentCfg := config.Get().Agents["task"]
104		if taskAgentCfg.ID == "" {
105			return nil, fmt.Errorf("task agent not found in config")
106		}
107		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
108		if err != nil {
109			return nil, fmt.Errorf("failed to create task agent: %w", err)
110		}
111
112		agentTool = NewAgentTool(taskAgent, sessions, messages)
113	}
114
115	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116	if providerCfg == nil {
117		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118	}
119	model := config.Get().GetModelByType(agentCfg.Model)
120
121	if model == nil {
122		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123	}
124
125	promptID := agentPromptMap[agentCfg.ID]
126	if promptID == "" {
127		promptID = prompt.PromptDefault
128	}
129	opts := []provider.ProviderClientOption{
130		provider.WithModel(agentCfg.Model),
131		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132	}
133	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134	if err != nil {
135		return nil, err
136	}
137
138	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139	var smallModelProviderCfg *config.ProviderConfig
140	if smallModelCfg.Provider == providerCfg.ID {
141		smallModelProviderCfg = providerCfg
142	} else {
143		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145		if smallModelProviderCfg.ID == "" {
146			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147		}
148	}
149	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150	if smallModel.ID == "" {
151		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152	}
153
154	titleOpts := []provider.ProviderClientOption{
155		provider.WithModel(config.SelectedModelTypeSmall),
156		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157	}
158	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159	if err != nil {
160		return nil, err
161	}
162	summarizeOpts := []provider.ProviderClientOption{
163		provider.WithModel(config.SelectedModelTypeSmall),
164		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165	}
166	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167	if err != nil {
168		return nil, err
169	}
170
171	toolFn := func() []tools.BaseTool {
172		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173		defer func() {
174			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175		}()
176
177		cwd := cfg.WorkingDir()
178		allTools := []tools.BaseTool{
179			tools.NewBashTool(permissions, cwd),
180			tools.NewDownloadTool(permissions, cwd),
181			tools.NewEditTool(lspClients, permissions, history, cwd),
182			tools.NewFetchTool(permissions, cwd),
183			tools.NewGlobTool(cwd),
184			tools.NewGrepTool(cwd),
185			tools.NewLsTool(cwd),
186			tools.NewSourcegraphTool(),
187			tools.NewViewTool(lspClients, cwd),
188			tools.NewWriteTool(lspClients, permissions, history, cwd),
189		}
190
191		mcpTools := GetMCPTools(ctx, permissions, cfg)
192		allTools = append(allTools, mcpTools...)
193
194		if len(lspClients) > 0 {
195			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
196		}
197
198		if agentTool != nil {
199			allTools = append(allTools, agentTool)
200		}
201
202		if agentCfg.AllowedTools == nil {
203			return allTools
204		}
205
206		var filteredTools []tools.BaseTool
207		for _, tool := range allTools {
208			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
209				filteredTools = append(filteredTools, tool)
210			}
211		}
212		return filteredTools
213	}
214
215	return &agent{
216		Broker:              pubsub.NewBroker[AgentEvent](),
217		agentCfg:            agentCfg,
218		provider:            agentProvider,
219		providerID:          string(providerCfg.ID),
220		messages:            messages,
221		sessions:            sessions,
222		titleProvider:       titleProvider,
223		summarizeProvider:   summarizeProvider,
224		summarizeProviderID: string(smallModelProviderCfg.ID),
225		activeRequests:      sync.Map{},
226		tools:               csync.NewLazySlice(toolFn),
227	}, nil
228}
229
230func (a *agent) Model() catwalk.Model {
231	return *config.Get().GetModelByType(a.agentCfg.Model)
232}
233
234func (a *agent) Cancel(sessionID string) {
235	// Cancel regular requests
236	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
237		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
238			slog.Info("Request cancellation initiated", "session_id", sessionID)
239			cancel()
240		}
241	}
242
243	// Also check for summarize requests
244	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
245		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
246			slog.Info("Summarize cancellation initiated", "session_id", sessionID)
247			cancel()
248		}
249	}
250}
251
252func (a *agent) IsBusy() bool {
253	busy := false
254	a.activeRequests.Range(func(key, value any) bool {
255		if cancelFunc, ok := value.(context.CancelFunc); ok {
256			if cancelFunc != nil {
257				busy = true
258				return false
259			}
260		}
261		return true
262	})
263	return busy
264}
265
266func (a *agent) IsSessionBusy(sessionID string) bool {
267	_, busy := a.activeRequests.Load(sessionID)
268	return busy
269}
270
271func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
272	if content == "" {
273		return nil
274	}
275	if a.titleProvider == nil {
276		return nil
277	}
278	session, err := a.sessions.Get(ctx, sessionID)
279	if err != nil {
280		return err
281	}
282	parts := []message.ContentPart{message.TextContent{
283		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
284	}}
285
286	// Use streaming approach like summarization
287	response := a.titleProvider.StreamResponse(
288		ctx,
289		[]message.Message{
290			{
291				Role:  message.User,
292				Parts: parts,
293			},
294		},
295		nil,
296	)
297
298	var finalResponse *provider.ProviderResponse
299	for r := range response {
300		if r.Error != nil {
301			return r.Error
302		}
303		finalResponse = r.Response
304	}
305
306	if finalResponse == nil {
307		return fmt.Errorf("no response received from title provider")
308	}
309
310	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
311	if title == "" {
312		return nil
313	}
314
315	session.Title = title
316	_, err = a.sessions.Save(ctx, session)
317	return err
318}
319
320func (a *agent) err(err error) AgentEvent {
321	return AgentEvent{
322		Type:  AgentEventTypeError,
323		Error: err,
324	}
325}
326
327func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
328	if !a.Model().SupportsImages && attachments != nil {
329		attachments = nil
330	}
331	events := make(chan AgentEvent)
332	if a.IsSessionBusy(sessionID) {
333		return nil, ErrSessionBusy
334	}
335
336	genCtx, cancel := context.WithCancel(ctx)
337
338	a.activeRequests.Store(sessionID, cancel)
339	go func() {
340		slog.Debug("Request started", "sessionID", sessionID)
341		defer log.RecoverPanic("agent.Run", func() {
342			events <- a.err(fmt.Errorf("panic while running the agent"))
343		})
344		var attachmentParts []message.ContentPart
345		for _, attachment := range attachments {
346			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
347		}
348		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
349		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
350			slog.Error(result.Error.Error())
351		}
352		slog.Debug("Request completed", "sessionID", sessionID)
353		a.activeRequests.Delete(sessionID)
354		cancel()
355		a.Publish(pubsub.CreatedEvent, result)
356		events <- result
357		close(events)
358	}()
359	return events, nil
360}
361
362func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
363	cfg := config.Get()
364	// List existing messages; if none, start title generation asynchronously.
365	msgs, err := a.messages.List(ctx, sessionID)
366	if err != nil {
367		return a.err(fmt.Errorf("failed to list messages: %w", err))
368	}
369	if len(msgs) == 0 {
370		go func() {
371			defer log.RecoverPanic("agent.Run", func() {
372				slog.Error("panic while generating title")
373			})
374			titleErr := a.generateTitle(context.Background(), sessionID, content)
375			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
376				slog.Error("failed to generate title", "error", titleErr)
377			}
378		}()
379	}
380	session, err := a.sessions.Get(ctx, sessionID)
381	if err != nil {
382		return a.err(fmt.Errorf("failed to get session: %w", err))
383	}
384	if session.SummaryMessageID != "" {
385		summaryMsgInex := -1
386		for i, msg := range msgs {
387			if msg.ID == session.SummaryMessageID {
388				summaryMsgInex = i
389				break
390			}
391		}
392		if summaryMsgInex != -1 {
393			msgs = msgs[summaryMsgInex:]
394			msgs[0].Role = message.User
395		}
396	}
397
398	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
399	if err != nil {
400		return a.err(fmt.Errorf("failed to create user message: %w", err))
401	}
402	// Append the new user message to the conversation history.
403	msgHistory := append(msgs, userMsg)
404
405	for {
406		// Check for cancellation before each iteration
407		select {
408		case <-ctx.Done():
409			return a.err(ctx.Err())
410		default:
411			// Continue processing
412		}
413		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
414		if err != nil {
415			if errors.Is(err, context.Canceled) {
416				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
417				a.messages.Update(context.Background(), agentMessage)
418				return a.err(ErrRequestCancelled)
419			}
420			return a.err(fmt.Errorf("failed to process events: %w", err))
421		}
422		if cfg.Options.Debug {
423			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
424		}
425		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
426			// We are not done, we need to respond with the tool response
427			msgHistory = append(msgHistory, agentMessage, *toolResults)
428			continue
429		}
430		return AgentEvent{
431			Type:    AgentEventTypeResponse,
432			Message: agentMessage,
433			Done:    true,
434		}
435	}
436}
437
438func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
439	parts := []message.ContentPart{message.TextContent{Text: content}}
440	parts = append(parts, attachmentParts...)
441	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442		Role:  message.User,
443		Parts: parts,
444	})
445}
446
447func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
448	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
449	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
450
451	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
452		Role:     message.Assistant,
453		Parts:    []message.ContentPart{},
454		Model:    a.Model().ID,
455		Provider: a.providerID,
456	})
457	if err != nil {
458		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
459	}
460
461	// Add the session and message ID into the context if needed by tools.
462	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
463
464	// Process each event in the stream.
465	for event := range eventChan {
466		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
467			if errors.Is(processErr, context.Canceled) {
468				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
469			} else {
470				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
471			}
472			return assistantMsg, nil, processErr
473		}
474		if ctx.Err() != nil {
475			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476			return assistantMsg, nil, ctx.Err()
477		}
478	}
479
480	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
481	toolCalls := assistantMsg.ToolCalls()
482	for i, toolCall := range toolCalls {
483		select {
484		case <-ctx.Done():
485			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
486			// Make all future tool calls cancelled
487			for j := i; j < len(toolCalls); j++ {
488				toolResults[j] = message.ToolResult{
489					ToolCallID: toolCalls[j].ID,
490					Content:    "Tool execution canceled by user",
491					IsError:    true,
492				}
493			}
494			goto out
495		default:
496			// Continue processing
497			var tool tools.BaseTool
498			for availableTool := range a.tools.Seq() {
499				if availableTool.Info().Name == toolCall.Name {
500					tool = availableTool
501					break
502				}
503			}
504
505			// Tool not found
506			if tool == nil {
507				toolResults[i] = message.ToolResult{
508					ToolCallID: toolCall.ID,
509					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
510					IsError:    true,
511				}
512				continue
513			}
514
515			// Run tool in goroutine to allow cancellation
516			type toolExecResult struct {
517				response tools.ToolResponse
518				err      error
519			}
520			resultChan := make(chan toolExecResult, 1)
521
522			go func() {
523				response, err := tool.Run(ctx, tools.ToolCall{
524					ID:    toolCall.ID,
525					Name:  toolCall.Name,
526					Input: toolCall.Input,
527				})
528				resultChan <- toolExecResult{response: response, err: err}
529			}()
530
531			var toolResponse tools.ToolResponse
532			var toolErr error
533
534			select {
535			case <-ctx.Done():
536				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
537				// Mark remaining tool calls as cancelled
538				for j := i; j < len(toolCalls); j++ {
539					toolResults[j] = message.ToolResult{
540						ToolCallID: toolCalls[j].ID,
541						Content:    "Tool execution canceled by user",
542						IsError:    true,
543					}
544				}
545				goto out
546			case result := <-resultChan:
547				toolResponse = result.response
548				toolErr = result.err
549			}
550
551			if toolErr != nil {
552				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
553				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
554					toolResults[i] = message.ToolResult{
555						ToolCallID: toolCall.ID,
556						Content:    "Permission denied",
557						IsError:    true,
558					}
559					for j := i + 1; j < len(toolCalls); j++ {
560						toolResults[j] = message.ToolResult{
561							ToolCallID: toolCalls[j].ID,
562							Content:    "Tool execution canceled by user",
563							IsError:    true,
564						}
565					}
566					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
567					break
568				}
569			}
570			toolResults[i] = message.ToolResult{
571				ToolCallID: toolCall.ID,
572				Content:    toolResponse.Content,
573				Metadata:   toolResponse.Metadata,
574				IsError:    toolResponse.IsError,
575			}
576		}
577	}
578out:
579	if len(toolResults) == 0 {
580		return assistantMsg, nil, nil
581	}
582	parts := make([]message.ContentPart, 0)
583	for _, tr := range toolResults {
584		parts = append(parts, tr)
585	}
586	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
587		Role:     message.Tool,
588		Parts:    parts,
589		Provider: a.providerID,
590	})
591	if err != nil {
592		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
593	}
594
595	return assistantMsg, &msg, err
596}
597
598func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
599	msg.AddFinish(finishReason, message, details)
600	_ = a.messages.Update(ctx, *msg)
601}
602
603func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
604	select {
605	case <-ctx.Done():
606		return ctx.Err()
607	default:
608		// Continue processing.
609	}
610
611	switch event.Type {
612	case provider.EventThinkingDelta:
613		assistantMsg.AppendReasoningContent(event.Thinking)
614		return a.messages.Update(ctx, *assistantMsg)
615	case provider.EventSignatureDelta:
616		assistantMsg.AppendReasoningSignature(event.Signature)
617		return a.messages.Update(ctx, *assistantMsg)
618	case provider.EventContentDelta:
619		assistantMsg.FinishThinking()
620		assistantMsg.AppendContent(event.Content)
621		return a.messages.Update(ctx, *assistantMsg)
622	case provider.EventToolUseStart:
623		assistantMsg.FinishThinking()
624		slog.Info("Tool call started", "toolCall", event.ToolCall)
625		assistantMsg.AddToolCall(*event.ToolCall)
626		return a.messages.Update(ctx, *assistantMsg)
627	case provider.EventToolUseDelta:
628		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
629		return a.messages.Update(ctx, *assistantMsg)
630	case provider.EventToolUseStop:
631		slog.Info("Finished tool call", "toolCall", event.ToolCall)
632		assistantMsg.FinishToolCall(event.ToolCall.ID)
633		return a.messages.Update(ctx, *assistantMsg)
634	case provider.EventError:
635		return event.Error
636	case provider.EventComplete:
637		assistantMsg.FinishThinking()
638		assistantMsg.SetToolCalls(event.Response.ToolCalls)
639		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
640		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
641			return fmt.Errorf("failed to update message: %w", err)
642		}
643		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
644	}
645
646	return nil
647}
648
649func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
650	sess, err := a.sessions.Get(ctx, sessionID)
651	if err != nil {
652		return fmt.Errorf("failed to get session: %w", err)
653	}
654
655	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
656		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
657		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
658		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
659
660	sess.Cost += cost
661	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
662	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
663
664	_, err = a.sessions.Save(ctx, sess)
665	if err != nil {
666		return fmt.Errorf("failed to save session: %w", err)
667	}
668	return nil
669}
670
671func (a *agent) Summarize(ctx context.Context, sessionID string) error {
672	if a.summarizeProvider == nil {
673		return fmt.Errorf("summarize provider not available")
674	}
675
676	// Check if session is busy
677	if a.IsSessionBusy(sessionID) {
678		return ErrSessionBusy
679	}
680
681	// Create a new context with cancellation
682	summarizeCtx, cancel := context.WithCancel(ctx)
683
684	// Store the cancel function in activeRequests to allow cancellation
685	a.activeRequests.Store(sessionID+"-summarize", cancel)
686
687	go func() {
688		defer a.activeRequests.Delete(sessionID + "-summarize")
689		defer cancel()
690		event := AgentEvent{
691			Type:     AgentEventTypeSummarize,
692			Progress: "Starting summarization...",
693		}
694
695		a.Publish(pubsub.CreatedEvent, event)
696		// Get all messages from the session
697		msgs, err := a.messages.List(summarizeCtx, sessionID)
698		if err != nil {
699			event = AgentEvent{
700				Type:  AgentEventTypeError,
701				Error: fmt.Errorf("failed to list messages: %w", err),
702				Done:  true,
703			}
704			a.Publish(pubsub.CreatedEvent, event)
705			return
706		}
707		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
708
709		if len(msgs) == 0 {
710			event = AgentEvent{
711				Type:  AgentEventTypeError,
712				Error: fmt.Errorf("no messages to summarize"),
713				Done:  true,
714			}
715			a.Publish(pubsub.CreatedEvent, event)
716			return
717		}
718
719		event = AgentEvent{
720			Type:     AgentEventTypeSummarize,
721			Progress: "Analyzing conversation...",
722		}
723		a.Publish(pubsub.CreatedEvent, event)
724
725		// Add a system message to guide the summarization
726		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
727
728		// Create a new message with the summarize prompt
729		promptMsg := message.Message{
730			Role:  message.User,
731			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
732		}
733
734		// Append the prompt to the messages
735		msgsWithPrompt := append(msgs, promptMsg)
736
737		event = AgentEvent{
738			Type:     AgentEventTypeSummarize,
739			Progress: "Generating summary...",
740		}
741
742		a.Publish(pubsub.CreatedEvent, event)
743
744		// Send the messages to the summarize provider
745		response := a.summarizeProvider.StreamResponse(
746			summarizeCtx,
747			msgsWithPrompt,
748			nil,
749		)
750		var finalResponse *provider.ProviderResponse
751		for r := range response {
752			if r.Error != nil {
753				event = AgentEvent{
754					Type:  AgentEventTypeError,
755					Error: fmt.Errorf("failed to summarize: %w", err),
756					Done:  true,
757				}
758				a.Publish(pubsub.CreatedEvent, event)
759				return
760			}
761			finalResponse = r.Response
762		}
763
764		summary := strings.TrimSpace(finalResponse.Content)
765		if summary == "" {
766			event = AgentEvent{
767				Type:  AgentEventTypeError,
768				Error: fmt.Errorf("empty summary returned"),
769				Done:  true,
770			}
771			a.Publish(pubsub.CreatedEvent, event)
772			return
773		}
774		shell := shell.GetPersistentShell(config.Get().WorkingDir())
775		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
776		event = AgentEvent{
777			Type:     AgentEventTypeSummarize,
778			Progress: "Creating new session...",
779		}
780
781		a.Publish(pubsub.CreatedEvent, event)
782		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
783		if err != nil {
784			event = AgentEvent{
785				Type:  AgentEventTypeError,
786				Error: fmt.Errorf("failed to get session: %w", err),
787				Done:  true,
788			}
789
790			a.Publish(pubsub.CreatedEvent, event)
791			return
792		}
793		// Create a message in the new session with the summary
794		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
795			Role: message.Assistant,
796			Parts: []message.ContentPart{
797				message.TextContent{Text: summary},
798				message.Finish{
799					Reason: message.FinishReasonEndTurn,
800					Time:   time.Now().Unix(),
801				},
802			},
803			Model:    a.summarizeProvider.Model().ID,
804			Provider: a.summarizeProviderID,
805		})
806		if err != nil {
807			event = AgentEvent{
808				Type:  AgentEventTypeError,
809				Error: fmt.Errorf("failed to create summary message: %w", err),
810				Done:  true,
811			}
812
813			a.Publish(pubsub.CreatedEvent, event)
814			return
815		}
816		oldSession.SummaryMessageID = msg.ID
817		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
818		oldSession.PromptTokens = 0
819		model := a.summarizeProvider.Model()
820		usage := finalResponse.Usage
821		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
822			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
823			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
824			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
825		oldSession.Cost += cost
826		_, err = a.sessions.Save(summarizeCtx, oldSession)
827		if err != nil {
828			event = AgentEvent{
829				Type:  AgentEventTypeError,
830				Error: fmt.Errorf("failed to save session: %w", err),
831				Done:  true,
832			}
833			a.Publish(pubsub.CreatedEvent, event)
834		}
835
836		event = AgentEvent{
837			Type:      AgentEventTypeSummarize,
838			SessionID: oldSession.ID,
839			Progress:  "Summary complete",
840			Done:      true,
841		}
842		a.Publish(pubsub.CreatedEvent, event)
843		// Send final success event with the new session ID
844	}()
845
846	return nil
847}
848
849func (a *agent) CancelAll() {
850	if !a.IsBusy() {
851		return
852	}
853	a.activeRequests.Range(func(key, value any) bool {
854		a.Cancel(key.(string)) // key is sessionID
855		return true
856	})
857
858	timeout := time.After(5 * time.Second)
859	for a.IsBusy() {
860		select {
861		case <-timeout:
862			return
863		default:
864			time.Sleep(200 * time.Millisecond)
865		}
866	}
867}
868
869func (a *agent) UpdateModel() error {
870	cfg := config.Get()
871
872	// Get current provider configuration
873	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
874	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
875		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
876	}
877
878	// Check if provider has changed
879	if string(currentProviderCfg.ID) != a.providerID {
880		// Provider changed, need to recreate the main provider
881		model := cfg.GetModelByType(a.agentCfg.Model)
882		if model.ID == "" {
883			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
884		}
885
886		promptID := agentPromptMap[a.agentCfg.ID]
887		if promptID == "" {
888			promptID = prompt.PromptDefault
889		}
890
891		opts := []provider.ProviderClientOption{
892			provider.WithModel(a.agentCfg.Model),
893			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
894		}
895
896		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
897		if err != nil {
898			return fmt.Errorf("failed to create new provider: %w", err)
899		}
900
901		// Update the provider and provider ID
902		a.provider = newProvider
903		a.providerID = string(currentProviderCfg.ID)
904	}
905
906	// Check if small model provider has changed (affects title and summarize providers)
907	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
908	var smallModelProviderCfg config.ProviderConfig
909
910	for _, p := range cfg.Providers.Seq2() {
911		if p.ID == smallModelCfg.Provider {
912			smallModelProviderCfg = p
913			break
914		}
915	}
916
917	if smallModelProviderCfg.ID == "" {
918		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
919	}
920
921	// Check if summarize provider has changed
922	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
923		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
924		if smallModel == nil {
925			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
926		}
927
928		// Recreate title provider
929		titleOpts := []provider.ProviderClientOption{
930			provider.WithModel(config.SelectedModelTypeSmall),
931			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
932			// We want the title to be short, so we limit the max tokens
933			provider.WithMaxTokens(40),
934		}
935		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
936		if err != nil {
937			return fmt.Errorf("failed to create new title provider: %w", err)
938		}
939
940		// Recreate summarize provider
941		summarizeOpts := []provider.ProviderClientOption{
942			provider.WithModel(config.SelectedModelTypeSmall),
943			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
944		}
945		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
946		if err != nil {
947			return fmt.Errorf("failed to create new summarize provider: %w", err)
948		}
949
950		// Update the providers and provider ID
951		a.titleProvider = newTitleProvider
952		a.summarizeProvider = newSummarizeProvider
953		a.summarizeProviderID = string(smallModelProviderCfg.ID)
954	}
955
956	return nil
957}