1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42func (t AgentEventType) MarshalText() ([]byte, error) {
  43	return []byte(t), nil
  44}
  45
  46func (t *AgentEventType) UnmarshalText(text []byte) error {
  47	*t = AgentEventType(text)
  48	return nil
  49}
  50
  51type AgentEvent struct {
  52	Type    AgentEventType  `json:"type"`
  53	Message message.Message `json:"message"`
  54	Error   error           `json:"error,omitempty"`
  55
  56	// When summarizing
  57	SessionID string `json:"session_id,omitempty"`
  58	Progress  string `json:"progress,omitempty"`
  59	Done      bool   `json:"done,omitempty"`
  60}
  61
  62type Service interface {
  63	pubsub.Suscriber[AgentEvent]
  64	Model() catwalk.Model
  65	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  66	Cancel(sessionID string)
  67	CancelAll()
  68	IsSessionBusy(sessionID string) bool
  69	IsBusy() bool
  70	Summarize(ctx context.Context, sessionID string) error
  71	UpdateModel() error
  72	QueuedPrompts(sessionID string) int
  73	ClearQueue(sessionID string)
  74}
  75
  76type agent struct {
  77	*pubsub.Broker[AgentEvent]
  78	agentCfg config.Agent
  79	sessions session.Service
  80	messages message.Service
  81	mcpTools []McpTool
  82
  83	tools *csync.LazySlice[tools.BaseTool]
  84	// We need this to be able to update it when model changes
  85	agentToolFn func() (tools.BaseTool, error)
  86
  87	provider   provider.Provider
  88	providerID string
  89
  90	titleProvider       provider.Provider
  91	summarizeProvider   provider.Provider
  92	summarizeProviderID string
  93
  94	activeRequests *csync.Map[string, context.CancelFunc]
  95	promptQueue    *csync.Map[string, []string]
  96}
  97
  98var agentPromptMap = map[string]prompt.PromptID{
  99	"coder": prompt.PromptCoder,
 100	"task":  prompt.PromptTask,
 101}
 102
 103func NewAgent(
 104	ctx context.Context,
 105	agentCfg config.Agent,
 106	// These services are needed in the tools
 107	permissions permission.Service,
 108	sessions session.Service,
 109	messages message.Service,
 110	history history.Service,
 111	lspClients *csync.Map[string, *lsp.Client],
 112) (Service, error) {
 113	cfg := config.Get()
 114
 115	var agentToolFn func() (tools.BaseTool, error)
 116	if agentCfg.ID == "coder" {
 117		agentToolFn = func() (tools.BaseTool, error) {
 118			taskAgentCfg := config.Get().Agents["task"]
 119			if taskAgentCfg.ID == "" {
 120				return nil, fmt.Errorf("task agent not found in config")
 121			}
 122			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 123			if err != nil {
 124				return nil, fmt.Errorf("failed to create task agent: %w", err)
 125			}
 126			return NewAgentTool(taskAgent, sessions, messages), nil
 127		}
 128	}
 129
 130	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 131	if providerCfg == nil {
 132		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 133	}
 134	model := config.Get().GetModelByType(agentCfg.Model)
 135
 136	if model == nil {
 137		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 138	}
 139
 140	promptID := agentPromptMap[agentCfg.ID]
 141	if promptID == "" {
 142		promptID = prompt.PromptDefault
 143	}
 144	opts := []provider.ProviderClientOption{
 145		provider.WithModel(agentCfg.Model),
 146		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 147	}
 148	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 149	if err != nil {
 150		return nil, err
 151	}
 152
 153	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 154	var smallModelProviderCfg *config.ProviderConfig
 155	if smallModelCfg.Provider == providerCfg.ID {
 156		smallModelProviderCfg = providerCfg
 157	} else {
 158		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 159
 160		if smallModelProviderCfg.ID == "" {
 161			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 162		}
 163	}
 164	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 165	if smallModel.ID == "" {
 166		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 167	}
 168
 169	titleOpts := []provider.ProviderClientOption{
 170		provider.WithModel(config.SelectedModelTypeSmall),
 171		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 172	}
 173	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 174	if err != nil {
 175		return nil, err
 176	}
 177
 178	summarizeOpts := []provider.ProviderClientOption{
 179		provider.WithModel(config.SelectedModelTypeLarge),
 180		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 181	}
 182	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 183	if err != nil {
 184		return nil, err
 185	}
 186
 187	toolFn := func() []tools.BaseTool {
 188		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 189		defer func() {
 190			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 191		}()
 192
 193		cwd := cfg.WorkingDir()
 194		allTools := []tools.BaseTool{
 195			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 196			tools.NewDownloadTool(permissions, cwd),
 197			tools.NewEditTool(lspClients, permissions, history, cwd),
 198			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 199			tools.NewFetchTool(permissions, cwd),
 200			tools.NewGlobTool(cwd),
 201			tools.NewGrepTool(cwd),
 202			tools.NewLsTool(permissions, cwd),
 203			tools.NewSourcegraphTool(),
 204			tools.NewViewTool(lspClients, permissions, cwd),
 205			tools.NewWriteTool(lspClients, permissions, history, cwd),
 206		}
 207
 208		mcpToolsOnce.Do(func() {
 209			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 210		})
 211
 212		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 213			if agentCfg.ID == "coder" {
 214				t = append(t, mcpTools...)
 215				if lspClients.Len() > 0 {
 216					t = append(t, tools.NewDiagnosticsTool(lspClients))
 217				}
 218			}
 219			return t
 220		}
 221
 222		if agentCfg.AllowedTools == nil {
 223			return withCoderTools(allTools)
 224		}
 225
 226		var filteredTools []tools.BaseTool
 227		for _, tool := range allTools {
 228			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 229				filteredTools = append(filteredTools, tool)
 230			}
 231		}
 232		return withCoderTools(filteredTools)
 233	}
 234
 235	return &agent{
 236		Broker:              pubsub.NewBroker[AgentEvent](),
 237		agentCfg:            agentCfg,
 238		provider:            agentProvider,
 239		providerID:          string(providerCfg.ID),
 240		messages:            messages,
 241		sessions:            sessions,
 242		titleProvider:       titleProvider,
 243		summarizeProvider:   summarizeProvider,
 244		summarizeProviderID: string(providerCfg.ID),
 245		agentToolFn:         agentToolFn,
 246		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 247		tools:               csync.NewLazySlice(toolFn),
 248		promptQueue:         csync.NewMap[string, []string](),
 249	}, nil
 250}
 251
 252func (a *agent) Model() catwalk.Model {
 253	return *config.Get().GetModelByType(a.agentCfg.Model)
 254}
 255
 256func (a *agent) Cancel(sessionID string) {
 257	// Cancel regular requests
 258	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 259		slog.Info("Request cancellation initiated", "session_id", sessionID)
 260		cancel()
 261	}
 262
 263	// Also check for summarize requests
 264	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 265		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 266		cancel()
 267	}
 268
 269	if a.QueuedPrompts(sessionID) > 0 {
 270		slog.Info("Clearing queued prompts", "session_id", sessionID)
 271		a.promptQueue.Del(sessionID)
 272	}
 273}
 274
 275func (a *agent) IsBusy() bool {
 276	var busy bool
 277	for cancelFunc := range a.activeRequests.Seq() {
 278		if cancelFunc != nil {
 279			busy = true
 280			break
 281		}
 282	}
 283	return busy
 284}
 285
 286func (a *agent) IsSessionBusy(sessionID string) bool {
 287	_, busy := a.activeRequests.Get(sessionID)
 288	return busy
 289}
 290
 291func (a *agent) QueuedPrompts(sessionID string) int {
 292	l, ok := a.promptQueue.Get(sessionID)
 293	if !ok {
 294		return 0
 295	}
 296	return len(l)
 297}
 298
 299func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 300	if content == "" {
 301		return nil
 302	}
 303	if a.titleProvider == nil {
 304		return nil
 305	}
 306	session, err := a.sessions.Get(ctx, sessionID)
 307	if err != nil {
 308		return err
 309	}
 310	parts := []message.ContentPart{message.TextContent{
 311		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 312	}}
 313
 314	// Use streaming approach like summarization
 315	response := a.titleProvider.StreamResponse(
 316		ctx,
 317		[]message.Message{
 318			{
 319				Role:  message.User,
 320				Parts: parts,
 321			},
 322		},
 323		nil,
 324	)
 325
 326	var finalResponse *provider.ProviderResponse
 327	for r := range response {
 328		if r.Error != nil {
 329			return r.Error
 330		}
 331		finalResponse = r.Response
 332	}
 333
 334	if finalResponse == nil {
 335		return fmt.Errorf("no response received from title provider")
 336	}
 337
 338	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 339
 340	if idx := strings.Index(title, "</think>"); idx > 0 {
 341		title = title[idx+len("</think>"):]
 342	}
 343
 344	title = strings.TrimSpace(title)
 345	if title == "" {
 346		return nil
 347	}
 348
 349	session.Title = title
 350	_, err = a.sessions.Save(ctx, session)
 351	return err
 352}
 353
 354func (a *agent) err(err error) AgentEvent {
 355	return AgentEvent{
 356		Type:  AgentEventTypeError,
 357		Error: err,
 358	}
 359}
 360
 361func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 362	if !a.Model().SupportsImages && attachments != nil {
 363		attachments = nil
 364	}
 365	events := make(chan AgentEvent, 1)
 366	if a.IsSessionBusy(sessionID) {
 367		existing, ok := a.promptQueue.Get(sessionID)
 368		if !ok {
 369			existing = []string{}
 370		}
 371		existing = append(existing, content)
 372		a.promptQueue.Set(sessionID, existing)
 373		return nil, nil
 374	}
 375
 376	genCtx, cancel := context.WithCancel(ctx)
 377
 378	a.activeRequests.Set(sessionID, cancel)
 379	go func() {
 380		slog.Debug("Request started", "sessionID", sessionID)
 381		defer log.RecoverPanic("agent.Run", func() {
 382			events <- a.err(fmt.Errorf("panic while running the agent"))
 383		})
 384		var attachmentParts []message.ContentPart
 385		for _, attachment := range attachments {
 386			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 387		}
 388		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 389		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 390			slog.Error(result.Error.Error())
 391		}
 392		slog.Debug("Request completed", "sessionID", sessionID)
 393		a.activeRequests.Del(sessionID)
 394		cancel()
 395		a.Publish(pubsub.CreatedEvent, result)
 396		events <- result
 397		close(events)
 398	}()
 399	return events, nil
 400}
 401
 402func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 403	cfg := config.Get()
 404	// List existing messages; if none, start title generation asynchronously.
 405	msgs, err := a.messages.List(ctx, sessionID)
 406	if err != nil {
 407		return a.err(fmt.Errorf("failed to list messages: %w", err))
 408	}
 409	if len(msgs) == 0 {
 410		go func() {
 411			defer log.RecoverPanic("agent.Run", func() {
 412				slog.Error("panic while generating title")
 413			})
 414			titleErr := a.generateTitle(ctx, sessionID, content)
 415			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 416				slog.Error("failed to generate title", "error", titleErr)
 417			}
 418		}()
 419	}
 420	session, err := a.sessions.Get(ctx, sessionID)
 421	if err != nil {
 422		return a.err(fmt.Errorf("failed to get session: %w", err))
 423	}
 424	if session.SummaryMessageID != "" {
 425		summaryMsgInex := -1
 426		for i, msg := range msgs {
 427			if msg.ID == session.SummaryMessageID {
 428				summaryMsgInex = i
 429				break
 430			}
 431		}
 432		if summaryMsgInex != -1 {
 433			msgs = msgs[summaryMsgInex:]
 434			msgs[0].Role = message.User
 435		}
 436	}
 437
 438	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 439	if err != nil {
 440		return a.err(fmt.Errorf("failed to create user message: %w", err))
 441	}
 442	// Append the new user message to the conversation history.
 443	msgHistory := append(msgs, userMsg)
 444
 445	for {
 446		// Check for cancellation before each iteration
 447		select {
 448		case <-ctx.Done():
 449			return a.err(ctx.Err())
 450		default:
 451			// Continue processing
 452		}
 453		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 454		if err != nil {
 455			if errors.Is(err, context.Canceled) {
 456				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 457				a.messages.Update(context.Background(), agentMessage)
 458				return a.err(ErrRequestCancelled)
 459			}
 460			return a.err(fmt.Errorf("failed to process events: %w", err))
 461		}
 462		if cfg.Options.Debug {
 463			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 464		}
 465		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 466			// We are not done, we need to respond with the tool response
 467			msgHistory = append(msgHistory, agentMessage, *toolResults)
 468			// If there are queued prompts, process the next one
 469			nextPrompt, ok := a.promptQueue.Take(sessionID)
 470			if ok {
 471				for _, prompt := range nextPrompt {
 472					// Create a new user message for the queued prompt
 473					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 474					if err != nil {
 475						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 476					}
 477					// Append the new user message to the conversation history
 478					msgHistory = append(msgHistory, userMsg)
 479				}
 480			}
 481
 482			continue
 483		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 484			queuePrompts, ok := a.promptQueue.Take(sessionID)
 485			if ok {
 486				for _, prompt := range queuePrompts {
 487					if prompt == "" {
 488						continue
 489					}
 490					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 491					if err != nil {
 492						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 493					}
 494					msgHistory = append(msgHistory, userMsg)
 495				}
 496				continue
 497			}
 498		}
 499		if agentMessage.FinishReason() == "" {
 500			// Kujtim: could not track down where this is happening but this means its cancelled
 501			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 502			_ = a.messages.Update(context.Background(), agentMessage)
 503			return a.err(ErrRequestCancelled)
 504		}
 505		return AgentEvent{
 506			Type:    AgentEventTypeResponse,
 507			Message: agentMessage,
 508			Done:    true,
 509		}
 510	}
 511}
 512
 513func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 514	parts := []message.ContentPart{message.TextContent{Text: content}}
 515	parts = append(parts, attachmentParts...)
 516	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 517		Role:  message.User,
 518		Parts: parts,
 519	})
 520}
 521
 522func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 523	allTools := slices.Collect(a.tools.Seq())
 524	if a.agentToolFn != nil {
 525		agentTool, agentToolErr := a.agentToolFn()
 526		if agentToolErr != nil {
 527			return nil, agentToolErr
 528		}
 529		allTools = append(allTools, agentTool)
 530	}
 531	return allTools, nil
 532}
 533
 534func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 535	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 536
 537	// Create the assistant message first so the spinner shows immediately
 538	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 539		Role:     message.Assistant,
 540		Parts:    []message.ContentPart{},
 541		Model:    a.Model().ID,
 542		Provider: a.providerID,
 543	})
 544	if err != nil {
 545		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 546	}
 547
 548	allTools, toolsErr := a.getAllTools()
 549	if toolsErr != nil {
 550		return assistantMsg, nil, toolsErr
 551	}
 552	// Now collect tools (which may block on MCP initialization)
 553	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 554
 555	// Add the session and message ID into the context if needed by tools.
 556	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 557
 558	// Process each event in the stream.
 559	for event := range eventChan {
 560		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 561			if errors.Is(processErr, context.Canceled) {
 562				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 563			} else {
 564				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 565			}
 566			return assistantMsg, nil, processErr
 567		}
 568		if ctx.Err() != nil {
 569			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 570			return assistantMsg, nil, ctx.Err()
 571		}
 572	}
 573
 574	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 575	toolCalls := assistantMsg.ToolCalls()
 576	for i, toolCall := range toolCalls {
 577		select {
 578		case <-ctx.Done():
 579			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 580			// Make all future tool calls cancelled
 581			for j := i; j < len(toolCalls); j++ {
 582				toolResults[j] = message.ToolResult{
 583					ToolCallID: toolCalls[j].ID,
 584					Content:    "Tool execution canceled by user",
 585					IsError:    true,
 586				}
 587			}
 588			goto out
 589		default:
 590			// Continue processing
 591			var tool tools.BaseTool
 592			allTools, _ := a.getAllTools()
 593			for _, availableTool := range allTools {
 594				if availableTool.Info().Name == toolCall.Name {
 595					tool = availableTool
 596					break
 597				}
 598			}
 599
 600			// Tool not found
 601			if tool == nil {
 602				toolResults[i] = message.ToolResult{
 603					ToolCallID: toolCall.ID,
 604					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 605					IsError:    true,
 606				}
 607				continue
 608			}
 609
 610			// Run tool in goroutine to allow cancellation
 611			type toolExecResult struct {
 612				response tools.ToolResponse
 613				err      error
 614			}
 615			resultChan := make(chan toolExecResult, 1)
 616
 617			go func() {
 618				response, err := tool.Run(ctx, tools.ToolCall{
 619					ID:    toolCall.ID,
 620					Name:  toolCall.Name,
 621					Input: toolCall.Input,
 622				})
 623				resultChan <- toolExecResult{response: response, err: err}
 624			}()
 625
 626			var toolResponse tools.ToolResponse
 627			var toolErr error
 628
 629			select {
 630			case <-ctx.Done():
 631				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 632				// Mark remaining tool calls as cancelled
 633				for j := i; j < len(toolCalls); j++ {
 634					toolResults[j] = message.ToolResult{
 635						ToolCallID: toolCalls[j].ID,
 636						Content:    "Tool execution canceled by user",
 637						IsError:    true,
 638					}
 639				}
 640				goto out
 641			case result := <-resultChan:
 642				toolResponse = result.response
 643				toolErr = result.err
 644			}
 645
 646			if toolErr != nil {
 647				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 648				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 649					toolResults[i] = message.ToolResult{
 650						ToolCallID: toolCall.ID,
 651						Content:    "Permission denied",
 652						IsError:    true,
 653					}
 654					for j := i + 1; j < len(toolCalls); j++ {
 655						toolResults[j] = message.ToolResult{
 656							ToolCallID: toolCalls[j].ID,
 657							Content:    "Tool execution canceled by user",
 658							IsError:    true,
 659						}
 660					}
 661					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 662					break
 663				}
 664			}
 665			toolResults[i] = message.ToolResult{
 666				ToolCallID: toolCall.ID,
 667				Content:    toolResponse.Content,
 668				Metadata:   toolResponse.Metadata,
 669				IsError:    toolResponse.IsError,
 670			}
 671		}
 672	}
 673out:
 674	if len(toolResults) == 0 {
 675		return assistantMsg, nil, nil
 676	}
 677	parts := make([]message.ContentPart, 0)
 678	for _, tr := range toolResults {
 679		parts = append(parts, tr)
 680	}
 681	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 682		Role:     message.Tool,
 683		Parts:    parts,
 684		Provider: a.providerID,
 685	})
 686	if err != nil {
 687		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 688	}
 689
 690	return assistantMsg, &msg, err
 691}
 692
 693func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 694	msg.AddFinish(finishReason, message, details)
 695	_ = a.messages.Update(ctx, *msg)
 696}
 697
 698func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 699	select {
 700	case <-ctx.Done():
 701		return ctx.Err()
 702	default:
 703		// Continue processing.
 704	}
 705
 706	switch event.Type {
 707	case provider.EventThinkingDelta:
 708		assistantMsg.AppendReasoningContent(event.Thinking)
 709		return a.messages.Update(ctx, *assistantMsg)
 710	case provider.EventSignatureDelta:
 711		assistantMsg.AppendReasoningSignature(event.Signature)
 712		return a.messages.Update(ctx, *assistantMsg)
 713	case provider.EventContentDelta:
 714		assistantMsg.FinishThinking()
 715		assistantMsg.AppendContent(event.Content)
 716		return a.messages.Update(ctx, *assistantMsg)
 717	case provider.EventToolUseStart:
 718		assistantMsg.FinishThinking()
 719		slog.Info("Tool call started", "toolCall", event.ToolCall)
 720		assistantMsg.AddToolCall(*event.ToolCall)
 721		return a.messages.Update(ctx, *assistantMsg)
 722	case provider.EventToolUseDelta:
 723		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 724		return a.messages.Update(ctx, *assistantMsg)
 725	case provider.EventToolUseStop:
 726		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 727		assistantMsg.FinishToolCall(event.ToolCall.ID)
 728		return a.messages.Update(ctx, *assistantMsg)
 729	case provider.EventError:
 730		return event.Error
 731	case provider.EventComplete:
 732		assistantMsg.FinishThinking()
 733		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 734		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 735		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 736			return fmt.Errorf("failed to update message: %w", err)
 737		}
 738		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 739	}
 740
 741	return nil
 742}
 743
 744func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 745	sess, err := a.sessions.Get(ctx, sessionID)
 746	if err != nil {
 747		return fmt.Errorf("failed to get session: %w", err)
 748	}
 749
 750	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 751		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 752		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 753		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 754
 755	sess.Cost += cost
 756	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 757	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 758
 759	_, err = a.sessions.Save(ctx, sess)
 760	if err != nil {
 761		return fmt.Errorf("failed to save session: %w", err)
 762	}
 763	return nil
 764}
 765
 766func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 767	if a.summarizeProvider == nil {
 768		return fmt.Errorf("summarize provider not available")
 769	}
 770
 771	// Check if session is busy
 772	if a.IsSessionBusy(sessionID) {
 773		return ErrSessionBusy
 774	}
 775
 776	// Create a new context with cancellation
 777	summarizeCtx, cancel := context.WithCancel(ctx)
 778
 779	// Store the cancel function in activeRequests to allow cancellation
 780	a.activeRequests.Set(sessionID+"-summarize", cancel)
 781
 782	go func() {
 783		defer a.activeRequests.Del(sessionID + "-summarize")
 784		defer cancel()
 785		event := AgentEvent{
 786			Type:     AgentEventTypeSummarize,
 787			Progress: "Starting summarization...",
 788		}
 789
 790		a.Publish(pubsub.CreatedEvent, event)
 791		// Get all messages from the session
 792		msgs, err := a.messages.List(summarizeCtx, sessionID)
 793		if err != nil {
 794			event = AgentEvent{
 795				Type:  AgentEventTypeError,
 796				Error: fmt.Errorf("failed to list messages: %w", err),
 797				Done:  true,
 798			}
 799			a.Publish(pubsub.CreatedEvent, event)
 800			return
 801		}
 802		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 803
 804		if len(msgs) == 0 {
 805			event = AgentEvent{
 806				Type:  AgentEventTypeError,
 807				Error: fmt.Errorf("no messages to summarize"),
 808				Done:  true,
 809			}
 810			a.Publish(pubsub.CreatedEvent, event)
 811			return
 812		}
 813
 814		event = AgentEvent{
 815			Type:     AgentEventTypeSummarize,
 816			Progress: "Analyzing conversation...",
 817		}
 818		a.Publish(pubsub.CreatedEvent, event)
 819
 820		// Add a system message to guide the summarization
 821		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 822
 823		// Create a new message with the summarize prompt
 824		promptMsg := message.Message{
 825			Role:  message.User,
 826			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 827		}
 828
 829		// Append the prompt to the messages
 830		msgsWithPrompt := append(msgs, promptMsg)
 831
 832		event = AgentEvent{
 833			Type:     AgentEventTypeSummarize,
 834			Progress: "Generating summary...",
 835		}
 836
 837		a.Publish(pubsub.CreatedEvent, event)
 838
 839		// Send the messages to the summarize provider
 840		response := a.summarizeProvider.StreamResponse(
 841			summarizeCtx,
 842			msgsWithPrompt,
 843			nil,
 844		)
 845		var finalResponse *provider.ProviderResponse
 846		for r := range response {
 847			if r.Error != nil {
 848				event = AgentEvent{
 849					Type:  AgentEventTypeError,
 850					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 851					Done:  true,
 852				}
 853				a.Publish(pubsub.CreatedEvent, event)
 854				return
 855			}
 856			finalResponse = r.Response
 857		}
 858
 859		summary := strings.TrimSpace(finalResponse.Content)
 860		if summary == "" {
 861			event = AgentEvent{
 862				Type:  AgentEventTypeError,
 863				Error: fmt.Errorf("empty summary returned"),
 864				Done:  true,
 865			}
 866			a.Publish(pubsub.CreatedEvent, event)
 867			return
 868		}
 869		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 870		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 871		event = AgentEvent{
 872			Type:     AgentEventTypeSummarize,
 873			Progress: "Creating new session...",
 874		}
 875
 876		a.Publish(pubsub.CreatedEvent, event)
 877		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 878		if err != nil {
 879			event = AgentEvent{
 880				Type:  AgentEventTypeError,
 881				Error: fmt.Errorf("failed to get session: %w", err),
 882				Done:  true,
 883			}
 884
 885			a.Publish(pubsub.CreatedEvent, event)
 886			return
 887		}
 888		// Create a message in the new session with the summary
 889		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 890			Role: message.Assistant,
 891			Parts: []message.ContentPart{
 892				message.TextContent{Text: summary},
 893				message.Finish{
 894					Reason: message.FinishReasonEndTurn,
 895					Time:   time.Now().Unix(),
 896				},
 897			},
 898			Model:    a.summarizeProvider.Model().ID,
 899			Provider: a.summarizeProviderID,
 900		})
 901		if err != nil {
 902			event = AgentEvent{
 903				Type:  AgentEventTypeError,
 904				Error: fmt.Errorf("failed to create summary message: %w", err),
 905				Done:  true,
 906			}
 907
 908			a.Publish(pubsub.CreatedEvent, event)
 909			return
 910		}
 911		oldSession.SummaryMessageID = msg.ID
 912		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 913		oldSession.PromptTokens = 0
 914		model := a.summarizeProvider.Model()
 915		usage := finalResponse.Usage
 916		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 917			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 918			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 919			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 920		oldSession.Cost += cost
 921		_, err = a.sessions.Save(summarizeCtx, oldSession)
 922		if err != nil {
 923			event = AgentEvent{
 924				Type:  AgentEventTypeError,
 925				Error: fmt.Errorf("failed to save session: %w", err),
 926				Done:  true,
 927			}
 928			a.Publish(pubsub.CreatedEvent, event)
 929		}
 930
 931		event = AgentEvent{
 932			Type:      AgentEventTypeSummarize,
 933			SessionID: oldSession.ID,
 934			Progress:  "Summary complete",
 935			Done:      true,
 936		}
 937		a.Publish(pubsub.CreatedEvent, event)
 938		// Send final success event with the new session ID
 939	}()
 940
 941	return nil
 942}
 943
 944func (a *agent) ClearQueue(sessionID string) {
 945	if a.QueuedPrompts(sessionID) > 0 {
 946		slog.Info("Clearing queued prompts", "session_id", sessionID)
 947		a.promptQueue.Del(sessionID)
 948	}
 949}
 950
 951func (a *agent) CancelAll() {
 952	if !a.IsBusy() {
 953		return
 954	}
 955	for key := range a.activeRequests.Seq2() {
 956		a.Cancel(key) // key is sessionID
 957	}
 958
 959	timeout := time.After(5 * time.Second)
 960	for a.IsBusy() {
 961		select {
 962		case <-timeout:
 963			return
 964		default:
 965			time.Sleep(200 * time.Millisecond)
 966		}
 967	}
 968}
 969
 970func (a *agent) UpdateModel() error {
 971	cfg := config.Get()
 972
 973	// Get current provider configuration
 974	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 975	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 976		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 977	}
 978
 979	// Check if provider has changed
 980	if string(currentProviderCfg.ID) != a.providerID {
 981		// Provider changed, need to recreate the main provider
 982		model := cfg.GetModelByType(a.agentCfg.Model)
 983		if model.ID == "" {
 984			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 985		}
 986
 987		promptID := agentPromptMap[a.agentCfg.ID]
 988		if promptID == "" {
 989			promptID = prompt.PromptDefault
 990		}
 991
 992		opts := []provider.ProviderClientOption{
 993			provider.WithModel(a.agentCfg.Model),
 994			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 995		}
 996
 997		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 998		if err != nil {
 999			return fmt.Errorf("failed to create new provider: %w", err)
1000		}
1001
1002		// Update the provider and provider ID
1003		a.provider = newProvider
1004		a.providerID = string(currentProviderCfg.ID)
1005	}
1006
1007	// Check if providers have changed for title (small) and summarize (large)
1008	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1009	var smallModelProviderCfg config.ProviderConfig
1010	for p := range cfg.Providers.Seq() {
1011		if p.ID == smallModelCfg.Provider {
1012			smallModelProviderCfg = p
1013			break
1014		}
1015	}
1016	if smallModelProviderCfg.ID == "" {
1017		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1018	}
1019
1020	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1021	var largeModelProviderCfg config.ProviderConfig
1022	for p := range cfg.Providers.Seq() {
1023		if p.ID == largeModelCfg.Provider {
1024			largeModelProviderCfg = p
1025			break
1026		}
1027	}
1028	if largeModelProviderCfg.ID == "" {
1029		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1030	}
1031
1032	var maxTitleTokens int64 = 40
1033
1034	// if the max output is too low for the gemini provider it won't return anything
1035	if smallModelCfg.Provider == "gemini" {
1036		maxTitleTokens = 1000
1037	}
1038	// Recreate title provider
1039	titleOpts := []provider.ProviderClientOption{
1040		provider.WithModel(config.SelectedModelTypeSmall),
1041		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1042		provider.WithMaxTokens(maxTitleTokens),
1043	}
1044	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1045	if err != nil {
1046		return fmt.Errorf("failed to create new title provider: %w", err)
1047	}
1048	a.titleProvider = newTitleProvider
1049
1050	// Recreate summarize provider if provider changed (now large model)
1051	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1052		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1053		if largeModel == nil {
1054			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1055		}
1056		summarizeOpts := []provider.ProviderClientOption{
1057			provider.WithModel(config.SelectedModelTypeLarge),
1058			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1059		}
1060		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1061		if err != nil {
1062			return fmt.Errorf("failed to create new summarize provider: %w", err)
1063		}
1064		a.summarizeProvider = newSummarizeProvider
1065		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1066	}
1067
1068	return nil
1069}