agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75
  76	provider   provider.Provider
  77	providerID string
  78
  79	titleProvider       provider.Provider
  80	summarizeProvider   provider.Provider
  81	summarizeProviderID string
  82
  83	activeRequests *csync.Map[string, context.CancelFunc]
  84
  85	promptQueue *csync.Map[string, []string]
  86}
  87
  88var agentPromptMap = map[string]prompt.PromptID{
  89	"coder": prompt.PromptCoder,
  90	"task":  prompt.PromptTask,
  91}
  92
  93func NewAgent(
  94	ctx context.Context,
  95	agentCfg config.Agent,
  96	// These services are needed in the tools
  97	permissions permission.Service,
  98	sessions session.Service,
  99	messages message.Service,
 100	history history.Service,
 101	lspClients map[string]*lsp.Client,
 102) (Service, error) {
 103	cfg := config.Get()
 104
 105	var agentTool tools.BaseTool
 106	if agentCfg.ID == "coder" {
 107		taskAgentCfg := config.Get().Agents["task"]
 108		if taskAgentCfg.ID == "" {
 109			return nil, fmt.Errorf("task agent not found in config")
 110		}
 111		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 112		if err != nil {
 113			return nil, fmt.Errorf("failed to create task agent: %w", err)
 114		}
 115
 116		agentTool = NewAgentTool(taskAgent, sessions, messages)
 117	}
 118
 119	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 120	if providerCfg == nil {
 121		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 122	}
 123	model := config.Get().GetModelByType(agentCfg.Model)
 124
 125	if model == nil {
 126		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 127	}
 128
 129	promptID := agentPromptMap[agentCfg.ID]
 130	if promptID == "" {
 131		promptID = prompt.PromptDefault
 132	}
 133	opts := []provider.ProviderClientOption{
 134		provider.WithModel(agentCfg.Model),
 135		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 136	}
 137	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 138	if err != nil {
 139		return nil, err
 140	}
 141
 142	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 143	var smallModelProviderCfg *config.ProviderConfig
 144	if smallModelCfg.Provider == providerCfg.ID {
 145		smallModelProviderCfg = providerCfg
 146	} else {
 147		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 148
 149		if smallModelProviderCfg.ID == "" {
 150			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 151		}
 152	}
 153	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 154	if smallModel.ID == "" {
 155		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 156	}
 157
 158	titleOpts := []provider.ProviderClientOption{
 159		provider.WithModel(config.SelectedModelTypeSmall),
 160		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 161	}
 162	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 163	if err != nil {
 164		return nil, err
 165	}
 166
 167	summarizeOpts := []provider.ProviderClientOption{
 168		provider.WithModel(config.SelectedModelTypeLarge),
 169		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 170	}
 171	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 172	if err != nil {
 173		return nil, err
 174	}
 175
 176	toolFn := func() []tools.BaseTool {
 177		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 178		defer func() {
 179			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 180		}()
 181
 182		cwd := cfg.WorkingDir()
 183		allTools := []tools.BaseTool{
 184			tools.NewBashTool(permissions, cwd),
 185			tools.NewDownloadTool(permissions, cwd),
 186			tools.NewEditTool(lspClients, permissions, history, cwd),
 187			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 188			tools.NewFetchTool(permissions, cwd),
 189			tools.NewGlobTool(cwd),
 190			tools.NewGrepTool(cwd),
 191			tools.NewLsTool(permissions, cwd),
 192			tools.NewSourcegraphTool(),
 193			tools.NewViewTool(lspClients, permissions, cwd),
 194			tools.NewWriteTool(lspClients, permissions, history, cwd),
 195		}
 196
 197		mcpToolsOnce.Do(func() {
 198			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 199		})
 200		allTools = append(allTools, mcpTools...)
 201
 202		if len(lspClients) > 0 {
 203			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
 204		}
 205
 206		if agentTool != nil {
 207			allTools = append(allTools, agentTool)
 208		}
 209
 210		if agentCfg.AllowedTools == nil {
 211			return allTools
 212		}
 213
 214		var filteredTools []tools.BaseTool
 215		for _, tool := range allTools {
 216			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 217				filteredTools = append(filteredTools, tool)
 218			}
 219		}
 220		return filteredTools
 221	}
 222
 223	return &agent{
 224		Broker:              pubsub.NewBroker[AgentEvent](),
 225		agentCfg:            agentCfg,
 226		provider:            agentProvider,
 227		providerID:          string(providerCfg.ID),
 228		messages:            messages,
 229		sessions:            sessions,
 230		titleProvider:       titleProvider,
 231		summarizeProvider:   summarizeProvider,
 232		summarizeProviderID: string(providerCfg.ID),
 233		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 234		tools:               csync.NewLazySlice(toolFn),
 235		promptQueue:         csync.NewMap[string, []string](),
 236	}, nil
 237}
 238
 239func (a *agent) Model() catwalk.Model {
 240	return *config.Get().GetModelByType(a.agentCfg.Model)
 241}
 242
 243func (a *agent) Cancel(sessionID string) {
 244	// Cancel regular requests
 245	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 246		slog.Info("Request cancellation initiated", "session_id", sessionID)
 247		cancel()
 248	}
 249
 250	// Also check for summarize requests
 251	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 252		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 253		cancel()
 254	}
 255
 256	if a.QueuedPrompts(sessionID) > 0 {
 257		slog.Info("Clearing queued prompts", "session_id", sessionID)
 258		a.promptQueue.Del(sessionID)
 259	}
 260}
 261
 262func (a *agent) IsBusy() bool {
 263	var busy bool
 264	for cancelFunc := range a.activeRequests.Seq() {
 265		if cancelFunc != nil {
 266			busy = true
 267			break
 268		}
 269	}
 270	return busy
 271}
 272
 273func (a *agent) IsSessionBusy(sessionID string) bool {
 274	_, busy := a.activeRequests.Get(sessionID)
 275	return busy
 276}
 277
 278func (a *agent) QueuedPrompts(sessionID string) int {
 279	l, ok := a.promptQueue.Get(sessionID)
 280	if !ok {
 281		return 0
 282	}
 283	return len(l)
 284}
 285
 286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 287	if content == "" {
 288		return nil
 289	}
 290	if a.titleProvider == nil {
 291		return nil
 292	}
 293	session, err := a.sessions.Get(ctx, sessionID)
 294	if err != nil {
 295		return err
 296	}
 297	parts := []message.ContentPart{message.TextContent{
 298		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 299	}}
 300
 301	// Use streaming approach like summarization
 302	response := a.titleProvider.StreamResponse(
 303		ctx,
 304		[]message.Message{
 305			{
 306				Role:  message.User,
 307				Parts: parts,
 308			},
 309		},
 310		nil,
 311	)
 312
 313	var finalResponse *provider.ProviderResponse
 314	for r := range response {
 315		if r.Error != nil {
 316			return r.Error
 317		}
 318		finalResponse = r.Response
 319	}
 320
 321	if finalResponse == nil {
 322		return fmt.Errorf("no response received from title provider")
 323	}
 324
 325	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 326	if title == "" {
 327		return nil
 328	}
 329
 330	session.Title = title
 331	_, err = a.sessions.Save(ctx, session)
 332	return err
 333}
 334
 335func (a *agent) err(err error) AgentEvent {
 336	return AgentEvent{
 337		Type:  AgentEventTypeError,
 338		Error: err,
 339	}
 340}
 341
 342func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 343	if !a.Model().SupportsImages && attachments != nil {
 344		attachments = nil
 345	}
 346	events := make(chan AgentEvent)
 347	if a.IsSessionBusy(sessionID) {
 348		existing, ok := a.promptQueue.Get(sessionID)
 349		if !ok {
 350			existing = []string{}
 351		}
 352		existing = append(existing, content)
 353		a.promptQueue.Set(sessionID, existing)
 354		return nil, nil
 355	}
 356
 357	genCtx, cancel := context.WithCancel(ctx)
 358
 359	a.activeRequests.Set(sessionID, cancel)
 360	go func() {
 361		slog.Debug("Request started", "sessionID", sessionID)
 362		defer log.RecoverPanic("agent.Run", func() {
 363			events <- a.err(fmt.Errorf("panic while running the agent"))
 364		})
 365		var attachmentParts []message.ContentPart
 366		for _, attachment := range attachments {
 367			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 368		}
 369		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 370		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 371			slog.Error(result.Error.Error())
 372		}
 373		slog.Debug("Request completed", "sessionID", sessionID)
 374		a.activeRequests.Del(sessionID)
 375		cancel()
 376		a.Publish(pubsub.CreatedEvent, result)
 377		events <- result
 378		close(events)
 379	}()
 380	return events, nil
 381}
 382
 383func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 384	cfg := config.Get()
 385	// List existing messages; if none, start title generation asynchronously.
 386	msgs, err := a.messages.List(ctx, sessionID)
 387	if err != nil {
 388		return a.err(fmt.Errorf("failed to list messages: %w", err))
 389	}
 390	if len(msgs) == 0 {
 391		go func() {
 392			defer log.RecoverPanic("agent.Run", func() {
 393				slog.Error("panic while generating title")
 394			})
 395			titleErr := a.generateTitle(ctx, sessionID, content)
 396			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 397				slog.Error("failed to generate title", "error", titleErr)
 398			}
 399		}()
 400	}
 401	session, err := a.sessions.Get(ctx, sessionID)
 402	if err != nil {
 403		return a.err(fmt.Errorf("failed to get session: %w", err))
 404	}
 405	if session.SummaryMessageID != "" {
 406		summaryMsgInex := -1
 407		for i, msg := range msgs {
 408			if msg.ID == session.SummaryMessageID {
 409				summaryMsgInex = i
 410				break
 411			}
 412		}
 413		if summaryMsgInex != -1 {
 414			msgs = msgs[summaryMsgInex:]
 415			msgs[0].Role = message.User
 416		}
 417	}
 418
 419	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 420	if err != nil {
 421		return a.err(fmt.Errorf("failed to create user message: %w", err))
 422	}
 423	// Append the new user message to the conversation history.
 424	msgHistory := append(msgs, userMsg)
 425
 426	for {
 427		// Check for cancellation before each iteration
 428		select {
 429		case <-ctx.Done():
 430			return a.err(ctx.Err())
 431		default:
 432			// Continue processing
 433		}
 434		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 435		if err != nil {
 436			if errors.Is(err, context.Canceled) {
 437				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 438				a.messages.Update(context.Background(), agentMessage)
 439				return a.err(ErrRequestCancelled)
 440			}
 441			return a.err(fmt.Errorf("failed to process events: %w", err))
 442		}
 443		if cfg.Options.Debug {
 444			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 445		}
 446		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 447			// We are not done, we need to respond with the tool response
 448			msgHistory = append(msgHistory, agentMessage, *toolResults)
 449			// If there are queued prompts, process the next one
 450			nextPrompt, ok := a.promptQueue.Take(sessionID)
 451			if ok {
 452				for _, prompt := range nextPrompt {
 453					// Create a new user message for the queued prompt
 454					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 455					if err != nil {
 456						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 457					}
 458					// Append the new user message to the conversation history
 459					msgHistory = append(msgHistory, userMsg)
 460				}
 461			}
 462
 463			continue
 464		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 465			queuePrompts, ok := a.promptQueue.Take(sessionID)
 466			if ok {
 467				for _, prompt := range queuePrompts {
 468					if prompt == "" {
 469						continue
 470					}
 471					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 472					if err != nil {
 473						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 474					}
 475					msgHistory = append(msgHistory, userMsg)
 476				}
 477				continue
 478			}
 479		}
 480		if agentMessage.FinishReason() == "" {
 481			// Kujtim: could not track down where this is happening but this means its cancelled
 482			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 483			_ = a.messages.Update(context.Background(), agentMessage)
 484			return a.err(ErrRequestCancelled)
 485		}
 486		return AgentEvent{
 487			Type:    AgentEventTypeResponse,
 488			Message: agentMessage,
 489			Done:    true,
 490		}
 491	}
 492}
 493
 494func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 495	parts := []message.ContentPart{message.TextContent{Text: content}}
 496	parts = append(parts, attachmentParts...)
 497	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 498		Role:  message.User,
 499		Parts: parts,
 500	})
 501}
 502
 503func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 504	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 505
 506	// Create the assistant message first so the spinner shows immediately
 507	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 508		Role:     message.Assistant,
 509		Parts:    []message.ContentPart{},
 510		Model:    a.Model().ID,
 511		Provider: a.providerID,
 512	})
 513	if err != nil {
 514		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 515	}
 516
 517	// Now collect tools (which may block on MCP initialization)
 518	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
 519
 520	// Add the session and message ID into the context if needed by tools.
 521	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 522
 523	// Process each event in the stream.
 524	for event := range eventChan {
 525		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 526			if errors.Is(processErr, context.Canceled) {
 527				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 528			} else {
 529				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 530			}
 531			return assistantMsg, nil, processErr
 532		}
 533		if ctx.Err() != nil {
 534			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 535			return assistantMsg, nil, ctx.Err()
 536		}
 537	}
 538
 539	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 540	toolCalls := assistantMsg.ToolCalls()
 541	for i, toolCall := range toolCalls {
 542		select {
 543		case <-ctx.Done():
 544			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 545			// Make all future tool calls cancelled
 546			for j := i; j < len(toolCalls); j++ {
 547				toolResults[j] = message.ToolResult{
 548					ToolCallID: toolCalls[j].ID,
 549					Content:    "Tool execution canceled by user",
 550					IsError:    true,
 551				}
 552			}
 553			goto out
 554		default:
 555			// Continue processing
 556			var tool tools.BaseTool
 557			for availableTool := range a.tools.Seq() {
 558				if availableTool.Info().Name == toolCall.Name {
 559					tool = availableTool
 560					break
 561				}
 562			}
 563
 564			// Tool not found
 565			if tool == nil {
 566				toolResults[i] = message.ToolResult{
 567					ToolCallID: toolCall.ID,
 568					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 569					IsError:    true,
 570				}
 571				continue
 572			}
 573
 574			// Run tool in goroutine to allow cancellation
 575			type toolExecResult struct {
 576				response tools.ToolResponse
 577				err      error
 578			}
 579			resultChan := make(chan toolExecResult, 1)
 580
 581			go func() {
 582				response, err := tool.Run(ctx, tools.ToolCall{
 583					ID:    toolCall.ID,
 584					Name:  toolCall.Name,
 585					Input: toolCall.Input,
 586				})
 587				resultChan <- toolExecResult{response: response, err: err}
 588			}()
 589
 590			var toolResponse tools.ToolResponse
 591			var toolErr error
 592
 593			select {
 594			case <-ctx.Done():
 595				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 596				// Mark remaining tool calls as cancelled
 597				for j := i; j < len(toolCalls); j++ {
 598					toolResults[j] = message.ToolResult{
 599						ToolCallID: toolCalls[j].ID,
 600						Content:    "Tool execution canceled by user",
 601						IsError:    true,
 602					}
 603				}
 604				goto out
 605			case result := <-resultChan:
 606				toolResponse = result.response
 607				toolErr = result.err
 608			}
 609
 610			if toolErr != nil {
 611				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 612				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 613					toolResults[i] = message.ToolResult{
 614						ToolCallID: toolCall.ID,
 615						Content:    "Permission denied",
 616						IsError:    true,
 617					}
 618					for j := i + 1; j < len(toolCalls); j++ {
 619						toolResults[j] = message.ToolResult{
 620							ToolCallID: toolCalls[j].ID,
 621							Content:    "Tool execution canceled by user",
 622							IsError:    true,
 623						}
 624					}
 625					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 626					break
 627				}
 628			}
 629			toolResults[i] = message.ToolResult{
 630				ToolCallID: toolCall.ID,
 631				Content:    toolResponse.Content,
 632				Metadata:   toolResponse.Metadata,
 633				IsError:    toolResponse.IsError,
 634			}
 635		}
 636	}
 637out:
 638	if len(toolResults) == 0 {
 639		return assistantMsg, nil, nil
 640	}
 641	parts := make([]message.ContentPart, 0)
 642	for _, tr := range toolResults {
 643		parts = append(parts, tr)
 644	}
 645	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 646		Role:     message.Tool,
 647		Parts:    parts,
 648		Provider: a.providerID,
 649	})
 650	if err != nil {
 651		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 652	}
 653
 654	return assistantMsg, &msg, err
 655}
 656
 657func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 658	msg.AddFinish(finishReason, message, details)
 659	_ = a.messages.Update(ctx, *msg)
 660}
 661
 662func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 663	select {
 664	case <-ctx.Done():
 665		return ctx.Err()
 666	default:
 667		// Continue processing.
 668	}
 669
 670	switch event.Type {
 671	case provider.EventThinkingDelta:
 672		assistantMsg.AppendReasoningContent(event.Thinking)
 673		return a.messages.Update(ctx, *assistantMsg)
 674	case provider.EventSignatureDelta:
 675		assistantMsg.AppendReasoningSignature(event.Signature)
 676		return a.messages.Update(ctx, *assistantMsg)
 677	case provider.EventContentDelta:
 678		assistantMsg.FinishThinking()
 679		assistantMsg.AppendContent(event.Content)
 680		return a.messages.Update(ctx, *assistantMsg)
 681	case provider.EventToolUseStart:
 682		assistantMsg.FinishThinking()
 683		slog.Info("Tool call started", "toolCall", event.ToolCall)
 684		assistantMsg.AddToolCall(*event.ToolCall)
 685		return a.messages.Update(ctx, *assistantMsg)
 686	case provider.EventToolUseDelta:
 687		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 688		return a.messages.Update(ctx, *assistantMsg)
 689	case provider.EventToolUseStop:
 690		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 691		assistantMsg.FinishToolCall(event.ToolCall.ID)
 692		return a.messages.Update(ctx, *assistantMsg)
 693	case provider.EventError:
 694		return event.Error
 695	case provider.EventComplete:
 696		assistantMsg.FinishThinking()
 697		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 698		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 699		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 700			return fmt.Errorf("failed to update message: %w", err)
 701		}
 702		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 703	}
 704
 705	return nil
 706}
 707
 708func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 709	sess, err := a.sessions.Get(ctx, sessionID)
 710	if err != nil {
 711		return fmt.Errorf("failed to get session: %w", err)
 712	}
 713
 714	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 715		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 716		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 717		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 718
 719	sess.Cost += cost
 720	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 721	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 722
 723	_, err = a.sessions.Save(ctx, sess)
 724	if err != nil {
 725		return fmt.Errorf("failed to save session: %w", err)
 726	}
 727	return nil
 728}
 729
 730func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 731	if a.summarizeProvider == nil {
 732		return fmt.Errorf("summarize provider not available")
 733	}
 734
 735	// Check if session is busy
 736	if a.IsSessionBusy(sessionID) {
 737		return ErrSessionBusy
 738	}
 739
 740	// Create a new context with cancellation
 741	summarizeCtx, cancel := context.WithCancel(ctx)
 742
 743	// Store the cancel function in activeRequests to allow cancellation
 744	a.activeRequests.Set(sessionID+"-summarize", cancel)
 745
 746	go func() {
 747		defer a.activeRequests.Del(sessionID + "-summarize")
 748		defer cancel()
 749		event := AgentEvent{
 750			Type:     AgentEventTypeSummarize,
 751			Progress: "Starting summarization...",
 752		}
 753
 754		a.Publish(pubsub.CreatedEvent, event)
 755		// Get all messages from the session
 756		msgs, err := a.messages.List(summarizeCtx, sessionID)
 757		if err != nil {
 758			event = AgentEvent{
 759				Type:  AgentEventTypeError,
 760				Error: fmt.Errorf("failed to list messages: %w", err),
 761				Done:  true,
 762			}
 763			a.Publish(pubsub.CreatedEvent, event)
 764			return
 765		}
 766		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 767
 768		if len(msgs) == 0 {
 769			event = AgentEvent{
 770				Type:  AgentEventTypeError,
 771				Error: fmt.Errorf("no messages to summarize"),
 772				Done:  true,
 773			}
 774			a.Publish(pubsub.CreatedEvent, event)
 775			return
 776		}
 777
 778		event = AgentEvent{
 779			Type:     AgentEventTypeSummarize,
 780			Progress: "Analyzing conversation...",
 781		}
 782		a.Publish(pubsub.CreatedEvent, event)
 783
 784		// Add a system message to guide the summarization
 785		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 786
 787		// Create a new message with the summarize prompt
 788		promptMsg := message.Message{
 789			Role:  message.User,
 790			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 791		}
 792
 793		// Append the prompt to the messages
 794		msgsWithPrompt := append(msgs, promptMsg)
 795
 796		event = AgentEvent{
 797			Type:     AgentEventTypeSummarize,
 798			Progress: "Generating summary...",
 799		}
 800
 801		a.Publish(pubsub.CreatedEvent, event)
 802
 803		// Send the messages to the summarize provider
 804		response := a.summarizeProvider.StreamResponse(
 805			summarizeCtx,
 806			msgsWithPrompt,
 807			nil,
 808		)
 809		var finalResponse *provider.ProviderResponse
 810		for r := range response {
 811			if r.Error != nil {
 812				event = AgentEvent{
 813					Type:  AgentEventTypeError,
 814					Error: fmt.Errorf("failed to summarize: %w", err),
 815					Done:  true,
 816				}
 817				a.Publish(pubsub.CreatedEvent, event)
 818				return
 819			}
 820			finalResponse = r.Response
 821		}
 822
 823		summary := strings.TrimSpace(finalResponse.Content)
 824		if summary == "" {
 825			event = AgentEvent{
 826				Type:  AgentEventTypeError,
 827				Error: fmt.Errorf("empty summary returned"),
 828				Done:  true,
 829			}
 830			a.Publish(pubsub.CreatedEvent, event)
 831			return
 832		}
 833		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 834		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 835		event = AgentEvent{
 836			Type:     AgentEventTypeSummarize,
 837			Progress: "Creating new session...",
 838		}
 839
 840		a.Publish(pubsub.CreatedEvent, event)
 841		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 842		if err != nil {
 843			event = AgentEvent{
 844				Type:  AgentEventTypeError,
 845				Error: fmt.Errorf("failed to get session: %w", err),
 846				Done:  true,
 847			}
 848
 849			a.Publish(pubsub.CreatedEvent, event)
 850			return
 851		}
 852		// Create a message in the new session with the summary
 853		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 854			Role: message.Assistant,
 855			Parts: []message.ContentPart{
 856				message.TextContent{Text: summary},
 857				message.Finish{
 858					Reason: message.FinishReasonEndTurn,
 859					Time:   time.Now().Unix(),
 860				},
 861			},
 862			Model:    a.summarizeProvider.Model().ID,
 863			Provider: a.summarizeProviderID,
 864		})
 865		if err != nil {
 866			event = AgentEvent{
 867				Type:  AgentEventTypeError,
 868				Error: fmt.Errorf("failed to create summary message: %w", err),
 869				Done:  true,
 870			}
 871
 872			a.Publish(pubsub.CreatedEvent, event)
 873			return
 874		}
 875		oldSession.SummaryMessageID = msg.ID
 876		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 877		oldSession.PromptTokens = 0
 878		model := a.summarizeProvider.Model()
 879		usage := finalResponse.Usage
 880		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 881			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 882			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 883			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 884		oldSession.Cost += cost
 885		_, err = a.sessions.Save(summarizeCtx, oldSession)
 886		if err != nil {
 887			event = AgentEvent{
 888				Type:  AgentEventTypeError,
 889				Error: fmt.Errorf("failed to save session: %w", err),
 890				Done:  true,
 891			}
 892			a.Publish(pubsub.CreatedEvent, event)
 893		}
 894
 895		event = AgentEvent{
 896			Type:      AgentEventTypeSummarize,
 897			SessionID: oldSession.ID,
 898			Progress:  "Summary complete",
 899			Done:      true,
 900		}
 901		a.Publish(pubsub.CreatedEvent, event)
 902		// Send final success event with the new session ID
 903	}()
 904
 905	return nil
 906}
 907
 908func (a *agent) ClearQueue(sessionID string) {
 909	if a.QueuedPrompts(sessionID) > 0 {
 910		slog.Info("Clearing queued prompts", "session_id", sessionID)
 911		a.promptQueue.Del(sessionID)
 912	}
 913}
 914
 915func (a *agent) CancelAll() {
 916	if !a.IsBusy() {
 917		return
 918	}
 919	for key := range a.activeRequests.Seq2() {
 920		a.Cancel(key) // key is sessionID
 921	}
 922
 923	timeout := time.After(5 * time.Second)
 924	for a.IsBusy() {
 925		select {
 926		case <-timeout:
 927			return
 928		default:
 929			time.Sleep(200 * time.Millisecond)
 930		}
 931	}
 932}
 933
 934func (a *agent) UpdateModel() error {
 935	cfg := config.Get()
 936
 937	// Get current provider configuration
 938	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 939	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 940		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 941	}
 942
 943	// Check if provider has changed
 944	if string(currentProviderCfg.ID) != a.providerID {
 945		// Provider changed, need to recreate the main provider
 946		model := cfg.GetModelByType(a.agentCfg.Model)
 947		if model.ID == "" {
 948			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 949		}
 950
 951		promptID := agentPromptMap[a.agentCfg.ID]
 952		if promptID == "" {
 953			promptID = prompt.PromptDefault
 954		}
 955
 956		opts := []provider.ProviderClientOption{
 957			provider.WithModel(a.agentCfg.Model),
 958			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 959		}
 960
 961		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 962		if err != nil {
 963			return fmt.Errorf("failed to create new provider: %w", err)
 964		}
 965
 966		// Update the provider and provider ID
 967		a.provider = newProvider
 968		a.providerID = string(currentProviderCfg.ID)
 969	}
 970
 971	// Check if providers have changed for title (small) and summarize (large)
 972	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 973	var smallModelProviderCfg config.ProviderConfig
 974	for p := range cfg.Providers.Seq() {
 975		if p.ID == smallModelCfg.Provider {
 976			smallModelProviderCfg = p
 977			break
 978		}
 979	}
 980	if smallModelProviderCfg.ID == "" {
 981		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 982	}
 983
 984	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
 985	var largeModelProviderCfg config.ProviderConfig
 986	for p := range cfg.Providers.Seq() {
 987		if p.ID == largeModelCfg.Provider {
 988			largeModelProviderCfg = p
 989			break
 990		}
 991	}
 992	if largeModelProviderCfg.ID == "" {
 993		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
 994	}
 995
 996	var maxTitleTokens int64 = 40
 997
 998	// if the max output is too low for the gemini provider it won't return anything
 999	if smallModelCfg.Provider == "gemini" {
1000		maxTitleTokens = 1000
1001	}
1002	// Recreate title provider
1003	titleOpts := []provider.ProviderClientOption{
1004		provider.WithModel(config.SelectedModelTypeSmall),
1005		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1006		provider.WithMaxTokens(maxTitleTokens),
1007	}
1008	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1009	if err != nil {
1010		return fmt.Errorf("failed to create new title provider: %w", err)
1011	}
1012	a.titleProvider = newTitleProvider
1013
1014	// Recreate summarize provider if provider changed (now large model)
1015	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1016		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1017		if largeModel == nil {
1018			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1019		}
1020		summarizeOpts := []provider.ProviderClientOption{
1021			provider.WithModel(config.SelectedModelTypeLarge),
1022			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1023		}
1024		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1025		if err != nil {
1026			return fmt.Errorf("failed to create new summarize provider: %w", err)
1027		}
1028		a.summarizeProvider = newSummarizeProvider
1029		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1030	}
1031
1032	return nil
1033}