agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75	// We need this to be able to update it when model changes
  76	agentToolFn func() (tools.BaseTool, error)
  77
  78	provider   provider.Provider
  79	providerID string
  80
  81	titleProvider       provider.Provider
  82	summarizeProvider   provider.Provider
  83	summarizeProviderID string
  84
  85	activeRequests *csync.Map[string, context.CancelFunc]
  86
  87	promptQueue *csync.Map[string, []string]
  88}
  89
  90var agentPromptMap = map[string]prompt.PromptID{
  91	"coder": prompt.PromptCoder,
  92	"task":  prompt.PromptTask,
  93}
  94
  95func NewAgent(
  96	ctx context.Context,
  97	agentCfg config.Agent,
  98	// These services are needed in the tools
  99	permissions permission.Service,
 100	sessions session.Service,
 101	messages message.Service,
 102	history history.Service,
 103	lspClients map[string]*lsp.Client,
 104) (Service, error) {
 105	cfg := config.Get()
 106
 107	var agentToolFn func() (tools.BaseTool, error)
 108	if agentCfg.ID == "coder" {
 109		agentToolFn = func() (tools.BaseTool, error) {
 110			taskAgentCfg := config.Get().Agents["task"]
 111			if taskAgentCfg.ID == "" {
 112				return nil, fmt.Errorf("task agent not found in config")
 113			}
 114			taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
 115			if err != nil {
 116				return nil, fmt.Errorf("failed to create task agent: %w", err)
 117			}
 118			return NewAgentTool(taskAgent, sessions, messages), nil
 119		}
 120	}
 121
 122	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 123	if providerCfg == nil {
 124		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 125	}
 126	model := config.Get().GetModelByType(agentCfg.Model)
 127
 128	if model == nil {
 129		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 130	}
 131
 132	promptID := agentPromptMap[agentCfg.ID]
 133	if promptID == "" {
 134		promptID = prompt.PromptDefault
 135	}
 136	opts := []provider.ProviderClientOption{
 137		provider.WithModel(agentCfg.Model),
 138		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 139	}
 140	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 141	if err != nil {
 142		return nil, err
 143	}
 144
 145	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 146	var smallModelProviderCfg *config.ProviderConfig
 147	if smallModelCfg.Provider == providerCfg.ID {
 148		smallModelProviderCfg = providerCfg
 149	} else {
 150		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 151
 152		if smallModelProviderCfg.ID == "" {
 153			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 154		}
 155	}
 156	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 157	if smallModel.ID == "" {
 158		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 159	}
 160
 161	titleOpts := []provider.ProviderClientOption{
 162		provider.WithModel(config.SelectedModelTypeSmall),
 163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 164	}
 165	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 166	if err != nil {
 167		return nil, err
 168	}
 169
 170	summarizeOpts := []provider.ProviderClientOption{
 171		provider.WithModel(config.SelectedModelTypeLarge),
 172		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 173	}
 174	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 175	if err != nil {
 176		return nil, err
 177	}
 178
 179	toolFn := func() []tools.BaseTool {
 180		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 181		defer func() {
 182			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 183		}()
 184
 185		cwd := cfg.WorkingDir()
 186		allTools := []tools.BaseTool{
 187			tools.NewBashTool(permissions, cwd),
 188			tools.NewDownloadTool(permissions, cwd),
 189			tools.NewEditTool(lspClients, permissions, history, cwd),
 190			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 191			tools.NewFetchTool(permissions, cwd),
 192			tools.NewGlobTool(cwd),
 193			tools.NewGrepTool(cwd),
 194			tools.NewLsTool(permissions, cwd),
 195			tools.NewSourcegraphTool(),
 196			tools.NewViewTool(lspClients, permissions, cwd),
 197			tools.NewWriteTool(lspClients, permissions, history, cwd),
 198		}
 199
 200		mcpToolsOnce.Do(func() {
 201			mcpTools = doGetMCPTools(ctx, permissions, cfg)
 202		})
 203		allTools = append(allTools, mcpTools...)
 204
 205		if len(lspClients) > 0 {
 206			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
 207		}
 208
 209		if agentCfg.AllowedTools == nil {
 210			return allTools
 211		}
 212
 213		var filteredTools []tools.BaseTool
 214		for _, tool := range allTools {
 215			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 216				filteredTools = append(filteredTools, tool)
 217			}
 218		}
 219		return filteredTools
 220	}
 221
 222	return &agent{
 223		Broker:              pubsub.NewBroker[AgentEvent](),
 224		agentCfg:            agentCfg,
 225		provider:            agentProvider,
 226		providerID:          string(providerCfg.ID),
 227		messages:            messages,
 228		sessions:            sessions,
 229		titleProvider:       titleProvider,
 230		summarizeProvider:   summarizeProvider,
 231		summarizeProviderID: string(providerCfg.ID),
 232		agentToolFn:         agentToolFn,
 233		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 234		tools:               csync.NewLazySlice(toolFn),
 235		promptQueue:         csync.NewMap[string, []string](),
 236	}, nil
 237}
 238
 239func (a *agent) Model() catwalk.Model {
 240	return *config.Get().GetModelByType(a.agentCfg.Model)
 241}
 242
 243func (a *agent) Cancel(sessionID string) {
 244	// Cancel regular requests
 245	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 246		slog.Info("Request cancellation initiated", "session_id", sessionID)
 247		cancel()
 248	}
 249
 250	// Also check for summarize requests
 251	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 252		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 253		cancel()
 254	}
 255
 256	if a.QueuedPrompts(sessionID) > 0 {
 257		slog.Info("Clearing queued prompts", "session_id", sessionID)
 258		a.promptQueue.Del(sessionID)
 259	}
 260}
 261
 262func (a *agent) IsBusy() bool {
 263	var busy bool
 264	for cancelFunc := range a.activeRequests.Seq() {
 265		if cancelFunc != nil {
 266			busy = true
 267			break
 268		}
 269	}
 270	return busy
 271}
 272
 273func (a *agent) IsSessionBusy(sessionID string) bool {
 274	_, busy := a.activeRequests.Get(sessionID)
 275	return busy
 276}
 277
 278func (a *agent) QueuedPrompts(sessionID string) int {
 279	l, ok := a.promptQueue.Get(sessionID)
 280	if !ok {
 281		return 0
 282	}
 283	return len(l)
 284}
 285
 286func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 287	if content == "" {
 288		return nil
 289	}
 290	if a.titleProvider == nil {
 291		return nil
 292	}
 293	session, err := a.sessions.Get(ctx, sessionID)
 294	if err != nil {
 295		return err
 296	}
 297	parts := []message.ContentPart{message.TextContent{
 298		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 299	}}
 300
 301	// Use streaming approach like summarization
 302	response := a.titleProvider.StreamResponse(
 303		ctx,
 304		[]message.Message{
 305			{
 306				Role:  message.User,
 307				Parts: parts,
 308			},
 309		},
 310		nil,
 311	)
 312
 313	var finalResponse *provider.ProviderResponse
 314	for r := range response {
 315		if r.Error != nil {
 316			return r.Error
 317		}
 318		finalResponse = r.Response
 319	}
 320
 321	if finalResponse == nil {
 322		return fmt.Errorf("no response received from title provider")
 323	}
 324
 325	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 326	if title == "" {
 327		return nil
 328	}
 329
 330	session.Title = title
 331	_, err = a.sessions.Save(ctx, session)
 332	return err
 333}
 334
 335func (a *agent) err(err error) AgentEvent {
 336	return AgentEvent{
 337		Type:  AgentEventTypeError,
 338		Error: err,
 339	}
 340}
 341
 342func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 343	if !a.Model().SupportsImages && attachments != nil {
 344		attachments = nil
 345	}
 346	events := make(chan AgentEvent, 1)
 347	if a.IsSessionBusy(sessionID) {
 348		existing, ok := a.promptQueue.Get(sessionID)
 349		if !ok {
 350			existing = []string{}
 351		}
 352		existing = append(existing, content)
 353		a.promptQueue.Set(sessionID, existing)
 354		return nil, nil
 355	}
 356
 357	genCtx, cancel := context.WithCancel(ctx)
 358
 359	a.activeRequests.Set(sessionID, cancel)
 360	go func() {
 361		slog.Debug("Request started", "sessionID", sessionID)
 362		defer log.RecoverPanic("agent.Run", func() {
 363			events <- a.err(fmt.Errorf("panic while running the agent"))
 364		})
 365		var attachmentParts []message.ContentPart
 366		for _, attachment := range attachments {
 367			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 368		}
 369		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 370		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 371			slog.Error(result.Error.Error())
 372		}
 373		slog.Debug("Request completed", "sessionID", sessionID)
 374		a.activeRequests.Del(sessionID)
 375		cancel()
 376		a.Publish(pubsub.CreatedEvent, result)
 377		events <- result
 378		close(events)
 379	}()
 380	return events, nil
 381}
 382
 383func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 384	cfg := config.Get()
 385	// List existing messages; if none, start title generation asynchronously.
 386	msgs, err := a.messages.List(ctx, sessionID)
 387	if err != nil {
 388		return a.err(fmt.Errorf("failed to list messages: %w", err))
 389	}
 390	if len(msgs) == 0 {
 391		go func() {
 392			defer log.RecoverPanic("agent.Run", func() {
 393				slog.Error("panic while generating title")
 394			})
 395			titleErr := a.generateTitle(ctx, sessionID, content)
 396			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 397				slog.Error("failed to generate title", "error", titleErr)
 398			}
 399		}()
 400	}
 401	session, err := a.sessions.Get(ctx, sessionID)
 402	if err != nil {
 403		return a.err(fmt.Errorf("failed to get session: %w", err))
 404	}
 405	if session.SummaryMessageID != "" {
 406		summaryMsgInex := -1
 407		for i, msg := range msgs {
 408			if msg.ID == session.SummaryMessageID {
 409				summaryMsgInex = i
 410				break
 411			}
 412		}
 413		if summaryMsgInex != -1 {
 414			msgs = msgs[summaryMsgInex:]
 415			msgs[0].Role = message.User
 416		}
 417	}
 418
 419	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 420	if err != nil {
 421		return a.err(fmt.Errorf("failed to create user message: %w", err))
 422	}
 423	// Append the new user message to the conversation history.
 424	msgHistory := append(msgs, userMsg)
 425
 426	for {
 427		// Check for cancellation before each iteration
 428		select {
 429		case <-ctx.Done():
 430			return a.err(ctx.Err())
 431		default:
 432			// Continue processing
 433		}
 434		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 435		if err != nil {
 436			if errors.Is(err, context.Canceled) {
 437				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 438				a.messages.Update(context.Background(), agentMessage)
 439				return a.err(ErrRequestCancelled)
 440			}
 441			return a.err(fmt.Errorf("failed to process events: %w", err))
 442		}
 443		if cfg.Options.Debug {
 444			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 445		}
 446		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 447			// We are not done, we need to respond with the tool response
 448			msgHistory = append(msgHistory, agentMessage, *toolResults)
 449			// If there are queued prompts, process the next one
 450			nextPrompt, ok := a.promptQueue.Take(sessionID)
 451			if ok {
 452				for _, prompt := range nextPrompt {
 453					// Create a new user message for the queued prompt
 454					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 455					if err != nil {
 456						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 457					}
 458					// Append the new user message to the conversation history
 459					msgHistory = append(msgHistory, userMsg)
 460				}
 461			}
 462
 463			continue
 464		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 465			queuePrompts, ok := a.promptQueue.Take(sessionID)
 466			if ok {
 467				for _, prompt := range queuePrompts {
 468					if prompt == "" {
 469						continue
 470					}
 471					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 472					if err != nil {
 473						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 474					}
 475					msgHistory = append(msgHistory, userMsg)
 476				}
 477				continue
 478			}
 479		}
 480		if agentMessage.FinishReason() == "" {
 481			// Kujtim: could not track down where this is happening but this means its cancelled
 482			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 483			_ = a.messages.Update(context.Background(), agentMessage)
 484			return a.err(ErrRequestCancelled)
 485		}
 486		return AgentEvent{
 487			Type:    AgentEventTypeResponse,
 488			Message: agentMessage,
 489			Done:    true,
 490		}
 491	}
 492}
 493
 494func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 495	parts := []message.ContentPart{message.TextContent{Text: content}}
 496	parts = append(parts, attachmentParts...)
 497	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 498		Role:  message.User,
 499		Parts: parts,
 500	})
 501}
 502
 503func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 504	allTools := slices.Collect(a.tools.Seq())
 505	if a.agentToolFn != nil {
 506		agentTool, agentToolErr := a.agentToolFn()
 507		if agentToolErr != nil {
 508			return nil, agentToolErr
 509		}
 510		allTools = append(allTools, agentTool)
 511	}
 512	return allTools, nil
 513}
 514
 515func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 516	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 517
 518	// Create the assistant message first so the spinner shows immediately
 519	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 520		Role:     message.Assistant,
 521		Parts:    []message.ContentPart{},
 522		Model:    a.Model().ID,
 523		Provider: a.providerID,
 524	})
 525	if err != nil {
 526		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 527	}
 528
 529	allTools, toolsErr := a.getAllTools()
 530	if toolsErr != nil {
 531		return assistantMsg, nil, toolsErr
 532	}
 533	// Now collect tools (which may block on MCP initialization)
 534	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 535
 536	// Add the session and message ID into the context if needed by tools.
 537	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 538
 539	// Process each event in the stream.
 540	for event := range eventChan {
 541		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 542			if errors.Is(processErr, context.Canceled) {
 543				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 544			} else {
 545				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 546			}
 547			return assistantMsg, nil, processErr
 548		}
 549		if ctx.Err() != nil {
 550			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 551			return assistantMsg, nil, ctx.Err()
 552		}
 553	}
 554
 555	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 556	toolCalls := assistantMsg.ToolCalls()
 557	for i, toolCall := range toolCalls {
 558		select {
 559		case <-ctx.Done():
 560			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 561			// Make all future tool calls cancelled
 562			for j := i; j < len(toolCalls); j++ {
 563				toolResults[j] = message.ToolResult{
 564					ToolCallID: toolCalls[j].ID,
 565					Content:    "Tool execution canceled by user",
 566					IsError:    true,
 567				}
 568			}
 569			goto out
 570		default:
 571			// Continue processing
 572			var tool tools.BaseTool
 573			allTools, _ := a.getAllTools()
 574			for _, availableTool := range allTools {
 575				if availableTool.Info().Name == toolCall.Name {
 576					tool = availableTool
 577					break
 578				}
 579			}
 580
 581			// Tool not found
 582			if tool == nil {
 583				toolResults[i] = message.ToolResult{
 584					ToolCallID: toolCall.ID,
 585					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 586					IsError:    true,
 587				}
 588				continue
 589			}
 590
 591			// Run tool in goroutine to allow cancellation
 592			type toolExecResult struct {
 593				response tools.ToolResponse
 594				err      error
 595			}
 596			resultChan := make(chan toolExecResult, 1)
 597
 598			go func() {
 599				response, err := tool.Run(ctx, tools.ToolCall{
 600					ID:    toolCall.ID,
 601					Name:  toolCall.Name,
 602					Input: toolCall.Input,
 603				})
 604				resultChan <- toolExecResult{response: response, err: err}
 605			}()
 606
 607			var toolResponse tools.ToolResponse
 608			var toolErr error
 609
 610			select {
 611			case <-ctx.Done():
 612				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 613				// Mark remaining tool calls as cancelled
 614				for j := i; j < len(toolCalls); j++ {
 615					toolResults[j] = message.ToolResult{
 616						ToolCallID: toolCalls[j].ID,
 617						Content:    "Tool execution canceled by user",
 618						IsError:    true,
 619					}
 620				}
 621				goto out
 622			case result := <-resultChan:
 623				toolResponse = result.response
 624				toolErr = result.err
 625			}
 626
 627			if toolErr != nil {
 628				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 629				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 630					toolResults[i] = message.ToolResult{
 631						ToolCallID: toolCall.ID,
 632						Content:    "Permission denied",
 633						IsError:    true,
 634					}
 635					for j := i + 1; j < len(toolCalls); j++ {
 636						toolResults[j] = message.ToolResult{
 637							ToolCallID: toolCalls[j].ID,
 638							Content:    "Tool execution canceled by user",
 639							IsError:    true,
 640						}
 641					}
 642					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 643					break
 644				}
 645			}
 646			toolResults[i] = message.ToolResult{
 647				ToolCallID: toolCall.ID,
 648				Content:    toolResponse.Content,
 649				Metadata:   toolResponse.Metadata,
 650				IsError:    toolResponse.IsError,
 651			}
 652		}
 653	}
 654out:
 655	if len(toolResults) == 0 {
 656		return assistantMsg, nil, nil
 657	}
 658	parts := make([]message.ContentPart, 0)
 659	for _, tr := range toolResults {
 660		parts = append(parts, tr)
 661	}
 662	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 663		Role:     message.Tool,
 664		Parts:    parts,
 665		Provider: a.providerID,
 666	})
 667	if err != nil {
 668		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 669	}
 670
 671	return assistantMsg, &msg, err
 672}
 673
 674func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 675	msg.AddFinish(finishReason, message, details)
 676	_ = a.messages.Update(ctx, *msg)
 677}
 678
 679func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 680	select {
 681	case <-ctx.Done():
 682		return ctx.Err()
 683	default:
 684		// Continue processing.
 685	}
 686
 687	switch event.Type {
 688	case provider.EventThinkingDelta:
 689		assistantMsg.AppendReasoningContent(event.Thinking)
 690		return a.messages.Update(ctx, *assistantMsg)
 691	case provider.EventSignatureDelta:
 692		assistantMsg.AppendReasoningSignature(event.Signature)
 693		return a.messages.Update(ctx, *assistantMsg)
 694	case provider.EventContentDelta:
 695		assistantMsg.FinishThinking()
 696		assistantMsg.AppendContent(event.Content)
 697		return a.messages.Update(ctx, *assistantMsg)
 698	case provider.EventToolUseStart:
 699		assistantMsg.FinishThinking()
 700		slog.Info("Tool call started", "toolCall", event.ToolCall)
 701		assistantMsg.AddToolCall(*event.ToolCall)
 702		return a.messages.Update(ctx, *assistantMsg)
 703	case provider.EventToolUseDelta:
 704		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 705		return a.messages.Update(ctx, *assistantMsg)
 706	case provider.EventToolUseStop:
 707		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 708		assistantMsg.FinishToolCall(event.ToolCall.ID)
 709		return a.messages.Update(ctx, *assistantMsg)
 710	case provider.EventError:
 711		return event.Error
 712	case provider.EventComplete:
 713		assistantMsg.FinishThinking()
 714		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 715		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 716		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 717			return fmt.Errorf("failed to update message: %w", err)
 718		}
 719		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 720	}
 721
 722	return nil
 723}
 724
 725func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 726	sess, err := a.sessions.Get(ctx, sessionID)
 727	if err != nil {
 728		return fmt.Errorf("failed to get session: %w", err)
 729	}
 730
 731	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 732		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 733		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 734		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 735
 736	sess.Cost += cost
 737	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 738	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 739
 740	_, err = a.sessions.Save(ctx, sess)
 741	if err != nil {
 742		return fmt.Errorf("failed to save session: %w", err)
 743	}
 744	return nil
 745}
 746
 747func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 748	if a.summarizeProvider == nil {
 749		return fmt.Errorf("summarize provider not available")
 750	}
 751
 752	// Check if session is busy
 753	if a.IsSessionBusy(sessionID) {
 754		return ErrSessionBusy
 755	}
 756
 757	// Create a new context with cancellation
 758	summarizeCtx, cancel := context.WithCancel(ctx)
 759
 760	// Store the cancel function in activeRequests to allow cancellation
 761	a.activeRequests.Set(sessionID+"-summarize", cancel)
 762
 763	go func() {
 764		defer a.activeRequests.Del(sessionID + "-summarize")
 765		defer cancel()
 766		event := AgentEvent{
 767			Type:     AgentEventTypeSummarize,
 768			Progress: "Starting summarization...",
 769		}
 770
 771		a.Publish(pubsub.CreatedEvent, event)
 772		// Get all messages from the session
 773		msgs, err := a.messages.List(summarizeCtx, sessionID)
 774		if err != nil {
 775			event = AgentEvent{
 776				Type:  AgentEventTypeError,
 777				Error: fmt.Errorf("failed to list messages: %w", err),
 778				Done:  true,
 779			}
 780			a.Publish(pubsub.CreatedEvent, event)
 781			return
 782		}
 783		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 784
 785		if len(msgs) == 0 {
 786			event = AgentEvent{
 787				Type:  AgentEventTypeError,
 788				Error: fmt.Errorf("no messages to summarize"),
 789				Done:  true,
 790			}
 791			a.Publish(pubsub.CreatedEvent, event)
 792			return
 793		}
 794
 795		event = AgentEvent{
 796			Type:     AgentEventTypeSummarize,
 797			Progress: "Analyzing conversation...",
 798		}
 799		a.Publish(pubsub.CreatedEvent, event)
 800
 801		// Add a system message to guide the summarization
 802		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 803
 804		// Create a new message with the summarize prompt
 805		promptMsg := message.Message{
 806			Role:  message.User,
 807			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 808		}
 809
 810		// Append the prompt to the messages
 811		msgsWithPrompt := append(msgs, promptMsg)
 812
 813		event = AgentEvent{
 814			Type:     AgentEventTypeSummarize,
 815			Progress: "Generating summary...",
 816		}
 817
 818		a.Publish(pubsub.CreatedEvent, event)
 819
 820		// Send the messages to the summarize provider
 821		response := a.summarizeProvider.StreamResponse(
 822			summarizeCtx,
 823			msgsWithPrompt,
 824			nil,
 825		)
 826		var finalResponse *provider.ProviderResponse
 827		for r := range response {
 828			if r.Error != nil {
 829				event = AgentEvent{
 830					Type:  AgentEventTypeError,
 831					Error: fmt.Errorf("failed to summarize: %w", err),
 832					Done:  true,
 833				}
 834				a.Publish(pubsub.CreatedEvent, event)
 835				return
 836			}
 837			finalResponse = r.Response
 838		}
 839
 840		summary := strings.TrimSpace(finalResponse.Content)
 841		if summary == "" {
 842			event = AgentEvent{
 843				Type:  AgentEventTypeError,
 844				Error: fmt.Errorf("empty summary returned"),
 845				Done:  true,
 846			}
 847			a.Publish(pubsub.CreatedEvent, event)
 848			return
 849		}
 850		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 851		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 852		event = AgentEvent{
 853			Type:     AgentEventTypeSummarize,
 854			Progress: "Creating new session...",
 855		}
 856
 857		a.Publish(pubsub.CreatedEvent, event)
 858		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 859		if err != nil {
 860			event = AgentEvent{
 861				Type:  AgentEventTypeError,
 862				Error: fmt.Errorf("failed to get session: %w", err),
 863				Done:  true,
 864			}
 865
 866			a.Publish(pubsub.CreatedEvent, event)
 867			return
 868		}
 869		// Create a message in the new session with the summary
 870		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 871			Role: message.Assistant,
 872			Parts: []message.ContentPart{
 873				message.TextContent{Text: summary},
 874				message.Finish{
 875					Reason: message.FinishReasonEndTurn,
 876					Time:   time.Now().Unix(),
 877				},
 878			},
 879			Model:    a.summarizeProvider.Model().ID,
 880			Provider: a.summarizeProviderID,
 881		})
 882		if err != nil {
 883			event = AgentEvent{
 884				Type:  AgentEventTypeError,
 885				Error: fmt.Errorf("failed to create summary message: %w", err),
 886				Done:  true,
 887			}
 888
 889			a.Publish(pubsub.CreatedEvent, event)
 890			return
 891		}
 892		oldSession.SummaryMessageID = msg.ID
 893		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 894		oldSession.PromptTokens = 0
 895		model := a.summarizeProvider.Model()
 896		usage := finalResponse.Usage
 897		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 898			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 899			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 900			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 901		oldSession.Cost += cost
 902		_, err = a.sessions.Save(summarizeCtx, oldSession)
 903		if err != nil {
 904			event = AgentEvent{
 905				Type:  AgentEventTypeError,
 906				Error: fmt.Errorf("failed to save session: %w", err),
 907				Done:  true,
 908			}
 909			a.Publish(pubsub.CreatedEvent, event)
 910		}
 911
 912		event = AgentEvent{
 913			Type:      AgentEventTypeSummarize,
 914			SessionID: oldSession.ID,
 915			Progress:  "Summary complete",
 916			Done:      true,
 917		}
 918		a.Publish(pubsub.CreatedEvent, event)
 919		// Send final success event with the new session ID
 920	}()
 921
 922	return nil
 923}
 924
 925func (a *agent) ClearQueue(sessionID string) {
 926	if a.QueuedPrompts(sessionID) > 0 {
 927		slog.Info("Clearing queued prompts", "session_id", sessionID)
 928		a.promptQueue.Del(sessionID)
 929	}
 930}
 931
 932func (a *agent) CancelAll() {
 933	if !a.IsBusy() {
 934		return
 935	}
 936	for key := range a.activeRequests.Seq2() {
 937		a.Cancel(key) // key is sessionID
 938	}
 939
 940	timeout := time.After(5 * time.Second)
 941	for a.IsBusy() {
 942		select {
 943		case <-timeout:
 944			return
 945		default:
 946			time.Sleep(200 * time.Millisecond)
 947		}
 948	}
 949}
 950
 951func (a *agent) UpdateModel() error {
 952	cfg := config.Get()
 953
 954	// Get current provider configuration
 955	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 956	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 957		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 958	}
 959
 960	// Check if provider has changed
 961	if string(currentProviderCfg.ID) != a.providerID {
 962		// Provider changed, need to recreate the main provider
 963		model := cfg.GetModelByType(a.agentCfg.Model)
 964		if model.ID == "" {
 965			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 966		}
 967
 968		promptID := agentPromptMap[a.agentCfg.ID]
 969		if promptID == "" {
 970			promptID = prompt.PromptDefault
 971		}
 972
 973		opts := []provider.ProviderClientOption{
 974			provider.WithModel(a.agentCfg.Model),
 975			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 976		}
 977
 978		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 979		if err != nil {
 980			return fmt.Errorf("failed to create new provider: %w", err)
 981		}
 982
 983		// Update the provider and provider ID
 984		a.provider = newProvider
 985		a.providerID = string(currentProviderCfg.ID)
 986	}
 987
 988	// Check if providers have changed for title (small) and summarize (large)
 989	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 990	var smallModelProviderCfg config.ProviderConfig
 991	for p := range cfg.Providers.Seq() {
 992		if p.ID == smallModelCfg.Provider {
 993			smallModelProviderCfg = p
 994			break
 995		}
 996	}
 997	if smallModelProviderCfg.ID == "" {
 998		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 999	}
