agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"maps"
   9	"slices"
  10	"strings"
  11	"time"
  12
  13	"github.com/charmbracelet/catwalk/pkg/catwalk"
  14	"github.com/charmbracelet/crush/internal/config"
  15	"github.com/charmbracelet/crush/internal/csync"
  16	"github.com/charmbracelet/crush/internal/event"
  17	"github.com/charmbracelet/crush/internal/history"
  18	"github.com/charmbracelet/crush/internal/llm/prompt"
  19	"github.com/charmbracelet/crush/internal/llm/provider"
  20	"github.com/charmbracelet/crush/internal/llm/tools"
  21	"github.com/charmbracelet/crush/internal/log"
  22	"github.com/charmbracelet/crush/internal/lsp"
  23	"github.com/charmbracelet/crush/internal/message"
  24	"github.com/charmbracelet/crush/internal/permission"
  25	"github.com/charmbracelet/crush/internal/pubsub"
  26	"github.com/charmbracelet/crush/internal/session"
  27	"github.com/charmbracelet/crush/internal/shell"
  28)
  29
  30const streamChunkTimeout = 80 * time.Second
  31
  32type AgentEventType string
  33
  34const (
  35	AgentEventTypeError     AgentEventType = "error"
  36	AgentEventTypeResponse  AgentEventType = "response"
  37	AgentEventTypeSummarize AgentEventType = "summarize"
  38)
  39
  40type AgentEvent struct {
  41	Type    AgentEventType
  42	Message message.Message
  43	Error   error
  44
  45	// When summarizing
  46	SessionID string
  47	Progress  string
  48	Done      bool
  49}
  50
  51type Service interface {
  52	pubsub.Suscriber[AgentEvent]
  53	Model() catwalk.Model
  54	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  55	Cancel(sessionID string)
  56	CancelAll()
  57	IsSessionBusy(sessionID string) bool
  58	IsBusy() bool
  59	Summarize(ctx context.Context, sessionID string) error
  60	UpdateModel() error
  61	QueuedPrompts(sessionID string) int
  62	ClearQueue(sessionID string)
  63}
  64
  65type agent struct {
  66	cleanupFuncs []func()
  67
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72
  73	permissions permission.Service
  74	baseTools   *csync.Map[string, tools.BaseTool]
  75	mcpTools    *csync.Map[string, tools.BaseTool]
  76	lspClients  *csync.Map[string, *lsp.Client]
  77
  78	// We need this to be able to update it when model changes
  79	agentToolFn func() (tools.BaseTool, error)
  80
  81	provider   provider.Provider
  82	providerID string
  83
  84	titleProvider       provider.Provider
  85	summarizeProvider   provider.Provider
  86	summarizeProviderID string
  87
  88	activeRequests *csync.Map[string, context.CancelFunc]
  89	promptQueue    *csync.Map[string, []string]
  90}
  91
  92var agentPromptMap = map[string]prompt.PromptID{
  93	"coder": prompt.PromptCoder,
  94	"task":  prompt.PromptTask,
  95}
  96
  97func NewAgent(
  98	ctx context.Context,
  99	agentCfg config.Agent,
 100	// These services are needed in the tools
 101	permissions permission.Service,
 102	sessions session.Service,
 103	messages message.Service,
 104	history history.Service,
 105	lspClients *csync.Map[string, *lsp.Client],
 106) (Service, error) {
 107	cfg := config.Get()
 108
 109	var agentToolFn func() (tools.BaseTool, error)
 110	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 111		agentToolFn = func() (tools.BaseTool, error) {
 112			taskAgentCfg := config.Get().Agents["task"]
 113			if taskAgentCfg.ID == "" {
 114				return nil, fmt.Errorf("task agent not found in config")
 115			}
 116			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 117			if err != nil {
 118				return nil, fmt.Errorf("failed to create task agent: %w", err)
 119			}
 120			return NewAgentTool(taskAgent, sessions, messages), nil
 121		}
 122	}
 123
 124	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 125	if providerCfg == nil {
 126		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 127	}
 128	model := config.Get().GetModelByType(agentCfg.Model)
 129
 130	if model == nil {
 131		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 132	}
 133
 134	promptID := agentPromptMap[agentCfg.ID]
 135	if promptID == "" {
 136		promptID = prompt.PromptDefault
 137	}
 138	opts := []provider.ProviderClientOption{
 139		provider.WithModel(agentCfg.Model),
 140		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 141	}
 142	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 143	if err != nil {
 144		return nil, err
 145	}
 146
 147	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 148	var smallModelProviderCfg *config.ProviderConfig
 149	if smallModelCfg.Provider == providerCfg.ID {
 150		smallModelProviderCfg = providerCfg
 151	} else {
 152		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 153
 154		if smallModelProviderCfg.ID == "" {
 155			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 156		}
 157	}
 158	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 159	if smallModel.ID == "" {
 160		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 161	}
 162
 163	titleOpts := []provider.ProviderClientOption{
 164		provider.WithModel(config.SelectedModelTypeSmall),
 165		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 166	}
 167	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 168	if err != nil {
 169		return nil, err
 170	}
 171
 172	summarizeOpts := []provider.ProviderClientOption{
 173		provider.WithModel(config.SelectedModelTypeLarge),
 174		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 175	}
 176	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 177	if err != nil {
 178		return nil, err
 179	}
 180
 181	baseToolsFn := func() map[string]tools.BaseTool {
 182		slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
 183		defer func() {
 184			slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
 185		}()
 186
 187		// Base tools available to all agents
 188		cwd := cfg.WorkingDir()
 189		result := make(map[string]tools.BaseTool)
 190		for _, tool := range []tools.BaseTool{
 191			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 192			tools.NewDownloadTool(permissions, cwd),
 193			tools.NewEditTool(lspClients, permissions, history, cwd),
 194			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 195			tools.NewFetchTool(permissions, cwd),
 196			tools.NewGlobTool(cwd),
 197			tools.NewGrepTool(cwd),
 198			tools.NewLsTool(permissions, cwd),
 199			tools.NewSourcegraphTool(),
 200			tools.NewViewTool(lspClients, permissions, cwd),
 201			tools.NewWriteTool(lspClients, permissions, history, cwd),
 202		} {
 203			result[tool.Name()] = tool
 204		}
 205		return result
 206	}
 207	mcpToolsFn := func() map[string]tools.BaseTool {
 208		slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
 209		defer func() {
 210			slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
 211		}()
 212
 213		mcpToolsOnce.Do(func() {
 214			doGetMCPTools(ctx, permissions, cfg)
 215		})
 216		result := make(map[string]tools.BaseTool)
 217		for _, mcpTool := range mcpTools.Seq2() {
 218			result[mcpTool.Name()] = mcpTool
 219		}
 220		return result
 221	}
 222
 223	a := &agent{
 224		Broker:              pubsub.NewBroker[AgentEvent](),
 225		agentCfg:            agentCfg,
 226		provider:            agentProvider,
 227		providerID:          string(providerCfg.ID),
 228		messages:            messages,
 229		sessions:            sessions,
 230		titleProvider:       titleProvider,
 231		summarizeProvider:   summarizeProvider,
 232		summarizeProviderID: string(providerCfg.ID),
 233		agentToolFn:         agentToolFn,
 234		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 235		mcpTools:            csync.NewLazyMap(mcpToolsFn),
 236		baseTools:           csync.NewLazyMap(baseToolsFn),
 237		promptQueue:         csync.NewMap[string, []string](),
 238		permissions:         permissions,
 239		lspClients:          lspClients,
 240	}
 241	a.setupEvents(ctx)
 242	return a, nil
 243}
 244
 245func (a *agent) Model() catwalk.Model {
 246	return *config.Get().GetModelByType(a.agentCfg.Model)
 247}
 248
 249func (a *agent) Cancel(sessionID string) {
 250	// Cancel regular requests
 251	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 252		slog.Info("Request cancellation initiated", "session_id", sessionID)
 253		cancel()
 254	}
 255
 256	// Also check for summarize requests
 257	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 258		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 259		cancel()
 260	}
 261
 262	if a.QueuedPrompts(sessionID) > 0 {
 263		slog.Info("Clearing queued prompts", "session_id", sessionID)
 264		a.promptQueue.Del(sessionID)
 265	}
 266}
 267
 268func (a *agent) IsBusy() bool {
 269	var busy bool
 270	for cancelFunc := range a.activeRequests.Seq() {
 271		if cancelFunc != nil {
 272			busy = true
 273			break
 274		}
 275	}
 276	return busy
 277}
 278
 279func (a *agent) IsSessionBusy(sessionID string) bool {
 280	_, busy := a.activeRequests.Get(sessionID)
 281	return busy
 282}
 283
 284func (a *agent) QueuedPrompts(sessionID string) int {
 285	l, ok := a.promptQueue.Get(sessionID)
 286	if !ok {
 287		return 0
 288	}
 289	return len(l)
 290}
 291
 292func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 293	if content == "" {
 294		return nil
 295	}
 296	if a.titleProvider == nil {
 297		return nil
 298	}
 299	session, err := a.sessions.Get(ctx, sessionID)
 300	if err != nil {
 301		return err
 302	}
 303	parts := []message.ContentPart{message.TextContent{
 304		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 305	}}
 306
 307	// Use streaming approach like summarization
 308	response := a.titleProvider.StreamResponse(
 309		ctx,
 310		[]message.Message{
 311			{
 312				Role:  message.User,
 313				Parts: parts,
 314			},
 315		},
 316		nil,
 317	)
 318
 319	var finalResponse *provider.ProviderResponse
 320	for r := range response {
 321		if r.Error != nil {
 322			return r.Error
 323		}
 324		finalResponse = r.Response
 325	}
 326
 327	if finalResponse == nil {
 328		return fmt.Errorf("no response received from title provider")
 329	}
 330
 331	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 332
 333	if idx := strings.Index(title, "</think>"); idx > 0 {
 334		title = title[idx+len("</think>"):]
 335	}
 336
 337	title = strings.TrimSpace(title)
 338	if title == "" {
 339		return nil
 340	}
 341
 342	session.Title = title
 343	_, err = a.sessions.Save(ctx, session)
 344	return err
 345}
 346
 347func (a *agent) err(err error) AgentEvent {
 348	return AgentEvent{
 349		Type:  AgentEventTypeError,
 350		Error: err,
 351	}
 352}
 353
 354func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 355	if !a.Model().SupportsImages && attachments != nil {
 356		attachments = nil
 357	}
 358	events := make(chan AgentEvent, 1)
 359	if a.IsSessionBusy(sessionID) {
 360		existing, ok := a.promptQueue.Get(sessionID)
 361		if !ok {
 362			existing = []string{}
 363		}
 364		existing = append(existing, content)
 365		a.promptQueue.Set(sessionID, existing)
 366		return nil, nil
 367	}
 368
 369	genCtx, cancel := context.WithCancel(ctx)
 370	a.activeRequests.Set(sessionID, cancel)
 371	startTime := time.Now()
 372
 373	go func() {
 374		slog.Debug("Request started", "sessionID", sessionID)
 375		defer log.RecoverPanic("agent.Run", func() {
 376			events <- a.err(fmt.Errorf("panic while running the agent"))
 377		})
 378		var attachmentParts []message.ContentPart
 379		for _, attachment := range attachments {
 380			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 381		}
 382		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 383		if result.Error != nil {
 384			if isCancelledErr(result.Error) {
 385				slog.Error("Request canceled", "sessionID", sessionID)
 386			} else {
 387				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 388				event.Error(result.Error)
 389			}
 390		} else {
 391			slog.Debug("Request completed", "sessionID", sessionID)
 392		}
 393		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 394		a.activeRequests.Del(sessionID)
 395		cancel()
 396		a.Publish(pubsub.CreatedEvent, result)
 397		events <- result
 398		close(events)
 399	}()
 400	a.eventPromptSent(sessionID)
 401	return events, nil
 402}
 403
 404func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 405	cfg := config.Get()
 406	// List existing messages; if none, start title generation asynchronously.
 407	msgs, err := a.messages.List(ctx, sessionID)
 408	if err != nil {
 409		return a.err(fmt.Errorf("failed to list messages: %w", err))
 410	}
 411	if len(msgs) == 0 {
 412		go func() {
 413			defer log.RecoverPanic("agent.Run", func() {
 414				slog.Error("panic while generating title")
 415			})
 416			titleErr := a.generateTitle(ctx, sessionID, content)
 417			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 418				slog.Error("failed to generate title", "error", titleErr)
 419			}
 420		}()
 421	}
 422	session, err := a.sessions.Get(ctx, sessionID)
 423	if err != nil {
 424		return a.err(fmt.Errorf("failed to get session: %w", err))
 425	}
 426	if session.SummaryMessageID != "" {
 427		summaryMsgInex := -1
 428		for i, msg := range msgs {
 429			if msg.ID == session.SummaryMessageID {
 430				summaryMsgInex = i
 431				break
 432			}
 433		}
 434		if summaryMsgInex != -1 {
 435			msgs = msgs[summaryMsgInex:]
 436			msgs[0].Role = message.User
 437		}
 438	}
 439
 440	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 441	if err != nil {
 442		return a.err(fmt.Errorf("failed to create user message: %w", err))
 443	}
 444	// Append the new user message to the conversation history.
 445	msgHistory := append(msgs, userMsg)
 446
 447	for {
 448		// Check for cancellation before each iteration
 449		select {
 450		case <-ctx.Done():
 451			return a.err(ctx.Err())
 452		default:
 453			// Continue processing
 454		}
 455		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 456		if err != nil {
 457			if errors.Is(err, context.Canceled) {
 458				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 459				a.messages.Update(context.Background(), agentMessage)
 460				return a.err(ErrRequestCancelled)
 461			}
 462			return a.err(fmt.Errorf("failed to process events: %w", err))
 463		}
 464		if cfg.Options.Debug {
 465			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 466		}
 467		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 468			// We are not done, we need to respond with the tool response
 469			msgHistory = append(msgHistory, agentMessage, *toolResults)
 470			// If there are queued prompts, process the next one
 471			nextPrompt, ok := a.promptQueue.Take(sessionID)
 472			if ok {
 473				for _, prompt := range nextPrompt {
 474					// Create a new user message for the queued prompt
 475					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 476					if err != nil {
 477						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 478					}
 479					// Append the new user message to the conversation history
 480					msgHistory = append(msgHistory, userMsg)
 481				}
 482			}
 483
 484			continue
 485		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 486			queuePrompts, ok := a.promptQueue.Take(sessionID)
 487			if ok {
 488				for _, prompt := range queuePrompts {
 489					if prompt == "" {
 490						continue
 491					}
 492					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 493					if err != nil {
 494						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 495					}
 496					msgHistory = append(msgHistory, userMsg)
 497				}
 498				continue
 499			}
 500		}
 501		if agentMessage.FinishReason() == "" {
 502			// Kujtim: could not track down where this is happening but this means its cancelled
 503			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 504			_ = a.messages.Update(context.Background(), agentMessage)
 505			return a.err(ErrRequestCancelled)
 506		}
 507		return AgentEvent{
 508			Type:    AgentEventTypeResponse,
 509			Message: agentMessage,
 510			Done:    true,
 511		}
 512	}
 513}
 514
 515func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 516	parts := []message.ContentPart{message.TextContent{Text: content}}
 517	parts = append(parts, attachmentParts...)
 518	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 519		Role:  message.User,
 520		Parts: parts,
 521	})
 522}
 523
 524func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 525	allTools := slices.Collect(a.baseTools.Seq())
 526
 527	withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 528		if a.agentCfg.ID == "coder" {
 529			t = append(t, slices.Collect(a.mcpTools.Seq())...)
 530			if a.lspClients.Len() > 0 {
 531				t = append(t, tools.NewDiagnosticsTool(a.lspClients))
 532			}
 533		}
 534		return t
 535	}
 536
 537	if a.agentCfg.AllowedTools == nil {
 538		allTools = withCoderTools(allTools)
 539	} else {
 540		var filteredTools []tools.BaseTool
 541		for _, tool := range allTools {
 542			if slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
 543				filteredTools = append(filteredTools, tool)
 544			}
 545		}
 546		allTools = withCoderTools(filteredTools)
 547	}
 548
 549	if a.agentToolFn != nil {
 550		agentTool, agentToolErr := a.agentToolFn()
 551		if agentToolErr != nil {
 552			return nil, agentToolErr
 553		}
 554		allTools = append(allTools, agentTool)
 555	}
 556	return allTools, nil
 557}
 558
 559func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 560	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 561
 562	// Create the assistant message first so the spinner shows immediately
 563	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 564		Role:     message.Assistant,
 565		Parts:    []message.ContentPart{},
 566		Model:    a.Model().ID,
 567		Provider: a.providerID,
 568	})
 569	if err != nil {
 570		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 571	}
 572
 573	allTools, toolsErr := a.getAllTools()
 574	if toolsErr != nil {
 575		return assistantMsg, nil, toolsErr
 576	}
 577	// Now collect tools (which may block on MCP initialization)
 578	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 579
 580	// Add the session and message ID into the context if needed by tools.
 581	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 582
 583	// Process each event in the stream.
 584loop:
 585	for {
 586		select {
 587		case event, ok := <-eventChan:
 588			if !ok {
 589				break loop
 590			}
 591			if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 592				if errors.Is(processErr, context.Canceled) {
 593					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 594				} else {
 595					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 596				}
 597				return assistantMsg, nil, processErr
 598			}
 599		case <-time.After(streamChunkTimeout):
 600			a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "Stream timeout", "No chunk received within timeout")
 601			return assistantMsg, nil, fmt.Errorf("stream chunk timeout")
 602		case <-ctx.Done():
 603			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 604			return assistantMsg, nil, ctx.Err()
 605		}
 606	}
 607
 608	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 609	toolCalls := assistantMsg.ToolCalls()
 610	for i, toolCall := range toolCalls {
 611		select {
 612		case <-ctx.Done():
 613			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 614			// Make all future tool calls cancelled
 615			for j := i; j < len(toolCalls); j++ {
 616				toolResults[j] = message.ToolResult{
 617					ToolCallID: toolCalls[j].ID,
 618					Content:    "Tool execution canceled by user",
 619					IsError:    true,
 620				}
 621			}
 622			goto out
 623		default:
 624			// Continue processing
 625			var tool tools.BaseTool
 626			allTools, _ = a.getAllTools()
 627			for _, availableTool := range allTools {
 628				if availableTool.Info().Name == toolCall.Name {
 629					tool = availableTool
 630					break
 631				}
 632			}
 633
 634			// Tool not found
 635			if tool == nil {
 636				toolResults[i] = message.ToolResult{
 637					ToolCallID: toolCall.ID,
 638					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 639					IsError:    true,
 640				}
 641				continue
 642			}
 643
 644			// Run tool in goroutine to allow cancellation
 645			type toolExecResult struct {
 646				response tools.ToolResponse
 647				err      error
 648			}
 649			resultChan := make(chan toolExecResult, 1)
 650
 651			go func() {
 652				response, err := tool.Run(ctx, tools.ToolCall{
 653					ID:    toolCall.ID,
 654					Name:  toolCall.Name,
 655					Input: toolCall.Input,
 656				})
 657				resultChan <- toolExecResult{response: response, err: err}
 658			}()
 659
 660			var toolResponse tools.ToolResponse
 661			var toolErr error
 662
 663			select {
 664			case <-ctx.Done():
 665				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 666				// Mark remaining tool calls as cancelled
 667				for j := i; j < len(toolCalls); j++ {
 668					toolResults[j] = message.ToolResult{
 669						ToolCallID: toolCalls[j].ID,
 670						Content:    "Tool execution canceled by user",
 671						IsError:    true,
 672					}
 673				}
 674				goto out
 675			case result := <-resultChan:
 676				toolResponse = result.response
 677				toolErr = result.err
 678			}
 679
 680			if toolErr != nil {
 681				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 682				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 683					toolResults[i] = message.ToolResult{
 684						ToolCallID: toolCall.ID,
 685						Content:    "Permission denied",
 686						IsError:    true,
 687					}
 688					for j := i + 1; j < len(toolCalls); j++ {
 689						toolResults[j] = message.ToolResult{
 690							ToolCallID: toolCalls[j].ID,
 691							Content:    "Tool execution canceled by user",
 692							IsError:    true,
 693						}
 694					}
 695					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 696					break
 697				}
 698			}
 699			toolResults[i] = message.ToolResult{
 700				ToolCallID: toolCall.ID,
 701				Content:    toolResponse.Content,
 702				Metadata:   toolResponse.Metadata,
 703				IsError:    toolResponse.IsError,
 704			}
 705		}
 706	}
 707out:
 708	if len(toolResults) == 0 {
 709		return assistantMsg, nil, nil
 710	}
 711	parts := make([]message.ContentPart, 0)
 712	for _, tr := range toolResults {
 713		parts = append(parts, tr)
 714	}
 715	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 716		Role:     message.Tool,
 717		Parts:    parts,
 718		Provider: a.providerID,
 719	})
 720	if err != nil {
 721		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 722	}
 723
 724	return assistantMsg, &msg, err
 725}
 726
 727func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 728	msg.AddFinish(finishReason, message, details)
 729	_ = a.messages.Update(ctx, *msg)
 730}
 731
 732func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 733	select {
 734	case <-ctx.Done():
 735		return ctx.Err()
 736	default:
 737		// Continue processing.
 738	}
 739
 740	switch event.Type {
 741	case provider.EventThinkingDelta:
 742		assistantMsg.AppendReasoningContent(event.Thinking)
 743		return a.messages.Update(ctx, *assistantMsg)
 744	case provider.EventSignatureDelta:
 745		assistantMsg.AppendReasoningSignature(event.Signature)
 746		return a.messages.Update(ctx, *assistantMsg)
 747	case provider.EventContentDelta:
 748		assistantMsg.FinishThinking()
 749		assistantMsg.AppendContent(event.Content)
 750		return a.messages.Update(ctx, *assistantMsg)
 751	case provider.EventToolUseStart:
 752		assistantMsg.FinishThinking()
 753		slog.Info("Tool call started", "toolCall", event.ToolCall)
 754		assistantMsg.AddToolCall(*event.ToolCall)
 755		return a.messages.Update(ctx, *assistantMsg)
 756	case provider.EventToolUseDelta:
 757		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 758		return a.messages.Update(ctx, *assistantMsg)
 759	case provider.EventToolUseStop:
 760		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 761		assistantMsg.FinishToolCall(event.ToolCall.ID)
 762		return a.messages.Update(ctx, *assistantMsg)
 763	case provider.EventError:
 764		return event.Error
 765	case provider.EventComplete:
 766		assistantMsg.FinishThinking()
 767		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 768		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 769		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 770			return fmt.Errorf("failed to update message: %w", err)
 771		}
 772		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 773	}
 774
 775	return nil
 776}
 777
 778func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 779	sess, err := a.sessions.Get(ctx, sessionID)
 780	if err != nil {
 781		return fmt.Errorf("failed to get session: %w", err)
 782	}
 783
 784	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 785		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 786		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 787		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 788
 789	a.eventTokensUsed(sessionID, usage, cost)
 790
 791	sess.Cost += cost
 792	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 793	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 794
 795	_, err = a.sessions.Save(ctx, sess)
 796	if err != nil {
 797		return fmt.Errorf("failed to save session: %w", err)
 798	}
 799	return nil
 800}
 801
 802func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 803	if a.summarizeProvider == nil {
 804		return fmt.Errorf("summarize provider not available")
 805	}
 806
 807	// Check if session is busy
 808	if a.IsSessionBusy(sessionID) {
 809		return ErrSessionBusy
 810	}
 811
 812	// Create a new context with cancellation
 813	summarizeCtx, cancel := context.WithCancel(ctx)
 814
 815	// Store the cancel function in activeRequests to allow cancellation
 816	a.activeRequests.Set(sessionID+"-summarize", cancel)
 817
 818	go func() {
 819		defer a.activeRequests.Del(sessionID + "-summarize")
 820		defer cancel()
 821		event := AgentEvent{
 822			Type:     AgentEventTypeSummarize,
 823			Progress: "Starting summarization...",
 824		}
 825
 826		a.Publish(pubsub.CreatedEvent, event)
 827		// Get all messages from the session
 828		msgs, err := a.messages.List(summarizeCtx, sessionID)
 829		if err != nil {
 830			event = AgentEvent{
 831				Type:  AgentEventTypeError,
 832				Error: fmt.Errorf("failed to list messages: %w", err),
 833				Done:  true,
 834			}
 835			a.Publish(pubsub.CreatedEvent, event)
 836			return
 837		}
 838		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 839
 840		if len(msgs) == 0 {
 841			event = AgentEvent{
 842				Type:  AgentEventTypeError,
 843				Error: fmt.Errorf("no messages to summarize"),
 844				Done:  true,
 845			}
 846			a.Publish(pubsub.CreatedEvent, event)
 847			return
 848		}
 849
 850		event = AgentEvent{
 851			Type:     AgentEventTypeSummarize,
 852			Progress: "Analyzing conversation...",
 853		}
 854		a.Publish(pubsub.CreatedEvent, event)
 855
 856		// Add a system message to guide the summarization
 857		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 858
 859		// Create a new message with the summarize prompt
 860		promptMsg := message.Message{
 861			Role:  message.User,
 862			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 863		}
 864
 865		// Append the prompt to the messages
 866		msgsWithPrompt := append(msgs, promptMsg)
 867
 868		event = AgentEvent{
 869			Type:     AgentEventTypeSummarize,
 870			Progress: "Generating summary...",
 871		}
 872
 873		a.Publish(pubsub.CreatedEvent, event)
 874
 875		// Send the messages to the summarize provider
 876		response := a.summarizeProvider.StreamResponse(
 877			summarizeCtx,
 878			msgsWithPrompt,
 879			nil,
 880		)
 881		var finalResponse *provider.ProviderResponse
 882		for r := range response {
 883			if r.Error != nil {
 884				event = AgentEvent{
 885					Type:  AgentEventTypeError,
 886					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 887					Done:  true,
 888				}
 889				a.Publish(pubsub.CreatedEvent, event)
 890				return
 891			}
 892			finalResponse = r.Response
 893		}
 894
 895		summary := strings.TrimSpace(finalResponse.Content)
 896		if summary == "" {
 897			event = AgentEvent{
 898				Type:  AgentEventTypeError,
 899				Error: fmt.Errorf("empty summary returned"),
 900				Done:  true,
 901			}
 902			a.Publish(pubsub.CreatedEvent, event)
 903			return
 904		}
 905		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 906		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 907		event = AgentEvent{
 908			Type:     AgentEventTypeSummarize,
 909			Progress: "Creating new session...",
 910		}
 911
 912		a.Publish(pubsub.CreatedEvent, event)
 913		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 914		if err != nil {
 915			event = AgentEvent{
 916				Type:  AgentEventTypeError,
 917				Error: fmt.Errorf("failed to get session: %w", err),
 918				Done:  true,
 919			}
 920
 921			a.Publish(pubsub.CreatedEvent, event)
 922			return
 923		}
 924		// Create a message in the new session with the summary
 925		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 926			Role: message.Assistant,
 927			Parts: []message.ContentPart{
 928				message.TextContent{Text: summary},
 929				message.Finish{
 930					Reason: message.FinishReasonEndTurn,
 931					Time:   time.Now().Unix(),
 932				},
 933			},
 934			Model:    a.summarizeProvider.Model().ID,
 935			Provider: a.summarizeProviderID,
 936		})
 937		if err != nil {
 938			event = AgentEvent{
 939				Type:  AgentEventTypeError,
 940				Error: fmt.Errorf("failed to create summary message: %w", err),
 941				Done:  true,
 942			}
 943
 944			a.Publish(pubsub.CreatedEvent, event)
 945			return
 946		}
 947		oldSession.SummaryMessageID = msg.ID
 948		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 949		oldSession.PromptTokens = 0
 950		model := a.summarizeProvider.Model()
 951		usage := finalResponse.Usage
 952		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 953			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 954			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 955			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 956		oldSession.Cost += cost
 957		_, err = a.sessions.Save(summarizeCtx, oldSession)
 958		if err != nil {
 959			event = AgentEvent{
 960				Type:  AgentEventTypeError,
 961				Error: fmt.Errorf("failed to save session: %w", err),
 962				Done:  true,
 963			}
 964			a.Publish(pubsub.CreatedEvent, event)
 965		}
 966
 967		event = AgentEvent{
 968			Type:      AgentEventTypeSummarize,
 969			SessionID: oldSession.ID,
 970			Progress:  "Summary complete",
 971			Done:      true,
 972		}
 973		a.Publish(pubsub.CreatedEvent, event)
 974		// Send final success event with the new session ID
 975	}()
 976
 977	return nil
 978}
 979
 980func (a *agent) ClearQueue(sessionID string) {
 981	if a.QueuedPrompts(sessionID) > 0 {
 982		slog.Info("Clearing queued prompts", "session_id", sessionID)
 983		a.promptQueue.Del(sessionID)
 984	}
 985}
 986
 987func (a *agent) CancelAll() {
 988	if !a.IsBusy() {
 989		return
 990	}
 991	for key := range a.activeRequests.Seq2() {
 992		a.Cancel(key) // key is sessionID
 993	}
 994
 995	for _, cleanup := range a.cleanupFuncs {
 996		if cleanup != nil {
 997			cleanup()
 998		}
 999	}
1000
1001	timeout := time.After(5 * time.Second)
1002	for a.IsBusy() {
1003		select {
1004		case <-timeout:
1005			return
1006		default:
1007			time.Sleep(200 * time.Millisecond)
1008		}
1009	}
1010}
1011
1012func (a *agent) UpdateModel() error {
1013	cfg := config.Get()
1014
1015	// Get current provider configuration
1016	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
1017	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
1018		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
1019	}
1020
1021	// Check if provider has changed
1022	if string(currentProviderCfg.ID) != a.providerID {
1023		// Provider changed, need to recreate the main provider
1024		model := cfg.GetModelByType(a.agentCfg.Model)
1025		if model.ID == "" {
1026			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1027		}
1028
1029		promptID := agentPromptMap[a.agentCfg.ID]
1030		if promptID == "" {
1031			promptID = prompt.PromptDefault
1032		}
1033
1034		opts := []provider.ProviderClientOption{
1035			provider.WithModel(a.agentCfg.Model),
1036			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1037		}
1038
1039		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1040		if err != nil {
1041			return fmt.Errorf("failed to create new provider: %w", err)
1042		}
1043
1044		// Update the provider and provider ID
1045		a.provider = newProvider
1046		a.providerID = string(currentProviderCfg.ID)
1047	}
1048
1049	// Check if providers have changed for title (small) and summarize (large)
1050	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1051	var smallModelProviderCfg config.ProviderConfig
1052	for p := range cfg.Providers.Seq() {
1053		if p.ID == smallModelCfg.Provider {
1054			smallModelProviderCfg = p
1055			break
1056		}
1057	}
1058	if smallModelProviderCfg.ID == "" {
1059		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1060	}
1061
1062	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1063	var largeModelProviderCfg config.ProviderConfig
1064	for p := range cfg.Providers.Seq() {
1065		if p.ID == largeModelCfg.Provider {
1066			largeModelProviderCfg = p
1067			break
1068		}
1069	}
1070	if largeModelProviderCfg.ID == "" {
1071		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1072	}
1073
1074	var maxTitleTokens int64 = 40
1075
1076	// if the max output is too low for the gemini provider it won't return anything
1077	if smallModelCfg.Provider == "gemini" {
1078		maxTitleTokens = 1000
1079	}
1080	// Recreate title provider
1081	titleOpts := []provider.ProviderClientOption{
1082		provider.WithModel(config.SelectedModelTypeSmall),
1083		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1084		provider.WithMaxTokens(maxTitleTokens),
1085	}
1086	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1087	if err != nil {
1088		return fmt.Errorf("failed to create new title provider: %w", err)
1089	}
1090	a.titleProvider = newTitleProvider
1091
1092	// Recreate summarize provider if provider changed (now large model)
1093	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1094		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1095		if largeModel == nil {
1096			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1097		}
1098		summarizeOpts := []provider.ProviderClientOption{
1099			provider.WithModel(config.SelectedModelTypeLarge),
1100			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1101		}
1102		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1103		if err != nil {
1104			return fmt.Errorf("failed to create new summarize provider: %w", err)
1105		}
1106		a.summarizeProvider = newSummarizeProvider
1107		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1108	}
1109
1110	return nil
1111}
1112
1113func (a *agent) setupEvents(ctx context.Context) {
1114	ctx, cancel := context.WithCancel(ctx)
1115
1116	go func() {
1117		subCh := SubscribeMCPEvents(ctx)
1118
1119		for {
1120			select {
1121			case event, ok := <-subCh:
1122				if !ok {
1123					slog.Debug("MCPEvents subscription channel closed")
1124					return
1125				}
1126				switch event.Payload.Type {
1127				case MCPEventToolsListChanged:
1128					name := event.Payload.Name
1129					c, ok := mcpClients.Get(name)
1130					if !ok {
1131						slog.Warn("MCP client not found for tools update", "name", name)
1132						continue
1133					}
1134					cfg := config.Get()
1135					tools := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1136					updateMcpTools(name, tools)
1137					// Update the lazy map with the new tools
1138					a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2()))
1139					updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1140				default:
1141					continue
1142				}
1143			case <-ctx.Done():
1144				slog.Debug("MCPEvents subscription cancelled")
1145				return
1146			}
1147		}
1148	}()
1149
1150	a.cleanupFuncs = append(a.cleanupFuncs, func() {
1151		cancel()
1152	})
1153}