1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"slices"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/history"
 14	"github.com/charmbracelet/crush/internal/llm/prompt"
 15	"github.com/charmbracelet/crush/internal/llm/provider"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/session"
 23)
 24
 25// Common errors
 26var (
 27	ErrRequestCancelled = errors.New("request cancelled by user")
 28	ErrSessionBusy      = errors.New("session is currently processing another request")
 29)
 30
 31type AgentEventType string
 32
 33const (
 34	AgentEventTypeError     AgentEventType = "error"
 35	AgentEventTypeResponse  AgentEventType = "response"
 36	AgentEventTypeSummarize AgentEventType = "summarize"
 37)
 38
 39type AgentEvent struct {
 40	Type    AgentEventType
 41	Message message.Message
 42	Error   error
 43
 44	// When summarizing
 45	SessionID string
 46	Progress  string
 47	Done      bool
 48}
 49
 50type Service interface {
 51	pubsub.Suscriber[AgentEvent]
 52	Model() config.Model
 53	EffectiveMaxTokens() int64
 54	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 55	Cancel(sessionID string)
 56	CancelAll()
 57	IsSessionBusy(sessionID string) bool
 58	IsBusy() bool
 59	Summarize(ctx context.Context, sessionID string) error
 60	UpdateModel() error
 61}
 62
 63type agent struct {
 64	*pubsub.Broker[AgentEvent]
 65	agentCfg config.Agent
 66	sessions session.Service
 67	messages message.Service
 68
 69	tools      []tools.BaseTool
 70	provider   provider.Provider
 71	providerID string
 72
 73	titleProvider       provider.Provider
 74	summarizeProvider   provider.Provider
 75	summarizeProviderID string
 76
 77	activeRequests sync.Map
 78}
 79
 80var agentPromptMap = map[config.AgentID]prompt.PromptID{
 81	config.AgentCoder: prompt.PromptCoder,
 82	config.AgentTask:  prompt.PromptTask,
 83}
 84
 85func NewAgent(
 86	agentCfg config.Agent,
 87	// These services are needed in the tools
 88	permissions permission.Service,
 89	sessions session.Service,
 90	messages message.Service,
 91	history history.Service,
 92	lspClients map[string]*lsp.Client,
 93) (Service, error) {
 94	ctx := context.Background()
 95	cfg := config.Get()
 96	otherTools := GetMcpTools(ctx, permissions)
 97	if len(lspClients) > 0 {
 98		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 99	}
100
101	allTools := []tools.BaseTool{
102		tools.NewBashTool(permissions),
103		tools.NewEditTool(lspClients, permissions, history),
104		tools.NewFetchTool(permissions),
105		tools.NewGlobTool(),
106		tools.NewGrepTool(),
107		tools.NewLsTool(),
108		tools.NewSourcegraphTool(),
109		tools.NewViewTool(lspClients),
110		tools.NewVSCodeDiffTool(permissions),
111		tools.NewWriteTool(lspClients, permissions, history),
112	}
113
114	if agentCfg.ID == config.AgentCoder {
115		taskAgentCfg := config.Get().Agents[config.AgentTask]
116		if taskAgentCfg.ID == "" {
117			return nil, fmt.Errorf("task agent not found in config")
118		}
119		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
120		if err != nil {
121			return nil, fmt.Errorf("failed to create task agent: %w", err)
122		}
123
124		allTools = append(
125			allTools,
126			NewAgentTool(
127				taskAgent,
128				sessions,
129				messages,
130			),
131		)
132	}
133
134	allTools = append(allTools, otherTools...)
135	providerCfg := config.GetAgentProvider(agentCfg.ID)
136	if providerCfg.ID == "" {
137		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
138	}
139	model := config.GetAgentModel(agentCfg.ID)
140
141	if model.ID == "" {
142		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
143	}
144
145	promptID := agentPromptMap[agentCfg.ID]
146	if promptID == "" {
147		promptID = prompt.PromptDefault
148	}
149	opts := []provider.ProviderClientOption{
150		provider.WithModel(agentCfg.Model),
151		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
152	}
153	agentProvider, err := provider.NewProvider(providerCfg, opts...)
154	if err != nil {
155		return nil, err
156	}
157
158	smallModelCfg := cfg.Models.Small
159	var smallModel config.Model
160
161	var smallModelProviderCfg config.ProviderConfig
162	if smallModelCfg.Provider == providerCfg.ID {
163		smallModelProviderCfg = providerCfg
164	} else {
165		for _, p := range cfg.Providers {
166			if p.ID == smallModelCfg.Provider {
167				smallModelProviderCfg = p
168				break
169			}
170		}
171		if smallModelProviderCfg.ID == "" {
172			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
173		}
174	}
175	for _, m := range smallModelProviderCfg.Models {
176		if m.ID == smallModelCfg.ModelID {
177			smallModel = m
178			break
179		}
180	}
181	if smallModel.ID == "" {
182		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
183	}
184
185	titleOpts := []provider.ProviderClientOption{
186		provider.WithModel(config.SmallModel),
187		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
188	}
189	titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
190	if err != nil {
191		return nil, err
192	}
193	summarizeOpts := []provider.ProviderClientOption{
194		provider.WithModel(config.SmallModel),
195		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
196	}
197	summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
198	if err != nil {
199		return nil, err
200	}
201
202	agentTools := []tools.BaseTool{}
203	if agentCfg.AllowedTools == nil {
204		agentTools = allTools
205	} else {
206		for _, tool := range allTools {
207			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
208				agentTools = append(agentTools, tool)
209			}
210		}
211	}
212
213	agent := &agent{
214		Broker:              pubsub.NewBroker[AgentEvent](),
215		agentCfg:            agentCfg,
216		provider:            agentProvider,
217		providerID:          string(providerCfg.ID),
218		messages:            messages,
219		sessions:            sessions,
220		tools:               agentTools,
221		titleProvider:       titleProvider,
222		summarizeProvider:   summarizeProvider,
223		summarizeProviderID: string(smallModelProviderCfg.ID),
224		activeRequests:      sync.Map{},
225	}
226
227	return agent, nil
228}
229
230func (a *agent) Model() config.Model {
231	return config.GetAgentModel(a.agentCfg.ID)
232}
233
234func (a *agent) EffectiveMaxTokens() int64 {
235	return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
236}
237
238func (a *agent) Cancel(sessionID string) {
239	// Cancel regular requests
240	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
241		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
242			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
243			cancel()
244		}
245	}
246
247	// Also check for summarize requests
248	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
249		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
250			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
251			cancel()
252		}
253	}
254}
255
256func (a *agent) IsBusy() bool {
257	busy := false
258	a.activeRequests.Range(func(key, value any) bool {
259		if cancelFunc, ok := value.(context.CancelFunc); ok {
260			if cancelFunc != nil {
261				busy = true
262				return false
263			}
264		}
265		return true
266	})
267	return busy
268}
269
270func (a *agent) IsSessionBusy(sessionID string) bool {
271	_, busy := a.activeRequests.Load(sessionID)
272	return busy
273}
274
275func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
276	if content == "" {
277		return nil
278	}
279	if a.titleProvider == nil {
280		return nil
281	}
282	session, err := a.sessions.Get(ctx, sessionID)
283	if err != nil {
284		return err
285	}
286	parts := []message.ContentPart{message.TextContent{
287		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
288	}}
289
290	// Use streaming approach like summarization
291	response := a.titleProvider.StreamResponse(
292		ctx,
293		[]message.Message{
294			{
295				Role:  message.User,
296				Parts: parts,
297			},
298		},
299		make([]tools.BaseTool, 0),
300	)
301
302	var finalResponse *provider.ProviderResponse
303	for r := range response {
304		if r.Error != nil {
305			return r.Error
306		}
307		finalResponse = r.Response
308	}
309
310	if finalResponse == nil {
311		return fmt.Errorf("no response received from title provider")
312	}
313
314	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
315	if title == "" {
316		return nil
317	}
318
319	session.Title = title
320	_, err = a.sessions.Save(ctx, session)
321	return err
322}
323
324func (a *agent) err(err error) AgentEvent {
325	return AgentEvent{
326		Type:  AgentEventTypeError,
327		Error: err,
328	}
329}
330
331func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
332	if !a.Model().SupportsImages && attachments != nil {
333		attachments = nil
334	}
335	events := make(chan AgentEvent)
336	if a.IsSessionBusy(sessionID) {
337		return nil, ErrSessionBusy
338	}
339
340	genCtx, cancel := context.WithCancel(ctx)
341
342	a.activeRequests.Store(sessionID, cancel)
343	go func() {
344		logging.Debug("Request started", "sessionID", sessionID)
345		defer logging.RecoverPanic("agent.Run", func() {
346			events <- a.err(fmt.Errorf("panic while running the agent"))
347		})
348		var attachmentParts []message.ContentPart
349		for _, attachment := range attachments {
350			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
351		}
352		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
353		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
354			logging.ErrorPersist(result.Error.Error())
355		}
356		logging.Debug("Request completed", "sessionID", sessionID)
357		a.activeRequests.Delete(sessionID)
358		cancel()
359		a.Publish(pubsub.CreatedEvent, result)
360		events <- result
361		close(events)
362	}()
363	return events, nil
364}
365
366func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
367	cfg := config.Get()
368	// List existing messages; if none, start title generation asynchronously.
369	msgs, err := a.messages.List(ctx, sessionID)
370	if err != nil {
371		return a.err(fmt.Errorf("failed to list messages: %w", err))
372	}
373	if len(msgs) == 0 {
374		go func() {
375			defer logging.RecoverPanic("agent.Run", func() {
376				logging.ErrorPersist("panic while generating title")
377			})
378			titleErr := a.generateTitle(context.Background(), sessionID, content)
379			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
380				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
381			}
382		}()
383	}
384	session, err := a.sessions.Get(ctx, sessionID)
385	if err != nil {
386		return a.err(fmt.Errorf("failed to get session: %w", err))
387	}
388	if session.SummaryMessageID != "" {
389		summaryMsgInex := -1
390		for i, msg := range msgs {
391			if msg.ID == session.SummaryMessageID {
392				summaryMsgInex = i
393				break
394			}
395		}
396		if summaryMsgInex != -1 {
397			msgs = msgs[summaryMsgInex:]
398			msgs[0].Role = message.User
399		}
400	}
401
402	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
403	if err != nil {
404		return a.err(fmt.Errorf("failed to create user message: %w", err))
405	}
406	// Append the new user message to the conversation history.
407	msgHistory := append(msgs, userMsg)
408
409	for {
410		// Check for cancellation before each iteration
411		select {
412		case <-ctx.Done():
413			return a.err(ctx.Err())
414		default:
415			// Continue processing
416		}
417		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
418		if err != nil {
419			if errors.Is(err, context.Canceled) {
420				agentMessage.AddFinish(message.FinishReasonCanceled)
421				a.messages.Update(context.Background(), agentMessage)
422				return a.err(ErrRequestCancelled)
423			}
424			return a.err(fmt.Errorf("failed to process events: %w", err))
425		}
426		if cfg.Options.Debug {
427			seqId := (len(msgHistory) + 1) / 2
428			toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults)
429			logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath)
430		} else {
431			logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
432		}
433		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
434			// We are not done, we need to respond with the tool response
435			msgHistory = append(msgHistory, agentMessage, *toolResults)
436			continue
437		}
438		return AgentEvent{
439			Type:    AgentEventTypeResponse,
440			Message: agentMessage,
441			Done:    true,
442		}
443	}
444}
445
446func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
447	parts := []message.ContentPart{message.TextContent{Text: content}}
448	parts = append(parts, attachmentParts...)
449	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
450		Role:  message.User,
451		Parts: parts,
452	})
453}
454
455func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
456	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
457	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
458
459	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
460		Role:     message.Assistant,
461		Parts:    []message.ContentPart{},
462		Model:    a.Model().ID,
463		Provider: a.providerID,
464	})
465	if err != nil {
466		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
467	}
468
469	// Add the session and message ID into the context if needed by tools.
470	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
471
472	// Process each event in the stream.
473	for event := range eventChan {
474		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
475			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
476			return assistantMsg, nil, processErr
477		}
478		if ctx.Err() != nil {
479			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
480			return assistantMsg, nil, ctx.Err()
481		}
482	}
483
484	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
485	toolCalls := assistantMsg.ToolCalls()
486	for i, toolCall := range toolCalls {
487		select {
488		case <-ctx.Done():
489			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
490			// Make all future tool calls cancelled
491			for j := i; j < len(toolCalls); j++ {
492				toolResults[j] = message.ToolResult{
493					ToolCallID: toolCalls[j].ID,
494					Content:    "Tool execution canceled by user",
495					IsError:    true,
496				}
497			}
498			goto out
499		default:
500			// Continue processing
501			var tool tools.BaseTool
502			for _, availableTool := range a.tools {
503				if availableTool.Info().Name == toolCall.Name {
504					tool = availableTool
505					break
506				}
507			}
508
509			// Tool not found
510			if tool == nil {
511				toolResults[i] = message.ToolResult{
512					ToolCallID: toolCall.ID,
513					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
514					IsError:    true,
515				}
516				continue
517			}
518			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
519				ID:    toolCall.ID,
520				Name:  toolCall.Name,
521				Input: toolCall.Input,
522			})
523			if toolErr != nil {
524				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
525					toolResults[i] = message.ToolResult{
526						ToolCallID: toolCall.ID,
527						Content:    "Permission denied",
528						IsError:    true,
529					}
530					for j := i + 1; j < len(toolCalls); j++ {
531						toolResults[j] = message.ToolResult{
532							ToolCallID: toolCalls[j].ID,
533							Content:    "Tool execution canceled by user",
534							IsError:    true,
535						}
536					}
537					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
538					break
539				}
540			}
541			toolResults[i] = message.ToolResult{
542				ToolCallID: toolCall.ID,
543				Content:    toolResult.Content,
544				Metadata:   toolResult.Metadata,
545				IsError:    toolResult.IsError,
546			}
547		}
548	}
549out:
550	if len(toolResults) == 0 {
551		return assistantMsg, nil, nil
552	}
553	parts := make([]message.ContentPart, 0)
554	for _, tr := range toolResults {
555		parts = append(parts, tr)
556	}
557	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
558		Role:     message.Tool,
559		Parts:    parts,
560		Provider: a.providerID,
561	})
562	if err != nil {
563		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
564	}
565
566	return assistantMsg, &msg, err
567}
568
569func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
570	msg.AddFinish(finishReson)
571	_ = a.messages.Update(ctx, *msg)
572}
573
574func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
575	select {
576	case <-ctx.Done():
577		return ctx.Err()
578	default:
579		// Continue processing.
580	}
581
582	switch event.Type {
583	case provider.EventThinkingDelta:
584		assistantMsg.AppendReasoningContent(event.Content)
585		return a.messages.Update(ctx, *assistantMsg)
586	case provider.EventContentDelta:
587		assistantMsg.AppendContent(event.Content)
588		return a.messages.Update(ctx, *assistantMsg)
589	case provider.EventToolUseStart:
590		logging.Info("Tool call started", "toolCall", event.ToolCall)
591		assistantMsg.AddToolCall(*event.ToolCall)
592		return a.messages.Update(ctx, *assistantMsg)
593	case provider.EventToolUseDelta:
594		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
595		return a.messages.Update(ctx, *assistantMsg)
596	case provider.EventToolUseStop:
597		logging.Info("Finished tool call", "toolCall", event.ToolCall)
598		assistantMsg.FinishToolCall(event.ToolCall.ID)
599		return a.messages.Update(ctx, *assistantMsg)
600	case provider.EventError:
601		if errors.Is(event.Error, context.Canceled) {
602			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
603			return context.Canceled
604		}
605		logging.ErrorPersist(event.Error.Error())
606		return event.Error
607	case provider.EventComplete:
608		assistantMsg.SetToolCalls(event.Response.ToolCalls)
609		assistantMsg.AddFinish(event.Response.FinishReason)
610		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
611			return fmt.Errorf("failed to update message: %w", err)
612		}
613		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
614	}
615
616	return nil
617}
618
619func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
620	sess, err := a.sessions.Get(ctx, sessionID)
621	if err != nil {
622		return fmt.Errorf("failed to get session: %w", err)
623	}
624
625	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
626		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
627		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
628		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
629
630	sess.Cost += cost
631	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
632	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
633
634	_, err = a.sessions.Save(ctx, sess)
635	if err != nil {
636		return fmt.Errorf("failed to save session: %w", err)
637	}
638	return nil
639}
640
641func (a *agent) Summarize(ctx context.Context, sessionID string) error {
642	if a.summarizeProvider == nil {
643		return fmt.Errorf("summarize provider not available")
644	}
645
646	// Check if session is busy
647	if a.IsSessionBusy(sessionID) {
648		return ErrSessionBusy
649	}
650
651	// Create a new context with cancellation
652	summarizeCtx, cancel := context.WithCancel(ctx)
653
654	// Store the cancel function in activeRequests to allow cancellation
655	a.activeRequests.Store(sessionID+"-summarize", cancel)
656
657	go func() {
658		defer a.activeRequests.Delete(sessionID + "-summarize")
659		defer cancel()
660		event := AgentEvent{
661			Type:     AgentEventTypeSummarize,
662			Progress: "Starting summarization...",
663		}
664
665		a.Publish(pubsub.CreatedEvent, event)
666		// Get all messages from the session
667		msgs, err := a.messages.List(summarizeCtx, sessionID)
668		if err != nil {
669			event = AgentEvent{
670				Type:  AgentEventTypeError,
671				Error: fmt.Errorf("failed to list messages: %w", err),
672				Done:  true,
673			}
674			a.Publish(pubsub.CreatedEvent, event)
675			return
676		}
677		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
678
679		if len(msgs) == 0 {
680			event = AgentEvent{
681				Type:  AgentEventTypeError,
682				Error: fmt.Errorf("no messages to summarize"),
683				Done:  true,
684			}
685			a.Publish(pubsub.CreatedEvent, event)
686			return
687		}
688
689		event = AgentEvent{
690			Type:     AgentEventTypeSummarize,
691			Progress: "Analyzing conversation...",
692		}
693		a.Publish(pubsub.CreatedEvent, event)
694
695		// Add a system message to guide the summarization
696		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
697
698		// Create a new message with the summarize prompt
699		promptMsg := message.Message{
700			Role:  message.User,
701			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
702		}
703
704		// Append the prompt to the messages
705		msgsWithPrompt := append(msgs, promptMsg)
706
707		event = AgentEvent{
708			Type:     AgentEventTypeSummarize,
709			Progress: "Generating summary...",
710		}
711
712		a.Publish(pubsub.CreatedEvent, event)
713
714		// Send the messages to the summarize provider
715		response := a.summarizeProvider.StreamResponse(
716			summarizeCtx,
717			msgsWithPrompt,
718			make([]tools.BaseTool, 0),
719		)
720		var finalResponse *provider.ProviderResponse
721		for r := range response {
722			if r.Error != nil {
723				event = AgentEvent{
724					Type:  AgentEventTypeError,
725					Error: fmt.Errorf("failed to summarize: %w", err),
726					Done:  true,
727				}
728				a.Publish(pubsub.CreatedEvent, event)
729				return
730			}
731			finalResponse = r.Response
732		}
733
734		summary := strings.TrimSpace(finalResponse.Content)
735		if summary == "" {
736			event = AgentEvent{
737				Type:  AgentEventTypeError,
738				Error: fmt.Errorf("empty summary returned"),
739				Done:  true,
740			}
741			a.Publish(pubsub.CreatedEvent, event)
742			return
743		}
744		event = AgentEvent{
745			Type:     AgentEventTypeSummarize,
746			Progress: "Creating new session...",
747		}
748
749		a.Publish(pubsub.CreatedEvent, event)
750		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
751		if err != nil {
752			event = AgentEvent{
753				Type:  AgentEventTypeError,
754				Error: fmt.Errorf("failed to get session: %w", err),
755				Done:  true,
756			}
757
758			a.Publish(pubsub.CreatedEvent, event)
759			return
760		}
761		// Create a message in the new session with the summary
762		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
763			Role: message.Assistant,
764			Parts: []message.ContentPart{
765				message.TextContent{Text: summary},
766				message.Finish{
767					Reason: message.FinishReasonEndTurn,
768					Time:   time.Now().Unix(),
769				},
770			},
771			Model:    a.summarizeProvider.Model().ID,
772			Provider: a.summarizeProviderID,
773		})
774		if err != nil {
775			event = AgentEvent{
776				Type:  AgentEventTypeError,
777				Error: fmt.Errorf("failed to create summary message: %w", err),
778				Done:  true,
779			}
780
781			a.Publish(pubsub.CreatedEvent, event)
782			return
783		}
784		oldSession.SummaryMessageID = msg.ID
785		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
786		oldSession.PromptTokens = 0
787		model := a.summarizeProvider.Model()
788		usage := finalResponse.Usage
789		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
790			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
791			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
792			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
793		oldSession.Cost += cost
794		_, err = a.sessions.Save(summarizeCtx, oldSession)
795		if err != nil {
796			event = AgentEvent{
797				Type:  AgentEventTypeError,
798				Error: fmt.Errorf("failed to save session: %w", err),
799				Done:  true,
800			}
801			a.Publish(pubsub.CreatedEvent, event)
802		}
803
804		event = AgentEvent{
805			Type:      AgentEventTypeSummarize,
806			SessionID: oldSession.ID,
807			Progress:  "Summary complete",
808			Done:      true,
809		}
810		a.Publish(pubsub.CreatedEvent, event)
811		// Send final success event with the new session ID
812	}()
813
814	return nil
815}
816
817func (a *agent) CancelAll() {
818	a.activeRequests.Range(func(key, value any) bool {
819		a.Cancel(key.(string)) // key is sessionID
820		return true
821	})
822}
823
824func (a *agent) UpdateModel() error {
825	cfg := config.Get()
826
827	// Get current provider configuration
828	currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
829	if currentProviderCfg.ID == "" {
830		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
831	}
832
833	// Check if provider has changed
834	if string(currentProviderCfg.ID) != a.providerID {
835		// Provider changed, need to recreate the main provider
836		model := config.GetAgentModel(a.agentCfg.ID)
837		if model.ID == "" {
838			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
839		}
840
841		promptID := agentPromptMap[a.agentCfg.ID]
842		if promptID == "" {
843			promptID = prompt.PromptDefault
844		}
845
846		opts := []provider.ProviderClientOption{
847			provider.WithModel(a.agentCfg.Model),
848			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
849		}
850
851		newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
852		if err != nil {
853			return fmt.Errorf("failed to create new provider: %w", err)
854		}
855
856		// Update the provider and provider ID
857		a.provider = newProvider
858		a.providerID = string(currentProviderCfg.ID)
859	}
860
861	// Check if small model provider has changed (affects title and summarize providers)
862	smallModelCfg := cfg.Models.Small
863	var smallModelProviderCfg config.ProviderConfig
864
865	for _, p := range cfg.Providers {
866		if p.ID == smallModelCfg.Provider {
867			smallModelProviderCfg = p
868			break
869		}
870	}
871
872	if smallModelProviderCfg.ID == "" {
873		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
874	}
875
876	// Check if summarize provider has changed
877	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
878		var smallModel config.Model
879		for _, m := range smallModelProviderCfg.Models {
880			if m.ID == smallModelCfg.ModelID {
881				smallModel = m
882				break
883			}
884		}
885		if smallModel.ID == "" {
886			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
887		}
888
889		// Recreate title provider
890		titleOpts := []provider.ProviderClientOption{
891			provider.WithModel(config.SmallModel),
892			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
893			// We want the title to be short, so we limit the max tokens
894			provider.WithMaxTokens(40),
895		}
896		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
897		if err != nil {
898			return fmt.Errorf("failed to create new title provider: %w", err)
899		}
900
901		// Recreate summarize provider
902		summarizeOpts := []provider.ProviderClientOption{
903			provider.WithModel(config.SmallModel),
904			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
905		}
906		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
907		if err != nil {
908			return fmt.Errorf("failed to create new summarize provider: %w", err)
909		}
910
911		// Update the providers and provider ID
912		a.titleProvider = newTitleProvider
913		a.summarizeProvider = newSummarizeProvider
914		a.summarizeProviderID = string(smallModelProviderCfg.ID)
915	}
916
917	return nil
918}