1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"maps"
   9	"slices"
  10	"strings"
  11	"time"
  12
  13	"github.com/charmbracelet/catwalk/pkg/catwalk"
  14	"github.com/charmbracelet/crush/internal/config"
  15	"github.com/charmbracelet/crush/internal/csync"
  16	"github.com/charmbracelet/crush/internal/event"
  17	"github.com/charmbracelet/crush/internal/history"
  18	"github.com/charmbracelet/crush/internal/llm/prompt"
  19	"github.com/charmbracelet/crush/internal/llm/provider"
  20	"github.com/charmbracelet/crush/internal/llm/tools"
  21	"github.com/charmbracelet/crush/internal/log"
  22	"github.com/charmbracelet/crush/internal/lsp"
  23	"github.com/charmbracelet/crush/internal/message"
  24	"github.com/charmbracelet/crush/internal/permission"
  25	"github.com/charmbracelet/crush/internal/pubsub"
  26	"github.com/charmbracelet/crush/internal/session"
  27	"github.com/charmbracelet/crush/internal/shell"
  28)
  29
  30type AgentEventType string
  31
  32const (
  33	AgentEventTypeError     AgentEventType = "error"
  34	AgentEventTypeResponse  AgentEventType = "response"
  35	AgentEventTypeSummarize AgentEventType = "summarize"
  36)
  37
  38type AgentEvent struct {
  39	Type    AgentEventType
  40	Message message.Message
  41	Error   error
  42
  43	// When summarizing
  44	SessionID string
  45	Progress  string
  46	Done      bool
  47}
  48
  49type Service interface {
  50	pubsub.Suscriber[AgentEvent]
  51	Model() catwalk.Model
  52	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  53	Cancel(sessionID string)
  54	CancelAll()
  55	IsSessionBusy(sessionID string) bool
  56	IsBusy() bool
  57	Summarize(ctx context.Context, sessionID string) error
  58	UpdateModel() error
  59	QueuedPrompts(sessionID string) int
  60	ClearQueue(sessionID string)
  61}
  62
  63type agent struct {
  64	cleanupFuncs []func()
  65
  66	*pubsub.Broker[AgentEvent]
  67	agentCfg config.Agent
  68	sessions session.Service
  69	messages message.Service
  70
  71	permissions permission.Service
  72	baseTools   *csync.Map[string, tools.BaseTool]
  73	mcpTools    *csync.Map[string, tools.BaseTool]
  74	lspClients  *csync.Map[string, *lsp.Client]
  75
  76	// We need this to be able to update it when model changes
  77	agentToolFn func() (tools.BaseTool, error)
  78
  79	provider   provider.Provider
  80	providerID string
  81
  82	titleProvider       provider.Provider
  83	summarizeProvider   provider.Provider
  84	summarizeProviderID string
  85
  86	activeRequests *csync.Map[string, context.CancelFunc]
  87	promptQueue    *csync.Map[string, []string]
  88}
  89
  90var agentPromptMap = map[string]prompt.PromptID{
  91	"coder": prompt.PromptCoder,
  92	"task":  prompt.PromptTask,
  93}
  94
  95func NewAgent(
  96	ctx context.Context,
  97	agentCfg config.Agent,
  98	// These services are needed in the tools
  99	permissions permission.Service,
 100	sessions session.Service,
 101	messages message.Service,
 102	history history.Service,
 103	lspClients *csync.Map[string, *lsp.Client],
 104) (Service, error) {
 105	cfg := config.Get()
 106
 107	var agentToolFn func() (tools.BaseTool, error)
 108	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 109		agentToolFn = func() (tools.BaseTool, error) {
 110			taskAgentCfg := config.Get().Agents["task"]
 111			if taskAgentCfg.ID == "" {
 112				return nil, fmt.Errorf("task agent not found in config")
 113			}
 114			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 115			if err != nil {
 116				return nil, fmt.Errorf("failed to create task agent: %w", err)
 117			}
 118			return NewAgentTool(taskAgent, sessions, messages), nil
 119		}
 120	}
 121
 122	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 123	if providerCfg == nil {
 124		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 125	}
 126	model := config.Get().GetModelByType(agentCfg.Model)
 127
 128	if model == nil {
 129		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 130	}
 131
 132	promptID := agentPromptMap[agentCfg.ID]
 133	if promptID == "" {
 134		promptID = prompt.PromptDefault
 135	}
 136	opts := []provider.ProviderClientOption{
 137		provider.WithModel(agentCfg.Model),
 138		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 139	}
 140	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 141	if err != nil {
 142		return nil, err
 143	}
 144
 145	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 146	var smallModelProviderCfg *config.ProviderConfig
 147	if smallModelCfg.Provider == providerCfg.ID {
 148		smallModelProviderCfg = providerCfg
 149	} else {
 150		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 151
 152		if smallModelProviderCfg.ID == "" {
 153			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 154		}
 155	}
 156	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 157	if smallModel.ID == "" {
 158		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 159	}
 160
 161	titleOpts := []provider.ProviderClientOption{
 162		provider.WithModel(config.SelectedModelTypeSmall),
 163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 164	}
 165	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 166	if err != nil {
 167		return nil, err
 168	}
 169
 170	summarizeOpts := []provider.ProviderClientOption{
 171		provider.WithModel(config.SelectedModelTypeLarge),
 172		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 173	}
 174	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 175	if err != nil {
 176		return nil, err
 177	}
 178
 179	baseToolsFn := func() map[string]tools.BaseTool {
 180		slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
 181		defer func() {
 182			slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
 183		}()
 184
 185		// Base tools available to all agents
 186		cwd := cfg.WorkingDir()
 187		result := make(map[string]tools.BaseTool)
 188		for _, tool := range []tools.BaseTool{
 189			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 190			tools.NewDownloadTool(permissions, cwd),
 191			tools.NewEditTool(lspClients, permissions, history, cwd),
 192			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 193			tools.NewFetchTool(permissions, cwd),
 194			tools.NewGlobTool(cwd),
 195			tools.NewGrepTool(cwd),
 196			tools.NewLsTool(permissions, cwd),
 197			tools.NewSourcegraphTool(),
 198			tools.NewViewTool(lspClients, permissions, cwd),
 199			tools.NewWriteTool(lspClients, permissions, history, cwd),
 200		} {
 201			result[tool.Name()] = tool
 202		}
 203		return result
 204	}
 205	mcpToolsFn := func() map[string]tools.BaseTool {
 206		slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
 207		defer func() {
 208			slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
 209		}()
 210
 211		mcpToolsOnce.Do(func() {
 212			doGetMCPTools(ctx, permissions, cfg)
 213		})
 214
 215		return maps.Collect(mcpTools.Seq2())
 216	}
 217
 218	a := &agent{
 219		Broker:              pubsub.NewBroker[AgentEvent](),
 220		agentCfg:            agentCfg,
 221		provider:            agentProvider,
 222		providerID:          string(providerCfg.ID),
 223		messages:            messages,
 224		sessions:            sessions,
 225		titleProvider:       titleProvider,
 226		summarizeProvider:   summarizeProvider,
 227		summarizeProviderID: string(providerCfg.ID),
 228		agentToolFn:         agentToolFn,
 229		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 230		mcpTools:            csync.NewLazyMap(mcpToolsFn),
 231		baseTools:           csync.NewLazyMap(baseToolsFn),
 232		promptQueue:         csync.NewMap[string, []string](),
 233		permissions:         permissions,
 234		lspClients:          lspClients,
 235	}
 236	a.setupEvents(ctx)
 237	return a, nil
 238}
 239
 240func (a *agent) Model() catwalk.Model {
 241	return *config.Get().GetModelByType(a.agentCfg.Model)
 242}
 243
 244func (a *agent) Cancel(sessionID string) {
 245	// Cancel regular requests
 246	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 247		slog.Info("Request cancellation initiated", "session_id", sessionID)
 248		cancel()
 249	}
 250
 251	// Also check for summarize requests
 252	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 253		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 254		cancel()
 255	}
 256
 257	if a.QueuedPrompts(sessionID) > 0 {
 258		slog.Info("Clearing queued prompts", "session_id", sessionID)
 259		a.promptQueue.Del(sessionID)
 260	}
 261}
 262
 263func (a *agent) IsBusy() bool {
 264	var busy bool
 265	for cancelFunc := range a.activeRequests.Seq() {
 266		if cancelFunc != nil {
 267			busy = true
 268			break
 269		}
 270	}
 271	return busy
 272}
 273
 274func (a *agent) IsSessionBusy(sessionID string) bool {
 275	_, busy := a.activeRequests.Get(sessionID)
 276	return busy
 277}
 278
 279func (a *agent) QueuedPrompts(sessionID string) int {
 280	l, ok := a.promptQueue.Get(sessionID)
 281	if !ok {
 282		return 0
 283	}
 284	return len(l)
 285}
 286
 287func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 288	if content == "" {
 289		return nil
 290	}
 291	if a.titleProvider == nil {
 292		return nil
 293	}
 294	session, err := a.sessions.Get(ctx, sessionID)
 295	if err != nil {
 296		return err
 297	}
 298	parts := []message.ContentPart{message.TextContent{
 299		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 300	}}
 301
 302	// Use streaming approach like summarization
 303	response := a.titleProvider.StreamResponse(
 304		ctx,
 305		[]message.Message{
 306			{
 307				Role:  message.User,
 308				Parts: parts,
 309			},
 310		},
 311		nil,
 312	)
 313
 314	var finalResponse *provider.ProviderResponse
 315	for r := range response {
 316		if r.Error != nil {
 317			return r.Error
 318		}
 319		finalResponse = r.Response
 320	}
 321
 322	if finalResponse == nil {
 323		return fmt.Errorf("no response received from title provider")
 324	}
 325
 326	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 327
 328	if idx := strings.Index(title, "</think>"); idx > 0 {
 329		title = title[idx+len("</think>"):]
 330	}
 331
 332	title = strings.TrimSpace(title)
 333	if title == "" {
 334		return nil
 335	}
 336
 337	session.Title = title
 338	_, err = a.sessions.Save(ctx, session)
 339	return err
 340}
 341
 342func (a *agent) err(err error) AgentEvent {
 343	return AgentEvent{
 344		Type:  AgentEventTypeError,
 345		Error: err,
 346	}
 347}
 348
 349func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 350	if !a.Model().SupportsImages && attachments != nil {
 351		attachments = nil
 352	}
 353	events := make(chan AgentEvent, 1)
 354	if a.IsSessionBusy(sessionID) {
 355		existing, ok := a.promptQueue.Get(sessionID)
 356		if !ok {
 357			existing = []string{}
 358		}
 359		existing = append(existing, content)
 360		a.promptQueue.Set(sessionID, existing)
 361		return nil, nil
 362	}
 363
 364	genCtx, cancel := context.WithCancel(ctx)
 365	a.activeRequests.Set(sessionID, cancel)
 366	startTime := time.Now()
 367
 368	go func() {
 369		slog.Debug("Request started", "sessionID", sessionID)
 370		defer log.RecoverPanic("agent.Run", func() {
 371			events <- a.err(fmt.Errorf("panic while running the agent"))
 372		})
 373		var attachmentParts []message.ContentPart
 374		for _, attachment := range attachments {
 375			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 376		}
 377		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 378		if result.Error != nil {
 379			if isCancelledErr(result.Error) {
 380				slog.Error("Request canceled", "sessionID", sessionID)
 381			} else {
 382				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 383				event.Error(result.Error)
 384			}
 385		} else {
 386			slog.Debug("Request completed", "sessionID", sessionID)
 387		}
 388		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 389		a.activeRequests.Del(sessionID)
 390		cancel()
 391		a.Publish(pubsub.CreatedEvent, result)
 392		events <- result
 393		close(events)
 394	}()
 395	a.eventPromptSent(sessionID)
 396	return events, nil
 397}
 398
 399func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 400	cfg := config.Get()
 401	// List existing messages; if none, start title generation asynchronously.
 402	msgs, err := a.messages.List(ctx, sessionID)
 403	if err != nil {
 404		return a.err(fmt.Errorf("failed to list messages: %w", err))
 405	}
 406	if len(msgs) == 0 {
 407		go func() {
 408			defer log.RecoverPanic("agent.Run", func() {
 409				slog.Error("panic while generating title")
 410			})
 411			titleErr := a.generateTitle(ctx, sessionID, content)
 412			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 413				slog.Error("failed to generate title", "error", titleErr)
 414			}
 415		}()
 416	}
 417	session, err := a.sessions.Get(ctx, sessionID)
 418	if err != nil {
 419		return a.err(fmt.Errorf("failed to get session: %w", err))
 420	}
 421	if session.SummaryMessageID != "" {
 422		summaryMsgInex := -1
 423		for i, msg := range msgs {
 424			if msg.ID == session.SummaryMessageID {
 425				summaryMsgInex = i
 426				break
 427			}
 428		}
 429		if summaryMsgInex != -1 {
 430			msgs = msgs[summaryMsgInex:]
 431			msgs[0].Role = message.User
 432		}
 433	}
 434
 435	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 436	if err != nil {
 437		return a.err(fmt.Errorf("failed to create user message: %w", err))
 438	}
 439	// Append the new user message to the conversation history.
 440	msgHistory := append(msgs, userMsg)
 441
 442	for {
 443		// Check for cancellation before each iteration
 444		select {
 445		case <-ctx.Done():
 446			return a.err(ctx.Err())
 447		default:
 448			// Continue processing
 449		}
 450		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 451		if err != nil {
 452			if errors.Is(err, context.Canceled) {
 453				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 454				a.messages.Update(context.Background(), agentMessage)
 455				return a.err(ErrRequestCancelled)
 456			}
 457			return a.err(fmt.Errorf("failed to process events: %w", err))
 458		}
 459		if cfg.Options.Debug {
 460			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 461		}
 462		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 463			// We are not done, we need to respond with the tool response
 464			msgHistory = append(msgHistory, agentMessage, *toolResults)
 465			// If there are queued prompts, process the next one
 466			nextPrompt, ok := a.promptQueue.Take(sessionID)
 467			if ok {
 468				for _, prompt := range nextPrompt {
 469					// Create a new user message for the queued prompt
 470					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 471					if err != nil {
 472						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 473					}
 474					// Append the new user message to the conversation history
 475					msgHistory = append(msgHistory, userMsg)
 476				}
 477			}
 478
 479			continue
 480		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 481			queuePrompts, ok := a.promptQueue.Take(sessionID)
 482			if ok {
 483				for _, prompt := range queuePrompts {
 484					if prompt == "" {
 485						continue
 486					}
 487					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 488					if err != nil {
 489						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 490					}
 491					msgHistory = append(msgHistory, userMsg)
 492				}
 493				continue
 494			}
 495		}
 496		if agentMessage.FinishReason() == "" {
 497			// Kujtim: could not track down where this is happening but this means its cancelled
 498			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 499			_ = a.messages.Update(context.Background(), agentMessage)
 500			return a.err(ErrRequestCancelled)
 501		}
 502		return AgentEvent{
 503			Type:    AgentEventTypeResponse,
 504			Message: agentMessage,
 505			Done:    true,
 506		}
 507	}
 508}
 509
 510func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 511	parts := []message.ContentPart{message.TextContent{Text: content}}
 512	parts = append(parts, attachmentParts...)
 513	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 514		Role:  message.User,
 515		Parts: parts,
 516	})
 517}
 518
 519func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 520	allTools := slices.Collect(a.baseTools.Seq())
 521
 522	withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 523		if a.agentCfg.ID == "coder" {
 524			t = append(t, slices.Collect(a.mcpTools.Seq())...)
 525			if a.lspClients.Len() > 0 {
 526				t = append(t, tools.NewDiagnosticsTool(a.lspClients))
 527			}
 528		}
 529		return t
 530	}
 531
 532	if a.agentCfg.AllowedTools == nil {
 533		allTools = withCoderTools(allTools)
 534	} else {
 535		var filteredTools []tools.BaseTool
 536		for _, tool := range allTools {
 537			if slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
 538				filteredTools = append(filteredTools, tool)
 539			}
 540		}
 541		allTools = withCoderTools(filteredTools)
 542	}
 543
 544	if a.agentToolFn != nil {
 545		agentTool, agentToolErr := a.agentToolFn()
 546		if agentToolErr != nil {
 547			return nil, agentToolErr
 548		}
 549		allTools = append(allTools, agentTool)
 550	}
 551	return allTools, nil
 552}
 553
 554func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 555	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 556
 557	// Create the assistant message first so the spinner shows immediately
 558	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 559		Role:     message.Assistant,
 560		Parts:    []message.ContentPart{},
 561		Model:    a.Model().ID,
 562		Provider: a.providerID,
 563	})
 564	if err != nil {
 565		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 566	}
 567
 568	allTools, toolsErr := a.getAllTools()
 569	if toolsErr != nil {
 570		return assistantMsg, nil, toolsErr
 571	}
 572	// Now collect tools (which may block on MCP initialization)
 573	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 574
 575	// Add the session and message ID into the context if needed by tools.
 576	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 577
 578loop:
 579	for {
 580		select {
 581		case event, ok := <-eventChan:
 582			if !ok {
 583				break loop
 584			}
 585			if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 586				if errors.Is(processErr, context.Canceled) {
 587					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 588				} else {
 589					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 590				}
 591				return assistantMsg, nil, processErr
 592			}
 593		case <-ctx.Done():
 594			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 595			return assistantMsg, nil, ctx.Err()
 596		}
 597	}
 598
 599	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 600	toolCalls := assistantMsg.ToolCalls()
 601	for i, toolCall := range toolCalls {
 602		select {
 603		case <-ctx.Done():
 604			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 605			// Make all future tool calls cancelled
 606			for j := i; j < len(toolCalls); j++ {
 607				toolResults[j] = message.ToolResult{
 608					ToolCallID: toolCalls[j].ID,
 609					Content:    "Tool execution canceled by user",
 610					IsError:    true,
 611				}
 612			}
 613			goto out
 614		default:
 615			// Continue processing
 616			var tool tools.BaseTool
 617			allTools, _ = a.getAllTools()
 618			for _, availableTool := range allTools {
 619				if availableTool.Info().Name == toolCall.Name {
 620					tool = availableTool
 621					break
 622				}
 623			}
 624
 625			// Tool not found
 626			if tool == nil {
 627				toolResults[i] = message.ToolResult{
 628					ToolCallID: toolCall.ID,
 629					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 630					IsError:    true,
 631				}
 632				continue
 633			}
 634
 635			// Run tool in goroutine to allow cancellation
 636			type toolExecResult struct {
 637				response tools.ToolResponse
 638				err      error
 639			}
 640			resultChan := make(chan toolExecResult, 1)
 641
 642			go func() {
 643				response, err := tool.Run(ctx, tools.ToolCall{
 644					ID:    toolCall.ID,
 645					Name:  toolCall.Name,
 646					Input: toolCall.Input,
 647				})
 648				resultChan <- toolExecResult{response: response, err: err}
 649			}()
 650
 651			var toolResponse tools.ToolResponse
 652			var toolErr error
 653
 654			select {
 655			case <-ctx.Done():
 656				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 657				// Mark remaining tool calls as cancelled
 658				for j := i; j < len(toolCalls); j++ {
 659					toolResults[j] = message.ToolResult{
 660						ToolCallID: toolCalls[j].ID,
 661						Content:    "Tool execution canceled by user",
 662						IsError:    true,
 663					}
 664				}
 665				goto out
 666			case result := <-resultChan:
 667				toolResponse = result.response
 668				toolErr = result.err
 669			}
 670
 671			if toolErr != nil {
 672				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 673				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 674					toolResults[i] = message.ToolResult{
 675						ToolCallID: toolCall.ID,
 676						Content:    "Permission denied",
 677						IsError:    true,
 678					}
 679					for j := i + 1; j < len(toolCalls); j++ {
 680						toolResults[j] = message.ToolResult{
 681							ToolCallID: toolCalls[j].ID,
 682							Content:    "Tool execution canceled by user",
 683							IsError:    true,
 684						}
 685					}
 686					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 687					break
 688				}
 689			}
 690			toolResults[i] = message.ToolResult{
 691				ToolCallID: toolCall.ID,
 692				Content:    toolResponse.Content,
 693				Metadata:   toolResponse.Metadata,
 694				IsError:    toolResponse.IsError,
 695			}
 696		}
 697	}
 698out:
 699	if len(toolResults) == 0 {
 700		return assistantMsg, nil, nil
 701	}
 702	parts := make([]message.ContentPart, 0)
 703	for _, tr := range toolResults {
 704		parts = append(parts, tr)
 705	}
 706	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 707		Role:     message.Tool,
 708		Parts:    parts,
 709		Provider: a.providerID,
 710	})
 711	if err != nil {
 712		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 713	}
 714
 715	return assistantMsg, &msg, err
 716}
 717
 718func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 719	msg.AddFinish(finishReason, message, details)
 720	_ = a.messages.Update(ctx, *msg)
 721}
 722
 723func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 724	select {
 725	case <-ctx.Done():
 726		return ctx.Err()
 727	default:
 728		// Continue processing.
 729	}
 730
 731	switch event.Type {
 732	case provider.EventThinkingDelta:
 733		assistantMsg.AppendReasoningContent(event.Thinking)
 734		return a.messages.Update(ctx, *assistantMsg)
 735	case provider.EventSignatureDelta:
 736		assistantMsg.AppendReasoningSignature(event.Signature)
 737		return a.messages.Update(ctx, *assistantMsg)
 738	case provider.EventContentDelta:
 739		assistantMsg.FinishThinking()
 740		assistantMsg.AppendContent(event.Content)
 741		return a.messages.Update(ctx, *assistantMsg)
 742	case provider.EventToolUseStart:
 743		assistantMsg.FinishThinking()
 744		slog.Info("Tool call started", "toolCall", event.ToolCall)
 745		assistantMsg.AddToolCall(*event.ToolCall)
 746		return a.messages.Update(ctx, *assistantMsg)
 747	case provider.EventToolUseDelta:
 748		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 749		return a.messages.Update(ctx, *assistantMsg)
 750	case provider.EventToolUseStop:
 751		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 752		assistantMsg.FinishToolCall(event.ToolCall.ID)
 753		return a.messages.Update(ctx, *assistantMsg)
 754	case provider.EventError:
 755		return event.Error
 756	case provider.EventComplete:
 757		assistantMsg.FinishThinking()
 758		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 759		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 760		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 761			return fmt.Errorf("failed to update message: %w", err)
 762		}
 763		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 764	}
 765
 766	return nil
 767}
 768
 769func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 770	sess, err := a.sessions.Get(ctx, sessionID)
 771	if err != nil {
 772		return fmt.Errorf("failed to get session: %w", err)
 773	}
 774
 775	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 776		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 777		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 778		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 779
 780	a.eventTokensUsed(sessionID, usage, cost)
 781
 782	sess.Cost += cost
 783	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 784	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 785
 786	_, err = a.sessions.Save(ctx, sess)
 787	if err != nil {
 788		return fmt.Errorf("failed to save session: %w", err)
 789	}
 790	return nil
 791}
 792
 793func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 794	if a.summarizeProvider == nil {
 795		return fmt.Errorf("summarize provider not available")
 796	}
 797
 798	// Check if session is busy
 799	if a.IsSessionBusy(sessionID) {
 800		return ErrSessionBusy
 801	}
 802
 803	// Create a new context with cancellation
 804	summarizeCtx, cancel := context.WithCancel(ctx)
 805
 806	// Store the cancel function in activeRequests to allow cancellation
 807	a.activeRequests.Set(sessionID+"-summarize", cancel)
 808
 809	go func() {
 810		defer a.activeRequests.Del(sessionID + "-summarize")
 811		defer cancel()
 812		event := AgentEvent{
 813			Type:     AgentEventTypeSummarize,
 814			Progress: "Starting summarization...",
 815		}
 816
 817		a.Publish(pubsub.CreatedEvent, event)
 818		// Get all messages from the session
 819		msgs, err := a.messages.List(summarizeCtx, sessionID)
 820		if err != nil {
 821			event = AgentEvent{
 822				Type:  AgentEventTypeError,
 823				Error: fmt.Errorf("failed to list messages: %w", err),
 824				Done:  true,
 825			}
 826			a.Publish(pubsub.CreatedEvent, event)
 827			return
 828		}
 829		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 830
 831		if len(msgs) == 0 {
 832			event = AgentEvent{
 833				Type:  AgentEventTypeError,
 834				Error: fmt.Errorf("no messages to summarize"),
 835				Done:  true,
 836			}
 837			a.Publish(pubsub.CreatedEvent, event)
 838			return
 839		}
 840
 841		event = AgentEvent{
 842			Type:     AgentEventTypeSummarize,
 843			Progress: "Analyzing conversation...",
 844		}
 845		a.Publish(pubsub.CreatedEvent, event)
 846
 847		// Add a system message to guide the summarization
 848		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 849
 850		// Create a new message with the summarize prompt
 851		promptMsg := message.Message{
 852			Role:  message.User,
 853			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 854		}
 855
 856		// Append the prompt to the messages
 857		msgsWithPrompt := append(msgs, promptMsg)
 858
 859		event = AgentEvent{
 860			Type:     AgentEventTypeSummarize,
 861			Progress: "Generating summary...",
 862		}
 863
 864		a.Publish(pubsub.CreatedEvent, event)
 865
 866		// Send the messages to the summarize provider
 867		response := a.summarizeProvider.StreamResponse(
 868			summarizeCtx,
 869			msgsWithPrompt,
 870			nil,
 871		)
 872		var finalResponse *provider.ProviderResponse
 873		for r := range response {
 874			if r.Error != nil {
 875				event = AgentEvent{
 876					Type:  AgentEventTypeError,
 877					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 878					Done:  true,
 879				}
 880				a.Publish(pubsub.CreatedEvent, event)
 881				return
 882			}
 883			finalResponse = r.Response
 884		}
 885
 886		summary := strings.TrimSpace(finalResponse.Content)
 887		if summary == "" {
 888			event = AgentEvent{
 889				Type:  AgentEventTypeError,
 890				Error: fmt.Errorf("empty summary returned"),
 891				Done:  true,
 892			}
 893			a.Publish(pubsub.CreatedEvent, event)
 894			return
 895		}
 896		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 897		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 898		event = AgentEvent{
 899			Type:     AgentEventTypeSummarize,
 900			Progress: "Creating new session...",
 901		}
 902
 903		a.Publish(pubsub.CreatedEvent, event)
 904		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 905		if err != nil {
 906			event = AgentEvent{
 907				Type:  AgentEventTypeError,
 908				Error: fmt.Errorf("failed to get session: %w", err),
 909				Done:  true,
 910			}
 911
 912			a.Publish(pubsub.CreatedEvent, event)
 913			return
 914		}
 915		// Create a message in the new session with the summary
 916		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 917			Role: message.Assistant,
 918			Parts: []message.ContentPart{
 919				message.TextContent{Text: summary},
 920				message.Finish{
 921					Reason: message.FinishReasonEndTurn,
 922					Time:   time.Now().Unix(),
 923				},
 924			},
 925			Model:    a.summarizeProvider.Model().ID,
 926			Provider: a.summarizeProviderID,
 927		})
 928		if err != nil {
 929			event = AgentEvent{
 930				Type:  AgentEventTypeError,
 931				Error: fmt.Errorf("failed to create summary message: %w", err),
 932				Done:  true,
 933			}
 934
 935			a.Publish(pubsub.CreatedEvent, event)
 936			return
 937		}
 938		oldSession.SummaryMessageID = msg.ID
 939		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 940		oldSession.PromptTokens = 0
 941		model := a.summarizeProvider.Model()
 942		usage := finalResponse.Usage
 943		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 944			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 945			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 946			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 947		oldSession.Cost += cost
 948		_, err = a.sessions.Save(summarizeCtx, oldSession)
 949		if err != nil {
 950			event = AgentEvent{
 951				Type:  AgentEventTypeError,
 952				Error: fmt.Errorf("failed to save session: %w", err),
 953				Done:  true,
 954			}
 955			a.Publish(pubsub.CreatedEvent, event)
 956		}
 957
 958		event = AgentEvent{
 959			Type:      AgentEventTypeSummarize,
 960			SessionID: oldSession.ID,
 961			Progress:  "Summary complete",
 962			Done:      true,
 963		}
 964		a.Publish(pubsub.CreatedEvent, event)
 965		// Send final success event with the new session ID
 966	}()
 967
 968	return nil
 969}
 970
 971func (a *agent) ClearQueue(sessionID string) {
 972	if a.QueuedPrompts(sessionID) > 0 {
 973		slog.Info("Clearing queued prompts", "session_id", sessionID)
 974		a.promptQueue.Del(sessionID)
 975	}
 976}
 977
 978func (a *agent) CancelAll() {
 979	if !a.IsBusy() {
 980		return
 981	}
 982	for key := range a.activeRequests.Seq2() {
 983		a.Cancel(key) // key is sessionID
 984	}
 985
 986	for _, cleanup := range a.cleanupFuncs {
 987		if cleanup != nil {
 988			cleanup()
 989		}
 990	}
 991
 992	timeout := time.After(5 * time.Second)
 993	for a.IsBusy() {
 994		select {
 995		case <-timeout:
 996			return
 997		default:
 998			time.Sleep(200 * time.Millisecond)
 999		}
1000	}
1001}
1002
1003func (a *agent) UpdateModel() error {
1004	cfg := config.Get()
1005
1006	// Get current provider configuration
1007	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
1008	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
1009		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
1010	}
1011
1012	// Check if provider has changed
1013	if string(currentProviderCfg.ID) != a.providerID {
1014		// Provider changed, need to recreate the main provider
1015		model := cfg.GetModelByType(a.agentCfg.Model)
1016		if model.ID == "" {
1017			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1018		}
1019
1020		promptID := agentPromptMap[a.agentCfg.ID]
1021		if promptID == "" {
1022			promptID = prompt.PromptDefault
1023		}
1024
1025		opts := []provider.ProviderClientOption{
1026			provider.WithModel(a.agentCfg.Model),
1027			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1028		}
1029
1030		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1031		if err != nil {
1032			return fmt.Errorf("failed to create new provider: %w", err)
1033		}
1034
1035		// Update the provider and provider ID
1036		a.provider = newProvider
1037		a.providerID = string(currentProviderCfg.ID)
1038	}
1039
1040	// Check if providers have changed for title (small) and summarize (large)
1041	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1042	var smallModelProviderCfg config.ProviderConfig
1043	for p := range cfg.Providers.Seq() {
1044		if p.ID == smallModelCfg.Provider {
1045			smallModelProviderCfg = p
1046			break
1047		}
1048	}
1049	if smallModelProviderCfg.ID == "" {
1050		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1051	}
1052
1053	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1054	var largeModelProviderCfg config.ProviderConfig
1055	for p := range cfg.Providers.Seq() {
1056		if p.ID == largeModelCfg.Provider {
1057			largeModelProviderCfg = p
1058			break
1059		}
1060	}
1061	if largeModelProviderCfg.ID == "" {
1062		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1063	}
1064
1065	var maxTitleTokens int64 = 40
1066
1067	// if the max output is too low for the gemini provider it won't return anything
1068	if smallModelCfg.Provider == "gemini" {
1069		maxTitleTokens = 1000
1070	}
1071	// Recreate title provider
1072	titleOpts := []provider.ProviderClientOption{
1073		provider.WithModel(config.SelectedModelTypeSmall),
1074		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1075		provider.WithMaxTokens(maxTitleTokens),
1076	}
1077	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1078	if err != nil {
1079		return fmt.Errorf("failed to create new title provider: %w", err)
1080	}
1081	a.titleProvider = newTitleProvider
1082
1083	// Recreate summarize provider if provider changed (now large model)
1084	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1085		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1086		if largeModel == nil {
1087			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1088		}
1089		summarizeOpts := []provider.ProviderClientOption{
1090			provider.WithModel(config.SelectedModelTypeLarge),
1091			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1092		}
1093		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1094		if err != nil {
1095			return fmt.Errorf("failed to create new summarize provider: %w", err)
1096		}
1097		a.summarizeProvider = newSummarizeProvider
1098		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1099	}
1100
1101	return nil
1102}
1103
1104func (a *agent) setupEvents(ctx context.Context) {
1105	ctx, cancel := context.WithCancel(ctx)
1106
1107	go func() {
1108		subCh := SubscribeMCPEvents(ctx)
1109
1110		for {
1111			select {
1112			case event, ok := <-subCh:
1113				if !ok {
1114					slog.Debug("MCPEvents subscription channel closed")
1115					return
1116				}
1117				switch event.Payload.Type {
1118				case MCPEventToolsListChanged:
1119					name := event.Payload.Name
1120					c, ok := mcpClients.Get(name)
1121					if !ok {
1122						slog.Warn("MCP client not found for tools update", "name", name)
1123						continue
1124					}
1125					cfg := config.Get()
1126					tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1127					if err != nil {
1128						slog.Error("error listing tools", "error", err)
1129						updateMCPState(name, MCPStateError, err, nil, 0)
1130						_ = c.Close()
1131						continue
1132					}
1133					updateMcpTools(name, tools)
1134					// Update the lazy map with the new tools
1135					a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2()))
1136					updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1137				default:
1138					continue
1139				}
1140			case <-ctx.Done():
1141				slog.Debug("MCPEvents subscription cancelled")
1142				return
1143			}
1144		}
1145	}()
1146
1147	a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1148}