agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"maps"
   9	"slices"
  10	"strings"
  11	"time"
  12
  13	"github.com/charmbracelet/catwalk/pkg/catwalk"
  14	"github.com/charmbracelet/crush/internal/config"
  15	"github.com/charmbracelet/crush/internal/csync"
  16	"github.com/charmbracelet/crush/internal/event"
  17	"github.com/charmbracelet/crush/internal/history"
  18	"github.com/charmbracelet/crush/internal/llm/prompt"
  19	"github.com/charmbracelet/crush/internal/llm/provider"
  20	"github.com/charmbracelet/crush/internal/llm/tools"
  21	"github.com/charmbracelet/crush/internal/log"
  22	"github.com/charmbracelet/crush/internal/lsp"
  23	"github.com/charmbracelet/crush/internal/message"
  24	"github.com/charmbracelet/crush/internal/permission"
  25	"github.com/charmbracelet/crush/internal/pubsub"
  26	"github.com/charmbracelet/crush/internal/session"
  27	"github.com/charmbracelet/crush/internal/shell"
  28)
  29
  30type AgentEventType string
  31
  32const (
  33	AgentEventTypeError     AgentEventType = "error"
  34	AgentEventTypeResponse  AgentEventType = "response"
  35	AgentEventTypeSummarize AgentEventType = "summarize"
  36)
  37
  38type AgentEvent struct {
  39	Type    AgentEventType
  40	Message message.Message
  41	Error   error
  42
  43	// When summarizing
  44	SessionID string
  45	Progress  string
  46	Done      bool
  47}
  48
  49type Service interface {
  50	pubsub.Suscriber[AgentEvent]
  51	Model() catwalk.Model
  52	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  53	Cancel(sessionID string)
  54	CancelAll()
  55	IsSessionBusy(sessionID string) bool
  56	IsBusy() bool
  57	Summarize(ctx context.Context, sessionID string) error
  58	UpdateModel() error
  59	QueuedPrompts(sessionID string) int
  60	ClearQueue(sessionID string)
  61}
  62
  63type agent struct {
  64	*pubsub.Broker[AgentEvent]
  65	agentCfg    config.Agent
  66	sessions    session.Service
  67	messages    message.Service
  68	permissions permission.Service
  69	baseTools   *csync.Map[string, tools.BaseTool]
  70	mcpTools    *csync.Map[string, tools.BaseTool]
  71	lspClients  *csync.Map[string, *lsp.Client]
  72
  73	// We need this to be able to update it when model changes
  74	agentToolFn  func() (tools.BaseTool, error)
  75	cleanupFuncs []func()
  76
  77	provider   provider.Provider
  78	providerID string
  79
  80	titleProvider       provider.Provider
  81	summarizeProvider   provider.Provider
  82	summarizeProviderID string
  83
  84	activeRequests *csync.Map[string, context.CancelFunc]
  85	promptQueue    *csync.Map[string, []string]
  86}
  87
  88var agentPromptMap = map[string]prompt.PromptID{
  89	"coder": prompt.PromptCoder,
  90	"task":  prompt.PromptTask,
  91}
  92
  93func NewAgent(
  94	ctx context.Context,
  95	agentCfg config.Agent,
  96	// These services are needed in the tools
  97	permissions permission.Service,
  98	sessions session.Service,
  99	messages message.Service,
 100	history history.Service,
 101	lspClients *csync.Map[string, *lsp.Client],
 102) (Service, error) {
 103	cfg := config.Get()
 104
 105	var agentToolFn func() (tools.BaseTool, error)
 106	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 107		agentToolFn = func() (tools.BaseTool, error) {
 108			taskAgentCfg := config.Get().Agents["task"]
 109			if taskAgentCfg.ID == "" {
 110				return nil, fmt.Errorf("task agent not found in config")
 111			}
 112			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 113			if err != nil {
 114				return nil, fmt.Errorf("failed to create task agent: %w", err)
 115			}
 116			return NewAgentTool(taskAgent, sessions, messages), nil
 117		}
 118	}
 119
 120	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 121	if providerCfg == nil {
 122		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 123	}
 124	model := config.Get().GetModelByType(agentCfg.Model)
 125
 126	if model == nil {
 127		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 128	}
 129
 130	promptID := agentPromptMap[agentCfg.ID]
 131	if promptID == "" {
 132		promptID = prompt.PromptDefault
 133	}
 134	opts := []provider.ProviderClientOption{
 135		provider.WithModel(agentCfg.Model),
 136		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 137	}
 138	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 139	if err != nil {
 140		return nil, err
 141	}
 142
 143	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 144	var smallModelProviderCfg *config.ProviderConfig
 145	if smallModelCfg.Provider == providerCfg.ID {
 146		smallModelProviderCfg = providerCfg
 147	} else {
 148		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 149
 150		if smallModelProviderCfg.ID == "" {
 151			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 152		}
 153	}
 154	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 155	if smallModel.ID == "" {
 156		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 157	}
 158
 159	titleOpts := []provider.ProviderClientOption{
 160		provider.WithModel(config.SelectedModelTypeSmall),
 161		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 162	}
 163	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 164	if err != nil {
 165		return nil, err
 166	}
 167
 168	summarizeOpts := []provider.ProviderClientOption{
 169		provider.WithModel(config.SelectedModelTypeLarge),
 170		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 171	}
 172	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 173	if err != nil {
 174		return nil, err
 175	}
 176
 177	baseToolsFn := func() map[string]tools.BaseTool {
 178		slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
 179		defer func() {
 180			slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
 181		}()
 182
 183		// Base tools available to all agents
 184		cwd := cfg.WorkingDir()
 185		result := make(map[string]tools.BaseTool)
 186		for _, tool := range []tools.BaseTool{
 187			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 188			tools.NewDownloadTool(permissions, cwd),
 189			tools.NewEditTool(lspClients, permissions, history, cwd),
 190			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 191			tools.NewFetchTool(permissions, cwd),
 192			tools.NewGlobTool(cwd),
 193			tools.NewGrepTool(cwd),
 194			tools.NewLsTool(permissions, cwd),
 195			tools.NewSourcegraphTool(),
 196			tools.NewViewTool(lspClients, permissions, cwd),
 197			tools.NewWriteTool(lspClients, permissions, history, cwd),
 198		} {
 199			result[tool.Name()] = tool
 200		}
 201		return result
 202	}
 203	mcpToolsFn := func() map[string]tools.BaseTool {
 204		slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
 205		defer func() {
 206			slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
 207		}()
 208
 209		mcpToolsOnce.Do(func() {
 210			doGetMCPTools(ctx, permissions, cfg)
 211		})
 212
 213		return maps.Collect(mcpTools.Seq2())
 214	}
 215
 216	a := &agent{
 217		Broker:              pubsub.NewBroker[AgentEvent](),
 218		agentCfg:            agentCfg,
 219		provider:            agentProvider,
 220		providerID:          string(providerCfg.ID),
 221		messages:            messages,
 222		sessions:            sessions,
 223		titleProvider:       titleProvider,
 224		summarizeProvider:   summarizeProvider,
 225		summarizeProviderID: string(providerCfg.ID),
 226		agentToolFn:         agentToolFn,
 227		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 228		mcpTools:            csync.NewLazyMap(mcpToolsFn),
 229		baseTools:           csync.NewLazyMap(baseToolsFn),
 230		promptQueue:         csync.NewMap[string, []string](),
 231		permissions:         permissions,
 232		lspClients:          lspClients,
 233	}
 234	a.setupEvents(ctx)
 235	return a, nil
 236}
 237
 238func (a *agent) Model() catwalk.Model {
 239	return *config.Get().GetModelByType(a.agentCfg.Model)
 240}
 241
 242func (a *agent) Cancel(sessionID string) {
 243	// Cancel regular requests
 244	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 245		slog.Info("Request cancellation initiated", "session_id", sessionID)
 246		cancel()
 247	}
 248
 249	// Also check for summarize requests
 250	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 251		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 252		cancel()
 253	}
 254
 255	if a.QueuedPrompts(sessionID) > 0 {
 256		slog.Info("Clearing queued prompts", "session_id", sessionID)
 257		a.promptQueue.Del(sessionID)
 258	}
 259}
 260
 261func (a *agent) IsBusy() bool {
 262	var busy bool
 263	for cancelFunc := range a.activeRequests.Seq() {
 264		if cancelFunc != nil {
 265			busy = true
 266			break
 267		}
 268	}
 269	return busy
 270}
 271
 272func (a *agent) IsSessionBusy(sessionID string) bool {
 273	_, busy := a.activeRequests.Get(sessionID)
 274	return busy
 275}
 276
 277func (a *agent) QueuedPrompts(sessionID string) int {
 278	l, ok := a.promptQueue.Get(sessionID)
 279	if !ok {
 280		return 0
 281	}
 282	return len(l)
 283}
 284
 285func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 286	if content == "" {
 287		return nil
 288	}
 289	if a.titleProvider == nil {
 290		return nil
 291	}
 292	session, err := a.sessions.Get(ctx, sessionID)
 293	if err != nil {
 294		return err
 295	}
 296	parts := []message.ContentPart{message.TextContent{
 297		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 298	}}
 299
 300	// Use streaming approach like summarization
 301	response := a.titleProvider.StreamResponse(
 302		ctx,
 303		[]message.Message{
 304			{
 305				Role:  message.User,
 306				Parts: parts,
 307			},
 308		},
 309		nil,
 310	)
 311
 312	var finalResponse *provider.ProviderResponse
 313	for r := range response {
 314		if r.Error != nil {
 315			return r.Error
 316		}
 317		finalResponse = r.Response
 318	}
 319
 320	if finalResponse == nil {
 321		return fmt.Errorf("no response received from title provider")
 322	}
 323
 324	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 325
 326	if idx := strings.Index(title, "</think>"); idx > 0 {
 327		title = title[idx+len("</think>"):]
 328	}
 329
 330	title = strings.TrimSpace(title)
 331	if title == "" {
 332		return nil
 333	}
 334
 335	session.Title = title
 336	_, err = a.sessions.Save(ctx, session)
 337	return err
 338}
 339
 340func (a *agent) err(err error) AgentEvent {
 341	return AgentEvent{
 342		Type:  AgentEventTypeError,
 343		Error: err,
 344	}
 345}
 346
 347func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 348	if !a.Model().SupportsImages && attachments != nil {
 349		attachments = nil
 350	}
 351	events := make(chan AgentEvent, 1)
 352	if a.IsSessionBusy(sessionID) {
 353		existing, ok := a.promptQueue.Get(sessionID)
 354		if !ok {
 355			existing = []string{}
 356		}
 357		existing = append(existing, content)
 358		a.promptQueue.Set(sessionID, existing)
 359		return nil, nil
 360	}
 361
 362	genCtx, cancel := context.WithCancel(ctx)
 363	a.activeRequests.Set(sessionID, cancel)
 364	startTime := time.Now()
 365
 366	go func() {
 367		slog.Debug("Request started", "sessionID", sessionID)
 368		defer log.RecoverPanic("agent.Run", func() {
 369			events <- a.err(fmt.Errorf("panic while running the agent"))
 370		})
 371		var attachmentParts []message.ContentPart
 372		for _, attachment := range attachments {
 373			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 374		}
 375		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 376		if result.Error != nil {
 377			if isCancelledErr(result.Error) {
 378				slog.Error("Request canceled", "sessionID", sessionID)
 379			} else {
 380				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 381				event.Error(result.Error)
 382			}
 383		} else {
 384			slog.Debug("Request completed", "sessionID", sessionID)
 385		}
 386		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 387		a.activeRequests.Del(sessionID)
 388		cancel()
 389		a.Publish(pubsub.CreatedEvent, result)
 390		events <- result
 391		close(events)
 392	}()
 393	a.eventPromptSent(sessionID)
 394	return events, nil
 395}
 396
 397func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 398	cfg := config.Get()
 399	// List existing messages; if none, start title generation asynchronously.
 400	msgs, err := a.messages.List(ctx, sessionID)
 401	if err != nil {
 402		return a.err(fmt.Errorf("failed to list messages: %w", err))
 403	}
 404	if len(msgs) == 0 {
 405		go func() {
 406			defer log.RecoverPanic("agent.Run", func() {
 407				slog.Error("panic while generating title")
 408			})
 409			titleErr := a.generateTitle(ctx, sessionID, content)
 410			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 411				slog.Error("failed to generate title", "error", titleErr)
 412			}
 413		}()
 414	}
 415	session, err := a.sessions.Get(ctx, sessionID)
 416	if err != nil {
 417		return a.err(fmt.Errorf("failed to get session: %w", err))
 418	}
 419	if session.SummaryMessageID != "" {
 420		summaryMsgInex := -1
 421		for i, msg := range msgs {
 422			if msg.ID == session.SummaryMessageID {
 423				summaryMsgInex = i
 424				break
 425			}
 426		}
 427		if summaryMsgInex != -1 {
 428			msgs = msgs[summaryMsgInex:]
 429			msgs[0].Role = message.User
 430		}
 431	}
 432
 433	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 434	if err != nil {
 435		return a.err(fmt.Errorf("failed to create user message: %w", err))
 436	}
 437	// Append the new user message to the conversation history.
 438	msgHistory := append(msgs, userMsg)
 439
 440	for {
 441		// Check for cancellation before each iteration
 442		select {
 443		case <-ctx.Done():
 444			return a.err(ctx.Err())
 445		default:
 446			// Continue processing
 447		}
 448		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 449		if err != nil {
 450			if errors.Is(err, context.Canceled) {
 451				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 452				a.messages.Update(context.Background(), agentMessage)
 453				return a.err(ErrRequestCancelled)
 454			}
 455			return a.err(fmt.Errorf("failed to process events: %w", err))
 456		}
 457		if cfg.Options.Debug {
 458			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 459		}
 460		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 461			// We are not done, we need to respond with the tool response
 462			msgHistory = append(msgHistory, agentMessage, *toolResults)
 463			// If there are queued prompts, process the next one
 464			nextPrompt, ok := a.promptQueue.Take(sessionID)
 465			if ok {
 466				for _, prompt := range nextPrompt {
 467					// Create a new user message for the queued prompt
 468					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 469					if err != nil {
 470						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 471					}
 472					// Append the new user message to the conversation history
 473					msgHistory = append(msgHistory, userMsg)
 474				}
 475			}
 476
 477			continue
 478		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 479			queuePrompts, ok := a.promptQueue.Take(sessionID)
 480			if ok {
 481				for _, prompt := range queuePrompts {
 482					if prompt == "" {
 483						continue
 484					}
 485					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 486					if err != nil {
 487						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 488					}
 489					msgHistory = append(msgHistory, userMsg)
 490				}
 491				continue
 492			}
 493		}
 494		if agentMessage.FinishReason() == "" {
 495			// Kujtim: could not track down where this is happening but this means its cancelled
 496			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 497			_ = a.messages.Update(context.Background(), agentMessage)
 498			return a.err(ErrRequestCancelled)
 499		}
 500		return AgentEvent{
 501			Type:    AgentEventTypeResponse,
 502			Message: agentMessage,
 503			Done:    true,
 504		}
 505	}
 506}
 507
 508func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 509	parts := []message.ContentPart{message.TextContent{Text: content}}
 510	parts = append(parts, attachmentParts...)
 511	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 512		Role:  message.User,
 513		Parts: parts,
 514	})
 515}
 516
 517func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 518	allTools := slices.Collect(a.baseTools.Seq())
 519
 520	withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 521		if a.agentCfg.ID == "coder" {
 522			t = append(t, slices.Collect(a.mcpTools.Seq())...)
 523			if a.lspClients.Len() > 0 {
 524				t = append(t, tools.NewDiagnosticsTool(a.lspClients))
 525			}
 526		}
 527		return t
 528	}
 529
 530	if a.agentCfg.AllowedTools == nil {
 531		allTools = withCoderTools(allTools)
 532	} else {
 533		var filteredTools []tools.BaseTool
 534		for _, tool := range allTools {
 535			if slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
 536				filteredTools = append(filteredTools, tool)
 537			}
 538		}
 539		allTools = withCoderTools(filteredTools)
 540	}
 541
 542	if a.agentToolFn != nil {
 543		agentTool, agentToolErr := a.agentToolFn()
 544		if agentToolErr != nil {
 545			return nil, agentToolErr
 546		}
 547		allTools = append(allTools, agentTool)
 548	}
 549	return allTools, nil
 550}
 551
 552func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 553	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 554
 555	// Create the assistant message first so the spinner shows immediately
 556	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 557		Role:     message.Assistant,
 558		Parts:    []message.ContentPart{},
 559		Model:    a.Model().ID,
 560		Provider: a.providerID,
 561	})
 562	if err != nil {
 563		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 564	}
 565
 566	allTools, toolsErr := a.getAllTools()
 567	if toolsErr != nil {
 568		return assistantMsg, nil, toolsErr
 569	}
 570	// Now collect tools (which may block on MCP initialization)
 571	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 572
 573	// Add the session and message ID into the context if needed by tools.
 574	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 575
 576loop:
 577	for {
 578		select {
 579		case event, ok := <-eventChan:
 580			if !ok {
 581				break loop
 582			}
 583			if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 584				if errors.Is(processErr, context.Canceled) {
 585					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 586				} else {
 587					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 588				}
 589				return assistantMsg, nil, processErr
 590			}
 591		case <-ctx.Done():
 592			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 593			return assistantMsg, nil, ctx.Err()
 594		}
 595	}
 596
 597	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 598	toolCalls := assistantMsg.ToolCalls()
 599	for i, toolCall := range toolCalls {
 600		select {
 601		case <-ctx.Done():
 602			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 603			// Make all future tool calls cancelled
 604			for j := i; j < len(toolCalls); j++ {
 605				toolResults[j] = message.ToolResult{
 606					ToolCallID: toolCalls[j].ID,
 607					Content:    "Tool execution canceled by user",
 608					IsError:    true,
 609				}
 610			}
 611			goto out
 612		default:
 613			// Continue processing
 614			var tool tools.BaseTool
 615			allTools, _ = a.getAllTools()
 616			for _, availableTool := range allTools {
 617				if availableTool.Info().Name == toolCall.Name {
 618					tool = availableTool
 619					break
 620				}
 621			}
 622
 623			// Tool not found
 624			if tool == nil {
 625				toolResults[i] = message.ToolResult{
 626					ToolCallID: toolCall.ID,
 627					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 628					IsError:    true,
 629				}
 630				continue
 631			}
 632
 633			// Run tool in goroutine to allow cancellation
 634			type toolExecResult struct {
 635				response tools.ToolResponse
 636				err      error
 637			}
 638			resultChan := make(chan toolExecResult, 1)
 639
 640			go func() {
 641				response, err := tool.Run(ctx, tools.ToolCall{
 642					ID:    toolCall.ID,
 643					Name:  toolCall.Name,
 644					Input: toolCall.Input,
 645				})
 646				resultChan <- toolExecResult{response: response, err: err}
 647			}()
 648
 649			var toolResponse tools.ToolResponse
 650			var toolErr error
 651
 652			select {
 653			case <-ctx.Done():
 654				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 655				// Mark remaining tool calls as cancelled
 656				for j := i; j < len(toolCalls); j++ {
 657					toolResults[j] = message.ToolResult{
 658						ToolCallID: toolCalls[j].ID,
 659						Content:    "Tool execution canceled by user",
 660						IsError:    true,
 661					}
 662				}
 663				goto out
 664			case result := <-resultChan:
 665				toolResponse = result.response
 666				toolErr = result.err
 667			}
 668
 669			if toolErr != nil {
 670				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 671				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 672					toolResults[i] = message.ToolResult{
 673						ToolCallID: toolCall.ID,
 674						Content:    "Permission denied",
 675						IsError:    true,
 676					}
 677					for j := i + 1; j < len(toolCalls); j++ {
 678						toolResults[j] = message.ToolResult{
 679							ToolCallID: toolCalls[j].ID,
 680							Content:    "Tool execution canceled by user",
 681							IsError:    true,
 682						}
 683					}
 684					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 685					break
 686				}
 687			}
 688			toolResults[i] = message.ToolResult{
 689				ToolCallID: toolCall.ID,
 690				Content:    toolResponse.Content,
 691				Metadata:   toolResponse.Metadata,
 692				IsError:    toolResponse.IsError,
 693			}
 694		}
 695	}
 696out:
 697	if len(toolResults) == 0 {
 698		return assistantMsg, nil, nil
 699	}
 700	parts := make([]message.ContentPart, 0)
 701	for _, tr := range toolResults {
 702		parts = append(parts, tr)
 703	}
 704	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 705		Role:     message.Tool,
 706		Parts:    parts,
 707		Provider: a.providerID,
 708	})
 709	if err != nil {
 710		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 711	}
 712
 713	return assistantMsg, &msg, err
 714}
 715
 716func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 717	msg.AddFinish(finishReason, message, details)
 718	_ = a.messages.Update(ctx, *msg)
 719}
 720
 721func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 722	select {
 723	case <-ctx.Done():
 724		return ctx.Err()
 725	default:
 726		// Continue processing.
 727	}
 728
 729	switch event.Type {
 730	case provider.EventThinkingDelta:
 731		assistantMsg.AppendReasoningContent(event.Thinking)
 732		return a.messages.Update(ctx, *assistantMsg)
 733	case provider.EventSignatureDelta:
 734		assistantMsg.AppendReasoningSignature(event.Signature)
 735		return a.messages.Update(ctx, *assistantMsg)
 736	case provider.EventContentDelta:
 737		assistantMsg.FinishThinking()
 738		assistantMsg.AppendContent(event.Content)
 739		return a.messages.Update(ctx, *assistantMsg)
 740	case provider.EventToolUseStart:
 741		assistantMsg.FinishThinking()
 742		slog.Info("Tool call started", "toolCall", event.ToolCall)
 743		assistantMsg.AddToolCall(*event.ToolCall)
 744		return a.messages.Update(ctx, *assistantMsg)
 745	case provider.EventToolUseDelta:
 746		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 747		return a.messages.Update(ctx, *assistantMsg)
 748	case provider.EventToolUseStop:
 749		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 750		assistantMsg.FinishToolCall(event.ToolCall.ID)
 751		return a.messages.Update(ctx, *assistantMsg)
 752	case provider.EventError:
 753		return event.Error
 754	case provider.EventComplete:
 755		assistantMsg.FinishThinking()
 756		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 757		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 758		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 759			return fmt.Errorf("failed to update message: %w", err)
 760		}
 761		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 762	}
 763
 764	return nil
 765}
 766
 767func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 768	sess, err := a.sessions.Get(ctx, sessionID)
 769	if err != nil {
 770		return fmt.Errorf("failed to get session: %w", err)
 771	}
 772
 773	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 774		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 775		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 776		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 777
 778	a.eventTokensUsed(sessionID, usage, cost)
 779
 780	sess.Cost += cost
 781	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 782	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 783
 784	_, err = a.sessions.Save(ctx, sess)
 785	if err != nil {
 786		return fmt.Errorf("failed to save session: %w", err)
 787	}
 788	return nil
 789}
 790
 791func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 792	if a.summarizeProvider == nil {
 793		return fmt.Errorf("summarize provider not available")
 794	}
 795
 796	// Check if session is busy
 797	if a.IsSessionBusy(sessionID) {
 798		return ErrSessionBusy
 799	}
 800
 801	// Create a new context with cancellation
 802	summarizeCtx, cancel := context.WithCancel(ctx)
 803
 804	// Store the cancel function in activeRequests to allow cancellation
 805	a.activeRequests.Set(sessionID+"-summarize", cancel)
 806
 807	go func() {
 808		defer a.activeRequests.Del(sessionID + "-summarize")
 809		defer cancel()
 810		event := AgentEvent{
 811			Type:     AgentEventTypeSummarize,
 812			Progress: "Starting summarization...",
 813		}
 814
 815		a.Publish(pubsub.CreatedEvent, event)
 816		// Get all messages from the session
 817		msgs, err := a.messages.List(summarizeCtx, sessionID)
 818		if err != nil {
 819			event = AgentEvent{
 820				Type:  AgentEventTypeError,
 821				Error: fmt.Errorf("failed to list messages: %w", err),
 822				Done:  true,
 823			}
 824			a.Publish(pubsub.CreatedEvent, event)
 825			return
 826		}
 827		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 828
 829		if len(msgs) == 0 {
 830			event = AgentEvent{
 831				Type:  AgentEventTypeError,
 832				Error: fmt.Errorf("no messages to summarize"),
 833				Done:  true,
 834			}
 835			a.Publish(pubsub.CreatedEvent, event)
 836			return
 837		}
 838
 839		event = AgentEvent{
 840			Type:     AgentEventTypeSummarize,
 841			Progress: "Analyzing conversation...",
 842		}
 843		a.Publish(pubsub.CreatedEvent, event)
 844
 845		// Add a system message to guide the summarization
 846		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 847
 848		// Create a new message with the summarize prompt
 849		promptMsg := message.Message{
 850			Role:  message.User,
 851			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 852		}
 853
 854		// Append the prompt to the messages
 855		msgsWithPrompt := append(msgs, promptMsg)
 856
 857		event = AgentEvent{
 858			Type:     AgentEventTypeSummarize,
 859			Progress: "Generating summary...",
 860		}
 861
 862		a.Publish(pubsub.CreatedEvent, event)
 863
 864		// Send the messages to the summarize provider
 865		response := a.summarizeProvider.StreamResponse(
 866			summarizeCtx,
 867			msgsWithPrompt,
 868			nil,
 869		)
 870		var finalResponse *provider.ProviderResponse
 871		for r := range response {
 872			if r.Error != nil {
 873				event = AgentEvent{
 874					Type:  AgentEventTypeError,
 875					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 876					Done:  true,
 877				}
 878				a.Publish(pubsub.CreatedEvent, event)
 879				return
 880			}
 881			finalResponse = r.Response
 882		}
 883
 884		summary := strings.TrimSpace(finalResponse.Content)
 885		if summary == "" {
 886			event = AgentEvent{
 887				Type:  AgentEventTypeError,
 888				Error: fmt.Errorf("empty summary returned"),
 889				Done:  true,
 890			}
 891			a.Publish(pubsub.CreatedEvent, event)
 892			return
 893		}
 894		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 895		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 896		event = AgentEvent{
 897			Type:     AgentEventTypeSummarize,
 898			Progress: "Creating new session...",
 899		}
 900
 901		a.Publish(pubsub.CreatedEvent, event)
 902		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 903		if err != nil {
 904			event = AgentEvent{
 905				Type:  AgentEventTypeError,
 906				Error: fmt.Errorf("failed to get session: %w", err),
 907				Done:  true,
 908			}
 909
 910			a.Publish(pubsub.CreatedEvent, event)
 911			return
 912		}
 913		// Create a message in the new session with the summary
 914		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 915			Role: message.Assistant,
 916			Parts: []message.ContentPart{
 917				message.TextContent{Text: summary},
 918				message.Finish{
 919					Reason: message.FinishReasonEndTurn,
 920					Time:   time.Now().Unix(),
 921				},
 922			},
 923			Model:    a.summarizeProvider.Model().ID,
 924			Provider: a.summarizeProviderID,
 925		})
 926		if err != nil {
 927			event = AgentEvent{
 928				Type:  AgentEventTypeError,
 929				Error: fmt.Errorf("failed to create summary message: %w", err),
 930				Done:  true,
 931			}
 932
 933			a.Publish(pubsub.CreatedEvent, event)
 934			return
 935		}
 936		oldSession.SummaryMessageID = msg.ID
 937		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 938		oldSession.PromptTokens = 0
 939		model := a.summarizeProvider.Model()
 940		usage := finalResponse.Usage
 941		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 942			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 943			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 944			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 945		oldSession.Cost += cost
 946		_, err = a.sessions.Save(summarizeCtx, oldSession)
 947		if err != nil {
 948			event = AgentEvent{
 949				Type:  AgentEventTypeError,
 950				Error: fmt.Errorf("failed to save session: %w", err),
 951				Done:  true,
 952			}
 953			a.Publish(pubsub.CreatedEvent, event)
 954		}
 955
 956		event = AgentEvent{
 957			Type:      AgentEventTypeSummarize,
 958			SessionID: oldSession.ID,
 959			Progress:  "Summary complete",
 960			Done:      true,
 961		}
 962		a.Publish(pubsub.CreatedEvent, event)
 963		// Send final success event with the new session ID
 964	}()
 965
 966	return nil
 967}
 968
 969func (a *agent) ClearQueue(sessionID string) {
 970	if a.QueuedPrompts(sessionID) > 0 {
 971		slog.Info("Clearing queued prompts", "session_id", sessionID)
 972		a.promptQueue.Del(sessionID)
 973	}
 974}
 975
 976func (a *agent) CancelAll() {
 977	if !a.IsBusy() {
 978		return
 979	}
 980	for key := range a.activeRequests.Seq2() {
 981		a.Cancel(key) // key is sessionID
 982	}
 983
 984	for _, cleanup := range a.cleanupFuncs {
 985		if cleanup != nil {
 986			cleanup()
 987		}
 988	}
 989
 990	timeout := time.After(5 * time.Second)
 991	for a.IsBusy() {
 992		select {
 993		case <-timeout:
 994			return
 995		default:
 996			time.Sleep(200 * time.Millisecond)
 997		}
 998	}
 999}
