1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"time"
 11
 12	"github.com/charmbracelet/catwalk/pkg/catwalk"
 13	"github.com/charmbracelet/crush/internal/config"
 14	"github.com/charmbracelet/crush/internal/csync"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25	"github.com/charmbracelet/crush/internal/shell"
 26)
 27
 28// Common errors
 29var (
 30	ErrRequestCancelled = errors.New("request canceled by user")
 31	ErrSessionBusy      = errors.New("session is currently processing another request")
 32)
 33
 34type AgentEventType string
 35
 36const (
 37	AgentEventTypeError     AgentEventType = "error"
 38	AgentEventTypeResponse  AgentEventType = "response"
 39	AgentEventTypeSummarize AgentEventType = "summarize"
 40)
 41
 42type AgentEvent struct {
 43	Type    AgentEventType
 44	Message message.Message
 45	Error   error
 46
 47	// When summarizing
 48	SessionID string
 49	Progress  string
 50	Done      bool
 51}
 52
 53type Service interface {
 54	pubsub.Suscriber[AgentEvent]
 55	Model() catwalk.Model
 56	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 57	Cancel(sessionID string)
 58	CancelAll()
 59	IsSessionBusy(sessionID string) bool
 60	IsBusy() bool
 61	Summarize(ctx context.Context, sessionID string) error
 62	UpdateModel() error
 63}
 64
 65type agent struct {
 66	*pubsub.Broker[AgentEvent]
 67	agentCfg config.Agent
 68	sessions session.Service
 69	messages message.Service
 70
 71	tools *csync.LazySlice[tools.BaseTool]
 72
 73	provider   provider.Provider
 74	providerID string
 75
 76	titleProvider       provider.Provider
 77	summarizeProvider   provider.Provider
 78	summarizeProviderID string
 79
 80	activeRequests *csync.Map[string, context.CancelFunc]
 81}
 82
 83var agentPromptMap = map[string]prompt.PromptID{
 84	"coder": prompt.PromptCoder,
 85	"task":  prompt.PromptTask,
 86}
 87
 88func NewAgent(
 89	agentCfg config.Agent,
 90	// These services are needed in the tools
 91	permissions permission.Service,
 92	sessions session.Service,
 93	messages message.Service,
 94	history history.Service,
 95	lspClients map[string]*lsp.Client,
 96) (Service, error) {
 97	ctx := context.Background()
 98	cfg := config.Get()
 99
100	var agentTool tools.BaseTool
101	if agentCfg.ID == "coder" {
102		taskAgentCfg := config.Get().Agents["task"]
103		if taskAgentCfg.ID == "" {
104			return nil, fmt.Errorf("task agent not found in config")
105		}
106		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
107		if err != nil {
108			return nil, fmt.Errorf("failed to create task agent: %w", err)
109		}
110
111		agentTool = NewAgentTool(taskAgent, sessions, messages)
112	}
113
114	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
115	if providerCfg == nil {
116		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
117	}
118	model := config.Get().GetModelByType(agentCfg.Model)
119
120	if model == nil {
121		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
122	}
123
124	promptID := agentPromptMap[agentCfg.ID]
125	if promptID == "" {
126		promptID = prompt.PromptDefault
127	}
128	opts := []provider.ProviderClientOption{
129		provider.WithModel(agentCfg.Model),
130		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
131	}
132	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
133	if err != nil {
134		return nil, err
135	}
136
137	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
138	var smallModelProviderCfg *config.ProviderConfig
139	if smallModelCfg.Provider == providerCfg.ID {
140		smallModelProviderCfg = providerCfg
141	} else {
142		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
143
144		if smallModelProviderCfg.ID == "" {
145			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
146		}
147	}
148	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
149	if smallModel.ID == "" {
150		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
151	}
152
153	titleOpts := []provider.ProviderClientOption{
154		provider.WithModel(config.SelectedModelTypeSmall),
155		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
156	}
157	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
158	if err != nil {
159		return nil, err
160	}
161	summarizeOpts := []provider.ProviderClientOption{
162		provider.WithModel(config.SelectedModelTypeSmall),
163		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
164	}
165	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
166	if err != nil {
167		return nil, err
168	}
169
170	toolFn := func() []tools.BaseTool {
171		slog.Info("Initializing agent tools", "agent", agentCfg.ID)
172		defer func() {
173			slog.Info("Initialized agent tools", "agent", agentCfg.ID)
174		}()
175
176		cwd := cfg.WorkingDir()
177		allTools := []tools.BaseTool{
178			tools.NewBashTool(permissions, cwd),
179			tools.NewDownloadTool(permissions, cwd),
180			tools.NewEditTool(lspClients, permissions, history, cwd),
181			tools.NewMultiEditTool(lspClients, permissions, history, cwd),
182			tools.NewFetchTool(permissions, cwd),
183			tools.NewGlobTool(cwd),
184			tools.NewGrepTool(cwd),
185			tools.NewLsTool(permissions, cwd),
186			tools.NewSourcegraphTool(),
187			tools.NewViewTool(lspClients, permissions, cwd),
188			tools.NewWriteTool(lspClients, permissions, history, cwd),
189		}
190
191		mcpTools := GetMCPTools(ctx, permissions, cfg)
192		allTools = append(allTools, mcpTools...)
193
194		if len(lspClients) > 0 {
195			allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
196		}
197
198		if agentTool != nil {
199			allTools = append(allTools, agentTool)
200		}
201
202		if agentCfg.AllowedTools == nil {
203			return allTools
204		}
205
206		var filteredTools []tools.BaseTool
207		for _, tool := range allTools {
208			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
209				filteredTools = append(filteredTools, tool)
210			}
211		}
212		return filteredTools
213	}
214
215	return &agent{
216		Broker:              pubsub.NewBroker[AgentEvent](),
217		agentCfg:            agentCfg,
218		provider:            agentProvider,
219		providerID:          string(providerCfg.ID),
220		messages:            messages,
221		sessions:            sessions,
222		titleProvider:       titleProvider,
223		summarizeProvider:   summarizeProvider,
224		summarizeProviderID: string(smallModelProviderCfg.ID),
225		activeRequests:      csync.NewMap[string, context.CancelFunc](),
226		tools:               csync.NewLazySlice(toolFn),
227	}, nil
228}
229
230func (a *agent) Model() catwalk.Model {
231	return *config.Get().GetModelByType(a.agentCfg.Model)
232}
233
234func (a *agent) Cancel(sessionID string) {
235	// Cancel regular requests
236	if cancel, ok := a.activeRequests.Take(sessionID); ok && cancel != nil {
237		slog.Info("Request cancellation initiated", "session_id", sessionID)
238		cancel()
239	}
240
241	// Also check for summarize requests
242	if cancel, ok := a.activeRequests.Take(sessionID + "-summarize"); ok && cancel != nil {
243		slog.Info("Summarize cancellation initiated", "session_id", sessionID)
244		cancel()
245	}
246}
247
248func (a *agent) IsBusy() bool {
249	var busy bool
250	for cancelFunc := range a.activeRequests.Seq() {
251		if cancelFunc != nil {
252			busy = true
253			break
254		}
255	}
256	return busy
257}
258
259func (a *agent) IsSessionBusy(sessionID string) bool {
260	_, busy := a.activeRequests.Get(sessionID)
261	return busy
262}
263
264func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
265	if content == "" {
266		return nil
267	}
268	if a.titleProvider == nil {
269		return nil
270	}
271	session, err := a.sessions.Get(ctx, sessionID)
272	if err != nil {
273		return err
274	}
275	parts := []message.ContentPart{message.TextContent{
276		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
277	}}
278
279	// Use streaming approach like summarization
280	response := a.titleProvider.StreamResponse(
281		ctx,
282		[]message.Message{
283			{
284				Role:  message.User,
285				Parts: parts,
286			},
287		},
288		nil,
289	)
290
291	var finalResponse *provider.ProviderResponse
292	for r := range response {
293		if r.Error != nil {
294			return r.Error
295		}
296		finalResponse = r.Response
297	}
298
299	if finalResponse == nil {
300		return fmt.Errorf("no response received from title provider")
301	}
302
303	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
304	if title == "" {
305		return nil
306	}
307
308	session.Title = title
309	_, err = a.sessions.Save(ctx, session)
310	return err
311}
312
313func (a *agent) err(err error) AgentEvent {
314	return AgentEvent{
315		Type:  AgentEventTypeError,
316		Error: err,
317	}
318}
319
320func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
321	if !a.Model().SupportsImages && attachments != nil {
322		attachments = nil
323	}
324	events := make(chan AgentEvent)
325	if a.IsSessionBusy(sessionID) {
326		return nil, ErrSessionBusy
327	}
328
329	genCtx, cancel := context.WithCancel(ctx)
330
331	a.activeRequests.Set(sessionID, cancel)
332	go func() {
333		slog.Debug("Request started", "sessionID", sessionID)
334		defer log.RecoverPanic("agent.Run", func() {
335			events <- a.err(fmt.Errorf("panic while running the agent"))
336		})
337		var attachmentParts []message.ContentPart
338		for _, attachment := range attachments {
339			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
340		}
341		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
342		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
343			slog.Error(result.Error.Error())
344		}
345		slog.Debug("Request completed", "sessionID", sessionID)
346		a.activeRequests.Del(sessionID)
347		cancel()
348		a.Publish(pubsub.CreatedEvent, result)
349		events <- result
350		close(events)
351	}()
352	return events, nil
353}
354
355func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
356	cfg := config.Get()
357	// List existing messages; if none, start title generation asynchronously.
358	msgs, err := a.messages.List(ctx, sessionID)
359	if err != nil {
360		return a.err(fmt.Errorf("failed to list messages: %w", err))
361	}
362	if len(msgs) == 0 {
363		go func() {
364			defer log.RecoverPanic("agent.Run", func() {
365				slog.Error("panic while generating title")
366			})
367			titleErr := a.generateTitle(context.Background(), sessionID, content)
368			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
369				slog.Error("failed to generate title", "error", titleErr)
370			}
371		}()
372	}
373	session, err := a.sessions.Get(ctx, sessionID)
374	if err != nil {
375		return a.err(fmt.Errorf("failed to get session: %w", err))
376	}
377	if session.SummaryMessageID != "" {
378		summaryMsgInex := -1
379		for i, msg := range msgs {
380			if msg.ID == session.SummaryMessageID {
381				summaryMsgInex = i
382				break
383			}
384		}
385		if summaryMsgInex != -1 {
386			msgs = msgs[summaryMsgInex:]
387			msgs[0].Role = message.User
388		}
389	}
390
391	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
392	if err != nil {
393		return a.err(fmt.Errorf("failed to create user message: %w", err))
394	}
395	// Append the new user message to the conversation history.
396	msgHistory := append(msgs, userMsg)
397
398	for {
399		// Check for cancellation before each iteration
400		select {
401		case <-ctx.Done():
402			return a.err(ctx.Err())
403		default:
404			// Continue processing
405		}
406		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
407		if err != nil {
408			if errors.Is(err, context.Canceled) {
409				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
410				a.messages.Update(context.Background(), agentMessage)
411				return a.err(ErrRequestCancelled)
412			}
413			return a.err(fmt.Errorf("failed to process events: %w", err))
414		}
415		if cfg.Options.Debug {
416			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
417		}
418		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
419			// We are not done, we need to respond with the tool response
420			msgHistory = append(msgHistory, agentMessage, *toolResults)
421			continue
422		}
423		if agentMessage.FinishReason() == "" {
424			// Kujtim: could not track down where this is happening but this means its cancelled
425			agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
426			_ = a.messages.Update(context.Background(), agentMessage)
427			return a.err(ErrRequestCancelled)
428		}
429		return AgentEvent{
430			Type:    AgentEventTypeResponse,
431			Message: agentMessage,
432			Done:    true,
433		}
434	}
435}
436
437func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
438	parts := []message.ContentPart{message.TextContent{Text: content}}
439	parts = append(parts, attachmentParts...)
440	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
441		Role:  message.User,
442		Parts: parts,
443	})
444}
445
446func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
447	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
448	eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Seq()))
449
450	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
451		Role:     message.Assistant,
452		Parts:    []message.ContentPart{},
453		Model:    a.Model().ID,
454		Provider: a.providerID,
455	})
456	if err != nil {
457		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
458	}
459
460	// Add the session and message ID into the context if needed by tools.
461	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
462
463	// Process each event in the stream.
464	for event := range eventChan {
465		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
466			if errors.Is(processErr, context.Canceled) {
467				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
468			} else {
469				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
470			}
471			return assistantMsg, nil, processErr
472		}
473		if ctx.Err() != nil {
474			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
475			return assistantMsg, nil, ctx.Err()
476		}
477	}
478
479	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
480	toolCalls := assistantMsg.ToolCalls()
481	for i, toolCall := range toolCalls {
482		select {
483		case <-ctx.Done():
484			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
485			// Make all future tool calls cancelled
486			for j := i; j < len(toolCalls); j++ {
487				toolResults[j] = message.ToolResult{
488					ToolCallID: toolCalls[j].ID,
489					Content:    "Tool execution canceled by user",
490					IsError:    true,
491				}
492			}
493			goto out
494		default:
495			// Continue processing
496			var tool tools.BaseTool
497			for availableTool := range a.tools.Seq() {
498				if availableTool.Info().Name == toolCall.Name {
499					tool = availableTool
500					break
501				}
502			}
503
504			// Tool not found
505			if tool == nil {
506				toolResults[i] = message.ToolResult{
507					ToolCallID: toolCall.ID,
508					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
509					IsError:    true,
510				}
511				continue
512			}
513
514			// Run tool in goroutine to allow cancellation
515			type toolExecResult struct {
516				response tools.ToolResponse
517				err      error
518			}
519			resultChan := make(chan toolExecResult, 1)
520
521			go func() {
522				response, err := tool.Run(ctx, tools.ToolCall{
523					ID:    toolCall.ID,
524					Name:  toolCall.Name,
525					Input: toolCall.Input,
526				})
527				resultChan <- toolExecResult{response: response, err: err}
528			}()
529
530			var toolResponse tools.ToolResponse
531			var toolErr error
532
533			select {
534			case <-ctx.Done():
535				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
536				// Mark remaining tool calls as cancelled
537				for j := i; j < len(toolCalls); j++ {
538					toolResults[j] = message.ToolResult{
539						ToolCallID: toolCalls[j].ID,
540						Content:    "Tool execution canceled by user",
541						IsError:    true,
542					}
543				}
544				goto out
545			case result := <-resultChan:
546				toolResponse = result.response
547				toolErr = result.err
548			}
549
550			if toolErr != nil {
551				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
552				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
553					toolResults[i] = message.ToolResult{
554						ToolCallID: toolCall.ID,
555						Content:    "Permission denied",
556						IsError:    true,
557					}
558					for j := i + 1; j < len(toolCalls); j++ {
559						toolResults[j] = message.ToolResult{
560							ToolCallID: toolCalls[j].ID,
561							Content:    "Tool execution canceled by user",
562							IsError:    true,
563						}
564					}
565					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
566					break
567				}
568			}
569			toolResults[i] = message.ToolResult{
570				ToolCallID: toolCall.ID,
571				Content:    toolResponse.Content,
572				Metadata:   toolResponse.Metadata,
573				IsError:    toolResponse.IsError,
574			}
575		}
576	}
577out:
578	if len(toolResults) == 0 {
579		return assistantMsg, nil, nil
580	}
581	parts := make([]message.ContentPart, 0)
582	for _, tr := range toolResults {
583		parts = append(parts, tr)
584	}
585	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
586		Role:     message.Tool,
587		Parts:    parts,
588		Provider: a.providerID,
589	})
590	if err != nil {
591		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
592	}
593
594	return assistantMsg, &msg, err
595}
596
597func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
598	msg.AddFinish(finishReason, message, details)
599	_ = a.messages.Update(ctx, *msg)
600}
601
602func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
603	select {
604	case <-ctx.Done():
605		return ctx.Err()
606	default:
607		// Continue processing.
608	}
609
610	switch event.Type {
611	case provider.EventThinkingDelta:
612		assistantMsg.AppendReasoningContent(event.Thinking)
613		return a.messages.Update(ctx, *assistantMsg)
614	case provider.EventSignatureDelta:
615		assistantMsg.AppendReasoningSignature(event.Signature)
616		return a.messages.Update(ctx, *assistantMsg)
617	case provider.EventContentDelta:
618		assistantMsg.FinishThinking()
619		assistantMsg.AppendContent(event.Content)
620		return a.messages.Update(ctx, *assistantMsg)
621	case provider.EventToolUseStart:
622		assistantMsg.FinishThinking()
623		slog.Info("Tool call started", "toolCall", event.ToolCall)
624		assistantMsg.AddToolCall(*event.ToolCall)
625		return a.messages.Update(ctx, *assistantMsg)
626	case provider.EventToolUseDelta:
627		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
628		return a.messages.Update(ctx, *assistantMsg)
629	case provider.EventToolUseStop:
630		slog.Info("Finished tool call", "toolCall", event.ToolCall)
631		assistantMsg.FinishToolCall(event.ToolCall.ID)
632		return a.messages.Update(ctx, *assistantMsg)
633	case provider.EventError:
634		return event.Error
635	case provider.EventComplete:
636		assistantMsg.FinishThinking()
637		assistantMsg.SetToolCalls(event.Response.ToolCalls)
638		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
639		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
640			return fmt.Errorf("failed to update message: %w", err)
641		}
642		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
643	}
644
645	return nil
646}
647
648func (a *agent) TrackUsage(ctx context.Context, sessionID string, model catwalk.Model, usage provider.TokenUsage) error {
649	sess, err := a.sessions.Get(ctx, sessionID)
650	if err != nil {
651		return fmt.Errorf("failed to get session: %w", err)
652	}
653
654	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
655		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
656		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
657		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
658
659	sess.Cost += cost
660	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
661	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
662
663	_, err = a.sessions.Save(ctx, sess)
664	if err != nil {
665		return fmt.Errorf("failed to save session: %w", err)
666	}
667	return nil
668}
669
670func (a *agent) Summarize(ctx context.Context, sessionID string) error {
671	if a.summarizeProvider == nil {
672		return fmt.Errorf("summarize provider not available")
673	}
674
675	// Check if session is busy
676	if a.IsSessionBusy(sessionID) {
677		return ErrSessionBusy
678	}
679
680	// Create a new context with cancellation
681	summarizeCtx, cancel := context.WithCancel(ctx)
682
683	// Store the cancel function in activeRequests to allow cancellation
684	a.activeRequests.Set(sessionID+"-summarize", cancel)
685
686	go func() {
687		defer a.activeRequests.Del(sessionID + "-summarize")
688		defer cancel()
689		event := AgentEvent{
690			Type:     AgentEventTypeSummarize,
691			Progress: "Starting summarization...",
692		}
693
694		a.Publish(pubsub.CreatedEvent, event)
695		// Get all messages from the session
696		msgs, err := a.messages.List(summarizeCtx, sessionID)
697		if err != nil {
698			event = AgentEvent{
699				Type:  AgentEventTypeError,
700				Error: fmt.Errorf("failed to list messages: %w", err),
701				Done:  true,
702			}
703			a.Publish(pubsub.CreatedEvent, event)
704			return
705		}
706		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
707
708		if len(msgs) == 0 {
709			event = AgentEvent{
710				Type:  AgentEventTypeError,
711				Error: fmt.Errorf("no messages to summarize"),
712				Done:  true,
713			}
714			a.Publish(pubsub.CreatedEvent, event)
715			return
716		}
717
718		event = AgentEvent{
719			Type:     AgentEventTypeSummarize,
720			Progress: "Analyzing conversation...",
721		}
722		a.Publish(pubsub.CreatedEvent, event)
723
724		// Add a system message to guide the summarization
725		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
726
727		// Create a new message with the summarize prompt
728		promptMsg := message.Message{
729			Role:  message.User,
730			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
731		}
732
733		// Append the prompt to the messages
734		msgsWithPrompt := append(msgs, promptMsg)
735
736		event = AgentEvent{
737			Type:     AgentEventTypeSummarize,
738			Progress: "Generating summary...",
739		}
740
741		a.Publish(pubsub.CreatedEvent, event)
742
743		// Send the messages to the summarize provider
744		response := a.summarizeProvider.StreamResponse(
745			summarizeCtx,
746			msgsWithPrompt,
747			nil,
748		)
749		var finalResponse *provider.ProviderResponse
750		for r := range response {
751			if r.Error != nil {
752				event = AgentEvent{
753					Type:  AgentEventTypeError,
754					Error: fmt.Errorf("failed to summarize: %w", err),
755					Done:  true,
756				}
757				a.Publish(pubsub.CreatedEvent, event)
758				return
759			}
760			finalResponse = r.Response
761		}
762
763		summary := strings.TrimSpace(finalResponse.Content)
764		if summary == "" {
765			event = AgentEvent{
766				Type:  AgentEventTypeError,
767				Error: fmt.Errorf("empty summary returned"),
768				Done:  true,
769			}
770			a.Publish(pubsub.CreatedEvent, event)
771			return
772		}
773		shell := shell.GetPersistentShell(config.Get().WorkingDir())
774		summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
775		event = AgentEvent{
776			Type:     AgentEventTypeSummarize,
777			Progress: "Creating new session...",
778		}
779
780		a.Publish(pubsub.CreatedEvent, event)
781		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
782		if err != nil {
783			event = AgentEvent{
784				Type:  AgentEventTypeError,
785				Error: fmt.Errorf("failed to get session: %w", err),
786				Done:  true,
787			}
788
789			a.Publish(pubsub.CreatedEvent, event)
790			return
791		}
792		// Create a message in the new session with the summary
793		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
794			Role: message.Assistant,
795			Parts: []message.ContentPart{
796				message.TextContent{Text: summary},
797				message.Finish{
798					Reason: message.FinishReasonEndTurn,
799					Time:   time.Now().Unix(),
800				},
801			},
802			Model:    a.summarizeProvider.Model().ID,
803			Provider: a.summarizeProviderID,
804		})
805		if err != nil {
806			event = AgentEvent{
807				Type:  AgentEventTypeError,
808				Error: fmt.Errorf("failed to create summary message: %w", err),
809				Done:  true,
810			}
811
812			a.Publish(pubsub.CreatedEvent, event)
813			return
814		}
815		oldSession.SummaryMessageID = msg.ID
816		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
817		oldSession.PromptTokens = 0
818		model := a.summarizeProvider.Model()
819		usage := finalResponse.Usage
820		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
821			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
822			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
823			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
824		oldSession.Cost += cost
825		_, err = a.sessions.Save(summarizeCtx, oldSession)
826		if err != nil {
827			event = AgentEvent{
828				Type:  AgentEventTypeError,
829				Error: fmt.Errorf("failed to save session: %w", err),
830				Done:  true,
831			}
832			a.Publish(pubsub.CreatedEvent, event)
833		}
834
835		event = AgentEvent{
836			Type:      AgentEventTypeSummarize,
837			SessionID: oldSession.ID,
838			Progress:  "Summary complete",
839			Done:      true,
840		}
841		a.Publish(pubsub.CreatedEvent, event)
842		// Send final success event with the new session ID
843	}()
844
845	return nil
846}
847
848func (a *agent) CancelAll() {
849	if !a.IsBusy() {
850		return
851	}
852	for key := range a.activeRequests.Seq2() {
853		a.Cancel(key) // key is sessionID
854	}
855
856	timeout := time.After(5 * time.Second)
857	for a.IsBusy() {
858		select {
859		case <-timeout:
860			return
861		default:
862			time.Sleep(200 * time.Millisecond)
863		}
864	}
865}
866
867func (a *agent) UpdateModel() error {
868	cfg := config.Get()
869
870	// Get current provider configuration
871	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
872	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
873		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
874	}
875
876	// Check if provider has changed
877	if string(currentProviderCfg.ID) != a.providerID {
878		// Provider changed, need to recreate the main provider
879		model := cfg.GetModelByType(a.agentCfg.Model)
880		if model.ID == "" {
881			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
882		}
883
884		promptID := agentPromptMap[a.agentCfg.ID]
885		if promptID == "" {
886			promptID = prompt.PromptDefault
887		}
888
889		opts := []provider.ProviderClientOption{
890			provider.WithModel(a.agentCfg.Model),
891			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
892		}
893
894		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
895		if err != nil {
896			return fmt.Errorf("failed to create new provider: %w", err)
897		}
898
899		// Update the provider and provider ID
900		a.provider = newProvider
901		a.providerID = string(currentProviderCfg.ID)
902	}
903
904	// Check if small model provider has changed (affects title and summarize providers)
905	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
906	var smallModelProviderCfg config.ProviderConfig
907
908	for p := range cfg.Providers.Seq() {
909		if p.ID == smallModelCfg.Provider {
910			smallModelProviderCfg = p
911			break
912		}
913	}
914
915	if smallModelProviderCfg.ID == "" {
916		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
917	}
918
919	// Check if summarize provider has changed
920	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
921		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
922		if smallModel == nil {
923			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
924		}
925
926		// Recreate title provider
927		titleOpts := []provider.ProviderClientOption{
928			provider.WithModel(config.SelectedModelTypeSmall),
929			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
930			// We want the title to be short, so we limit the max tokens
931			provider.WithMaxTokens(40),
932		}
933		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
934		if err != nil {
935			return fmt.Errorf("failed to create new title provider: %w", err)
936		}
937
938		// Recreate summarize provider
939		summarizeOpts := []provider.ProviderClientOption{
940			provider.WithModel(config.SelectedModelTypeSmall),
941			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
942		}
943		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
944		if err != nil {
945			return fmt.Errorf("failed to create new summarize provider: %w", err)
946		}
947
948		// Update the providers and provider ID
949		a.titleProvider = newTitleProvider
950		a.summarizeProvider = newSummarizeProvider
951		a.summarizeProviderID = string(smallModelProviderCfg.ID)
952	}
953
954	return nil
955}