agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"maps"
   9	"slices"
  10	"strings"
  11	"time"
  12
  13	"github.com/charmbracelet/catwalk/pkg/catwalk"
  14	"github.com/charmbracelet/crush/internal/config"
  15	"github.com/charmbracelet/crush/internal/csync"
  16	"github.com/charmbracelet/crush/internal/event"
  17	"github.com/charmbracelet/crush/internal/history"
  18	"github.com/charmbracelet/crush/internal/llm/prompt"
  19	"github.com/charmbracelet/crush/internal/llm/provider"
  20	"github.com/charmbracelet/crush/internal/llm/tools"
  21	"github.com/charmbracelet/crush/internal/log"
  22	"github.com/charmbracelet/crush/internal/lsp"
  23	"github.com/charmbracelet/crush/internal/message"
  24	"github.com/charmbracelet/crush/internal/permission"
  25	"github.com/charmbracelet/crush/internal/pubsub"
  26	"github.com/charmbracelet/crush/internal/session"
  27	"github.com/charmbracelet/crush/internal/shell"
  28)
  29
  30type AgentEventType string
  31
  32const (
  33	AgentEventTypeError     AgentEventType = "error"
  34	AgentEventTypeResponse  AgentEventType = "response"
  35	AgentEventTypeSummarize AgentEventType = "summarize"
  36)
  37
  38type AgentEvent struct {
  39	Type    AgentEventType
  40	Message message.Message
  41	Error   error
  42
  43	// When summarizing
  44	SessionID string
  45	Progress  string
  46	Done      bool
  47}
  48
  49type Service interface {
  50	pubsub.Suscriber[AgentEvent]
  51	Model() catwalk.Model
  52	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  53	Cancel(sessionID string)
  54	CancelAll()
  55	IsSessionBusy(sessionID string) bool
  56	IsBusy() bool
  57	Summarize(ctx context.Context, sessionID string) error
  58	UpdateModel() error
  59	QueuedPrompts(sessionID string) int
  60	ClearQueue(sessionID string)
  61}
  62
  63type agent struct {
  64	*pubsub.Broker[AgentEvent]
  65	agentCfg    config.Agent
  66	sessions    session.Service
  67	messages    message.Service
  68	permissions permission.Service
  69	baseTools   *csync.Map[string, tools.BaseTool]
  70	mcpTools    *csync.Map[string, tools.BaseTool]
  71	lspClients  *csync.Map[string, *lsp.Client]
  72
  73	// We need this to be able to update it when model changes
  74	agentToolFn  func() (tools.BaseTool, error)
  75	cleanupFuncs []func()
  76
  77	provider   provider.Provider
  78	providerID string
  79
  80	titleProvider       provider.Provider
  81	summarizeProvider   provider.Provider
  82	summarizeProviderID string
  83
  84	activeRequests *csync.Map[string, context.CancelFunc]
  85	promptQueue    *csync.Map[string, []string]
  86}
  87
  88var agentPromptMap = map[string]prompt.PromptID{
  89	"coder": prompt.PromptCoder,
  90	"task":  prompt.PromptTask,
  91}
  92
  93func NewAgent(
  94	ctx context.Context,
  95	agentCfg config.Agent,
  96	// These services are needed in the tools
  97	permissions permission.Service,
  98	sessions session.Service,
  99	messages message.Service,
 100	history history.Service,
 101	lspClients *csync.Map[string, *lsp.Client],
 102) (Service, error) {
 103	cfg := config.Get()
 104
 105	var agentToolFn func() (tools.BaseTool, error)
 106	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 107		agentToolFn = func() (tools.BaseTool, error) {
 108			taskAgentCfg := config.Get().Agents["task"]
 109			if taskAgentCfg.ID == "" {
 110				return nil, fmt.Errorf("task agent not found in config")
 111			}
 112			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 113			if err != nil {
 114				return nil, fmt.Errorf("failed to create task agent: %w", err)
 115			}
 116			return NewAgentTool(taskAgent, sessions, messages), nil
 117		}
 118	}
 119
 120	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 121	if providerCfg == nil {
 122		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 123	}
 124	model := config.Get().GetModelByType(agentCfg.Model)
 125
 126	if model == nil {
 127		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 128	}
 129
 130	promptID := agentPromptMap[agentCfg.ID]
 131	if promptID == "" {
 132		promptID = prompt.PromptDefault
 133	}
 134	opts := []provider.ProviderClientOption{
 135		provider.WithModel(agentCfg.Model),
 136		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 137	}
 138	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 139	if err != nil {
 140		return nil, err
 141	}
 142
 143	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 144	var smallModelProviderCfg *config.ProviderConfig
 145	if smallModelCfg.Provider == providerCfg.ID {
 146		smallModelProviderCfg = providerCfg
 147	} else {
 148		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 149
 150		if smallModelProviderCfg.ID == "" {
 151			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 152		}
 153	}
 154	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 155	if smallModel.ID == "" {
 156		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 157	}
 158
 159	titleOpts := []provider.ProviderClientOption{
 160		provider.WithModel(config.SelectedModelTypeSmall),
 161		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 162	}
 163	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 164	if err != nil {
 165		return nil, err
 166	}
 167
 168	summarizeOpts := []provider.ProviderClientOption{
 169		provider.WithModel(config.SelectedModelTypeLarge),
 170		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 171	}
 172	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 173	if err != nil {
 174		return nil, err
 175	}
 176
 177	baseToolsFn := func() map[string]tools.BaseTool {
 178		slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
 179		defer func() {
 180			slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
 181		}()
 182
 183		// Base tools available to all agents
 184		cwd := cfg.WorkingDir()
 185		result := make(map[string]tools.BaseTool)
 186		for _, tool := range []tools.BaseTool{
 187			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 188			tools.NewDownloadTool(permissions, cwd),
 189			tools.NewEditTool(lspClients, permissions, history, cwd),
 190			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 191			tools.NewFetchTool(permissions, cwd),
 192			tools.NewGlobTool(cwd),
 193			tools.NewGrepTool(cwd),
 194			tools.NewLsTool(permissions, cwd),
 195			tools.NewSourcegraphTool(),
 196			tools.NewViewTool(lspClients, permissions, cwd),
 197			tools.NewWriteTool(lspClients, permissions, history, cwd),
 198		} {
 199			result[tool.Name()] = tool
 200		}
 201		return result
 202	}
 203	mcpToolsFn := func() map[string]tools.BaseTool {
 204		slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
 205		defer func() {
 206			slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
 207		}()
 208
 209		mcpToolsOnce.Do(func() {
 210			doGetMCPTools(ctx, permissions, cfg)
 211		})
 212
 213		return maps.Collect(mcpTools.Seq2())
 214	}
 215
 216	a := &agent{
 217		Broker:              pubsub.NewBroker[AgentEvent](),
 218		agentCfg:            agentCfg,
 219		provider:            agentProvider,
 220		providerID:          string(providerCfg.ID),
 221		messages:            messages,
 222		sessions:            sessions,
 223		titleProvider:       titleProvider,
 224		summarizeProvider:   summarizeProvider,
 225		summarizeProviderID: string(providerCfg.ID),
 226		agentToolFn:         agentToolFn,
 227		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 228		mcpTools:            csync.NewLazyMap(mcpToolsFn),
 229		baseTools:           csync.NewLazyMap(baseToolsFn),
 230		promptQueue:         csync.NewMap[string, []string](),
 231		permissions:         permissions,
 232		lspClients:          lspClients,
 233	}
 234	a.setupEvents(ctx)
 235	return a, nil
 236}
 237
 238func (a *agent) Model() catwalk.Model {
 239	return *config.Get().GetModelByType(a.agentCfg.Model)
 240}
 241
 242func (a *agent) Cancel(sessionID string) {
 243	// Cancel regular requests
 244	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 245		slog.Info("Request cancellation initiated", "session_id", sessionID)
 246		cancel()
 247	}
 248
 249	// Also check for summarize requests
 250	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 251		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 252		cancel()
 253	}
 254
 255	if a.QueuedPrompts(sessionID) > 0 {
 256		slog.Info("Clearing queued prompts", "session_id", sessionID)
 257		a.promptQueue.Del(sessionID)
 258	}
 259}
 260
 261func (a *agent) IsBusy() bool {
 262	var busy bool
 263	for cancelFunc := range a.activeRequests.Seq() {
 264		if cancelFunc != nil {
 265			busy = true
 266			break
 267		}
 268	}
 269	return busy
 270}
 271
 272func (a *agent) IsSessionBusy(sessionID string) bool {
 273	_, busy := a.activeRequests.Get(sessionID)
 274	return busy
 275}
 276
 277func (a *agent) QueuedPrompts(sessionID string) int {
 278	l, ok := a.promptQueue.Get(sessionID)
 279	if !ok {
 280		return 0
 281	}
 282	return len(l)
 283}
 284
 285func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 286	if content == "" {
 287		return nil
 288	}
 289	if a.titleProvider == nil {
 290		return nil
 291	}
 292	session, err := a.sessions.Get(ctx, sessionID)
 293	if err != nil {
 294		return err
 295	}
 296	parts := []message.ContentPart{message.TextContent{
 297		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 298	}}
 299
 300	// Use streaming approach like summarization
 301	response := a.titleProvider.StreamResponse(
 302		ctx,
 303		[]message.Message{
 304			{
 305				Role:  message.User,
 306				Parts: parts,
 307			},
 308		},
 309		nil,
 310	)
 311
 312	var finalResponse *provider.ProviderResponse
 313	for r := range response {
 314		if r.Error != nil {
 315			return r.Error
 316		}
 317		finalResponse = r.Response
 318	}
 319
 320	if finalResponse == nil {
 321		return fmt.Errorf("no response received from title provider")
 322	}
 323
 324	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 325
 326	if idx := strings.Index(title, "</think>"); idx > 0 {
 327		title = title[idx+len("</think>"):]
 328	}
 329
 330	title = strings.TrimSpace(title)
 331	if title == "" {
 332		return nil
 333	}
 334
 335	session.Title = title
 336	_, err = a.sessions.Save(ctx, session)
 337	return err
 338}
 339
 340func (a *agent) err(err error) AgentEvent {
 341	return AgentEvent{
 342		Type:  AgentEventTypeError,
 343		Error: err,
 344	}
 345}
 346
 347func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 348	if !a.Model().SupportsImages && attachments != nil {
 349		attachments = nil
 350	}
 351	events := make(chan AgentEvent, 1)
 352	if a.IsSessionBusy(sessionID) {
 353		existing, ok := a.promptQueue.Get(sessionID)
 354		if !ok {
 355			existing = []string{}
 356		}
 357		existing = append(existing, content)
 358		a.promptQueue.Set(sessionID, existing)
 359		return nil, nil
 360	}
 361
 362	genCtx, cancel := context.WithCancel(ctx)
 363	a.activeRequests.Set(sessionID, cancel)
 364	startTime := time.Now()
 365
 366	go func() {
 367		slog.Debug("Request started", "sessionID", sessionID)
 368		defer log.RecoverPanic("agent.Run", func() {
 369			events <- a.err(fmt.Errorf("panic while running the agent"))
 370		})
 371		var attachmentParts []message.ContentPart
 372		for _, attachment := range attachments {
 373			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 374		}
 375		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 376		if result.Error != nil {
 377			if isCancelledErr(result.Error) {
 378				slog.Error("Request canceled", "sessionID", sessionID)
 379			} else {
 380				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 381				event.Error(result.Error)
 382			}
 383		} else {
 384			slog.Debug("Request completed", "sessionID", sessionID)
 385		}
 386		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 387		a.activeRequests.Del(sessionID)
 388		cancel()
 389		a.Publish(pubsub.CreatedEvent, result)
 390		events <- result
 391		close(events)
 392	}()
 393	a.eventPromptSent(sessionID)
 394	return events, nil
 395}
 396
 397func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 398	cfg := config.Get()
 399	// List existing messages; if none, start title generation asynchronously.
 400	msgs, err := a.messages.List(ctx, sessionID)
 401	if err != nil {
 402		return a.err(fmt.Errorf("failed to list messages: %w", err))
 403	}
 404	if len(msgs) == 0 {
 405		go func() {
 406			defer log.RecoverPanic("agent.Run", func() {
 407				slog.Error("panic while generating title")
 408			})
 409			titleErr := a.generateTitle(ctx, sessionID, content)
 410			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 411				slog.Error("failed to generate title", "error", titleErr)
 412			}
 413		}()
 414	}
 415	session, err := a.sessions.Get(ctx, sessionID)
 416	if err != nil {
 417		return a.err(fmt.Errorf("failed to get session: %w", err))
 418	}
 419	if session.SummaryMessageID != "" {
 420		summaryMsgInex := -1
 421		for i, msg := range msgs {
 422			if msg.ID == session.SummaryMessageID {
 423				summaryMsgInex = i
 424				break
 425			}
 426		}
 427		if summaryMsgInex != -1 {
 428			msgs = msgs[summaryMsgInex:]
 429			msgs[0].Role = message.User
 430		}
 431	}
 432
 433	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 434	if err != nil {
 435		return a.err(fmt.Errorf("failed to create user message: %w", err))
 436	}
 437	// Append the new user message to the conversation history.
 438	msgHistory := append(msgs, userMsg)
 439
 440	for {
 441		// Check for cancellation before each iteration
 442		select {
 443		case <-ctx.Done():
 444			return a.err(ctx.Err())
 445		default:
 446			// Continue processing
 447		}
 448		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 449		if err != nil {
 450			if errors.Is(err, context.Canceled) {
 451				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 452				a.messages.Update(context.Background(), agentMessage)
 453				return a.err(ErrRequestCancelled)
 454			}
 455			return a.err(fmt.Errorf("failed to process events: %w", err))
 456		}
 457		if cfg.Options.Debug {
 458			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 459		}
 460		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 461			// We are not done, we need to respond with the tool response
 462			msgHistory = append(msgHistory, agentMessage, *toolResults)
 463			// If there are queued prompts, process the next one
 464			nextPrompt, ok := a.promptQueue.Take(sessionID)
 465			if ok {
 466				for _, prompt := range nextPrompt {
 467					// Create a new user message for the queued prompt
 468					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 469					if err != nil {
 470						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 471					}
 472					// Append the new user message to the conversation history
 473					msgHistory = append(msgHistory, userMsg)
 474				}
 475			}
 476
 477			continue
 478		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 479			queuePrompts, ok := a.promptQueue.Take(sessionID)
 480			if ok {
 481				for _, prompt := range queuePrompts {
 482					if prompt == "" {
 483						continue
 484					}
 485					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 486					if err != nil {
 487						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 488					}
 489					msgHistory = append(msgHistory, userMsg)
 490				}
 491				continue
 492			}
 493		}
 494		if agentMessage.FinishReason() == "" {
 495			// Kujtim: could not track down where this is happening but this means its cancelled
 496			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 497			_ = a.messages.Update(context.Background(), agentMessage)
 498			return a.err(ErrRequestCancelled)
 499		}
 500		return AgentEvent{
 501			Type:    AgentEventTypeResponse,
 502			Message: agentMessage,
 503			Done:    true,
 504		}
 505	}
 506}
 507
 508func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 509	parts := []message.ContentPart{message.TextContent{Text: content}}
 510	parts = append(parts, attachmentParts...)
 511	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 512		Role:  message.User,
 513		Parts: parts,
 514	})
 515}
 516
 517func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 518	var allTools []tools.BaseTool
 519	for tool := range a.baseTools.Seq() {
 520		if a.agentCfg.AllowedTools == nil || slices.Contains(a.agentCfg.AllowedTools, tool.Name()) {
 521			allTools = append(allTools, tool)
 522		}
 523	}
 524	if a.agentCfg.ID == "coder" {
 525		allTools = slices.AppendSeq(allTools, a.mcpTools.Seq())
 526		if a.lspClients.Len() > 0 {
 527			allTools = append(allTools, tools.NewDiagnosticsTool(a.lspClients))
 528		}
 529	}
 530	if a.agentToolFn != nil {
 531		agentTool, agentToolErr := a.agentToolFn()
 532		if agentToolErr != nil {
 533			return nil, agentToolErr
 534		}
 535		allTools = append(allTools, agentTool)
 536	}
 537	return allTools, nil
 538}
 539
 540func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 541	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 542
 543	// Create the assistant message first so the spinner shows immediately
 544	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 545		Role:     message.Assistant,
 546		Parts:    []message.ContentPart{},
 547		Model:    a.Model().ID,
 548		Provider: a.providerID,
 549	})
 550	if err != nil {
 551		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 552	}
 553
 554	allTools, toolsErr := a.getAllTools()
 555	if toolsErr != nil {
 556		return assistantMsg, nil, toolsErr
 557	}
 558	// Now collect tools (which may block on MCP initialization)
 559	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 560
 561	// Add the session and message ID into the context if needed by tools.
 562	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 563
 564loop:
 565	for {
 566		select {
 567		case event, ok := <-eventChan:
 568			if !ok {
 569				break loop
 570			}
 571			if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 572				if errors.Is(processErr, context.Canceled) {
 573					a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 574				} else {
 575					a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 576				}
 577				return assistantMsg, nil, processErr
 578			}
 579		case <-ctx.Done():
 580			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 581			return assistantMsg, nil, ctx.Err()
 582		}
 583	}
 584
 585	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 586	toolCalls := assistantMsg.ToolCalls()
 587	for i, toolCall := range toolCalls {
 588		select {
 589		case <-ctx.Done():
 590			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 591			// Make all future tool calls cancelled
 592			for j := i; j < len(toolCalls); j++ {
 593				toolResults[j] = message.ToolResult{
 594					ToolCallID: toolCalls[j].ID,
 595					Content:    "Tool execution canceled by user",
 596					IsError:    true,
 597				}
 598			}
 599			goto out
 600		default:
 601			// Continue processing
 602			var tool tools.BaseTool
 603			allTools, _ = a.getAllTools()
 604			for _, availableTool := range allTools {
 605				if availableTool.Info().Name == toolCall.Name {
 606					tool = availableTool
 607					break
 608				}
 609			}
 610
 611			// Tool not found
 612			if tool == nil {
 613				toolResults[i] = message.ToolResult{
 614					ToolCallID: toolCall.ID,
 615					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 616					IsError:    true,
 617				}
 618				continue
 619			}
 620
 621			// Run tool in goroutine to allow cancellation
 622			type toolExecResult struct {
 623				response tools.ToolResponse
 624				err      error
 625			}
 626			resultChan := make(chan toolExecResult, 1)
 627
 628			go func() {
 629				response, err := tool.Run(ctx, tools.ToolCall{
 630					ID:    toolCall.ID,
 631					Name:  toolCall.Name,
 632					Input: toolCall.Input,
 633				})
 634				resultChan <- toolExecResult{response: response, err: err}
 635			}()
 636
 637			var toolResponse tools.ToolResponse
 638			var toolErr error
 639
 640			select {
 641			case <-ctx.Done():
 642				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 643				// Mark remaining tool calls as cancelled
 644				for j := i; j < len(toolCalls); j++ {
 645					toolResults[j] = message.ToolResult{
 646						ToolCallID: toolCalls[j].ID,
 647						Content:    "Tool execution canceled by user",
 648						IsError:    true,
 649					}
 650				}
 651				goto out
 652			case result := <-resultChan:
 653				toolResponse = result.response
 654				toolErr = result.err
 655			}
 656
 657			if toolErr != nil {
 658				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 659				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 660					toolResults[i] = message.ToolResult{
 661						ToolCallID: toolCall.ID,
 662						Content:    "Permission denied",
 663						IsError:    true,
 664					}
 665					for j := i + 1; j < len(toolCalls); j++ {
 666						toolResults[j] = message.ToolResult{
 667							ToolCallID: toolCalls[j].ID,
 668							Content:    "Tool execution canceled by user",
 669							IsError:    true,
 670						}
 671					}
 672					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 673					break
 674				}
 675			}
 676			toolResults[i] = message.ToolResult{
 677				ToolCallID: toolCall.ID,
 678				Content:    toolResponse.Content,
 679				Metadata:   toolResponse.Metadata,
 680				IsError:    toolResponse.IsError,
 681			}
 682		}
 683	}
 684out:
 685	if len(toolResults) == 0 {
 686		return assistantMsg, nil, nil
 687	}
 688	parts := make([]message.ContentPart, 0)
 689	for _, tr := range toolResults {
 690		parts = append(parts, tr)
 691	}
 692	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 693		Role:     message.Tool,
 694		Parts:    parts,
 695		Provider: a.providerID,
 696	})
 697	if err != nil {
 698		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 699	}
 700
 701	return assistantMsg, &msg, err
 702}
 703
 704func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 705	msg.AddFinish(finishReason, message, details)
 706	_ = a.messages.Update(ctx, *msg)
 707}
 708
 709func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 710	select {
 711	case <-ctx.Done():
 712		return ctx.Err()
 713	default:
 714		// Continue processing.
 715	}
 716
 717	switch event.Type {
 718	case provider.EventThinkingDelta:
 719		assistantMsg.AppendReasoningContent(event.Thinking)
 720		return a.messages.Update(ctx, *assistantMsg)
 721	case provider.EventSignatureDelta:
 722		assistantMsg.AppendReasoningSignature(event.Signature)
 723		return a.messages.Update(ctx, *assistantMsg)
 724	case provider.EventContentDelta:
 725		assistantMsg.FinishThinking()
 726		assistantMsg.AppendContent(event.Content)
 727		return a.messages.Update(ctx, *assistantMsg)
 728	case provider.EventToolUseStart:
 729		assistantMsg.FinishThinking()
 730		slog.Info("Tool call started", "toolCall", event.ToolCall)
 731		assistantMsg.AddToolCall(*event.ToolCall)
 732		return a.messages.Update(ctx, *assistantMsg)
 733	case provider.EventToolUseDelta:
 734		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 735		return a.messages.Update(ctx, *assistantMsg)
 736	case provider.EventToolUseStop:
 737		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 738		assistantMsg.FinishToolCall(event.ToolCall.ID)
 739		return a.messages.Update(ctx, *assistantMsg)
 740	case provider.EventError:
 741		return event.Error
 742	case provider.EventComplete:
 743		assistantMsg.FinishThinking()
 744		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 745		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 746		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 747			return fmt.Errorf("failed to update message: %w", err)
 748		}
 749		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 750	}
 751
 752	return nil
 753}
 754
 755func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 756	sess, err := a.sessions.Get(ctx, sessionID)
 757	if err != nil {
 758		return fmt.Errorf("failed to get session: %w", err)
 759	}
 760
 761	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 762		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 763		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 764		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 765
 766	a.eventTokensUsed(sessionID, usage, cost)
 767
 768	sess.Cost += cost
 769	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 770	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 771
 772	_, err = a.sessions.Save(ctx, sess)
 773	if err != nil {
 774		return fmt.Errorf("failed to save session: %w", err)
 775	}
 776	return nil
 777}
 778
 779func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 780	if a.summarizeProvider == nil {
 781		return fmt.Errorf("summarize provider not available")
 782	}
 783
 784	// Check if session is busy
 785	if a.IsSessionBusy(sessionID) {
 786		return ErrSessionBusy
 787	}
 788
 789	// Create a new context with cancellation
 790	summarizeCtx, cancel := context.WithCancel(ctx)
 791
 792	// Store the cancel function in activeRequests to allow cancellation
 793	a.activeRequests.Set(sessionID+"-summarize", cancel)
 794
 795	go func() {
 796		defer a.activeRequests.Del(sessionID + "-summarize")
 797		defer cancel()
 798		event := AgentEvent{
 799			Type:     AgentEventTypeSummarize,
 800			Progress: "Starting summarization...",
 801		}
 802
 803		a.Publish(pubsub.CreatedEvent, event)
 804		// Get all messages from the session
 805		msgs, err := a.messages.List(summarizeCtx, sessionID)
 806		if err != nil {
 807			event = AgentEvent{
 808				Type:  AgentEventTypeError,
 809				Error: fmt.Errorf("failed to list messages: %w", err),
 810				Done:  true,
 811			}
 812			a.Publish(pubsub.CreatedEvent, event)
 813			return
 814		}
 815		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 816
 817		if len(msgs) == 0 {
 818			event = AgentEvent{
 819				Type:  AgentEventTypeError,
 820				Error: fmt.Errorf("no messages to summarize"),
 821				Done:  true,
 822			}
 823			a.Publish(pubsub.CreatedEvent, event)
 824			return
 825		}
 826
 827		event = AgentEvent{
 828			Type:     AgentEventTypeSummarize,
 829			Progress: "Analyzing conversation...",
 830		}
 831		a.Publish(pubsub.CreatedEvent, event)
 832
 833		// Add a system message to guide the summarization
 834		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 835
 836		// Create a new message with the summarize prompt
 837		promptMsg := message.Message{
 838			Role:  message.User,
 839			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 840		}
 841
 842		// Append the prompt to the messages
 843		msgsWithPrompt := append(msgs, promptMsg)
 844
 845		event = AgentEvent{
 846			Type:     AgentEventTypeSummarize,
 847			Progress: "Generating summary...",
 848		}
 849
 850		a.Publish(pubsub.CreatedEvent, event)
 851
 852		// Send the messages to the summarize provider
 853		response := a.summarizeProvider.StreamResponse(
 854			summarizeCtx,
 855			msgsWithPrompt,
 856			nil,
 857		)
 858		var finalResponse *provider.ProviderResponse
 859		for r := range response {
 860			if r.Error != nil {
 861				event = AgentEvent{
 862					Type:  AgentEventTypeError,
 863					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 864					Done:  true,
 865				}
 866				a.Publish(pubsub.CreatedEvent, event)
 867				return
 868			}
 869			finalResponse = r.Response
 870		}
 871
 872		summary := strings.TrimSpace(finalResponse.Content)
 873		if summary == "" {
 874			event = AgentEvent{
 875				Type:  AgentEventTypeError,
 876				Error: fmt.Errorf("empty summary returned"),
 877				Done:  true,
 878			}
 879			a.Publish(pubsub.CreatedEvent, event)
 880			return
 881		}
 882		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 883		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 884		event = AgentEvent{
 885			Type:     AgentEventTypeSummarize,
 886			Progress: "Creating new session...",
 887		}
 888
 889		a.Publish(pubsub.CreatedEvent, event)
 890		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 891		if err != nil {
 892			event = AgentEvent{
 893				Type:  AgentEventTypeError,
 894				Error: fmt.Errorf("failed to get session: %w", err),
 895				Done:  true,
 896			}
 897
 898			a.Publish(pubsub.CreatedEvent, event)
 899			return
 900		}
 901		// Create a message in the new session with the summary
 902		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 903			Role: message.Assistant,
 904			Parts: []message.ContentPart{
 905				message.TextContent{Text: summary},
 906				message.Finish{
 907					Reason: message.FinishReasonEndTurn,
 908					Time:   time.Now().Unix(),
 909				},
 910			},
 911			Model:    a.summarizeProvider.Model().ID,
 912			Provider: a.summarizeProviderID,
 913		})
 914		if err != nil {
 915			event = AgentEvent{
 916				Type:  AgentEventTypeError,
 917				Error: fmt.Errorf("failed to create summary message: %w", err),
 918				Done:  true,
 919			}
 920
 921			a.Publish(pubsub.CreatedEvent, event)
 922			return
 923		}
 924		oldSession.SummaryMessageID = msg.ID
 925		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 926		oldSession.PromptTokens = 0
 927		model := a.summarizeProvider.Model()
 928		usage := finalResponse.Usage
 929		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 930			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 931			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 932			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 933		oldSession.Cost += cost
 934		_, err = a.sessions.Save(summarizeCtx, oldSession)
 935		if err != nil {
 936			event = AgentEvent{
 937				Type:  AgentEventTypeError,
 938				Error: fmt.Errorf("failed to save session: %w", err),
 939				Done:  true,
 940			}
 941			a.Publish(pubsub.CreatedEvent, event)
 942		}
 943
 944		event = AgentEvent{
 945			Type:      AgentEventTypeSummarize,
 946			SessionID: oldSession.ID,
 947			Progress:  "Summary complete",
 948			Done:      true,
 949		}
 950		a.Publish(pubsub.CreatedEvent, event)
 951		// Send final success event with the new session ID
 952	}()
 953
 954	return nil
 955}
 956
 957func (a *agent) ClearQueue(sessionID string) {
 958	if a.QueuedPrompts(sessionID) > 0 {
 959		slog.Info("Clearing queued prompts", "session_id", sessionID)
 960		a.promptQueue.Del(sessionID)
 961	}
 962}
 963
 964func (a *agent) CancelAll() {
 965	if !a.IsBusy() {
 966		return
 967	}
 968	for key := range a.activeRequests.Seq2() {
 969		a.Cancel(key) // key is sessionID
 970	}
 971
 972	for _, cleanup := range a.cleanupFuncs {
 973		if cleanup != nil {
 974			cleanup()
 975		}
 976	}
 977
 978	timeout := time.After(5 * time.Second)
 979	for a.IsBusy() {
 980		select {
 981		case <-timeout:
 982			return
 983		default:
 984			time.Sleep(200 * time.Millisecond)
 985		}
 986	}
 987}
 988
 989func (a *agent) UpdateModel() error {
 990	cfg := config.Get()
 991
 992	// Get current provider configuration
 993	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 994	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 995		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 996	}
 997
 998	// Check if provider has changed
 999	if string(currentProviderCfg.ID) != a.providerID {
1000		// Provider changed, need to recreate the main provider
1001		model := cfg.GetModelByType(a.agentCfg.Model)
1002		if model.ID == "" {
1003			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
1004		}
1005
1006		promptID := agentPromptMap[a.agentCfg.ID]
1007		if promptID == "" {
1008			promptID = prompt.PromptDefault
1009		}
1010
1011		opts := []provider.ProviderClientOption{
1012			provider.WithModel(a.agentCfg.Model),
1013			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
1014		}
1015
1016		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
1017		if err != nil {
1018			return fmt.Errorf("failed to create new provider: %w", err)
1019		}
1020
1021		// Update the provider and provider ID
1022		a.provider = newProvider
1023		a.providerID = string(currentProviderCfg.ID)
1024	}
1025
1026	// Check if providers have changed for title (small) and summarize (large)
1027	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1028	var smallModelProviderCfg config.ProviderConfig
1029	for p := range cfg.Providers.Seq() {
1030		if p.ID == smallModelCfg.Provider {
1031			smallModelProviderCfg = p
1032			break
1033		}
1034	}
1035	if smallModelProviderCfg.ID == "" {
1036		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1037	}
1038
1039	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1040	var largeModelProviderCfg config.ProviderConfig
1041	for p := range cfg.Providers.Seq() {
1042		if p.ID == largeModelCfg.Provider {
1043			largeModelProviderCfg = p
1044			break
1045		}
1046	}
1047	if largeModelProviderCfg.ID == "" {
1048		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1049	}
1050
1051	var maxTitleTokens int64 = 40
1052
1053	// if the max output is too low for the gemini provider it won't return anything
1054	if smallModelCfg.Provider == "gemini" {
1055		maxTitleTokens = 1000
1056	}
1057	// Recreate title provider
1058	titleOpts := []provider.ProviderClientOption{
1059		provider.WithModel(config.SelectedModelTypeSmall),
1060		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1061		provider.WithMaxTokens(maxTitleTokens),
1062	}
1063	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1064	if err != nil {
1065		return fmt.Errorf("failed to create new title provider: %w", err)
1066	}
1067	a.titleProvider = newTitleProvider
1068
1069	// Recreate summarize provider if provider changed (now large model)
1070	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1071		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1072		if largeModel == nil {
1073			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1074		}
1075		summarizeOpts := []provider.ProviderClientOption{
1076			provider.WithModel(config.SelectedModelTypeLarge),
1077			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1078		}
1079		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1080		if err != nil {
1081			return fmt.Errorf("failed to create new summarize provider: %w", err)
1082		}
1083		a.summarizeProvider = newSummarizeProvider
1084		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1085	}
1086
1087	return nil
1088}
1089
1090func (a *agent) setupEvents(ctx context.Context) {
1091	ctx, cancel := context.WithCancel(ctx)
1092
1093	go func() {
1094		subCh := SubscribeMCPEvents(ctx)
1095
1096		for {
1097			select {
1098			case event, ok := <-subCh:
1099				if !ok {
1100					slog.Debug("MCPEvents subscription channel closed")
1101					return
1102				}
1103				switch event.Payload.Type {
1104				case MCPEventToolsListChanged:
1105					name := event.Payload.Name
1106					c, ok := mcpClients.Get(name)
1107					if !ok {
1108						slog.Warn("MCP client not found for tools update", "name", name)
1109						continue
1110					}
1111					cfg := config.Get()
1112					tools, err := getTools(ctx, name, a.permissions, c, cfg.WorkingDir())
1113					if err != nil {
1114						slog.Error("error listing tools", "error", err)
1115						updateMCPState(name, MCPStateError, err, nil, 0)
1116						_ = c.Close()
1117						continue
1118					}
1119					updateMcpTools(name, tools)
1120					a.mcpTools.Reset(maps.Collect(mcpTools.Seq2()))
1121					updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1122				default:
1123					continue
1124				}
1125			case <-ctx.Done():
1126				slog.Debug("MCPEvents subscription cancelled")
1127				return
1128			}
1129		}
1130	}()
1131
1132	a.cleanupFuncs = append(a.cleanupFuncs, cancel)
1133}