1000
1001	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1002	var largeModelProviderCfg config.ProviderConfig
1003	for p := range cfg.Providers.Seq() {
1004		if p.ID == largeModelCfg.Provider {
1005			largeModelProviderCfg = p
1006			break
1007		}
1008	}
1009	if largeModelProviderCfg.ID == "" {
1010		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1011	}
1012
1013	var maxTitleTokens int64 = 40
1014
1015	// if the max output is too low for the gemini provider it won't return anything
1016	if smallModelCfg.Provider == "gemini" {
1017		maxTitleTokens = 1000
1018	}
1019	// Recreate title provider
1020	titleOpts := []provider.ProviderClientOption{
1021		provider.WithModel(config.SelectedModelTypeSmall),
1022		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1023		provider.WithMaxTokens(maxTitleTokens),
1024	}
1025	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1026	if err != nil {
1027		return fmt.Errorf("failed to create new title provider: %w", err)
1028	}
1029	a.titleProvider = newTitleProvider
1030
1031	// Recreate summarize provider if provider changed (now large model)
1032	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1033		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1034		if largeModel == nil {
1035			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1036		}
1037		summarizeOpts := []provider.ProviderClientOption{
1038			provider.WithModel(config.SelectedModelTypeLarge),
1039			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1040		}
1041		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1042		if err != nil {
1043			return fmt.Errorf("failed to create new summarize provider: %w", err)
1044		}
1045		a.summarizeProvider = newSummarizeProvider
1046		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1047	}
1048
1049	return nil
1050}