1000
1001func (a *agent) UpdateModel() error {
1002	cfg := config.Get()
1003
1004	// Get current provider configuration
1005	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
1006	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
1007		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
1008	}
1009
1010	// Check if provider has changed
1011	if string(currentProviderCfg.ID) != a.providerID {
1012		// Provider changed, need to recreate the main provider
1013		model := cfg.GetModelByType(a.agentCfg.Model)
1014		if model.ID == "" {
1015			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1016		}
1017
1018		promptID := agentPromptMap[a.agentCfg.ID]
1019		if promptID == "" {
1020			promptID = prompt.PromptDefault
1021		}
1022
1023		opts := []provider.ProviderClientOption{
1024			provider.WithModel(a.agentCfg.Model),
1025			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1026		}
1027
1028		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1029		if err != nil {
1030			return fmt.Errorf("failed to create new provider: %w", err)
1031		}
1032
1033		// Update the provider and provider ID
1034		a.provider = newProvider
1035		a.providerID = string(currentProviderCfg.ID)
1036	}
1037
1038	// Check if providers have changed for title (small) and summarize (large)
1039	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1040	var smallModelProviderCfg config.ProviderConfig
1041	for p := range cfg.Providers.Seq() {
1042		if p.ID == smallModelCfg.Provider {
1043			smallModelProviderCfg = p
1044			break
1045		}
1046	}
1047	if smallModelProviderCfg.ID == "" {
1048		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1049	}
1050
1051	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1052	var largeModelProviderCfg config.ProviderConfig
1053	for p := range cfg.Providers.Seq() {
1054		if p.ID == largeModelCfg.Provider {
1055			largeModelProviderCfg = p
1056			break
1057		}
1058	}
1059	if largeModelProviderCfg.ID == "" {
1060		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1061	}
1062
1063	var maxTitleTokens int64 = 40
1064
1065	// if the max output is too low for the gemini provider it won't return anything
1066	if smallModelCfg.Provider == "gemini" {
1067		maxTitleTokens = 1000
1068	}
1069	// Recreate title provider
1070	titleOpts := []provider.ProviderClientOption{
1071		provider.WithModel(config.SelectedModelTypeSmall),
1072		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1073		provider.WithMaxTokens(maxTitleTokens),
1074	}
1075	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1076	if err != nil {
1077		return fmt.Errorf("failed to create new title provider: %w", err)
1078	}
1079	a.titleProvider = newTitleProvider
1080
1081	// Recreate summarize provider if provider changed (now large model)
1082	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1083		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1084		if largeModel == nil {
1085			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1086		}
1087		summarizeOpts := []provider.ProviderClientOption{
1088			provider.WithModel(config.SelectedModelTypeLarge),
1089			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1090		}
1091		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1092		if err != nil {
1093			return fmt.Errorf("failed to create new summarize provider: %w", err)
1094		}
1095		a.summarizeProvider = newSummarizeProvider
1096		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1097	}
1098
1099	return nil
1100}
1101
1102func (a *agent) setupEvents(ctx context.Context) {
1103	ctx, cancel := context.WithCancel(ctx)
1104
1105	go func() {
1106		subCh := SubscribeMCPEvents(ctx)
1107
1108		for {
1109			select {
1110			case event, ok := <-subCh:
1111				if !ok {
1112					slog.Debug("MCPEvents subscription channel closed")
1113					return
1114				}
1115				switch event.Payload.Type {
1116				case MCPEventToolsListChanged:
1117					name := event.Payload.Name
1118					c, ok := mcpClients.Get(name)
1119					if !ok {
1120						slog.Warn("MCP client not found for tools update", "name", name)
1121						continue
1122					}
1123					cfg := config.Get()
1124					tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1125					if err != nil {
1126						slog.Error("error listing tools", "error", err)
1127						updateMCPState(name, MCPStateError, err, nil, 0)
1128						_ = c.Close()
1129						continue
1130					}
1131					updateMcpTools(name, tools)
1132					a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
1133					updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1134				default:
1135					continue
1136				}
1137			case <-ctx.Done():
1138				slog.Debug("MCPEvents subscription cancelled")
1139				return
1140			}
1141		}
1142	}()
1143
1144	a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1145}