agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"slices"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/history"
 14	"github.com/charmbracelet/crush/internal/llm/prompt"
 15	"github.com/charmbracelet/crush/internal/llm/provider"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/session"
 23)
 24
 25// Common errors
 26var (
 27	ErrRequestCancelled = errors.New("request cancelled by user")
 28	ErrSessionBusy      = errors.New("session is currently processing another request")
 29)
 30
 31type AgentEventType string
 32
 33const (
 34	AgentEventTypeError     AgentEventType = "error"
 35	AgentEventTypeResponse  AgentEventType = "response"
 36	AgentEventTypeSummarize AgentEventType = "summarize"
 37)
 38
 39type AgentEvent struct {
 40	Type    AgentEventType
 41	Message message.Message
 42	Error   error
 43
 44	// When summarizing
 45	SessionID string
 46	Progress  string
 47	Done      bool
 48}
 49
 50type Service interface {
 51	pubsub.Suscriber[AgentEvent]
 52	Model() config.Model
 53	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 54	Cancel(sessionID string)
 55	CancelAll()
 56	IsSessionBusy(sessionID string) bool
 57	IsBusy() bool
 58	Summarize(ctx context.Context, sessionID string) error
 59	UpdateModel() error
 60}
 61
 62type agent struct {
 63	*pubsub.Broker[AgentEvent]
 64	agentCfg config.Agent
 65	sessions session.Service
 66	messages message.Service
 67
 68	tools      []tools.BaseTool
 69	provider   provider.Provider
 70	providerID string
 71
 72	titleProvider       provider.Provider
 73	summarizeProvider   provider.Provider
 74	summarizeProviderID string
 75
 76	activeRequests sync.Map
 77}
 78
 79var agentPromptMap = map[config.AgentID]prompt.PromptID{
 80	config.AgentCoder: prompt.PromptCoder,
 81	config.AgentTask:  prompt.PromptTask,
 82}
 83
 84func NewAgent(
 85	agentCfg config.Agent,
 86	// These services are needed in the tools
 87	permissions permission.Service,
 88	sessions session.Service,
 89	messages message.Service,
 90	history history.Service,
 91	lspClients map[string]*lsp.Client,
 92) (Service, error) {
 93	ctx := context.Background()
 94	cfg := config.Get()
 95	otherTools := GetMcpTools(ctx, permissions)
 96	if len(lspClients) > 0 {
 97		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 98	}
 99
100	allTools := []tools.BaseTool{
101		tools.NewBashTool(permissions),
102		tools.NewEditTool(lspClients, permissions, history),
103		tools.NewFetchTool(permissions),
104		tools.NewGlobTool(),
105		tools.NewGrepTool(),
106		tools.NewLsTool(),
107		tools.NewSourcegraphTool(),
108		tools.NewViewTool(lspClients),
109		tools.NewWriteTool(lspClients, permissions, history),
110	}
111
112	if agentCfg.ID == config.AgentCoder {
113		taskAgentCfg := config.Get().Agents[config.AgentTask]
114		if taskAgentCfg.ID == "" {
115			return nil, fmt.Errorf("task agent not found in config")
116		}
117		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
118		if err != nil {
119			return nil, fmt.Errorf("failed to create task agent: %w", err)
120		}
121
122		allTools = append(
123			allTools,
124			NewAgentTool(
125				taskAgent,
126				sessions,
127				messages,
128			),
129		)
130	}
131
132	allTools = append(allTools, otherTools...)
133	providerCfg := config.GetAgentProvider(agentCfg.ID)
134	if providerCfg.ID == "" {
135		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
136	}
137	model := config.GetAgentModel(agentCfg.ID)
138
139	if model.ID == "" {
140		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
141	}
142
143	promptID := agentPromptMap[agentCfg.ID]
144	if promptID == "" {
145		promptID = prompt.PromptDefault
146	}
147	opts := []provider.ProviderClientOption{
148		provider.WithModel(agentCfg.Model),
149		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
150		provider.WithMaxTokens(model.DefaultMaxTokens),
151	}
152	agentProvider, err := provider.NewProviderV2(providerCfg, opts...)
153	if err != nil {
154		return nil, err
155	}
156
157	smallModelCfg := cfg.Models.Small
158	var smallModel config.Model
159
160	var smallModelProviderCfg config.ProviderConfig
161	if smallModelCfg.Provider == providerCfg.ID {
162		smallModelProviderCfg = providerCfg
163	} else {
164		for _, p := range cfg.Providers {
165			if p.ID == smallModelCfg.Provider {
166				smallModelProviderCfg = p
167				break
168			}
169		}
170		if smallModelProviderCfg.ID == "" {
171			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
172		}
173	}
174	for _, m := range smallModelProviderCfg.Models {
175		if m.ID == smallModelCfg.ModelID {
176			smallModel = m
177			break
178		}
179	}
180	if smallModel.ID == "" {
181		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
182	}
183
184	titleOpts := []provider.ProviderClientOption{
185		provider.WithModel(config.SmallModel),
186		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
187		provider.WithMaxTokens(40),
188	}
189	titleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
190	if err != nil {
191		return nil, err
192	}
193	summarizeOpts := []provider.ProviderClientOption{
194		provider.WithModel(config.SmallModel),
195		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
196		provider.WithMaxTokens(smallModel.DefaultMaxTokens),
197	}
198	summarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
199	if err != nil {
200		return nil, err
201	}
202
203	agentTools := []tools.BaseTool{}
204	if agentCfg.AllowedTools == nil {
205		agentTools = allTools
206	} else {
207		for _, tool := range allTools {
208			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
209				agentTools = append(agentTools, tool)
210			}
211		}
212	}
213
214	agent := &agent{
215		Broker:              pubsub.NewBroker[AgentEvent](),
216		agentCfg:            agentCfg,
217		provider:            agentProvider,
218		providerID:          string(providerCfg.ID),
219		messages:            messages,
220		sessions:            sessions,
221		tools:               agentTools,
222		titleProvider:       titleProvider,
223		summarizeProvider:   summarizeProvider,
224		summarizeProviderID: string(smallModelProviderCfg.ID),
225		activeRequests:      sync.Map{},
226	}
227
228	return agent, nil
229}
230
231func (a *agent) Model() config.Model {
232	return config.GetAgentModel(a.agentCfg.ID)
233}
234
235func (a *agent) Cancel(sessionID string) {
236	// Cancel regular requests
237	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
238		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
239			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
240			cancel()
241		}
242	}
243
244	// Also check for summarize requests
245	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
246		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
247			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
248			cancel()
249		}
250	}
251}
252
253func (a *agent) IsBusy() bool {
254	busy := false
255	a.activeRequests.Range(func(key, value any) bool {
256		if cancelFunc, ok := value.(context.CancelFunc); ok {
257			if cancelFunc != nil {
258				busy = true
259				return false // Stop iterating
260			}
261		}
262		return true // Continue iterating
263	})
264	return busy
265}
266
267func (a *agent) IsSessionBusy(sessionID string) bool {
268	_, busy := a.activeRequests.Load(sessionID)
269	return busy
270}
271
272func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
273	if content == "" {
274		return nil
275	}
276	if a.titleProvider == nil {
277		return nil
278	}
279	session, err := a.sessions.Get(ctx, sessionID)
280	if err != nil {
281		return err
282	}
283	parts := []message.ContentPart{message.TextContent{Text: content}}
284
285	// Use streaming approach like summarization
286	response := a.titleProvider.StreamResponse(
287		ctx,
288		[]message.Message{
289			{
290				Role:  message.User,
291				Parts: parts,
292			},
293		},
294		make([]tools.BaseTool, 0),
295	)
296
297	var finalResponse *provider.ProviderResponse
298	for r := range response {
299		if r.Error != nil {
300			return r.Error
301		}
302		finalResponse = r.Response
303	}
304
305	if finalResponse == nil {
306		return fmt.Errorf("no response received from title provider")
307	}
308
309	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
310	if title == "" {
311		return nil
312	}
313
314	session.Title = title
315	_, err = a.sessions.Save(ctx, session)
316	return err
317}
318
319func (a *agent) err(err error) AgentEvent {
320	return AgentEvent{
321		Type:  AgentEventTypeError,
322		Error: err,
323	}
324}
325
326func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
327	if !a.Model().SupportsImages && attachments != nil {
328		attachments = nil
329	}
330	events := make(chan AgentEvent)
331	if a.IsSessionBusy(sessionID) {
332		return nil, ErrSessionBusy
333	}
334
335	genCtx, cancel := context.WithCancel(ctx)
336
337	a.activeRequests.Store(sessionID, cancel)
338	go func() {
339		logging.Debug("Request started", "sessionID", sessionID)
340		defer logging.RecoverPanic("agent.Run", func() {
341			events <- a.err(fmt.Errorf("panic while running the agent"))
342		})
343		var attachmentParts []message.ContentPart
344		for _, attachment := range attachments {
345			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
346		}
347		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
348		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
349			logging.ErrorPersist(result.Error.Error())
350		}
351		logging.Debug("Request completed", "sessionID", sessionID)
352		a.activeRequests.Delete(sessionID)
353		cancel()
354		a.Publish(pubsub.CreatedEvent, result)
355		events <- result
356		close(events)
357	}()
358	return events, nil
359}
360
361func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
362	// List existing messages; if none, start title generation asynchronously.
363	msgs, err := a.messages.List(ctx, sessionID)
364	if err != nil {
365		return a.err(fmt.Errorf("failed to list messages: %w", err))
366	}
367	if len(msgs) == 0 {
368		go func() {
369			defer logging.RecoverPanic("agent.Run", func() {
370				logging.ErrorPersist("panic while generating title")
371			})
372			titleErr := a.generateTitle(context.Background(), sessionID, content)
373			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
374				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
375			}
376		}()
377	}
378	session, err := a.sessions.Get(ctx, sessionID)
379	if err != nil {
380		return a.err(fmt.Errorf("failed to get session: %w", err))
381	}
382	if session.SummaryMessageID != "" {
383		summaryMsgInex := -1
384		for i, msg := range msgs {
385			if msg.ID == session.SummaryMessageID {
386				summaryMsgInex = i
387				break
388			}
389		}
390		if summaryMsgInex != -1 {
391			msgs = msgs[summaryMsgInex:]
392			msgs[0].Role = message.User
393		}
394	}
395
396	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
397	if err != nil {
398		return a.err(fmt.Errorf("failed to create user message: %w", err))
399	}
400	// Append the new user message to the conversation history.
401	msgHistory := append(msgs, userMsg)
402
403	for {
404		// Check for cancellation before each iteration
405		select {
406		case <-ctx.Done():
407			return a.err(ctx.Err())
408		default:
409			// Continue processing
410		}
411		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
412		if err != nil {
413			if errors.Is(err, context.Canceled) {
414				agentMessage.AddFinish(message.FinishReasonCanceled)
415				a.messages.Update(context.Background(), agentMessage)
416				return a.err(ErrRequestCancelled)
417			}
418			return a.err(fmt.Errorf("failed to process events: %w", err))
419		}
420		logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
421		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
422			// We are not done, we need to respond with the tool response
423			msgHistory = append(msgHistory, agentMessage, *toolResults)
424			continue
425		}
426		return AgentEvent{
427			Type:    AgentEventTypeResponse,
428			Message: agentMessage,
429			Done:    true,
430		}
431	}
432}
433
434func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
435	parts := []message.ContentPart{message.TextContent{Text: content}}
436	parts = append(parts, attachmentParts...)
437	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
438		Role:  message.User,
439		Parts: parts,
440	})
441}
442
443func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
444	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
445
446	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
447		Role:     message.Assistant,
448		Parts:    []message.ContentPart{},
449		Model:    a.Model().ID,
450		Provider: a.providerID,
451	})
452	if err != nil {
453		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
454	}
455
456	// Add the session and message ID into the context if needed by tools.
457	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
458	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
459
460	// Process each event in the stream.
461	for event := range eventChan {
462		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
463			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
464			return assistantMsg, nil, processErr
465		}
466		if ctx.Err() != nil {
467			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
468			return assistantMsg, nil, ctx.Err()
469		}
470	}
471
472	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
473	toolCalls := assistantMsg.ToolCalls()
474	for i, toolCall := range toolCalls {
475		select {
476		case <-ctx.Done():
477			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
478			// Make all future tool calls cancelled
479			for j := i; j < len(toolCalls); j++ {
480				toolResults[j] = message.ToolResult{
481					ToolCallID: toolCalls[j].ID,
482					Content:    "Tool execution canceled by user",
483					IsError:    true,
484				}
485			}
486			goto out
487		default:
488			// Continue processing
489			var tool tools.BaseTool
490			for _, availableTools := range a.tools {
491				if availableTools.Info().Name == toolCall.Name {
492					tool = availableTools
493				}
494			}
495
496			// Tool not found
497			if tool == nil {
498				toolResults[i] = message.ToolResult{
499					ToolCallID: toolCall.ID,
500					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
501					IsError:    true,
502				}
503				continue
504			}
505			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
506				ID:    toolCall.ID,
507				Name:  toolCall.Name,
508				Input: toolCall.Input,
509			})
510			if toolErr != nil {
511				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
512					toolResults[i] = message.ToolResult{
513						ToolCallID: toolCall.ID,
514						Content:    "Permission denied",
515						IsError:    true,
516					}
517					for j := i + 1; j < len(toolCalls); j++ {
518						toolResults[j] = message.ToolResult{
519							ToolCallID: toolCalls[j].ID,
520							Content:    "Tool execution canceled by user",
521							IsError:    true,
522						}
523					}
524					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
525					break
526				}
527			}
528			toolResults[i] = message.ToolResult{
529				ToolCallID: toolCall.ID,
530				Content:    toolResult.Content,
531				Metadata:   toolResult.Metadata,
532				IsError:    toolResult.IsError,
533			}
534		}
535	}
536out:
537	if len(toolResults) == 0 {
538		return assistantMsg, nil, nil
539	}
540	parts := make([]message.ContentPart, 0)
541	for _, tr := range toolResults {
542		parts = append(parts, tr)
543	}
544	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
545		Role:     message.Tool,
546		Parts:    parts,
547		Provider: a.providerID,
548	})
549	if err != nil {
550		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
551	}
552
553	return assistantMsg, &msg, err
554}
555
556func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
557	msg.AddFinish(finishReson)
558	_ = a.messages.Update(ctx, *msg)
559}
560
561func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
562	select {
563	case <-ctx.Done():
564		return ctx.Err()
565	default:
566		// Continue processing.
567	}
568
569	switch event.Type {
570	case provider.EventThinkingDelta:
571		assistantMsg.AppendReasoningContent(event.Content)
572		return a.messages.Update(ctx, *assistantMsg)
573	case provider.EventContentDelta:
574		assistantMsg.AppendContent(event.Content)
575		return a.messages.Update(ctx, *assistantMsg)
576	case provider.EventToolUseStart:
577		logging.Info("Tool call started", "toolCall", event.ToolCall)
578		assistantMsg.AddToolCall(*event.ToolCall)
579		return a.messages.Update(ctx, *assistantMsg)
580	case provider.EventToolUseDelta:
581		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
582		return a.messages.Update(ctx, *assistantMsg)
583	case provider.EventToolUseStop:
584		logging.Info("Finished tool call", "toolCall", event.ToolCall)
585		assistantMsg.FinishToolCall(event.ToolCall.ID)
586		return a.messages.Update(ctx, *assistantMsg)
587	case provider.EventError:
588		if errors.Is(event.Error, context.Canceled) {
589			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
590			return context.Canceled
591		}
592		logging.ErrorPersist(event.Error.Error())
593		return event.Error
594	case provider.EventComplete:
595		assistantMsg.SetToolCalls(event.Response.ToolCalls)
596		assistantMsg.AddFinish(event.Response.FinishReason)
597		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
598			return fmt.Errorf("failed to update message: %w", err)
599		}
600		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
601	}
602
603	return nil
604}
605
606func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
607	sess, err := a.sessions.Get(ctx, sessionID)
608	if err != nil {
609		return fmt.Errorf("failed to get session: %w", err)
610	}
611
612	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
613		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
614		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
615		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
616
617	sess.Cost += cost
618	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
619	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
620
621	_, err = a.sessions.Save(ctx, sess)
622	if err != nil {
623		return fmt.Errorf("failed to save session: %w", err)
624	}
625	return nil
626}
627
628func (a *agent) Summarize(ctx context.Context, sessionID string) error {
629	if a.summarizeProvider == nil {
630		return fmt.Errorf("summarize provider not available")
631	}
632
633	// Check if session is busy
634	if a.IsSessionBusy(sessionID) {
635		return ErrSessionBusy
636	}
637
638	// Create a new context with cancellation
639	summarizeCtx, cancel := context.WithCancel(ctx)
640
641	// Store the cancel function in activeRequests to allow cancellation
642	a.activeRequests.Store(sessionID+"-summarize", cancel)
643
644	go func() {
645		defer a.activeRequests.Delete(sessionID + "-summarize")
646		defer cancel()
647		event := AgentEvent{
648			Type:     AgentEventTypeSummarize,
649			Progress: "Starting summarization...",
650		}
651
652		a.Publish(pubsub.CreatedEvent, event)
653		// Get all messages from the session
654		msgs, err := a.messages.List(summarizeCtx, sessionID)
655		if err != nil {
656			event = AgentEvent{
657				Type:  AgentEventTypeError,
658				Error: fmt.Errorf("failed to list messages: %w", err),
659				Done:  true,
660			}
661			a.Publish(pubsub.CreatedEvent, event)
662			return
663		}
664
665		if len(msgs) == 0 {
666			event = AgentEvent{
667				Type:  AgentEventTypeError,
668				Error: fmt.Errorf("no messages to summarize"),
669				Done:  true,
670			}
671			a.Publish(pubsub.CreatedEvent, event)
672			return
673		}
674
675		event = AgentEvent{
676			Type:     AgentEventTypeSummarize,
677			Progress: "Analyzing conversation...",
678		}
679		a.Publish(pubsub.CreatedEvent, event)
680
681		// Add a system message to guide the summarization
682		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
683
684		// Create a new message with the summarize prompt
685		promptMsg := message.Message{
686			Role:  message.User,
687			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
688		}
689
690		// Append the prompt to the messages
691		msgsWithPrompt := append(msgs, promptMsg)
692
693		event = AgentEvent{
694			Type:     AgentEventTypeSummarize,
695			Progress: "Generating summary...",
696		}
697
698		a.Publish(pubsub.CreatedEvent, event)
699
700		// Send the messages to the summarize provider
701		response := a.summarizeProvider.StreamResponse(
702			summarizeCtx,
703			msgsWithPrompt,
704			make([]tools.BaseTool, 0),
705		)
706		var finalResponse *provider.ProviderResponse
707		for r := range response {
708			if r.Error != nil {
709				event = AgentEvent{
710					Type:  AgentEventTypeError,
711					Error: fmt.Errorf("failed to summarize: %w", err),
712					Done:  true,
713				}
714				a.Publish(pubsub.CreatedEvent, event)
715				return
716			}
717			finalResponse = r.Response
718		}
719
720		summary := strings.TrimSpace(finalResponse.Content)
721		if summary == "" {
722			event = AgentEvent{
723				Type:  AgentEventTypeError,
724				Error: fmt.Errorf("empty summary returned"),
725				Done:  true,
726			}
727			a.Publish(pubsub.CreatedEvent, event)
728			return
729		}
730		event = AgentEvent{
731			Type:     AgentEventTypeSummarize,
732			Progress: "Creating new session...",
733		}
734
735		a.Publish(pubsub.CreatedEvent, event)
736		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
737		if err != nil {
738			event = AgentEvent{
739				Type:  AgentEventTypeError,
740				Error: fmt.Errorf("failed to get session: %w", err),
741				Done:  true,
742			}
743
744			a.Publish(pubsub.CreatedEvent, event)
745			return
746		}
747		// Create a message in the new session with the summary
748		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
749			Role: message.Assistant,
750			Parts: []message.ContentPart{
751				message.TextContent{Text: summary},
752				message.Finish{
753					Reason: message.FinishReasonEndTurn,
754					Time:   time.Now().Unix(),
755				},
756			},
757			Model:    a.summarizeProvider.Model().ID,
758			Provider: a.summarizeProviderID,
759		})
760		if err != nil {
761			event = AgentEvent{
762				Type:  AgentEventTypeError,
763				Error: fmt.Errorf("failed to create summary message: %w", err),
764				Done:  true,
765			}
766
767			a.Publish(pubsub.CreatedEvent, event)
768			return
769		}
770		oldSession.SummaryMessageID = msg.ID
771		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
772		oldSession.PromptTokens = 0
773		model := a.summarizeProvider.Model()
774		usage := finalResponse.Usage
775		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
776			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
777			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
778			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
779		oldSession.Cost += cost
780		_, err = a.sessions.Save(summarizeCtx, oldSession)
781		if err != nil {
782			event = AgentEvent{
783				Type:  AgentEventTypeError,
784				Error: fmt.Errorf("failed to save session: %w", err),
785				Done:  true,
786			}
787			a.Publish(pubsub.CreatedEvent, event)
788		}
789
790		event = AgentEvent{
791			Type:      AgentEventTypeSummarize,
792			SessionID: oldSession.ID,
793			Progress:  "Summary complete",
794			Done:      true,
795		}
796		a.Publish(pubsub.CreatedEvent, event)
797		// Send final success event with the new session ID
798	}()
799
800	return nil
801}
802
803func (a *agent) CancelAll() {
804	a.activeRequests.Range(func(key, value any) bool {
805		a.Cancel(key.(string)) // key is sessionID
806		return true
807	})
808}
809
810func (a *agent) UpdateModel() error {
811	cfg := config.Get()
812
813	// Get current provider configuration
814	currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
815	if currentProviderCfg.ID == "" {
816		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
817	}
818
819	// Check if provider has changed
820	if string(currentProviderCfg.ID) != a.providerID {
821		// Provider changed, need to recreate the main provider
822		model := config.GetAgentModel(a.agentCfg.ID)
823		if model.ID == "" {
824			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
825		}
826
827		promptID := agentPromptMap[a.agentCfg.ID]
828		if promptID == "" {
829			promptID = prompt.PromptDefault
830		}
831
832		opts := []provider.ProviderClientOption{
833			provider.WithModel(a.agentCfg.Model),
834			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
835			provider.WithMaxTokens(model.DefaultMaxTokens),
836		}
837
838		newProvider, err := provider.NewProviderV2(currentProviderCfg, opts...)
839		if err != nil {
840			return fmt.Errorf("failed to create new provider: %w", err)
841		}
842
843		// Update the provider and provider ID
844		a.provider = newProvider
845		a.providerID = string(currentProviderCfg.ID)
846	}
847
848	// Check if small model provider has changed (affects title and summarize providers)
849	smallModelCfg := cfg.Models.Small
850	var smallModelProviderCfg config.ProviderConfig
851
852	for _, p := range cfg.Providers {
853		if p.ID == smallModelCfg.Provider {
854			smallModelProviderCfg = p
855			break
856		}
857	}
858
859	if smallModelProviderCfg.ID == "" {
860		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
861	}
862
863	// Check if summarize provider has changed
864	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
865		var smallModel config.Model
866		for _, m := range smallModelProviderCfg.Models {
867			if m.ID == smallModelCfg.ModelID {
868				smallModel = m
869				break
870			}
871		}
872		if smallModel.ID == "" {
873			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
874		}
875
876		// Recreate title provider
877		titleOpts := []provider.ProviderClientOption{
878			provider.WithModel(config.SmallModel),
879			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
880			provider.WithMaxTokens(40),
881		}
882		newTitleProvider, err := provider.NewProviderV2(smallModelProviderCfg, titleOpts...)
883		if err != nil {
884			return fmt.Errorf("failed to create new title provider: %w", err)
885		}
886
887		// Recreate summarize provider
888		summarizeOpts := []provider.ProviderClientOption{
889			provider.WithModel(config.SmallModel),
890			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
891			provider.WithMaxTokens(smallModel.DefaultMaxTokens),
892		}
893		newSummarizeProvider, err := provider.NewProviderV2(smallModelProviderCfg, summarizeOpts...)
894		if err != nil {
895			return fmt.Errorf("failed to create new summarize provider: %w", err)
896		}
897
898		// Update the providers and provider ID
899		a.titleProvider = newTitleProvider
900		a.summarizeProvider = newSummarizeProvider
901		a.summarizeProviderID = string(smallModelProviderCfg.ID)
902	}
903
904	return nil
905}