agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"maps"
   9	"slices"
  10	"strings"
  11	"time"
  12
  13	"github.com/charmbracelet/catwalk/pkg/catwalk"
  14	"github.com/charmbracelet/crush/internal/config"
  15	"github.com/charmbracelet/crush/internal/csync"
  16	"github.com/charmbracelet/crush/internal/history"
  17	"github.com/charmbracelet/crush/internal/llm/prompt"
  18	"github.com/charmbracelet/crush/internal/llm/provider"
  19	"github.com/charmbracelet/crush/internal/llm/tools"
  20	"github.com/charmbracelet/crush/internal/log"
  21	"github.com/charmbracelet/crush/internal/lsp"
  22	"github.com/charmbracelet/crush/internal/message"
  23	"github.com/charmbracelet/crush/internal/permission"
  24	"github.com/charmbracelet/crush/internal/pubsub"
  25	"github.com/charmbracelet/crush/internal/session"
  26	"github.com/charmbracelet/crush/internal/shell"
  27)
  28
  29// Common errors
  30var (
  31	ErrRequestCancelled = errors.New("request canceled by user")
  32	ErrSessionBusy      = errors.New("session is currently processing another request")
  33)
  34
  35type AgentEventType string
  36
  37const (
  38	AgentEventTypeError     AgentEventType = "error"
  39	AgentEventTypeResponse  AgentEventType = "response"
  40	AgentEventTypeSummarize AgentEventType = "summarize"
  41)
  42
  43type AgentEvent struct {
  44	Type    AgentEventType
  45	Message message.Message
  46	Error   error
  47
  48	// When summarizing
  49	SessionID string
  50	Progress  string
  51	Done      bool
  52}
  53
  54type Service interface {
  55	pubsub.Suscriber[AgentEvent]
  56	Model() catwalk.Model
  57	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  58	Cancel(sessionID string)
  59	CancelAll()
  60	IsSessionBusy(sessionID string) bool
  61	IsBusy() bool
  62	Summarize(ctx context.Context, sessionID string) error
  63	UpdateModel() error
  64	QueuedPrompts(sessionID string) int
  65	ClearQueue(sessionID string)
  66}
  67
  68type agent struct {
  69	globalCtx    context.Context
  70	cleanupFuncs []func()
  71
  72	*pubsub.Broker[AgentEvent]
  73	agentCfg config.Agent
  74	sessions session.Service
  75	messages message.Service
  76
  77	baseTools *csync.Map[string, tools.BaseTool]
  78	mcpTools  *csync.Map[string, tools.BaseTool]
  79
  80	provider   provider.Provider
  81	providerID string
  82
  83	titleProvider       provider.Provider
  84	summarizeProvider   provider.Provider
  85	summarizeProviderID string
  86
  87	activeRequests *csync.Map[string, context.CancelFunc]
  88
  89	promptQueue *csync.Map[string, []string]
  90}
  91
  92var agentPromptMap = map[string]prompt.PromptID{
  93	"coder": prompt.PromptCoder,
  94	"task":  prompt.PromptTask,
  95}
  96
  97func NewAgent(
  98	ctx context.Context,
  99	agentCfg config.Agent,
 100	// These services are needed in the tools
 101	permissions permission.Service,
 102	sessions session.Service,
 103	messages message.Service,
 104	history history.Service,
 105	lspClients map[string]*lsp.Client,
 106) (Service, error) {
 107	cfg := config.Get()
 108
 109	var agentTool tools.BaseTool
 110	if agentCfg.ID == "coder" {
 111		taskAgentCfg := config.Get().Agents["task"]
 112		if taskAgentCfg.ID == "" {
 113			return nil, fmt.Errorf("task agent not found in config")
 114		}
 115		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 116		if err != nil {
 117			return nil, fmt.Errorf("failed to create task agent: %w", err)
 118		}
 119
 120		agentTool = NewAgentTool(taskAgent, sessions, messages)
 121	}
 122
 123	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 124	if providerCfg == nil {
 125		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 126	}
 127	model := config.Get().GetModelByType(agentCfg.Model)
 128
 129	if model == nil {
 130		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 131	}
 132
 133	promptID := agentPromptMap[agentCfg.ID]
 134	if promptID == "" {
 135		promptID = prompt.PromptDefault
 136	}
 137	opts := []provider.ProviderClientOption{
 138		provider.WithModel(agentCfg.Model),
 139		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 140	}
 141	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 142	if err != nil {
 143		return nil, err
 144	}
 145
 146	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 147	var smallModelProviderCfg *config.ProviderConfig
 148	if smallModelCfg.Provider == providerCfg.ID {
 149		smallModelProviderCfg = providerCfg
 150	} else {
 151		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 152
 153		if smallModelProviderCfg.ID == "" {
 154			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 155		}
 156	}
 157	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 158	if smallModel.ID == "" {
 159		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 160	}
 161
 162	titleOpts := []provider.ProviderClientOption{
 163		provider.WithModel(config.SelectedModelTypeSmall),
 164		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 165	}
 166	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 167	if err != nil {
 168		return nil, err
 169	}
 170
 171	summarizeOpts := []provider.ProviderClientOption{
 172		provider.WithModel(config.SelectedModelTypeLarge),
 173		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 174	}
 175	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 176	if err != nil {
 177		return nil, err
 178	}
 179
 180	baseToolsFn := func() map[string]tools.BaseTool {
 181		slog.Info("Initializing agent base tools", "agent", agentCfg.ID)
 182		defer func() {
 183			slog.Info("Initialized agent base tools", "agent", agentCfg.ID)
 184		}()
 185
 186		// Base tools available to all agents
 187		cwd := cfg.WorkingDir()
 188		result := make(map[string]tools.BaseTool)
 189		for _, tool := range []tools.BaseTool{
 190			tools.NewBashTool(permissions, cwd),
 191			tools.NewDownloadTool(permissions, cwd),
 192			tools.NewEditTool(lspClients, permissions, history, cwd),
 193			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 194			tools.NewFetchTool(permissions, cwd),
 195			tools.NewGlobTool(cwd),
 196			tools.NewGrepTool(cwd),
 197			tools.NewLsTool(permissions, cwd),
 198			tools.NewSourcegraphTool(),
 199			tools.NewViewTool(lspClients, permissions, cwd),
 200			tools.NewWriteTool(lspClients, permissions, history, cwd),
 201		} {
 202			result[tool.Name()] = tool
 203		}
 204
 205		if len(lspClients) > 0 {
 206			diagnosticsTool := tools.NewDiagnosticsTool(lspClients)
 207			result[diagnosticsTool.Name()] = diagnosticsTool
 208		}
 209
 210		if agentTool != nil {
 211			result[agentTool.Name()] = agentTool
 212		}
 213
 214		return result
 215	}
 216	mcpToolsFn := func() map[string]tools.BaseTool {
 217		slog.Info("Initializing agent mcp tools", "agent", agentCfg.ID)
 218		defer func() {
 219			slog.Info("Initialized agent mcp tools", "agent", agentCfg.ID)
 220		}()
 221
 222		mcpToolsOnce.Do(func() {
 223			doGetMCPTools(ctx, permissions, cfg)
 224		})
 225
 226		result := make(map[string]tools.BaseTool)
 227		for _, mcpTool := range mcpTools.Seq2() {
 228			result[mcpTool.Name()] = mcpTool
 229		}
 230		return result
 231	}
 232
 233	a := &agent{
 234		globalCtx:           ctx,
 235		Broker:              pubsub.NewBroker[AgentEvent](),
 236		agentCfg:            agentCfg,
 237		provider:            agentProvider,
 238		providerID:          string(providerCfg.ID),
 239		messages:            messages,
 240		sessions:            sessions,
 241		titleProvider:       titleProvider,
 242		summarizeProvider:   summarizeProvider,
 243		summarizeProviderID: string(providerCfg.ID),
 244		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 245		mcpTools:            csync.NewLazyMap(mcpToolsFn),
 246		baseTools:           csync.NewLazyMap(baseToolsFn),
 247		promptQueue:         csync.NewMap[string, []string](),
 248	}
 249	a.setupEvents()
 250	return a, nil
 251}
 252
 253func (a *agent) Model() catwalk.Model {
 254	return *config.Get().GetModelByType(a.agentCfg.Model)
 255}
 256
 257func (a *agent) Cancel(sessionID string) {
 258	// Cancel regular requests
 259	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 260		slog.Info("Request cancellation initiated", "session_id", sessionID)
 261		cancel()
 262	}
 263
 264	// Also check for summarize requests
 265	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 266		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 267		cancel()
 268	}
 269
 270	if a.QueuedPrompts(sessionID) > 0 {
 271		slog.Info("Clearing queued prompts", "session_id", sessionID)
 272		a.promptQueue.Del(sessionID)
 273	}
 274}
 275
 276func (a *agent) IsBusy() bool {
 277	var busy bool
 278	for cancelFunc := range a.activeRequests.Seq() {
 279		if cancelFunc != nil {
 280			busy = true
 281			break
 282		}
 283	}
 284	return busy
 285}
 286
 287func (a *agent) IsSessionBusy(sessionID string) bool {
 288	_, busy := a.activeRequests.Get(sessionID)
 289	return busy
 290}
 291
 292func (a *agent) QueuedPrompts(sessionID string) int {
 293	l, ok := a.promptQueue.Get(sessionID)
 294	if !ok {
 295		return 0
 296	}
 297	return len(l)
 298}
 299
 300func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 301	if content == "" {
 302		return nil
 303	}
 304	if a.titleProvider == nil {
 305		return nil
 306	}
 307	session, err := a.sessions.Get(ctx, sessionID)
 308	if err != nil {
 309		return err
 310	}
 311	parts := []message.ContentPart{message.TextContent{
 312		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 313	}}
 314
 315	// Use streaming approach like summarization
 316	response := a.titleProvider.StreamResponse(
 317		ctx,
 318		[]message.Message{
 319			{
 320				Role:  message.User,
 321				Parts: parts,
 322			},
 323		},
 324		nil,
 325	)
 326
 327	var finalResponse *provider.ProviderResponse
 328	for r := range response {
 329		if r.Error != nil {
 330			return r.Error
 331		}
 332		finalResponse = r.Response
 333	}
 334
 335	if finalResponse == nil {
 336		return fmt.Errorf("no response received from title provider")
 337	}
 338
 339	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 340	if title == "" {
 341		return nil
 342	}
 343
 344	session.Title = title
 345	_, err = a.sessions.Save(ctx, session)
 346	return err
 347}
 348
 349func (a *agent) err(err error) AgentEvent {
 350	return AgentEvent{
 351		Type:  AgentEventTypeError,
 352		Error: err,
 353	}
 354}
 355
 356func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 357	if !a.Model().SupportsImages && attachments != nil {
 358		attachments = nil
 359	}
 360	events := make(chan AgentEvent)
 361	if a.IsSessionBusy(sessionID) {
 362		existing, ok := a.promptQueue.Get(sessionID)
 363		if !ok {
 364			existing = []string{}
 365		}
 366		existing = append(existing, content)
 367		a.promptQueue.Set(sessionID, existing)
 368		return nil, nil
 369	}
 370
 371	genCtx, cancel := context.WithCancel(ctx)
 372
 373	a.activeRequests.Set(sessionID, cancel)
 374	go func() {
 375		slog.Debug("Request started", "sessionID", sessionID)
 376		defer log.RecoverPanic("agent.Run", func() {
 377			events <- a.err(fmt.Errorf("panic while running the agent"))
 378		})
 379		var attachmentParts []message.ContentPart
 380		for _, attachment := range attachments {
 381			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 382		}
 383		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 384		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 385			slog.Error(result.Error.Error())
 386		}
 387		slog.Debug("Request completed", "sessionID", sessionID)
 388		a.activeRequests.Del(sessionID)
 389		cancel()
 390		a.Publish(pubsub.CreatedEvent, result)
 391		select {
 392		case events <- result:
 393		case <-genCtx.Done():
 394		}
 395		close(events)
 396	}()
 397	return events, nil
 398}
 399
 400func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 401	cfg := config.Get()
 402	// List existing messages; if none, start title generation asynchronously.
 403	msgs, err := a.messages.List(ctx, sessionID)
 404	if err != nil {
 405		return a.err(fmt.Errorf("failed to list messages: %w", err))
 406	}
 407	if len(msgs) == 0 {
 408		go func() {
 409			defer log.RecoverPanic("agent.Run", func() {
 410				slog.Error("panic while generating title")
 411			})
 412			titleErr := a.generateTitle(context.Background(), sessionID, content)
 413			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 414				slog.Error("failed to generate title", "error", titleErr)
 415			}
 416		}()
 417	}
 418	session, err := a.sessions.Get(ctx, sessionID)
 419	if err != nil {
 420		return a.err(fmt.Errorf("failed to get session: %w", err))
 421	}
 422	if session.SummaryMessageID != "" {
 423		summaryMsgInex := -1
 424		for i, msg := range msgs {
 425			if msg.ID == session.SummaryMessageID {
 426				summaryMsgInex = i
 427				break
 428			}
 429		}
 430		if summaryMsgInex != -1 {
 431			msgs = msgs[summaryMsgInex:]
 432			msgs[0].Role = message.User
 433		}
 434	}
 435
 436	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 437	if err != nil {
 438		return a.err(fmt.Errorf("failed to create user message: %w", err))
 439	}
 440	// Append the new user message to the conversation history.
 441	msgHistory := append(msgs, userMsg)
 442
 443	for {
 444		// Check for cancellation before each iteration
 445		select {
 446		case <-ctx.Done():
 447			return a.err(ctx.Err())
 448		default:
 449			// Continue processing
 450		}
 451		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 452		if err != nil {
 453			if errors.Is(err, context.Canceled) {
 454				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 455				a.messages.Update(context.Background(), agentMessage)
 456				return a.err(ErrRequestCancelled)
 457			}
 458			return a.err(fmt.Errorf("failed to process events: %w", err))
 459		}
 460		if cfg.Options.Debug {
 461			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 462		}
 463		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 464			// We are not done, we need to respond with the tool response
 465			msgHistory = append(msgHistory, agentMessage, *toolResults)
 466			// If there are queued prompts, process the next one
 467			nextPrompt, ok := a.promptQueue.Take(sessionID)
 468			if ok {
 469				for _, prompt := range nextPrompt {
 470					// Create a new user message for the queued prompt
 471					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 472					if err != nil {
 473						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 474					}
 475					// Append the new user message to the conversation history
 476					msgHistory = append(msgHistory, userMsg)
 477				}
 478			}
 479
 480			continue
 481		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 482			queuePrompts, ok := a.promptQueue.Take(sessionID)
 483			if ok {
 484				for _, prompt := range queuePrompts {
 485					if prompt == "" {
 486						continue
 487					}
 488					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 489					if err != nil {
 490						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 491					}
 492					msgHistory = append(msgHistory, userMsg)
 493				}
 494				continue
 495			}
 496		}
 497		if agentMessage.FinishReason() == "" {
 498			// Kujtim: could not track down where this is happening but this means its cancelled
 499			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 500			_ = a.messages.Update(context.Background(), agentMessage)
 501			return a.err(ErrRequestCancelled)
 502		}
 503		return AgentEvent{
 504			Type:    AgentEventTypeResponse,
 505			Message: agentMessage,
 506			Done:    true,
 507		}
 508	}
 509}
 510
 511func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 512	parts := []message.ContentPart{message.TextContent{Text: content}}
 513	parts = append(parts, attachmentParts...)
 514	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 515		Role:  message.User,
 516		Parts: parts,
 517	})
 518}
 519
 520func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 521	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 522
 523	// Create the assistant message first so the spinner shows immediately
 524	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 525		Role:     message.Assistant,
 526		Parts:    []message.ContentPart{},
 527		Model:    a.Model().ID,
 528		Provider: a.providerID,
 529	})
 530	if err != nil {
 531		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 532	}
 533
 534	// Now collect tools (which may block on MCP initialization)
 535	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.allTools())
 536
 537	// Add the session and message ID into the context if needed by tools.
 538	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 539
 540	// Process each event in the stream.
 541	for event := range eventChan {
 542		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 543			if errors.Is(processErr, context.Canceled) {
 544				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 545			} else {
 546				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 547			}
 548			return assistantMsg, nil, processErr
 549		}
 550		if ctx.Err() != nil {
 551			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 552			return assistantMsg, nil, ctx.Err()
 553		}
 554	}
 555
 556	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 557	toolCalls := assistantMsg.ToolCalls()
 558	for i, toolCall := range toolCalls {
 559		select {
 560		case <-ctx.Done():
 561			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 562			// Make all future tool calls cancelled
 563			for j := i; j < len(toolCalls); j++ {
 564				toolResults[j] = message.ToolResult{
 565					ToolCallID: toolCalls[j].ID,
 566					Content:    "Tool execution canceled by user",
 567					IsError:    true,
 568				}
 569			}
 570			goto out
 571		default:
 572			// Continue processing
 573			tool, ok := a.getToolByName(toolCall.Name)
 574
 575			// Tool not found
 576			if !ok {
 577				toolResults[i] = message.ToolResult{
 578					ToolCallID: toolCall.ID,
 579					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 580					IsError:    true,
 581				}
 582				continue
 583			}
 584
 585			// Run tool in goroutine to allow cancellation
 586			type toolExecResult struct {
 587				response tools.ToolResponse
 588				err      error
 589			}
 590			resultChan := make(chan toolExecResult, 1)
 591
 592			go func() {
 593				response, err := tool.Run(ctx, tools.ToolCall{
 594					ID:    toolCall.ID,
 595					Name:  toolCall.Name,
 596					Input: toolCall.Input,
 597				})
 598				resultChan <- toolExecResult{response: response, err: err}
 599			}()
 600
 601			var toolResponse tools.ToolResponse
 602			var toolErr error
 603
 604			select {
 605			case <-ctx.Done():
 606				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 607				// Mark remaining tool calls as cancelled
 608				for j := i; j < len(toolCalls); j++ {
 609					toolResults[j] = message.ToolResult{
 610						ToolCallID: toolCalls[j].ID,
 611						Content:    "Tool execution canceled by user",
 612						IsError:    true,
 613					}
 614				}
 615				goto out
 616			case result := <-resultChan:
 617				toolResponse = result.response
 618				toolErr = result.err
 619			}
 620
 621			if toolErr != nil {
 622				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 623				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 624					toolResults[i] = message.ToolResult{
 625						ToolCallID: toolCall.ID,
 626						Content:    "Permission denied",
 627						IsError:    true,
 628					}
 629					for j := i + 1; j < len(toolCalls); j++ {
 630						toolResults[j] = message.ToolResult{
 631							ToolCallID: toolCalls[j].ID,
 632							Content:    "Tool execution canceled by user",
 633							IsError:    true,
 634						}
 635					}
 636					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 637					break
 638				}
 639			}
 640			toolResults[i] = message.ToolResult{
 641				ToolCallID: toolCall.ID,
 642				Content:    toolResponse.Content,
 643				Metadata:   toolResponse.Metadata,
 644				IsError:    toolResponse.IsError,
 645			}
 646		}
 647	}
 648out:
 649	if len(toolResults) == 0 {
 650		return assistantMsg, nil, nil
 651	}
 652	parts := make([]message.ContentPart, 0)
 653	for _, tr := range toolResults {
 654		parts = append(parts, tr)
 655	}
 656	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 657		Role:     message.Tool,
 658		Parts:    parts,
 659		Provider: a.providerID,
 660	})
 661	if err != nil {
 662		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 663	}
 664
 665	return assistantMsg, &msg, err
 666}
 667
 668func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 669	msg.AddFinish(finishReason, message, details)
 670	_ = a.messages.Update(ctx, *msg)
 671}
 672
 673func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 674	select {
 675	case <-ctx.Done():
 676		return ctx.Err()
 677	default:
 678		// Continue processing.
 679	}
 680
 681	switch event.Type {
 682	case provider.EventThinkingDelta:
 683		assistantMsg.AppendReasoningContent(event.Thinking)
 684		return a.messages.Update(ctx, *assistantMsg)
 685	case provider.EventSignatureDelta:
 686		assistantMsg.AppendReasoningSignature(event.Signature)
 687		return a.messages.Update(ctx, *assistantMsg)
 688	case provider.EventContentDelta:
 689		assistantMsg.FinishThinking()
 690		assistantMsg.AppendContent(event.Content)
 691		return a.messages.Update(ctx, *assistantMsg)
 692	case provider.EventToolUseStart:
 693		assistantMsg.FinishThinking()
 694		slog.Info("Tool call started", "toolCall", event.ToolCall)
 695		assistantMsg.AddToolCall(*event.ToolCall)
 696		return a.messages.Update(ctx, *assistantMsg)
 697	case provider.EventToolUseDelta:
 698		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 699		return a.messages.Update(ctx, *assistantMsg)
 700	case provider.EventToolUseStop:
 701		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 702		assistantMsg.FinishToolCall(event.ToolCall.ID)
 703		return a.messages.Update(ctx, *assistantMsg)
 704	case provider.EventError:
 705		return event.Error
 706	case provider.EventComplete:
 707		assistantMsg.FinishThinking()
 708		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 709		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 710		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 711			return fmt.Errorf("failed to update message: %w", err)
 712		}
 713		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 714	}
 715
 716	return nil
 717}
 718
 719func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 720	sess, err := a.sessions.Get(ctx, sessionID)
 721	if err != nil {
 722		return fmt.Errorf("failed to get session: %w", err)
 723	}
 724
 725	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 726		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 727		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 728		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 729
 730	sess.Cost += cost
 731	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 732	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 733
 734	_, err = a.sessions.Save(ctx, sess)
 735	if err != nil {
 736		return fmt.Errorf("failed to save session: %w", err)
 737	}
 738	return nil
 739}
 740
 741func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 742	if a.summarizeProvider == nil {
 743		return fmt.Errorf("summarize provider not available")
 744	}
 745
 746	// Check if session is busy
 747	if a.IsSessionBusy(sessionID) {
 748		return ErrSessionBusy
 749	}
 750
 751	// Create a new context with cancellation
 752	summarizeCtx, cancel := context.WithCancel(ctx)
 753
 754	// Store the cancel function in activeRequests to allow cancellation
 755	a.activeRequests.Set(sessionID+"-summarize", cancel)
 756
 757	go func() {
 758		defer a.activeRequests.Del(sessionID + "-summarize")
 759		defer cancel()
 760		event := AgentEvent{
 761			Type:     AgentEventTypeSummarize,
 762			Progress: "Starting summarization...",
 763		}
 764
 765		a.Publish(pubsub.CreatedEvent, event)
 766		// Get all messages from the session
 767		msgs, err := a.messages.List(summarizeCtx, sessionID)
 768		if err != nil {
 769			event = AgentEvent{
 770				Type:  AgentEventTypeError,
 771				Error: fmt.Errorf("failed to list messages: %w", err),
 772				Done:  true,
 773			}
 774			a.Publish(pubsub.CreatedEvent, event)
 775			return
 776		}
 777		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 778
 779		if len(msgs) == 0 {
 780			event = AgentEvent{
 781				Type:  AgentEventTypeError,
 782				Error: fmt.Errorf("no messages to summarize"),
 783				Done:  true,
 784			}
 785			a.Publish(pubsub.CreatedEvent, event)
 786			return
 787		}
 788
 789		event = AgentEvent{
 790			Type:     AgentEventTypeSummarize,
 791			Progress: "Analyzing conversation...",
 792		}
 793		a.Publish(pubsub.CreatedEvent, event)
 794
 795		// Add a system message to guide the summarization
 796		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 797
 798		// Create a new message with the summarize prompt
 799		promptMsg := message.Message{
 800			Role:  message.User,
 801			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 802		}
 803
 804		// Append the prompt to the messages
 805		msgsWithPrompt := append(msgs, promptMsg)
 806
 807		event = AgentEvent{
 808			Type:     AgentEventTypeSummarize,
 809			Progress: "Generating summary...",
 810		}
 811
 812		a.Publish(pubsub.CreatedEvent, event)
 813
 814		// Send the messages to the summarize provider
 815		response := a.summarizeProvider.StreamResponse(
 816			summarizeCtx,
 817			msgsWithPrompt,
 818			nil,
 819		)
 820		var finalResponse *provider.ProviderResponse
 821		for r := range response {
 822			if r.Error != nil {
 823				event = AgentEvent{
 824					Type:  AgentEventTypeError,
 825					Error: fmt.Errorf("failed to summarize: %w", err),
 826					Done:  true,
 827				}
 828				a.Publish(pubsub.CreatedEvent, event)
 829				return
 830			}
 831			finalResponse = r.Response
 832		}
 833
 834		summary := strings.TrimSpace(finalResponse.Content)
 835		if summary == "" {
 836			event = AgentEvent{
 837				Type:  AgentEventTypeError,
 838				Error: fmt.Errorf("empty summary returned"),
 839				Done:  true,
 840			}
 841			a.Publish(pubsub.CreatedEvent, event)
 842			return
 843		}
 844		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 845		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 846		event = AgentEvent{
 847			Type:     AgentEventTypeSummarize,
 848			Progress: "Creating new session...",
 849		}
 850
 851		a.Publish(pubsub.CreatedEvent, event)
 852		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 853		if err != nil {
 854			event = AgentEvent{
 855				Type:  AgentEventTypeError,
 856				Error: fmt.Errorf("failed to get session: %w", err),
 857				Done:  true,
 858			}
 859
 860			a.Publish(pubsub.CreatedEvent, event)
 861			return
 862		}
 863		// Create a message in the new session with the summary
 864		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 865			Role: message.Assistant,
 866			Parts: []message.ContentPart{
 867				message.TextContent{Text: summary},
 868				message.Finish{
 869					Reason: message.FinishReasonEndTurn,
 870					Time:   time.Now().Unix(),
 871				},
 872			},
 873			Model:    a.summarizeProvider.Model().ID,
 874			Provider: a.summarizeProviderID,
 875		})
 876		if err != nil {
 877			event = AgentEvent{
 878				Type:  AgentEventTypeError,
 879				Error: fmt.Errorf("failed to create summary message: %w", err),
 880				Done:  true,
 881			}
 882
 883			a.Publish(pubsub.CreatedEvent, event)
 884			return
 885		}
 886		oldSession.SummaryMessageID = msg.ID
 887		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 888		oldSession.PromptTokens = 0
 889		model := a.summarizeProvider.Model()
 890		usage := finalResponse.Usage
 891		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 892			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 893			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 894			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 895		oldSession.Cost += cost
 896		_, err = a.sessions.Save(summarizeCtx, oldSession)
 897		if err != nil {
 898			event = AgentEvent{
 899				Type:  AgentEventTypeError,
 900				Error: fmt.Errorf("failed to save session: %w", err),
 901				Done:  true,
 902			}
 903			a.Publish(pubsub.CreatedEvent, event)
 904		}
 905
 906		event = AgentEvent{
 907			Type:      AgentEventTypeSummarize,
 908			SessionID: oldSession.ID,
 909			Progress:  "Summary complete",
 910			Done:      true,
 911		}
 912		a.Publish(pubsub.CreatedEvent, event)
 913		// Send final success event with the new session ID
 914	}()
 915
 916	return nil
 917}
 918
 919func (a *agent) ClearQueue(sessionID string) {
 920	if a.QueuedPrompts(sessionID) > 0 {
 921		slog.Info("Clearing queued prompts", "session_id", sessionID)
 922		a.promptQueue.Del(sessionID)
 923	}
 924}
 925
 926func (a *agent) CancelAll() {
 927	if !a.IsBusy() {
 928		return
 929	}
 930	for key := range a.activeRequests.Seq2() {
 931		a.Cancel(key) // key is sessionID
 932	}
 933
 934	for _, cleanup := range a.cleanupFuncs {
 935		if cleanup != nil {
 936			cleanup()
 937		}
 938	}
 939
 940	timeout := time.After(5 * time.Second)
 941	for a.IsBusy() {
 942		select {
 943		case <-timeout:
 944			return
 945		default:
 946			time.Sleep(200 * time.Millisecond)
 947		}
 948	}
 949}
 950
 951func (a *agent) UpdateModel() error {
 952	cfg := config.Get()
 953
 954	// Get current provider configuration
 955	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 956	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 957		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 958	}
 959
 960	// Check if provider has changed
 961	if string(currentProviderCfg.ID) != a.providerID {
 962		// Provider changed, need to recreate the main provider
 963		model := cfg.GetModelByType(a.agentCfg.Model)
 964		if model.ID == "" {
 965			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 966		}
 967
 968		promptID := agentPromptMap[a.agentCfg.ID]
 969		if promptID == "" {
 970			promptID = prompt.PromptDefault
 971		}
 972
 973		opts := []provider.ProviderClientOption{
 974			provider.WithModel(a.agentCfg.Model),
 975			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 976		}
 977
 978		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 979		if err != nil {
 980			return fmt.Errorf("failed to create new provider: %w", err)
 981		}
 982
 983		// Update the provider and provider ID
 984		a.provider = newProvider
 985		a.providerID = string(currentProviderCfg.ID)
 986	}
 987
 988	// Check if providers have changed for title (small) and summarize (large)
 989	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 990	var smallModelProviderCfg config.ProviderConfig
 991	for p := range cfg.Providers.Seq() {
 992		if p.ID == smallModelCfg.Provider {
 993			smallModelProviderCfg = p
 994			break
 995		}
 996	}
 997	if smallModelProviderCfg.ID == "" {
 998		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 999	}
