agent.go

   1// Package agent contains the implementation of the AI agent service.
   2package agent
   3
   4import (
   5	"context"
   6	"errors"
   7	"fmt"
   8	"log/slog"
   9	"maps"
  10	"slices"
  11	"strings"
  12	"time"
  13
  14	"github.com/charmbracelet/catwalk/pkg/catwalk"
  15	"github.com/charmbracelet/crush/internal/config"
  16	"github.com/charmbracelet/crush/internal/csync"
  17	"github.com/charmbracelet/crush/internal/event"
  18	"github.com/charmbracelet/crush/internal/history"
  19	"github.com/charmbracelet/crush/internal/llm/prompt"
  20	"github.com/charmbracelet/crush/internal/llm/provider"
  21	"github.com/charmbracelet/crush/internal/llm/tools"
  22	"github.com/charmbracelet/crush/internal/log"
  23	"github.com/charmbracelet/crush/internal/lsp"
  24	"github.com/charmbracelet/crush/internal/message"
  25	"github.com/charmbracelet/crush/internal/permission"
  26	"github.com/charmbracelet/crush/internal/pubsub"
  27	"github.com/charmbracelet/crush/internal/session"
  28	"github.com/charmbracelet/crush/internal/shell"
  29)
  30
  31type AgentEventType string
  32
  33const (
  34	AgentEventTypeError     AgentEventType = "error"
  35	AgentEventTypeResponse  AgentEventType = "response"
  36	AgentEventTypeSummarize AgentEventType = "summarize"
  37)
  38
  39type AgentEvent struct {
  40	Type    AgentEventType
  41	Message message.Message
  42	Error   error
  43
  44	// When summarizing
  45	SessionID string
  46	Progress  string
  47	Done      bool
  48}
  49
  50type Service interface {
  51	pubsub.Suscriber[AgentEvent]
  52	Model() catwalk.Model
  53	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  54	Cancel(sessionID string)
  55	CancelAll()
  56	IsSessionBusy(sessionID string) bool
  57	IsBusy() bool
  58	Summarize(ctx context.Context, sessionID string) error
  59	UpdateModel() error
  60	QueuedPrompts(sessionID string) int
  61	ClearQueue(sessionID string)
  62}
  63
  64type agent struct {
  65	*pubsub.Broker[AgentEvent]
  66	agentCfg    config.Agent
  67	sessions    session.Service
  68	messages    message.Service
  69	permissions permission.Service
  70	baseTools   *csync.Map[string, tools.BaseTool]
  71	mcpTools    *csync.Map[string, tools.BaseTool]
  72	lspClients  *csync.Map[string, *lsp.Client]
  73
  74	// We need this to be able to update it when model changes
  75	agentToolFn  func() (tools.BaseTool, error)
  76	cleanupFuncs []func()
  77
  78	provider   provider.Provider
  79	providerID string
  80
  81	titleProvider       provider.Provider
  82	summarizeProvider   provider.Provider
  83	summarizeProviderID string
  84
  85	activeRequests *csync.Map[string, context.CancelFunc]
  86	promptQueue    *csync.Map[string, []string]
  87}
  88
  89var agentPromptMap = map[string]prompt.PromptID{
  90	"coder": prompt.PromptCoder,
  91	"task":  prompt.PromptTask,
  92}
  93
  94func NewAgent(
  95	ctx context.Context,
  96	agentCfg config.Agent,
  97	// These services are needed in the tools
  98	permissions permission.Service,
  99	sessions session.Service,
 100	messages message.Service,
 101	history history.Service,
 102	lspClients *csync.Map[string, *lsp.Client],
 103) (Service, error) {
 104	cfg := config.Get()
 105
 106	var agentToolFn func() (tools.BaseTool, error)
 107	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 108		agentToolFn = func() (tools.BaseTool, error) {
 109			taskAgentCfg := config.Get().Agents["task"]
 110			if taskAgentCfg.ID == "" {
 111				return nil, fmt.Errorf("task agent not found in config")
 112			}
 113			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 114			if err != nil {
 115				return nil, fmt.Errorf("failed to create task agent: %w", err)
 116			}
 117			return NewAgentTool(taskAgent, sessions, messages), nil
 118		}
 119	}
 120
 121	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 122	if providerCfg == nil {
 123		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 124	}
 125	model := config.Get().GetModelByType(agentCfg.Model)
 126
 127	if model == nil {
 128		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 129	}
 130
 131	promptID := agentPromptMap[agentCfg.ID]
 132	if promptID == "" {
 133		promptID = prompt.PromptDefault
 134	}
 135	opts := []provider.ProviderClientOption{
 136		provider.WithModel(agentCfg.Model),
 137		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 138	}
 139	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 140	if err != nil {
 141		return nil, err
 142	}
 143
 144	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 145	var smallModelProviderCfg *config.ProviderConfig
 146	if smallModelCfg.Provider == providerCfg.ID {
 147		smallModelProviderCfg = providerCfg
 148	} else {
 149		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 150
 151		if smallModelProviderCfg.ID == "" {
 152			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 153		}
 154	}
 155	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 156	if smallModel.ID == "" {
 157		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 158	}
 159
 160	titleOpts := []provider.ProviderClientOption{
 161		provider.WithModel(config.SelectedModelTypeSmall),
 162		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 163	}
 164	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 165	if err != nil {
 166		return nil, err
 167	}
 168
 169	summarizeOpts := []provider.ProviderClientOption{
 170		provider.WithModel(config.SelectedModelTypeLarge),
 171		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 172	}
 173	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 174	if err != nil {
 175		return nil, err
 176	}
 177
 178	baseToolsFn := func() map[string]tools.BaseTool {
 179		slog.Debug("Initializing agent base tools", "agent", agentCfg.ID)
 180		defer func() {
 181			slog.Debug("Initialized agent base tools", "agent", agentCfg.ID)
 182		}()
 183
 184		// Base tools available to all agents
 185		cwd := cfg.WorkingDir()
 186		result := make(map[string]tools.BaseTool)
 187		for _, tool := range []tools.BaseTool{
 188			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 189			tools.NewDownloadTool(permissions, cwd),
 190			tools.NewEditTool(lspClients, permissions, history, cwd),
 191			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 192			tools.NewFetchTool(permissions, cwd),
 193			tools.NewGlobTool(cwd),
 194			tools.NewGrepTool(cwd),
 195			tools.NewLsTool(permissions, cwd),
 196			tools.NewSourcegraphTool(),
 197			tools.NewViewTool(lspClients, permissions, cwd),
 198			tools.NewWriteTool(lspClients, permissions, history, cwd),
 199		} {
 200			result[tool.Name()] = tool
 201		}
 202		return result
 203	}
 204	mcpToolsFn := func() map[string]tools.BaseTool {
 205		slog.Debug("Initializing agent mcp tools", "agent", agentCfg.ID)
 206		defer func() {
 207			slog.Debug("Initialized agent mcp tools", "agent", agentCfg.ID)
 208		}()
 209
 210		mcpToolsOnce.Do(func() {
 211			doGetMCPTools(ctx, permissions, cfg)
 212		})
 213
 214		return maps.Collect(mcpTools.Seq2())
 215	}
 216
 217	a := &agent{
 218		Broker:              pubsub.NewBroker[AgentEvent](),
 219		agentCfg:            agentCfg,
 220		provider:            agentProvider,
 221		providerID:          string(providerCfg.ID),
 222		messages:            messages,
 223		sessions:            sessions,
 224		titleProvider:       titleProvider,
 225		summarizeProvider:   summarizeProvider,
 226		summarizeProviderID: string(providerCfg.ID),
 227		agentToolFn:         agentToolFn,
 228		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 229		mcpTools:            csync.NewLazyMap(mcpToolsFn),
 230		baseTools:           csync.NewLazyMap(baseToolsFn),
 231		promptQueue:         csync.NewMap[string, []string](),
 232		permissions:         permissions,
 233		lspClients:          lspClients,
 234	}
 235	a.setupEvents(ctx)
 236	return a, nil
 237}
 238
 239func (a *agent) Model() catwalk.Model {
 240	return *config.Get().GetModelByType(a.agentCfg.Model)
 241}
 242
 243func (a *agent) Cancel(sessionID string) {
 244	// Cancel regular requests
 245	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 246		slog.Info("Request cancellation initiated", "session_id", sessionID)
 247		cancel()
 248	}
 249
 250	// Also check for summarize requests
 251	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 252		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 253		cancel()
 254	}
 255
 256	if a.QueuedPrompts(sessionID) > 0 {
 257		slog.Info("Clearing queued prompts", "session_id", sessionID)
 258		a.promptQueue.Del(sessionID)
 259	}
 260}
 261
 262func (a *agent) IsBusy() bool {
 263	var busy bool
 264	for cancelFunc := range a.activeRequests.Seq() {
 265		if cancelFunc != nil {
 266			busy = true
 267			break
 268		}
 269	}
 270	return busy
 271}
 272
 273func (a *agent) IsSessionBusy(sessionID string) bool {
 274	_, busy := a.activeRequests.Get(sessionID)
 275	return busy
 276}
 277
 278func (a *agent) QueuedPrompts(sessionID string) int {
 279	l, ok := a.promptQueue.Get(sessionID)
 280	if !ok {
 281		return 0
 282	}
 283	return len(l)
 284}
 285
 286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 287	if content == "" {
 288		return nil
 289	}
 290	if a.titleProvider == nil {
 291		return nil
 292	}
 293	session, err := a.sessions.Get(ctx, sessionID)
 294	if err != nil {
 295		return err
 296	}
 297	parts := []message.ContentPart{message.TextContent{
 298		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 299	}}
 300
 301	// Use streaming approach like summarization
 302	response := a.titleProvider.StreamResponse(
 303		ctx,
 304		[]message.Message{
 305			{
 306				Role:  message.User,
 307				Parts: parts,
 308			},
 309		},
 310		nil,
 311	)
 312
 313	var finalResponse *provider.ProviderResponse
 314	for r := range response {
 315		if r.Error != nil {
 316			return r.Error
 317		}
 318		finalResponse = r.Response
 319	}
 320
 321	if finalResponse == nil {
 322		return fmt.Errorf("no response received from title provider")
 323	}
 324
 325	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 326
 327	if idx := strings.Index(title, "</think>"); idx > 0 {
 328		title = title[idx+len("</think>"):]
 329	}
 330
 331	title = strings.TrimSpace(title)
 332	if title == "" {
 333		return nil
 334	}
 335
 336	session.Title = title
 337	_, err = a.sessions.Save(ctx, session)
 338	return err
 339}
 340
 341func (a *agent) err(err error) AgentEvent {
 342	return AgentEvent{
 343		Type:  AgentEventTypeError,
 344		Error: err,
 345	}
 346}
 347
 348func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 349	if !a.Model().SupportsImages && attachments != nil {
 350		attachments = nil
 351	}
 352	events := make(chan AgentEvent, 1)
 353	if a.IsSessionBusy(sessionID) {
 354		existing, ok := a.promptQueue.Get(sessionID)
 355		if !ok {
 356			existing = []string{}
 357		}
 358		existing = append(existing, content)
 359		a.promptQueue.Set(sessionID, existing)
 360		return nil, nil
 361	}
 362
 363	genCtx, cancel := context.WithCancel(ctx)
 364	a.activeRequests.Set(sessionID, cancel)
 365	startTime := time.Now()
 366
 367	go func() {
 368		slog.Debug("Request started", "sessionID", sessionID)
 369		defer log.RecoverPanic("agent.Run", func() {
 370			events <- a.err(fmt.Errorf("panic while running the agent"))
 371		})
 372		var attachmentParts []message.ContentPart
 373		for _, attachment := range attachments {
 374			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 375		}
 376		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 377		if result.Error != nil {
 378			if isCancelledErr(result.Error) {
 379				slog.Error("Request canceled", "sessionID", sessionID)
 380			} else {
 381				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 382				event.Error(result.Error)
 383			}
 384		} else {
 385			slog.Debug("Request completed", "sessionID", sessionID)
 386		}
 387		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 388		a.activeRequests.Del(sessionID)
 389		cancel()
 390		a.Publish(pubsub.CreatedEvent, result)
 391		events <- result
 392		close(events)
 393	}()
 394	a.eventPromptSent(sessionID)
 395	return events, nil
 396}
 397
 398func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 399	cfg := config.Get()
 400	// List existing messages; if none, start title generation asynchronously.
 401	msgs, err := a.messages.List(ctx, sessionID)
 402	if err != nil {
 403		return a.err(fmt.Errorf("failed to list messages: %w", err))
 404	}
 405	if len(msgs) == 0 {
 406		go func() {
 407			defer log.RecoverPanic("agent.Run", func() {
 408				slog.Error("panic while generating title")
 409			})
 410			titleErr := a.generateTitle(ctx, sessionID, content)
 411			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 412				slog.Error("failed to generate title", "error", titleErr)
 413			}
 414		}()
 415	}
 416	session, err := a.sessions.Get(ctx, sessionID)
 417	if err != nil {
 418		return a.err(fmt.Errorf("failed to get session: %w", err))
 419	}
 420	if session.SummaryMessageID != "" {
 421		summaryMsgInex := -1
 422		for i, msg := range msgs {
 423			if msg.ID == session.SummaryMessageID {
 424				summaryMsgInex = i
 425				break
 426			}
 427		}
 428		if summaryMsgInex != -1 {
 429			msgs = msgs[summaryMsgInex:]
 430			msgs[0].Role = message.User
 431		}
 432	}
 433
 434	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 435	if err != nil {
 436		return a.err(fmt.Errorf("failed to create user message: %w", err))
 437	}
 438	// Append the new user message to the conversation history.
 439	msgHistory := append(msgs, userMsg)
 440
 441	for {
 442		// Check for cancellation before each iteration
 443		select {
 444		case <-ctx.Done():
 445			return a.err(ctx.Err())
 446		default:
 447			// Continue processing
 448		}
 449		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 450		if err != nil {
 451			if errors.Is(err, context.Canceled) {
 452				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 453				a.messages.Update(context.Background(), agentMessage)
 454				return a.err(ErrRequestCancelled)
 455			}
 456			return a.err(fmt.Errorf("failed to process events: %w", err))
 457		}
 458		if cfg.Options.Debug {
 459			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 460		}
 461		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 462			// We are not done, we need to respond with the tool response
 463			msgHistory = append(msgHistory, agentMessage, *toolResults)
 464			// If there are queued prompts, process the next one
 465			nextPrompt, ok := a.promptQueue.Take(sessionID)
 466			if ok {
 467				for _, prompt := range nextPrompt {
 468					// Create a new user message for the queued prompt
 469					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 470					if err != nil {
 471						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 472					}
 473					// Append the new user message to the conversation history
 474					msgHistory = append(msgHistory, userMsg)
 475				}
 476			}
 477
 478			continue
 479		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 480			queuePrompts, ok := a.promptQueue.Take(sessionID)
 481			if ok {
 482				for _, prompt := range queuePrompts {
 483					if prompt == "" {
 484						continue
 485					}
 486					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 487					if err != nil {
 488						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 489					}
 490					msgHistory = append(msgHistory, userMsg)
 491				}
 492				continue
 493			}
 494		}
 495		if agentMessage.FinishReason() == "" {
 496			// Kujtim: could not track down where this is happening but this means its cancelled
 497			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 498			_ = a.messages.Update(context.Background(), agentMessage)
 499			return a.err(ErrRequestCancelled)
 500		}
 501		return AgentEvent{
 502			Type:    AgentEventTypeResponse,
 503			Message: agentMessage,
 504			Done:    true,
 505		}
 506	}
 507}
 508
 509func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 510	parts := []message.ContentPart{message.TextContent{Text: content}}
 511	parts = append(parts, attachmentParts...)
 512	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 513		Role:  message.User,
 514		Parts: parts,
 515	})
 516}
 517
 518func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 519	var allTools []tools.BaseTool
 520	for tool := range a.baseTools.Seq() {
 521		if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
 522			allTools = append(allTools, tool)
 523		}
 524	}
 525	if a.agentCfg.ID == "coder" {
 526		allTools = slices.AppendSeq(allTools, a.mcpTools.Seq())
 527		if a.lspClients.Len() > 0 {
 528			allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients))
 529		}
 530	}
 531	if a.agentToolFn != nil {
 532		agentTool, agentToolErr := a.agentToolFn()
 533		if agentToolErr != nil {
 534			return nil, agentToolErr
 535		}
 536		allTools = append(allTools, agentTool)
 537	}
 538	return allTools, nil
 539}
 540
 541func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 542	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 543
 544	// Create the assistant message first so the spinner shows immediately
 545	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 546		Role:     message.Assistant,
 547		Parts:    []message.ContentPart{},
 548		Model:    a.Model().ID,
 549		Provider: a.providerID,
 550	})
 551	if err != nil {
 552		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 553	}
 554
 555	allTools, toolsErr := a.getAllTools()
 556	if toolsErr != nil {
 557		return assistantMsg, nil, toolsErr
 558	}
 559	// Now collect tools (which may block on MCP initialization)
 560	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 561
 562	// Add the session and message ID into the context if needed by tools.
 563	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 564
 565loop:
 566	for {
 567		select {
 568		case event, ok := <-eventChan:
 569			if !ok {
 570				break loop
 571			}
 572			if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 573				if errors.Is(processErr, context.Canceled) {
 574					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 575				} else {
 576					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 577				}
 578				return assistantMsg, nil, processErr
 579			}
 580		case <-ctx.Done():
 581			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 582			return assistantMsg, nil, ctx.Err()
 583		}
 584	}
 585
 586	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 587	toolCalls := assistantMsg.ToolCalls()
 588	for i, toolCall := range toolCalls {
 589		select {
 590		case <-ctx.Done():
 591			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 592			// Make all future tool calls cancelled
 593			for j := i; j < len(toolCalls); j++ {
 594				toolResults[j] = message.ToolResult{
 595					ToolCallID: toolCalls[j].ID,
 596					Content:    "Tool execution canceled by user",
 597					IsError:    true,
 598				}
 599			}
 600			goto out
 601		default:
 602			// Continue processing
 603			var tool tools.BaseTool
 604			allTools, _ = a.getAllTools()
 605			for _, availableTool := range allTools {
 606				if availableTool.Info().Name == toolCall.Name {
 607					tool = availableTool
 608					break
 609				}
 610			}
 611
 612			// Tool not found
 613			if tool == nil {
 614				toolResults[i] = message.ToolResult{
 615					ToolCallID: toolCall.ID,
 616					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 617					IsError:    true,
 618				}
 619				continue
 620			}
 621
 622			// Run tool in goroutine to allow cancellation
 623			type toolExecResult struct {
 624				response tools.ToolResponse
 625				err      error
 626			}
 627			resultChan := make(chan toolExecResult, 1)
 628
 629			go func() {
 630				response, err := tool.Run(ctx, tools.ToolCall{
 631					ID:    toolCall.ID,
 632					Name:  toolCall.Name,
 633					Input: toolCall.Input,
 634				})
 635				resultChan <- toolExecResult{response: response, err: err}
 636			}()
 637
 638			var toolResponse tools.ToolResponse
 639			var toolErr error
 640
 641			select {
 642			case <-ctx.Done():
 643				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 644				// Mark remaining tool calls as cancelled
 645				for j := i; j < len(toolCalls); j++ {
 646					toolResults[j] = message.ToolResult{
 647						ToolCallID: toolCalls[j].ID,
 648						Content:    "Tool execution canceled by user",
 649						IsError:    true,
 650					}
 651				}
 652				goto out
 653			case result := <-resultChan:
 654				toolResponse = result.response
 655				toolErr = result.err
 656			}
 657
 658			if toolErr != nil {
 659				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 660				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 661					toolResults[i] = message.ToolResult{
 662						ToolCallID: toolCall.ID,
 663						Content:    "Permission denied",
 664						IsError:    true,
 665					}
 666					for j := i + 1; j < len(toolCalls); j++ {
 667						toolResults[j] = message.ToolResult{
 668							ToolCallID: toolCalls[j].ID,
 669							Content:    "Tool execution canceled by user",
 670							IsError:    true,
 671						}
 672					}
 673					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 674					break
 675				}
 676			}
 677			toolResults[i] = message.ToolResult{
 678				ToolCallID: toolCall.ID,
 679				Content:    toolResponse.Content,
 680				Metadata:   toolResponse.Metadata,
 681				IsError:    toolResponse.IsError,
 682			}
 683		}
 684	}
 685out:
 686	if len(toolResults) == 0 {
 687		return assistantMsg, nil, nil
 688	}
 689	parts := make([]message.ContentPart, 0)
 690	for _, tr := range toolResults {
 691		parts = append(parts, tr)
 692	}
 693	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 694		Role:     message.Tool,
 695		Parts:    parts,
 696		Provider: a.providerID,
 697	})
 698	if err != nil {
 699		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 700	}
 701
 702	return assistantMsg, &msg, err
 703}
 704
 705func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 706	msg.AddFinish(finishReason, message, details)
 707	_ = a.messages.Update(ctx, *msg)
 708}
 709
 710func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 711	select {
 712	case <-ctx.Done():
 713		return ctx.Err()
 714	default:
 715		// Continue processing.
 716	}
 717
 718	switch event.Type {
 719	case provider.EventThinkingDelta:
 720		assistantMsg.AppendReasoningContent(event.Thinking)
 721		return a.messages.Update(ctx, *assistantMsg)
 722	case provider.EventSignatureDelta:
 723		assistantMsg.AppendReasoningSignature(event.Signature)
 724		return a.messages.Update(ctx, *assistantMsg)
 725	case provider.EventContentDelta:
 726		assistantMsg.FinishThinking()
 727		assistantMsg.AppendContent(event.Content)
 728		return a.messages.Update(ctx, *assistantMsg)
 729	case provider.EventToolUseStart:
 730		assistantMsg.FinishThinking()
 731		slog.Info("Tool call started", "toolCall", event.ToolCall)
 732		assistantMsg.AddToolCall(*event.ToolCall)
 733		return a.messages.Update(ctx, *assistantMsg)
 734	case provider.EventToolUseDelta:
 735		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 736		return a.messages.Update(ctx, *assistantMsg)
 737	case provider.EventToolUseStop:
 738		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 739		assistantMsg.FinishToolCall(event.ToolCall.ID)
 740		return a.messages.Update(ctx, *assistantMsg)
 741	case provider.EventError:
 742		return event.Error
 743	case provider.EventComplete:
 744		assistantMsg.FinishThinking()
 745		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 746		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 747		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 748			return fmt.Errorf("failed to update message: %w", err)
 749		}
 750		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 751	}
 752
 753	return nil
 754}
 755
 756func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 757	sess, err := a.sessions.Get(ctx, sessionID)
 758	if err != nil {
 759		return fmt.Errorf("failed to get session: %w", err)
 760	}
 761
 762	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 763		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 764		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 765		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 766
 767	a.eventTokensUsed(sessionID, usage, cost)
 768
 769	sess.Cost += cost
 770	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 771	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 772
 773	_, err = a.sessions.Save(ctx, sess)
 774	if err != nil {
 775		return fmt.Errorf("failed to save session: %w", err)
 776	}
 777	return nil
 778}
 779
 780func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 781	if a.summarizeProvider == nil {
 782		return fmt.Errorf("summarize provider not available")
 783	}
 784
 785	// Check if session is busy
 786	if a.IsSessionBusy(sessionID) {
 787		return ErrSessionBusy
 788	}
 789
 790	// Create a new context with cancellation
 791	summarizeCtx, cancel := context.WithCancel(ctx)
 792
 793	// Store the cancel function in activeRequests to allow cancellation
 794	a.activeRequests.Set(sessionID+"-summarize", cancel)
 795
 796	go func() {
 797		defer a.activeRequests.Del(sessionID + "-summarize")
 798		defer cancel()
 799		event := AgentEvent{
 800			Type:     AgentEventTypeSummarize,
 801			Progress: "Starting summarization...",
 802		}
 803
 804		a.Publish(pubsub.CreatedEvent, event)
 805		// Get all messages from the session
 806		msgs, err := a.messages.List(summarizeCtx, sessionID)
 807		if err != nil {
 808			event = AgentEvent{
 809				Type:  AgentEventTypeError,
 810				Error: fmt.Errorf("failed to list messages: %w", err),
 811				Done:  true,
 812			}
 813			a.Publish(pubsub.CreatedEvent, event)
 814			return
 815		}
 816		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 817
 818		if len(msgs) == 0 {
 819			event = AgentEvent{
 820				Type:  AgentEventTypeError,
 821				Error: fmt.Errorf("no messages to summarize"),
 822				Done:  true,
 823			}
 824			a.Publish(pubsub.CreatedEvent, event)
 825			return
 826		}
 827
 828		event = AgentEvent{
 829			Type:     AgentEventTypeSummarize,
 830			Progress: "Analyzing conversation...",
 831		}
 832		a.Publish(pubsub.CreatedEvent, event)
 833
 834		// Add a system message to guide the summarization
 835		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 836
 837		// Create a new message with the summarize prompt
 838		promptMsg := message.Message{
 839			Role:  message.User,
 840			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 841		}
 842
 843		// Append the prompt to the messages
 844		msgsWithPrompt := append(msgs, promptMsg)
 845
 846		event = AgentEvent{
 847			Type:     AgentEventTypeSummarize,
 848			Progress: "Generating summary...",
 849		}
 850
 851		a.Publish(pubsub.CreatedEvent, event)
 852
 853		// Send the messages to the summarize provider
 854		response := a.summarizeProvider.StreamResponse(
 855			summarizeCtx,
 856			msgsWithPrompt,
 857			nil,
 858		)
 859		var finalResponse *provider.ProviderResponse
 860		for r := range response {
 861			if r.Error != nil {
 862				event = AgentEvent{
 863					Type:  AgentEventTypeError,
 864					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 865					Done:  true,
 866				}
 867				a.Publish(pubsub.CreatedEvent, event)
 868				return
 869			}
 870			finalResponse = r.Response
 871		}
 872
 873		summary := strings.TrimSpace(finalResponse.Content)
 874		if summary == "" {
 875			event = AgentEvent{
 876				Type:  AgentEventTypeError,
 877				Error: fmt.Errorf("empty summary returned"),
 878				Done:  true,
 879			}
 880			a.Publish(pubsub.CreatedEvent, event)
 881			return
 882		}
 883		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 884		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 885		event = AgentEvent{
 886			Type:     AgentEventTypeSummarize,
 887			Progress: "Creating new session...",
 888		}
 889
 890		a.Publish(pubsub.CreatedEvent, event)
 891		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 892		if err != nil {
 893			event = AgentEvent{
 894				Type:  AgentEventTypeError,
 895				Error: fmt.Errorf("failed to get session: %w", err),
 896				Done:  true,
 897			}
 898
 899			a.Publish(pubsub.CreatedEvent, event)
 900			return
 901		}
 902		// Create a message in the new session with the summary
 903		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 904			Role: message.Assistant,
 905			Parts: []message.ContentPart{
 906				message.TextContent{Text: summary},
 907				message.Finish{
 908					Reason: message.FinishReasonEndTurn,
 909					Time:   time.Now().Unix(),
 910				},
 911			},
 912			Model:    a.summarizeProvider.Model().ID,
 913			Provider: a.summarizeProviderID,
 914		})
 915		if err != nil {
 916			event = AgentEvent{
 917				Type:  AgentEventTypeError,
 918				Error: fmt.Errorf("failed to create summary message: %w", err),
 919				Done:  true,
 920			}
 921
 922			a.Publish(pubsub.CreatedEvent, event)
 923			return
 924		}
 925		oldSession.SummaryMessageID = msg.ID
 926		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 927		oldSession.PromptTokens = 0
 928		model := a.summarizeProvider.Model()
 929		usage := finalResponse.Usage
 930		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 931			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 932			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 933			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 934		oldSession.Cost += cost
 935		_, err = a.sessions.Save(summarizeCtx, oldSession)
 936		if err != nil {
 937			event = AgentEvent{
 938				Type:  AgentEventTypeError,
 939				Error: fmt.Errorf("failed to save session: %w", err),
 940				Done:  true,
 941			}
 942			a.Publish(pubsub.CreatedEvent, event)
 943		}
 944
 945		event = AgentEvent{
 946			Type:      AgentEventTypeSummarize,
 947			SessionID: oldSession.ID,
 948			Progress:  "Summary complete",
 949			Done:      true,
 950		}
 951		a.Publish(pubsub.CreatedEvent, event)
 952		// Send final success event with the new session ID
 953	}()
 954
 955	return nil
 956}
 957
 958func (a *agent) ClearQueue(sessionID string) {
 959	if a.QueuedPrompts(sessionID) > 0 {
 960		slog.Info("Clearing queued prompts", "session_id", sessionID)
 961		a.promptQueue.Del(sessionID)
 962	}
 963}
 964
 965func (a *agent) CancelAll() {
 966	if !a.IsBusy() {
 967		return
 968	}
 969	for key := range a.activeRequests.Seq2() {
 970		a.Cancel(key) // key is sessionID
 971	}
 972
 973	for _, cleanup := range a.cleanupFuncs {
 974		if cleanup != nil {
 975			cleanup()
 976		}
 977	}
 978
 979	timeout := time.After(5 * time.Second)
 980	for a.IsBusy() {
 981		select {
 982		case <-timeout:
 983			return
 984		default:
 985			time.Sleep(200 * time.Millisecond)
 986		}
 987	}
 988}
 989
 990func (a *agent) UpdateModel() error {
 991	cfg := config.Get()
 992
 993	// Get current provider configuration
 994	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 995	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 996		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 997	}
 998
 999	// Check if provider has changed
