1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75
  76	provider   provider.Provider
  77	providerID string
  78
  79	titleProvider       provider.Provider
  80	summarizeProvider   provider.Provider
  81	summarizeProviderID string
  82
  83	activeRequests *csync.Map[string, context.CancelFunc]
  84
  85	promptQueue *csync.Map[string, []string]
  86}
  87
  88var agentPromptMap = map[string]prompt.PromptID{
  89	"coder": prompt.PromptCoder,
  90	"task":  prompt.PromptTask,
  91}
  92
  93func NewAgent(
  94	ctx context.Context,
  95	agentCfg config.Agent,
  96	// These services are needed in the tools
  97	permissions permission.Service,
  98	sessions session.Service,
  99	messages message.Service,
 100	history history.Service,
 101	lspClients map[string]*lsp.Client,
 102) (Service, error) {
 103	cfg := config.Get()
 104
 105	var agentTool tools.BaseTool
 106	if agentCfg.ID == "coder" {
 107		taskAgentCfg := config.Get().Agents["task"]
 108		if taskAgentCfg.ID == "" {
 109			return nil, fmt.Errorf("task agent not found in config")
 110		}
 111		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 112		if err != nil {
 113			return nil, fmt.Errorf("failed to create task agent: %w", err)
 114		}
 115
 116		agentTool = NewAgentTool(taskAgent, sessions, messages)
 117	}
 118
 119	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 120	if providerCfg == nil {
 121		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 122	}
 123	model := config.Get().GetModelByType(agentCfg.Model)
 124
 125	if model == nil {
 126		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 127	}
 128
 129	promptID := agentPromptMap[agentCfg.ID]
 130	if promptID == "" {
 131		promptID = prompt.PromptDefault
 132	}
 133	opts := []provider.ProviderClientOption{
 134		provider.WithModel(agentCfg.Model),
 135		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 136	}
 137	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 138	if err != nil {
 139		return nil, err
 140	}
 141
 142	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 143	var smallModelProviderCfg *config.ProviderConfig
 144	if smallModelCfg.Provider == providerCfg.ID {
 145		smallModelProviderCfg = providerCfg
 146	} else {
 147		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 148
 149		if smallModelProviderCfg.ID == "" {
 150			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 151		}
 152	}
 153	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 154	if smallModel.ID == "" {
 155		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 156	}
 157
 158	titleOpts := []provider.ProviderClientOption{
 159		provider.WithModel(config.SelectedModelTypeSmall),
 160		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 161	}
 162	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 163	if err != nil {
 164		return nil, err
 165	}
 166
 167	summarizeOpts := []provider.ProviderClientOption{
 168		provider.WithModel(config.SelectedModelTypeLarge),
 169		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 170	}
 171	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 172	if err != nil {
 173		return nil, err
 174	}
 175
 176	toolFn := func() []tools.BaseTool {
 177		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 178		defer func() {
 179			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 180		}()
 181
 182		cwd := cfg.WorkingDir()
 183		allTools := []tools.BaseTool{
 184			tools.NewBashTool(permissions, cwd),
 185			tools.NewDownloadTool(permissions, cwd),
 186			tools.NewEditTool(lspClients, permissions, history, cwd),
 187			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 188			tools.NewFetchTool(permissions, cwd),
 189			tools.NewGlobTool(cwd),
 190			tools.NewGrepTool(cwd),
 191			tools.NewLsTool(permissions, cwd),
 192			tools.NewSourcegraphTool(),
 193			tools.NewViewTool(lspClients, permissions, cwd),
 194			tools.NewWriteTool(lspClients, permissions, history, cwd),
 195		}
 196
 197		mcpToolsOnce.Do(func() {
 198			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 199		})
 200		allTools = append(allTools, mcpTools...)
 201
 202		if len(lspClients) > 0 {
 203			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
 204		}
 205
 206		if agentTool != nil {
 207			allTools = append(allTools, agentTool)
 208		}
 209
 210		if agentCfg.AllowedTools == nil {
 211			return allTools
 212		}
 213
 214		var filteredTools []tools.BaseTool
 215		for _, tool := range allTools {
 216			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 217				filteredTools = append(filteredTools, tool)
 218			}
 219		}
 220		return filteredTools
 221	}
 222
 223	return &agent{
 224		Broker:              pubsub.NewBroker[AgentEvent](),
 225		agentCfg:            agentCfg,
 226		provider:            agentProvider,
 227		providerID:          string(providerCfg.ID),
 228		messages:            messages,
 229		sessions:            sessions,
 230		titleProvider:       titleProvider,
 231		summarizeProvider:   summarizeProvider,
 232		summarizeProviderID: string(providerCfg.ID),
 233		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 234		tools:               csync.NewLazySlice(toolFn),
 235		promptQueue:         csync.NewMap[string, []string](),
 236	}, nil
 237}
 238
 239func (a *agent) Model() catwalk.Model {
 240	return *config.Get().GetModelByType(a.agentCfg.Model)
 241}
 242
 243func (a *agent) Cancel(sessionID string) {
 244	// Cancel regular requests
 245	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 246		slog.Info("Request cancellation initiated", "session_id", sessionID)
 247		cancel()
 248	}
 249
 250	// Also check for summarize requests
 251	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 252		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 253		cancel()
 254	}
 255
 256	if a.QueuedPrompts(sessionID) > 0 {
 257		slog.Info("Clearing queued prompts", "session_id", sessionID)
 258		a.promptQueue.Del(sessionID)
 259	}
 260}
 261
 262func (a *agent) IsBusy() bool {
 263	var busy bool
 264	for cancelFunc := range a.activeRequests.Seq() {
 265		if cancelFunc != nil {
 266			busy = true
 267			break
 268		}
 269	}
 270	return busy
 271}
 272
 273func (a *agent) IsSessionBusy(sessionID string) bool {
 274	_, busy := a.activeRequests.Get(sessionID)
 275	return busy
 276}
 277
 278func (a *agent) QueuedPrompts(sessionID string) int {
 279	l, ok := a.promptQueue.Get(sessionID)
 280	if !ok {
 281		return 0
 282	}
 283	return len(l)
 284}
 285
 286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 287	if content == "" {
 288		return nil
 289	}
 290	if a.titleProvider == nil {
 291		return nil
 292	}
 293	session, err := a.sessions.Get(ctx, sessionID)
 294	if err != nil {
 295		return err
 296	}
 297	parts := []message.ContentPart{message.TextContent{
 298		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 299	}}
 300
 301	// Use streaming approach like summarization
 302	response := a.titleProvider.StreamResponse(
 303		ctx,
 304		[]message.Message{
 305			{
 306				Role:  message.User,
 307				Parts: parts,
 308			},
 309		},
 310		nil,
 311	)
 312
 313	var finalResponse *provider.ProviderResponse
 314	for r := range response {
 315		if r.Error != nil {
 316			return r.Error
 317		}
 318		finalResponse = r.Response
 319	}
 320
 321	if finalResponse == nil {
 322		return fmt.Errorf("no response received from title provider")
 323	}
 324
 325	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 326	if title == "" {
 327		return nil
 328	}
 329
 330	session.Title = title
 331	_, err = a.sessions.Save(ctx, session)
 332	return err
 333}
 334
 335func (a *agent) err(err error) AgentEvent {
 336	return AgentEvent{
 337		Type:  AgentEventTypeError,
 338		Error: err,
 339	}
 340}
 341
 342func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 343	if !a.Model().SupportsImages && attachments != nil {
 344		attachments = nil
 345	}
 346	events := make(chan AgentEvent)
 347	if a.IsSessionBusy(sessionID) {
 348		existing, ok := a.promptQueue.Get(sessionID)
 349		if !ok {
 350			existing = []string{}
 351		}
 352		existing = append(existing, content)
 353		a.promptQueue.Set(sessionID, existing)
 354		return nil, nil
 355	}
 356
 357	genCtx, cancel := context.WithCancel(ctx)
 358	a.activeRequests.Set(sessionID, cancel)
 359
 360	go func() {
 361		slog.Debug("Request started", "sessionID", sessionID)
 362		defer log.RecoverPanic("agent.Run", func() {
 363			events <- a.err(fmt.Errorf("panic while running the agent"))
 364		})
 365		var attachmentParts []message.ContentPart
 366		for _, attachment := range attachments {
 367			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 368		}
 369		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 370		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 371			slog.Error(result.Error.Error())
 372		}
 373		slog.Debug("Request completed", "sessionID", sessionID)
 374		a.activeRequests.Del(sessionID)
 375		cancel()
 376		a.Publish(pubsub.CreatedEvent, result)
 377		events <- result
 378		close(events)
 379	}()
 380	return events, nil
 381}
 382
 383func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 384	cfg := config.Get()
 385	// List existing messages; if none, start title generation asynchronously.
 386	msgs, err := a.messages.List(ctx, sessionID)
 387	if err != nil {
 388		return a.err(fmt.Errorf("failed to list messages: %w", err))
 389	}
 390
 391	// sliding window to limit message history
 392	maxMessagesInContext := cfg.Options.MaxMessages
 393	if maxMessagesInContext > 0 && len(msgs) > maxMessagesInContext {
 394		// Keep the first message (usually system/context) and the last N-1 messages
 395		msgs = append(msgs[:1], msgs[len(msgs)-maxMessagesInContext+1:]...)
 396	}
 397
 398	if len(msgs) == 0 {
 399		// Use a context with timeout for title generation
 400		titleCtx, titleCancel := context.WithTimeout(context.Background(), 30*time.Second)
 401		go func() {
 402			defer titleCancel()
 403			defer log.RecoverPanic("agent.Run", func() {
 404				slog.Error("panic while generating title")
 405			})
 406			titleErr := a.generateTitle(titleCtx, sessionID, content)
 407			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 408				slog.Error("failed to generate title", "error", titleErr)
 409			}
 410		}()
 411	}
 412	session, err := a.sessions.Get(ctx, sessionID)
 413	if err != nil {
 414		return a.err(fmt.Errorf("failed to get session: %w", err))
 415	}
 416	if session.SummaryMessageID != "" {
 417		summaryMsgInex := -1
 418		for i, msg := range msgs {
 419			if msg.ID == session.SummaryMessageID {
 420				summaryMsgInex = i
 421				break
 422			}
 423		}
 424		if summaryMsgInex != -1 {
 425			msgs = msgs[summaryMsgInex:]
 426			msgs[0].Role = message.User
 427		}
 428	}
 429
 430	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 431	if err != nil {
 432		return a.err(fmt.Errorf("failed to create user message: %w", err))
 433	}
 434	// Append the new user message to the conversation history.
 435	msgHistory := append(msgs, userMsg)
 436
 437	for {
 438		// Check for cancellation before each iteration
 439		select {
 440		case <-ctx.Done():
 441			return a.err(ctx.Err())
 442		default:
 443			// Continue processing
 444		}
 445		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 446		if err != nil {
 447			if errors.Is(err, context.Canceled) {
 448				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 449				a.messages.Update(context.Background(), agentMessage)
 450				return a.err(ErrRequestCancelled)
 451			}
 452			return a.err(fmt.Errorf("failed to process events: %w", err))
 453		}
 454		if cfg.Options.Debug {
 455			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 456		}
 457		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 458			// We are not done, we need to respond with the tool response
 459			msgHistory = append(msgHistory, agentMessage, *toolResults)
 460			// If there are queued prompts, process the next one
 461			nextPrompt, ok := a.promptQueue.Take(sessionID)
 462			if ok {
 463				for _, prompt := range nextPrompt {
 464					// Create a new user message for the queued prompt
 465					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 466					if err != nil {
 467						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 468					}
 469					// Append the new user message to the conversation history
 470					msgHistory = append(msgHistory, userMsg)
 471				}
 472			}
 473
 474			continue
 475		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 476			queuePrompts, ok := a.promptQueue.Take(sessionID)
 477			if ok {
 478				for _, prompt := range queuePrompts {
 479					if prompt == "" {
 480						continue
 481					}
 482					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 483					if err != nil {
 484						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 485					}
 486					msgHistory = append(msgHistory, userMsg)
 487				}
 488				continue
 489			}
 490		}
 491		if agentMessage.FinishReason() == "" {
 492			// Kujtim: could not track down where this is happening but this means its cancelled
 493			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 494			_ = a.messages.Update(context.Background(), agentMessage)
 495			return a.err(ErrRequestCancelled)
 496		}
 497		return AgentEvent{
 498			Type:    AgentEventTypeResponse,
 499			Message: agentMessage,
 500			Done:    true,
 501		}
 502	}
 503}
 504
 505func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 506	parts := []message.ContentPart{message.TextContent{Text: content}}
 507	parts = append(parts, attachmentParts...)
 508	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 509		Role:  message.User,
 510		Parts: parts,
 511	})
 512}
 513
 514func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 515	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 516
 517	// Create the assistant message first so the spinner shows immediately
 518	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 519		Role:     message.Assistant,
 520		Parts:    []message.ContentPart{},
 521		Model:    a.Model().ID,
 522		Provider: a.providerID,
 523	})
 524	if err != nil {
 525		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 526	}
 527
 528	// Now collect tools (which may block on MCP initialization)
 529	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
 530
 531	// Add the session and message ID into the context if needed by tools.
 532	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 533
 534	// Process each event in the stream.
 535	for event := range eventChan {
 536		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 537			if errors.Is(processErr, context.Canceled) {
 538				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 539			} else {
 540				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 541			}
 542			return assistantMsg, nil, processErr
 543		}
 544		if ctx.Err() != nil {
 545			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 546			return assistantMsg, nil, ctx.Err()
 547		}
 548	}
 549
 550	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 551	toolCalls := assistantMsg.ToolCalls()
 552	for i, toolCall := range toolCalls {
 553		select {
 554		case <-ctx.Done():
 555			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 556			// Make all future tool calls cancelled
 557			for j := i; j < len(toolCalls); j++ {
 558				toolResults[j] = message.ToolResult{
 559					ToolCallID: toolCalls[j].ID,
 560					Content:    "Tool execution canceled by user",
 561					IsError:    true,
 562				}
 563			}
 564			goto out
 565		default:
 566			// Continue processing
 567			var tool tools.BaseTool
 568			for availableTool := range a.tools.Seq() {
 569				if availableTool.Info().Name == toolCall.Name {
 570					tool = availableTool
 571					break
 572				}
 573			}
 574
 575			// Tool not found
 576			if tool == nil {
 577				toolResults[i] = message.ToolResult{
 578					ToolCallID: toolCall.ID,
 579					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 580					IsError:    true,
 581				}
 582				continue
 583			}
 584
 585			// Run tool in goroutine to allow cancellation
 586			type toolExecResult struct {
 587				response tools.ToolResponse
 588				err      error
 589			}
 590			resultChan := make(chan toolExecResult, 1)
 591
 592			go func() {
 593				response, err := tool.Run(ctx, tools.ToolCall{
 594					ID:    toolCall.ID,
 595					Name:  toolCall.Name,
 596					Input: toolCall.Input,
 597				})
 598				resultChan <- toolExecResult{response: response, err: err}
 599			}()
 600
 601			var toolResponse tools.ToolResponse
 602			var toolErr error
 603
 604			select {
 605			case <-ctx.Done():
 606				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 607				// Mark remaining tool calls as cancelled
 608				for j := i; j < len(toolCalls); j++ {
 609					toolResults[j] = message.ToolResult{
 610						ToolCallID: toolCalls[j].ID,
 611						Content:    "Tool execution canceled by user",
 612						IsError:    true,
 613					}
 614				}
 615				goto out
 616			case result := <-resultChan:
 617				toolResponse = result.response
 618				toolErr = result.err
 619			}
 620
 621			if toolErr != nil {
 622				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 623				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 624					toolResults[i] = message.ToolResult{
 625						ToolCallID: toolCall.ID,
 626						Content:    "Permission denied",
 627						IsError:    true,
 628					}
 629					for j := i + 1; j < len(toolCalls); j++ {
 630						toolResults[j] = message.ToolResult{
 631							ToolCallID: toolCalls[j].ID,
 632							Content:    "Tool execution canceled by user",
 633							IsError:    true,
 634						}
 635					}
 636					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 637					break
 638				}
 639			}
 640			toolResults[i] = message.ToolResult{
 641				ToolCallID: toolCall.ID,
 642				Content:    toolResponse.Content,
 643				Metadata:   toolResponse.Metadata,
 644				IsError:    toolResponse.IsError,
 645			}
 646		}
 647	}
 648out:
 649	if len(toolResults) == 0 {
 650		return assistantMsg, nil, nil
 651	}
 652	parts := make([]message.ContentPart, 0)
 653	for _, tr := range toolResults {
 654		parts = append(parts, tr)
 655	}
 656	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 657		Role:     message.Tool,
 658		Parts:    parts,
 659		Provider: a.providerID,
 660	})
 661	if err != nil {
 662		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 663	}
 664
 665	return assistantMsg, &msg, err
 666}
 667
 668func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 669	msg.AddFinish(finishReason, message, details)
 670	_ = a.messages.Update(ctx, *msg)
 671}
 672
 673func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 674	select {
 675	case <-ctx.Done():
 676		return ctx.Err()
 677	default:
 678		// Continue processing.
 679	}
 680
 681	switch event.Type {
 682	case provider.EventThinkingDelta:
 683		assistantMsg.AppendReasoningContent(event.Thinking)
 684		return a.messages.Update(ctx, *assistantMsg)
 685	case provider.EventSignatureDelta:
 686		assistantMsg.AppendReasoningSignature(event.Signature)
 687		return a.messages.Update(ctx, *assistantMsg)
 688	case provider.EventContentDelta:
 689		assistantMsg.FinishThinking()
 690		assistantMsg.AppendContent(event.Content)
 691		return a.messages.Update(ctx, *assistantMsg)
 692	case provider.EventToolUseStart:
 693		assistantMsg.FinishThinking()
 694		slog.Info("Tool call started", "toolCall", event.ToolCall)
 695		assistantMsg.AddToolCall(*event.ToolCall)
 696		return a.messages.Update(ctx, *assistantMsg)
 697	case provider.EventToolUseDelta:
 698		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 699		return a.messages.Update(ctx, *assistantMsg)
 700	case provider.EventToolUseStop:
 701		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 702		assistantMsg.FinishToolCall(event.ToolCall.ID)
 703		return a.messages.Update(ctx, *assistantMsg)
 704	case provider.EventError:
 705		return event.Error
 706	case provider.EventComplete:
 707		assistantMsg.FinishThinking()
 708		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 709		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 710		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 711			return fmt.Errorf("failed to update message: %w", err)
 712		}
 713		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 714	}
 715
 716	return nil
 717}
 718
 719func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 720	sess, err := a.sessions.Get(ctx, sessionID)
 721	if err != nil {
 722		return fmt.Errorf("failed to get session: %w", err)
 723	}
 724
 725	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 726		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 727		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 728		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 729
 730	sess.Cost += cost
 731	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 732	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 733
 734	_, err = a.sessions.Save(ctx, sess)
 735	if err != nil {
 736		return fmt.Errorf("failed to save session: %w", err)
 737	}
 738	return nil
 739}
 740
 741func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 742	if a.summarizeProvider == nil {
 743		return fmt.Errorf("summarize provider not available")
 744	}
 745
 746	// Check if session is busy
 747	if a.IsSessionBusy(sessionID) {
 748		return ErrSessionBusy
 749	}
 750
 751	// Create a new context with cancellation
 752	summarizeCtx, cancel := context.WithCancel(ctx)
 753
 754	// Store the cancel function in activeRequests to allow cancellation
 755	a.activeRequests.Set(sessionID+"-summarize", cancel)
 756
 757	go func() {
 758		defer a.activeRequests.Del(sessionID + "-summarize")
 759		defer cancel()
 760		event := AgentEvent{
 761			Type:     AgentEventTypeSummarize,
 762			Progress: "Starting summarization...",
 763		}
 764
 765		a.Publish(pubsub.CreatedEvent, event)
 766		// Get all messages from the session
 767		msgs, err := a.messages.List(summarizeCtx, sessionID)
 768		if err != nil {
 769			event = AgentEvent{
 770				Type:  AgentEventTypeError,
 771				Error: fmt.Errorf("failed to list messages: %w", err),
 772				Done:  true,
 773			}
 774			a.Publish(pubsub.CreatedEvent, event)
 775			return
 776		}
 777		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 778
 779		if len(msgs) == 0 {
 780			event = AgentEvent{
 781				Type:  AgentEventTypeError,
 782				Error: fmt.Errorf("no messages to summarize"),
 783				Done:  true,
 784			}
 785			a.Publish(pubsub.CreatedEvent, event)
 786			return
 787		}
 788
 789		event = AgentEvent{
 790			Type:     AgentEventTypeSummarize,
 791			Progress: "Analyzing conversation...",
 792		}
 793		a.Publish(pubsub.CreatedEvent, event)
 794
 795		// Add a system message to guide the summarization
 796		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 797
 798		// Create a new message with the summarize prompt
 799		promptMsg := message.Message{
 800			Role:  message.User,
 801			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 802		}
 803
 804		// Append the prompt to the messages
 805		msgsWithPrompt := append(msgs, promptMsg)
 806
 807		event = AgentEvent{
 808			Type:     AgentEventTypeSummarize,
 809			Progress: "Generating summary...",
 810		}
 811
 812		a.Publish(pubsub.CreatedEvent, event)
 813
 814		// Send the messages to the summarize provider
 815		response := a.summarizeProvider.StreamResponse(
 816			summarizeCtx,
 817			msgsWithPrompt,
 818			nil,
 819		)
 820		var finalResponse *provider.ProviderResponse
 821		for r := range response {
 822			if r.Error != nil {
 823				event = AgentEvent{
 824					Type:  AgentEventTypeError,
 825					Error: fmt.Errorf("failed to summarize: %w", err),
 826					Done:  true,
 827				}
 828				a.Publish(pubsub.CreatedEvent, event)
 829				return
 830			}
 831			finalResponse = r.Response
 832		}
 833
 834		summary := strings.TrimSpace(finalResponse.Content)
 835		if summary == "" {
 836			event = AgentEvent{
 837				Type:  AgentEventTypeError,
 838				Error: fmt.Errorf("empty summary returned"),
 839				Done:  true,
 840			}
 841			a.Publish(pubsub.CreatedEvent, event)
 842			return
 843		}
 844		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 845		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 846		event = AgentEvent{
 847			Type:     AgentEventTypeSummarize,
 848			Progress: "Creating new session...",
 849		}
 850
 851		a.Publish(pubsub.CreatedEvent, event)
 852		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 853		if err != nil {
 854			event = AgentEvent{
 855				Type:  AgentEventTypeError,
 856				Error: fmt.Errorf("failed to get session: %w", err),
 857				Done:  true,
 858			}
 859
 860			a.Publish(pubsub.CreatedEvent, event)
 861			return
 862		}
 863		// Create a message in the new session with the summary
 864		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 865			Role: message.Assistant,
 866			Parts: []message.ContentPart{
 867				message.TextContent{Text: summary},
 868				message.Finish{
 869					Reason: message.FinishReasonEndTurn,
 870					Time:   time.Now().Unix(),
 871				},
 872			},
 873			Model:    a.summarizeProvider.Model().ID,
 874			Provider: a.summarizeProviderID,
 875		})
 876		if err != nil {
 877			event = AgentEvent{
 878				Type:  AgentEventTypeError,
 879				Error: fmt.Errorf("failed to create summary message: %w", err),
 880				Done:  true,
 881			}
 882
 883			a.Publish(pubsub.CreatedEvent, event)
 884			return
 885		}
 886		oldSession.SummaryMessageID = msg.ID
 887		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 888		oldSession.PromptTokens = 0
 889		model := a.summarizeProvider.Model()
 890		usage := finalResponse.Usage
 891		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 892			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 893			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 894			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 895		oldSession.Cost += cost
 896		_, err = a.sessions.Save(summarizeCtx, oldSession)
 897		if err != nil {
 898			event = AgentEvent{
 899				Type:  AgentEventTypeError,
 900				Error: fmt.Errorf("failed to save session: %w", err),
 901				Done:  true,
 902			}
 903			a.Publish(pubsub.CreatedEvent, event)
 904		}
 905
 906		event = AgentEvent{
 907			Type:      AgentEventTypeSummarize,
 908			SessionID: oldSession.ID,
 909			Progress:  "Summary complete",
 910			Done:      true,
 911		}
 912		a.Publish(pubsub.CreatedEvent, event)
 913		// Send final success event with the new session ID
 914	}()
 915
 916	return nil
 917}
 918
 919func (a *agent) ClearQueue(sessionID string) {
 920	if a.QueuedPrompts(sessionID) > 0 {
 921		slog.Info("Clearing queued prompts", "session_id", sessionID)
 922		a.promptQueue.Del(sessionID)
 923	}
 924}
 925
 926func (a *agent) CancelAll() {
 927	if !a.IsBusy() {
 928		return
 929	}
 930	for key := range a.activeRequests.Seq2() {
 931		a.Cancel(key) // key is sessionID
 932	}
 933
 934	timeout := time.After(5 * time.Second)
 935	for a.IsBusy() {
 936		select {
 937		case <-timeout:
 938			return
 939		default:
 940			time.Sleep(200 * time.Millisecond)
 941		}
 942	}
 943}
 944
 945func (a *agent) UpdateModel() error {
 946	cfg := config.Get()
 947
 948	// Get current provider configuration
 949	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 950	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 951		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 952	}
 953
 954	// Check if provider has changed
 955	if string(currentProviderCfg.ID) != a.providerID {
 956		// Provider changed, need to recreate the main provider
 957		model := cfg.GetModelByType(a.agentCfg.Model)
 958		if model.ID == "" {
 959			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 960		}
 961
 962		promptID := agentPromptMap[a.agentCfg.ID]
 963		if promptID == "" {
 964			promptID = prompt.PromptDefault
 965		}
 966
 967		opts := []provider.ProviderClientOption{
 968			provider.WithModel(a.agentCfg.Model),
 969			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 970		}
 971
 972		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 973		if err != nil {
 974			return fmt.Errorf("failed to create new provider: %w", err)
 975		}
 976
 977		// Update the provider and provider ID
 978		a.provider = newProvider
 979		a.providerID = string(currentProviderCfg.ID)
 980	}
 981
 982	// Check if providers have changed for title (small) and summarize (large)
 983	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 984	var smallModelProviderCfg config.ProviderConfig
 985	for p := range cfg.Providers.Seq() {
 986		if p.ID == smallModelCfg.Provider {
 987			smallModelProviderCfg = p
 988			break
 989		}
 990	}
 991	if smallModelProviderCfg.ID == "" {
 992		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 993	}
 994
 995	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
 996	var largeModelProviderCfg config.ProviderConfig
 997	for p := range cfg.Providers.Seq() {
 998		if p.ID == largeModelCfg.Provider {
 999			largeModelProviderCfg = p
1000			break
1001		}
1002	}
1003	if largeModelProviderCfg.ID == "" {
1004		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1005	}
1006
1007	// Recreate title provider
1008	titleOpts := []provider.ProviderClientOption{
1009		provider.WithModel(config.SelectedModelTypeSmall),
1010		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1011		provider.WithMaxTokens(40),
1012	}
1013	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1014	if err != nil {
1015		return fmt.Errorf("failed to create new title provider: %w", err)
1016	}
1017	a.titleProvider = newTitleProvider
1018
1019	// Recreate summarize provider if provider changed (now large model)
1020	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1021		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1022		if largeModel == nil {
1023			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1024		}
1025		summarizeOpts := []provider.ProviderClientOption{
1026			provider.WithModel(config.SelectedModelTypeLarge),
1027			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1028		}
1029		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1030		if err != nil {
1031			return fmt.Errorf("failed to create new summarize provider: %w", err)
1032		}
1033		a.summarizeProvider = newSummarizeProvider
1034		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1035	}
1036
1037	return nil
1038}