1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"slices"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/history"
 14	"github.com/charmbracelet/crush/internal/llm/prompt"
 15	"github.com/charmbracelet/crush/internal/llm/provider"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/session"
 23)
 24
 25// Common errors
 26var (
 27	ErrRequestCancelled = errors.New("request cancelled by user")
 28	ErrSessionBusy      = errors.New("session is currently processing another request")
 29)
 30
 31type AgentEventType string
 32
 33const (
 34	AgentEventTypeError     AgentEventType = "error"
 35	AgentEventTypeResponse  AgentEventType = "response"
 36	AgentEventTypeSummarize AgentEventType = "summarize"
 37)
 38
 39type AgentEvent struct {
 40	Type    AgentEventType
 41	Message message.Message
 42	Error   error
 43
 44	// When summarizing
 45	SessionID string
 46	Progress  string
 47	Done      bool
 48}
 49
 50type Service interface {
 51	pubsub.Suscriber[AgentEvent]
 52	Model() config.Model
 53	EffectiveMaxTokens() int64
 54	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 55	Cancel(sessionID string)
 56	CancelAll()
 57	IsSessionBusy(sessionID string) bool
 58	IsBusy() bool
 59	Summarize(ctx context.Context, sessionID string) error
 60	UpdateModel() error
 61}
 62
 63type agent struct {
 64	*pubsub.Broker[AgentEvent]
 65	agentCfg config.Agent
 66	sessions session.Service
 67	messages message.Service
 68
 69	tools      []tools.BaseTool
 70	provider   provider.Provider
 71	providerID string
 72
 73	titleProvider       provider.Provider
 74	summarizeProvider   provider.Provider
 75	summarizeProviderID string
 76
 77	activeRequests sync.Map
 78}
 79
 80var agentPromptMap = map[config.AgentID]prompt.PromptID{
 81	config.AgentCoder: prompt.PromptCoder,
 82	config.AgentTask:  prompt.PromptTask,
 83}
 84
 85func NewAgent(
 86	agentCfg config.Agent,
 87	// These services are needed in the tools
 88	permissions permission.Service,
 89	sessions session.Service,
 90	messages message.Service,
 91	history history.Service,
 92	lspClients map[string]*lsp.Client,
 93) (Service, error) {
 94	ctx := context.Background()
 95	cfg := config.Get()
 96	otherTools := GetMcpTools(ctx, permissions)
 97	if len(cfg.LSP) > 0 {
 98		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 99		otherTools = append(otherTools, tools.NewDefinitionsTool(lspClients))
100	}
101
102	allTools := []tools.BaseTool{
103		tools.NewBashTool(permissions),
104		tools.NewEditTool(lspClients, permissions, history),
105		tools.NewFetchTool(permissions),
106		tools.NewGlobTool(),
107		tools.NewGrepTool(),
108		tools.NewLsTool(),
109		tools.NewSourcegraphTool(),
110		tools.NewViewTool(lspClients),
111		tools.NewWriteTool(lspClients, permissions, history),
112	}
113
114	if agentCfg.ID == config.AgentCoder {
115		taskAgentCfg := config.Get().Agents[config.AgentTask]
116		if taskAgentCfg.ID == "" {
117			return nil, fmt.Errorf("task agent not found in config")
118		}
119		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
120		if err != nil {
121			return nil, fmt.Errorf("failed to create task agent: %w", err)
122		}
123
124		allTools = append(
125			allTools,
126			NewAgentTool(
127				taskAgent,
128				sessions,
129				messages,
130			),
131		)
132	}
133
134	allTools = append(allTools, otherTools...)
135	providerCfg := config.GetAgentProvider(agentCfg.ID)
136	if providerCfg.ID == "" {
137		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
138	}
139	model := config.GetAgentModel(agentCfg.ID)
140
141	if model.ID == "" {
142		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
143	}
144
145	promptID := agentPromptMap[agentCfg.ID]
146	if promptID == "" {
147		promptID = prompt.PromptDefault
148	}
149	opts := []provider.ProviderClientOption{
150		provider.WithModel(agentCfg.Model),
151		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
152	}
153	agentProvider, err := provider.NewProvider(providerCfg, opts...)
154	if err != nil {
155		return nil, err
156	}
157
158	smallModelCfg := cfg.Models.Small
159	var smallModel config.Model
160
161	var smallModelProviderCfg config.ProviderConfig
162	if smallModelCfg.Provider == providerCfg.ID {
163		smallModelProviderCfg = providerCfg
164	} else {
165		for _, p := range cfg.Providers {
166			if p.ID == smallModelCfg.Provider {
167				smallModelProviderCfg = p
168				break
169			}
170		}
171		if smallModelProviderCfg.ID == "" {
172			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
173		}
174	}
175	for _, m := range smallModelProviderCfg.Models {
176		if m.ID == smallModelCfg.ModelID {
177			smallModel = m
178			break
179		}
180	}
181	if smallModel.ID == "" {
182		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
183	}
184
185	titleOpts := []provider.ProviderClientOption{
186		provider.WithModel(config.SmallModel),
187		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
188	}
189	titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
190	if err != nil {
191		return nil, err
192	}
193	summarizeOpts := []provider.ProviderClientOption{
194		provider.WithModel(config.SmallModel),
195		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
196	}
197	summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
198	if err != nil {
199		return nil, err
200	}
201
202	agentTools := []tools.BaseTool{}
203	if agentCfg.AllowedTools == nil {
204		agentTools = allTools
205	} else {
206		for _, tool := range allTools {
207			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
208				agentTools = append(agentTools, tool)
209			}
210		}
211	}
212
213	agent := &agent{
214		Broker:              pubsub.NewBroker[AgentEvent](),
215		agentCfg:            agentCfg,
216		provider:            agentProvider,
217		providerID:          string(providerCfg.ID),
218		messages:            messages,
219		sessions:            sessions,
220		tools:               agentTools,
221		titleProvider:       titleProvider,
222		summarizeProvider:   summarizeProvider,
223		summarizeProviderID: string(smallModelProviderCfg.ID),
224		activeRequests:      sync.Map{},
225	}
226
227	return agent, nil
228}
229
230func (a *agent) Model() config.Model {
231	return config.GetAgentModel(a.agentCfg.ID)
232}
233
234func (a *agent) EffectiveMaxTokens() int64 {
235	return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
236}
237
238func (a *agent) Cancel(sessionID string) {
239	// Cancel regular requests
240	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
241		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
242			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
243			cancel()
244		}
245	}
246
247	// Also check for summarize requests
248	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
249		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
250			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
251			cancel()
252		}
253	}
254}
255
256func (a *agent) IsBusy() bool {
257	busy := false
258	a.activeRequests.Range(func(key, value any) bool {
259		if cancelFunc, ok := value.(context.CancelFunc); ok {
260			if cancelFunc != nil {
261				busy = true
262				return false
263			}
264		}
265		return true
266	})
267	return busy
268}
269
270func (a *agent) IsSessionBusy(sessionID string) bool {
271	_, busy := a.activeRequests.Load(sessionID)
272	return busy
273}
274
275func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
276	if content == "" {
277		return nil
278	}
279	if a.titleProvider == nil {
280		return nil
281	}
282	session, err := a.sessions.Get(ctx, sessionID)
283	if err != nil {
284		return err
285	}
286	parts := []message.ContentPart{message.TextContent{
287		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
288	}}
289
290	// Use streaming approach like summarization
291	response := a.titleProvider.StreamResponse(
292		ctx,
293		[]message.Message{
294			{
295				Role:  message.User,
296				Parts: parts,
297			},
298		},
299		make([]tools.BaseTool, 0),
300	)
301
302	var finalResponse *provider.ProviderResponse
303	for r := range response {
304		if r.Error != nil {
305			return r.Error
306		}
307		finalResponse = r.Response
308	}
309
310	if finalResponse == nil {
311		return fmt.Errorf("no response received from title provider")
312	}
313
314	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
315	if title == "" {
316		return nil
317	}
318
319	session.Title = title
320	_, err = a.sessions.Save(ctx, session)
321	return err
322}
323
324func (a *agent) err(err error) AgentEvent {
325	return AgentEvent{
326		Type:  AgentEventTypeError,
327		Error: err,
328	}
329}
330
331func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
332	if !a.Model().SupportsImages && attachments != nil {
333		attachments = nil
334	}
335	events := make(chan AgentEvent)
336	if a.IsSessionBusy(sessionID) {
337		return nil, ErrSessionBusy
338	}
339
340	genCtx, cancel := context.WithCancel(ctx)
341
342	a.activeRequests.Store(sessionID, cancel)
343	go func() {
344		logging.Debug("Request started", "sessionID", sessionID)
345		defer logging.RecoverPanic("agent.Run", func() {
346			events <- a.err(fmt.Errorf("panic while running the agent"))
347		})
348		var attachmentParts []message.ContentPart
349		for _, attachment := range attachments {
350			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
351		}
352		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
353		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
354			logging.ErrorPersist(result.Error.Error())
355		}
356		logging.Debug("Request completed", "sessionID", sessionID)
357		a.activeRequests.Delete(sessionID)
358		cancel()
359		a.Publish(pubsub.CreatedEvent, result)
360		events <- result
361		close(events)
362	}()
363	return events, nil
364}
365
366func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
367	// List existing messages; if none, start title generation asynchronously.
368	msgs, err := a.messages.List(ctx, sessionID)
369	if err != nil {
370		return a.err(fmt.Errorf("failed to list messages: %w", err))
371	}
372	if len(msgs) == 0 {
373		go func() {
374			defer logging.RecoverPanic("agent.Run", func() {
375				logging.ErrorPersist("panic while generating title")
376			})
377			titleErr := a.generateTitle(context.Background(), sessionID, content)
378			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
379				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
380			}
381		}()
382	}
383	session, err := a.sessions.Get(ctx, sessionID)
384	if err != nil {
385		return a.err(fmt.Errorf("failed to get session: %w", err))
386	}
387	if session.SummaryMessageID != "" {
388		summaryMsgInex := -1
389		for i, msg := range msgs {
390			if msg.ID == session.SummaryMessageID {
391				summaryMsgInex = i
392				break
393			}
394		}
395		if summaryMsgInex != -1 {
396			msgs = msgs[summaryMsgInex:]
397			msgs[0].Role = message.User
398		}
399	}
400
401	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
402	if err != nil {
403		return a.err(fmt.Errorf("failed to create user message: %w", err))
404	}
405	// Append the new user message to the conversation history.
406	msgHistory := append(msgs, userMsg)
407
408	for {
409		// Check for cancellation before each iteration
410		select {
411		case <-ctx.Done():
412			return a.err(ctx.Err())
413		default:
414			// Continue processing
415		}
416		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
417		if err != nil {
418			if errors.Is(err, context.Canceled) {
419				agentMessage.AddFinish(message.FinishReasonCanceled)
420				a.messages.Update(context.Background(), agentMessage)
421				return a.err(ErrRequestCancelled)
422			}
423			return a.err(fmt.Errorf("failed to process events: %w", err))
424		}
425		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
426		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
427			// We are not done, we need to respond with the tool response
428			msgHistory = append(msgHistory, agentMessage, *toolResults)
429			continue
430		}
431		return AgentEvent{
432			Type:    AgentEventTypeResponse,
433			Message: agentMessage,
434			Done:    true,
435		}
436	}
437}
438
439func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
440	parts := []message.ContentPart{message.TextContent{Text: content}}
441	parts = append(parts, attachmentParts...)
442	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
443		Role:  message.User,
444		Parts: parts,
445	})
446}
447
448func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
449	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
450
451	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
452		Role:     message.Assistant,
453		Parts:    []message.ContentPart{},
454		Model:    a.Model().ID,
455		Provider: a.providerID,
456	})
457	if err != nil {
458		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
459	}
460
461	// Add the session and message ID into the context if needed by tools.
462	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
463	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
464
465	// Process each event in the stream.
466	for event := range eventChan {
467		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
468			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
469			return assistantMsg, nil, processErr
470		}
471		if ctx.Err() != nil {
472			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
473			return assistantMsg, nil, ctx.Err()
474		}
475	}
476
477	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
478	toolCalls := assistantMsg.ToolCalls()
479	for i, toolCall := range toolCalls {
480		select {
481		case <-ctx.Done():
482			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
483			// Make all future tool calls cancelled
484			for j := i; j < len(toolCalls); j++ {
485				toolResults[j] = message.ToolResult{
486					ToolCallID: toolCalls[j].ID,
487					Content:    "Tool execution canceled by user",
488					IsError:    true,
489				}
490			}
491			goto out
492		default:
493			// Continue processing
494			var tool tools.BaseTool
495			for _, availableTools := range a.tools {
496				if availableTools.Info().Name == toolCall.Name {
497					tool = availableTools
498				}
499			}
500
501			// Tool not found
502			if tool == nil {
503				toolResults[i] = message.ToolResult{
504					ToolCallID: toolCall.ID,
505					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
506					IsError:    true,
507				}
508				continue
509			}
510			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
511				ID:    toolCall.ID,
512				Name:  toolCall.Name,
513				Input: toolCall.Input,
514			})
515			if toolErr != nil {
516				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
517					toolResults[i] = message.ToolResult{
518						ToolCallID: toolCall.ID,
519						Content:    "Permission denied",
520						IsError:    true,
521					}
522					for j := i + 1; j < len(toolCalls); j++ {
523						toolResults[j] = message.ToolResult{
524							ToolCallID: toolCalls[j].ID,
525							Content:    "Tool execution canceled by user",
526							IsError:    true,
527						}
528					}
529					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
530					break
531				}
532			}
533			toolResults[i] = message.ToolResult{
534				ToolCallID: toolCall.ID,
535				Content:    toolResult.Content,
536				Metadata:   toolResult.Metadata,
537				IsError:    toolResult.IsError,
538			}
539		}
540	}
541out:
542	if len(toolResults) == 0 {
543		return assistantMsg, nil, nil
544	}
545	parts := make([]message.ContentPart, 0)
546	for _, tr := range toolResults {
547		parts = append(parts, tr)
548	}
549	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
550		Role:     message.Tool,
551		Parts:    parts,
552		Provider: a.providerID,
553	})
554	if err != nil {
555		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
556	}
557
558	return assistantMsg, &msg, err
559}
560
561func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
562	msg.AddFinish(finishReson)
563	_ = a.messages.Update(ctx, *msg)
564}
565
566func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
567	select {
568	case <-ctx.Done():
569		return ctx.Err()
570	default:
571		// Continue processing.
572	}
573
574	switch event.Type {
575	case provider.EventThinkingDelta:
576		assistantMsg.AppendReasoningContent(event.Content)
577		return a.messages.Update(ctx, *assistantMsg)
578	case provider.EventContentDelta:
579		assistantMsg.AppendContent(event.Content)
580		return a.messages.Update(ctx, *assistantMsg)
581	case provider.EventToolUseStart:
582		logging.Info("Tool call started", "toolCall", event.ToolCall)
583		assistantMsg.AddToolCall(*event.ToolCall)
584		return a.messages.Update(ctx, *assistantMsg)
585	case provider.EventToolUseDelta:
586		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
587		return a.messages.Update(ctx, *assistantMsg)
588	case provider.EventToolUseStop:
589		logging.Info("Finished tool call", "toolCall", event.ToolCall)
590		assistantMsg.FinishToolCall(event.ToolCall.ID)
591		return a.messages.Update(ctx, *assistantMsg)
592	case provider.EventError:
593		if errors.Is(event.Error, context.Canceled) {
594			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
595			return context.Canceled
596		}
597		logging.ErrorPersist(event.Error.Error())
598		return event.Error
599	case provider.EventComplete:
600		assistantMsg.SetToolCalls(event.Response.ToolCalls)
601		assistantMsg.AddFinish(event.Response.FinishReason)
602		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
603			return fmt.Errorf("failed to update message: %w", err)
604		}
605		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
606	}
607
608	return nil
609}
610
611func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
612	sess, err := a.sessions.Get(ctx, sessionID)
613	if err != nil {
614		return fmt.Errorf("failed to get session: %w", err)
615	}
616
617	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
618		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
619		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
620		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
621
622	sess.Cost += cost
623	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
624	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
625
626	_, err = a.sessions.Save(ctx, sess)
627	if err != nil {
628		return fmt.Errorf("failed to save session: %w", err)
629	}
630	return nil
631}
632
633func (a *agent) Summarize(ctx context.Context, sessionID string) error {
634	if a.summarizeProvider == nil {
635		return fmt.Errorf("summarize provider not available")
636	}
637
638	// Check if session is busy
639	if a.IsSessionBusy(sessionID) {
640		return ErrSessionBusy
641	}
642
643	// Create a new context with cancellation
644	summarizeCtx, cancel := context.WithCancel(ctx)
645
646	// Store the cancel function in activeRequests to allow cancellation
647	a.activeRequests.Store(sessionID+"-summarize", cancel)
648
649	go func() {
650		defer a.activeRequests.Delete(sessionID + "-summarize")
651		defer cancel()
652		event := AgentEvent{
653			Type:     AgentEventTypeSummarize,
654			Progress: "Starting summarization...",
655		}
656
657		a.Publish(pubsub.CreatedEvent, event)
658		// Get all messages from the session
659		msgs, err := a.messages.List(summarizeCtx, sessionID)
660		if err != nil {
661			event = AgentEvent{
662				Type:  AgentEventTypeError,
663				Error: fmt.Errorf("failed to list messages: %w", err),
664				Done:  true,
665			}
666			a.Publish(pubsub.CreatedEvent, event)
667			return
668		}
669
670		if len(msgs) == 0 {
671			event = AgentEvent{
672				Type:  AgentEventTypeError,
673				Error: fmt.Errorf("no messages to summarize"),
674				Done:  true,
675			}
676			a.Publish(pubsub.CreatedEvent, event)
677			return
678		}
679
680		event = AgentEvent{
681			Type:     AgentEventTypeSummarize,
682			Progress: "Analyzing conversation...",
683		}
684		a.Publish(pubsub.CreatedEvent, event)
685
686		// Add a system message to guide the summarization
687		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
688
689		// Create a new message with the summarize prompt
690		promptMsg := message.Message{
691			Role:  message.User,
692			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
693		}
694
695		// Append the prompt to the messages
696		msgsWithPrompt := append(msgs, promptMsg)
697
698		event = AgentEvent{
699			Type:     AgentEventTypeSummarize,
700			Progress: "Generating summary...",
701		}
702
703		a.Publish(pubsub.CreatedEvent, event)
704
705		// Send the messages to the summarize provider
706		response := a.summarizeProvider.StreamResponse(
707			summarizeCtx,
708			msgsWithPrompt,
709			make([]tools.BaseTool, 0),
710		)
711		var finalResponse *provider.ProviderResponse
712		for r := range response {
713			if r.Error != nil {
714				event = AgentEvent{
715					Type:  AgentEventTypeError,
716					Error: fmt.Errorf("failed to summarize: %w", err),
717					Done:  true,
718				}
719				a.Publish(pubsub.CreatedEvent, event)
720				return
721			}
722			finalResponse = r.Response
723		}
724
725		summary := strings.TrimSpace(finalResponse.Content)
726		if summary == "" {
727			event = AgentEvent{
728				Type:  AgentEventTypeError,
729				Error: fmt.Errorf("empty summary returned"),
730				Done:  true,
731			}
732			a.Publish(pubsub.CreatedEvent, event)
733			return
734		}
735		event = AgentEvent{
736			Type:     AgentEventTypeSummarize,
737			Progress: "Creating new session...",
738		}
739
740		a.Publish(pubsub.CreatedEvent, event)
741		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
742		if err != nil {
743			event = AgentEvent{
744				Type:  AgentEventTypeError,
745				Error: fmt.Errorf("failed to get session: %w", err),
746				Done:  true,
747			}
748
749			a.Publish(pubsub.CreatedEvent, event)
750			return
751		}
752		// Create a message in the new session with the summary
753		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
754			Role: message.Assistant,
755			Parts: []message.ContentPart{
756				message.TextContent{Text: summary},
757				message.Finish{
758					Reason: message.FinishReasonEndTurn,
759					Time:   time.Now().Unix(),
760				},
761			},
762			Model:    a.summarizeProvider.Model().ID,
763			Provider: a.summarizeProviderID,
764		})
765		if err != nil {
766			event = AgentEvent{
767				Type:  AgentEventTypeError,
768				Error: fmt.Errorf("failed to create summary message: %w", err),
769				Done:  true,
770			}
771
772			a.Publish(pubsub.CreatedEvent, event)
773			return
774		}
775		oldSession.SummaryMessageID = msg.ID
776		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
777		oldSession.PromptTokens = 0
778		model := a.summarizeProvider.Model()
779		usage := finalResponse.Usage
780		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
781			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
782			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
783			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
784		oldSession.Cost += cost
785		_, err = a.sessions.Save(summarizeCtx, oldSession)
786		if err != nil {
787			event = AgentEvent{
788				Type:  AgentEventTypeError,
789				Error: fmt.Errorf("failed to save session: %w", err),
790				Done:  true,
791			}
792			a.Publish(pubsub.CreatedEvent, event)
793		}
794
795		event = AgentEvent{
796			Type:      AgentEventTypeSummarize,
797			SessionID: oldSession.ID,
798			Progress:  "Summary complete",
799			Done:      true,
800		}
801		a.Publish(pubsub.CreatedEvent, event)
802		// Send final success event with the new session ID
803	}()
804
805	return nil
806}
807
808func (a *agent) CancelAll() {
809	a.activeRequests.Range(func(key, value any) bool {
810		a.Cancel(key.(string)) // key is sessionID
811		return true
812	})
813}
814
815func (a *agent) UpdateModel() error {
816	cfg := config.Get()
817
818	// Get current provider configuration
819	currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
820	if currentProviderCfg.ID == "" {
821		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
822	}
823
824	// Check if provider has changed
825	if string(currentProviderCfg.ID) != a.providerID {
826		// Provider changed, need to recreate the main provider
827		model := config.GetAgentModel(a.agentCfg.ID)
828		if model.ID == "" {
829			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
830		}
831
832		promptID := agentPromptMap[a.agentCfg.ID]
833		if promptID == "" {
834			promptID = prompt.PromptDefault
835		}
836
837		opts := []provider.ProviderClientOption{
838			provider.WithModel(a.agentCfg.Model),
839			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
840		}
841
842		newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
843		if err != nil {
844			return fmt.Errorf("failed to create new provider: %w", err)
845		}
846
847		// Update the provider and provider ID
848		a.provider = newProvider
849		a.providerID = string(currentProviderCfg.ID)
850	}
851
852	// Check if small model provider has changed (affects title and summarize providers)
853	smallModelCfg := cfg.Models.Small
854	var smallModelProviderCfg config.ProviderConfig
855
856	for _, p := range cfg.Providers {
857		if p.ID == smallModelCfg.Provider {
858			smallModelProviderCfg = p
859			break
860		}
861	}
862
863	if smallModelProviderCfg.ID == "" {
864		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
865	}
866
867	// Check if summarize provider has changed
868	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
869		var smallModel config.Model
870		for _, m := range smallModelProviderCfg.Models {
871			if m.ID == smallModelCfg.ModelID {
872				smallModel = m
873				break
874			}
875		}
876		if smallModel.ID == "" {
877			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
878		}
879
880		// Recreate title provider
881		titleOpts := []provider.ProviderClientOption{
882			provider.WithModel(config.SmallModel),
883			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
884			// We want the title to be short, so we limit the max tokens
885			provider.WithMaxTokens(40),
886		}
887		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
888		if err != nil {
889			return fmt.Errorf("failed to create new title provider: %w", err)
890		}
891
892		// Recreate summarize provider
893		summarizeOpts := []provider.ProviderClientOption{
894			provider.WithModel(config.SmallModel),
895			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
896		}
897		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
898		if err != nil {
899			return fmt.Errorf("failed to create new summarize provider: %w", err)
900		}
901
902		// Update the providers and provider ID
903		a.titleProvider = newTitleProvider
904		a.summarizeProvider = newSummarizeProvider
905		a.summarizeProviderID = string(smallModelProviderCfg.ID)
906	}
907
908	return nil
909}