agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75
  76	provider   provider.Provider
  77	providerID string
  78
  79	titleProvider       provider.Provider
  80	summarizeProvider   provider.Provider
  81	summarizeProviderID string
  82
  83	activeRequests *csync.Map[string, context.CancelFunc]
  84
  85	promptQueue *csync.Map[string, []string]
  86}
  87
  88var agentPromptMap = map[string]prompt.PromptID{
  89	"coder": prompt.PromptCoder,
  90	"task":  prompt.PromptTask,
  91}
  92
  93func NewAgent(
  94	ctx context.Context,
  95	agentCfg config.Agent,
  96	// These services are needed in the tools
  97	permissions permission.Service,
  98	sessions session.Service,
  99	messages message.Service,
 100	history history.Service,
 101	lspClients map[string]*lsp.Client,
 102) (Service, error) {
 103	cfg := config.Get()
 104
 105	var agentTool tools.BaseTool
 106	if agentCfg.ID == "coder" {
 107		taskAgentCfg := config.Get().Agents["task"]
 108		if taskAgentCfg.ID == "" {
 109			return nil, fmt.Errorf("task agent not found in config")
 110		}
 111		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 112		if err != nil {
 113			return nil, fmt.Errorf("failed to create task agent: %w", err)
 114		}
 115
 116		agentTool = NewAgentTool(taskAgent, sessions, messages)
 117	}
 118
 119	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 120	if providerCfg == nil {
 121		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 122	}
 123	model := config.Get().GetModelByType(agentCfg.Model)
 124
 125	if model == nil {
 126		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 127	}
 128
 129	promptID := agentPromptMap[agentCfg.ID]
 130	if promptID == "" {
 131		promptID = prompt.PromptDefault
 132	}
 133	opts := []provider.ProviderClientOption{
 134		provider.WithModel(agentCfg.Model),
 135		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 136	}
 137	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 138	if err != nil {
 139		return nil, err
 140	}
 141
 142	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 143	var smallModelProviderCfg *config.ProviderConfig
 144	if smallModelCfg.Provider == providerCfg.ID {
 145		smallModelProviderCfg = providerCfg
 146	} else {
 147		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 148
 149		if smallModelProviderCfg.ID == "" {
 150			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 151		}
 152	}
 153	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 154	if smallModel.ID == "" {
 155		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 156	}
 157
 158	titleOpts := []provider.ProviderClientOption{
 159		provider.WithModel(config.SelectedModelTypeSmall),
 160		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 161	}
 162	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 163	if err != nil {
 164		return nil, err
 165	}
 166
 167	summarizeOpts := []provider.ProviderClientOption{
 168		provider.WithModel(config.SelectedModelTypeLarge),
 169		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 170	}
 171	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 172	if err != nil {
 173		return nil, err
 174	}
 175
 176	toolFn := func() []tools.BaseTool {
 177		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 178		defer func() {
 179			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 180		}()
 181
 182		cwd := cfg.WorkingDir()
 183		allTools := []tools.BaseTool{
 184			tools.NewBashTool(permissions, cwd),
 185			tools.NewDownloadTool(permissions, cwd),
 186			tools.NewEditTool(lspClients, permissions, history, cwd),
 187			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 188			tools.NewFetchTool(permissions, cwd),
 189			tools.NewGlobTool(cwd),
 190			tools.NewGrepTool(cwd),
 191			tools.NewLsTool(permissions, cwd),
 192			tools.NewSourcegraphTool(),
 193			tools.NewViewTool(lspClients, permissions, cwd),
 194			tools.NewWriteTool(lspClients, permissions, history, cwd),
 195		}
 196
 197		mcpToolsOnce.Do(func() {
 198			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 199		})
 200		allTools = append(allTools, mcpTools...)
 201
 202		if len(lspClients) > 0 {
 203			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
 204		}
 205
 206		if agentTool != nil {
 207			allTools = append(allTools, agentTool)
 208		}
 209
 210		if agentCfg.AllowedTools == nil {
 211			return allTools
 212		}
 213
 214		var filteredTools []tools.BaseTool
 215		for _, tool := range allTools {
 216			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 217				filteredTools = append(filteredTools, tool)
 218			}
 219		}
 220		return filteredTools
 221	}
 222
 223	return &agent{
 224		Broker:              pubsub.NewBroker[AgentEvent](),
 225		agentCfg:            agentCfg,
 226		provider:            agentProvider,
 227		providerID:          string(providerCfg.ID),
 228		messages:            messages,
 229		sessions:            sessions,
 230		titleProvider:       titleProvider,
 231		summarizeProvider:   summarizeProvider,
 232		summarizeProviderID: string(providerCfg.ID),
 233		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 234		tools:               csync.NewLazySlice(toolFn),
 235		promptQueue:         csync.NewMap[string, []string](),
 236	}, nil
 237}
 238
 239func (a *agent) Model() catwalk.Model {
 240	return *config.Get().GetModelByType(a.agentCfg.Model)
 241}
 242
 243func (a *agent) Cancel(sessionID string) {
 244	// Cancel regular requests
 245	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 246		slog.Info("Request cancellation initiated", "session_id", sessionID)
 247		cancel()
 248	}
 249
 250	// Also check for summarize requests
 251	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 252		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 253		cancel()
 254	}
 255
 256	if a.QueuedPrompts(sessionID) > 0 {
 257		slog.Info("Clearing queued prompts", "session_id", sessionID)
 258		a.promptQueue.Del(sessionID)
 259	}
 260}
 261
 262func (a *agent) IsBusy() bool {
 263	var busy bool
 264	for cancelFunc := range a.activeRequests.Seq() {
 265		if cancelFunc != nil {
 266			busy = true
 267			break
 268		}
 269	}
 270	return busy
 271}
 272
 273func (a *agent) IsSessionBusy(sessionID string) bool {
 274	_, busy := a.activeRequests.Get(sessionID)
 275	return busy
 276}
 277
 278func (a *agent) QueuedPrompts(sessionID string) int {
 279	l, ok := a.promptQueue.Get(sessionID)
 280	if !ok {
 281		return 0
 282	}
 283	return len(l)
 284}
 285
 286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 287	if content == "" {
 288		return nil
 289	}
 290	if a.titleProvider == nil {
 291		return nil
 292	}
 293	session, err := a.sessions.Get(ctx, sessionID)
 294	if err != nil {
 295		return err
 296	}
 297	parts := []message.ContentPart{message.TextContent{
 298		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 299	}}
 300
 301	// Use streaming approach like summarization
 302	response := a.titleProvider.StreamResponse(
 303		ctx,
 304		[]message.Message{
 305			{
 306				Role:  message.User,
 307				Parts: parts,
 308			},
 309		},
 310		nil,
 311	)
 312
 313	var finalResponse *provider.ProviderResponse
 314	for r := range response {
 315		if r.Error != nil {
 316			return r.Error
 317		}
 318		finalResponse = r.Response
 319	}
 320
 321	if finalResponse == nil {
 322		return fmt.Errorf("no response received from title provider")
 323	}
 324
 325	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 326	if title == "" {
 327		return nil
 328	}
 329
 330	session.Title = title
 331	_, err = a.sessions.Save(ctx, session)
 332	return err
 333}
 334
 335func (a *agent) err(err error) AgentEvent {
 336	return AgentEvent{
 337		Type:  AgentEventTypeError,
 338		Error: err,
 339	}
 340}
 341
 342func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 343	if !a.Model().SupportsImages && attachments != nil {
 344		attachments = nil
 345	}
 346	events := make(chan AgentEvent)
 347	if a.IsSessionBusy(sessionID) {
 348		existing, ok := a.promptQueue.Get(sessionID)
 349		if !ok {
 350			existing = []string{}
 351		}
 352		existing = append(existing, content)
 353		a.promptQueue.Set(sessionID, existing)
 354		return nil, nil
 355	}
 356
 357	genCtx, cancel := context.WithCancel(ctx)
 358	defer cancel() // Ensure cancel is always called
 359
 360	a.activeRequests.Set(sessionID, cancel)
 361	defer a.activeRequests.Del(sessionID) // Clean up on exit
 362
 363	go func() {
 364		slog.Debug("Request started", "sessionID", sessionID)
 365		defer log.RecoverPanic("agent.Run", func() {
 366			events <- a.err(fmt.Errorf("panic while running the agent"))
 367		})
 368		var attachmentParts []message.ContentPart
 369		for _, attachment := range attachments {
 370			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 371		}
 372		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 373		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 374			slog.Error(result.Error.Error())
 375		}
 376		slog.Debug("Request completed", "sessionID", sessionID)
 377		a.Publish(pubsub.CreatedEvent, result)
 378		events <- result
 379		close(events)
 380	}()
 381	return events, nil
 382}
 383
 384func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 385	cfg := config.Get()
 386	// List existing messages; if none, start title generation asynchronously.
 387	msgs, err := a.messages.List(ctx, sessionID)
 388	if err != nil {
 389		return a.err(fmt.Errorf("failed to list messages: %w", err))
 390	}
 391
 392	// sliding window to limit message history
 393	maxMessagesInContext := cfg.Options.MaxMessages
 394	if maxMessagesInContext > 0 && len(msgs) > maxMessagesInContext {
 395		// Keep the first message (usually system/context) and the last N-1 messages
 396		msgs = append(msgs[:1], msgs[len(msgs)-maxMessagesInContext+1:]...)
 397	}
 398
 399	if len(msgs) == 0 {
 400		// Use a context with timeout for title generation
 401		titleCtx, titleCancel := context.WithTimeout(context.Background(), 30*time.Second)
 402		go func() {
 403			defer titleCancel()
 404			defer log.RecoverPanic("agent.Run", func() {
 405				slog.Error("panic while generating title")
 406			})
 407			titleErr := a.generateTitle(titleCtx, sessionID, content)
 408			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 409				slog.Error("failed to generate title", "error", titleErr)
 410			}
 411		}()
 412	}
 413	session, err := a.sessions.Get(ctx, sessionID)
 414	if err != nil {
 415		return a.err(fmt.Errorf("failed to get session: %w", err))
 416	}
 417	if session.SummaryMessageID != "" {
 418		summaryMsgInex := -1
 419		for i, msg := range msgs {
 420			if msg.ID == session.SummaryMessageID {
 421				summaryMsgInex = i
 422				break
 423			}
 424		}
 425		if summaryMsgInex != -1 {
 426			msgs = msgs[summaryMsgInex:]
 427			msgs[0].Role = message.User
 428		}
 429	}
 430
 431	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 432	if err != nil {
 433		return a.err(fmt.Errorf("failed to create user message: %w", err))
 434	}
 435	// Append the new user message to the conversation history.
 436	msgHistory := append(msgs, userMsg)
 437
 438	for {
 439		// Check for cancellation before each iteration
 440		select {
 441		case <-ctx.Done():
 442			return a.err(ctx.Err())
 443		default:
 444			// Continue processing
 445		}
 446		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 447		if err != nil {
 448			if errors.Is(err, context.Canceled) {
 449				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 450				a.messages.Update(context.Background(), agentMessage)
 451				return a.err(ErrRequestCancelled)
 452			}
 453			return a.err(fmt.Errorf("failed to process events: %w", err))
 454		}
 455		if cfg.Options.Debug {
 456			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 457		}
 458		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 459			// We are not done, we need to respond with the tool response
 460			msgHistory = append(msgHistory, agentMessage, *toolResults)
 461			// If there are queued prompts, process the next one
 462			nextPrompt, ok := a.promptQueue.Take(sessionID)
 463			if ok {
 464				for _, prompt := range nextPrompt {
 465					// Create a new user message for the queued prompt
 466					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 467					if err != nil {
 468						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 469					}
 470					// Append the new user message to the conversation history
 471					msgHistory = append(msgHistory, userMsg)
 472				}
 473			}
 474
 475			continue
 476		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 477			queuePrompts, ok := a.promptQueue.Take(sessionID)
 478			if ok {
 479				for _, prompt := range queuePrompts {
 480					if prompt == "" {
 481						continue
 482					}
 483					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 484					if err != nil {
 485						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 486					}
 487					msgHistory = append(msgHistory, userMsg)
 488				}
 489				continue
 490			}
 491		}
 492		if agentMessage.FinishReason() == "" {
 493			// Kujtim: could not track down where this is happening but this means its cancelled
 494			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 495			_ = a.messages.Update(context.Background(), agentMessage)
 496			return a.err(ErrRequestCancelled)
 497		}
 498		return AgentEvent{
 499			Type:    AgentEventTypeResponse,
 500			Message: agentMessage,
 501			Done:    true,
 502		}
 503	}
 504}
 505
 506func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 507	parts := []message.ContentPart{message.TextContent{Text: content}}
 508	parts = append(parts, attachmentParts...)
 509	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 510		Role:  message.User,
 511		Parts: parts,
 512	})
 513}
 514
 515func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 516	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 517
 518	// Create the assistant message first so the spinner shows immediately
 519	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 520		Role:     message.Assistant,
 521		Parts:    []message.ContentPart{},
 522		Model:    a.Model().ID,
 523		Provider: a.providerID,
 524	})
 525	if err != nil {
 526		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 527	}
 528
 529	// Now collect tools (which may block on MCP initialization)
 530	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
 531
 532	// Add the session and message ID into the context if needed by tools.
 533	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 534
 535	// Process each event in the stream.
 536	for event := range eventChan {
 537		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 538			if errors.Is(processErr, context.Canceled) {
 539				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 540			} else {
 541				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 542			}
 543			return assistantMsg, nil, processErr
 544		}
 545		if ctx.Err() != nil {
 546			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 547			return assistantMsg, nil, ctx.Err()
 548		}
 549	}
 550
 551	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 552	toolCalls := assistantMsg.ToolCalls()
 553	for i, toolCall := range toolCalls {
 554		select {
 555		case <-ctx.Done():
 556			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 557			// Make all future tool calls cancelled
 558			for j := i; j < len(toolCalls); j++ {
 559				toolResults[j] = message.ToolResult{
 560					ToolCallID: toolCalls[j].ID,
 561					Content:    "Tool execution canceled by user",
 562					IsError:    true,
 563				}
 564			}
 565			goto out
 566		default:
 567			// Continue processing
 568			var tool tools.BaseTool
 569			for availableTool := range a.tools.Seq() {
 570				if availableTool.Info().Name == toolCall.Name {
 571					tool = availableTool
 572					break
 573				}
 574			}
 575
 576			// Tool not found
 577			if tool == nil {
 578				toolResults[i] = message.ToolResult{
 579					ToolCallID: toolCall.ID,
 580					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 581					IsError:    true,
 582				}
 583				continue
 584			}
 585
 586			// Run tool in goroutine to allow cancellation
 587			type toolExecResult struct {
 588				response tools.ToolResponse
 589				err      error
 590			}
 591			resultChan := make(chan toolExecResult, 1)
 592
 593			go func() {
 594				response, err := tool.Run(ctx, tools.ToolCall{
 595					ID:    toolCall.ID,
 596					Name:  toolCall.Name,
 597					Input: toolCall.Input,
 598				})
 599				resultChan <- toolExecResult{response: response, err: err}
 600			}()
 601
 602			var toolResponse tools.ToolResponse
 603			var toolErr error
 604
 605			select {
 606			case <-ctx.Done():
 607				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 608				// Mark remaining tool calls as cancelled
 609				for j := i; j < len(toolCalls); j++ {
 610					toolResults[j] = message.ToolResult{
 611						ToolCallID: toolCalls[j].ID,
 612						Content:    "Tool execution canceled by user",
 613						IsError:    true,
 614					}
 615				}
 616				goto out
 617			case result := <-resultChan:
 618				toolResponse = result.response
 619				toolErr = result.err
 620			}
 621
 622			if toolErr != nil {
 623				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 624				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 625					toolResults[i] = message.ToolResult{
 626						ToolCallID: toolCall.ID,
 627						Content:    "Permission denied",
 628						IsError:    true,
 629					}
 630					for j := i + 1; j < len(toolCalls); j++ {
 631						toolResults[j] = message.ToolResult{
 632							ToolCallID: toolCalls[j].ID,
 633							Content:    "Tool execution canceled by user",
 634							IsError:    true,
 635						}
 636					}
 637					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 638					break
 639				}
 640			}
 641			toolResults[i] = message.ToolResult{
 642				ToolCallID: toolCall.ID,
 643				Content:    toolResponse.Content,
 644				Metadata:   toolResponse.Metadata,
 645				IsError:    toolResponse.IsError,
 646			}
 647		}
 648	}
 649out:
 650	if len(toolResults) == 0 {
 651		return assistantMsg, nil, nil
 652	}
 653	parts := make([]message.ContentPart, 0)
 654	for _, tr := range toolResults {
 655		parts = append(parts, tr)
 656	}
 657	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 658		Role:     message.Tool,
 659		Parts:    parts,
 660		Provider: a.providerID,
 661	})
 662	if err != nil {
 663		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 664	}
 665
 666	return assistantMsg, &msg, err
 667}
 668
 669func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 670	msg.AddFinish(finishReason, message, details)
 671	_ = a.messages.Update(ctx, *msg)
 672}
 673
 674func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 675	select {
 676	case <-ctx.Done():
 677		return ctx.Err()
 678	default:
 679		// Continue processing.
 680	}
 681
 682	switch event.Type {
 683	case provider.EventThinkingDelta:
 684		assistantMsg.AppendReasoningContent(event.Thinking)
 685		return a.messages.Update(ctx, *assistantMsg)
 686	case provider.EventSignatureDelta:
 687		assistantMsg.AppendReasoningSignature(event.Signature)
 688		return a.messages.Update(ctx, *assistantMsg)
 689	case provider.EventContentDelta:
 690		assistantMsg.FinishThinking()
 691		assistantMsg.AppendContent(event.Content)
 692		return a.messages.Update(ctx, *assistantMsg)
 693	case provider.EventToolUseStart:
 694		assistantMsg.FinishThinking()
 695		slog.Info("Tool call started", "toolCall", event.ToolCall)
 696		assistantMsg.AddToolCall(*event.ToolCall)
 697		return a.messages.Update(ctx, *assistantMsg)
 698	case provider.EventToolUseDelta:
 699		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 700		return a.messages.Update(ctx, *assistantMsg)
 701	case provider.EventToolUseStop:
 702		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 703		assistantMsg.FinishToolCall(event.ToolCall.ID)
 704		return a.messages.Update(ctx, *assistantMsg)
 705	case provider.EventError:
 706		return event.Error
 707	case provider.EventComplete:
 708		assistantMsg.FinishThinking()
 709		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 710		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 711		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 712			return fmt.Errorf("failed to update message: %w", err)
 713		}
 714		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 715	}
 716
 717	return nil
 718}
 719
 720func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 721	sess, err := a.sessions.Get(ctx, sessionID)
 722	if err != nil {
 723		return fmt.Errorf("failed to get session: %w", err)
 724	}
 725
 726	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 727		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 728		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 729		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 730
 731	sess.Cost += cost
 732	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 733	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 734
 735	_, err = a.sessions.Save(ctx, sess)
 736	if err != nil {
 737		return fmt.Errorf("failed to save session: %w", err)
 738	}
 739	return nil
 740}
 741
 742func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 743	if a.summarizeProvider == nil {
 744		return fmt.Errorf("summarize provider not available")
 745	}
 746
 747	// Check if session is busy
 748	if a.IsSessionBusy(sessionID) {
 749		return ErrSessionBusy
 750	}
 751
 752	// Create a new context with cancellation
 753	summarizeCtx, cancel := context.WithCancel(ctx)
 754
 755	// Store the cancel function in activeRequests to allow cancellation
 756	a.activeRequests.Set(sessionID+"-summarize", cancel)
 757
 758	go func() {
 759		defer a.activeRequests.Del(sessionID + "-summarize")
 760		defer cancel()
 761		event := AgentEvent{
 762			Type:     AgentEventTypeSummarize,
 763			Progress: "Starting summarization...",
 764		}
 765
 766		a.Publish(pubsub.CreatedEvent, event)
 767		// Get all messages from the session
 768		msgs, err := a.messages.List(summarizeCtx, sessionID)
 769		if err != nil {
 770			event = AgentEvent{
 771				Type:  AgentEventTypeError,
 772				Error: fmt.Errorf("failed to list messages: %w", err),
 773				Done:  true,
 774			}
 775			a.Publish(pubsub.CreatedEvent, event)
 776			return
 777		}
 778		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 779
 780		if len(msgs) == 0 {
 781			event = AgentEvent{
 782				Type:  AgentEventTypeError,
 783				Error: fmt.Errorf("no messages to summarize"),
 784				Done:  true,
 785			}
 786			a.Publish(pubsub.CreatedEvent, event)
 787			return
 788		}
 789
 790		event = AgentEvent{
 791			Type:     AgentEventTypeSummarize,
 792			Progress: "Analyzing conversation...",
 793		}
 794		a.Publish(pubsub.CreatedEvent, event)
 795
 796		// Add a system message to guide the summarization
 797		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 798
 799		// Create a new message with the summarize prompt
 800		promptMsg := message.Message{
 801			Role:  message.User,
 802			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 803		}
 804
 805		// Append the prompt to the messages
 806		msgsWithPrompt := append(msgs, promptMsg)
 807
 808		event = AgentEvent{
 809			Type:     AgentEventTypeSummarize,
 810			Progress: "Generating summary...",
 811		}
 812
 813		a.Publish(pubsub.CreatedEvent, event)
 814
 815		// Send the messages to the summarize provider
 816		response := a.summarizeProvider.StreamResponse(
 817			summarizeCtx,
 818			msgsWithPrompt,
 819			nil,
 820		)
 821		var finalResponse *provider.ProviderResponse
 822		for r := range response {
 823			if r.Error != nil {
 824				event = AgentEvent{
 825					Type:  AgentEventTypeError,
 826					Error: fmt.Errorf("failed to summarize: %w", err),
 827					Done:  true,
 828				}
 829				a.Publish(pubsub.CreatedEvent, event)
 830				return
 831			}
 832			finalResponse = r.Response
 833		}
 834
 835		summary := strings.TrimSpace(finalResponse.Content)
 836		if summary == "" {
 837			event = AgentEvent{
 838				Type:  AgentEventTypeError,
 839				Error: fmt.Errorf("empty summary returned"),
 840				Done:  true,
 841			}
 842			a.Publish(pubsub.CreatedEvent, event)
 843			return
 844		}
 845		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 846		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 847		event = AgentEvent{
 848			Type:     AgentEventTypeSummarize,
 849			Progress: "Creating new session...",
 850		}
 851
 852		a.Publish(pubsub.CreatedEvent, event)
 853		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 854		if err != nil {
 855			event = AgentEvent{
 856				Type:  AgentEventTypeError,
 857				Error: fmt.Errorf("failed to get session: %w", err),
 858				Done:  true,
 859			}
 860
 861			a.Publish(pubsub.CreatedEvent, event)
 862			return
 863		}
 864		// Create a message in the new session with the summary
 865		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 866			Role: message.Assistant,
 867			Parts: []message.ContentPart{
 868				message.TextContent{Text: summary},
 869				message.Finish{
 870					Reason: message.FinishReasonEndTurn,
 871					Time:   time.Now().Unix(),
 872				},
 873			},
 874			Model:    a.summarizeProvider.Model().ID,
 875			Provider: a.summarizeProviderID,
 876		})
 877		if err != nil {
 878			event = AgentEvent{
 879				Type:  AgentEventTypeError,
 880				Error: fmt.Errorf("failed to create summary message: %w", err),
 881				Done:  true,
 882			}
 883
 884			a.Publish(pubsub.CreatedEvent, event)
 885			return
 886		}
 887		oldSession.SummaryMessageID = msg.ID
 888		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 889		oldSession.PromptTokens = 0
 890		model := a.summarizeProvider.Model()
 891		usage := finalResponse.Usage
 892		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 893			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 894			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 895			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 896		oldSession.Cost += cost
 897		_, err = a.sessions.Save(summarizeCtx, oldSession)
 898		if err != nil {
 899			event = AgentEvent{
 900				Type:  AgentEventTypeError,
 901				Error: fmt.Errorf("failed to save session: %w", err),
 902				Done:  true,
 903			}
 904			a.Publish(pubsub.CreatedEvent, event)
 905		}
 906
 907		event = AgentEvent{
 908			Type:      AgentEventTypeSummarize,
 909			SessionID: oldSession.ID,
 910			Progress:  "Summary complete",
 911			Done:      true,
 912		}
 913		a.Publish(pubsub.CreatedEvent, event)
 914		// Send final success event with the new session ID
 915	}()
 916
 917	return nil
 918}
 919
 920func (a *agent) ClearQueue(sessionID string) {
 921	if a.QueuedPrompts(sessionID) > 0 {
 922		slog.Info("Clearing queued prompts", "session_id", sessionID)
 923		a.promptQueue.Del(sessionID)
 924	}
 925}
 926
 927func (a *agent) CancelAll() {
 928	if !a.IsBusy() {
 929		return
 930	}
 931	for key := range a.activeRequests.Seq2() {
 932		a.Cancel(key) // key is sessionID
 933	}
 934
 935	timeout := time.After(5 * time.Second)
 936	for a.IsBusy() {
 937		select {
 938		case <-timeout:
 939			return
 940		default:
 941			time.Sleep(200 * time.Millisecond)
 942		}
 943	}
 944}
 945
 946func (a *agent) UpdateModel() error {
 947	cfg := config.Get()
 948
 949	// Get current provider configuration
 950	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 951	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 952		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 953	}
 954
 955	// Check if provider has changed
 956	if string(currentProviderCfg.ID) != a.providerID {
 957		// Provider changed, need to recreate the main provider
 958		model := cfg.GetModelByType(a.agentCfg.Model)
 959		if model.ID == "" {
 960			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 961		}
 962
 963		promptID := agentPromptMap[a.agentCfg.ID]
 964		if promptID == "" {
 965			promptID = prompt.PromptDefault
 966		}
 967
 968		opts := []provider.ProviderClientOption{
 969			provider.WithModel(a.agentCfg.Model),
 970			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 971		}
 972
 973		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 974		if err != nil {
 975			return fmt.Errorf("failed to create new provider: %w", err)
 976		}
 977
 978		// Update the provider and provider ID
 979		a.provider = newProvider
 980		a.providerID = string(currentProviderCfg.ID)
 981	}
 982
 983	// Check if providers have changed for title (small) and summarize (large)
 984	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 985	var smallModelProviderCfg config.ProviderConfig
 986	for p := range cfg.Providers.Seq() {
 987		if p.ID == smallModelCfg.Provider {
 988			smallModelProviderCfg = p
 989			break
 990		}
 991	}
 992	if smallModelProviderCfg.ID == "" {
 993		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 994	}
 995
 996	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
 997	var largeModelProviderCfg config.ProviderConfig
 998	for p := range cfg.Providers.Seq() {
 999		if p.ID == largeModelCfg.Provider {
1000			largeModelProviderCfg = p
1001			break
1002		}
1003	}
1004	if largeModelProviderCfg.ID == "" {
1005		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1006	}
1007
1008	// Recreate title provider
1009	titleOpts := []provider.ProviderClientOption{
1010		provider.WithModel(config.SelectedModelTypeSmall),
1011		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1012		provider.WithMaxTokens(40),
1013	}
1014	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1015	if err != nil {
1016		return fmt.Errorf("failed to create new title provider: %w", err)
1017	}
1018	a.titleProvider = newTitleProvider
1019
1020	// Recreate summarize provider if provider changed (now large model)
1021	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1022		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1023		if largeModel == nil {
1024			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1025		}
1026		summarizeOpts := []provider.ProviderClientOption{
1027			provider.WithModel(config.SelectedModelTypeLarge),
1028			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1029		}
1030		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1031		if err != nil {
1032			return fmt.Errorf("failed to create new summarize provider: %w", err)
1033		}
1034		a.summarizeProvider = newSummarizeProvider
1035		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1036	}
1037
1038	return nil
1039}