1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/crush/internal/shell"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() catwalk.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70	mcpTools []McpTool
 71
 72	tools *csync.LazySlice[tools.BaseTool]
 73
 74	provider   provider.Provider
 75	providerID string
 76
 77	titleProvider       provider.Provider
 78	summarizeProvider   provider.Provider
 79	summarizeProviderID string
 80
 81	activeRequests *csync.Map[string, context.CancelFunc]
 82}
 83
 84var agentPromptMap = map[string]prompt.PromptID{
 85	"coder": prompt.PromptCoder,
 86	"task":  prompt.PromptTask,
 87}
 88
 89func NewAgent(
 90	ctx context.Context,
 91	agentCfg config.Agent,
 92	// These services are needed in the tools
 93	permissions permission.Service,
 94	sessions session.Service,
 95	messages message.Service,
 96	history history.Service,
 97	lspClients map[string]*lsp.Client,
 98) (Service, error) {
 99	cfg := config.Get()
100
101	var agentTool tools.BaseTool
102	if agentCfg.ID == "coder" {
103		taskAgentCfg := config.Get().Agents["task"]
104		if taskAgentCfg.ID == "" {
105			return nil, fmt.Errorf("task agent not found in config")
106		}
107		taskAgent, err := NewAgent(ctx, taskAgentCfg, permissions, sessions, messages, history, lspClients)
108		if err != nil {
109			return nil, fmt.Errorf("failed to create task agent: %w", err)
110		}
111
112		agentTool = NewAgentTool(taskAgent, sessions, messages)
113	}
114
115	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116	if providerCfg == nil {
117		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118	}
119	model := config.Get().GetModelByType(agentCfg.Model)
120
121	if model == nil {
122		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123	}
124
125	promptID := agentPromptMap[agentCfg.ID]
126	if promptID == "" {
127		promptID = prompt.PromptDefault
128	}
129	opts := []provider.ProviderClientOption{
130		provider.WithModel(agentCfg.Model),
131		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132	}
133	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134	if err != nil {
135		return nil, err
136	}
137
138	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139	var smallModelProviderCfg *config.ProviderConfig
140	if smallModelCfg.Provider == providerCfg.ID {
141		smallModelProviderCfg = providerCfg
142	} else {
143		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145		if smallModelProviderCfg.ID == "" {
146			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147		}
148	}
149	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150	if smallModel.ID == "" {
151		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152	}
153
154	titleOpts := []provider.ProviderClientOption{
155		provider.WithModel(config.SelectedModelTypeSmall),
156		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157	}
158	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159	if err != nil {
160		return nil, err
161	}
162	summarizeOpts := []provider.ProviderClientOption{
163		provider.WithModel(config.SelectedModelTypeSmall),
164		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165	}
166	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167	if err != nil {
168		return nil, err
169	}
170
171	toolFn := func() []tools.BaseTool {
172		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
173		defer func() {
174			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
175		}()
176
177		cwd := cfg.WorkingDir()
178		allTools := []tools.BaseTool{
179			tools.NewBashTool(permissions, cwd),
180			tools.NewDownloadTool(permissions, cwd),
181			tools.NewEditTool(lspClients, permissions, history, cwd),
182			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
183			tools.NewFetchTool(permissions, cwd),
184			tools.NewGlobTool(cwd),
185			tools.NewGrepTool(cwd),
186			tools.NewLsTool(permissions, cwd),
187			tools.NewSourcegraphTool(),
188			tools.NewViewTool(lspClients, permissions, cwd),
189			tools.NewWriteTool(lspClients, permissions, history, cwd),
190		}
191
192		mcpToolsOnce.Do(func() {
193			mcpTools = doGetMCPTools(ctx, permissions, cfg)
194		})
195		allTools = append(allTools, mcpTools...)
196
197		if len(lspClients) > 0 {
198			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
199		}
200
201		if agentTool != nil {
202			allTools = append(allTools, agentTool)
203		}
204
205		if agentCfg.AllowedTools == nil {
206			return allTools
207		}
208
209		var filteredTools []tools.BaseTool
210		for _, tool := range allTools {
211			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
212				filteredTools = append(filteredTools, tool)
213			}
214		}
215		return filteredTools
216	}
217
218	return &agent{
219		Broker:              pubsub.NewBroker[AgentEvent](),
220		agentCfg:            agentCfg,
221		provider:            agentProvider,
222		providerID:          string(providerCfg.ID),
223		messages:            messages,
224		sessions:            sessions,
225		titleProvider:       titleProvider,
226		summarizeProvider:   summarizeProvider,
227		summarizeProviderID: string(smallModelProviderCfg.ID),
228		activeRequests:      csync.NewMap[string, context.CancelFunc](),
229		tools:               csync.NewLazySlice(toolFn),
230	}, nil
231}
232
233func (a *agent) Model() catwalk.Model {
234	return *config.Get().GetModelByType(a.agentCfg.Model)
235}
236
237func (a *agent) Cancel(sessionID string) {
238	// Cancel regular requests
239	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
240		slog.Info("Request cancellation initiated", "session_id", sessionID)
241		cancel()
242	}
243
244	// Also check for summarize requests
245	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
246		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
247		cancel()
248	}
249}
250
251func (a *agent) IsBusy() bool {
252	var busy bool
253	for cancelFunc := range a.activeRequests.Seq() {
254		if cancelFunc != nil {
255			busy = true
256			break
257		}
258	}
259	return busy
260}
261
262func (a *agent) IsSessionBusy(sessionID string) bool {
263	_, busy := a.activeRequests.Get(sessionID)
264	return busy
265}
266
267func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
268	if content == "" {
269		return nil
270	}
271	if a.titleProvider == nil {
272		return nil
273	}
274	session, err := a.sessions.Get(ctx, sessionID)
275	if err != nil {
276		return err
277	}
278	parts := []message.ContentPart{message.TextContent{
279		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
280	}}
281
282	// Use streaming approach like summarization
283	response := a.titleProvider.StreamResponse(
284		ctx,
285		[]message.Message{
286			{
287				Role:  message.User,
288				Parts: parts,
289			},
290		},
291		nil,
292	)
293
294	var finalResponse *provider.ProviderResponse
295	for r := range response {
296		if r.Error != nil {
297			return r.Error
298		}
299		finalResponse = r.Response
300	}
301
302	if finalResponse == nil {
303		return fmt.Errorf("no response received from title provider")
304	}
305
306	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
307	if title == "" {
308		return nil
309	}
310
311	session.Title = title
312	_, err = a.sessions.Save(ctx, session)
313	return err
314}
315
316func (a *agent) err(err error) AgentEvent {
317	return AgentEvent{
318		Type:  AgentEventTypeError,
319		Error: err,
320	}
321}
322
323func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
324	if !a.Model().SupportsImages && attachments != nil {
325		attachments = nil
326	}
327	events := make(chan AgentEvent)
328	if a.IsSessionBusy(sessionID) {
329		return nil, ErrSessionBusy
330	}
331
332	genCtx, cancel := context.WithCancel(ctx)
333
334	a.activeRequests.Set(sessionID, cancel)
335	go func() {
336		slog.Debug("Request started", "sessionID", sessionID)
337		defer log.RecoverPanic("agent.Run", func() {
338			events <- a.err(fmt.Errorf("panic while running the agent"))
339		})
340		var attachmentParts []message.ContentPart
341		for _, attachment := range attachments {
342			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
343		}
344		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
345		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
346			slog.Error(result.Error.Error())
347		}
348		slog.Debug("Request completed", "sessionID", sessionID)
349		a.activeRequests.Del(sessionID)
350		cancel()
351		a.Publish(pubsub.CreatedEvent, result)
352		events <- result
353		close(events)
354	}()
355	return events, nil
356}
357
358func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
359	cfg := config.Get()
360	// List existing messages; if none, start title generation asynchronously.
361	msgs, err := a.messages.List(ctx, sessionID)
362	if err != nil {
363		return a.err(fmt.Errorf("failed to list messages: %w", err))
364	}
365	if len(msgs) == 0 {
366		go func() {
367			defer log.RecoverPanic("agent.Run", func() {
368				slog.Error("panic while generating title")
369			})
370			titleErr := a.generateTitle(context.Background(), sessionID, content)
371			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
372				slog.Error("failed to generate title", "error", titleErr)
373			}
374		}()
375	}
376	session, err := a.sessions.Get(ctx, sessionID)
377	if err != nil {
378		return a.err(fmt.Errorf("failed to get session: %w", err))
379	}
380	if session.SummaryMessageID != "" {
381		summaryMsgInex := -1
382		for i, msg := range msgs {
383			if msg.ID == session.SummaryMessageID {
384				summaryMsgInex = i
385				break
386			}
387		}
388		if summaryMsgInex != -1 {
389			msgs = msgs[summaryMsgInex:]
390			msgs[0].Role = message.User
391		}
392	}
393
394	// clean messages
395	// if there is a tool call that has no tool response here we need to mark it as cancelled this could have happened by a crash or something similar
396	resultsMap := make(map[string]bool, 0) // toolCallId=>true
397	for _, msg := range msgs {
398		if msg.Role == message.Tool {
399			results := msg.ToolResults()
400			for _, result := range results {
401				resultsMap[result.ToolCallID] = true
402			}
403		}
404	}
405	for _, msg := range msgs {
406		if msg.Role == message.Assistant && len(msg.ToolCalls()) > 0 {
407			for _, tc := range msg.ToolCalls() {
408				if _, ok := resultsMap[tc.ID]; !ok {
409					a.finishMessage(context.Background(), &msg, message.FinishReasonCanceled, "Request cancelled", "")
410					goto next
411				}
412			}
413		next:
414		}
415	}
416
417	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
418	if err != nil {
419		return a.err(fmt.Errorf("failed to create user message: %w", err))
420	}
421	// Append the new user message to the conversation history.
422	msgHistory := append(msgs, userMsg)
423
424	for {
425		// Check for cancellation before each iteration
426		select {
427		case <-ctx.Done():
428			return a.err(ctx.Err())
429		default:
430			// Continue processing
431		}
432		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
433		if err != nil {
434			if errors.Is(err, context.Canceled) {
435				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
436				a.messages.Update(context.Background(), agentMessage)
437				return a.err(ErrRequestCancelled)
438			}
439			return a.err(fmt.Errorf("failed to process events: %w", err))
440		}
441		if cfg.Options.Debug {
442			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
443		}
444		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
445			// We are not done, we need to respond with the tool response
446			msgHistory = append(msgHistory, agentMessage, *toolResults)
447			continue
448		}
449		if agentMessage.FinishReason() == "" {
450			// Kujtim: could not track down where this is happening but this means its cancelled
451			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
452			_ = a.messages.Update(context.Background(), agentMessage)
453			return a.err(ErrRequestCancelled)
454		}
455		return AgentEvent{
456			Type:    AgentEventTypeResponse,
457			Message: agentMessage,
458			Done:    true,
459		}
460	}
461}
462
463func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
464	parts := []message.ContentPart{message.TextContent{Text: content}}
465	parts = append(parts, attachmentParts...)
466	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
467		Role:  message.User,
468		Parts: parts,
469	})
470}
471
472func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
473	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
474	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
475
476	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
477		Role:     message.Assistant,
478		Parts:    []message.ContentPart{},
479		Model:    a.Model().ID,
480		Provider: a.providerID,
481	})
482	if err != nil {
483		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
484	}
485
486	// Add the session and message ID into the context if needed by tools.
487	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
488
489	// Process each event in the stream.
490	for event := range eventChan {
491		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
492			if errors.Is(processErr, context.Canceled) {
493				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
494			} else {
495				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
496			}
497			return assistantMsg, nil, processErr
498		}
499		if ctx.Err() != nil {
500			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
501			return assistantMsg, nil, ctx.Err()
502		}
503	}
504
505	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
506	toolCalls := assistantMsg.ToolCalls()
507	for i, toolCall := range toolCalls {
508		select {
509		case <-ctx.Done():
510			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
511			// Make all future tool calls cancelled
512			for j := i; j < len(toolCalls); j++ {
513				toolResults[j] = message.ToolResult{
514					ToolCallID: toolCalls[j].ID,
515					Content:    "Tool execution canceled by user",
516					IsError:    true,
517				}
518			}
519			goto out
520		default:
521			// Continue processing
522			var tool tools.BaseTool
523			for availableTool := range a.tools.Seq() {
524				if availableTool.Info().Name == toolCall.Name {
525					tool = availableTool
526					break
527				}
528			}
529
530			// Tool not found
531			if tool == nil {
532				toolResults[i] = message.ToolResult{
533					ToolCallID: toolCall.ID,
534					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
535					IsError:    true,
536				}
537				continue
538			}
539
540			// Run tool in goroutine to allow cancellation
541			type toolExecResult struct {
542				response tools.ToolResponse
543				err      error
544			}
545			resultChan := make(chan toolExecResult, 1)
546
547			go func() {
548				response, err := tool.Run(ctx, tools.ToolCall{
549					ID:    toolCall.ID,
550					Name:  toolCall.Name,
551					Input: toolCall.Input,
552				})
553				resultChan <- toolExecResult{response: response, err: err}
554			}()
555
556			var toolResponse tools.ToolResponse
557			var toolErr error
558
559			select {
560			case <-ctx.Done():
561				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
562				// Mark remaining tool calls as cancelled
563				for j := i; j < len(toolCalls); j++ {
564					toolResults[j] = message.ToolResult{
565						ToolCallID: toolCalls[j].ID,
566						Content:    "Tool execution canceled by user",
567						IsError:    true,
568					}
569				}
570				goto out
571			case result := <-resultChan:
572				toolResponse = result.response
573				toolErr = result.err
574			}
575
576			if toolErr != nil {
577				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
578				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
579					toolResults[i] = message.ToolResult{
580						ToolCallID: toolCall.ID,
581						Content:    "Permission denied",
582						IsError:    true,
583					}
584					for j := i + 1; j < len(toolCalls); j++ {
585						toolResults[j] = message.ToolResult{
586							ToolCallID: toolCalls[j].ID,
587							Content:    "Tool execution canceled by user",
588							IsError:    true,
589						}
590					}
591					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
592					break
593				}
594			}
595			toolResults[i] = message.ToolResult{
596				ToolCallID: toolCall.ID,
597				Content:    toolResponse.Content,
598				Metadata:   toolResponse.Metadata,
599				IsError:    toolResponse.IsError,
600			}
601		}
602	}
603out:
604	if len(toolResults) == 0 {
605		return assistantMsg, nil, nil
606	}
607	parts := make([]message.ContentPart, 0)
608	for _, tr := range toolResults {
609		parts = append(parts, tr)
610	}
611	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
612		Role:     message.Tool,
613		Parts:    parts,
614		Provider: a.providerID,
615	})
616	if err != nil {
617		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
618	}
619
620	return assistantMsg, &msg, err
621}
622
623func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
624	msg.AddFinish(finishReason, message, details)
625	_ = a.messages.Update(ctx, *msg)
626}
627
628func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
629	select {
630	case <-ctx.Done():
631		return ctx.Err()
632	default:
633		// Continue processing.
634	}
635
636	switch event.Type {
637	case provider.EventThinkingDelta:
638		assistantMsg.AppendReasoningContent(event.Thinking)
639		return a.messages.Update(ctx, *assistantMsg)
640	case provider.EventSignatureDelta:
641		assistantMsg.AppendReasoningSignature(event.Signature)
642		return a.messages.Update(ctx, *assistantMsg)
643	case provider.EventContentDelta:
644		assistantMsg.FinishThinking()
645		assistantMsg.AppendContent(event.Content)
646		return a.messages.Update(ctx, *assistantMsg)
647	case provider.EventToolUseStart:
648		assistantMsg.FinishThinking()
649		slog.Info("Tool call started", "toolCall", event.ToolCall)
650		assistantMsg.AddToolCall(*event.ToolCall)
651		return a.messages.Update(ctx, *assistantMsg)
652	case provider.EventToolUseDelta:
653		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
654		return a.messages.Update(ctx, *assistantMsg)
655	case provider.EventToolUseStop:
656		slog.Info("Finished tool call", "toolCall", event.ToolCall)
657		assistantMsg.FinishToolCall(event.ToolCall.ID)
658		return a.messages.Update(ctx, *assistantMsg)
659	case provider.EventError:
660		return event.Error
661	case provider.EventComplete:
662		assistantMsg.FinishThinking()
663		assistantMsg.SetToolCalls(event.Response.ToolCalls)
664		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
665		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
666			return fmt.Errorf("failed to update message: %w", err)
667		}
668		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
669	}
670
671	return nil
672}
673
674func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
675	sess, err := a.sessions.Get(ctx, sessionID)
676	if err != nil {
677		return fmt.Errorf("failed to get session: %w", err)
678	}
679
680	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
681		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
682		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
683		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
684
685	sess.Cost += cost
686	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
687	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
688
689	_, err = a.sessions.Save(ctx, sess)
690	if err != nil {
691		return fmt.Errorf("failed to save session: %w", err)
692	}
693	return nil
694}
695
696func (a *agent) Summarize(ctx context.Context, sessionID string) error {
697	if a.summarizeProvider == nil {
698		return fmt.Errorf("summarize provider not available")
699	}
700
701	// Check if session is busy
702	if a.IsSessionBusy(sessionID) {
703		return ErrSessionBusy
704	}
705
706	// Create a new context with cancellation
707	summarizeCtx, cancel := context.WithCancel(ctx)
708
709	// Store the cancel function in activeRequests to allow cancellation
710	a.activeRequests.Set(sessionID+"-summarize", cancel)
711
712	go func() {
713		defer a.activeRequests.Del(sessionID + "-summarize")
714		defer cancel()
715		event := AgentEvent{
716			Type:     AgentEventTypeSummarize,
717			Progress: "Starting summarization...",
718		}
719
720		a.Publish(pubsub.CreatedEvent, event)
721		// Get all messages from the session
722		msgs, err := a.messages.List(summarizeCtx, sessionID)
723		if err != nil {
724			event = AgentEvent{
725				Type:  AgentEventTypeError,
726				Error: fmt.Errorf("failed to list messages: %w", err),
727				Done:  true,
728			}
729			a.Publish(pubsub.CreatedEvent, event)
730			return
731		}
732		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
733
734		if len(msgs) == 0 {
735			event = AgentEvent{
736				Type:  AgentEventTypeError,
737				Error: fmt.Errorf("no messages to summarize"),
738				Done:  true,
739			}
740			a.Publish(pubsub.CreatedEvent, event)
741			return
742		}
743
744		event = AgentEvent{
745			Type:     AgentEventTypeSummarize,
746			Progress: "Analyzing conversation...",
747		}
748		a.Publish(pubsub.CreatedEvent, event)
749
750		// Add a system message to guide the summarization
751		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
752
753		// Create a new message with the summarize prompt
754		promptMsg := message.Message{
755			Role:  message.User,
756			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
757		}
758
759		// Append the prompt to the messages
760		msgsWithPrompt := append(msgs, promptMsg)
761
762		event = AgentEvent{
763			Type:     AgentEventTypeSummarize,
764			Progress: "Generating summary...",
765		}
766
767		a.Publish(pubsub.CreatedEvent, event)
768
769		// Send the messages to the summarize provider
770		response := a.summarizeProvider.StreamResponse(
771			summarizeCtx,
772			msgsWithPrompt,
773			nil,
774		)
775		var finalResponse *provider.ProviderResponse
776		for r := range response {
777			if r.Error != nil {
778				event = AgentEvent{
779					Type:  AgentEventTypeError,
780					Error: fmt.Errorf("failed to summarize: %w", err),
781					Done:  true,
782				}
783				a.Publish(pubsub.CreatedEvent, event)
784				return
785			}
786			finalResponse = r.Response
787		}
788
789		summary := strings.TrimSpace(finalResponse.Content)
790		if summary == "" {
791			event = AgentEvent{
792				Type:  AgentEventTypeError,
793				Error: fmt.Errorf("empty summary returned"),
794				Done:  true,
795			}
796			a.Publish(pubsub.CreatedEvent, event)
797			return
798		}
799		shell := shell.GetPersistentShell(config.Get().WorkingDir())
800		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
801		event = AgentEvent{
802			Type:     AgentEventTypeSummarize,
803			Progress: "Creating new session...",
804		}
805
806		a.Publish(pubsub.CreatedEvent, event)
807		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
808		if err != nil {
809			event = AgentEvent{
810				Type:  AgentEventTypeError,
811				Error: fmt.Errorf("failed to get session: %w", err),
812				Done:  true,
813			}
814
815			a.Publish(pubsub.CreatedEvent, event)
816			return
817		}
818		// Create a message in the new session with the summary
819		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
820			Role: message.Assistant,
821			Parts: []message.ContentPart{
822				message.TextContent{Text: summary},
823				message.Finish{
824					Reason: message.FinishReasonEndTurn,
825					Time:   time.Now().Unix(),
826				},
827			},
828			Model:    a.summarizeProvider.Model().ID,
829			Provider: a.summarizeProviderID,
830		})
831		if err != nil {
832			event = AgentEvent{
833				Type:  AgentEventTypeError,
834				Error: fmt.Errorf("failed to create summary message: %w", err),
835				Done:  true,
836			}
837
838			a.Publish(pubsub.CreatedEvent, event)
839			return
840		}
841		oldSession.SummaryMessageID = msg.ID
842		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
843		oldSession.PromptTokens = 0
844		model := a.summarizeProvider.Model()
845		usage := finalResponse.Usage
846		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
847			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
848			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
849			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
850		oldSession.Cost += cost
851		_, err = a.sessions.Save(summarizeCtx, oldSession)
852		if err != nil {
853			event = AgentEvent{
854				Type:  AgentEventTypeError,
855				Error: fmt.Errorf("failed to save session: %w", err),
856				Done:  true,
857			}
858			a.Publish(pubsub.CreatedEvent, event)
859		}
860
861		event = AgentEvent{
862			Type:      AgentEventTypeSummarize,
863			SessionID: oldSession.ID,
864			Progress:  "Summary complete",
865			Done:      true,
866		}
867		a.Publish(pubsub.CreatedEvent, event)
868		// Send final success event with the new session ID
869	}()
870
871	return nil
872}
873
874func (a *agent) CancelAll() {
875	if !a.IsBusy() {
876		return
877	}
878	for key := range a.activeRequests.Seq2() {
879		a.Cancel(key) // key is sessionID
880	}
881
882	timeout := time.After(5 * time.Second)
883	for a.IsBusy() {
884		select {
885		case <-timeout:
886			return
887		default:
888			time.Sleep(200 * time.Millisecond)
889		}
890	}
891}
892
893func (a *agent) UpdateModel() error {
894	cfg := config.Get()
895
896	// Get current provider configuration
897	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
898	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
899		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
900	}
901
902	// Check if provider has changed
903	if string(currentProviderCfg.ID) != a.providerID {
904		// Provider changed, need to recreate the main provider
905		model := cfg.GetModelByType(a.agentCfg.Model)
906		if model.ID == "" {
907			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
908		}
909
910		promptID := agentPromptMap[a.agentCfg.ID]
911		if promptID == "" {
912			promptID = prompt.PromptDefault
913		}
914
915		opts := []provider.ProviderClientOption{
916			provider.WithModel(a.agentCfg.Model),
917			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
918		}
919
920		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
921		if err != nil {
922			return fmt.Errorf("failed to create new provider: %w", err)
923		}
924
925		// Update the provider and provider ID
926		a.provider = newProvider
927		a.providerID = string(currentProviderCfg.ID)
928	}
929
930	// Check if small model provider has changed (affects title and summarize providers)
931	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
932	var smallModelProviderCfg config.ProviderConfig
933
934	for p := range cfg.Providers.Seq() {
935		if p.ID == smallModelCfg.Provider {
936			smallModelProviderCfg = p
937			break
938		}
939	}
940
941	if smallModelProviderCfg.ID == "" {
942		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
943	}
944
945	// Check if summarize provider has changed
946	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
947		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
948		if smallModel == nil {
949			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
950		}
951
952		// Recreate title provider
953		titleOpts := []provider.ProviderClientOption{
954			provider.WithModel(config.SelectedModelTypeSmall),
955			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
956			// We want the title to be short, so we limit the max tokens
957			provider.WithMaxTokens(40),
958		}
959		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
960		if err != nil {
961			return fmt.Errorf("failed to create new title provider: %w", err)
962		}
963
964		// Recreate summarize provider
965		summarizeOpts := []provider.ProviderClientOption{
966			provider.WithModel(config.SelectedModelTypeSmall),
967			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
968		}
969		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
970		if err != nil {
971			return fmt.Errorf("failed to create new summarize provider: %w", err)
972		}
973
974		// Update the providers and provider ID
975		a.titleProvider = newTitleProvider
976		a.summarizeProvider = newSummarizeProvider
977		a.summarizeProviderID = string(smallModelProviderCfg.ID)
978	}
979
980	return nil
981}