agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/crush/internal/shell"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() catwalk.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70	mcpTools []McpTool
 71
 72	tools *csync.LazySlice[tools.BaseTool]
 73
 74	provider   provider.Provider
 75	providerID string
 76
 77	titleProvider       provider.Provider
 78	summarizeProvider   provider.Provider
 79	summarizeProviderID string
 80
 81	activeRequests *csync.Map[string, context.CancelFunc]
 82}
 83
 84var agentPromptMap = map[string]prompt.PromptID{
 85	"coder": prompt.PromptCoder,
 86	"task":  prompt.PromptTask,
 87}
 88
 89func NewAgent(
 90	ctx context.Context,
 91	agentCfg config.Agent,
 92	// These services are needed in the tools
 93	permissions permission.Service,
 94	sessions session.Service,
 95	messages message.Service,
 96	history history.Service,
 97	lspClients map[string]*lsp.Client,
 98) (Service, error) {
 99	cfg := config.Get()
100
101	var agentTool tools.BaseTool
102	if agentCfg.ID == "coder" {
103		taskAgentCfg := config.Get().Agents["task"]
104		if taskAgentCfg.ID == "" {
105			return nil, fmt.Errorf("task agent not found in config")
106		}
107		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108		if err != nil {
109			return nil, fmt.Errorf("failed to create task agent: %w", err)
110		}
111
112		agentTool = NewAgentTool(taskAgent, sessions, messages)
113	}
114
115	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116	if providerCfg == nil {
117		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118	}
119	model := config.Get().GetModelByType(agentCfg.Model)
120
121	if model == nil {
122		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123	}
124
125	promptID := agentPromptMap[agentCfg.ID]
126	if promptID == "" {
127		promptID = prompt.PromptDefault
128	}
129	opts := []provider.ProviderClientOption{
130		provider.WithModel(agentCfg.Model),
131		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132	}
133	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134	if err != nil {
135		return nil, err
136	}
137
138	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139	var smallModelProviderCfg *config.ProviderConfig
140	if smallModelCfg.Provider == providerCfg.ID {
141		smallModelProviderCfg = providerCfg
142	} else {
143		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145		if smallModelProviderCfg.ID == "" {
146			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147		}
148	}
149	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150	if smallModel.ID == "" {
151		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152	}
153
154	titleOpts := []provider.ProviderClientOption{
155		provider.WithModel(config.SelectedModelTypeSmall),
156		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157	}
158	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159	if err != nil {
160		return nil, err
161	}
162
163	summarizeOpts := []provider.ProviderClientOption{
164		provider.WithModel(config.SelectedModelTypeLarge),
165		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, providerCfg.ID)),
166	}
167	summarizeProvider, err := provider.NewProvider(*providerCfg, summarizeOpts...)
168	if err != nil {
169		return nil, err
170	}
171
172	toolFn := func() []tools.BaseTool {
173		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
174		defer func() {
175			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
176		}()
177
178		cwd := cfg.WorkingDir()
179		allTools := []tools.BaseTool{
180			tools.NewBashTool(permissions, cwd),
181			tools.NewDownloadTool(permissions, cwd),
182			tools.NewEditTool(lspClients, permissions, history, cwd),
183			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
184			tools.NewFetchTool(permissions, cwd),
185			tools.NewGlobTool(cwd),
186			tools.NewGrepTool(cwd),
187			tools.NewLsTool(permissions, cwd),
188			tools.NewSourcegraphTool(),
189			tools.NewViewTool(lspClients, permissions, cwd),
190			tools.NewWriteTool(lspClients, permissions, history, cwd),
191		}
192
193		mcpToolsOnce.Do(func() {
194			mcpTools = doGetMCPTools(ctx, permissions, cfg)
195		})
196		allTools = append(allTools, mcpTools...)
197
198		if len(lspClients) > 0 {
199			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
200		}
201
202		if agentTool != nil {
203			allTools = append(allTools, agentTool)
204		}
205
206		if agentCfg.AllowedTools == nil {
207			return allTools
208		}
209
210		var filteredTools []tools.BaseTool
211		for _, tool := range allTools {
212			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
213				filteredTools = append(filteredTools, tool)
214			}
215		}
216		return filteredTools
217	}
218
219	return &agent{
220		Broker:              pubsub.NewBroker[AgentEvent](),
221		agentCfg:            agentCfg,
222		provider:            agentProvider,
223		providerID:          string(providerCfg.ID),
224		messages:            messages,
225		sessions:            sessions,
226		titleProvider:       titleProvider,
227		summarizeProvider:   summarizeProvider,
228		summarizeProviderID: string(providerCfg.ID),
229		activeRequests:      csync.NewMap[string, context.CancelFunc](),
230		tools:               csync.NewLazySlice(toolFn),
231	}, nil
232}
233
234func (a *agent) Model() catwalk.Model {
235	return *config.Get().GetModelByType(a.agentCfg.Model)
236}
237
238func (a *agent) Cancel(sessionID string) {
239	// Cancel regular requests
240	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
241		slog.Info("Request cancellation initiated", "session_id", sessionID)
242		cancel()
243	}
244
245	// Also check for summarize requests
246	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
247		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
248		cancel()
249	}
250}
251
252func (a *agent) IsBusy() bool {
253	var busy bool
254	for cancelFunc := range a.activeRequests.Seq() {
255		if cancelFunc != nil {
256			busy = true
257			break
258		}
259	}
260	return busy
261}
262
263func (a *agent) IsSessionBusy(sessionID string) bool {
264	_, busy := a.activeRequests.Get(sessionID)
265	return busy
266}
267
268func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
269	if content == "" {
270		return nil
271	}
272	if a.titleProvider == nil {
273		return nil
274	}
275	session, err := a.sessions.Get(ctx, sessionID)
276	if err != nil {
277		return err
278	}
279	parts := []message.ContentPart{message.TextContent{
280		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
281	}}
282
283	// Use streaming approach like summarization
284	response := a.titleProvider.StreamResponse(
285		ctx,
286		[]message.Message{
287			{
288				Role:  message.User,
289				Parts: parts,
290			},
291		},
292		nil,
293	)
294
295	var finalResponse *provider.ProviderResponse
296	for r := range response {
297		if r.Error != nil {
298			return r.Error
299		}
300		finalResponse = r.Response
301	}
302
303	if finalResponse == nil {
304		return fmt.Errorf("no response received from title provider")
305	}
306
307	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
308	if title == "" {
309		return nil
310	}
311
312	session.Title = title
313	_, err = a.sessions.Save(ctx, session)
314	return err
315}
316
317func (a *agent) err(err error) AgentEvent {
318	return AgentEvent{
319		Type:  AgentEventTypeError,
320		Error: err,
321	}
322}
323
324func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
325	if !a.Model().SupportsImages && attachments != nil {
326		attachments = nil
327	}
328	events := make(chan AgentEvent)
329	if a.IsSessionBusy(sessionID) {
330		return nil, ErrSessionBusy
331	}
332
333	genCtx, cancel := context.WithCancel(ctx)
334
335	a.activeRequests.Set(sessionID, cancel)
336	go func() {
337		slog.Debug("Request started", "sessionID", sessionID)
338		defer log.RecoverPanic("agent.Run", func() {
339			events <- a.err(fmt.Errorf("panic while running the agent"))
340		})
341		var attachmentParts []message.ContentPart
342		for _, attachment := range attachments {
343			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
344		}
345		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
346		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
347			slog.Error(result.Error.Error())
348		}
349		slog.Debug("Request completed", "sessionID", sessionID)
350		a.activeRequests.Del(sessionID)
351		cancel()
352		a.Publish(pubsub.CreatedEvent, result)
353		events <- result
354		close(events)
355	}()
356	return events, nil
357}
358
359func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
360	cfg := config.Get()
361	// List existing messages; if none, start title generation asynchronously.
362	msgs, err := a.messages.List(ctx, sessionID)
363	if err != nil {
364		return a.err(fmt.Errorf("failed to list messages: %w", err))
365	}
366	if len(msgs) == 0 {
367		go func() {
368			defer log.RecoverPanic("agent.Run", func() {
369				slog.Error("panic while generating title")
370			})
371			titleErr := a.generateTitle(context.Background(), sessionID, content)
372			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
373				slog.Error("failed to generate title", "error", titleErr)
374			}
375		}()
376	}
377	session, err := a.sessions.Get(ctx, sessionID)
378	if err != nil {
379		return a.err(fmt.Errorf("failed to get session: %w", err))
380	}
381	if session.SummaryMessageID != "" {
382		summaryMsgInex := -1
383		for i, msg := range msgs {
384			if msg.ID == session.SummaryMessageID {
385				summaryMsgInex = i
386				break
387			}
388		}
389		if summaryMsgInex != -1 {
390			msgs = msgs[summaryMsgInex:]
391			msgs[0].Role = message.User
392		}
393	}
394
395	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
396	if err != nil {
397		return a.err(fmt.Errorf("failed to create user message: %w", err))
398	}
399	// Append the new user message to the conversation history.
400	msgHistory := append(msgs, userMsg)
401
402	for {
403		// Check for cancellation before each iteration
404		select {
405		case <-ctx.Done():
406			return a.err(ctx.Err())
407		default:
408			// Continue processing
409		}
410		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
411		if err != nil {
412			if errors.Is(err, context.Canceled) {
413				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
414				a.messages.Update(context.Background(), agentMessage)
415				return a.err(ErrRequestCancelled)
416			}
417			return a.err(fmt.Errorf("failed to process events: %w", err))
418		}
419		if cfg.Options.Debug {
420			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
421		}
422		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
423			// We are not done, we need to respond with the tool response
424			msgHistory = append(msgHistory, agentMessage, *toolResults)
425			continue
426		}
427		if agentMessage.FinishReason() == "" {
428			// Kujtim: could not track down where this is happening but this means its cancelled
429			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
430			_ = a.messages.Update(context.Background(), agentMessage)
431			return a.err(ErrRequestCancelled)
432		}
433		return AgentEvent{
434			Type:    AgentEventTypeResponse,
435			Message: agentMessage,
436			Done:    true,
437		}
438	}
439}
440
441func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
442	parts := []message.ContentPart{message.TextContent{Text: content}}
443	parts = append(parts, attachmentParts...)
444	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
445		Role:  message.User,
446		Parts: parts,
447	})
448}
449
450func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
451	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
452
453	// Create the assistant message first so the spinner shows immediately
454	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
455		Role:     message.Assistant,
456		Parts:    []message.ContentPart{},
457		Model:    a.Model().ID,
458		Provider: a.providerID,
459	})
460	if err != nil {
461		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
462	}
463
464	// Now collect tools (which may block on MCP initialization)
465	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
466
467	// Add the session and message ID into the context if needed by tools.
468	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
469
470	// Process each event in the stream.
471	for event := range eventChan {
472		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
473			if errors.Is(processErr, context.Canceled) {
474				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
475			} else {
476				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
477			}
478			return assistantMsg, nil, processErr
479		}
480		if ctx.Err() != nil {
481			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
482			return assistantMsg, nil, ctx.Err()
483		}
484	}
485
486	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
487	toolCalls := assistantMsg.ToolCalls()
488	for i, toolCall := range toolCalls {
489		select {
490		case <-ctx.Done():
491			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
492			// Make all future tool calls cancelled
493			for j := i; j < len(toolCalls); j++ {
494				toolResults[j] = message.ToolResult{
495					ToolCallID: toolCalls[j].ID,
496					Content:    "Tool execution canceled by user",
497					IsError:    true,
498				}
499			}
500			goto out
501		default:
502			// Continue processing
503			var tool tools.BaseTool
504			for availableTool := range a.tools.Seq() {
505				if availableTool.Info().Name == toolCall.Name {
506					tool = availableTool
507					break
508				}
509			}
510
511			// Tool not found
512			if tool == nil {
513				toolResults[i] = message.ToolResult{
514					ToolCallID: toolCall.ID,
515					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
516					IsError:    true,
517				}
518				continue
519			}
520
521			// Run tool in goroutine to allow cancellation
522			type toolExecResult struct {
523				response tools.ToolResponse
524				err      error
525			}
526			resultChan := make(chan toolExecResult, 1)
527
528			go func() {
529				response, err := tool.Run(ctx, tools.ToolCall{
530					ID:    toolCall.ID,
531					Name:  toolCall.Name,
532					Input: toolCall.Input,
533				})
534				resultChan <- toolExecResult{response: response, err: err}
535			}()
536
537			var toolResponse tools.ToolResponse
538			var toolErr error
539
540			select {
541			case <-ctx.Done():
542				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
543				// Mark remaining tool calls as cancelled
544				for j := i; j < len(toolCalls); j++ {
545					toolResults[j] = message.ToolResult{
546						ToolCallID: toolCalls[j].ID,
547						Content:    "Tool execution canceled by user",
548						IsError:    true,
549					}
550				}
551				goto out
552			case result := <-resultChan:
553				toolResponse = result.response
554				toolErr = result.err
555			}
556
557			if toolErr != nil {
558				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
559				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
560					toolResults[i] = message.ToolResult{
561						ToolCallID: toolCall.ID,
562						Content:    "Permission denied",
563						IsError:    true,
564					}
565					for j := i + 1; j < len(toolCalls); j++ {
566						toolResults[j] = message.ToolResult{
567							ToolCallID: toolCalls[j].ID,
568							Content:    "Tool execution canceled by user",
569							IsError:    true,
570						}
571					}
572					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
573					break
574				}
575			}
576			toolResults[i] = message.ToolResult{
577				ToolCallID: toolCall.ID,
578				Content:    toolResponse.Content,
579				Metadata:   toolResponse.Metadata,
580				IsError:    toolResponse.IsError,
581			}
582		}
583	}
584out:
585	if len(toolResults) == 0 {
586		return assistantMsg, nil, nil
587	}
588	parts := make([]message.ContentPart, 0)
589	for _, tr := range toolResults {
590		parts = append(parts, tr)
591	}
592	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
593		Role:     message.Tool,
594		Parts:    parts,
595		Provider: a.providerID,
596	})
597	if err != nil {
598		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
599	}
600
601	return assistantMsg, &msg, err
602}
603
604func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
605	msg.AddFinish(finishReason, message, details)
606	_ = a.messages.Update(ctx, *msg)
607}
608
609func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
610	select {
611	case <-ctx.Done():
612		return ctx.Err()
613	default:
614		// Continue processing.
615	}
616
617	switch event.Type {
618	case provider.EventThinkingDelta:
619		assistantMsg.AppendReasoningContent(event.Thinking)
620		return a.messages.Update(ctx, *assistantMsg)
621	case provider.EventSignatureDelta:
622		assistantMsg.AppendReasoningSignature(event.Signature)
623		return a.messages.Update(ctx, *assistantMsg)
624	case provider.EventContentDelta:
625		assistantMsg.FinishThinking()
626		assistantMsg.AppendContent(event.Content)
627		return a.messages.Update(ctx, *assistantMsg)
628	case provider.EventToolUseStart:
629		assistantMsg.FinishThinking()
630		slog.Info("Tool call started", "toolCall", event.ToolCall)
631		assistantMsg.AddToolCall(*event.ToolCall)
632		return a.messages.Update(ctx, *assistantMsg)
633	case provider.EventToolUseDelta:
634		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
635		return a.messages.Update(ctx, *assistantMsg)
636	case provider.EventToolUseStop:
637		slog.Info("Finished tool call", "toolCall", event.ToolCall)
638		assistantMsg.FinishToolCall(event.ToolCall.ID)
639		return a.messages.Update(ctx, *assistantMsg)
640	case provider.EventError:
641		return event.Error
642	case provider.EventComplete:
643		assistantMsg.FinishThinking()
644		assistantMsg.SetToolCalls(event.Response.ToolCalls)
645		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
646		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
647			return fmt.Errorf("failed to update message: %w", err)
648		}
649		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
650	}
651
652	return nil
653}
654
655func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
656	sess, err := a.sessions.Get(ctx, sessionID)
657	if err != nil {
658		return fmt.Errorf("failed to get session: %w", err)
659	}
660
661	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
662		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
663		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
664		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
665
666	sess.Cost += cost
667	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
668	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
669
670	_, err = a.sessions.Save(ctx, sess)
671	if err != nil {
672		return fmt.Errorf("failed to save session: %w", err)
673	}
674	return nil
675}
676
677func (a *agent) Summarize(ctx context.Context, sessionID string) error {
678	if a.summarizeProvider == nil {
679		return fmt.Errorf("summarize provider not available")
680	}
681
682	// Check if session is busy
683	if a.IsSessionBusy(sessionID) {
684		return ErrSessionBusy
685	}
686
687	// Create a new context with cancellation
688	summarizeCtx, cancel := context.WithCancel(ctx)
689
690	// Store the cancel function in activeRequests to allow cancellation
691	a.activeRequests.Set(sessionID+"-summarize", cancel)
692
693	go func() {
694		defer a.activeRequests.Del(sessionID + "-summarize")
695		defer cancel()
696		event := AgentEvent{
697			Type:     AgentEventTypeSummarize,
698			Progress: "Starting summarization...",
699		}
700
701		a.Publish(pubsub.CreatedEvent, event)
702		// Get all messages from the session
703		msgs, err := a.messages.List(summarizeCtx, sessionID)
704		if err != nil {
705			event = AgentEvent{
706				Type:  AgentEventTypeError,
707				Error: fmt.Errorf("failed to list messages: %w", err),
708				Done:  true,
709			}
710			a.Publish(pubsub.CreatedEvent, event)
711			return
712		}
713		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
714
715		if len(msgs) == 0 {
716			event = AgentEvent{
717				Type:  AgentEventTypeError,
718				Error: fmt.Errorf("no messages to summarize"),
719				Done:  true,
720			}
721			a.Publish(pubsub.CreatedEvent, event)
722			return
723		}
724
725		event = AgentEvent{
726			Type:     AgentEventTypeSummarize,
727			Progress: "Analyzing conversation...",
728		}
729		a.Publish(pubsub.CreatedEvent, event)
730
731		// Add a system message to guide the summarization
732		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
733
734		// Create a new message with the summarize prompt
735		promptMsg := message.Message{
736			Role:  message.User,
737			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
738		}
739
740		// Append the prompt to the messages
741		msgsWithPrompt := append(msgs, promptMsg)
742
743		event = AgentEvent{
744			Type:     AgentEventTypeSummarize,
745			Progress: "Generating summary...",
746		}
747
748		a.Publish(pubsub.CreatedEvent, event)
749
750		// Send the messages to the summarize provider
751		response := a.summarizeProvider.StreamResponse(
752			summarizeCtx,
753			msgsWithPrompt,
754			nil,
755		)
756		var finalResponse *provider.ProviderResponse
757		for r := range response {
758			if r.Error != nil {
759				event = AgentEvent{
760					Type:  AgentEventTypeError,
761					Error: fmt.Errorf("failed to summarize: %w", err),
762					Done:  true,
763				}
764				a.Publish(pubsub.CreatedEvent, event)
765				return
766			}
767			finalResponse = r.Response
768		}
769
770		summary := strings.TrimSpace(finalResponse.Content)
771		if summary == "" {
772			event = AgentEvent{
773				Type:  AgentEventTypeError,
774				Error: fmt.Errorf("empty summary returned"),
775				Done:  true,
776			}
777			a.Publish(pubsub.CreatedEvent, event)
778			return
779		}
780		shell := shell.GetPersistentShell(config.Get().WorkingDir())
781		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
782		event = AgentEvent{
783			Type:     AgentEventTypeSummarize,
784			Progress: "Creating new session...",
785		}
786
787		a.Publish(pubsub.CreatedEvent, event)
788		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
789		if err != nil {
790			event = AgentEvent{
791				Type:  AgentEventTypeError,
792				Error: fmt.Errorf("failed to get session: %w", err),
793				Done:  true,
794			}
795
796			a.Publish(pubsub.CreatedEvent, event)
797			return
798		}
799		// Create a message in the new session with the summary
800		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
801			Role: message.Assistant,
802			Parts: []message.ContentPart{
803				message.TextContent{Text: summary},
804				message.Finish{
805					Reason: message.FinishReasonEndTurn,
806					Time:   time.Now().Unix(),
807				},
808			},
809			Model:    a.summarizeProvider.Model().ID,
810			Provider: a.summarizeProviderID,
811		})
812		if err != nil {
813			event = AgentEvent{
814				Type:  AgentEventTypeError,
815				Error: fmt.Errorf("failed to create summary message: %w", err),
816				Done:  true,
817			}
818
819			a.Publish(pubsub.CreatedEvent, event)
820			return
821		}
822		oldSession.SummaryMessageID = msg.ID
823		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
824		oldSession.PromptTokens = 0
825		model := a.summarizeProvider.Model()
826		usage := finalResponse.Usage
827		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
828			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
829			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
830			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
831		oldSession.Cost += cost
832		_, err = a.sessions.Save(summarizeCtx, oldSession)
833		if err != nil {
834			event = AgentEvent{
835				Type:  AgentEventTypeError,
836				Error: fmt.Errorf("failed to save session: %w", err),
837				Done:  true,
838			}
839			a.Publish(pubsub.CreatedEvent, event)
840		}
841
842		event = AgentEvent{
843			Type:      AgentEventTypeSummarize,
844			SessionID: oldSession.ID,
845			Progress:  "Summary complete",
846			Done:      true,
847		}
848		a.Publish(pubsub.CreatedEvent, event)
849		// Send final success event with the new session ID
850	}()
851
852	return nil
853}
854
855func (a *agent) CancelAll() {
856	if !a.IsBusy() {
857		return
858	}
859	for key := range a.activeRequests.Seq2() {
860		a.Cancel(key) // key is sessionID
861	}
862
863	timeout := time.After(5 * time.Second)
864	for a.IsBusy() {
865		select {
866		case <-timeout:
867			return
868		default:
869			time.Sleep(200 * time.Millisecond)
870		}
871	}
872}
873
874func (a *agent) UpdateModel() error {
875	cfg := config.Get()
876
877	// Get current provider configuration
878	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
879	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
880		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
881	}
882
883	// Check if provider has changed
884	if string(currentProviderCfg.ID) != a.providerID {
885		// Provider changed, need to recreate the main provider
886		model := cfg.GetModelByType(a.agentCfg.Model)
887		if model.ID == "" {
888			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
889		}
890
891		promptID := agentPromptMap[a.agentCfg.ID]
892		if promptID == "" {
893			promptID = prompt.PromptDefault
894		}
895
896		opts := []provider.ProviderClientOption{
897			provider.WithModel(a.agentCfg.Model),
898			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
899		}
900
901		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
902		if err != nil {
903			return fmt.Errorf("failed to create new provider: %w", err)
904		}
905
906		// Update the provider and provider ID
907		a.provider = newProvider
908		a.providerID = string(currentProviderCfg.ID)
909	}
910
911	// Check if providers have changed for title (small) and summarize (large)
912	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
913	var smallModelProviderCfg config.ProviderConfig
914	for p := range cfg.Providers.Seq() {
915		if p.ID == smallModelCfg.Provider {
916			smallModelProviderCfg = p
917			break
918		}
919	}
920	if smallModelProviderCfg.ID == "" {
921		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
922	}
923
924	largeModelCfg := cfg.Models[config.SelectedModelTypeLarge]
925	var largeModelProviderCfg config.ProviderConfig
926	for p := range cfg.Providers.Seq() {
927		if p.ID == largeModelCfg.Provider {
928			largeModelProviderCfg = p
929			break
930		}
931	}
932	if largeModelProviderCfg.ID == "" {
933		return fmt.Errorf("provider %s not found in config", largeModelCfg.Provider)
934	}
935
936	// Recreate title provider
937	titleOpts := []provider.ProviderClientOption{
938		provider.WithModel(config.SelectedModelTypeSmall),
939		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
940		provider.WithMaxTokens(40),
941	}
942	newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
943	if err != nil {
944		return fmt.Errorf("failed to create new title provider: %w", err)
945	}
946	a.titleProvider = newTitleProvider
947
948	// Recreate summarize provider if provider changed (now large model)
949	if string(largeModelProviderCfg.ID) != a.summarizeProviderID {
950		largeModel := cfg.GetModelByType(config.SelectedModelTypeLarge)
951		if largeModel == nil {
952			return fmt.Errorf("model %s not found in provider %s", largeModelCfg.Model, largeModelProviderCfg.ID)
953		}
954		summarizeOpts := []provider.ProviderClientOption{
955			provider.WithModel(config.SelectedModelTypeLarge),
956			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, largeModelProviderCfg.ID)),
957		}
958		newSummarizeProvider, err := provider.NewProvider(largeModelProviderCfg, summarizeOpts...)
959		if err != nil {
960			return fmt.Errorf("failed to create new summarize provider: %w", err)
961		}
962		a.summarizeProvider = newSummarizeProvider
963		a.summarizeProviderID = string(largeModelProviderCfg.ID)
964	}
965
966	return nil
967}