agent.go

   1// Package agent contains the implementation of the AI agent service.
   2package agent
   3
   4import (
   5	"context"
   6	"errors"
   7	"fmt"
   8	"log/slog"
   9	"maps"
  10	"slices"
  11	"strings"
  12	"time"
  13
  14	"github.com/charmbracelet/catwalk/pkg/catwalk"
  15	"github.com/charmbracelet/crush/internal/config"
  16	"github.com/charmbracelet/crush/internal/csync"
  17	"github.com/charmbracelet/crush/internal/event"
  18	"github.com/charmbracelet/crush/internal/history"
  19	"github.com/charmbracelet/crush/internal/llm/prompt"
  20	"github.com/charmbracelet/crush/internal/llm/provider"
  21	"github.com/charmbracelet/crush/internal/llm/tools"
  22	"github.com/charmbracelet/crush/internal/log"
  23	"github.com/charmbracelet/crush/internal/lsp"
  24	"github.com/charmbracelet/crush/internal/message"
  25	"github.com/charmbracelet/crush/internal/permission"
  26	"github.com/charmbracelet/crush/internal/pubsub"
  27	"github.com/charmbracelet/crush/internal/session"
  28	"github.com/charmbracelet/crush/internal/shell"
  29)
  30
  31type AgentEventType string
  32
  33const (
  34	AgentEventTypeError     AgentEventType = "error"
  35	AgentEventTypeResponse  AgentEventType = "response"
  36	AgentEventTypeSummarize AgentEventType = "summarize"
  37)
  38
  39type AgentEvent struct {
  40	Type    AgentEventType
  41	Message message.Message
  42	Error   error
  43
  44	// When summarizing
  45	SessionID string
  46	Progress  string
  47	Done      bool
  48}
  49
  50type Service interface {
  51	pubsub.Suscriber[AgentEvent]
  52	Model() catwalk.Model
  53	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  54	Cancel(sessionID string)
  55	CancelAll()
  56	IsSessionBusy(sessionID string) bool
  57	IsBusy() bool
  58	Summarize(ctx context.Context, sessionID string) error
  59	UpdateModel() error
  60	QueuedPrompts(sessionID string) int
  61	ClearQueue(sessionID string)
  62}
  63
  64type agent struct {
  65	*pubsub.Broker[AgentEvent]
  66	agentCfg    config.Agent
  67	sessions    session.Service
  68	messages    message.Service
  69	permissions permission.Service
  70	baseTools   *csync.Map[string, tools.BaseTool]
  71	mcpTools    *csync.Map[string, tools.BaseTool]
  72	lspClients  *csync.Map[string, *lsp.Client]
  73
  74	// We need this to be able to update it when model changes
  75	agentToolFn  func() (tools.BaseTool, error)
  76	cleanupFuncs []func()
  77
  78	provider   provider.Provider
  79	providerID string
  80
  81	titleProvider       provider.Provider
  82	summarizeProvider   provider.Provider
  83	summarizeProviderID string
  84
  85	activeRequests *csync.Map[string, context.CancelFunc]
  86	promptQueue    *csync.Map[string, []string]
  87}
  88
  89var agentPromptMap = map[string]prompt.PromptID{
  90	"coder": prompt.PromptCoder,
  91	"task":  prompt.PromptTask,
  92}
  93
  94func NewAgent(
  95	ctx context.Context,
  96	agentCfg config.Agent,
  97	// These services are needed in the tools
  98	permissions permission.Service,
  99	sessions session.Service,
 100	messages message.Service,
 101	history history.Service,
 102	lspClients *csync.Map[string, *lsp.Client],
 103) (Service, error) {
 104	cfg := config.Get()
 105
 106	var agentToolFn func() (tools.BaseTool, error)
 107	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 108		agentToolFn = func() (tools.BaseTool, error) {
 109			taskAgentCfg := config.Get().Agents["task"]
 110			if taskAgentCfg.ID == "" {
 111				return nil, fmt.Errorf("task agent not found in config")
 112			}
 113			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 114			if err != nil {
 115				return nil, fmt.Errorf("failed to create task agent: %w", err)
 116			}
 117			return NewAgentTool(taskAgent, sessions, messages), nil
 118		}
 119	}
 120
 121	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 122	if providerCfg == nil {
 123		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 124	}
 125	model := config.Get().GetModelByType(agentCfg.Model)
 126
 127	if model == nil {
 128		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 129	}
 130
 131	promptID := agentPromptMap[agentCfg.ID]
 132	if promptID == "" {
 133		promptID = prompt.PromptDefault
 134	}
 135	opts := []provider.ProviderClientOption{
 136		provider.WithModel(agentCfg.Model),
 137		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 138	}
 139	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 140	if err != nil {
 141		return nil, err
 142	}
 143
 144	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 145	var smallModelProviderCfg *config.ProviderConfig
 146	if smallModelCfg.Provider == providerCfg.ID {
 147		smallModelProviderCfg = providerCfg
 148	} else {
 149		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 150
 151		if smallModelProviderCfg.ID == "" {
 152			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 153		}
 154	}
 155	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 156	if smallModel.ID == "" {
 157		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 158	}
 159
 160	titleOpts := []provider.ProviderClientOption{
 161		provider.WithModel(config.SelectedModelTypeSmall),
 162		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 163	}
 164	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 165	if err != nil {
 166		return nil, err
 167	}
 168
 169	summarizeOpts := []provider.ProviderClientOption{
 170		provider.WithModel(config.SelectedModelTypeLarge),
 171		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 172	}
 173	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 174	if err != nil {
 175		return nil, err
 176	}
 177
 178	baseToolsFn := func() map[string]tools.BaseTool {
 179		slog.Debug("Initializing agent base tools", "agent", agentCfg.ID)
 180		defer func() {
 181			slog.Debug("Initialized agent base tools", "agent", agentCfg.ID)
 182		}()
 183
 184		// Base tools available to all agents
 185		cwd := cfg.WorkingDir()
 186		result := make(map[string]tools.BaseTool)
 187		for _, tool := range []tools.BaseTool{
 188			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 189			tools.NewDownloadTool(permissions, cwd),
 190			tools.NewEditTool(lspClients, permissions, history, cwd),
 191			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 192			tools.NewFetchTool(permissions, cwd),
 193			tools.NewGlobTool(cwd),
 194			tools.NewGrepTool(cwd),
 195			tools.NewLsTool(permissions, cwd),
 196			tools.NewSourcegraphTool(),
 197			tools.NewViewTool(lspClients, permissions, cwd),
 198			tools.NewWriteTool(lspClients, permissions, history, cwd),
 199		} {
 200			result[tool.Name()] = tool
 201		}
 202		return result
 203	}
 204	mcpToolsFn := func() map[string]tools.BaseTool {
 205		slog.Debug("Initializing agent mcp tools", "agent", agentCfg.ID)
 206		defer func() {
 207			slog.Debug("Initialized agent mcp tools", "agent", agentCfg.ID)
 208		}()
 209
 210		mcpToolsOnce.Do(func() {
 211			doGetMCPTools(ctx, permissions, cfg)
 212		})
 213
 214		return maps.Collect(mcpTools.Seq2())
 215	}
 216
 217	a := &agent{
 218		Broker:              pubsub.NewBroker[AgentEvent](),
 219		agentCfg:            agentCfg,
 220		provider:            agentProvider,
 221		providerID:          string(providerCfg.ID),
 222		messages:            messages,
 223		sessions:            sessions,
 224		titleProvider:       titleProvider,
 225		summarizeProvider:   summarizeProvider,
 226		summarizeProviderID: string(providerCfg.ID),
 227		agentToolFn:         agentToolFn,
 228		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 229		mcpTools:            csync.NewLazyMap(mcpToolsFn),
 230		baseTools:           csync.NewLazyMap(baseToolsFn),
 231		promptQueue:         csync.NewMap[string, []string](),
 232		permissions:         permissions,
 233		lspClients:          lspClients,
 234	}
 235	a.setupEvents(ctx)
 236	return a, nil
 237}
 238
 239func (a *agent) Model() catwalk.Model {
 240	return *config.Get().GetModelByType(a.agentCfg.Model)
 241}
 242
 243func (a *agent) Cancel(sessionID string) {
 244	// Cancel regular requests
 245	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 246		slog.Info("Request cancellation initiated", "session_id", sessionID)
 247		cancel()
 248	}
 249
 250	// Also check for summarize requests
 251	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 252		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 253		cancel()
 254	}
 255
 256	if a.QueuedPrompts(sessionID) > 0 {
 257		slog.Info("Clearing queued prompts", "session_id", sessionID)
 258		a.promptQueue.Del(sessionID)
 259	}
 260}
 261
 262func (a *agent) IsBusy() bool {
 263	var busy bool
 264	for cancelFunc := range a.activeRequests.Seq() {
 265		if cancelFunc != nil {
 266			busy = true
 267			break
 268		}
 269	}
 270	return busy
 271}
 272
 273func (a *agent) IsSessionBusy(sessionID string) bool {
 274	_, busy := a.activeRequests.Get(sessionID)
 275	return busy
 276}
 277
 278func (a *agent) QueuedPrompts(sessionID string) int {
 279	l, ok := a.promptQueue.Get(sessionID)
 280	if !ok {
 281		return 0
 282	}
 283	return len(l)
 284}
 285
 286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 287	if content == "" {
 288		return nil
 289	}
 290	if a.titleProvider == nil {
 291		return nil
 292	}
 293	session, err := a.sessions.Get(ctx, sessionID)
 294	if err != nil {
 295		return err
 296	}
 297	parts := []message.ContentPart{message.TextContent{
 298		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 299	}}
 300
 301	// Use streaming approach like summarization
 302	response := a.titleProvider.StreamResponse(
 303		ctx,
 304		[]message.Message{
 305			{
 306				Role:  message.User,
 307				Parts: parts,
 308			},
 309		},
 310		nil,
 311	)
 312
 313	var finalResponse *provider.ProviderResponse
 314	for r := range response {
 315		if r.Error != nil {
 316			return r.Error
 317		}
 318		finalResponse = r.Response
 319	}
 320
 321	if finalResponse == nil {
 322		return fmt.Errorf("no response received from title provider")
 323	}
 324
 325	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 326
 327	if idx := strings.Index(title, "</think>"); idx > 0 {
 328		title = title[idx+len("</think>"):]
 329	}
 330
 331	title = strings.TrimSpace(title)
 332	if title == "" {
 333		return nil
 334	}
 335
 336	session.Title = title
 337	_, err = a.sessions.Save(ctx, session)
 338	return err
 339}
 340
 341func (a *agent) err(err error) AgentEvent {
 342	return AgentEvent{
 343		Type:  AgentEventTypeError,
 344		Error: err,
 345	}
 346}
 347
 348func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 349	events := make(chan AgentEvent, 1)
 350	if a.IsSessionBusy(sessionID) {
 351		existing, ok := a.promptQueue.Get(sessionID)
 352		if !ok {
 353			existing = []string{}
 354		}
 355		existing = append(existing, content)
 356		a.promptQueue.Set(sessionID, existing)
 357		return nil, nil
 358	}
 359
 360	genCtx, cancel := context.WithCancel(ctx)
 361	a.activeRequests.Set(sessionID, cancel)
 362	startTime := time.Now()
 363
 364	go func() {
 365		slog.Debug("Request started", "sessionID", sessionID)
 366		defer log.RecoverPanic("agent.Run", func() {
 367			events <- a.err(fmt.Errorf("panic while running the agent"))
 368		})
 369		var attachmentParts []message.ContentPart
 370		for _, attachment := range attachments {
 371			if !a.Model().SupportsImages && strings.HasPrefix(attachment.MimeType, "image/") {
 372				slog.Warn("Model does not support images, skipping attachment", "mimeType", attachment.MimeType, "fileName", attachment.FileName)
 373				continue
 374			}
 375			attachmentParts = append(attachmentParts, message.BinaryContent{
 376				Path:     attachment.FilePath,
 377				MIMEType: attachment.MimeType,
 378				Data:     attachment.Content,
 379			})
 380		}
 381		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 382		if result.Error != nil {
 383			if isCancelledErr(result.Error) {
 384				slog.Error("Request canceled", "sessionID", sessionID)
 385			} else {
 386				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 387				event.Error(result.Error)
 388			}
 389		} else {
 390			slog.Debug("Request completed", "sessionID", sessionID)
 391		}
 392		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 393		a.activeRequests.Del(sessionID)
 394		cancel()
 395		a.Publish(pubsub.CreatedEvent, result)
 396		events <- result
 397		close(events)
 398	}()
 399	a.eventPromptSent(sessionID)
 400	return events, nil
 401}
 402
 403func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 404	cfg := config.Get()
 405	// List existing messages; if none, start title generation asynchronously.
 406	msgs, err := a.messages.List(ctx, sessionID)
 407	if err != nil {
 408		return a.err(fmt.Errorf("failed to list messages: %w", err))
 409	}
 410	if len(msgs) == 0 {
 411		go func() {
 412			defer log.RecoverPanic("agent.Run", func() {
 413				slog.Error("panic while generating title")
 414			})
 415			titleErr := a.generateTitle(ctx, sessionID, content)
 416			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 417				slog.Error("failed to generate title", "error", titleErr)
 418			}
 419		}()
 420	}
 421	session, err := a.sessions.Get(ctx, sessionID)
 422	if err != nil {
 423		return a.err(fmt.Errorf("failed to get session: %w", err))
 424	}
 425	if session.SummaryMessageID != "" {
 426		summaryMsgInex := -1
 427		for i, msg := range msgs {
 428			if msg.ID == session.SummaryMessageID {
 429				summaryMsgInex = i
 430				break
 431			}
 432		}
 433		if summaryMsgInex != -1 {
 434			msgs = msgs[summaryMsgInex:]
 435			msgs[0].Role = message.User
 436		}
 437	}
 438
 439	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 440	if err != nil {
 441		return a.err(fmt.Errorf("failed to create user message: %w", err))
 442	}
 443	// Append the new user message to the conversation history.
 444	msgHistory := append(msgs, userMsg)
 445
 446	for {
 447		// Check for cancellation before each iteration
 448		select {
 449		case <-ctx.Done():
 450			return a.err(ctx.Err())
 451		default:
 452			// Continue processing
 453		}
 454		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 455		if err != nil {
 456			if errors.Is(err, context.Canceled) {
 457				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 458				a.messages.Update(context.Background(), agentMessage)
 459				return a.err(ErrRequestCancelled)
 460			}
 461			return a.err(fmt.Errorf("failed to process events: %w", err))
 462		}
 463		if cfg.Options.Debug {
 464			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 465		}
 466		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 467			// We are not done, we need to respond with the tool response
 468			msgHistory = append(msgHistory, agentMessage, *toolResults)
 469			// If there are queued prompts, process the next one
 470			nextPrompt, ok := a.promptQueue.Take(sessionID)
 471			if ok {
 472				for _, prompt := range nextPrompt {
 473					// Create a new user message for the queued prompt
 474					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 475					if err != nil {
 476						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 477					}
 478					// Append the new user message to the conversation history
 479					msgHistory = append(msgHistory, userMsg)
 480				}
 481			}
 482
 483			continue
 484		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 485			queuePrompts, ok := a.promptQueue.Take(sessionID)
 486			if ok {
 487				for _, prompt := range queuePrompts {
 488					if prompt == "" {
 489						continue
 490					}
 491					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 492					if err != nil {
 493						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 494					}
 495					msgHistory = append(msgHistory, userMsg)
 496				}
 497				continue
 498			}
 499		}
 500		if agentMessage.FinishReason() == "" {
 501			// Kujtim: could not track down where this is happening but this means its cancelled
 502			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 503			_ = a.messages.Update(context.Background(), agentMessage)
 504			return a.err(ErrRequestCancelled)
 505		}
 506		return AgentEvent{
 507			Type:    AgentEventTypeResponse,
 508			Message: agentMessage,
 509			Done:    true,
 510		}
 511	}
 512}
 513
 514func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 515	parts := []message.ContentPart{message.TextContent{Text: content}}
 516	parts = append(parts, attachmentParts...)
 517	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 518		Role:  message.User,
 519		Parts: parts,
 520	})
 521}
 522
 523func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 524	var allTools []tools.BaseTool
 525	for tool := range a.baseTools.Seq() {
 526		if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
 527			allTools = append(allTools, tool)
 528		}
 529	}
 530	if a.agentCfg.ID == "coder" {
 531		allTools = slices.AppendSeq(allTools, a.mcpTools.Seq())
 532		if a.lspClients.Len() > 0 {
 533			allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients))
 534		}
 535	}
 536	if a.agentToolFn != nil {
 537		agentTool, agentToolErr := a.agentToolFn()
 538		if agentToolErr != nil {
 539			return nil, agentToolErr
 540		}
 541		allTools = append(allTools, agentTool)
 542	}
 543	return allTools, nil
 544}
 545
 546func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 547	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 548
 549	// Create the assistant message first so the spinner shows immediately
 550	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 551		Role:     message.Assistant,
 552		Parts:    []message.ContentPart{},
 553		Model:    a.Model().ID,
 554		Provider: a.providerID,
 555	})
 556	if err != nil {
 557		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 558	}
 559
 560	allTools, toolsErr := a.getAllTools()
 561	if toolsErr != nil {
 562		return assistantMsg, nil, toolsErr
 563	}
 564	// Now collect tools (which may block on MCP initialization)
 565	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 566
 567	// Add the session and message ID into the context if needed by tools.
 568	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 569
 570loop:
 571	for {
 572		select {
 573		case event, ok := <-eventChan:
 574			if !ok {
 575				break loop
 576			}
 577			if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 578				if errors.Is(processErr, context.Canceled) {
 579					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 580				} else {
 581					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 582				}
 583				return assistantMsg, nil, processErr
 584			}
 585		case <-ctx.Done():
 586			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 587			return assistantMsg, nil, ctx.Err()
 588		}
 589	}
 590
 591	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 592	toolCalls := assistantMsg.ToolCalls()
 593	for i, toolCall := range toolCalls {
 594		select {
 595		case <-ctx.Done():
 596			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 597			// Make all future tool calls cancelled
 598			for j := i; j < len(toolCalls); j++ {
 599				toolResults[j] = message.ToolResult{
 600					ToolCallID: toolCalls[j].ID,
 601					Content:    "Tool execution canceled by user",
 602					IsError:    true,
 603				}
 604			}
 605			goto out
 606		default:
 607			// Continue processing
 608			var tool tools.BaseTool
 609			allTools, _ = a.getAllTools()
 610			for _, availableTool := range allTools {
 611				if availableTool.Info().Name == toolCall.Name {
 612					tool = availableTool
 613					break
 614				}
 615			}
 616
 617			// Tool not found
 618			if tool == nil {
 619				toolResults[i] = message.ToolResult{
 620					ToolCallID: toolCall.ID,
 621					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 622					IsError:    true,
 623				}
 624				continue
 625			}
 626
 627			// Run tool in goroutine to allow cancellation
 628			type toolExecResult struct {
 629				response tools.ToolResponse
 630				err      error
 631			}
 632			resultChan := make(chan toolExecResult, 1)
 633
 634			go func() {
 635				response, err := tool.Run(ctx, tools.ToolCall{
 636					ID:    toolCall.ID,
 637					Name:  toolCall.Name,
 638					Input: toolCall.Input,
 639				})
 640				resultChan <- toolExecResult{response: response, err: err}
 641			}()
 642
 643			var toolResponse tools.ToolResponse
 644			var toolErr error
 645
 646			select {
 647			case <-ctx.Done():
 648				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 649				// Mark remaining tool calls as cancelled
 650				for j := i; j < len(toolCalls); j++ {
 651					toolResults[j] = message.ToolResult{
 652						ToolCallID: toolCalls[j].ID,
 653						Content:    "Tool execution canceled by user",
 654						IsError:    true,
 655					}
 656				}
 657				goto out
 658			case result := <-resultChan:
 659				toolResponse = result.response
 660				toolErr = result.err
 661			}
 662
 663			if toolErr != nil {
 664				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 665				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 666					toolResults[i] = message.ToolResult{
 667						ToolCallID: toolCall.ID,
 668						Content:    "Permission denied",
 669						IsError:    true,
 670					}
 671					for j := i + 1; j < len(toolCalls); j++ {
 672						toolResults[j] = message.ToolResult{
 673							ToolCallID: toolCalls[j].ID,
 674							Content:    "Tool execution canceled by user",
 675							IsError:    true,
 676						}
 677					}
 678					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 679					break
 680				}
 681			}
 682			toolResults[i] = message.ToolResult{
 683				ToolCallID: toolCall.ID,
 684				Content:    toolResponse.Content,
 685				Metadata:   toolResponse.Metadata,
 686				IsError:    toolResponse.IsError,
 687			}
 688		}
 689	}
 690out:
 691	if len(toolResults) == 0 {
 692		return assistantMsg, nil, nil
 693	}
 694	parts := make([]message.ContentPart, 0)
 695	for _, tr := range toolResults {
 696		parts = append(parts, tr)
 697	}
 698	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 699		Role:     message.Tool,
 700		Parts:    parts,
 701		Provider: a.providerID,
 702	})
 703	if err != nil {
 704		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 705	}
 706
 707	return assistantMsg, &msg, err
 708}
 709
 710func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 711	msg.AddFinish(finishReason, message, details)
 712	_ = a.messages.Update(ctx, *msg)
 713}
 714
 715func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 716	select {
 717	case <-ctx.Done():
 718		return ctx.Err()
 719	default:
 720		// Continue processing.
 721	}
 722
 723	switch event.Type {
 724	case provider.EventThinkingDelta:
 725		assistantMsg.AppendReasoningContent(event.Thinking)
 726		return a.messages.Update(ctx, *assistantMsg)
 727	case provider.EventSignatureDelta:
 728		assistantMsg.AppendReasoningSignature(event.Signature)
 729		return a.messages.Update(ctx, *assistantMsg)
 730	case provider.EventContentDelta:
 731		assistantMsg.FinishThinking()
 732		assistantMsg.AppendContent(event.Content)
 733		return a.messages.Update(ctx, *assistantMsg)
 734	case provider.EventToolUseStart:
 735		assistantMsg.FinishThinking()
 736		slog.Info("Tool call started", "toolCall", event.ToolCall)
 737		assistantMsg.AddToolCall(*event.ToolCall)
 738		return a.messages.Update(ctx, *assistantMsg)
 739	case provider.EventToolUseDelta:
 740		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 741		return a.messages.Update(ctx, *assistantMsg)
 742	case provider.EventToolUseStop:
 743		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 744		assistantMsg.FinishToolCall(event.ToolCall.ID)
 745		return a.messages.Update(ctx, *assistantMsg)
 746	case provider.EventError:
 747		return event.Error
 748	case provider.EventComplete:
 749		assistantMsg.FinishThinking()
 750		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 751		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 752		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 753			return fmt.Errorf("failed to update message: %w", err)
 754		}
 755		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 756	}
 757
 758	return nil
 759}
 760
 761func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 762	sess, err := a.sessions.Get(ctx, sessionID)
 763	if err != nil {
 764		return fmt.Errorf("failed to get session: %w", err)
 765	}
 766
 767	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 768		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 769		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 770		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 771
 772	a.eventTokensUsed(sessionID, usage, cost)
 773
 774	sess.Cost += cost
 775	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 776	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 777
 778	_, err = a.sessions.Save(ctx, sess)
 779	if err != nil {
 780		return fmt.Errorf("failed to save session: %w", err)
 781	}
 782	return nil
 783}
 784
 785func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 786	if a.summarizeProvider == nil {
 787		return fmt.Errorf("summarize provider not available")
 788	}
 789
 790	// Check if session is busy
 791	if a.IsSessionBusy(sessionID) {
 792		return ErrSessionBusy
 793	}
 794
 795	// Create a new context with cancellation
 796	summarizeCtx, cancel := context.WithCancel(ctx)
 797
 798	// Store the cancel function in activeRequests to allow cancellation
 799	a.activeRequests.Set(sessionID+"-summarize", cancel)
 800
 801	go func() {
 802		defer a.activeRequests.Del(sessionID + "-summarize")
 803		defer cancel()
 804		event := AgentEvent{
 805			Type:     AgentEventTypeSummarize,
 806			Progress: "Starting summarization...",
 807		}
 808
 809		a.Publish(pubsub.CreatedEvent, event)
 810		// Get all messages from the session
 811		msgs, err := a.messages.List(summarizeCtx, sessionID)
 812		if err != nil {
 813			event = AgentEvent{
 814				Type:  AgentEventTypeError,
 815				Error: fmt.Errorf("failed to list messages: %w", err),
 816				Done:  true,
 817			}
 818			a.Publish(pubsub.CreatedEvent, event)
 819			return
 820		}
 821		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 822
 823		if len(msgs) == 0 {
 824			event = AgentEvent{
 825				Type:  AgentEventTypeError,
 826				Error: fmt.Errorf("no messages to summarize"),
 827				Done:  true,
 828			}
 829			a.Publish(pubsub.CreatedEvent, event)
 830			return
 831		}
 832
 833		event = AgentEvent{
 834			Type:     AgentEventTypeSummarize,
 835			Progress: "Analyzing conversation...",
 836		}
 837		a.Publish(pubsub.CreatedEvent, event)
 838
 839		// Add a system message to guide the summarization
 840		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 841
 842		// Create a new message with the summarize prompt
 843		promptMsg := message.Message{
 844			Role:  message.User,
 845			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 846		}
 847
 848		// Append the prompt to the messages
 849		msgsWithPrompt := append(msgs, promptMsg)
 850
 851		event = AgentEvent{
 852			Type:     AgentEventTypeSummarize,
 853			Progress: "Generating summary...",
 854		}
 855
 856		a.Publish(pubsub.CreatedEvent, event)
 857
 858		// Send the messages to the summarize provider
 859		response := a.summarizeProvider.StreamResponse(
 860			summarizeCtx,
 861			msgsWithPrompt,
 862			nil,
 863		)
 864		var finalResponse *provider.ProviderResponse
 865		for r := range response {
 866			if r.Error != nil {
 867				event = AgentEvent{
 868					Type:  AgentEventTypeError,
 869					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 870					Done:  true,
 871				}
 872				a.Publish(pubsub.CreatedEvent, event)
 873				return
 874			}
 875			finalResponse = r.Response
 876		}
 877
 878		summary := strings.TrimSpace(finalResponse.Content)
 879		if summary == "" {
 880			event = AgentEvent{
 881				Type:  AgentEventTypeError,
 882				Error: fmt.Errorf("empty summary returned"),
 883				Done:  true,
 884			}
 885			a.Publish(pubsub.CreatedEvent, event)
 886			return
 887		}
 888		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 889		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 890		event = AgentEvent{
 891			Type:     AgentEventTypeSummarize,
 892			Progress: "Creating new session...",
 893		}
 894
 895		a.Publish(pubsub.CreatedEvent, event)
 896		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 897		if err != nil {
 898			event = AgentEvent{
 899				Type:  AgentEventTypeError,
 900				Error: fmt.Errorf("failed to get session: %w", err),
 901				Done:  true,
 902			}
 903
 904			a.Publish(pubsub.CreatedEvent, event)
 905			return
 906		}
 907		// Create a message in the new session with the summary
 908		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 909			Role: message.Assistant,
 910			Parts: []message.ContentPart{
 911				message.TextContent{Text: summary},
 912				message.Finish{
 913					Reason: message.FinishReasonEndTurn,
 914					Time:   time.Now().Unix(),
 915				},
 916			},
 917			Model:    a.summarizeProvider.Model().ID,
 918			Provider: a.summarizeProviderID,
 919		})
 920		if err != nil {
 921			event = AgentEvent{
 922				Type:  AgentEventTypeError,
 923				Error: fmt.Errorf("failed to create summary message: %w", err),
 924				Done:  true,
 925			}
 926
 927			a.Publish(pubsub.CreatedEvent, event)
 928			return
 929		}
 930		oldSession.SummaryMessageID = msg.ID
 931		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 932		oldSession.PromptTokens = 0
 933		model := a.summarizeProvider.Model()
 934		usage := finalResponse.Usage
 935		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 936			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 937			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 938			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 939		oldSession.Cost += cost
 940		_, err = a.sessions.Save(summarizeCtx, oldSession)
 941		if err != nil {
 942			event = AgentEvent{
 943				Type:  AgentEventTypeError,
 944				Error: fmt.Errorf("failed to save session: %w", err),
 945				Done:  true,
 946			}
 947			a.Publish(pubsub.CreatedEvent, event)
 948		}
 949
 950		event = AgentEvent{
 951			Type:      AgentEventTypeSummarize,
 952			SessionID: oldSession.ID,
 953			Progress:  "Summary complete",
 954			Done:      true,
 955		}
 956		a.Publish(pubsub.CreatedEvent, event)
 957		// Send final success event with the new session ID
 958	}()
 959
 960	return nil
 961}
 962
 963func (a *agent) ClearQueue(sessionID string) {
 964	if a.QueuedPrompts(sessionID) > 0 {
 965		slog.Info("Clearing queued prompts", "session_id", sessionID)
 966		a.promptQueue.Del(sessionID)
 967	}
 968}
 969
 970func (a *agent) CancelAll() {
 971	if !a.IsBusy() {
 972		return
 973	}
 974	for key := range a.activeRequests.Seq2() {
 975		a.Cancel(key) // key is sessionID
 976	}
 977
 978	for _, cleanup := range a.cleanupFuncs {
 979		if cleanup != nil {
 980			cleanup()
 981		}
 982	}
 983
 984	timeout := time.After(5 * time.Second)
 985	for a.IsBusy() {
 986		select {
 987		case <-timeout:
 988			return
 989		default:
 990			time.Sleep(200 * time.Millisecond)
 991		}
 992	}
 993}
 994
 995func (a *agent) UpdateModel() error {
 996	cfg := config.Get()
 997
 998	// Get current provider configuration
 999	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
1000	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
1001		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
1002	}
1003
1004	// Check if provider has changed
1005	if string(currentProviderCfg.ID) != a.providerID {
1006		// Provider changed, need to recreate the main provider
1007		model := cfg.GetModelByType(a.agentCfg.Model)
1008		if model.ID == "" {
1009			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1010		}
1011
1012		promptID := agentPromptMap[a.agentCfg.ID]
1013		if promptID == "" {
1014			promptID = prompt.PromptDefault
1015		}
1016
1017		opts := []provider.ProviderClientOption{
1018			provider.WithModel(a.agentCfg.Model),
1019			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1020		}
1021
1022		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1023		if err != nil {
1024			return fmt.Errorf("failed to create new provider: %w", err)
1025		}
1026
1027		// Update the provider and provider ID
1028		a.provider = newProvider
1029		a.providerID = string(currentProviderCfg.ID)
1030	}
1031
1032	// Check if providers have changed for title (small) and summarize (large)
1033	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1034	var smallModelProviderCfg config.ProviderConfig
1035	for p := range cfg.Providers.Seq() {
1036		if p.ID == smallModelCfg.Provider {
1037			smallModelProviderCfg = p
1038			break
1039		}
1040	}
1041	if smallModelProviderCfg.ID == "" {
1042		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1043	}
1044
1045	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1046	var largeModelProviderCfg config.ProviderConfig
1047	for p := range cfg.Providers.Seq() {
1048		if p.ID == largeModelCfg.Provider {
1049			largeModelProviderCfg = p
1050			break
1051		}
1052	}
1053	if largeModelProviderCfg.ID == "" {
1054		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1055	}
1056
1057	var maxTitleTokens int64 = 40
1058
1059	// if the max output is too low for the gemini provider it won't return anything
1060	if smallModelCfg.Provider == "gemini" {
1061		maxTitleTokens = 1000
1062	}
1063	// Recreate title provider
1064	titleOpts := []provider.ProviderClientOption{
1065		provider.WithModel(config.SelectedModelTypeSmall),
1066		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1067		provider.WithMaxTokens(maxTitleTokens),
1068	}
1069	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1070	if err != nil {
1071		return fmt.Errorf("failed to create new title provider: %w", err)
1072	}
1073	a.titleProvider = newTitleProvider
1074
1075	// Recreate summarize provider if provider changed (now large model)
1076	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1077		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1078		if largeModel == nil {
1079			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1080		}
1081		summarizeOpts := []provider.ProviderClientOption{
1082			provider.WithModel(config.SelectedModelTypeLarge),
1083			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1084		}
1085		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1086		if err != nil {
1087			return fmt.Errorf("failed to create new summarize provider: %w", err)
1088		}
1089		a.summarizeProvider = newSummarizeProvider
1090		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1091	}
1092
1093	return nil
1094}
1095
1096func (a *agent) setupEvents(ctx context.Context) {
1097	ctx, cancel := context.WithCancel(ctx)
1098
1099	go func() {
1100		subCh := SubscribeMCPEvents(ctx)
1101
1102		for {
1103			select {
1104			case event, ok := <-subCh:
1105				if !ok {
1106					slog.Debug("MCPEvents subscription channel closed")
1107					return
1108				}
1109				name := event.Payload.Name
1110				c, ok := mcpClients.Get(name)
1111				if !ok {
1112					slog.Warn("MCP client not found for tools/prompts update", "name", name)
1113					continue
1114				}
1115				switch event.Payload.Type {
1116				case MCPEventToolsListChanged:
1117					cfg := config.Get()
1118					tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1119					if err != nil {
1120						slog.Error("error listing tools", "error", err)
1121						updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
1122						_ = c.Close()
1123						continue
1124					}
1125					updateMcpTools(name, tools)
1126					a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
1127				case MCPEventPromptsListChanged:
1128					prompts, err := getPrompts(ctx, c)
1129					if err != nil {
1130						slog.Error("error listing prompts", "error", err)
1131						updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
1132						_ = c.Close()
1133						continue
1134					}
1135					updateMcpPrompts(name, prompts)
1136				default:
1137					continue
1138				}
1139				updateMCPState(name, MCPStateConnected, nil, c, MCPCounts{
1140					Tools:   mcpTools.Len(),
1141					Prompts: mcpPrompts.Len(),
1142				})
1143			case <-ctx.Done():
1144				slog.Debug("MCPEvents subscription cancelled")
1145				return
1146			}
1147		}
1148	}()
1149
1150	a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1151}