agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"slices"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	fur "github.com/charmbracelet/crush/internal/fur/provider"
 14	"github.com/charmbracelet/crush/internal/history"
 15	"github.com/charmbracelet/crush/internal/llm/prompt"
 16	"github.com/charmbracelet/crush/internal/llm/provider"
 17	"github.com/charmbracelet/crush/internal/llm/tools"
 18	"github.com/charmbracelet/crush/internal/logging"
 19	"github.com/charmbracelet/crush/internal/lsp"
 20	"github.com/charmbracelet/crush/internal/message"
 21	"github.com/charmbracelet/crush/internal/permission"
 22	"github.com/charmbracelet/crush/internal/pubsub"
 23	"github.com/charmbracelet/crush/internal/session"
 24)
 25
 26// Common errors
 27var (
 28	ErrRequestCancelled = errors.New("request cancelled by user")
 29	ErrSessionBusy      = errors.New("session is currently processing another request")
 30)
 31
 32type AgentEventType string
 33
 34const (
 35	AgentEventTypeError     AgentEventType = "error"
 36	AgentEventTypeResponse  AgentEventType = "response"
 37	AgentEventTypeSummarize AgentEventType = "summarize"
 38)
 39
 40type AgentEvent struct {
 41	Type    AgentEventType
 42	Message message.Message
 43	Error   error
 44
 45	// When summarizing
 46	SessionID string
 47	Progress  string
 48	Done      bool
 49}
 50
 51type Service interface {
 52	pubsub.Suscriber[AgentEvent]
 53	Model() fur.Model
 54	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 55	Cancel(sessionID string)
 56	CancelAll()
 57	IsSessionBusy(sessionID string) bool
 58	IsBusy() bool
 59	Summarize(ctx context.Context, sessionID string) error
 60	UpdateModel() error
 61}
 62
 63type agent struct {
 64	*pubsub.Broker[AgentEvent]
 65	agentCfg config.Agent
 66	sessions session.Service
 67	messages message.Service
 68
 69	tools      []tools.BaseTool
 70	provider   provider.Provider
 71	providerID string
 72
 73	titleProvider       provider.Provider
 74	summarizeProvider   provider.Provider
 75	summarizeProviderID string
 76
 77	activeRequests sync.Map
 78}
 79
 80var agentPromptMap = map[string]prompt.PromptID{
 81	"coder": prompt.PromptCoder,
 82	"task":  prompt.PromptTask,
 83}
 84
 85func NewAgent(
 86	agentCfg config.Agent,
 87	// These services are needed in the tools
 88	permissions permission.Service,
 89	sessions session.Service,
 90	messages message.Service,
 91	history history.Service,
 92	lspClients map[string]*lsp.Client,
 93) (Service, error) {
 94	ctx := context.Background()
 95	cfg := config.Get()
 96	otherTools := GetMcpTools(ctx, permissions)
 97	if len(lspClients) > 0 {
 98		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 99	}
100
101	allTools := []tools.BaseTool{
102		tools.NewBashTool(permissions),
103		tools.NewEditTool(lspClients, permissions, history),
104		tools.NewFetchTool(permissions),
105		tools.NewGlobTool(),
106		tools.NewGrepTool(),
107		tools.NewLsTool(),
108		tools.NewSourcegraphTool(),
109		tools.NewViewTool(lspClients),
110		tools.NewWriteTool(lspClients, permissions, history),
111	}
112
113	if agentCfg.ID == "coder" {
114		taskAgentCfg := config.Get().Agents["task"]
115		if taskAgentCfg.ID == "" {
116			return nil, fmt.Errorf("task agent not found in config")
117		}
118		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
119		if err != nil {
120			return nil, fmt.Errorf("failed to create task agent: %w", err)
121		}
122
123		allTools = append(
124			allTools,
125			NewAgentTool(
126				taskAgent,
127				sessions,
128				messages,
129			),
130		)
131	}
132
133	allTools = append(allTools, otherTools...)
134	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
135	if providerCfg == nil {
136		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
137	}
138	model := config.Get().GetModelByType(agentCfg.Model)
139
140	if model == nil {
141		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
142	}
143
144	promptID := agentPromptMap[agentCfg.ID]
145	if promptID == "" {
146		promptID = prompt.PromptDefault
147	}
148	opts := []provider.ProviderClientOption{
149		provider.WithModel(agentCfg.Model),
150		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
151	}
152	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
153	if err != nil {
154		return nil, err
155	}
156
157	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
158	var smallModelProviderCfg *config.ProviderConfig
159	if smallModelCfg.Provider == providerCfg.ID {
160		smallModelProviderCfg = providerCfg
161	} else {
162		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
163
164		if smallModelProviderCfg.ID == "" {
165			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
166		}
167	}
168	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
169	if smallModel.ID == "" {
170		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
171	}
172
173	titleOpts := []provider.ProviderClientOption{
174		provider.WithModel(config.SelectedModelTypeSmall),
175		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
176	}
177	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
178	if err != nil {
179		return nil, err
180	}
181	summarizeOpts := []provider.ProviderClientOption{
182		provider.WithModel(config.SelectedModelTypeSmall),
183		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
184	}
185	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
186	if err != nil {
187		return nil, err
188	}
189
190	agentTools := []tools.BaseTool{}
191	if agentCfg.AllowedTools == nil {
192		agentTools = allTools
193	} else {
194		for _, tool := range allTools {
195			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
196				agentTools = append(agentTools, tool)
197			}
198		}
199	}
200
201	agent := &agent{
202		Broker:              pubsub.NewBroker[AgentEvent](),
203		agentCfg:            agentCfg,
204		provider:            agentProvider,
205		providerID:          string(providerCfg.ID),
206		messages:            messages,
207		sessions:            sessions,
208		tools:               agentTools,
209		titleProvider:       titleProvider,
210		summarizeProvider:   summarizeProvider,
211		summarizeProviderID: string(smallModelProviderCfg.ID),
212		activeRequests:      sync.Map{},
213	}
214
215	return agent, nil
216}
217
218func (a *agent) Model() fur.Model {
219	return *config.Get().GetModelByType(a.agentCfg.Model)
220}
221
222func (a *agent) Cancel(sessionID string) {
223	// Cancel regular requests
224	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
225		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
226			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
227			cancel()
228		}
229	}
230
231	// Also check for summarize requests
232	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
233		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
234			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
235			cancel()
236		}
237	}
238}
239
240func (a *agent) IsBusy() bool {
241	busy := false
242	a.activeRequests.Range(func(key, value any) bool {
243		if cancelFunc, ok := value.(context.CancelFunc); ok {
244			if cancelFunc != nil {
245				busy = true
246				return false
247			}
248		}
249		return true
250	})
251	return busy
252}
253
254func (a *agent) IsSessionBusy(sessionID string) bool {
255	_, busy := a.activeRequests.Load(sessionID)
256	return busy
257}
258
259func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
260	if content == "" {
261		return nil
262	}
263	if a.titleProvider == nil {
264		return nil
265	}
266	session, err := a.sessions.Get(ctx, sessionID)
267	if err != nil {
268		return err
269	}
270	parts := []message.ContentPart{message.TextContent{
271		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
272	}}
273
274	// Use streaming approach like summarization
275	response := a.titleProvider.StreamResponse(
276		ctx,
277		[]message.Message{
278			{
279				Role:  message.User,
280				Parts: parts,
281			},
282		},
283		make([]tools.BaseTool, 0),
284	)
285
286	var finalResponse *provider.ProviderResponse
287	for r := range response {
288		if r.Error != nil {
289			return r.Error
290		}
291		finalResponse = r.Response
292	}
293
294	if finalResponse == nil {
295		return fmt.Errorf("no response received from title provider")
296	}
297
298	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
299	if title == "" {
300		return nil
301	}
302
303	session.Title = title
304	_, err = a.sessions.Save(ctx, session)
305	return err
306}
307
308func (a *agent) err(err error) AgentEvent {
309	return AgentEvent{
310		Type:  AgentEventTypeError,
311		Error: err,
312	}
313}
314
315func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
316	if !a.Model().SupportsImages && attachments != nil {
317		attachments = nil
318	}
319	events := make(chan AgentEvent)
320	if a.IsSessionBusy(sessionID) {
321		return nil, ErrSessionBusy
322	}
323
324	genCtx, cancel := context.WithCancel(ctx)
325
326	a.activeRequests.Store(sessionID, cancel)
327	go func() {
328		logging.Debug("Request started", "sessionID", sessionID)
329		defer logging.RecoverPanic("agent.Run", func() {
330			events <- a.err(fmt.Errorf("panic while running the agent"))
331		})
332		var attachmentParts []message.ContentPart
333		for _, attachment := range attachments {
334			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
335		}
336		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
337		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
338			logging.ErrorPersist(result.Error.Error())
339		}
340		logging.Debug("Request completed", "sessionID", sessionID)
341		a.activeRequests.Delete(sessionID)
342		cancel()
343		a.Publish(pubsub.CreatedEvent, result)
344		events <- result
345		close(events)
346	}()
347	return events, nil
348}
349
350func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
351	cfg := config.Get()
352	// List existing messages; if none, start title generation asynchronously.
353	msgs, err := a.messages.List(ctx, sessionID)
354	if err != nil {
355		return a.err(fmt.Errorf("failed to list messages: %w", err))
356	}
357	if len(msgs) == 0 {
358		go func() {
359			defer logging.RecoverPanic("agent.Run", func() {
360				logging.ErrorPersist("panic while generating title")
361			})
362			titleErr := a.generateTitle(context.Background(), sessionID, content)
363			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
364				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
365			}
366		}()
367	}
368	session, err := a.sessions.Get(ctx, sessionID)
369	if err != nil {
370		return a.err(fmt.Errorf("failed to get session: %w", err))
371	}
372	if session.SummaryMessageID != "" {
373		summaryMsgInex := -1
374		for i, msg := range msgs {
375			if msg.ID == session.SummaryMessageID {
376				summaryMsgInex = i
377				break
378			}
379		}
380		if summaryMsgInex != -1 {
381			msgs = msgs[summaryMsgInex:]
382			msgs[0].Role = message.User
383		}
384	}
385
386	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
387	if err != nil {
388		return a.err(fmt.Errorf("failed to create user message: %w", err))
389	}
390	// Append the new user message to the conversation history.
391	msgHistory := append(msgs, userMsg)
392
393	for {
394		// Check for cancellation before each iteration
395		select {
396		case <-ctx.Done():
397			return a.err(ctx.Err())
398		default:
399			// Continue processing
400		}
401		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
402		if err != nil {
403			if errors.Is(err, context.Canceled) {
404				agentMessage.AddFinish(message.FinishReasonCanceled)
405				a.messages.Update(context.Background(), agentMessage)
406				return a.err(ErrRequestCancelled)
407			}
408			return a.err(fmt.Errorf("failed to process events: %w", err))
409		}
410		if cfg.Options.Debug {
411			seqId := (len(msgHistory) + 1) / 2
412			toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults)
413			logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath)
414		} else {
415			logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
416		}
417		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
418			// We are not done, we need to respond with the tool response
419			msgHistory = append(msgHistory, agentMessage, *toolResults)
420			continue
421		}
422		return AgentEvent{
423			Type:    AgentEventTypeResponse,
424			Message: agentMessage,
425			Done:    true,
426		}
427	}
428}
429
430func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
431	parts := []message.ContentPart{message.TextContent{Text: content}}
432	parts = append(parts, attachmentParts...)
433	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
434		Role:  message.User,
435		Parts: parts,
436	})
437}
438
439func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
440	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
441	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
442
443	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444		Role:     message.Assistant,
445		Parts:    []message.ContentPart{},
446		Model:    a.Model().ID,
447		Provider: a.providerID,
448	})
449	if err != nil {
450		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
451	}
452
453	// Add the session and message ID into the context if needed by tools.
454	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
455
456	// Process each event in the stream.
457	for event := range eventChan {
458		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
459			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
460			return assistantMsg, nil, processErr
461		}
462		if ctx.Err() != nil {
463			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
464			return assistantMsg, nil, ctx.Err()
465		}
466	}
467
468	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
469	toolCalls := assistantMsg.ToolCalls()
470	for i, toolCall := range toolCalls {
471		select {
472		case <-ctx.Done():
473			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
474			// Make all future tool calls cancelled
475			for j := i; j < len(toolCalls); j++ {
476				toolResults[j] = message.ToolResult{
477					ToolCallID: toolCalls[j].ID,
478					Content:    "Tool execution canceled by user",
479					IsError:    true,
480				}
481			}
482			goto out
483		default:
484			// Continue processing
485			var tool tools.BaseTool
486			for _, availableTool := range a.tools {
487				if availableTool.Info().Name == toolCall.Name {
488					tool = availableTool
489					break
490				}
491			}
492
493			// Tool not found
494			if tool == nil {
495				toolResults[i] = message.ToolResult{
496					ToolCallID: toolCall.ID,
497					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
498					IsError:    true,
499				}
500				continue
501			}
502			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
503				ID:    toolCall.ID,
504				Name:  toolCall.Name,
505				Input: toolCall.Input,
506			})
507			if toolErr != nil {
508				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
509					toolResults[i] = message.ToolResult{
510						ToolCallID: toolCall.ID,
511						Content:    "Permission denied",
512						IsError:    true,
513					}
514					for j := i + 1; j < len(toolCalls); j++ {
515						toolResults[j] = message.ToolResult{
516							ToolCallID: toolCalls[j].ID,
517							Content:    "Tool execution canceled by user",
518							IsError:    true,
519						}
520					}
521					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
522					break
523				}
524			}
525			toolResults[i] = message.ToolResult{
526				ToolCallID: toolCall.ID,
527				Content:    toolResult.Content,
528				Metadata:   toolResult.Metadata,
529				IsError:    toolResult.IsError,
530			}
531		}
532	}
533out:
534	if len(toolResults) == 0 {
535		return assistantMsg, nil, nil
536	}
537	parts := make([]message.ContentPart, 0)
538	for _, tr := range toolResults {
539		parts = append(parts, tr)
540	}
541	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
542		Role:     message.Tool,
543		Parts:    parts,
544		Provider: a.providerID,
545	})
546	if err != nil {
547		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
548	}
549
550	return assistantMsg, &msg, err
551}
552
553func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
554	msg.AddFinish(finishReson)
555	_ = a.messages.Update(ctx, *msg)
556}
557
558func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
559	select {
560	case <-ctx.Done():
561		return ctx.Err()
562	default:
563		// Continue processing.
564	}
565
566	switch event.Type {
567	case provider.EventThinkingDelta:
568		assistantMsg.AppendReasoningContent(event.Content)
569		return a.messages.Update(ctx, *assistantMsg)
570	case provider.EventContentDelta:
571		assistantMsg.AppendContent(event.Content)
572		return a.messages.Update(ctx, *assistantMsg)
573	case provider.EventToolUseStart:
574		logging.Info("Tool call started", "toolCall", event.ToolCall)
575		assistantMsg.AddToolCall(*event.ToolCall)
576		return a.messages.Update(ctx, *assistantMsg)
577	case provider.EventToolUseDelta:
578		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
579		return a.messages.Update(ctx, *assistantMsg)
580	case provider.EventToolUseStop:
581		logging.Info("Finished tool call", "toolCall", event.ToolCall)
582		assistantMsg.FinishToolCall(event.ToolCall.ID)
583		return a.messages.Update(ctx, *assistantMsg)
584	case provider.EventError:
585		if errors.Is(event.Error, context.Canceled) {
586			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
587			return context.Canceled
588		}
589		logging.ErrorPersist(event.Error.Error())
590		return event.Error
591	case provider.EventComplete:
592		assistantMsg.SetToolCalls(event.Response.ToolCalls)
593		assistantMsg.AddFinish(event.Response.FinishReason)
594		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
595			return fmt.Errorf("failed to update message: %w", err)
596		}
597		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
598	}
599
600	return nil
601}
602
603func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
604	sess, err := a.sessions.Get(ctx, sessionID)
605	if err != nil {
606		return fmt.Errorf("failed to get session: %w", err)
607	}
608
609	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
610		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
611		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
612		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
613
614	sess.Cost += cost
615	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
616	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
617
618	_, err = a.sessions.Save(ctx, sess)
619	if err != nil {
620		return fmt.Errorf("failed to save session: %w", err)
621	}
622	return nil
623}
624
625func (a *agent) Summarize(ctx context.Context, sessionID string) error {
626	if a.summarizeProvider == nil {
627		return fmt.Errorf("summarize provider not available")
628	}
629
630	// Check if session is busy
631	if a.IsSessionBusy(sessionID) {
632		return ErrSessionBusy
633	}
634
635	// Create a new context with cancellation
636	summarizeCtx, cancel := context.WithCancel(ctx)
637
638	// Store the cancel function in activeRequests to allow cancellation
639	a.activeRequests.Store(sessionID+"-summarize", cancel)
640
641	go func() {
642		defer a.activeRequests.Delete(sessionID + "-summarize")
643		defer cancel()
644		event := AgentEvent{
645			Type:     AgentEventTypeSummarize,
646			Progress: "Starting summarization...",
647		}
648
649		a.Publish(pubsub.CreatedEvent, event)
650		// Get all messages from the session
651		msgs, err := a.messages.List(summarizeCtx, sessionID)
652		if err != nil {
653			event = AgentEvent{
654				Type:  AgentEventTypeError,
655				Error: fmt.Errorf("failed to list messages: %w", err),
656				Done:  true,
657			}
658			a.Publish(pubsub.CreatedEvent, event)
659			return
660		}
661		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
662
663		if len(msgs) == 0 {
664			event = AgentEvent{
665				Type:  AgentEventTypeError,
666				Error: fmt.Errorf("no messages to summarize"),
667				Done:  true,
668			}
669			a.Publish(pubsub.CreatedEvent, event)
670			return
671		}
672
673		event = AgentEvent{
674			Type:     AgentEventTypeSummarize,
675			Progress: "Analyzing conversation...",
676		}
677		a.Publish(pubsub.CreatedEvent, event)
678
679		// Add a system message to guide the summarization
680		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
681
682		// Create a new message with the summarize prompt
683		promptMsg := message.Message{
684			Role:  message.User,
685			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
686		}
687
688		// Append the prompt to the messages
689		msgsWithPrompt := append(msgs, promptMsg)
690
691		event = AgentEvent{
692			Type:     AgentEventTypeSummarize,
693			Progress: "Generating summary...",
694		}
695
696		a.Publish(pubsub.CreatedEvent, event)
697
698		// Send the messages to the summarize provider
699		response := a.summarizeProvider.StreamResponse(
700			summarizeCtx,
701			msgsWithPrompt,
702			make([]tools.BaseTool, 0),
703		)
704		var finalResponse *provider.ProviderResponse
705		for r := range response {
706			if r.Error != nil {
707				event = AgentEvent{
708					Type:  AgentEventTypeError,
709					Error: fmt.Errorf("failed to summarize: %w", err),
710					Done:  true,
711				}
712				a.Publish(pubsub.CreatedEvent, event)
713				return
714			}
715			finalResponse = r.Response
716		}
717
718		summary := strings.TrimSpace(finalResponse.Content)
719		if summary == "" {
720			event = AgentEvent{
721				Type:  AgentEventTypeError,
722				Error: fmt.Errorf("empty summary returned"),
723				Done:  true,
724			}
725			a.Publish(pubsub.CreatedEvent, event)
726			return
727		}
728		event = AgentEvent{
729			Type:     AgentEventTypeSummarize,
730			Progress: "Creating new session...",
731		}
732
733		a.Publish(pubsub.CreatedEvent, event)
734		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
735		if err != nil {
736			event = AgentEvent{
737				Type:  AgentEventTypeError,
738				Error: fmt.Errorf("failed to get session: %w", err),
739				Done:  true,
740			}
741
742			a.Publish(pubsub.CreatedEvent, event)
743			return
744		}
745		// Create a message in the new session with the summary
746		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
747			Role: message.Assistant,
748			Parts: []message.ContentPart{
749				message.TextContent{Text: summary},
750				message.Finish{
751					Reason: message.FinishReasonEndTurn,
752					Time:   time.Now().Unix(),
753				},
754			},
755			Model:    a.summarizeProvider.Model().ID,
756			Provider: a.summarizeProviderID,
757		})
758		if err != nil {
759			event = AgentEvent{
760				Type:  AgentEventTypeError,
761				Error: fmt.Errorf("failed to create summary message: %w", err),
762				Done:  true,
763			}
764
765			a.Publish(pubsub.CreatedEvent, event)
766			return
767		}
768		oldSession.SummaryMessageID = msg.ID
769		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
770		oldSession.PromptTokens = 0
771		model := a.summarizeProvider.Model()
772		usage := finalResponse.Usage
773		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
774			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
775			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
776			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
777		oldSession.Cost += cost
778		_, err = a.sessions.Save(summarizeCtx, oldSession)
779		if err != nil {
780			event = AgentEvent{
781				Type:  AgentEventTypeError,
782				Error: fmt.Errorf("failed to save session: %w", err),
783				Done:  true,
784			}
785			a.Publish(pubsub.CreatedEvent, event)
786		}
787
788		event = AgentEvent{
789			Type:      AgentEventTypeSummarize,
790			SessionID: oldSession.ID,
791			Progress:  "Summary complete",
792			Done:      true,
793		}
794		a.Publish(pubsub.CreatedEvent, event)
795		// Send final success event with the new session ID
796	}()
797
798	return nil
799}
800
801func (a *agent) CancelAll() {
802	a.activeRequests.Range(func(key, value any) bool {
803		a.Cancel(key.(string)) // key is sessionID
804		return true
805	})
806}
807
808func (a *agent) UpdateModel() error {
809	cfg := config.Get()
810
811	// Get current provider configuration
812	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
813	if currentProviderCfg.ID == "" {
814		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
815	}
816
817	// Check if provider has changed
818	if string(currentProviderCfg.ID) != a.providerID {
819		// Provider changed, need to recreate the main provider
820		model := cfg.GetModelByType(a.agentCfg.Model)
821		if model.ID == "" {
822			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
823		}
824
825		promptID := agentPromptMap[a.agentCfg.ID]
826		if promptID == "" {
827			promptID = prompt.PromptDefault
828		}
829
830		opts := []provider.ProviderClientOption{
831			provider.WithModel(a.agentCfg.Model),
832			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
833		}
834
835		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
836		if err != nil {
837			return fmt.Errorf("failed to create new provider: %w", err)
838		}
839
840		// Update the provider and provider ID
841		a.provider = newProvider
842		a.providerID = string(currentProviderCfg.ID)
843	}
844
845	// Check if small model provider has changed (affects title and summarize providers)
846	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
847	var smallModelProviderCfg config.ProviderConfig
848
849	for _, p := range cfg.Providers {
850		if p.ID == smallModelCfg.Provider {
851			smallModelProviderCfg = p
852			break
853		}
854	}
855
856	if smallModelProviderCfg.ID == "" {
857		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
858	}
859
860	// Check if summarize provider has changed
861	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
862		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
863		if smallModel == nil {
864			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
865		}
866
867		// Recreate title provider
868		titleOpts := []provider.ProviderClientOption{
869			provider.WithModel(config.SelectedModelTypeSmall),
870			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
871			// We want the title to be short, so we limit the max tokens
872			provider.WithMaxTokens(40),
873		}
874		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
875		if err != nil {
876			return fmt.Errorf("failed to create new title provider: %w", err)
877		}
878
879		// Recreate summarize provider
880		summarizeOpts := []provider.ProviderClientOption{
881			provider.WithModel(config.SelectedModelTypeSmall),
882			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
883		}
884		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
885		if err != nil {
886			return fmt.Errorf("failed to create new summarize provider: %w", err)
887		}
888
889		// Update the providers and provider ID
890		a.titleProvider = newTitleProvider
891		a.summarizeProvider = newSummarizeProvider
892		a.summarizeProviderID = string(smallModelProviderCfg.ID)
893	}
894
895	return nil
896}