1000	if string(currentProviderCfg.ID) != a.providerID {
1001		// Provider changed, need to recreate the main provider
1002		model := cfg.GetModelByType(a.agentCfg.Model)
1003		if model.ID == "" {
1004			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1005		}
1006
1007		promptID := agentPromptMap[a.agentCfg.ID]
1008		if promptID == "" {
1009			promptID = prompt.PromptDefault
1010		}
1011
1012		opts := []provider.ProviderClientOption{
1013			provider.WithModel(a.agentCfg.Model),
1014			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1015		}
1016
1017		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1018		if err != nil {
1019			return fmt.Errorf("failed to create new provider: %w", err)
1020		}
1021
1022		// Update the provider and provider ID
1023		a.provider = newProvider
1024		a.providerID = string(currentProviderCfg.ID)
1025	}
1026
1027	// Check if providers have changed for title (small) and summarize (large)
1028	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1029	var smallModelProviderCfg config.ProviderConfig
1030	for p := range cfg.Providers.Seq() {
1031		if p.ID == smallModelCfg.Provider {
1032			smallModelProviderCfg = p
1033			break
1034		}
1035	}
1036	if smallModelProviderCfg.ID == "" {
1037		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1038	}
1039
1040	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1041	var largeModelProviderCfg config.ProviderConfig
1042	for p := range cfg.Providers.Seq() {
1043		if p.ID == largeModelCfg.Provider {
1044			largeModelProviderCfg = p
1045			break
1046		}
1047	}
1048	if largeModelProviderCfg.ID == "" {
1049		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1050	}
1051
1052	var maxTitleTokens int64 = 40
1053
1054	// if the max output is too low for the gemini provider it won't return anything
1055	if smallModelCfg.Provider == "gemini" {
1056		maxTitleTokens = 1000
1057	}
1058	// Recreate title provider
1059	titleOpts := []provider.ProviderClientOption{
1060		provider.WithModel(config.SelectedModelTypeSmall),
1061		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1062		provider.WithMaxTokens(maxTitleTokens),
1063	}
1064	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1065	if err != nil {
1066		return fmt.Errorf("failed to create new title provider: %w", err)
1067	}
1068	a.titleProvider = newTitleProvider
1069
1070	// Recreate summarize provider if provider changed (now large model)
1071	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1072		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1073		if largeModel == nil {
1074			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1075		}
1076		summarizeOpts := []provider.ProviderClientOption{
1077			provider.WithModel(config.SelectedModelTypeLarge),
1078			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1079		}
1080		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1081		if err != nil {
1082			return fmt.Errorf("failed to create new summarize provider: %w", err)
1083		}
1084		a.summarizeProvider = newSummarizeProvider
1085		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1086	}
1087
1088	return nil
1089}
1090
1091func (a *agent) setupEvents(ctx context.Context) {
1092	ctx, cancel := context.WithCancel(ctx)
1093
1094	go func() {
1095		subCh := SubscribeMCPEvents(ctx)
1096
1097		for {
1098			select {
1099			case event, ok := <-subCh:
1100				if !ok {
1101					slog.Debug("MCPEvents subscription channel closed")
1102					return
1103				}
1104				name := event.Payload.Name
1105				c, ok := mcpClients.Get(name)
1106				if !ok {
1107					slog.Warn("MCP client not found for tools/prompts update", "name", name)
1108					continue
1109				}
1110				switch event.Payload.Type {
1111				case MCPEventToolsListChanged:
1112					cfg := config.Get()
1113					tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1114					if err != nil {
1115						slog.Error("error listing tools", "error", err)
1116						updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
1117						_ = c.Close()
1118						continue
1119					}
1120					updateMcpTools(name, tools)
1121					a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
1122				case MCPEventPromptsListChanged:
1123					prompts, err := getPrompts(ctx, c)
1124					if err != nil {
1125						slog.Error("error listing prompts", "error", err)
1126						updateMCPState(name, MCPStateError, err, nil, MCPCounts{})
1127						_ = c.Close()
1128						continue
1129					}
1130					updateMcpPrompts(name, prompts)
1131				default:
1132					continue
1133				}
1134				updateMCPState(name, MCPStateConnected, nil, c, MCPCounts{
1135					Tools:   mcpTools.Len(),
1136					Prompts: mcpPrompts.Len(),
1137				})
1138			case <-ctx.Done():
1139				slog.Debug("MCPEvents subscription cancelled")
1140				return
1141			}
1142		}
1143	}()
1144
1145	a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1146}