1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/event"
  16	"github.com/charmbracelet/crush/internal/history"
  17	"github.com/charmbracelet/crush/internal/llm/prompt"
  18	"github.com/charmbracelet/crush/internal/llm/provider"
  19	"github.com/charmbracelet/crush/internal/llm/tools"
  20	"github.com/charmbracelet/crush/internal/log"
  21	"github.com/charmbracelet/crush/internal/lsp"
  22	"github.com/charmbracelet/crush/internal/message"
  23	"github.com/charmbracelet/crush/internal/permission"
  24	"github.com/charmbracelet/crush/internal/pubsub"
  25	"github.com/charmbracelet/crush/internal/session"
  26	"github.com/charmbracelet/crush/internal/shell"
  27)
  28
  29type AgentEventType string
  30
  31const (
  32	AgentEventTypeError     AgentEventType = "error"
  33	AgentEventTypeResponse  AgentEventType = "response"
  34	AgentEventTypeSummarize AgentEventType = "summarize"
  35)
  36
  37type AgentEvent struct {
  38	Type    AgentEventType
  39	Message message.Message
  40	Error   error
  41
  42	// When summarizing
  43	SessionID string
  44	Progress  string
  45	Done      bool
  46}
  47
  48type Service interface {
  49	pubsub.Suscriber[AgentEvent]
  50	Model() catwalk.Model
  51	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  52	Cancel(sessionID string)
  53	CancelAll()
  54	IsSessionBusy(sessionID string) bool
  55	IsBusy() bool
  56	Summarize(ctx context.Context, sessionID string) error
  57	UpdateModel() error
  58	QueuedPrompts(sessionID string) int
  59	ClearQueue(sessionID string)
  60}
  61
  62type agent struct {
  63	*pubsub.Broker[AgentEvent]
  64	agentCfg    config.Agent
  65	sessions    session.Service
  66	messages    message.Service
  67	permissions permission.Service
  68	mcpTools    []McpTool
  69
  70	tools *csync.LazySlice[tools.BaseTool]
  71	// We need this to be able to update it when model changes
  72	agentToolFn func() (tools.BaseTool, error)
  73
  74	provider   provider.Provider
  75	providerID string
  76
  77	titleProvider       provider.Provider
  78	summarizeProvider   provider.Provider
  79	summarizeProviderID string
  80
  81	activeRequests *csync.Map[string, context.CancelFunc]
  82	promptQueue    *csync.Map[string, []string]
  83}
  84
  85var agentPromptMap = map[string]prompt.PromptID{
  86	"coder": prompt.PromptCoder,
  87	"task":  prompt.PromptTask,
  88}
  89
  90func NewAgent(
  91	ctx context.Context,
  92	agentCfg config.Agent,
  93	// These services are needed in the tools
  94	permissions permission.Service,
  95	sessions session.Service,
  96	messages message.Service,
  97	history history.Service,
  98	lspClients *csync.Map[string, *lsp.Client],
  99) (Service, error) {
 100	cfg := config.Get()
 101
 102	var agentToolFn func() (tools.BaseTool, error)
 103	if agentCfg.ID == "coder" && slices.Contains(agentCfg.AllowedTools, AgentToolName) {
 104		agentToolFn = func() (tools.BaseTool, error) {
 105			taskAgentCfg := config.Get().Agents["task"]
 106			if taskAgentCfg.ID == "" {
 107				return nil, fmt.Errorf("task agent not found in config")
 108			}
 109			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 110			if err != nil {
 111				return nil, fmt.Errorf("failed to create task agent: %w", err)
 112			}
 113			return NewAgentTool(taskAgent, sessions, messages), nil
 114		}
 115	}
 116
 117	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 118	if providerCfg == nil {
 119		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 120	}
 121	model := config.Get().GetModelByType(agentCfg.Model)
 122
 123	if model == nil {
 124		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 125	}
 126
 127	promptID := agentPromptMap[agentCfg.ID]
 128	if promptID == "" {
 129		promptID = prompt.PromptDefault
 130	}
 131	opts := []provider.ProviderClientOption{
 132		provider.WithModel(agentCfg.Model),
 133		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 134	}
 135	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 136	if err != nil {
 137		return nil, err
 138	}
 139
 140	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 141	var smallModelProviderCfg *config.ProviderConfig
 142	if smallModelCfg.Provider == providerCfg.ID {
 143		smallModelProviderCfg = providerCfg
 144	} else {
 145		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 146
 147		if smallModelProviderCfg.ID == "" {
 148			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 149		}
 150	}
 151	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 152	if smallModel.ID == "" {
 153		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 154	}
 155
 156	titleOpts := []provider.ProviderClientOption{
 157		provider.WithModel(config.SelectedModelTypeSmall),
 158		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 159	}
 160	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 161	if err != nil {
 162		return nil, err
 163	}
 164
 165	summarizeOpts := []provider.ProviderClientOption{
 166		provider.WithModel(config.SelectedModelTypeLarge),
 167		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 168	}
 169	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 170	if err != nil {
 171		return nil, err
 172	}
 173
 174	toolFn := func() []tools.BaseTool {
 175		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 176		defer func() {
 177			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 178		}()
 179
 180		cwd := cfg.WorkingDir()
 181		allTools := []tools.BaseTool{
 182			tools.NewBashTool(permissions, cwd, cfg.Options.Attribution),
 183			tools.NewDownloadTool(permissions, cwd),
 184			tools.NewEditTool(lspClients, permissions, history, cwd),
 185			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 186			tools.NewFetchTool(permissions, cwd),
 187			tools.NewGlobTool(cwd),
 188			tools.NewGrepTool(cwd),
 189			tools.NewLsTool(permissions, cwd),
 190			tools.NewSourcegraphTool(),
 191			tools.NewViewTool(lspClients, permissions, cwd),
 192			tools.NewWriteTool(lspClients, permissions, history, cwd),
 193		}
 194
 195		mcpToolsOnce.Do(func() {
 196			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 197		})
 198
 199		withCoderTools := func(t []tools.BaseTool) []tools.BaseTool {
 200			if agentCfg.ID == "coder" {
 201				t = append(t, mcpTools...)
 202				if lspClients.Len() > 0 {
 203					t = append(t, tools.NewDiagnosticsTool(lspClients))
 204				}
 205			}
 206			return t
 207		}
 208
 209		if agentCfg.AllowedTools == nil {
 210			return withCoderTools(allTools)
 211		}
 212
 213		var filteredTools []tools.BaseTool
 214		for _, tool := range allTools {
 215			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 216				filteredTools = append(filteredTools, tool)
 217			}
 218		}
 219		return withCoderTools(filteredTools)
 220	}
 221
 222	return &agent{
 223		Broker:              pubsub.NewBroker[AgentEvent](),
 224		agentCfg:            agentCfg,
 225		provider:            agentProvider,
 226		providerID:          string(providerCfg.ID),
 227		messages:            messages,
 228		sessions:            sessions,
 229		titleProvider:       titleProvider,
 230		summarizeProvider:   summarizeProvider,
 231		summarizeProviderID: string(providerCfg.ID),
 232		agentToolFn:         agentToolFn,
 233		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 234		tools:               csync.NewLazySlice(toolFn),
 235		promptQueue:         csync.NewMap[string, []string](),
 236		permissions:         permissions,
 237	}, nil
 238}
 239
 240func (a *agent) Model() catwalk.Model {
 241	return *config.Get().GetModelByType(a.agentCfg.Model)
 242}
 243
 244func (a *agent) Cancel(sessionID string) {
 245	// Cancel regular requests
 246	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 247		slog.Info("Request cancellation initiated", "session_id", sessionID)
 248		cancel()
 249	}
 250
 251	// Also check for summarize requests
 252	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 253		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 254		cancel()
 255	}
 256
 257	if a.QueuedPrompts(sessionID) > 0 {
 258		slog.Info("Clearing queued prompts", "session_id", sessionID)
 259		a.promptQueue.Del(sessionID)
 260	}
 261}
 262
 263func (a *agent) IsBusy() bool {
 264	var busy bool
 265	for cancelFunc := range a.activeRequests.Seq() {
 266		if cancelFunc != nil {
 267			busy = true
 268			break
 269		}
 270	}
 271	return busy
 272}
 273
 274func (a *agent) IsSessionBusy(sessionID string) bool {
 275	_, busy := a.activeRequests.Get(sessionID)
 276	return busy
 277}
 278
 279func (a *agent) QueuedPrompts(sessionID string) int {
 280	l, ok := a.promptQueue.Get(sessionID)
 281	if !ok {
 282		return 0
 283	}
 284	return len(l)
 285}
 286
 287func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 288	if content == "" {
 289		return nil
 290	}
 291	if a.titleProvider == nil {
 292		return nil
 293	}
 294	session, err := a.sessions.Get(ctx, sessionID)
 295	if err != nil {
 296		return err
 297	}
 298	parts := []message.ContentPart{message.TextContent{
 299		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 300	}}
 301
 302	// Use streaming approach like summarization
 303	response := a.titleProvider.StreamResponse(
 304		ctx,
 305		[]message.Message{
 306			{
 307				Role:  message.User,
 308				Parts: parts,
 309			},
 310		},
 311		nil,
 312	)
 313
 314	var finalResponse *provider.ProviderResponse
 315	for r := range response {
 316		if r.Error != nil {
 317			return r.Error
 318		}
 319		finalResponse = r.Response
 320	}
 321
 322	if finalResponse == nil {
 323		return fmt.Errorf("no response received from title provider")
 324	}
 325
 326	title := strings.ReplaceAll(finalResponse.Content, "\n", " ")
 327
 328	if idx := strings.Index(title, "</think>"); idx > 0 {
 329		title = title[idx+len("</think>"):]
 330	}
 331
 332	title = strings.TrimSpace(title)
 333	if title == "" {
 334		return nil
 335	}
 336
 337	session.Title = title
 338	_, err = a.sessions.Save(ctx, session)
 339	return err
 340}
 341
 342func (a *agent) err(err error) AgentEvent {
 343	return AgentEvent{
 344		Type:  AgentEventTypeError,
 345		Error: err,
 346	}
 347}
 348
 349func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 350	if !a.Model().SupportsImages && attachments != nil {
 351		attachments = nil
 352	}
 353	events := make(chan AgentEvent, 1)
 354	if a.IsSessionBusy(sessionID) {
 355		existing, ok := a.promptQueue.Get(sessionID)
 356		if !ok {
 357			existing = []string{}
 358		}
 359		existing = append(existing, content)
 360		a.promptQueue.Set(sessionID, existing)
 361		return nil, nil
 362	}
 363
 364	genCtx, cancel := context.WithCancel(ctx)
 365	a.activeRequests.Set(sessionID, cancel)
 366	startTime := time.Now()
 367
 368	go func() {
 369		slog.Debug("Request started", "sessionID", sessionID)
 370		defer log.RecoverPanic("agent.Run", func() {
 371			events <- a.err(fmt.Errorf("panic while running the agent"))
 372		})
 373		var attachmentParts []message.ContentPart
 374		for _, attachment := range attachments {
 375			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 376		}
 377		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 378		if result.Error != nil {
 379			if isCancelledErr(result.Error) {
 380				slog.Error("Request canceled", "sessionID", sessionID)
 381			} else {
 382				slog.Error("Request errored", "sessionID", sessionID, "error", result.Error.Error())
 383				event.Error(result.Error)
 384			}
 385		} else {
 386			slog.Debug("Request completed", "sessionID", sessionID)
 387		}
 388		a.eventPromptResponded(sessionID, time.Since(startTime).Truncate(time.Second))
 389		a.activeRequests.Del(sessionID)
 390		cancel()
 391		a.Publish(pubsub.CreatedEvent, result)
 392		events <- result
 393		close(events)
 394	}()
 395	a.eventPromptSent(sessionID)
 396	return events, nil
 397}
 398
 399func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 400	cfg := config.Get()
 401	// List existing messages; if none, start title generation asynchronously.
 402	msgs, err := a.messages.List(ctx, sessionID)
 403	if err != nil {
 404		return a.err(fmt.Errorf("failed to list messages: %w", err))
 405	}
 406	if len(msgs) == 0 {
 407		go func() {
 408			defer log.RecoverPanic("agent.Run", func() {
 409				slog.Error("panic while generating title")
 410			})
 411			titleErr := a.generateTitle(ctx, sessionID, content)
 412			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 413				slog.Error("failed to generate title", "error", titleErr)
 414			}
 415		}()
 416	}
 417	session, err := a.sessions.Get(ctx, sessionID)
 418	if err != nil {
 419		return a.err(fmt.Errorf("failed to get session: %w", err))
 420	}
 421	if session.SummaryMessageID != "" {
 422		summaryMsgInex := -1
 423		for i, msg := range msgs {
 424			if msg.ID == session.SummaryMessageID {
 425				summaryMsgInex = i
 426				break
 427			}
 428		}
 429		if summaryMsgInex != -1 {
 430			msgs = msgs[summaryMsgInex:]
 431			msgs[0].Role = message.User
 432		}
 433	}
 434
 435	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 436	if err != nil {
 437		return a.err(fmt.Errorf("failed to create user message: %w", err))
 438	}
 439	// Append the new user message to the conversation history.
 440	msgHistory := append(msgs, userMsg)
 441
 442	for {
 443		// Check for cancellation before each iteration
 444		select {
 445		case <-ctx.Done():
 446			return a.err(ctx.Err())
 447		default:
 448			// Continue processing
 449		}
 450		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 451		if err != nil {
 452			if errors.Is(err, context.Canceled) {
 453				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 454				a.messages.Update(context.Background(), agentMessage)
 455				return a.err(ErrRequestCancelled)
 456			}
 457			return a.err(fmt.Errorf("failed to process events: %w", err))
 458		}
 459		if cfg.Options.Debug {
 460			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 461		}
 462		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 463			// We are not done, we need to respond with the tool response
 464			msgHistory = append(msgHistory, agentMessage, *toolResults)
 465			// If there are queued prompts, process the next one
 466			nextPrompt, ok := a.promptQueue.Take(sessionID)
 467			if ok {
 468				for _, prompt := range nextPrompt {
 469					// Create a new user message for the queued prompt
 470					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 471					if err != nil {
 472						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 473					}
 474					// Append the new user message to the conversation history
 475					msgHistory = append(msgHistory, userMsg)
 476				}
 477			}
 478
 479			continue
 480		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 481			queuePrompts, ok := a.promptQueue.Take(sessionID)
 482			if ok {
 483				for _, prompt := range queuePrompts {
 484					if prompt == "" {
 485						continue
 486					}
 487					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 488					if err != nil {
 489						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 490					}
 491					msgHistory = append(msgHistory, userMsg)
 492				}
 493				continue
 494			}
 495		}
 496		if agentMessage.FinishReason() == "" {
 497			// Kujtim: could not track down where this is happening but this means its cancelled
 498			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 499			_ = a.messages.Update(context.Background(), agentMessage)
 500			return a.err(ErrRequestCancelled)
 501		}
 502		return AgentEvent{
 503			Type:    AgentEventTypeResponse,
 504			Message: agentMessage,
 505			Done:    true,
 506		}
 507	}
 508}
 509
 510func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 511	parts := []message.ContentPart{message.TextContent{Text: content}}
 512	parts = append(parts, attachmentParts...)
 513	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 514		Role:  message.User,
 515		Parts: parts,
 516	})
 517}
 518
 519func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 520	allTools := slices.Collect(a.tools.Seq())
 521	if a.agentToolFn != nil {
 522		agentTool, agentToolErr := a.agentToolFn()
 523		if agentToolErr != nil {
 524			return nil, agentToolErr
 525		}
 526		allTools = append(allTools, agentTool)
 527	}
 528	return allTools, nil
 529}
 530
 531func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 532	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 533
 534	// Create the assistant message first so the spinner shows immediately
 535	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 536		Role:     message.Assistant,
 537		Parts:    []message.ContentPart{},
 538		Model:    a.Model().ID,
 539		Provider: a.providerID,
 540	})
 541	if err != nil {
 542		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 543	}
 544
 545	allTools, toolsErr := a.getAllTools()
 546	if toolsErr != nil {
 547		return assistantMsg, nil, toolsErr
 548	}
 549	// Now collect tools (which may block on MCP initialization)
 550	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 551
 552	// Add the session and message ID into the context if needed by tools.
 553	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 554
 555	// Process each event in the stream.
 556	for event := range eventChan {
 557		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 558			if errors.Is(processErr, context.Canceled) {
 559				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 560			} else {
 561				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 562			}
 563			return assistantMsg, nil, processErr
 564		}
 565		if ctx.Err() != nil {
 566			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 567			return assistantMsg, nil, ctx.Err()
 568		}
 569	}
 570
 571	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 572	toolCalls := assistantMsg.ToolCalls()
 573	for i, toolCall := range toolCalls {
 574		select {
 575		case <-ctx.Done():
 576			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 577			// Make all future tool calls cancelled
 578			for j := i; j < len(toolCalls); j++ {
 579				toolResults[j] = message.ToolResult{
 580					ToolCallID: toolCalls[j].ID,
 581					Content:    "Tool execution canceled by user",
 582					IsError:    true,
 583				}
 584			}
 585			goto out
 586		default:
 587			// Continue processing
 588			var tool tools.BaseTool
 589			allTools, _ := a.getAllTools()
 590			for _, availableTool := range allTools {
 591				if availableTool.Info().Name == toolCall.Name {
 592					tool = availableTool
 593					break
 594				}
 595			}
 596
 597			// Tool not found
 598			if tool == nil {
 599				toolResults[i] = message.ToolResult{
 600					ToolCallID: toolCall.ID,
 601					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 602					IsError:    true,
 603				}
 604				continue
 605			}
 606
 607			// Run tool in goroutine to allow cancellation
 608			type toolExecResult struct {
 609				response tools.ToolResponse
 610				err      error
 611			}
 612			resultChan := make(chan toolExecResult, 1)
 613
 614			go func() {
 615				response, err := tool.Run(ctx, tools.ToolCall{
 616					ID:    toolCall.ID,
 617					Name:  toolCall.Name,
 618					Input: toolCall.Input,
 619				})
 620				resultChan <- toolExecResult{response: response, err: err}
 621			}()
 622
 623			var toolResponse tools.ToolResponse
 624			var toolErr error
 625
 626			select {
 627			case <-ctx.Done():
 628				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 629				// Mark remaining tool calls as cancelled
 630				for j := i; j < len(toolCalls); j++ {
 631					toolResults[j] = message.ToolResult{
 632						ToolCallID: toolCalls[j].ID,
 633						Content:    "Tool execution canceled by user",
 634						IsError:    true,
 635					}
 636				}
 637				goto out
 638			case result := <-resultChan:
 639				toolResponse = result.response
 640				toolErr = result.err
 641			}
 642
 643			if toolErr != nil {
 644				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 645				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 646					toolResults[i] = message.ToolResult{
 647						ToolCallID: toolCall.ID,
 648						Content:    "Permission denied",
 649						IsError:    true,
 650					}
 651					for j := i + 1; j < len(toolCalls); j++ {
 652						toolResults[j] = message.ToolResult{
 653							ToolCallID: toolCalls[j].ID,
 654							Content:    "Tool execution canceled by user",
 655							IsError:    true,
 656						}
 657					}
 658					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 659					break
 660				}
 661			}
 662			toolResults[i] = message.ToolResult{
 663				ToolCallID: toolCall.ID,
 664				Content:    toolResponse.Content,
 665				Metadata:   toolResponse.Metadata,
 666				IsError:    toolResponse.IsError,
 667			}
 668		}
 669	}
 670out:
 671	if len(toolResults) == 0 {
 672		return assistantMsg, nil, nil
 673	}
 674	parts := make([]message.ContentPart, 0)
 675	for _, tr := range toolResults {
 676		parts = append(parts, tr)
 677	}
 678	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 679		Role:     message.Tool,
 680		Parts:    parts,
 681		Provider: a.providerID,
 682	})
 683	if err != nil {
 684		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 685	}
 686
 687	return assistantMsg, &msg, err
 688}
 689
 690func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 691	msg.AddFinish(finishReason, message, details)
 692	_ = a.messages.Update(ctx, *msg)
 693}
 694
 695func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 696	select {
 697	case <-ctx.Done():
 698		return ctx.Err()
 699	default:
 700		// Continue processing.
 701	}
 702
 703	switch event.Type {
 704	case provider.EventThinkingDelta:
 705		assistantMsg.AppendReasoningContent(event.Thinking)
 706		return a.messages.Update(ctx, *assistantMsg)
 707	case provider.EventSignatureDelta:
 708		assistantMsg.AppendReasoningSignature(event.Signature)
 709		return a.messages.Update(ctx, *assistantMsg)
 710	case provider.EventContentDelta:
 711		assistantMsg.FinishThinking()
 712		assistantMsg.AppendContent(event.Content)
 713		return a.messages.Update(ctx, *assistantMsg)
 714	case provider.EventToolUseStart:
 715		assistantMsg.FinishThinking()
 716		slog.Info("Tool call started", "toolCall", event.ToolCall)
 717		assistantMsg.AddToolCall(*event.ToolCall)
 718		return a.messages.Update(ctx, *assistantMsg)
 719	case provider.EventToolUseDelta:
 720		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 721		return a.messages.Update(ctx, *assistantMsg)
 722	case provider.EventToolUseStop:
 723		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 724		assistantMsg.FinishToolCall(event.ToolCall.ID)
 725		return a.messages.Update(ctx, *assistantMsg)
 726	case provider.EventError:
 727		return event.Error
 728	case provider.EventComplete:
 729		assistantMsg.FinishThinking()
 730		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 731		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 732		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 733			return fmt.Errorf("failed to update message: %w", err)
 734		}
 735		return a.trackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 736	}
 737
 738	return nil
 739}
 740
 741func (a *agent) trackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 742	sess, err := a.sessions.Get(ctx, sessionID)
 743	if err != nil {
 744		return fmt.Errorf("failed to get session: %w", err)
 745	}
 746
 747	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 748		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 749		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 750		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 751
 752	a.eventTokensUsed(sessionID, usage, cost)
 753
 754	sess.Cost += cost
 755	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 756	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 757
 758	_, err = a.sessions.Save(ctx, sess)
 759	if err != nil {
 760		return fmt.Errorf("failed to save session: %w", err)
 761	}
 762	return nil
 763}
 764
 765func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 766	if a.summarizeProvider == nil {
 767		return fmt.Errorf("summarize provider not available")
 768	}
 769
 770	// Check if session is busy
 771	if a.IsSessionBusy(sessionID) {
 772		return ErrSessionBusy
 773	}
 774
 775	// Create a new context with cancellation
 776	summarizeCtx, cancel := context.WithCancel(ctx)
 777
 778	// Store the cancel function in activeRequests to allow cancellation
 779	a.activeRequests.Set(sessionID+"-summarize", cancel)
 780
 781	go func() {
 782		defer a.activeRequests.Del(sessionID + "-summarize")
 783		defer cancel()
 784		event := AgentEvent{
 785			Type:     AgentEventTypeSummarize,
 786			Progress: "Starting summarization...",
 787		}
 788
 789		a.Publish(pubsub.CreatedEvent, event)
 790		// Get all messages from the session
 791		msgs, err := a.messages.List(summarizeCtx, sessionID)
 792		if err != nil {
 793			event = AgentEvent{
 794				Type:  AgentEventTypeError,
 795				Error: fmt.Errorf("failed to list messages: %w", err),
 796				Done:  true,
 797			}
 798			a.Publish(pubsub.CreatedEvent, event)
 799			return
 800		}
 801		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 802
 803		if len(msgs) == 0 {
 804			event = AgentEvent{
 805				Type:  AgentEventTypeError,
 806				Error: fmt.Errorf("no messages to summarize"),
 807				Done:  true,
 808			}
 809			a.Publish(pubsub.CreatedEvent, event)
 810			return
 811		}
 812
 813		event = AgentEvent{
 814			Type:     AgentEventTypeSummarize,
 815			Progress: "Analyzing conversation...",
 816		}
 817		a.Publish(pubsub.CreatedEvent, event)
 818
 819		// Add a system message to guide the summarization
 820		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 821
 822		// Create a new message with the summarize prompt
 823		promptMsg := message.Message{
 824			Role:  message.User,
 825			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 826		}
 827
 828		// Append the prompt to the messages
 829		msgsWithPrompt := append(msgs, promptMsg)
 830
 831		event = AgentEvent{
 832			Type:     AgentEventTypeSummarize,
 833			Progress: "Generating summary...",
 834		}
 835
 836		a.Publish(pubsub.CreatedEvent, event)
 837
 838		// Send the messages to the summarize provider
 839		response := a.summarizeProvider.StreamResponse(
 840			summarizeCtx,
 841			msgsWithPrompt,
 842			nil,
 843		)
 844		var finalResponse *provider.ProviderResponse
 845		for r := range response {
 846			if r.Error != nil {
 847				event = AgentEvent{
 848					Type:  AgentEventTypeError,
 849					Error: fmt.Errorf("failed to summarize: %w", r.Error),
 850					Done:  true,
 851				}
 852				a.Publish(pubsub.CreatedEvent, event)
 853				return
 854			}
 855			finalResponse = r.Response
 856		}
 857
 858		summary := strings.TrimSpace(finalResponse.Content)
 859		if summary == "" {
 860			event = AgentEvent{
 861				Type:  AgentEventTypeError,
 862				Error: fmt.Errorf("empty summary returned"),
 863				Done:  true,
 864			}
 865			a.Publish(pubsub.CreatedEvent, event)
 866			return
 867		}
 868		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 869		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 870		event = AgentEvent{
 871			Type:     AgentEventTypeSummarize,
 872			Progress: "Creating new session...",
 873		}
 874
 875		a.Publish(pubsub.CreatedEvent, event)
 876		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 877		if err != nil {
 878			event = AgentEvent{
 879				Type:  AgentEventTypeError,
 880				Error: fmt.Errorf("failed to get session: %w", err),
 881				Done:  true,
 882			}
 883
 884			a.Publish(pubsub.CreatedEvent, event)
 885			return
 886		}
 887		// Create a message in the new session with the summary
 888		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 889			Role: message.Assistant,
 890			Parts: []message.ContentPart{
 891				message.TextContent{Text: summary},
 892				message.Finish{
 893					Reason: message.FinishReasonEndTurn,
 894					Time:   time.Now().Unix(),
 895				},
 896			},
 897			Model:    a.summarizeProvider.Model().ID,
 898			Provider: a.summarizeProviderID,
 899		})
 900		if err != nil {
 901			event = AgentEvent{
 902				Type:  AgentEventTypeError,
 903				Error: fmt.Errorf("failed to create summary message: %w", err),
 904				Done:  true,
 905			}
 906
 907			a.Publish(pubsub.CreatedEvent, event)
 908			return
 909		}
 910		oldSession.SummaryMessageID = msg.ID
 911		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 912		oldSession.PromptTokens = 0
 913		model := a.summarizeProvider.Model()
 914		usage := finalResponse.Usage
 915		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 916			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 917			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 918			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 919		oldSession.Cost += cost
 920		_, err = a.sessions.Save(summarizeCtx, oldSession)
 921		if err != nil {
 922			event = AgentEvent{
 923				Type:  AgentEventTypeError,
 924				Error: fmt.Errorf("failed to save session: %w", err),
 925				Done:  true,
 926			}
 927			a.Publish(pubsub.CreatedEvent, event)
 928		}
 929
 930		event = AgentEvent{
 931			Type:      AgentEventTypeSummarize,
 932			SessionID: oldSession.ID,
 933			Progress:  "Summary complete",
 934			Done:      true,
 935		}
 936		a.Publish(pubsub.CreatedEvent, event)
 937		// Send final success event with the new session ID
 938	}()
 939
 940	return nil
 941}
 942
 943func (a *agent) ClearQueue(sessionID string) {
 944	if a.QueuedPrompts(sessionID) > 0 {
 945		slog.Info("Clearing queued prompts", "session_id", sessionID)
 946		a.promptQueue.Del(sessionID)
 947	}
 948}
 949
 950func (a *agent) CancelAll() {
 951	if !a.IsBusy() {
 952		return
 953	}
 954	for key := range a.activeRequests.Seq2() {
 955		a.Cancel(key) // key is sessionID
 956	}
 957
 958	timeout := time.After(5 * time.Second)
 959	for a.IsBusy() {
 960		select {
 961		case <-timeout:
 962			return
 963		default:
 964			time.Sleep(200 * time.Millisecond)
 965		}
 966	}
 967}
 968
 969func (a *agent) UpdateModel() error {
 970	cfg := config.Get()
 971
 972	// Get current provider configuration
 973	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 974	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 975		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 976	}
 977
 978	// Check if provider has changed
 979	if string(currentProviderCfg.ID) != a.providerID {
 980		// Provider changed, need to recreate the main provider
 981		model := cfg.GetModelByType(a.agentCfg.Model)
 982		if model.ID == "" {
 983			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 984		}
 985
 986		promptID := agentPromptMap[a.agentCfg.ID]
 987		if promptID == "" {
 988			promptID = prompt.PromptDefault
 989		}
 990
 991		opts := []provider.ProviderClientOption{
 992			provider.WithModel(a.agentCfg.Model),
 993			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 994		}
 995
 996		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 997		if err != nil {
 998			return fmt.Errorf("failed to create new provider: %w", err)
 999		}
1000
1001		// Update the provider and provider ID
1002		a.provider = newProvider
1003		a.providerID = string(currentProviderCfg.ID)
1004	}
1005
1006	// Check if providers have changed for title (small) and summarize (large)
1007	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
1008	var smallModelProviderCfg config.ProviderConfig
1009	for p := range cfg.Providers.Seq() {
1010		if p.ID == smallModelCfg.Provider {
1011			smallModelProviderCfg = p
1012			break
1013		}
1014	}
1015	if smallModelProviderCfg.ID == "" {
1016		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
1017	}
1018
1019	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1020	var largeModelProviderCfg config.ProviderConfig
1021	for p := range cfg.Providers.Seq() {
1022		if p.ID == largeModelCfg.Provider {
1023			largeModelProviderCfg = p
1024			break
1025		}
1026	}
1027	if largeModelProviderCfg.ID == "" {
1028		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1029	}
1030
1031	var maxTitleTokens int64 = 40
1032
1033	// if the max output is too low for the gemini provider it won't return anything
1034	if smallModelCfg.Provider == "gemini" {
1035		maxTitleTokens = 1000
1036	}
1037	// Recreate title provider
1038	titleOpts := []provider.ProviderClientOption{
1039		provider.WithModel(config.SelectedModelTypeSmall),
1040		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1041		provider.WithMaxTokens(maxTitleTokens),
1042	}
1043	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1044	if err != nil {
1045		return fmt.Errorf("failed to create new title provider: %w", err)
1046	}
1047	a.titleProvider = newTitleProvider
1048
1049	// Recreate summarize provider if provider changed (now large model)
1050	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1051		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1052		if largeModel == nil {
1053			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1054		}
1055		summarizeOpts := []provider.ProviderClientOption{
1056			provider.WithModel(config.SelectedModelTypeLarge),
1057			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1058		}
1059		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1060		if err != nil {
1061			return fmt.Errorf("failed to create new summarize provider: %w", err)
1062		}
1063		a.summarizeProvider = newSummarizeProvider
1064		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1065	}
1066
1067	return nil
1068}