1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"slices"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/history"
 14	"github.com/charmbracelet/crush/internal/llm/prompt"
 15	"github.com/charmbracelet/crush/internal/llm/provider"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/session"
 23)
 24
 25// Common errors
 26var (
 27	ErrRequestCancelled = errors.New("request cancelled by user")
 28	ErrSessionBusy      = errors.New("session is currently processing another request")
 29)
 30
 31type AgentEventType string
 32
 33const (
 34	AgentEventTypeError     AgentEventType = "error"
 35	AgentEventTypeResponse  AgentEventType = "response"
 36	AgentEventTypeSummarize AgentEventType = "summarize"
 37)
 38
 39type AgentEvent struct {
 40	Type    AgentEventType
 41	Message message.Message
 42	Error   error
 43
 44	// When summarizing
 45	SessionID string
 46	Progress  string
 47	Done      bool
 48}
 49
 50type Service interface {
 51	pubsub.Suscriber[AgentEvent]
 52	Model() config.Model
 53	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 54	Cancel(sessionID string)
 55	CancelAll()
 56	IsSessionBusy(sessionID string) bool
 57	IsBusy() bool
 58	Summarize(ctx context.Context, sessionID string) error
 59	UpdateModel() error
 60}
 61
 62type agent struct {
 63	*pubsub.Broker[AgentEvent]
 64	agentCfg config.Agent
 65	sessions session.Service
 66	messages message.Service
 67
 68	tools      []tools.BaseTool
 69	provider   provider.Provider
 70	providerID string
 71
 72	titleProvider       provider.Provider
 73	summarizeProvider   provider.Provider
 74	summarizeProviderID string
 75
 76	activeRequests sync.Map
 77}
 78
 79var agentPromptMap = map[config.AgentID]prompt.PromptID{
 80	config.AgentCoder: prompt.PromptCoder,
 81	config.AgentTask:  prompt.PromptTask,
 82}
 83
 84func NewAgent(
 85	agentCfg config.Agent,
 86	// These services are needed in the tools
 87	permissions permission.Service,
 88	sessions session.Service,
 89	messages message.Service,
 90	history history.Service,
 91	lspClients map[string]*lsp.Client,
 92) (Service, error) {
 93	ctx := context.Background()
 94	cfg := config.Get()
 95	otherTools := GetMcpTools(ctx, permissions)
 96	if len(lspClients) > 0 {
 97		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 98	}
 99
100	allTools := []tools.BaseTool{
101		tools.NewBashTool(permissions),
102		tools.NewEditTool(lspClients, permissions, history),
103		tools.NewFetchTool(permissions),
104		tools.NewGlobTool(),
105		tools.NewGrepTool(),
106		tools.NewLsTool(),
107		tools.NewSourcegraphTool(),
108		tools.NewViewTool(lspClients),
109		tools.NewWriteTool(lspClients, permissions, history),
110	}
111
112	if agentCfg.ID == config.AgentCoder {
113		taskAgentCfg := config.Get().Agents[config.AgentTask]
114		if taskAgentCfg.ID == "" {
115			return nil, fmt.Errorf("task agent not found in config")
116		}
117		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
118		if err != nil {
119			return nil, fmt.Errorf("failed to create task agent: %w", err)
120		}
121
122		allTools = append(
123			allTools,
124			NewAgentTool(
125				taskAgent,
126				sessions,
127				messages,
128			),
129		)
130	}
131
132	allTools = append(allTools, otherTools...)
133	providerCfg := config.GetAgentProvider(agentCfg.ID)
134	if providerCfg.ID == "" {
135		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
136	}
137	model := config.GetAgentModel(agentCfg.ID)
138
139	if model.ID == "" {
140		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
141	}
142
143	promptID := agentPromptMap[agentCfg.ID]
144	if promptID == "" {
145		promptID = prompt.PromptDefault
146	}
147	opts := []provider.ProviderClientOption{
148		provider.WithModel(agentCfg.Model),
149		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
150	}
151	agentProvider, err := provider.NewProvider(providerCfg, opts...)
152	if err != nil {
153		return nil, err
154	}
155
156	smallModelCfg := cfg.Models.Small
157	var smallModel config.Model
158
159	var smallModelProviderCfg config.ProviderConfig
160	if smallModelCfg.Provider == providerCfg.ID {
161		smallModelProviderCfg = providerCfg
162	} else {
163		for _, p := range cfg.Providers {
164			if p.ID == smallModelCfg.Provider {
165				smallModelProviderCfg = p
166				break
167			}
168		}
169		if smallModelProviderCfg.ID == "" {
170			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
171		}
172	}
173	for _, m := range smallModelProviderCfg.Models {
174		if m.ID == smallModelCfg.ModelID {
175			smallModel = m
176			break
177		}
178	}
179	if smallModel.ID == "" {
180		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
181	}
182
183	titleOpts := []provider.ProviderClientOption{
184		provider.WithModel(config.SmallModel),
185		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
186	}
187	titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
188	if err != nil {
189		return nil, err
190	}
191	summarizeOpts := []provider.ProviderClientOption{
192		provider.WithModel(config.SmallModel),
193		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
194	}
195	summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
196	if err != nil {
197		return nil, err
198	}
199
200	agentTools := []tools.BaseTool{}
201	if agentCfg.AllowedTools == nil {
202		agentTools = allTools
203	} else {
204		for _, tool := range allTools {
205			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
206				agentTools = append(agentTools, tool)
207			}
208		}
209	}
210
211	agent := &agent{
212		Broker:              pubsub.NewBroker[AgentEvent](),
213		agentCfg:            agentCfg,
214		provider:            agentProvider,
215		providerID:          string(providerCfg.ID),
216		messages:            messages,
217		sessions:            sessions,
218		tools:               agentTools,
219		titleProvider:       titleProvider,
220		summarizeProvider:   summarizeProvider,
221		summarizeProviderID: string(smallModelProviderCfg.ID),
222		activeRequests:      sync.Map{},
223	}
224
225	return agent, nil
226}
227
228func (a *agent) Model() config.Model {
229	return config.GetAgentModel(a.agentCfg.ID)
230}
231
232func (a *agent) Cancel(sessionID string) {
233	// Cancel regular requests
234	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
235		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
237			cancel()
238		}
239	}
240
241	// Also check for summarize requests
242	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
243		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
244			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
245			cancel()
246		}
247	}
248}
249
250func (a *agent) IsBusy() bool {
251	busy := false
252	a.activeRequests.Range(func(key, value any) bool {
253		if cancelFunc, ok := value.(context.CancelFunc); ok {
254			if cancelFunc != nil {
255				busy = true
256				return false // Stop iterating
257			}
258		}
259		return true // Continue iterating
260	})
261	return busy
262}
263
264func (a *agent) IsSessionBusy(sessionID string) bool {
265	_, busy := a.activeRequests.Load(sessionID)
266	return busy
267}
268
269func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
270	if content == "" {
271		return nil
272	}
273	if a.titleProvider == nil {
274		return nil
275	}
276	session, err := a.sessions.Get(ctx, sessionID)
277	if err != nil {
278		return err
279	}
280	parts := []message.ContentPart{message.TextContent{
281		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
282	}}
283
284	// Use streaming approach like summarization
285	response := a.titleProvider.StreamResponse(
286		ctx,
287		[]message.Message{
288			{
289				Role:  message.User,
290				Parts: parts,
291			},
292		},
293		make([]tools.BaseTool, 0),
294	)
295
296	var finalResponse *provider.ProviderResponse
297	for r := range response {
298		if r.Error != nil {
299			return r.Error
300		}
301		finalResponse = r.Response
302	}
303
304	if finalResponse == nil {
305		return fmt.Errorf("no response received from title provider")
306	}
307
308	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
309	if title == "" {
310		return nil
311	}
312
313	session.Title = title
314	_, err = a.sessions.Save(ctx, session)
315	return err
316}
317
318func (a *agent) err(err error) AgentEvent {
319	return AgentEvent{
320		Type:  AgentEventTypeError,
321		Error: err,
322	}
323}
324
325func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
326	if !a.Model().SupportsImages && attachments != nil {
327		attachments = nil
328	}
329	events := make(chan AgentEvent)
330	if a.IsSessionBusy(sessionID) {
331		return nil, ErrSessionBusy
332	}
333
334	genCtx, cancel := context.WithCancel(ctx)
335
336	a.activeRequests.Store(sessionID, cancel)
337	go func() {
338		logging.Debug("Request started", "sessionID", sessionID)
339		defer logging.RecoverPanic("agent.Run", func() {
340			events <- a.err(fmt.Errorf("panic while running the agent"))
341		})
342		var attachmentParts []message.ContentPart
343		for _, attachment := range attachments {
344			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
345		}
346		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
347		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
348			logging.ErrorPersist(result.Error.Error())
349		}
350		logging.Debug("Request completed", "sessionID", sessionID)
351		a.activeRequests.Delete(sessionID)
352		cancel()
353		a.Publish(pubsub.CreatedEvent, result)
354		events <- result
355		close(events)
356	}()
357	return events, nil
358}
359
360func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
361	// List existing messages; if none, start title generation asynchronously.
362	msgs, err := a.messages.List(ctx, sessionID)
363	if err != nil {
364		return a.err(fmt.Errorf("failed to list messages: %w", err))
365	}
366	if len(msgs) == 0 {
367		go func() {
368			defer logging.RecoverPanic("agent.Run", func() {
369				logging.ErrorPersist("panic while generating title")
370			})
371			titleErr := a.generateTitle(context.Background(), sessionID, content)
372			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
373				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
374			}
375		}()
376	}
377	session, err := a.sessions.Get(ctx, sessionID)
378	if err != nil {
379		return a.err(fmt.Errorf("failed to get session: %w", err))
380	}
381	if session.SummaryMessageID != "" {
382		summaryMsgInex := -1
383		for i, msg := range msgs {
384			if msg.ID == session.SummaryMessageID {
385				summaryMsgInex = i
386				break
387			}
388		}
389		if summaryMsgInex != -1 {
390			msgs = msgs[summaryMsgInex:]
391			msgs[0].Role = message.User
392		}
393	}
394
395	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
396	if err != nil {
397		return a.err(fmt.Errorf("failed to create user message: %w", err))
398	}
399	// Append the new user message to the conversation history.
400	msgHistory := append(msgs, userMsg)
401
402	for {
403		// Check for cancellation before each iteration
404		select {
405		case <-ctx.Done():
406			return a.err(ctx.Err())
407		default:
408			// Continue processing
409		}
410		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
411		if err != nil {
412			if errors.Is(err, context.Canceled) {
413				agentMessage.AddFinish(message.FinishReasonCanceled)
414				a.messages.Update(context.Background(), agentMessage)
415				return a.err(ErrRequestCancelled)
416			}
417			return a.err(fmt.Errorf("failed to process events: %w", err))
418		}
419		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
420		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
421			// We are not done, we need to respond with the tool response
422			msgHistory = append(msgHistory, agentMessage, *toolResults)
423			continue
424		}
425		return AgentEvent{
426			Type:    AgentEventTypeResponse,
427			Message: agentMessage,
428			Done:    true,
429		}
430	}
431}
432
433func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
434	parts := []message.ContentPart{message.TextContent{Text: content}}
435	parts = append(parts, attachmentParts...)
436	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
437		Role:  message.User,
438		Parts: parts,
439	})
440}
441
442func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
443	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
444
445	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
446		Role:     message.Assistant,
447		Parts:    []message.ContentPart{},
448		Model:    a.Model().ID,
449		Provider: a.providerID,
450	})
451	if err != nil {
452		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
453	}
454
455	// Add the session and message ID into the context if needed by tools.
456	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
457	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
458
459	// Process each event in the stream.
460	for event := range eventChan {
461		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
462			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
463			return assistantMsg, nil, processErr
464		}
465		if ctx.Err() != nil {
466			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
467			return assistantMsg, nil, ctx.Err()
468		}
469	}
470
471	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
472	toolCalls := assistantMsg.ToolCalls()
473	for i, toolCall := range toolCalls {
474		select {
475		case <-ctx.Done():
476			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
477			// Make all future tool calls cancelled
478			for j := i; j < len(toolCalls); j++ {
479				toolResults[j] = message.ToolResult{
480					ToolCallID: toolCalls[j].ID,
481					Content:    "Tool execution canceled by user",
482					IsError:    true,
483				}
484			}
485			goto out
486		default:
487			// Continue processing
488			var tool tools.BaseTool
489			for _, availableTools := range a.tools {
490				if availableTools.Info().Name == toolCall.Name {
491					tool = availableTools
492				}
493			}
494
495			// Tool not found
496			if tool == nil {
497				toolResults[i] = message.ToolResult{
498					ToolCallID: toolCall.ID,
499					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
500					IsError:    true,
501				}
502				continue
503			}
504			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
505				ID:    toolCall.ID,
506				Name:  toolCall.Name,
507				Input: toolCall.Input,
508			})
509			if toolErr != nil {
510				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
511					toolResults[i] = message.ToolResult{
512						ToolCallID: toolCall.ID,
513						Content:    "Permission denied",
514						IsError:    true,
515					}
516					for j := i + 1; j < len(toolCalls); j++ {
517						toolResults[j] = message.ToolResult{
518							ToolCallID: toolCalls[j].ID,
519							Content:    "Tool execution canceled by user",
520							IsError:    true,
521						}
522					}
523					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
524					break
525				}
526			}
527			toolResults[i] = message.ToolResult{
528				ToolCallID: toolCall.ID,
529				Content:    toolResult.Content,
530				Metadata:   toolResult.Metadata,
531				IsError:    toolResult.IsError,
532			}
533		}
534	}
535out:
536	if len(toolResults) == 0 {
537		return assistantMsg, nil, nil
538	}
539	parts := make([]message.ContentPart, 0)
540	for _, tr := range toolResults {
541		parts = append(parts, tr)
542	}
543	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
544		Role:     message.Tool,
545		Parts:    parts,
546		Provider: a.providerID,
547	})
548	if err != nil {
549		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
550	}
551
552	return assistantMsg, &msg, err
553}
554
555func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
556	msg.AddFinish(finishReson)
557	_ = a.messages.Update(ctx, *msg)
558}
559
560func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
561	select {
562	case <-ctx.Done():
563		return ctx.Err()
564	default:
565		// Continue processing.
566	}
567
568	switch event.Type {
569	case provider.EventThinkingDelta:
570		assistantMsg.AppendReasoningContent(event.Content)
571		return a.messages.Update(ctx, *assistantMsg)
572	case provider.EventContentDelta:
573		assistantMsg.AppendContent(event.Content)
574		return a.messages.Update(ctx, *assistantMsg)
575	case provider.EventToolUseStart:
576		logging.Info("Tool call started", "toolCall", event.ToolCall)
577		assistantMsg.AddToolCall(*event.ToolCall)
578		return a.messages.Update(ctx, *assistantMsg)
579	case provider.EventToolUseDelta:
580		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
581		return a.messages.Update(ctx, *assistantMsg)
582	case provider.EventToolUseStop:
583		logging.Info("Finished tool call", "toolCall", event.ToolCall)
584		assistantMsg.FinishToolCall(event.ToolCall.ID)
585		return a.messages.Update(ctx, *assistantMsg)
586	case provider.EventError:
587		if errors.Is(event.Error, context.Canceled) {
588			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
589			return context.Canceled
590		}
591		logging.ErrorPersist(event.Error.Error())
592		return event.Error
593	case provider.EventComplete:
594		assistantMsg.SetToolCalls(event.Response.ToolCalls)
595		assistantMsg.AddFinish(event.Response.FinishReason)
596		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
597			return fmt.Errorf("failed to update message: %w", err)
598		}
599		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
600	}
601
602	return nil
603}
604
605func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
606	sess, err := a.sessions.Get(ctx, sessionID)
607	if err != nil {
608		return fmt.Errorf("failed to get session: %w", err)
609	}
610
611	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
612		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
613		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
614		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
615
616	sess.Cost += cost
617	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
618	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
619
620	_, err = a.sessions.Save(ctx, sess)
621	if err != nil {
622		return fmt.Errorf("failed to save session: %w", err)
623	}
624	return nil
625}
626
627func (a *agent) Summarize(ctx context.Context, sessionID string) error {
628	if a.summarizeProvider == nil {
629		return fmt.Errorf("summarize provider not available")
630	}
631
632	// Check if session is busy
633	if a.IsSessionBusy(sessionID) {
634		return ErrSessionBusy
635	}
636
637	// Create a new context with cancellation
638	summarizeCtx, cancel := context.WithCancel(ctx)
639
640	// Store the cancel function in activeRequests to allow cancellation
641	a.activeRequests.Store(sessionID+"-summarize", cancel)
642
643	go func() {
644		defer a.activeRequests.Delete(sessionID + "-summarize")
645		defer cancel()
646		event := AgentEvent{
647			Type:     AgentEventTypeSummarize,
648			Progress: "Starting summarization...",
649		}
650
651		a.Publish(pubsub.CreatedEvent, event)
652		// Get all messages from the session
653		msgs, err := a.messages.List(summarizeCtx, sessionID)
654		if err != nil {
655			event = AgentEvent{
656				Type:  AgentEventTypeError,
657				Error: fmt.Errorf("failed to list messages: %w", err),
658				Done:  true,
659			}
660			a.Publish(pubsub.CreatedEvent, event)
661			return
662		}
663
664		if len(msgs) == 0 {
665			event = AgentEvent{
666				Type:  AgentEventTypeError,
667				Error: fmt.Errorf("no messages to summarize"),
668				Done:  true,
669			}
670			a.Publish(pubsub.CreatedEvent, event)
671			return
672		}
673
674		event = AgentEvent{
675			Type:     AgentEventTypeSummarize,
676			Progress: "Analyzing conversation...",
677		}
678		a.Publish(pubsub.CreatedEvent, event)
679
680		// Add a system message to guide the summarization
681		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
682
683		// Create a new message with the summarize prompt
684		promptMsg := message.Message{
685			Role:  message.User,
686			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
687		}
688
689		// Append the prompt to the messages
690		msgsWithPrompt := append(msgs, promptMsg)
691
692		event = AgentEvent{
693			Type:     AgentEventTypeSummarize,
694			Progress: "Generating summary...",
695		}
696
697		a.Publish(pubsub.CreatedEvent, event)
698
699		// Send the messages to the summarize provider
700		response := a.summarizeProvider.StreamResponse(
701			summarizeCtx,
702			msgsWithPrompt,
703			make([]tools.BaseTool, 0),
704		)
705		var finalResponse *provider.ProviderResponse
706		for r := range response {
707			if r.Error != nil {
708				event = AgentEvent{
709					Type:  AgentEventTypeError,
710					Error: fmt.Errorf("failed to summarize: %w", err),
711					Done:  true,
712				}
713				a.Publish(pubsub.CreatedEvent, event)
714				return
715			}
716			finalResponse = r.Response
717		}
718
719		summary := strings.TrimSpace(finalResponse.Content)
720		if summary == "" {
721			event = AgentEvent{
722				Type:  AgentEventTypeError,
723				Error: fmt.Errorf("empty summary returned"),
724				Done:  true,
725			}
726			a.Publish(pubsub.CreatedEvent, event)
727			return
728		}
729		event = AgentEvent{
730			Type:     AgentEventTypeSummarize,
731			Progress: "Creating new session...",
732		}
733
734		a.Publish(pubsub.CreatedEvent, event)
735		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
736		if err != nil {
737			event = AgentEvent{
738				Type:  AgentEventTypeError,
739				Error: fmt.Errorf("failed to get session: %w", err),
740				Done:  true,
741			}
742
743			a.Publish(pubsub.CreatedEvent, event)
744			return
745		}
746		// Create a message in the new session with the summary
747		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
748			Role: message.Assistant,
749			Parts: []message.ContentPart{
750				message.TextContent{Text: summary},
751				message.Finish{
752					Reason: message.FinishReasonEndTurn,
753					Time:   time.Now().Unix(),
754				},
755			},
756			Model:    a.summarizeProvider.Model().ID,
757			Provider: a.summarizeProviderID,
758		})
759		if err != nil {
760			event = AgentEvent{
761				Type:  AgentEventTypeError,
762				Error: fmt.Errorf("failed to create summary message: %w", err),
763				Done:  true,
764			}
765
766			a.Publish(pubsub.CreatedEvent, event)
767			return
768		}
769		oldSession.SummaryMessageID = msg.ID
770		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
771		oldSession.PromptTokens = 0
772		model := a.summarizeProvider.Model()
773		usage := finalResponse.Usage
774		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
775			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
776			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
777			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
778		oldSession.Cost += cost
779		_, err = a.sessions.Save(summarizeCtx, oldSession)
780		if err != nil {
781			event = AgentEvent{
782				Type:  AgentEventTypeError,
783				Error: fmt.Errorf("failed to save session: %w", err),
784				Done:  true,
785			}
786			a.Publish(pubsub.CreatedEvent, event)
787		}
788
789		event = AgentEvent{
790			Type:      AgentEventTypeSummarize,
791			SessionID: oldSession.ID,
792			Progress:  "Summary complete",
793			Done:      true,
794		}
795		a.Publish(pubsub.CreatedEvent, event)
796		// Send final success event with the new session ID
797	}()
798
799	return nil
800}
801
802func (a *agent) CancelAll() {
803	a.activeRequests.Range(func(key, value any) bool {
804		a.Cancel(key.(string)) // key is sessionID
805		return true
806	})
807}
808
809func (a *agent) UpdateModel() error {
810	cfg := config.Get()
811
812	// Get current provider configuration
813	currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
814	if currentProviderCfg.ID == "" {
815		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
816	}
817
818	// Check if provider has changed
819	if string(currentProviderCfg.ID) != a.providerID {
820		// Provider changed, need to recreate the main provider
821		model := config.GetAgentModel(a.agentCfg.ID)
822		if model.ID == "" {
823			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
824		}
825
826		promptID := agentPromptMap[a.agentCfg.ID]
827		if promptID == "" {
828			promptID = prompt.PromptDefault
829		}
830
831		opts := []provider.ProviderClientOption{
832			provider.WithModel(a.agentCfg.Model),
833			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
834		}
835
836		newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
837		if err != nil {
838			return fmt.Errorf("failed to create new provider: %w", err)
839		}
840
841		// Update the provider and provider ID
842		a.provider = newProvider
843		a.providerID = string(currentProviderCfg.ID)
844	}
845
846	// Check if small model provider has changed (affects title and summarize providers)
847	smallModelCfg := cfg.Models.Small
848	var smallModelProviderCfg config.ProviderConfig
849
850	for _, p := range cfg.Providers {
851		if p.ID == smallModelCfg.Provider {
852			smallModelProviderCfg = p
853			break
854		}
855	}
856
857	if smallModelProviderCfg.ID == "" {
858		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
859	}
860
861	// Check if summarize provider has changed
862	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
863		var smallModel config.Model
864		for _, m := range smallModelProviderCfg.Models {
865			if m.ID == smallModelCfg.ModelID {
866				smallModel = m
867				break
868			}
869		}
870		if smallModel.ID == "" {
871			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
872		}
873
874		// Recreate title provider
875		titleOpts := []provider.ProviderClientOption{
876			provider.WithModel(config.SmallModel),
877			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
878			// We want the title to be short, so we limit the max tokens
879			provider.WithMaxTokens(40),
880		}
881		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
882		if err != nil {
883			return fmt.Errorf("failed to create new title provider: %w", err)
884		}
885
886		// Recreate summarize provider
887		summarizeOpts := []provider.ProviderClientOption{
888			provider.WithModel(config.SmallModel),
889			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
890		}
891		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
892		if err != nil {
893			return fmt.Errorf("failed to create new summarize provider: %w", err)
894		}
895
896		// Update the providers and provider ID
897		a.titleProvider = newTitleProvider
898		a.summarizeProvider = newSummarizeProvider
899		a.summarizeProviderID = string(smallModelProviderCfg.ID)
900	}
901
902	return nil
903}