1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/event"
  16	"github.com/charmbracelet/crush/internal/history"
  17	"github.com/charmbracelet/crush/internal/llm/prompt"
  18	"github.com/charmbracelet/crush/internal/llm/provider"
  19	"github.com/charmbracelet/crush/internal/llm/tools"
  20	"github.com/charmbracelet/crush/internal/log"
  21	"github.com/charmbracelet/crush/internal/lsp"
  22	"github.com/charmbracelet/crush/internal/message"
  23	"github.com/charmbracelet/crush/internal/permission"
  24	"github.com/charmbracelet/crush/internal/proto"
  25	"github.com/charmbracelet/crush/internal/pubsub"
  26	"github.com/charmbracelet/crush/internal/session"
  27	"github.com/charmbracelet/crush/internal/shell"
  28)
  29
  30type (
  31	AgentEventType = proto.AgentEventType
  32	AgentEvent     = proto.AgentEvent
  33)
  34
  35const (
  36	AgentEventTypeError     = proto.AgentEventTypeError
  37	AgentEventTypeResponse  = proto.AgentEventTypeResponse
  38	AgentEventTypeSummarize = proto.AgentEventTypeSummarize
  39)
  40
  41type Service interface {
  42	pubsub.Suscriber[AgentEvent]
  43	Model() catwalk.Model
  44	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  45	Cancel(sessionID string)
  46	CancelAll()
  47	IsSessionBusy(sessionID string) bool
  48	IsBusy() bool
  49	Summarize(ctx context.Context, sessionID string) error
  50	UpdateModel() error
  51	QueuedPrompts(sessionID string) int
  52	ClearQueue(sessionID string)
  53}
  54
  55type agent struct {
  56	*pubsub.Broker[AgentEvent]
  57	agentCfg    config.Agent
  58	sessions    session.Service
  59	messages    message.Service
  60	permissions permission.Service
  61	mcpTools    []McpTool
  62	cfg         *config.Config
  63
  64	tools *csync.LazySlice[tools.BaseTool]
  65	// We need this to be able to update it when model changes
  66	agentToolFn func() (tools.BaseTool, error)
  67
  68	provider   provider.Provider
  69	providerID string
  70
  71	titleProvider       provider.Provider
  72	summarizeProvider   provider.Provider
  73	summarizeProviderID string
  74
  75	activeRequests *csync.Map[string, context.CancelFunc]
  76	promptQueue    *csync.Map[string, []string]
  77}
  78
  79var agentPromptMap = map[string]prompt.PromptID{
  80	"coder": prompt.PromptCoder,
  81	"task":  prompt.PromptTask,
  82}
  83
  84func NewAgent(
  85	ctx context.Context,
  86	cfg *config.Config,
  87	agentCfg config.Agent,
  88	// These services are needed in the tools
  89	permissions permission.Service,
  90	sessions session.Service,
  91	messages message.Service,
  92	history history.Service,
  93	lspClients *csync.Map[string, *lsp.Client],
  94) (Service, error) {
  95	var agentToolFn func() (tools.BaseTool, error)
  96	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
  97		agentToolFn = func() (tools.BaseTool, error) {
  98			taskAgentCfg := cfg.Agents["task"]
  99			if taskAgentCfg.ID == "" {
 100				return nil, fmt.Errorf("task agent not found in config")
 101			}
 102			taskAgent, err := NewAgent(ctx, cfg, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 103			if err != nil {
 104				return nil, fmt.Errorf("failed to create task agent: %w", err)
 105			}
 106			return NewAgentTool(taskAgent, sessions, messages), nil
 107		}
 108	}
 109
 110	providerCfg := cfg.GetProviderForModel(agentCfg.Model)
 111	if providerCfg == nil {
 112		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 113	}
 114	model := cfg.GetModelByType(agentCfg.Model)
 115
 116	if model == nil {
 117		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 118	}
 119
 120	promptID := agentPromptMap[agentCfg.ID]
 121	if promptID == "" {
 122		promptID = prompt.PromptDefault
 123	}
 124	opts := []provider.ProviderClientOption{
 125		provider.WithModel(agentCfg.Model),
 126		provider.WithSystemMessage(prompt.GetPrompt(cfg, promptID, providerCfg.ID, cfg.Options.ContextPaths...)),
 127	}
 128	agentProvider, err := provider.NewProvider(cfg, *providerCfg, opts...)
 129	if err != nil {
 130		return nil, err
 131	}
 132
 133	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 134	var smallModelProviderCfg *config.ProviderConfig
 135	if smallModelCfg.Provider == providerCfg.ID {
 136		smallModelProviderCfg = providerCfg
 137	} else {
 138		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 139
 140		if smallModelProviderCfg.ID == "" {
 141			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 142		}
 143	}
 144	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 145	if smallModel.ID == "" {
 146		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 147	}
 148
 149	titleOpts := []provider.ProviderClientOption{
 150		provider.WithModel(config.SelectedModelTypeSmall),
 151		provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptTitle, smallModelProviderCfg.ID)),
 152	}
 153	titleProvider, err := provider.NewProvider(cfg, *smallModelProviderCfg, titleOpts...)
 154	if err != nil {
 155		return nil, err
 156	}
 157
 158	summarizeOpts := []provider.ProviderClientOption{
 159		provider.WithModel(config.SelectedModelTypeLarge),
 160		provider.WithSystemMessage(prompt.GetPrompt(cfg, prompt.PromptSummarizer, providerCfg.ID)),
 161	}
 162	summarizeProvider, err := provider.NewProvider(cfg, *providerCfg, summarizeOpts...)
 163	if err != nil {
 164		return nil, err
 165	}
 166
 167	toolFn := func() []tools.BaseTool {
 168		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 169		defer func() {
 170			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 171		}()
 172
 173		cwd := cfg.WorkingDir()
 174		allTools := []tools.BaseTool{
 175			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 176			tools.NewDownloadTool(permissions, cwd),
 177			tools.NewEditTool(lspClients, permissions, history, cwd),
 178			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 179			tools.NewFetchTool(permissions, cwd),
 180			tools.NewGlobTool(cwd),
 181			tools.NewGrepTool(cwd),
 182			tools.NewLsTool(permissions, cwd),
 183			tools.NewSourcegraphTool(),
 184			tools.NewViewTool(lspClients, permissions, cwd),
 185			tools.NewWriteTool(lspClients, permissions, history, cwd),
 186		}
 187
 188		mcpToolsOnce.Do(func() {
 189			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 190		})
 191
 192		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 193			if agentCfg.ID == "coder" {
 194				t = append(t, mcpTools...)
 195				if lspClients.Len() > 0 {
 196					t = append(t, tools.NewDiagnosticsTool(lspClients))
 197				}
 198			}
 199			return t
 200		}
 201
 202		if agentCfg.AllowedTools == nil {
 203			return withCoderTools(allTools)
 204		}
 205
 206		var filteredTools []tools.BaseTool
 207		for _, tool := range allTools {
 208			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 209				filteredTools = append(filteredTools, tool)
 210			}
 211		}
 212		return withCoderTools(filteredTools)
 213	}
 214
 215	return &agent{
 216		Broker:              pubsub.NewBroker[AgentEvent](),
 217		agentCfg:            agentCfg,
 218		provider:            agentProvider,
 219		providerID:          string(providerCfg.ID),
 220		messages:            messages,
 221		sessions:            sessions,
 222		titleProvider:       titleProvider,
 223		summarizeProvider:   summarizeProvider,
 224		summarizeProviderID: string(providerCfg.ID),
 225		agentToolFn:         agentToolFn,
 226		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 227		tools:               csync.NewLazySlice(toolFn),
 228		promptQueue:         csync.NewMap[string, []string](),
 229		cfg:                 cfg,
 230		permissions:         permissions,
 231	}, nil
 232}
 233
 234func (a *agent) Model() catwalk.Model {
 235	return *a.cfg.GetModelByType(a.agentCfg.Model)
 236}
 237
 238func (a *agent) Cancel(sessionID string) {
 239	// Cancel regular requests
 240	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 241		slog.Info("Request cancellation initiated", "session_id", sessionID)
 242		cancel()
 243	}
 244
 245	// Also check for summarize requests
 246	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 247		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 248		cancel()
 249	}
 250
 251	if a.QueuedPrompts(sessionID) > 0 {
 252		slog.Info("Clearing queued prompts", "session_id", sessionID)
 253		a.promptQueue.Del(sessionID)
 254	}
 255}
 256
 257func (a *agent) IsBusy() bool {
 258	var busy bool
 259	for cancelFunc := range a.activeRequests.Seq() {
 260		if cancelFunc != nil {
 261			busy = true
 262			break
 263		}
 264	}
 265	return busy
 266}
 267
 268func (a *agent) IsSessionBusy(sessionID string) bool {
 269	_, busy := a.activeRequests.Get(sessionID)
 270	return busy
 271}
 272
 273func (a *agent) QueuedPrompts(sessionID string) int {
 274	l, ok := a.promptQueue.Get(sessionID)
 275	if !ok {
 276		return 0
 277	}
 278	return len(l)
 279}
 280
 281func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 282	if content == "" {
 283		return nil
 284	}
 285	if a.titleProvider == nil {
 286		return nil
 287	}
 288	session, err := a.sessions.Get(ctx, sessionID)
 289	if err != nil {
 290		return err
 291	}
 292	parts := []message.ContentPart{message.TextContent{
 293		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 294	}}
 295
 296	// Use streaming approach like summarization
 297	response := a.titleProvider.StreamResponse(
 298		ctx,
 299		[]message.Message{
 300			{
 301				Role:  message.User,
 302				Parts: parts,
 303			},
 304		},
 305		nil,
 306	)
 307
 308	var finalResponse *provider.ProviderResponse
 309	for r := range response {
 310		if r.Error != nil {
 311			return r.Error
 312		}
 313		finalResponse = r.Response
 314	}
 315
 316	if finalResponse == nil {
 317		return fmt.Errorf("no response received from title provider")
 318	}
 319
 320	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 321
 322	if idx := strings.Index(title, "</think>"); idx > 0 {
 323		title = title[idx+len("</think>"):]
 324	}
 325
 326	title = strings.TrimSpace(title)
 327	if title == "" {
 328		return nil
 329	}
 330
 331	session.Title = title
 332	_, err = a.sessions.Save(ctx, session)
 333	return err
 334}
 335
 336func (a *agent) err(err error) AgentEvent {
 337	return AgentEvent{
 338		Type:  AgentEventTypeError,
 339		Error: err,
 340	}
 341}
 342
 343func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 344	if !a.Model().SupportsImages && attachments != nil {
 345		attachments = nil
 346	}
 347	events := make(chan AgentEvent, 1)
 348	if a.IsSessionBusy(sessionID) {
 349		existing, ok := a.promptQueue.Get(sessionID)
 350		if !ok {
 351			existing = []string{}
 352		}
 353		existing = append(existing, content)
 354		a.promptQueue.Set(sessionID, existing)
 355		return nil, nil
 356	}
 357
 358	genCtx, cancel := context.WithCancel(ctx)
 359	a.activeRequests.Set(sessionID, cancel)
 360	startTime := time.Now()
 361
 362	go func() {
 363		slog.Debug("Request started", "sessionID", sessionID)
 364		defer log.RecoverPanic("agent.Run", func() {
 365			events <- a.err(fmt.Errorf("panic while running the agent"))
 366		})
 367		var attachmentParts []message.ContentPart
 368		for _, attachment := range attachments {
 369			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 370		}
 371		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 372		if result.Error != nil {
 373			if isCancelledErr(result.Error) {
 374				slog.Error("Request canceled", "sessionID", sessionID)
 375			} else {
 376				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 377				event.Error(result.Error)
 378			}
 379		} else {
 380			slog.Debug("Request completed", "sessionID", sessionID)
 381		}
 382		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 383		a.activeRequests.Del(sessionID)
 384		cancel()
 385		a.Publish(pubsub.CreatedEvent, result)
 386		events <- result
 387		close(events)
 388	}()
 389	a.eventPromptSent(sessionID)
 390	return events, nil
 391}
 392
 393func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 394	// List existing messages; if none, start title generation asynchronously.
 395	msgs, err := a.messages.List(ctx, sessionID)
 396	if err != nil {
 397		return a.err(fmt.Errorf("failed to list messages: %w", err))
 398	}
 399	if len(msgs) == 0 {
 400		go func() {
 401			defer log.RecoverPanic("agent.Run", func() {
 402				slog.Error("panic while generating title")
 403			})
 404			titleErr := a.generateTitle(ctx, sessionID, content)
 405			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 406				slog.Error("failed to generate title", "error", titleErr)
 407			}
 408		}()
 409	}
 410	session, err := a.sessions.Get(ctx, sessionID)
 411	if err != nil {
 412		return a.err(fmt.Errorf("failed to get session: %w", err))
 413	}
 414	if session.SummaryMessageID != "" {
 415		summaryMsgInex := -1
 416		for i, msg := range msgs {
 417			if msg.ID == session.SummaryMessageID {
 418				summaryMsgInex = i
 419				break
 420			}
 421		}
 422		if summaryMsgInex != -1 {
 423			msgs = msgs[summaryMsgInex:]
 424			msgs[0].Role = message.User
 425		}
 426	}
 427
 428	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 429	if err != nil {
 430		return a.err(fmt.Errorf("failed to create user message: %w", err))
 431	}
 432	// Append the new user message to the conversation history.
 433	msgHistory := append(msgs, userMsg)
 434
 435	for {
 436		// Check for cancellation before each iteration
 437		select {
 438		case <-ctx.Done():
 439			return a.err(ctx.Err())
 440		default:
 441			// Continue processing
 442		}
 443		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 444		if err != nil {
 445			if errors.Is(err, context.Canceled) {
 446				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 447				a.messages.Update(context.Background(), agentMessage)
 448				return a.err(ErrRequestCancelled)
 449			}
 450			return a.err(fmt.Errorf("failed to process events: %w", err))
 451		}
 452		if a.cfg.Options.Debug {
 453			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 454		}
 455		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 456			// We are not done, we need to respond with the tool response
 457			msgHistory = append(msgHistory, agentMessage, *toolResults)
 458			// If there are queued prompts, process the next one
 459			nextPrompt, ok := a.promptQueue.Take(sessionID)
 460			if ok {
 461				for _, prompt := range nextPrompt {
 462					// Create a new user message for the queued prompt
 463					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 464					if err != nil {
 465						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 466					}
 467					// Append the new user message to the conversation history
 468					msgHistory = append(msgHistory, userMsg)
 469				}
 470			}
 471
 472			continue
 473		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 474			queuePrompts, ok := a.promptQueue.Take(sessionID)
 475			if ok {
 476				for _, prompt := range queuePrompts {
 477					if prompt == "" {
 478						continue
 479					}
 480					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 481					if err != nil {
 482						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 483					}
 484					msgHistory = append(msgHistory, userMsg)
 485				}
 486				continue
 487			}
 488		}
 489		if agentMessage.FinishReason() == "" {
 490			// Kujtim: could not track down where this is happening but this means its cancelled
 491			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 492			_ = a.messages.Update(context.Background(), agentMessage)
 493			return a.err(ErrRequestCancelled)
 494		}
 495		return AgentEvent{
 496			Type:    AgentEventTypeResponse,
 497			Message: agentMessage,
 498			Done:    true,
 499		}
 500	}
 501}
 502
 503func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 504	parts := []message.ContentPart{message.TextContent{Text: content}}
 505	parts = append(parts, attachmentParts...)
 506	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 507		Role:  message.User,
 508		Parts: parts,
 509	})
 510}
 511
 512func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 513	allTools := slices.Collect(a.tools.Seq())
 514	if a.agentToolFn != nil {
 515		agentTool, agentToolErr := a.agentToolFn()
 516		if agentToolErr != nil {
 517			return nil, agentToolErr
 518		}
 519		allTools = append(allTools, agentTool)
 520	}
 521	return allTools, nil
 522}
 523
 524func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 525	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 526
 527	// Create the assistant message first so the spinner shows immediately
 528	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 529		Role:     message.Assistant,
 530		Parts:    []message.ContentPart{},
 531		Model:    a.Model().ID,
 532		Provider: a.providerID,
 533	})
 534	if err != nil {
 535		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 536	}
 537
 538	allTools, toolsErr := a.getAllTools()
 539	if toolsErr != nil {
 540		return assistantMsg, nil, toolsErr
 541	}
 542	// Now collect tools (which may block on MCP initialization)
 543	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 544
 545	// Add the session and message ID into the context if needed by tools.
 546	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 547
 548	// Process each event in the stream.
 549	for event := range eventChan {
 550		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 551			if errors.Is(processErr, context.Canceled) {
 552				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 553			} else {
 554				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 555			}
 556			return assistantMsg, nil, processErr
 557		}
 558		if ctx.Err() != nil {
 559			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 560			return assistantMsg, nil, ctx.Err()
 561		}
 562	}
 563
 564	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 565	toolCalls := assistantMsg.ToolCalls()
 566	for i, toolCall := range toolCalls {
 567		select {
 568		case <-ctx.Done():
 569			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 570			// Make all future tool calls cancelled
 571			for j := i; j < len(toolCalls); j++ {
 572				toolResults[j] = message.ToolResult{
 573					ToolCallID: toolCalls[j].ID,
 574					Content:    "Tool execution canceled by user",
 575					IsError:    true,
 576				}
 577			}
 578			goto out
 579		default:
 580			// Continue processing
 581			var tool tools.BaseTool
 582			allTools, _ := a.getAllTools()
 583			for _, availableTool := range allTools {
 584				if availableTool.Info().Name == toolCall.Name {
 585					tool = availableTool
 586					break
 587				}
 588			}
 589
 590			// Tool not found
 591			if tool == nil {
 592				toolResults[i] = message.ToolResult{
 593					ToolCallID: toolCall.ID,
 594					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 595					IsError:    true,
 596				}
 597				continue
 598			}
 599
 600			// Run tool in goroutine to allow cancellation
 601			type toolExecResult struct {
 602				response tools.ToolResponse
 603				err      error
 604			}
 605			resultChan := make(chan toolExecResult, 1)
 606
 607			go func() {
 608				response, err := tool.Run(ctx, tools.ToolCall{
 609					ID:    toolCall.ID,
 610					Name:  toolCall.Name,
 611					Input: toolCall.Input,
 612				})
 613				resultChan <- toolExecResult{response: response, err: err}
 614			}()
 615
 616			var toolResponse tools.ToolResponse
 617			var toolErr error
 618
 619			select {
 620			case <-ctx.Done():
 621				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 622				// Mark remaining tool calls as cancelled
 623				for j := i; j < len(toolCalls); j++ {
 624					toolResults[j] = message.ToolResult{
 625						ToolCallID: toolCalls[j].ID,
 626						Content:    "Tool execution canceled by user",
 627						IsError:    true,
 628					}
 629				}
 630				goto out
 631			case result := <-resultChan:
 632				toolResponse = result.response
 633				toolErr = result.err
 634			}
 635
 636			if toolErr != nil {
 637				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 638				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 639					toolResults[i] = message.ToolResult{
 640						ToolCallID: toolCall.ID,
 641						Content:    "Permission denied",
 642						IsError:    true,
 643					}
 644					for j := i + 1; j < len(toolCalls); j++ {
 645						toolResults[j] = message.ToolResult{
 646							ToolCallID: toolCalls[j].ID,
 647							Content:    "Tool execution canceled by user",
 648							IsError:    true,
 649						}
 650					}
 651					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 652					break
 653				}
 654			}
 655			toolResults[i] = message.ToolResult{
 656				ToolCallID: toolCall.ID,
 657				Content:    toolResponse.Content,
 658				Metadata:   toolResponse.Metadata,
 659				IsError:    toolResponse.IsError,
 660			}
 661		}
 662	}
 663out:
 664	if len(toolResults) == 0 {
 665		return assistantMsg, nil, nil
 666	}
 667	parts := make([]message.ContentPart, 0)
 668	for _, tr := range toolResults {
 669		parts = append(parts, tr)
 670	}
 671	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 672		Role:     message.Tool,
 673		Parts:    parts,
 674		Provider: a.providerID,
 675	})
 676	if err != nil {
 677		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 678	}
 679
 680	return assistantMsg, &msg, err
 681}
 682
 683func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 684	msg.AddFinish(finishReason, message, details)
 685	_ = a.messages.Update(ctx, *msg)
 686}
 687
 688func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 689	select {
 690	case <-ctx.Done():
 691		return ctx.Err()
 692	default:
 693		// Continue processing.
 694	}
 695
 696	switch event.Type {
 697	case provider.EventThinkingDelta:
 698		assistantMsg.AppendReasoningContent(event.Thinking)
 699		return a.messages.Update(ctx, *assistantMsg)
 700	case provider.EventSignatureDelta:
 701		assistantMsg.AppendReasoningSignature(event.Signature)
 702		return a.messages.Update(ctx, *assistantMsg)
 703	case provider.EventContentDelta:
 704		assistantMsg.FinishThinking()
 705		assistantMsg.AppendContent(event.Content)
 706		return a.messages.Update(ctx, *assistantMsg)
 707	case provider.EventToolUseStart:
 708		assistantMsg.FinishThinking()
 709		slog.Info("Tool call started", "toolCall", event.ToolCall)
 710		assistantMsg.AddToolCall(*event.ToolCall)
 711		return a.messages.Update(ctx, *assistantMsg)
 712	case provider.EventToolUseDelta:
 713		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 714		return a.messages.Update(ctx, *assistantMsg)
 715	case provider.EventToolUseStop:
 716		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 717		assistantMsg.FinishToolCall(event.ToolCall.ID)
 718		return a.messages.Update(ctx, *assistantMsg)
 719	case provider.EventError:
 720		return event.Error
 721	case provider.EventComplete:
 722		assistantMsg.FinishThinking()
 723		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 724		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 725		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 726			return fmt.Errorf("failed to update message: %w", err)
 727		}
 728		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 729	}
 730
 731	return nil
 732}
 733
 734func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 735	sess, err := a.sessions.Get(ctx, sessionID)
 736	if err != nil {
 737		return fmt.Errorf("failed to get session: %w", err)
 738	}
 739
 740	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 741		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 742		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 743		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 744
 745	a.eventTokensUsed(sessionID, usage, cost)
 746
 747	sess.Cost += cost
 748	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 749	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 750
 751	_, err = a.sessions.Save(ctx, sess)
 752	if err != nil {
 753		return fmt.Errorf("failed to save session: %w", err)
 754	}
 755	return nil
 756}
 757
 758func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 759	if a.summarizeProvider == nil {
 760		return fmt.Errorf("summarize provider not available")
 761	}
 762
 763	// Check if session is busy
 764	if a.IsSessionBusy(sessionID) {
 765		return ErrSessionBusy
 766	}
 767
 768	// Create a new context with cancellation
 769	summarizeCtx, cancel := context.WithCancel(ctx)
 770
 771	// Store the cancel function in activeRequests to allow cancellation
 772	a.activeRequests.Set(sessionID+"-summarize", cancel)
 773
 774	go func() {
 775		defer a.activeRequests.Del(sessionID + "-summarize")
 776		defer cancel()
 777		event := AgentEvent{
 778			Type:     AgentEventTypeSummarize,
 779			Progress: "Starting summarization...",
 780		}
 781
 782		a.Publish(pubsub.CreatedEvent, event)
 783		// Get all messages from the session
 784		msgs, err := a.messages.List(summarizeCtx, sessionID)
 785		if err != nil {
 786			event = AgentEvent{
 787				Type:  AgentEventTypeError,
 788				Error: fmt.Errorf("failed to list messages: %w", err),
 789				Done:  true,
 790			}
 791			a.Publish(pubsub.CreatedEvent, event)
 792			return
 793		}
 794		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 795
 796		if len(msgs) == 0 {
 797			event = AgentEvent{
 798				Type:  AgentEventTypeError,
 799				Error: fmt.Errorf("no messages to summarize"),
 800				Done:  true,
 801			}
 802			a.Publish(pubsub.CreatedEvent, event)
 803			return
 804		}
 805
 806		event = AgentEvent{
 807			Type:     AgentEventTypeSummarize,
 808			Progress: "Analyzing conversation...",
 809		}
 810		a.Publish(pubsub.CreatedEvent, event)
 811
 812		// Add a system message to guide the summarization
 813		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 814
 815		// Create a new message with the summarize prompt
 816		promptMsg := message.Message{
 817			Role:  message.User,
 818			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 819		}
 820
 821		// Append the prompt to the messages
 822		msgsWithPrompt := append(msgs, promptMsg)
 823
 824		event = AgentEvent{
 825			Type:     AgentEventTypeSummarize,
 826			Progress: "Generating summary...",
 827		}
 828
 829		a.Publish(pubsub.CreatedEvent, event)
 830
 831		// Send the messages to the summarize provider
 832		response := a.summarizeProvider.StreamResponse(
 833			summarizeCtx,
 834			msgsWithPrompt,
 835			nil,
 836		)
 837		var finalResponse *provider.ProviderResponse
 838		for r := range response {
 839			if r.Error != nil {
 840				event = AgentEvent{
 841					Type:  AgentEventTypeError,
 842					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 843					Done:  true,
 844				}
 845				a.Publish(pubsub.CreatedEvent, event)
 846				return
 847			}
 848			finalResponse = r.Response
 849		}
 850
 851		summary := strings.TrimSpace(finalResponse.Content)
 852		if summary == "" {
 853			event = AgentEvent{
 854				Type:  AgentEventTypeError,
 855				Error: fmt.Errorf("empty summary returned"),
 856				Done:  true,
 857			}
 858			a.Publish(pubsub.CreatedEvent, event)
 859			return
 860		}
 861		shell := shell.GetPersistentShell(a.cfg.WorkingDir())
 862		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 863		event = AgentEvent{
 864			Type:     AgentEventTypeSummarize,
 865			Progress: "Creating new session...",
 866		}
 867
 868		a.Publish(pubsub.CreatedEvent, event)
 869		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 870		if err != nil {
 871			event = AgentEvent{
 872				Type:  AgentEventTypeError,
 873				Error: fmt.Errorf("failed to get session: %w", err),
 874				Done:  true,
 875			}
 876
 877			a.Publish(pubsub.CreatedEvent, event)
 878			return
 879		}
 880		// Create a message in the new session with the summary
 881		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 882			Role: message.Assistant,
 883			Parts: []message.ContentPart{
 884				message.TextContent{Text: summary},
 885				message.Finish{
 886					Reason: message.FinishReasonEndTurn,
 887					Time:   time.Now().Unix(),
 888				},
 889			},
 890			Model:    a.summarizeProvider.Model().ID,
 891			Provider: a.summarizeProviderID,
 892		})
 893		if err != nil {
 894			event = AgentEvent{
 895				Type:  AgentEventTypeError,
 896				Error: fmt.Errorf("failed to create summary message: %w", err),
 897				Done:  true,
 898			}
 899
 900			a.Publish(pubsub.CreatedEvent, event)
 901			return
 902		}
 903		oldSession.SummaryMessageID = msg.ID
 904		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 905		oldSession.PromptTokens = 0
 906		model := a.summarizeProvider.Model()
 907		usage := finalResponse.Usage
 908		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 909			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 910			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 911			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 912		oldSession.Cost += cost
 913		_, err = a.sessions.Save(summarizeCtx, oldSession)
 914		if err != nil {
 915			event = AgentEvent{
 916				Type:  AgentEventTypeError,
 917				Error: fmt.Errorf("failed to save session: %w", err),
 918				Done:  true,
 919			}
 920			a.Publish(pubsub.CreatedEvent, event)
 921		}
 922
 923		event = AgentEvent{
 924			Type:      AgentEventTypeSummarize,
 925			SessionID: oldSession.ID,
 926			Progress:  "Summary complete",
 927			Done:      true,
 928		}
 929		a.Publish(pubsub.CreatedEvent, event)
 930		// Send final success event with the new session ID
 931	}()
 932
 933	return nil
 934}
 935
 936func (a *agent) ClearQueue(sessionID string) {
 937	if a.QueuedPrompts(sessionID) > 0 {
 938		slog.Info("Clearing queued prompts", "session_id", sessionID)
 939		a.promptQueue.Del(sessionID)
 940	}
 941}
 942
 943func (a *agent) CancelAll() {
 944	if !a.IsBusy() {
 945		return
 946	}
 947	for key := range a.activeRequests.Seq2() {
 948		a.Cancel(key) // key is sessionID
 949	}
 950
 951	timeout := time.After(5 * time.Second)
 952	for a.IsBusy() {
 953		select {
 954		case <-timeout:
 955			return
 956		default:
 957			time.Sleep(200 * time.Millisecond)
 958		}
 959	}
 960}
 961
 962func (a *agent) UpdateModel() error {
 963	// Get current provider configuration
 964	currentProviderCfg := a.cfg.GetProviderForModel(a.agentCfg.Model)
 965	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 966		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 967	}
 968
 969	// Check if provider has changed
 970	if string(currentProviderCfg.ID) != a.providerID {
 971		// Provider changed, need to recreate the main provider
 972		model := a.cfg.GetModelByType(a.agentCfg.Model)
 973		if model.ID == "" {
 974			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 975		}
 976
 977		promptID := agentPromptMap[a.agentCfg.ID]
 978		if promptID == "" {
 979			promptID = prompt.PromptDefault
 980		}
 981
 982		opts := []provider.ProviderClientOption{
 983			provider.WithModel(a.agentCfg.Model),
 984			provider.WithSystemMessage(prompt.GetPrompt(a.cfg, promptID, currentProviderCfg.ID, a.cfg.Options.ContextPaths...)),
 985		}
 986
 987		newProvider, err := provider.NewProvider(a.cfg, *currentProviderCfg, opts...)
 988		if err != nil {
 989			return fmt.Errorf("failed to create new provider: %w", err)
 990		}
 991
 992		// Update the provider and provider ID
 993		a.provider = newProvider
 994		a.providerID = string(currentProviderCfg.ID)
 995	}
 996
 997	// Check if providers have changed for title (small) and summarize (large)
 998	smallModelCfg := a.cfg.Models[config.SelectedModelTypeSmall]
 999	var smallModelProviderCfg config.ProviderConfig
