1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/catwalk/pkg/catwalk"
 14	"github.com/charmbracelet/crush/internal/config"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/crush/internal/shell"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() catwalk.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70
 71	tools      []tools.BaseTool
 72	provider   provider.Provider
 73	providerID string
 74
 75	titleProvider       provider.Provider
 76	summarizeProvider   provider.Provider
 77	summarizeProviderID string
 78
 79	activeRequests sync.Map
 80}
 81
 82var agentPromptMap = map[string]prompt.PromptID{
 83	"coder": prompt.PromptCoder,
 84	"task":  prompt.PromptTask,
 85}
 86
 87func NewAgent(
 88	agentCfg config.Agent,
 89	// These services are needed in the tools
 90	permissions permission.Service,
 91	sessions session.Service,
 92	messages message.Service,
 93	history history.Service,
 94	lspClients map[string]*lsp.Client,
 95) (Service, error) {
 96	ctx := context.Background()
 97	cfg := config.Get()
 98	otherTools := GetMCPTools(ctx, permissions, cfg)
 99	if len(lspClients) > 0 {
100		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
101	}
102
103	cwd := cfg.WorkingDir()
104	allTools := []tools.BaseTool{
105		tools.NewBashTool(permissions, cwd),
106		tools.NewDownloadTool(permissions, cwd),
107		tools.NewEditTool(lspClients, permissions, history, cwd),
108		tools.NewFetchTool(permissions, cwd),
109		tools.NewGlobTool(cwd),
110		tools.NewGrepTool(cwd),
111		tools.NewLsTool(cwd),
112		tools.NewSourcegraphTool(),
113		tools.NewViewTool(lspClients, cwd),
114		tools.NewWriteTool(lspClients, permissions, history, cwd),
115	}
116
117	if agentCfg.ID == "coder" {
118		taskAgentCfg := config.Get().Agents["task"]
119		if taskAgentCfg.ID == "" {
120			return nil, fmt.Errorf("task agent not found in config")
121		}
122		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
123		if err != nil {
124			return nil, fmt.Errorf("failed to create task agent: %w", err)
125		}
126
127		allTools = append(
128			allTools,
129			NewAgentTool(
130				taskAgent,
131				sessions,
132				messages,
133			),
134		)
135	}
136
137	allTools = append(allTools, otherTools...)
138	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
139	if providerCfg == nil {
140		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
141	}
142	model := config.Get().GetModelByType(agentCfg.Model)
143
144	if model == nil {
145		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
146	}
147
148	promptID := agentPromptMap[agentCfg.ID]
149	if promptID == "" {
150		promptID = prompt.PromptDefault
151	}
152	opts := []provider.ProviderClientOption{
153		provider.WithModel(agentCfg.Model),
154		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
155	}
156	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
157	if err != nil {
158		return nil, err
159	}
160
161	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
162	var smallModelProviderCfg *config.ProviderConfig
163	if smallModelCfg.Provider == providerCfg.ID {
164		smallModelProviderCfg = providerCfg
165	} else {
166		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
167
168		if smallModelProviderCfg.ID == "" {
169			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
170		}
171	}
172	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
173	if smallModel.ID == "" {
174		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
175	}
176
177	titleOpts := []provider.ProviderClientOption{
178		provider.WithModel(config.SelectedModelTypeSmall),
179		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
180	}
181	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
182	if err != nil {
183		return nil, err
184	}
185	summarizeOpts := []provider.ProviderClientOption{
186		provider.WithModel(config.SelectedModelTypeSmall),
187		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
188	}
189	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
190	if err != nil {
191		return nil, err
192	}
193
194	agentTools := []tools.BaseTool{}
195	if agentCfg.AllowedTools == nil {
196		agentTools = allTools
197	} else {
198		for _, tool := range allTools {
199			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
200				agentTools = append(agentTools, tool)
201			}
202		}
203	}
204
205	agent := &agent{
206		Broker:              pubsub.NewBroker[AgentEvent](),
207		agentCfg:            agentCfg,
208		provider:            agentProvider,
209		providerID:          string(providerCfg.ID),
210		messages:            messages,
211		sessions:            sessions,
212		tools:               agentTools,
213		titleProvider:       titleProvider,
214		summarizeProvider:   summarizeProvider,
215		summarizeProviderID: string(smallModelProviderCfg.ID),
216		activeRequests:      sync.Map{},
217	}
218
219	return agent, nil
220}
221
222func (a *agent) Model() catwalk.Model {
223	return *config.Get().GetModelByType(a.agentCfg.Model)
224}
225
226func (a *agent) Cancel(sessionID string) {
227	// Cancel regular requests
228	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
229		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
230			slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
231			cancel()
232		}
233	}
234
235	// Also check for summarize requests
236	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
237		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
238			slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
239			cancel()
240		}
241	}
242}
243
244func (a *agent) IsBusy() bool {
245	busy := false
246	a.activeRequests.Range(func(key, value any) bool {
247		if cancelFunc, ok := value.(context.CancelFunc); ok {
248			if cancelFunc != nil {
249				busy = true
250				return false
251			}
252		}
253		return true
254	})
255	return busy
256}
257
258func (a *agent) IsSessionBusy(sessionID string) bool {
259	_, busy := a.activeRequests.Load(sessionID)
260	return busy
261}
262
263func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
264	if content == "" {
265		return nil
266	}
267	if a.titleProvider == nil {
268		return nil
269	}
270	session, err := a.sessions.Get(ctx, sessionID)
271	if err != nil {
272		return err
273	}
274	parts := []message.ContentPart{message.TextContent{
275		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
276	}}
277
278	// Use streaming approach like summarization
279	response := a.titleProvider.StreamResponse(
280		ctx,
281		[]message.Message{
282			{
283				Role:  message.User,
284				Parts: parts,
285			},
286		},
287		make([]tools.BaseTool, 0),
288	)
289
290	var finalResponse *provider.ProviderResponse
291	for r := range response {
292		if r.Error != nil {
293			return r.Error
294		}
295		finalResponse = r.Response
296	}
297
298	if finalResponse == nil {
299		return fmt.Errorf("no response received from title provider")
300	}
301
302	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
303	if title == "" {
304		return nil
305	}
306
307	session.Title = title
308	_, err = a.sessions.Save(ctx, session)
309	return err
310}
311
312func (a *agent) err(err error) AgentEvent {
313	return AgentEvent{
314		Type:  AgentEventTypeError,
315		Error: err,
316	}
317}
318
319func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
320	if !a.Model().SupportsImages && attachments != nil {
321		attachments = nil
322	}
323	events := make(chan AgentEvent)
324	if a.IsSessionBusy(sessionID) {
325		return nil, ErrSessionBusy
326	}
327
328	genCtx, cancel := context.WithCancel(ctx)
329
330	a.activeRequests.Store(sessionID, cancel)
331	go func() {
332		slog.Debug("Request started", "sessionID", sessionID)
333		defer log.RecoverPanic("agent.Run", func() {
334			events <- a.err(fmt.Errorf("panic while running the agent"))
335		})
336		var attachmentParts []message.ContentPart
337		for _, attachment := range attachments {
338			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
339		}
340		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
341		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
342			slog.Error(result.Error.Error())
343		}
344		slog.Debug("Request completed", "sessionID", sessionID)
345		a.activeRequests.Delete(sessionID)
346		cancel()
347		a.Publish(pubsub.CreatedEvent, result)
348		events <- result
349		close(events)
350	}()
351	return events, nil
352}
353
354func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
355	cfg := config.Get()
356	// List existing messages; if none, start title generation asynchronously.
357	msgs, err := a.messages.List(ctx, sessionID)
358	if err != nil {
359		return a.err(fmt.Errorf("failed to list messages: %w", err))
360	}
361	if len(msgs) == 0 {
362		go func() {
363			defer log.RecoverPanic("agent.Run", func() {
364				slog.Error("panic while generating title")
365			})
366			titleErr := a.generateTitle(context.Background(), sessionID, content)
367			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
368				slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
369			}
370		}()
371	}
372	session, err := a.sessions.Get(ctx, sessionID)
373	if err != nil {
374		return a.err(fmt.Errorf("failed to get session: %w", err))
375	}
376	if session.SummaryMessageID != "" {
377		summaryMsgInex := -1
378		for i, msg := range msgs {
379			if msg.ID == session.SummaryMessageID {
380				summaryMsgInex = i
381				break
382			}
383		}
384		if summaryMsgInex != -1 {
385			msgs = msgs[summaryMsgInex:]
386			msgs[0].Role = message.User
387		}
388	}
389
390	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
391	if err != nil {
392		return a.err(fmt.Errorf("failed to create user message: %w", err))
393	}
394	// Append the new user message to the conversation history.
395	msgHistory := append(msgs, userMsg)
396
397	for {
398		// Check for cancellation before each iteration
399		select {
400		case <-ctx.Done():
401			return a.err(ctx.Err())
402		default:
403			// Continue processing
404		}
405		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
406		if err != nil {
407			if errors.Is(err, context.Canceled) {
408				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
409				a.messages.Update(context.Background(), agentMessage)
410				return a.err(ErrRequestCancelled)
411			}
412			return a.err(fmt.Errorf("failed to process events: %w", err))
413		}
414		if cfg.Options.Debug {
415			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
416		}
417		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
418			// We are not done, we need to respond with the tool response
419			msgHistory = append(msgHistory, agentMessage, *toolResults)
420			continue
421		}
422		return AgentEvent{
423			Type:    AgentEventTypeResponse,
424			Message: agentMessage,
425			Done:    true,
426		}
427	}
428}
429
430func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
431	parts := []message.ContentPart{message.TextContent{Text: content}}
432	parts = append(parts, attachmentParts...)
433	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
434		Role:  message.User,
435		Parts: parts,
436	})
437}
438
439func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
440	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
441	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
442
443	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444		Role:     message.Assistant,
445		Parts:    []message.ContentPart{},
446		Model:    a.Model().ID,
447		Provider: a.providerID,
448	})
449	if err != nil {
450		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
451	}
452
453	// Add the session and message ID into the context if needed by tools.
454	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
455
456	// Process each event in the stream.
457	for event := range eventChan {
458		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
459			if errors.Is(processErr, context.Canceled) {
460				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
461			} else {
462				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
463			}
464			return assistantMsg, nil, processErr
465		}
466		if ctx.Err() != nil {
467			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
468			return assistantMsg, nil, ctx.Err()
469		}
470	}
471
472	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
473	toolCalls := assistantMsg.ToolCalls()
474	for i, toolCall := range toolCalls {
475		select {
476		case <-ctx.Done():
477			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
478			// Make all future tool calls cancelled
479			for j := i; j < len(toolCalls); j++ {
480				toolResults[j] = message.ToolResult{
481					ToolCallID: toolCalls[j].ID,
482					Content:    "Tool execution canceled by user",
483					IsError:    true,
484				}
485			}
486			goto out
487		default:
488			// Continue processing
489			var tool tools.BaseTool
490			for _, availableTool := range a.tools {
491				if availableTool.Info().Name == toolCall.Name {
492					tool = availableTool
493					break
494				}
495			}
496
497			// Tool not found
498			if tool == nil {
499				toolResults[i] = message.ToolResult{
500					ToolCallID: toolCall.ID,
501					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
502					IsError:    true,
503				}
504				continue
505			}
506
507			// Run tool in goroutine to allow cancellation
508			type toolExecResult struct {
509				response tools.ToolResponse
510				err      error
511			}
512			resultChan := make(chan toolExecResult, 1)
513
514			go func() {
515				response, err := tool.Run(ctx, tools.ToolCall{
516					ID:    toolCall.ID,
517					Name:  toolCall.Name,
518					Input: toolCall.Input,
519				})
520				resultChan <- toolExecResult{response: response, err: err}
521			}()
522
523			var toolResponse tools.ToolResponse
524			var toolErr error
525
526			select {
527			case <-ctx.Done():
528				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
529				// Mark remaining tool calls as cancelled
530				for j := i; j < len(toolCalls); j++ {
531					toolResults[j] = message.ToolResult{
532						ToolCallID: toolCalls[j].ID,
533						Content:    "Tool execution canceled by user",
534						IsError:    true,
535					}
536				}
537				goto out
538			case result := <-resultChan:
539				toolResponse = result.response
540				toolErr = result.err
541			}
542
543			if toolErr != nil {
544				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
545				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
546					toolResults[i] = message.ToolResult{
547						ToolCallID: toolCall.ID,
548						Content:    "Permission denied",
549						IsError:    true,
550					}
551					for j := i + 1; j < len(toolCalls); j++ {
552						toolResults[j] = message.ToolResult{
553							ToolCallID: toolCalls[j].ID,
554							Content:    "Tool execution canceled by user",
555							IsError:    true,
556						}
557					}
558					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
559					break
560				}
561			}
562			toolResults[i] = message.ToolResult{
563				ToolCallID: toolCall.ID,
564				Content:    toolResponse.Content,
565				Metadata:   toolResponse.Metadata,
566				IsError:    toolResponse.IsError,
567			}
568		}
569	}
570out:
571	if len(toolResults) == 0 {
572		return assistantMsg, nil, nil
573	}
574	parts := make([]message.ContentPart, 0)
575	for _, tr := range toolResults {
576		parts = append(parts, tr)
577	}
578	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
579		Role:     message.Tool,
580		Parts:    parts,
581		Provider: a.providerID,
582	})
583	if err != nil {
584		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
585	}
586
587	return assistantMsg, &msg, err
588}
589
590func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
591	msg.AddFinish(finishReason, message, details)
592	_ = a.messages.Update(ctx, *msg)
593}
594
595func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
596	select {
597	case <-ctx.Done():
598		return ctx.Err()
599	default:
600		// Continue processing.
601	}
602
603	switch event.Type {
604	case provider.EventThinkingDelta:
605		assistantMsg.AppendReasoningContent(event.Thinking)
606		return a.messages.Update(ctx, *assistantMsg)
607	case provider.EventSignatureDelta:
608		assistantMsg.AppendReasoningSignature(event.Signature)
609		return a.messages.Update(ctx, *assistantMsg)
610	case provider.EventContentDelta:
611		assistantMsg.FinishThinking()
612		assistantMsg.AppendContent(event.Content)
613		return a.messages.Update(ctx, *assistantMsg)
614	case provider.EventToolUseStart:
615		assistantMsg.FinishThinking()
616		slog.Info("Tool call started", "toolCall", event.ToolCall)
617		assistantMsg.AddToolCall(*event.ToolCall)
618		return a.messages.Update(ctx, *assistantMsg)
619	case provider.EventToolUseDelta:
620		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
621		return a.messages.Update(ctx, *assistantMsg)
622	case provider.EventToolUseStop:
623		slog.Info("Finished tool call", "toolCall", event.ToolCall)
624		assistantMsg.FinishToolCall(event.ToolCall.ID)
625		return a.messages.Update(ctx, *assistantMsg)
626	case provider.EventError:
627		return event.Error
628	case provider.EventComplete:
629		assistantMsg.FinishThinking()
630		assistantMsg.SetToolCalls(event.Response.ToolCalls)
631		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
632		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
633			return fmt.Errorf("failed to update message: %w", err)
634		}
635		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
636	}
637
638	return nil
639}
640
641func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
642	sess, err := a.sessions.Get(ctx, sessionID)
643	if err != nil {
644		return fmt.Errorf("failed to get session: %w", err)
645	}
646
647	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
648		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
649		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
650		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
651
652	sess.Cost += cost
653	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
654	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
655
656	_, err = a.sessions.Save(ctx, sess)
657	if err != nil {
658		return fmt.Errorf("failed to save session: %w", err)
659	}
660	return nil
661}
662
663func (a *agent) Summarize(ctx context.Context, sessionID string) error {
664	if a.summarizeProvider == nil {
665		return fmt.Errorf("summarize provider not available")
666	}
667
668	// Check if session is busy
669	if a.IsSessionBusy(sessionID) {
670		return ErrSessionBusy
671	}
672
673	// Create a new context with cancellation
674	summarizeCtx, cancel := context.WithCancel(ctx)
675
676	// Store the cancel function in activeRequests to allow cancellation
677	a.activeRequests.Store(sessionID+"-summarize", cancel)
678
679	go func() {
680		defer a.activeRequests.Delete(sessionID + "-summarize")
681		defer cancel()
682		event := AgentEvent{
683			Type:     AgentEventTypeSummarize,
684			Progress: "Starting summarization...",
685		}
686
687		a.Publish(pubsub.CreatedEvent, event)
688		// Get all messages from the session
689		msgs, err := a.messages.List(summarizeCtx, sessionID)
690		if err != nil {
691			event = AgentEvent{
692				Type:  AgentEventTypeError,
693				Error: fmt.Errorf("failed to list messages: %w", err),
694				Done:  true,
695			}
696			a.Publish(pubsub.CreatedEvent, event)
697			return
698		}
699		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
700
701		if len(msgs) == 0 {
702			event = AgentEvent{
703				Type:  AgentEventTypeError,
704				Error: fmt.Errorf("no messages to summarize"),
705				Done:  true,
706			}
707			a.Publish(pubsub.CreatedEvent, event)
708			return
709		}
710
711		event = AgentEvent{
712			Type:     AgentEventTypeSummarize,
713			Progress: "Analyzing conversation...",
714		}
715		a.Publish(pubsub.CreatedEvent, event)
716
717		// Add a system message to guide the summarization
718		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
719
720		// Create a new message with the summarize prompt
721		promptMsg := message.Message{
722			Role:  message.User,
723			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
724		}
725
726		// Append the prompt to the messages
727		msgsWithPrompt := append(msgs, promptMsg)
728
729		event = AgentEvent{
730			Type:     AgentEventTypeSummarize,
731			Progress: "Generating summary...",
732		}
733
734		a.Publish(pubsub.CreatedEvent, event)
735
736		// Send the messages to the summarize provider
737		response := a.summarizeProvider.StreamResponse(
738			summarizeCtx,
739			msgsWithPrompt,
740			make([]tools.BaseTool, 0),
741		)
742		var finalResponse *provider.ProviderResponse
743		for r := range response {
744			if r.Error != nil {
745				event = AgentEvent{
746					Type:  AgentEventTypeError,
747					Error: fmt.Errorf("failed to summarize: %w", err),
748					Done:  true,
749				}
750				a.Publish(pubsub.CreatedEvent, event)
751				return
752			}
753			finalResponse = r.Response
754		}
755
756		summary := strings.TrimSpace(finalResponse.Content)
757		if summary == "" {
758			event = AgentEvent{
759				Type:  AgentEventTypeError,
760				Error: fmt.Errorf("empty summary returned"),
761				Done:  true,
762			}
763			a.Publish(pubsub.CreatedEvent, event)
764			return
765		}
766		shell := shell.GetPersistentShell(config.Get().WorkingDir())
767		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
768		event = AgentEvent{
769			Type:     AgentEventTypeSummarize,
770			Progress: "Creating new session...",
771		}
772
773		a.Publish(pubsub.CreatedEvent, event)
774		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
775		if err != nil {
776			event = AgentEvent{
777				Type:  AgentEventTypeError,
778				Error: fmt.Errorf("failed to get session: %w", err),
779				Done:  true,
780			}
781
782			a.Publish(pubsub.CreatedEvent, event)
783			return
784		}
785		// Create a message in the new session with the summary
786		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
787			Role: message.Assistant,
788			Parts: []message.ContentPart{
789				message.TextContent{Text: summary},
790				message.Finish{
791					Reason: message.FinishReasonEndTurn,
792					Time:   time.Now().Unix(),
793				},
794			},
795			Model:    a.summarizeProvider.Model().ID,
796			Provider: a.summarizeProviderID,
797		})
798		if err != nil {
799			event = AgentEvent{
800				Type:  AgentEventTypeError,
801				Error: fmt.Errorf("failed to create summary message: %w", err),
802				Done:  true,
803			}
804
805			a.Publish(pubsub.CreatedEvent, event)
806			return
807		}
808		oldSession.SummaryMessageID = msg.ID
809		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
810		oldSession.PromptTokens = 0
811		model := a.summarizeProvider.Model()
812		usage := finalResponse.Usage
813		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
814			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
815			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
816			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
817		oldSession.Cost += cost
818		_, err = a.sessions.Save(summarizeCtx, oldSession)
819		if err != nil {
820			event = AgentEvent{
821				Type:  AgentEventTypeError,
822				Error: fmt.Errorf("failed to save session: %w", err),
823				Done:  true,
824			}
825			a.Publish(pubsub.CreatedEvent, event)
826		}
827
828		event = AgentEvent{
829			Type:      AgentEventTypeSummarize,
830			SessionID: oldSession.ID,
831			Progress:  "Summary complete",
832			Done:      true,
833		}
834		a.Publish(pubsub.CreatedEvent, event)
835		// Send final success event with the new session ID
836	}()
837
838	return nil
839}
840
841func (a *agent) CancelAll() {
842	if !a.IsBusy() {
843		return
844	}
845	a.activeRequests.Range(func(key, value any) bool {
846		a.Cancel(key.(string)) // key is sessionID
847		return true
848	})
849
850	timeout := time.After(5 * time.Second)
851	for a.IsBusy() {
852		select {
853		case <-timeout:
854			return
855		default:
856			time.Sleep(200 * time.Millisecond)
857		}
858	}
859}
860
861func (a *agent) UpdateModel() error {
862	cfg := config.Get()
863
864	// Get current provider configuration
865	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
866	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
867		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
868	}
869
870	// Check if provider has changed
871	if string(currentProviderCfg.ID) != a.providerID {
872		// Provider changed, need to recreate the main provider
873		model := cfg.GetModelByType(a.agentCfg.Model)
874		if model.ID == "" {
875			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
876		}
877
878		promptID := agentPromptMap[a.agentCfg.ID]
879		if promptID == "" {
880			promptID = prompt.PromptDefault
881		}
882
883		opts := []provider.ProviderClientOption{
884			provider.WithModel(a.agentCfg.Model),
885			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
886		}
887
888		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
889		if err != nil {
890			return fmt.Errorf("failed to create new provider: %w", err)
891		}
892
893		// Update the provider and provider ID
894		a.provider = newProvider
895		a.providerID = string(currentProviderCfg.ID)
896	}
897
898	// Check if small model provider has changed (affects title and summarize providers)
899	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
900	var smallModelProviderCfg config.ProviderConfig
901
902	for _, p := range cfg.Providers {
903		if p.ID == smallModelCfg.Provider {
904			smallModelProviderCfg = p
905			break
906		}
907	}
908
909	if smallModelProviderCfg.ID == "" {
910		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
911	}
912
913	// Check if summarize provider has changed
914	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
915		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
916		if smallModel == nil {
917			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
918		}
919
920		// Recreate title provider
921		titleOpts := []provider.ProviderClientOption{
922			provider.WithModel(config.SelectedModelTypeSmall),
923			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
924			// We want the title to be short, so we limit the max tokens
925			provider.WithMaxTokens(40),
926		}
927		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
928		if err != nil {
929			return fmt.Errorf("failed to create new title provider: %w", err)
930		}
931
932		// Recreate summarize provider
933		summarizeOpts := []provider.ProviderClientOption{
934			provider.WithModel(config.SelectedModelTypeSmall),
935			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
936		}
937		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
938		if err != nil {
939			return fmt.Errorf("failed to create new summarize provider: %w", err)
940		}
941
942		// Update the providers and provider ID
943		a.titleProvider = newTitleProvider
944		a.summarizeProvider = newSummarizeProvider
945		a.summarizeProviderID = string(smallModelProviderCfg.ID)
946	}
947
948	return nil
949}