agent.go

   1// Package agent contains the implementation of the AI agent service.
   2package agent
   3
   4import (
   5	"context"
   6	"errors"
   7	"fmt"
   8	"log/slog"
   9	"maps"
  10	"slices"
  11	"strings"
  12	"time"
  13
  14	"github.com/charmbracelet/catwalk/pkg/catwalk"
  15	"github.com/charmbracelet/crush/internal/config"
  16	"github.com/charmbracelet/crush/internal/csync"
  17	"github.com/charmbracelet/crush/internal/event"
  18	"github.com/charmbracelet/crush/internal/history"
  19	"github.com/charmbracelet/crush/internal/llm/prompt"
  20	"github.com/charmbracelet/crush/internal/llm/provider"
  21	"github.com/charmbracelet/crush/internal/llm/tools"
  22	"github.com/charmbracelet/crush/internal/log"
  23	"github.com/charmbracelet/crush/internal/lsp"
  24	"github.com/charmbracelet/crush/internal/message"
  25	"github.com/charmbracelet/crush/internal/permission"
  26	"github.com/charmbracelet/crush/internal/pubsub"
  27	"github.com/charmbracelet/crush/internal/session"
  28	"github.com/charmbracelet/crush/internal/shell"
  29)
  30
  31type AgentEventType string
  32
  33const (
  34	AgentEventTypeError     AgentEventType = "error"
  35	AgentEventTypeResponse  AgentEventType = "response"
  36	AgentEventTypeSummarize AgentEventType = "summarize"
  37)
  38
  39type AgentEvent struct {
  40	Type    AgentEventType
  41	Message message.Message
  42	Error   error
  43
  44	// When summarizing
  45	SessionID string
  46	Progress  string
  47	Done      bool
  48}
  49
  50type Service interface {
  51	pubsub.Suscriber[AgentEvent]
  52	Model() catwalk.Model
  53	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  54	Cancel(sessionID string)
  55	CancelAll()
  56	IsSessionBusy(sessionID string) bool
  57	IsBusy() bool
  58	Summarize(ctx context.Context, sessionID string) error
  59	UpdateModel() error
  60	QueuedPrompts(sessionID string) int
  61	ClearQueue(sessionID string)
  62}
  63
  64type agent struct {
  65	*pubsub.Broker[AgentEvent]
  66	agentCfg    config.Agent
  67	sessions    session.Service
  68	messages    message.Service
  69	permissions permission.Service
  70	baseTools   *csync.Map[string, tools.BaseTool]
  71	mcpTools    *csync.Map[string, tools.BaseTool]
  72	lspClients  *csync.Map[string, *lsp.Client]
  73
  74	// We need this to be able to update it when model changes
  75	agentToolFn  func() (tools.BaseTool, error)
  76	cleanupFuncs []func()
  77
  78	provider   provider.Provider
  79	providerID string
  80
  81	titleProvider       provider.Provider
  82	summarizeProvider   provider.Provider
  83	summarizeProviderID string
  84
  85	activeRequests *csync.Map[string, context.CancelFunc]
  86	promptQueue    *csync.Map[string, []string]
  87}
  88
  89var agentPromptMap = map[string]prompt.PromptID{
  90	"coder": prompt.PromptCoder,
  91	"task":  prompt.PromptTask,
  92}
  93
  94func NewAgent(
  95	ctx context.Context,
  96	agentCfg config.Agent,
  97	// These services are needed in the tools
  98	permissions permission.Service,
  99	sessions session.Service,
 100	messages message.Service,
 101	history history.Service,
 102	lspClients *csync.Map[string, *lsp.Client],
 103) (Service, error) {
 104	cfg := config.Get()
 105
 106	var agentToolFn func() (tools.BaseTool, error)
 107	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 108		agentToolFn = func() (tools.BaseTool, error) {
 109			taskAgentCfg := config.Get().Agents["task"]
 110			if taskAgentCfg.ID == "" {
 111				return nil, fmt.Errorf("task agent not found in config")
 112			}
 113			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 114			if err != nil {
 115				return nil, fmt.Errorf("failed to create task agent: %w", err)
 116			}
 117			return NewAgentTool(taskAgent, sessions, messages), nil
 118		}
 119	}
 120
 121	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 122	if providerCfg == nil {
 123		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 124	}
 125	model := config.Get().GetModelByType(agentCfg.Model)
 126
 127	if model == nil {
 128		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 129	}
 130
 131	promptID := agentPromptMap[agentCfg.ID]
 132	if promptID == "" {
 133		promptID = prompt.PromptDefault
 134	}
 135	opts := []provider.ProviderClientOption{
 136		provider.WithModel(agentCfg.Model),
 137		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 138	}
 139	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 140	if err != nil {
 141		return nil, err
 142	}
 143
 144	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 145	var smallModelProviderCfg *config.ProviderConfig
 146	if smallModelCfg.Provider == providerCfg.ID {
 147		smallModelProviderCfg = providerCfg
 148	} else {
 149		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 150
 151		if smallModelProviderCfg.ID == "" {
 152			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 153		}
 154	}
 155	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 156	if smallModel.ID == "" {
 157		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 158	}
 159
 160	titleOpts := []provider.ProviderClientOption{
 161		provider.WithModel(config.SelectedModelTypeSmall),
 162		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 163	}
 164	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 165	if err != nil {
 166		return nil, err
 167	}
 168
 169	summarizeOpts := []provider.ProviderClientOption{
 170		provider.WithModel(config.SelectedModelTypeLarge),
 171		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 172	}
 173	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 174	if err != nil {
 175		return nil, err
 176	}
 177
 178	baseToolsFn := func() map[string]tools.BaseTool {
 179		slog.Debug("Initializing agent base tools", "agent", agentCfg.ID)
 180		defer func() {
 181			slog.Debug("Initialized agent base tools", "agent", agentCfg.ID)
 182		}()
 183
 184		// Base tools available to all agents
 185		cwd := cfg.WorkingDir()
 186		result := make(map[string]tools.BaseTool)
 187		for _, tool := range []tools.BaseTool{
 188			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 189			tools.NewDownloadTool(permissions, cwd),
 190			tools.NewEditTool(lspClients, permissions, history, cwd),
 191			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 192			tools.NewFetchTool(permissions, cwd),
 193			tools.NewGlobTool(cwd),
 194			tools.NewGrepTool(cwd),
 195			tools.NewLsTool(permissions, cwd),
 196			tools.NewSourcegraphTool(),
 197			tools.NewViewTool(lspClients, permissions, cwd),
 198			tools.NewWriteTool(lspClients, permissions, history, cwd),
 199		} {
 200			result[tool.Name()] = tool
 201		}
 202		return result
 203	}
 204	mcpToolsFn := func() map[string]tools.BaseTool {
 205		slog.Debug("Initializing agent mcp tools", "agent", agentCfg.ID)
 206		defer func() {
 207			slog.Debug("Initialized agent mcp tools", "agent", agentCfg.ID)
 208		}()
 209
 210		mcpToolsOnce.Do(func() {
 211			doGetMCPTools(ctx, permissions, cfg)
 212		})
 213
 214		return maps.Collect(mcpTools.Seq2())
 215	}
 216
 217	a := &agent{
 218		Broker:              pubsub.NewBroker[AgentEvent](),
 219		agentCfg:            agentCfg,
 220		provider:            agentProvider,
 221		providerID:          string(providerCfg.ID),
 222		messages:            messages,
 223		sessions:            sessions,
 224		titleProvider:       titleProvider,
 225		summarizeProvider:   summarizeProvider,
 226		summarizeProviderID: string(providerCfg.ID),
 227		agentToolFn:         agentToolFn,
 228		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 229		mcpTools:            csync.NewLazyMap(mcpToolsFn),
 230		baseTools:           csync.NewLazyMap(baseToolsFn),
 231		promptQueue:         csync.NewMap[string, []string](),
 232		permissions:         permissions,
 233		lspClients:          lspClients,
 234	}
 235	a.setupEvents(ctx)
 236	return a, nil
 237}
 238
 239func (a *agent) Model() catwalk.Model {
 240	return *config.Get().GetModelByType(a.agentCfg.Model)
 241}
 242
 243func (a *agent) Cancel(sessionID string) {
 244	// Cancel regular requests
 245	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 246		slog.Info("Request cancellation initiated", "session_id", sessionID)
 247		cancel()
 248	}
 249
 250	// Also check for summarize requests
 251	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 252		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 253		cancel()
 254	}
 255
 256	if a.QueuedPrompts(sessionID) > 0 {
 257		slog.Info("Clearing queued prompts", "session_id", sessionID)
 258		a.promptQueue.Del(sessionID)
 259	}
 260}
 261
 262func (a *agent) IsBusy() bool {
 263	var busy bool
 264	for cancelFunc := range a.activeRequests.Seq() {
 265		if cancelFunc != nil {
 266			busy = true
 267			break
 268		}
 269	}
 270	return busy
 271}
 272
 273func (a *agent) IsSessionBusy(sessionID string) bool {
 274	_, busy := a.activeRequests.Get(sessionID)
 275	return busy
 276}
 277
 278func (a *agent) QueuedPrompts(sessionID string) int {
 279	l, ok := a.promptQueue.Get(sessionID)
 280	if !ok {
 281		return 0
 282	}
 283	return len(l)
 284}
 285
 286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 287	if content == "" {
 288		return nil
 289	}
 290	if a.titleProvider == nil {
 291		return nil
 292	}
 293	session, err := a.sessions.Get(ctx, sessionID)
 294	if err != nil {
 295		return err
 296	}
 297	parts := []message.ContentPart{message.TextContent{
 298		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 299	}}
 300
 301	// Use streaming approach like summarization
 302	response := a.titleProvider.StreamResponse(
 303		ctx,
 304		[]message.Message{
 305			{
 306				Role:  message.User,
 307				Parts: parts,
 308			},
 309		},
 310		nil,
 311	)
 312
 313	var finalResponse *provider.ProviderResponse
 314	for r := range response {
 315		if r.Error != nil {
 316			return r.Error
 317		}
 318		finalResponse = r.Response
 319	}
 320
 321	if finalResponse == nil {
 322		return fmt.Errorf("no response received from title provider")
 323	}
 324
 325	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 326
 327	if idx := strings.Index(title, "</think>"); idx > 0 {
 328		title = title[idx+len("</think>"):]
 329	}
 330
 331	title = strings.TrimSpace(title)
 332	if title == "" {
 333		return nil
 334	}
 335
 336	session.Title = title
 337	_, err = a.sessions.Save(ctx, session)
 338	return err
 339}
 340
 341func (a *agent) err(err error) AgentEvent {
 342	return AgentEvent{
 343		Type:  AgentEventTypeError,
 344		Error: err,
 345	}
 346}
 347
 348func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 349	if !a.Model().SupportsImages && attachments != nil {
 350		attachments = nil
 351	}
 352	events := make(chan AgentEvent, 1)
 353	if a.IsSessionBusy(sessionID) {
 354		existing, ok := a.promptQueue.Get(sessionID)
 355		if !ok {
 356			existing = []string{}
 357		}
 358		existing = append(existing, content)
 359		a.promptQueue.Set(sessionID, existing)
 360		return nil, nil
 361	}
 362
 363	genCtx, cancel := context.WithCancel(ctx)
 364	a.activeRequests.Set(sessionID, cancel)
 365	startTime := time.Now()
 366
 367	go func() {
 368		slog.Debug("Request started", "sessionID", sessionID)
 369		defer log.RecoverPanic("agent.Run", func() {
 370			events <- a.err(fmt.Errorf("panic while running the agent"))
 371		})
 372		var attachmentParts []message.ContentPart
 373		for _, attachment := range attachments {
 374			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 375		}
 376		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 377		if result.Error != nil {
 378			if isCancelledErr(result.Error) {
 379				slog.Error("Request canceled", "sessionID", sessionID)
 380			} else {
 381				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 382				event.Error(result.Error)
 383			}
 384		} else {
 385			slog.Debug("Request completed", "sessionID", sessionID)
 386		}
 387		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 388		a.activeRequests.Del(sessionID)
 389		cancel()
 390		a.Publish(pubsub.CreatedEvent, result)
 391		events <- result
 392		close(events)
 393	}()
 394	a.eventPromptSent(sessionID)
 395	return events, nil
 396}
 397
 398func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 399	cfg := config.Get()
 400	// List existing messages; if none, start title generation asynchronously.
 401	msgs, err := a.messages.List(ctx, sessionID)
 402	if err != nil {
 403		return a.err(fmt.Errorf("failed to list messages: %w", err))
 404	}
 405	if len(msgs) == 0 {
 406		go func() {
 407			defer log.RecoverPanic("agent.Run", func() {
 408				slog.Error("panic while generating title")
 409			})
 410			titleErr := a.generateTitle(ctx, sessionID, content)
 411			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 412				slog.Error("failed to generate title", "error", titleErr)
 413			}
 414		}()
 415	}
 416	session, err := a.sessions.Get(ctx, sessionID)
 417	if err != nil {
 418		return a.err(fmt.Errorf("failed to get session: %w", err))
 419	}
 420	if session.SummaryMessageID != "" {
 421		summaryMsgInex := -1
 422		for i, msg := range msgs {
 423			if msg.ID == session.SummaryMessageID {
 424				summaryMsgInex = i
 425				break
 426			}
 427		}
 428		if summaryMsgInex != -1 {
 429			msgs = msgs[summaryMsgInex:]
 430			msgs[0].Role = message.User
 431		}
 432	}
 433
 434	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 435	if err != nil {
 436		return a.err(fmt.Errorf("failed to create user message: %w", err))
 437	}
 438	// Append the new user message to the conversation history.
 439	msgHistory := append(msgs, userMsg)
 440
 441	for {
 442		// Check for cancellation before each iteration
 443		select {
 444		case <-ctx.Done():
 445			return a.err(ctx.Err())
 446		default:
 447			// Continue processing
 448		}
 449		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 450		if err != nil {
 451			if errors.Is(err, context.Canceled) {
 452				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 453				a.messages.Update(context.Background(), agentMessage)
 454				return a.err(ErrRequestCancelled)
 455			}
 456			return a.err(fmt.Errorf("failed to process events: %w", err))
 457		}
 458		if cfg.Options.Debug {
 459			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 460		}
 461		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 462			// We are not done, we need to respond with the tool response
 463			msgHistory = append(msgHistory, agentMessage, *toolResults)
 464			// If there are queued prompts, process the next one
 465			nextPrompt, ok := a.promptQueue.Take(sessionID)
 466			if ok {
 467				for _, prompt := range nextPrompt {
 468					// Create a new user message for the queued prompt
 469					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 470					if err != nil {
 471						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 472					}
 473					// Append the new user message to the conversation history
 474					msgHistory = append(msgHistory, userMsg)
 475				}
 476			}
 477
 478			continue
 479		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 480			queuePrompts, ok := a.promptQueue.Take(sessionID)
 481			if ok {
 482				for _, prompt := range queuePrompts {
 483					if prompt == "" {
 484						continue
 485					}
 486					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 487					if err != nil {
 488						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 489					}
 490					msgHistory = append(msgHistory, userMsg)
 491				}
 492				continue
 493			}
 494		}
 495		if agentMessage.FinishReason() == "" {
 496			// Kujtim: could not track down where this is happening but this means its cancelled
 497			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 498			_ = a.messages.Update(context.Background(), agentMessage)
 499			return a.err(ErrRequestCancelled)
 500		}
 501		return AgentEvent{
 502			Type:    AgentEventTypeResponse,
 503			Message: agentMessage,
 504			Done:    true,
 505		}
 506	}
 507}
 508
 509func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 510	parts := []message.ContentPart{message.TextContent{Text: content}}
 511	parts = append(parts, attachmentParts...)
 512	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 513		Role:  message.User,
 514		Parts: parts,
 515	})
 516}
 517
 518func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 519	var allTools []tools.BaseTool
 520	for tool := range a.baseTools.Seq() {
 521		if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
 522			allTools = append(allTools, tool)
 523		}
 524	}
 525	if a.agentCfg.ID == "coder" {
 526		allTools = slices.AppendSeq(allTools, a.mcpTools.Seq())
 527		if a.lspClients.Len() > 0 {
 528			allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients), tools.NewReferencesTool(a.lspClients))
 529		}
 530	}
 531	if a.agentToolFn != nil {
 532		agentTool, agentToolErr := a.agentToolFn()
 533		if agentToolErr != nil {
 534			return nil, agentToolErr
 535		}
 536		allTools = append(allTools, agentTool)
 537	}
 538
 539	slices.SortFunc(allTools, func(a, b tools.BaseTool) int {
 540		return strings.Compare(a.Name(), b.Name())
 541	})
 542	return allTools, nil
 543}
 544
 545func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 546	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 547
 548	// Create the assistant message first so the spinner shows immediately
 549	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 550		Role:     message.Assistant,
 551		Parts:    []message.ContentPart{},
 552		Model:    a.Model().ID,
 553		Provider: a.providerID,
 554	})
 555	if err != nil {
 556		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 557	}
 558
 559	allTools, toolsErr := a.getAllTools()
 560	if toolsErr != nil {
 561		return assistantMsg, nil, toolsErr
 562	}
 563	// Now collect tools (which may block on MCP initialization)
 564	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 565
 566	// Add the session and message ID into the context if needed by tools.
 567	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 568
 569loop:
 570	for {
 571		select {
 572		case event, ok := <-eventChan:
 573			if !ok {
 574				break loop
 575			}
 576			if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 577				if errors.Is(processErr, context.Canceled) {
 578					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 579				} else {
 580					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 581				}
 582				return assistantMsg, nil, processErr
 583			}
 584		case <-ctx.Done():
 585			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 586			return assistantMsg, nil, ctx.Err()
 587		}
 588	}
 589
 590	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 591	toolCalls := assistantMsg.ToolCalls()
 592	for i, toolCall := range toolCalls {
 593		select {
 594		case <-ctx.Done():
 595			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 596			// Make all future tool calls cancelled
 597			for j := i; j < len(toolCalls); j++ {
 598				toolResults[j] = message.ToolResult{
 599					ToolCallID: toolCalls[j].ID,
 600					Content:    "Tool execution canceled by user",
 601					IsError:    true,
 602				}
 603			}
 604			goto out
 605		default:
 606			// Continue processing
 607			var tool tools.BaseTool
 608			allTools, _ = a.getAllTools()
 609			for _, availableTool := range allTools {
 610				if availableTool.Info().Name == toolCall.Name {
 611					tool = availableTool
 612					break
 613				}
 614			}
 615
 616			// Tool not found
 617			if tool == nil {
 618				toolResults[i] = message.ToolResult{
 619					ToolCallID: toolCall.ID,
 620					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 621					IsError:    true,
 622				}
 623				continue
 624			}
 625
 626			// Run tool in goroutine to allow cancellation
 627			type toolExecResult struct {
 628				response tools.ToolResponse
 629				err      error
 630			}
 631			resultChan := make(chan toolExecResult, 1)
 632
 633			go func() {
 634				response, err := tool.Run(ctx, tools.ToolCall{
 635					ID:    toolCall.ID,
 636					Name:  toolCall.Name,
 637					Input: toolCall.Input,
 638				})
 639				resultChan <- toolExecResult{response: response, err: err}
 640			}()
 641
 642			var toolResponse tools.ToolResponse
 643			var toolErr error
 644
 645			select {
 646			case <-ctx.Done():
 647				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 648				// Mark remaining tool calls as cancelled
 649				for j := i; j < len(toolCalls); j++ {
 650					toolResults[j] = message.ToolResult{
 651						ToolCallID: toolCalls[j].ID,
 652						Content:    "Tool execution canceled by user",
 653						IsError:    true,
 654					}
 655				}
 656				goto out
 657			case result := <-resultChan:
 658				toolResponse = result.response
 659				toolErr = result.err
 660			}
 661
 662			if toolErr != nil {
 663				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 664				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 665					toolResults[i] = message.ToolResult{
 666						ToolCallID: toolCall.ID,
 667						Content:    "Permission denied",
 668						IsError:    true,
 669					}
 670					for j := i + 1; j < len(toolCalls); j++ {
 671						toolResults[j] = message.ToolResult{
 672							ToolCallID: toolCalls[j].ID,
 673							Content:    "Tool execution canceled by user",
 674							IsError:    true,
 675						}
 676					}
 677					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 678					break
 679				}
 680			}
 681			toolResults[i] = message.ToolResult{
 682				ToolCallID: toolCall.ID,
 683				Content:    toolResponse.Content,
 684				Metadata:   toolResponse.Metadata,
 685				IsError:    toolResponse.IsError,
 686			}
 687		}
 688	}
 689out:
 690	if len(toolResults) == 0 {
 691		return assistantMsg, nil, nil
 692	}
 693	parts := make([]message.ContentPart, 0)
 694	for _, tr := range toolResults {
 695		parts = append(parts, tr)
 696	}
 697	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 698		Role:     message.Tool,
 699		Parts:    parts,
 700		Provider: a.providerID,
 701	})
 702	if err != nil {
 703		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 704	}
 705
 706	return assistantMsg, &msg, err
 707}
 708
 709func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 710	msg.AddFinish(finishReason, message, details)
 711	_ = a.messages.Update(ctx, *msg)
 712}
 713
 714func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 715	select {
 716	case <-ctx.Done():
 717		return ctx.Err()
 718	default:
 719		// Continue processing.
 720	}
 721
 722	switch event.Type {
 723	case provider.EventThinkingDelta:
 724		assistantMsg.AppendReasoningContent(event.Thinking)
 725		return a.messages.Update(ctx, *assistantMsg)
 726	case provider.EventSignatureDelta:
 727		assistantMsg.AppendReasoningSignature(event.Signature)
 728		return a.messages.Update(ctx, *assistantMsg)
 729	case provider.EventContentDelta:
 730		assistantMsg.FinishThinking()
 731		assistantMsg.AppendContent(event.Content)
 732		return a.messages.Update(ctx, *assistantMsg)
 733	case provider.EventToolUseStart:
 734		assistantMsg.FinishThinking()
 735		slog.Info("Tool call started", "toolCall", event.ToolCall)
 736		assistantMsg.AddToolCall(*event.ToolCall)
 737		return a.messages.Update(ctx, *assistantMsg)
 738	case provider.EventToolUseDelta:
 739		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 740		return a.messages.Update(ctx, *assistantMsg)
 741	case provider.EventToolUseStop:
 742		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 743		assistantMsg.FinishToolCall(event.ToolCall.ID)
 744		return a.messages.Update(ctx, *assistantMsg)
 745	case provider.EventError:
 746		return event.Error
 747	case provider.EventComplete:
 748		assistantMsg.FinishThinking()
 749		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 750		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 751		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 752			return fmt.Errorf("failed to update message: %w", err)
 753		}
 754		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 755	}
 756
 757	return nil
 758}
 759
 760func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 761	sess, err := a.sessions.Get(ctx, sessionID)
 762	if err != nil {
 763		return fmt.Errorf("failed to get session: %w", err)
 764	}
 765
 766	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 767		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 768		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 769		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 770
 771	a.eventTokensUsed(sessionID, usage, cost)
 772
 773	sess.Cost += cost
 774	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 775	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 776
 777	_, err = a.sessions.Save(ctx, sess)
 778	if err != nil {
 779		return fmt.Errorf("failed to save session: %w", err)
 780	}
 781	return nil
 782}
 783
 784func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 785	if a.summarizeProvider == nil {
 786		return fmt.Errorf("summarize provider not available")
 787	}
 788
 789	// Check if session is busy
 790	if a.IsSessionBusy(sessionID) {
 791		return ErrSessionBusy
 792	}
 793
 794	// Create a new context with cancellation
 795	summarizeCtx, cancel := context.WithCancel(ctx)
 796
 797	// Store the cancel function in activeRequests to allow cancellation
 798	a.activeRequests.Set(sessionID+"-summarize", cancel)
 799
 800	go func() {
 801		defer a.activeRequests.Del(sessionID + "-summarize")
 802		defer cancel()
 803		event := AgentEvent{
 804			Type:     AgentEventTypeSummarize,
 805			Progress: "Starting summarization...",
 806		}
 807
 808		a.Publish(pubsub.CreatedEvent, event)
 809		// Get all messages from the session
 810		msgs, err := a.messages.List(summarizeCtx, sessionID)
 811		if err != nil {
 812			event = AgentEvent{
 813				Type:  AgentEventTypeError,
 814				Error: fmt.Errorf("failed to list messages: %w", err),
 815				Done:  true,
 816			}
 817			a.Publish(pubsub.CreatedEvent, event)
 818			return
 819		}
 820		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 821
 822		if len(msgs) == 0 {
 823			event = AgentEvent{
 824				Type:  AgentEventTypeError,
 825				Error: fmt.Errorf("no messages to summarize"),
 826				Done:  true,
 827			}
 828			a.Publish(pubsub.CreatedEvent, event)
 829			return
 830		}
 831
 832		event = AgentEvent{
 833			Type:     AgentEventTypeSummarize,
 834			Progress: "Analyzing conversation...",
 835		}
 836		a.Publish(pubsub.CreatedEvent, event)
 837
 838		// Add a system message to guide the summarization
 839		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 840
 841		// Create a new message with the summarize prompt
 842		promptMsg := message.Message{
 843			Role:  message.User,
 844			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 845		}
 846
 847		// Append the prompt to the messages
 848		msgsWithPrompt := append(msgs, promptMsg)
 849
 850		event = AgentEvent{
 851			Type:     AgentEventTypeSummarize,
 852			Progress: "Generating summary...",
 853		}
 854
 855		a.Publish(pubsub.CreatedEvent, event)
 856
 857		// Send the messages to the summarize provider
 858		response := a.summarizeProvider.StreamResponse(
 859			summarizeCtx,
 860			msgsWithPrompt,
 861			nil,
 862		)
 863		var finalResponse *provider.ProviderResponse
 864		for r := range response {
 865			if r.Error != nil {
 866				event = AgentEvent{
 867					Type:  AgentEventTypeError,
 868					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 869					Done:  true,
 870				}
 871				a.Publish(pubsub.CreatedEvent, event)
 872				return
 873			}
 874			finalResponse = r.Response
 875		}
 876
 877		summary := strings.TrimSpace(finalResponse.Content)
 878		if summary == "" {
 879			event = AgentEvent{
 880				Type:  AgentEventTypeError,
 881				Error: fmt.Errorf("empty summary returned"),
 882				Done:  true,
 883			}
 884			a.Publish(pubsub.CreatedEvent, event)
 885			return
 886		}
 887		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 888		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 889		event = AgentEvent{
 890			Type:     AgentEventTypeSummarize,
 891			Progress: "Creating new session...",
 892		}
 893
 894		a.Publish(pubsub.CreatedEvent, event)
 895		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 896		if err != nil {
 897			event = AgentEvent{
 898				Type:  AgentEventTypeError,
 899				Error: fmt.Errorf("failed to get session: %w", err),
 900				Done:  true,
 901			}
 902
 903			a.Publish(pubsub.CreatedEvent, event)
 904			return
 905		}
 906		// Create a message in the new session with the summary
 907		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 908			Role: message.Assistant,
 909			Parts: []message.ContentPart{
 910				message.TextContent{Text: summary},
 911				message.Finish{
 912					Reason: message.FinishReasonEndTurn,
 913					Time:   time.Now().Unix(),
 914				},
 915			},
 916			Model:    a.summarizeProvider.Model().ID,
 917			Provider: a.summarizeProviderID,
 918		})
 919		if err != nil {
 920			event = AgentEvent{
 921				Type:  AgentEventTypeError,
 922				Error: fmt.Errorf("failed to create summary message: %w", err),
 923				Done:  true,
 924			}
 925
 926			a.Publish(pubsub.CreatedEvent, event)
 927			return
 928		}
 929		oldSession.SummaryMessageID = msg.ID
 930		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 931		oldSession.PromptTokens = 0
 932		model := a.summarizeProvider.Model()
 933		usage := finalResponse.Usage
 934		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 935			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 936			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 937			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 938		oldSession.Cost += cost
 939		_, err = a.sessions.Save(summarizeCtx, oldSession)
 940		if err != nil {
 941			event = AgentEvent{
 942				Type:  AgentEventTypeError,
 943				Error: fmt.Errorf("failed to save session: %w", err),
 944				Done:  true,
 945			}
 946			a.Publish(pubsub.CreatedEvent, event)
 947		}
 948
 949		event = AgentEvent{
 950			Type:      AgentEventTypeSummarize,
 951			SessionID: oldSession.ID,
 952			Progress:  "Summary complete",
 953			Done:      true,
 954		}
 955		a.Publish(pubsub.CreatedEvent, event)
 956		// Send final success event with the new session ID
 957	}()
 958
 959	return nil
 960}
 961
 962func (a *agent) ClearQueue(sessionID string) {
 963	if a.QueuedPrompts(sessionID) > 0 {
 964		slog.Info("Clearing queued prompts", "session_id", sessionID)
 965		a.promptQueue.Del(sessionID)
 966	}
 967}
 968
 969func (a *agent) CancelAll() {
 970	if !a.IsBusy() {
 971		return
 972	}
 973	for key := range a.activeRequests.Seq2() {
 974		a.Cancel(key) // key is sessionID
 975	}
 976
 977	for _, cleanup := range a.cleanupFuncs {
 978		if cleanup != nil {
 979			cleanup()
 980		}
 981	}
 982
 983	timeout := time.After(5 * time.Second)
 984	for a.IsBusy() {
 985		select {
 986		case <-timeout:
 987			return
 988		default:
 989			time.Sleep(200 * time.Millisecond)
 990		}
 991	}
 992}
 993
 994func (a *agent) UpdateModel() error {
 995	cfg := config.Get()
 996
 997	// Get current provider configuration
 998	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 999	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
1000		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
1001	}
1002
1003	// Check if provider has changed
1004	if string(currentProviderCfg.ID) != a.providerID {
1005		// Provider changed, need to recreate the main provider
1006		model := cfg.GetModelByType(a.agentCfg.Model)
1007		if model.ID == "" {
1008			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1009		}
1010
1011		promptID := agentPromptMap[a.agentCfg.ID]
1012		if promptID == "" {
1013			promptID = prompt.PromptDefault
1014		}
1015
1016		opts := []provider.ProviderClientOption{
1017			provider.WithModel(a.agentCfg.Model),
1018			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1019		}
1020
1021		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1022		if err != nil {
1023			return fmt.Errorf("failed to create new provider: %w", err)
1024		}
1025
1026		// Update the provider and provider ID
1027		a.provider = newProvider
1028		a.providerID = string(currentProviderCfg.ID)
1029	}
1030
1031	// Check if providers have changed for title (small) and summarize (large)
1032	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1033	var smallModelProviderCfg config.ProviderConfig
1034	for p := range cfg.Providers.Seq() {
1035		if p.ID == smallModelCfg.Provider {
1036			smallModelProviderCfg = p
1037			break
1038		}
1039	}
1040	if smallModelProviderCfg.ID == "" {
1041		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1042	}
1043
1044	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1045	var largeModelProviderCfg config.ProviderConfig
1046	for p := range cfg.Providers.Seq() {
1047		if p.ID == largeModelCfg.Provider {
1048			largeModelProviderCfg = p
1049			break
1050		}
1051	}
1052	if largeModelProviderCfg.ID == "" {
1053		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1054	}
1055
1056	var maxTitleTokens int64 = 40
1057
1058	// if the max output is too low for the gemini provider it won't return anything
1059	if smallModelCfg.Provider == "gemini" {
1060		maxTitleTokens = 1000
1061	}
1062	// Recreate title provider
1063	titleOpts := []provider.ProviderClientOption{
1064		provider.WithModel(config.SelectedModelTypeSmall),
1065		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1066		provider.WithMaxTokens(maxTitleTokens),
1067	}
1068	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1069	if err != nil {
1070		return fmt.Errorf("failed to create new title provider: %w", err)
1071	}
1072	a.titleProvider = newTitleProvider
1073
1074	// Recreate summarize provider if provider changed (now large model)
1075	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1076		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1077		if largeModel == nil {
1078			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1079		}
1080		summarizeOpts := []provider.ProviderClientOption{
1081			provider.WithModel(config.SelectedModelTypeLarge),
1082			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1083		}
1084		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1085		if err != nil {
1086			return fmt.Errorf("failed to create new summarize provider: %w", err)
1087		}
1088		a.summarizeProvider = newSummarizeProvider
1089		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1090	}
1091
1092	return nil
1093}
1094
1095func (a *agent) setupEvents(ctx context.Context) {
1096	ctx, cancel := context.WithCancel(ctx)
1097
1098	go func() {
1099		subCh := SubscribeMCPEvents(ctx)
1100
1101		for {
1102			select {
1103			case event, ok := <-subCh:
1104				if !ok {
1105					slog.Debug("MCPEvents subscription channel closed")
1106					return
1107				}
1108				switch event.Payload.Type {
1109				case MCPEventToolsListChanged:
1110					name := event.Payload.Name
1111					c, ok := mcpClients.Get(name)
1112					if !ok {
1113						slog.Warn("MCP client not found for tools update", "name", name)
1114						continue
1115					}
1116					cfg := config.Get()
1117					tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1118					if err != nil {
1119						slog.Error("error listing tools", "error", err)
1120						updateMCPState(name, MCPStateError, err, nil, 0)
1121						_ = c.Close()
1122						continue
1123					}
1124					updateMcpTools(name, tools)
1125					a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
1126					updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1127				default:
1128					continue
1129				}
1130			case <-ctx.Done():
1131				slog.Debug("MCPEvents subscription cancelled")
1132				return
1133			}
1134		}
1135	}()
1136
1137	a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1138}