1000	for p := range a.cfg.Providers.Seq() {
1001		if p.ID == smallModelCfg.Provider {
1002			smallModelProviderCfg = p
1003			break
1004		}
1005	}
1006	if smallModelProviderCfg.ID == "" {
1007		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1008	}
1009
1010	largeModelCfg := a.cfg.Models[config.SelectedModelTypeLarge]
1011	var largeModelProviderCfg config.ProviderConfig
1012	for p := range a.cfg.Providers.Seq() {
1013		if p.ID == largeModelCfg.Provider {
1014			largeModelProviderCfg = p
1015			break
1016		}
1017	}
1018	if largeModelProviderCfg.ID == "" {
1019		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1020	}
1021
1022	var maxTitleTokens int64 = 40
1023
1024	// if the max output is too low for the gemini provider it won't return anything
1025	if smallModelCfg.Provider == "gemini" {
1026		maxTitleTokens = 1000
1027	}
1028	// Recreate title provider
1029	titleOpts := []provider.ProviderClientOption{
1030		provider.WithModel(config.SelectedModelTypeSmall),
1031		provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptTitle, smallModelProviderCfg.ID)),
1032		provider.WithMaxTokens(maxTitleTokens),
1033	}
1034	newTitleProvider, err := provider.NewProvider(a.cfg, smallModelProviderCfg, titleOpts...)
1035	if err != nil {
1036		return fmt.Errorf("failed to create new title provider: %w", err)
1037	}
1038	a.titleProvider = newTitleProvider
1039
1040	// Recreate summarize provider if provider changed (now large model)
1041	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1042		largeModel := a.cfg.GetModelByType(config.SelectedModelTypeLarge)
1043		if largeModel == nil {
1044			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1045		}
1046		summarizeOpts := []provider.ProviderClientOption{
1047			provider.WithModel(config.SelectedModelTypeLarge),
1048			provider.WithSystemMessage(prompt.GetPrompt(a.cfg, prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1049		}
1050		newSummarizeProvider, err := provider.NewProvider(a.cfg, largeModelProviderCfg, summarizeOpts...)
1051		if err != nil {
1052			return fmt.Errorf("failed to create new summarize provider: %w", err)
1053		}
1054		a.summarizeProvider = newSummarizeProvider
1055		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1056	}
1057
1058	return nil
1059}