agent.go

   1package agent
   2
   3import (
   4	"context"
   5	"errors"
   6	"fmt"
   7	"log/slog"
   8	"slices"
   9	"strings"
  10	"time"
  11
  12	"github.com/charmbracelet/catwalk/pkg/catwalk"
  13	"github.com/charmbracelet/crush/internal/config"
  14	"github.com/charmbracelet/crush/internal/csync"
  15	"github.com/charmbracelet/crush/internal/history"
  16	"github.com/charmbracelet/crush/internal/llm/prompt"
  17	"github.com/charmbracelet/crush/internal/llm/provider"
  18	"github.com/charmbracelet/crush/internal/llm/tools"
  19	"github.com/charmbracelet/crush/internal/log"
  20	"github.com/charmbracelet/crush/internal/lsp"
  21	"github.com/charmbracelet/crush/internal/message"
  22	"github.com/charmbracelet/crush/internal/permission"
  23	"github.com/charmbracelet/crush/internal/pubsub"
  24	"github.com/charmbracelet/crush/internal/session"
  25	"github.com/charmbracelet/crush/internal/shell"
  26)
  27
  28// Common errors
  29var (
  30	ErrRequestCancelled = errors.New("request canceled by user")
  31	ErrSessionBusy      = errors.New("session is currently processing another request")
  32)
  33
  34type AgentEventType string
  35
  36const (
  37	AgentEventTypeError     AgentEventType = "error"
  38	AgentEventTypeResponse  AgentEventType = "response"
  39	AgentEventTypeSummarize AgentEventType = "summarize"
  40)
  41
  42type AgentEvent struct {
  43	Type    AgentEventType
  44	Message message.Message
  45	Error   error
  46
  47	// When summarizing
  48	SessionID string
  49	Progress  string
  50	Done      bool
  51}
  52
  53type Service interface {
  54	pubsub.Suscriber[AgentEvent]
  55	Model() catwalk.Model
  56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
  57	Cancel(sessionID string)
  58	CancelAll()
  59	IsSessionBusy(sessionID string) bool
  60	IsBusy() bool
  61	Summarize(ctx context.Context, sessionID string) error
  62	UpdateModel() error
  63	QueuedPrompts(sessionID string) int
  64	ClearQueue(sessionID string)
  65}
  66
  67type agent struct {
  68	*pubsub.Broker[AgentEvent]
  69	agentCfg config.Agent
  70	sessions session.Service
  71	messages message.Service
  72	mcpTools []McpTool
  73
  74	tools *csync.LazySlice[tools.BaseTool]
  75	// We need this to be able to update it when model changes
  76	agentToolFn func() (tools.BaseTool, error)
  77
  78	provider   provider.Provider
  79	providerID string
  80
  81	titleProvider       provider.Provider
  82	summarizeProvider   provider.Provider
  83	summarizeProviderID string
  84
  85	activeRequests *csync.Map[string, context.CancelFunc]
  86
  87	promptQueue *csync.Map[string, []string]
  88}
  89
  90var agentPromptMap = map[string]prompt.PromptID{
  91	"coder": prompt.PromptCoder,
  92	"task":  prompt.PromptTask,
  93}
  94
  95func NewAgent(
  96	agentCfg config.Agent,
  97	// These services are needed in the tools
  98	permissions permission.Service,
  99	sessions session.Service,
 100	messages message.Service,
 101	history history.Service,
 102	lspClients map[string]*lsp.Client,
 103) (Service, error) {
 104	cfg := config.Get()
 105
 106	var agentToolFn func() (tools.BaseTool, error)
 107	if agentCfg.ID == "coder" {
 108		agentToolFn = func() (tools.BaseTool, error) {
 109			taskAgentCfg := config.Get().Agents["task"]
 110			if taskAgentCfg.ID == "" {
 111				return nil, fmt.Errorf("task agent not found in config")
 112			}
 113			taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
 114			if err != nil {
 115				return nil, fmt.Errorf("failed to create task agent: %w", err)
 116			}
 117			return NewAgentTool(taskAgent, sessions, messages), nil
 118		}
 119	}
 120
 121	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
 122	if providerCfg == nil {
 123		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
 124	}
 125	model := config.Get().GetModelByType(agentCfg.Model)
 126
 127	if model == nil {
 128		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
 129	}
 130
 131	promptID := agentPromptMap[agentCfg.ID]
 132	if promptID == "" {
 133		promptID = prompt.PromptDefault
 134	}
 135	opts := []provider.ProviderClientOption{
 136		provider.WithModel(agentCfg.Model),
 137		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
 138	}
 139	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
 140	if err != nil {
 141		return nil, err
 142	}
 143
 144	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 145	var smallModelProviderCfg *config.ProviderConfig
 146	if smallModelCfg.Provider == providerCfg.ID {
 147		smallModelProviderCfg = providerCfg
 148	} else {
 149		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
 150
 151		if smallModelProviderCfg.ID == "" {
 152			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 153		}
 154	}
 155	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
 156	if smallModel.ID == "" {
 157		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
 158	}
 159
 160	titleOpts := []provider.ProviderClientOption{
 161		provider.WithModel(config.SelectedModelTypeSmall),
 162		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
 163	}
 164	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
 165	if err != nil {
 166		return nil, err
 167	}
 168
 169	summarizeOpts := []provider.ProviderClientOption{
 170		provider.WithModel(config.SelectedModelTypeLarge),
 171		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
 172	}
 173	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
 174	if err != nil {
 175		return nil, err
 176	}
 177
 178	toolFn := func() []tools.BaseTool {
 179		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
 180		defer func() {
 181			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
 182		}()
 183
 184		cwd := cfg.WorkingDir()
 185		allTools := []tools.BaseTool{
 186			tools.NewBashTool(permissions, cwd),
 187			tools.NewDownloadTool(permissions, cwd),
 188			tools.NewEditTool(lspClients, permissions, history, cwd),
 189			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
 190			tools.NewFetchTool(permissions, cwd),
 191			tools.NewGlobTool(cwd),
 192			tools.NewGrepTool(cwd),
 193			tools.NewLsTool(permissions, cwd),
 194			tools.NewSourcegraphTool(),
 195			tools.NewViewTool(lspClients, permissions, cwd),
 196			tools.NewWriteTool(lspClients, permissions, history, cwd),
 197		}
 198
 199		mcpToolsOnce.Do(func() {
 200			mcpTools = doGetMCPTools(permissions, cfg)
 201		})
 202		allTools = append(allTools, mcpTools...)
 203
 204		if len(lspClients) > 0 {
 205			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
 206		}
 207
 208		if agentCfg.AllowedTools == nil {
 209			return allTools
 210		}
 211
 212		var filteredTools []tools.BaseTool
 213		for _, tool := range allTools {
 214			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
 215				filteredTools = append(filteredTools, tool)
 216			}
 217		}
 218		return filteredTools
 219	}
 220
 221	return &agent{
 222		Broker:              pubsub.NewBroker[AgentEvent](),
 223		agentCfg:            agentCfg,
 224		provider:            agentProvider,
 225		providerID:          string(providerCfg.ID),
 226		messages:            messages,
 227		sessions:            sessions,
 228		titleProvider:       titleProvider,
 229		summarizeProvider:   summarizeProvider,
 230		summarizeProviderID: string(providerCfg.ID),
 231		agentToolFn:         agentToolFn,
 232		activeRequests:      csync.NewMap[string, context.CancelFunc](),
 233		tools:               csync.NewLazySlice(toolFn),
 234		promptQueue:         csync.NewMap[string, []string](),
 235	}, nil
 236}
 237
 238func (a *agent) Model() catwalk.Model {
 239	return *config.Get().GetModelByType(a.agentCfg.Model)
 240}
 241
 242func (a *agent) Cancel(sessionID string) {
 243	// Cancel regular requests
 244	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
 245		slog.Info("Request cancellation initiated", "session_id", sessionID)
 246		cancel()
 247	}
 248
 249	// Also check for summarize requests
 250	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
 251		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
 252		cancel()
 253	}
 254
 255	if a.QueuedPrompts(sessionID) > 0 {
 256		slog.Info("Clearing queued prompts", "session_id", sessionID)
 257		a.promptQueue.Del(sessionID)
 258	}
 259}
 260
 261func (a *agent) IsBusy() bool {
 262	var busy bool
 263	for cancelFunc := range a.activeRequests.Seq() {
 264		if cancelFunc != nil {
 265			busy = true
 266			break
 267		}
 268	}
 269	return busy
 270}
 271
 272func (a *agent) IsSessionBusy(sessionID string) bool {
 273	_, busy := a.activeRequests.Get(sessionID)
 274	return busy
 275}
 276
 277func (a *agent) QueuedPrompts(sessionID string) int {
 278	l, ok := a.promptQueue.Get(sessionID)
 279	if !ok {
 280		return 0
 281	}
 282	return len(l)
 283}
 284
 285func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
 286	if content == "" {
 287		return nil
 288	}
 289	if a.titleProvider == nil {
 290		return nil
 291	}
 292	session, err := a.sessions.Get(ctx, sessionID)
 293	if err != nil {
 294		return err
 295	}
 296	parts := []message.ContentPart{message.TextContent{
 297		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
 298	}}
 299
 300	// Use streaming approach like summarization
 301	response := a.titleProvider.StreamResponse(
 302		ctx,
 303		[]message.Message{
 304			{
 305				Role:  message.User,
 306				Parts: parts,
 307			},
 308		},
 309		nil,
 310	)
 311
 312	var finalResponse *provider.ProviderResponse
 313	for r := range response {
 314		if r.Error != nil {
 315			return r.Error
 316		}
 317		finalResponse = r.Response
 318	}
 319
 320	if finalResponse == nil {
 321		return fmt.Errorf("no response received from title provider")
 322	}
 323
 324	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
 325	if title == "" {
 326		return nil
 327	}
 328
 329	session.Title = title
 330	_, err = a.sessions.Save(ctx, session)
 331	return err
 332}
 333
 334func (a *agent) err(err error) AgentEvent {
 335	return AgentEvent{
 336		Type:  AgentEventTypeError,
 337		Error: err,
 338	}
 339}
 340
 341func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
 342	if !a.Model().SupportsImages && attachments != nil {
 343		attachments = nil
 344	}
 345	events := make(chan AgentEvent, 1)
 346	if a.IsSessionBusy(sessionID) {
 347		existing, ok := a.promptQueue.Get(sessionID)
 348		if !ok {
 349			existing = []string{}
 350		}
 351		existing = append(existing, content)
 352		a.promptQueue.Set(sessionID, existing)
 353		return nil, nil
 354	}
 355
 356	genCtx, cancel := context.WithCancel(ctx)
 357
 358	a.activeRequests.Set(sessionID, cancel)
 359	go func() {
 360		slog.Debug("Request started", "sessionID", sessionID)
 361		defer log.RecoverPanic("agent.Run", func() {
 362			events <- a.err(fmt.Errorf("panic while running the agent"))
 363		})
 364		var attachmentParts []message.ContentPart
 365		for _, attachment := range attachments {
 366			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
 367		}
 368		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
 369		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
 370			slog.Error(result.Error.Error())
 371		}
 372		slog.Debug("Request completed", "sessionID", sessionID)
 373		a.activeRequests.Del(sessionID)
 374		cancel()
 375		a.Publish(pubsub.CreatedEvent, result)
 376		events <- result
 377		close(events)
 378	}()
 379	return events, nil
 380}
 381
 382func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
 383	cfg := config.Get()
 384	// List existing messages; if none, start title generation asynchronously.
 385	msgs, err := a.messages.List(ctx, sessionID)
 386	if err != nil {
 387		return a.err(fmt.Errorf("failed to list messages: %w", err))
 388	}
 389	if len(msgs) == 0 {
 390		go func() {
 391			defer log.RecoverPanic("agent.Run", func() {
 392				slog.Error("panic while generating title")
 393			})
 394			titleErr := a.generateTitle(ctx, sessionID, content)
 395			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
 396				slog.Error("failed to generate title", "error", titleErr)
 397			}
 398		}()
 399	}
 400	session, err := a.sessions.Get(ctx, sessionID)
 401	if err != nil {
 402		return a.err(fmt.Errorf("failed to get session: %w", err))
 403	}
 404	if session.SummaryMessageID != "" {
 405		summaryMsgInex := -1
 406		for i, msg := range msgs {
 407			if msg.ID == session.SummaryMessageID {
 408				summaryMsgInex = i
 409				break
 410			}
 411		}
 412		if summaryMsgInex != -1 {
 413			msgs = msgs[summaryMsgInex:]
 414			msgs[0].Role = message.User
 415		}
 416	}
 417
 418	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
 419	if err != nil {
 420		return a.err(fmt.Errorf("failed to create user message: %w", err))
 421	}
 422	// Append the new user message to the conversation history.
 423	msgHistory := append(msgs, userMsg)
 424
 425	for {
 426		// Check for cancellation before each iteration
 427		select {
 428		case <-ctx.Done():
 429			return a.err(ctx.Err())
 430		default:
 431			// Continue processing
 432		}
 433		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
 434		if err != nil {
 435			if errors.Is(err, context.Canceled) {
 436				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 437				a.messages.Update(context.Background(), agentMessage)
 438				return a.err(ErrRequestCancelled)
 439			}
 440			return a.err(fmt.Errorf("failed to process events: %w", err))
 441		}
 442		if cfg.Options.Debug {
 443			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
 444		}
 445		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
 446			// We are not done, we need to respond with the tool response
 447			msgHistory = append(msgHistory, agentMessage, *toolResults)
 448			// If there are queued prompts, process the next one
 449			nextPrompt, ok := a.promptQueue.Take(sessionID)
 450			if ok {
 451				for _, prompt := range nextPrompt {
 452					// Create a new user message for the queued prompt
 453					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 454					if err != nil {
 455						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 456					}
 457					// Append the new user message to the conversation history
 458					msgHistory = append(msgHistory, userMsg)
 459				}
 460			}
 461
 462			continue
 463		} else if agentMessage.FinishReason() == message.FinishReasonEndTurn {
 464			queuePrompts, ok := a.promptQueue.Take(sessionID)
 465			if ok {
 466				for _, prompt := range queuePrompts {
 467					if prompt == "" {
 468						continue
 469					}
 470					userMsg, err := a.createUserMessage(ctx, sessionID, prompt, nil)
 471					if err != nil {
 472						return a.err(fmt.Errorf("failed to create user message for queued prompt: %w", err))
 473					}
 474					msgHistory = append(msgHistory, userMsg)
 475				}
 476				continue
 477			}
 478		}
 479		if agentMessage.FinishReason() == "" {
 480			// Kujtim: could not track down where this is happening but this means its cancelled
 481			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
 482			_ = a.messages.Update(context.Background(), agentMessage)
 483			return a.err(ErrRequestCancelled)
 484		}
 485		return AgentEvent{
 486			Type:    AgentEventTypeResponse,
 487			Message: agentMessage,
 488			Done:    true,
 489		}
 490	}
 491}
 492
 493func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
 494	parts := []message.ContentPart{message.TextContent{Text: content}}
 495	parts = append(parts, attachmentParts...)
 496	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 497		Role:  message.User,
 498		Parts: parts,
 499	})
 500}
 501
 502func (a *agent) getAllTools() ([]tools.BaseTool, error) {
 503	allTools := slices.Collect(a.tools.Seq())
 504	if a.agentToolFn != nil {
 505		agentTool, agentToolErr := a.agentToolFn()
 506		if agentToolErr != nil {
 507			return nil, agentToolErr
 508		}
 509		allTools = append(allTools, agentTool)
 510	}
 511	return allTools, nil
 512}
 513
 514func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
 515	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
 516
 517	// Create the assistant message first so the spinner shows immediately
 518	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
 519		Role:     message.Assistant,
 520		Parts:    []message.ContentPart{},
 521		Model:    a.Model().ID,
 522		Provider: a.providerID,
 523	})
 524	if err != nil {
 525		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
 526	}
 527
 528	allTools, toolsErr := a.getAllTools()
 529	if toolsErr != nil {
 530		return assistantMsg, nil, toolsErr
 531	}
 532	// Now collect tools (which may block on MCP initialization)
 533	eventChan := a.provider.StreamResponse(ctx, msgHistory, allTools)
 534
 535	// Add the session and message ID into the context if needed by tools.
 536	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
 537
 538	// Process each event in the stream.
 539	for event := range eventChan {
 540		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
 541			if errors.Is(processErr, context.Canceled) {
 542				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 543			} else {
 544				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
 545			}
 546			return assistantMsg, nil, processErr
 547		}
 548		if ctx.Err() != nil {
 549			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 550			return assistantMsg, nil, ctx.Err()
 551		}
 552	}
 553
 554	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
 555	toolCalls := assistantMsg.ToolCalls()
 556	for i, toolCall := range toolCalls {
 557		select {
 558		case <-ctx.Done():
 559			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 560			// Make all future tool calls cancelled
 561			for j := i; j < len(toolCalls); j++ {
 562				toolResults[j] = message.ToolResult{
 563					ToolCallID: toolCalls[j].ID,
 564					Content:    "Tool execution canceled by user",
 565					IsError:    true,
 566				}
 567			}
 568			goto out
 569		default:
 570			// Continue processing
 571			var tool tools.BaseTool
 572			allTools, _ := a.getAllTools()
 573			for _, availableTool := range allTools {
 574				if availableTool.Info().Name == toolCall.Name {
 575					tool = availableTool
 576					break
 577				}
 578			}
 579
 580			// Tool not found
 581			if tool == nil {
 582				toolResults[i] = message.ToolResult{
 583					ToolCallID: toolCall.ID,
 584					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
 585					IsError:    true,
 586				}
 587				continue
 588			}
 589
 590			// Run tool in goroutine to allow cancellation
 591			type toolExecResult struct {
 592				response tools.ToolResponse
 593				err      error
 594			}
 595			resultChan := make(chan toolExecResult, 1)
 596
 597			go func() {
 598				response, err := tool.Run(ctx, tools.ToolCall{
 599					ID:    toolCall.ID,
 600					Name:  toolCall.Name,
 601					Input: toolCall.Input,
 602				})
 603				resultChan <- toolExecResult{response: response, err: err}
 604			}()
 605
 606			var toolResponse tools.ToolResponse
 607			var toolErr error
 608
 609			select {
 610			case <-ctx.Done():
 611				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
 612				// Mark remaining tool calls as cancelled
 613				for j := i; j < len(toolCalls); j++ {
 614					toolResults[j] = message.ToolResult{
 615						ToolCallID: toolCalls[j].ID,
 616						Content:    "Tool execution canceled by user",
 617						IsError:    true,
 618					}
 619				}
 620				goto out
 621			case result := <-resultChan:
 622				toolResponse = result.response
 623				toolErr = result.err
 624			}
 625
 626			if toolErr != nil {
 627				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
 628				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
 629					toolResults[i] = message.ToolResult{
 630						ToolCallID: toolCall.ID,
 631						Content:    "Permission denied",
 632						IsError:    true,
 633					}
 634					for j := i + 1; j < len(toolCalls); j++ {
 635						toolResults[j] = message.ToolResult{
 636							ToolCallID: toolCalls[j].ID,
 637							Content:    "Tool execution canceled by user",
 638							IsError:    true,
 639						}
 640					}
 641					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
 642					break
 643				}
 644			}
 645			toolResults[i] = message.ToolResult{
 646				ToolCallID: toolCall.ID,
 647				Content:    toolResponse.Content,
 648				Metadata:   toolResponse.Metadata,
 649				IsError:    toolResponse.IsError,
 650			}
 651		}
 652	}
 653out:
 654	if len(toolResults) == 0 {
 655		return assistantMsg, nil, nil
 656	}
 657	parts := make([]message.ContentPart, 0)
 658	for _, tr := range toolResults {
 659		parts = append(parts, tr)
 660	}
 661	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
 662		Role:     message.Tool,
 663		Parts:    parts,
 664		Provider: a.providerID,
 665	})
 666	if err != nil {
 667		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
 668	}
 669
 670	return assistantMsg, &msg, err
 671}
 672
 673func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
 674	msg.AddFinish(finishReason, message, details)
 675	_ = a.messages.Update(ctx, *msg)
 676}
 677
 678func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
 679	select {
 680	case <-ctx.Done():
 681		return ctx.Err()
 682	default:
 683		// Continue processing.
 684	}
 685
 686	switch event.Type {
 687	case provider.EventThinkingDelta:
 688		assistantMsg.AppendReasoningContent(event.Thinking)
 689		return a.messages.Update(ctx, *assistantMsg)
 690	case provider.EventSignatureDelta:
 691		assistantMsg.AppendReasoningSignature(event.Signature)
 692		return a.messages.Update(ctx, *assistantMsg)
 693	case provider.EventContentDelta:
 694		assistantMsg.FinishThinking()
 695		assistantMsg.AppendContent(event.Content)
 696		return a.messages.Update(ctx, *assistantMsg)
 697	case provider.EventToolUseStart:
 698		assistantMsg.FinishThinking()
 699		slog.Info("Tool call started", "toolCall", event.ToolCall)
 700		assistantMsg.AddToolCall(*event.ToolCall)
 701		return a.messages.Update(ctx, *assistantMsg)
 702	case provider.EventToolUseDelta:
 703		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
 704		return a.messages.Update(ctx, *assistantMsg)
 705	case provider.EventToolUseStop:
 706		slog.Info("Finished tool call", "toolCall", event.ToolCall)
 707		assistantMsg.FinishToolCall(event.ToolCall.ID)
 708		return a.messages.Update(ctx, *assistantMsg)
 709	case provider.EventError:
 710		return event.Error
 711	case provider.EventComplete:
 712		assistantMsg.FinishThinking()
 713		assistantMsg.SetToolCalls(event.Response.ToolCalls)
 714		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
 715		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
 716			return fmt.Errorf("failed to update message: %w", err)
 717		}
 718		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
 719	}
 720
 721	return nil
 722}
 723
 724func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
 725	sess, err := a.sessions.Get(ctx, sessionID)
 726	if err != nil {
 727		return fmt.Errorf("failed to get session: %w", err)
 728	}
 729
 730	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 731		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 732		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 733		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 734
 735	sess.Cost += cost
 736	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
 737	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
 738
 739	_, err = a.sessions.Save(ctx, sess)
 740	if err != nil {
 741		return fmt.Errorf("failed to save session: %w", err)
 742	}
 743	return nil
 744}
 745
 746func (a *agent) Summarize(ctx context.Context, sessionID string) error {
 747	if a.summarizeProvider == nil {
 748		return fmt.Errorf("summarize provider not available")
 749	}
 750
 751	// Check if session is busy
 752	if a.IsSessionBusy(sessionID) {
 753		return ErrSessionBusy
 754	}
 755
 756	// Create a new context with cancellation
 757	summarizeCtx, cancel := context.WithCancel(ctx)
 758
 759	// Store the cancel function in activeRequests to allow cancellation
 760	a.activeRequests.Set(sessionID+"-summarize", cancel)
 761
 762	go func() {
 763		defer a.activeRequests.Del(sessionID + "-summarize")
 764		defer cancel()
 765		event := AgentEvent{
 766			Type:     AgentEventTypeSummarize,
 767			Progress: "Starting summarization...",
 768		}
 769
 770		a.Publish(pubsub.CreatedEvent, event)
 771		// Get all messages from the session
 772		msgs, err := a.messages.List(summarizeCtx, sessionID)
 773		if err != nil {
 774			event = AgentEvent{
 775				Type:  AgentEventTypeError,
 776				Error: fmt.Errorf("failed to list messages: %w", err),
 777				Done:  true,
 778			}
 779			a.Publish(pubsub.CreatedEvent, event)
 780			return
 781		}
 782		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
 783
 784		if len(msgs) == 0 {
 785			event = AgentEvent{
 786				Type:  AgentEventTypeError,
 787				Error: fmt.Errorf("no messages to summarize"),
 788				Done:  true,
 789			}
 790			a.Publish(pubsub.CreatedEvent, event)
 791			return
 792		}
 793
 794		event = AgentEvent{
 795			Type:     AgentEventTypeSummarize,
 796			Progress: "Analyzing conversation...",
 797		}
 798		a.Publish(pubsub.CreatedEvent, event)
 799
 800		// Add a system message to guide the summarization
 801		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
 802
 803		// Create a new message with the summarize prompt
 804		promptMsg := message.Message{
 805			Role:  message.User,
 806			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
 807		}
 808
 809		// Append the prompt to the messages
 810		msgsWithPrompt := append(msgs, promptMsg)
 811
 812		event = AgentEvent{
 813			Type:     AgentEventTypeSummarize,
 814			Progress: "Generating summary...",
 815		}
 816
 817		a.Publish(pubsub.CreatedEvent, event)
 818
 819		// Send the messages to the summarize provider
 820		response := a.summarizeProvider.StreamResponse(
 821			summarizeCtx,
 822			msgsWithPrompt,
 823			nil,
 824		)
 825		var finalResponse *provider.ProviderResponse
 826		for r := range response {
 827			if r.Error != nil {
 828				event = AgentEvent{
 829					Type:  AgentEventTypeError,
 830					Error: fmt.Errorf("failed to summarize: %w", err),
 831					Done:  true,
 832				}
 833				a.Publish(pubsub.CreatedEvent, event)
 834				return
 835			}
 836			finalResponse = r.Response
 837		}
 838
 839		summary := strings.TrimSpace(finalResponse.Content)
 840		if summary == "" {
 841			event = AgentEvent{
 842				Type:  AgentEventTypeError,
 843				Error: fmt.Errorf("empty summary returned"),
 844				Done:  true,
 845			}
 846			a.Publish(pubsub.CreatedEvent, event)
 847			return
 848		}
 849		shell := shell.GetPersistentShell(config.Get().WorkingDir())
 850		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
 851		event = AgentEvent{
 852			Type:     AgentEventTypeSummarize,
 853			Progress: "Creating new session...",
 854		}
 855
 856		a.Publish(pubsub.CreatedEvent, event)
 857		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
 858		if err != nil {
 859			event = AgentEvent{
 860				Type:  AgentEventTypeError,
 861				Error: fmt.Errorf("failed to get session: %w", err),
 862				Done:  true,
 863			}
 864
 865			a.Publish(pubsub.CreatedEvent, event)
 866			return
 867		}
 868		// Create a message in the new session with the summary
 869		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
 870			Role: message.Assistant,
 871			Parts: []message.ContentPart{
 872				message.TextContent{Text: summary},
 873				message.Finish{
 874					Reason: message.FinishReasonEndTurn,
 875					Time:   time.Now().Unix(),
 876				},
 877			},
 878			Model:    a.summarizeProvider.Model().ID,
 879			Provider: a.summarizeProviderID,
 880		})
 881		if err != nil {
 882			event = AgentEvent{
 883				Type:  AgentEventTypeError,
 884				Error: fmt.Errorf("failed to create summary message: %w", err),
 885				Done:  true,
 886			}
 887
 888			a.Publish(pubsub.CreatedEvent, event)
 889			return
 890		}
 891		oldSession.SummaryMessageID = msg.ID
 892		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
 893		oldSession.PromptTokens = 0
 894		model := a.summarizeProvider.Model()
 895		usage := finalResponse.Usage
 896		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
 897			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
 898			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
 899			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
 900		oldSession.Cost += cost
 901		_, err = a.sessions.Save(summarizeCtx, oldSession)
 902		if err != nil {
 903			event = AgentEvent{
 904				Type:  AgentEventTypeError,
 905				Error: fmt.Errorf("failed to save session: %w", err),
 906				Done:  true,
 907			}
 908			a.Publish(pubsub.CreatedEvent, event)
 909		}
 910
 911		event = AgentEvent{
 912			Type:      AgentEventTypeSummarize,
 913			SessionID: oldSession.ID,
 914			Progress:  "Summary complete",
 915			Done:      true,
 916		}
 917		a.Publish(pubsub.CreatedEvent, event)
 918		// Send final success event with the new session ID
 919	}()
 920
 921	return nil
 922}
 923
 924func (a *agent) ClearQueue(sessionID string) {
 925	if a.QueuedPrompts(sessionID) > 0 {
 926		slog.Info("Clearing queued prompts", "session_id", sessionID)
 927		a.promptQueue.Del(sessionID)
 928	}
 929}
 930
 931func (a *agent) CancelAll() {
 932	if !a.IsBusy() {
 933		return
 934	}
 935	for key := range a.activeRequests.Seq2() {
 936		a.Cancel(key) // key is sessionID
 937	}
 938
 939	timeout := time.After(5 * time.Second)
 940	for a.IsBusy() {
 941		select {
 942		case <-timeout:
 943			return
 944		default:
 945			time.Sleep(200 * time.Millisecond)
 946		}
 947	}
 948}
 949
 950func (a *agent) UpdateModel() error {
 951	cfg := config.Get()
 952
 953	// Get current provider configuration
 954	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
 955	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
 956		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
 957	}
 958
 959	// Check if provider has changed
 960	if string(currentProviderCfg.ID) != a.providerID {
 961		// Provider changed, need to recreate the main provider
 962		model := cfg.GetModelByType(a.agentCfg.Model)
 963		if model.ID == "" {
 964			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
 965		}
 966
 967		promptID := agentPromptMap[a.agentCfg.ID]
 968		if promptID == "" {
 969			promptID = prompt.PromptDefault
 970		}
 971
 972		opts := []provider.ProviderClientOption{
 973			provider.WithModel(a.agentCfg.Model),
 974			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
 975		}
 976
 977		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
 978		if err != nil {
 979			return fmt.Errorf("failed to create new provider: %w", err)
 980		}
 981
 982		// Update the provider and provider ID
 983		a.provider = newProvider
 984		a.providerID = string(currentProviderCfg.ID)
 985	}
 986
 987	// Check if providers have changed for title (small) and summarize (large)
 988	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
 989	var smallModelProviderCfg config.ProviderConfig
 990	for p := range cfg.Providers.Seq() {
 991		if p.ID == smallModelCfg.Provider {
 992			smallModelProviderCfg = p
 993			break
 994		}
 995	}
 996	if smallModelProviderCfg.ID == "" {
 997		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
 998	}
 999