1000
1001	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1002	var largeModelProviderCfg config.ProviderConfig
1003	for p := range cfg.Providers.Seq() {
1004		if p.ID == largeModelCfg.Provider {
1005			largeModelProviderCfg = p
1006			break
1007		}
1008	}
1009	if largeModelProviderCfg.ID == "" {
1010		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1011	}
1012
1013	// Recreate title provider
1014	titleOpts := []provider.ProviderClientOption{
1015		provider.WithModel(config.SelectedModelTypeSmall),
1016		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1017		provider.WithMaxTokens(40),
1018	}
1019	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1020	if err != nil {
1021		return fmt.Errorf("failed to create new title provider: %w", err)
1022	}
1023	a.titleProvider = newTitleProvider
1024
1025	// Recreate summarize provider if provider changed (now large model)
1026	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1027		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1028		if largeModel == nil {
1029			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1030		}
1031		summarizeOpts := []provider.ProviderClientOption{
1032			provider.WithModel(config.SelectedModelTypeLarge),
1033			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1034		}
1035		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1036		if err != nil {
1037			return fmt.Errorf("failed to create new summarize provider: %w", err)
1038		}
1039		a.summarizeProvider = newSummarizeProvider
1040		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1041	}
1042
1043	return nil
1044}
1045
1046func (a *agent) allTools() []tools.BaseTool {
1047	result := slices.Collect(a.baseTools.Seq())
1048	result = slices.AppendSeq(result, a.mcpTools.Seq())
1049	return result
1050}
1051
1052func (a *agent) getToolByName(name string) (tools.BaseTool, bool) {
1053	tool, ok := a.baseTools.Get(name)
1054	if !ok {
1055		tool, ok = a.mcpTools.Get(name)
1056	}
1057	return tool, ok
1058}
1059
1060func (a *agent) setupEvents() {
1061	ctx, cancel := context.WithCancel(a.globalCtx)
1062
1063	go func() {
1064		for event := range SubscribeMCPEvents(ctx) {
1065			switch event.Payload.Type {
1066			case MCPEventToolsListChanged:
1067				name := event.Payload.Name
1068				c, ok := mcpClients.Get(name)
1069				if !ok {
1070					slog.Warn("MCP client not found for tools update", "name", name)
1071					continue
1072				}
1073				tools := getTools(ctx, name, c)
1074				updateMcpTools(name, tools)
1075				// Update the lazy map with the new tools
1076				a.mcpTools = csync.NewMapFrom(maps.Collect(mcpTools.Seq2()))
1077				updateMCPState(name, MCPStateConnected, nil, c, a.mcpTools.Len())
1078			default:
1079				continue
1080			}
1081		}
1082	}()
1083
1084	a.cleanupFuncs = append(a.cleanupFuncs, func() {
1085		cancel()
1086	})
1087}