1000	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
1001	var largeModelProviderCfg config.ProviderConfig
1002	for p := range cfg.Providers.Seq() {
1003		if p.ID == largeModelCfg.Provider {
1004			largeModelProviderCfg = p
1005			break
1006		}
1007	}
1008	if largeModelProviderCfg.ID == "" {
1009		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
1010	}
1011
1012	var maxTitleTokens int64 = 40
1013
1014	// if the max output is too low for the gemini provider it won't return anything
1015	if smallModelCfg.Provider == "gemini" {
1016		maxTitleTokens = 1000
1017	}
1018	// Recreate title provider
1019	titleOpts := []provider.ProviderClientOption{
1020		provider.WithModel(config.SelectedModelTypeSmall),
1021		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
1022		provider.WithMaxTokens(maxTitleTokens),
1023	}
1024	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
1025	if err != nil {
1026		return fmt.Errorf("failed to create new title provider: %w", err)
1027	}
1028	a.titleProvider = newTitleProvider
1029
1030	// Recreate summarize provider if provider changed (now large model)
1031	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
1032		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
1033		if largeModel == nil {
1034			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
1035		}
1036		summarizeOpts := []provider.ProviderClientOption{
1037			provider.WithModel(config.SelectedModelTypeLarge),
1038			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
1039		}
1040		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
1041		if err != nil {
1042			return fmt.Errorf("failed to create new summarize provider: %w", err)
1043		}
1044		a.summarizeProvider = newSummarizeProvider
1045		a.summarizeProviderID = string(largeModelProviderCfg.ID)
1046	}
1047
1048	return nil
1049}