agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	fur "github.com/charmbracelet/crush/internal/fur/provider"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25)
 26
 27// Common errors
 28var (
 29	ErrRequestCancelled = errors.New("request canceled by user")
 30	ErrSessionBusy      = errors.New("session is currently processing another request")
 31)
 32
 33type AgentEventType string
 34
 35const (
 36	AgentEventTypeError     AgentEventType = "error"
 37	AgentEventTypeResponse  AgentEventType = "response"
 38	AgentEventTypeSummarize AgentEventType = "summarize"
 39)
 40
 41type AgentEvent struct {
 42	Type    AgentEventType
 43	Message message.Message
 44	Error   error
 45
 46	// When summarizing
 47	SessionID string
 48	Progress  string
 49	Done      bool
 50}
 51
 52type Service interface {
 53	pubsub.Suscriber[AgentEvent]
 54	Model() fur.Model
 55	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 56	Cancel(sessionID string)
 57	CancelAll()
 58	IsSessionBusy(sessionID string) bool
 59	IsBusy() bool
 60	Summarize(ctx context.Context, sessionID string) error
 61	UpdateModel() error
 62}
 63
 64type agent struct {
 65	*pubsub.Broker[AgentEvent]
 66	agentCfg config.Agent
 67	sessions session.Service
 68	messages message.Service
 69
 70	tools      []tools.BaseTool
 71	provider   provider.Provider
 72	providerID string
 73
 74	titleProvider       provider.Provider
 75	summarizeProvider   provider.Provider
 76	summarizeProviderID string
 77
 78	activeRequests sync.Map
 79}
 80
 81var agentPromptMap = map[string]prompt.PromptID{
 82	"coder": prompt.PromptCoder,
 83	"task":  prompt.PromptTask,
 84}
 85
 86func NewAgent(
 87	agentCfg config.Agent,
 88	// These services are needed in the tools
 89	permissions permission.Service,
 90	sessions session.Service,
 91	messages message.Service,
 92	history history.Service,
 93	lspClients map[string]*lsp.Client,
 94) (Service, error) {
 95	ctx := context.Background()
 96	cfg := config.Get()
 97	otherTools := GetMcpTools(ctx, permissions, cfg)
 98	if len(lspClients) > 0 {
 99		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100	}
101
102	cwd := cfg.WorkingDir()
103	allTools := []tools.BaseTool{
104		tools.NewBashTool(permissions, cwd),
105		tools.NewEditTool(lspClients, permissions, history, cwd),
106		tools.NewFetchTool(permissions, cwd),
107		tools.NewGlobTool(cwd),
108		tools.NewGrepTool(cwd),
109		tools.NewLsTool(cwd),
110		tools.NewSourcegraphTool(),
111		tools.NewViewTool(lspClients, cwd),
112		tools.NewWriteTool(lspClients, permissions, history, cwd),
113	}
114
115	if agentCfg.ID == "coder" {
116		taskAgentCfg := config.Get().Agents["task"]
117		if taskAgentCfg.ID == "" {
118			return nil, fmt.Errorf("task agent not found in config")
119		}
120		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
121		if err != nil {
122			return nil, fmt.Errorf("failed to create task agent: %w", err)
123		}
124
125		allTools = append(
126			allTools,
127			NewAgentTool(
128				taskAgent,
129				sessions,
130				messages,
131			),
132		)
133	}
134
135	allTools = append(allTools, otherTools...)
136	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
137	if providerCfg == nil {
138		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
139	}
140	model := config.Get().GetModelByType(agentCfg.Model)
141
142	if model == nil {
143		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
144	}
145
146	promptID := agentPromptMap[agentCfg.ID]
147	if promptID == "" {
148		promptID = prompt.PromptDefault
149	}
150	opts := []provider.ProviderClientOption{
151		provider.WithModel(agentCfg.Model),
152		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
153	}
154	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
155	if err != nil {
156		return nil, err
157	}
158
159	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
160	var smallModelProviderCfg *config.ProviderConfig
161	if smallModelCfg.Provider == providerCfg.ID {
162		smallModelProviderCfg = providerCfg
163	} else {
164		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
165
166		if smallModelProviderCfg.ID == "" {
167			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
168		}
169	}
170	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
171	if smallModel.ID == "" {
172		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
173	}
174
175	titleOpts := []provider.ProviderClientOption{
176		provider.WithModel(config.SelectedModelTypeSmall),
177		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
178	}
179	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
180	if err != nil {
181		return nil, err
182	}
183	summarizeOpts := []provider.ProviderClientOption{
184		provider.WithModel(config.SelectedModelTypeSmall),
185		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
186	}
187	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
188	if err != nil {
189		return nil, err
190	}
191
192	agentTools := []tools.BaseTool{}
193	if agentCfg.AllowedTools == nil {
194		agentTools = allTools
195	} else {
196		for _, tool := range allTools {
197			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
198				agentTools = append(agentTools, tool)
199			}
200		}
201	}
202
203	agent := &agent{
204		Broker:              pubsub.NewBroker[AgentEvent](),
205		agentCfg:            agentCfg,
206		provider:            agentProvider,
207		providerID:          string(providerCfg.ID),
208		messages:            messages,
209		sessions:            sessions,
210		tools:               agentTools,
211		titleProvider:       titleProvider,
212		summarizeProvider:   summarizeProvider,
213		summarizeProviderID: string(smallModelProviderCfg.ID),
214		activeRequests:      sync.Map{},
215	}
216
217	return agent, nil
218}
219
220func (a *agent) Model() fur.Model {
221	return *config.Get().GetModelByType(a.agentCfg.Model)
222}
223
224func (a *agent) Cancel(sessionID string) {
225	// Cancel regular requests
226	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
227		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
228			slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
229			cancel()
230		}
231	}
232
233	// Also check for summarize requests
234	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
235		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236			slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
237			cancel()
238		}
239	}
240}
241
242func (a *agent) IsBusy() bool {
243	busy := false
244	a.activeRequests.Range(func(key, value any) bool {
245		if cancelFunc, ok := value.(context.CancelFunc); ok {
246			if cancelFunc != nil {
247				busy = true
248				return false
249			}
250		}
251		return true
252	})
253	return busy
254}
255
256func (a *agent) IsSessionBusy(sessionID string) bool {
257	_, busy := a.activeRequests.Load(sessionID)
258	return busy
259}
260
261func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
262	if content == "" {
263		return nil
264	}
265	if a.titleProvider == nil {
266		return nil
267	}
268	session, err := a.sessions.Get(ctx, sessionID)
269	if err != nil {
270		return err
271	}
272	parts := []message.ContentPart{message.TextContent{
273		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
274	}}
275
276	// Use streaming approach like summarization
277	response := a.titleProvider.StreamResponse(
278		ctx,
279		[]message.Message{
280			{
281				Role:  message.User,
282				Parts: parts,
283			},
284		},
285		make([]tools.BaseTool, 0),
286	)
287
288	var finalResponse *provider.ProviderResponse
289	for r := range response {
290		if r.Error != nil {
291			return r.Error
292		}
293		finalResponse = r.Response
294	}
295
296	if finalResponse == nil {
297		return fmt.Errorf("no response received from title provider")
298	}
299
300	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
301	if title == "" {
302		return nil
303	}
304
305	session.Title = title
306	_, err = a.sessions.Save(ctx, session)
307	return err
308}
309
310func (a *agent) err(err error) AgentEvent {
311	return AgentEvent{
312		Type:  AgentEventTypeError,
313		Error: err,
314	}
315}
316
317func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
318	if !a.Model().SupportsImages && attachments != nil {
319		attachments = nil
320	}
321	events := make(chan AgentEvent)
322	if a.IsSessionBusy(sessionID) {
323		return nil, ErrSessionBusy
324	}
325
326	genCtx, cancel := context.WithCancel(ctx)
327
328	a.activeRequests.Store(sessionID, cancel)
329	go func() {
330		slog.Debug("Request started", "sessionID", sessionID)
331		defer log.RecoverPanic("agent.Run", func() {
332			events <- a.err(fmt.Errorf("panic while running the agent"))
333		})
334		var attachmentParts []message.ContentPart
335		for _, attachment := range attachments {
336			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
337		}
338		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
339		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
340			slog.Error(result.Error.Error())
341		}
342		slog.Debug("Request completed", "sessionID", sessionID)
343		a.activeRequests.Delete(sessionID)
344		cancel()
345		a.Publish(pubsub.CreatedEvent, result)
346		events <- result
347		close(events)
348	}()
349	return events, nil
350}
351
352func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
353	cfg := config.Get()
354	// List existing messages; if none, start title generation asynchronously.
355	msgs, err := a.messages.List(ctx, sessionID)
356	if err != nil {
357		return a.err(fmt.Errorf("failed to list messages: %w", err))
358	}
359	if len(msgs) == 0 {
360		go func() {
361			defer log.RecoverPanic("agent.Run", func() {
362				slog.Error("panic while generating title")
363			})
364			titleErr := a.generateTitle(context.Background(), sessionID, content)
365			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
366				slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
367			}
368		}()
369	}
370	session, err := a.sessions.Get(ctx, sessionID)
371	if err != nil {
372		return a.err(fmt.Errorf("failed to get session: %w", err))
373	}
374	if session.SummaryMessageID != "" {
375		summaryMsgInex := -1
376		for i, msg := range msgs {
377			if msg.ID == session.SummaryMessageID {
378				summaryMsgInex = i
379				break
380			}
381		}
382		if summaryMsgInex != -1 {
383			msgs = msgs[summaryMsgInex:]
384			msgs[0].Role = message.User
385		}
386	}
387
388	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
389	if err != nil {
390		return a.err(fmt.Errorf("failed to create user message: %w", err))
391	}
392	// Append the new user message to the conversation history.
393	msgHistory := append(msgs, userMsg)
394
395	for {
396		// Check for cancellation before each iteration
397		select {
398		case <-ctx.Done():
399			return a.err(ctx.Err())
400		default:
401			// Continue processing
402		}
403		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
404		if err != nil {
405			if errors.Is(err, context.Canceled) {
406				agentMessage.AddFinish(message.FinishReasonCanceled)
407				a.messages.Update(context.Background(), agentMessage)
408				return a.err(ErrRequestCancelled)
409			}
410			return a.err(fmt.Errorf("failed to process events: %w", err))
411		}
412		if cfg.Options.Debug {
413			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
414		}
415		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
416			// We are not done, we need to respond with the tool response
417			msgHistory = append(msgHistory, agentMessage, *toolResults)
418			continue
419		}
420		return AgentEvent{
421			Type:    AgentEventTypeResponse,
422			Message: agentMessage,
423			Done:    true,
424		}
425	}
426}
427
428func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
429	parts := []message.ContentPart{message.TextContent{Text: content}}
430	parts = append(parts, attachmentParts...)
431	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
432		Role:  message.User,
433		Parts: parts,
434	})
435}
436
437func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
438	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
439	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
440
441	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442		Role:     message.Assistant,
443		Parts:    []message.ContentPart{},
444		Model:    a.Model().ID,
445		Provider: a.providerID,
446	})
447	if err != nil {
448		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
449	}
450
451	// Add the session and message ID into the context if needed by tools.
452	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
453
454	// Process each event in the stream.
455	for event := range eventChan {
456		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
457			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
458			return assistantMsg, nil, processErr
459		}
460		if ctx.Err() != nil {
461			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
462			return assistantMsg, nil, ctx.Err()
463		}
464	}
465
466	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
467	toolCalls := assistantMsg.ToolCalls()
468	for i, toolCall := range toolCalls {
469		select {
470		case <-ctx.Done():
471			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
472			// Make all future tool calls cancelled
473			for j := i; j < len(toolCalls); j++ {
474				toolResults[j] = message.ToolResult{
475					ToolCallID: toolCalls[j].ID,
476					Content:    "Tool execution canceled by user",
477					IsError:    true,
478				}
479			}
480			goto out
481		default:
482			// Continue processing
483			var tool tools.BaseTool
484			for _, availableTool := range a.tools {
485				if availableTool.Info().Name == toolCall.Name {
486					tool = availableTool
487					break
488				}
489			}
490
491			// Tool not found
492			if tool == nil {
493				toolResults[i] = message.ToolResult{
494					ToolCallID: toolCall.ID,
495					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
496					IsError:    true,
497				}
498				continue
499			}
500			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
501				ID:    toolCall.ID,
502				Name:  toolCall.Name,
503				Input: toolCall.Input,
504			})
505			if toolErr != nil {
506				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
507					toolResults[i] = message.ToolResult{
508						ToolCallID: toolCall.ID,
509						Content:    "Permission denied",
510						IsError:    true,
511					}
512					for j := i + 1; j < len(toolCalls); j++ {
513						toolResults[j] = message.ToolResult{
514							ToolCallID: toolCalls[j].ID,
515							Content:    "Tool execution canceled by user",
516							IsError:    true,
517						}
518					}
519					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
520					break
521				}
522			}
523			toolResults[i] = message.ToolResult{
524				ToolCallID: toolCall.ID,
525				Content:    toolResult.Content,
526				Metadata:   toolResult.Metadata,
527				IsError:    toolResult.IsError,
528			}
529		}
530	}
531out:
532	if len(toolResults) == 0 {
533		return assistantMsg, nil, nil
534	}
535	parts := make([]message.ContentPart, 0)
536	for _, tr := range toolResults {
537		parts = append(parts, tr)
538	}
539	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
540		Role:     message.Tool,
541		Parts:    parts,
542		Provider: a.providerID,
543	})
544	if err != nil {
545		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
546	}
547
548	return assistantMsg, &msg, err
549}
550
551func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
552	msg.AddFinish(finishReson)
553	_ = a.messages.Update(ctx, *msg)
554}
555
556func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
557	select {
558	case <-ctx.Done():
559		return ctx.Err()
560	default:
561		// Continue processing.
562	}
563
564	switch event.Type {
565	case provider.EventThinkingDelta:
566		assistantMsg.AppendReasoningContent(event.Content)
567		return a.messages.Update(ctx, *assistantMsg)
568	case provider.EventContentDelta:
569		assistantMsg.AppendContent(event.Content)
570		return a.messages.Update(ctx, *assistantMsg)
571	case provider.EventToolUseStart:
572		slog.Info("Tool call started", "toolCall", event.ToolCall)
573		assistantMsg.AddToolCall(*event.ToolCall)
574		return a.messages.Update(ctx, *assistantMsg)
575	case provider.EventToolUseDelta:
576		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
577		return a.messages.Update(ctx, *assistantMsg)
578	case provider.EventToolUseStop:
579		slog.Info("Finished tool call", "toolCall", event.ToolCall)
580		assistantMsg.FinishToolCall(event.ToolCall.ID)
581		return a.messages.Update(ctx, *assistantMsg)
582	case provider.EventError:
583		if errors.Is(event.Error, context.Canceled) {
584			slog.Info(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
585			return context.Canceled
586		}
587		slog.Error(event.Error.Error())
588		return event.Error
589	case provider.EventComplete:
590		assistantMsg.SetToolCalls(event.Response.ToolCalls)
591		assistantMsg.AddFinish(event.Response.FinishReason)
592		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
593			return fmt.Errorf("failed to update message: %w", err)
594		}
595		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
596	}
597
598	return nil
599}
600
601func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
602	sess, err := a.sessions.Get(ctx, sessionID)
603	if err != nil {
604		return fmt.Errorf("failed to get session: %w", err)
605	}
606
607	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
608		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
609		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
610		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
611
612	sess.Cost += cost
613	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
614	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
615
616	_, err = a.sessions.Save(ctx, sess)
617	if err != nil {
618		return fmt.Errorf("failed to save session: %w", err)
619	}
620	return nil
621}
622
623func (a *agent) Summarize(ctx context.Context, sessionID string) error {
624	if a.summarizeProvider == nil {
625		return fmt.Errorf("summarize provider not available")
626	}
627
628	// Check if session is busy
629	if a.IsSessionBusy(sessionID) {
630		return ErrSessionBusy
631	}
632
633	// Create a new context with cancellation
634	summarizeCtx, cancel := context.WithCancel(ctx)
635
636	// Store the cancel function in activeRequests to allow cancellation
637	a.activeRequests.Store(sessionID+"-summarize", cancel)
638
639	go func() {
640		defer a.activeRequests.Delete(sessionID + "-summarize")
641		defer cancel()
642		event := AgentEvent{
643			Type:     AgentEventTypeSummarize,
644			Progress: "Starting summarization...",
645		}
646
647		a.Publish(pubsub.CreatedEvent, event)
648		// Get all messages from the session
649		msgs, err := a.messages.List(summarizeCtx, sessionID)
650		if err != nil {
651			event = AgentEvent{
652				Type:  AgentEventTypeError,
653				Error: fmt.Errorf("failed to list messages: %w", err),
654				Done:  true,
655			}
656			a.Publish(pubsub.CreatedEvent, event)
657			return
658		}
659		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
660
661		if len(msgs) == 0 {
662			event = AgentEvent{
663				Type:  AgentEventTypeError,
664				Error: fmt.Errorf("no messages to summarize"),
665				Done:  true,
666			}
667			a.Publish(pubsub.CreatedEvent, event)
668			return
669		}
670
671		event = AgentEvent{
672			Type:     AgentEventTypeSummarize,
673			Progress: "Analyzing conversation...",
674		}
675		a.Publish(pubsub.CreatedEvent, event)
676
677		// Add a system message to guide the summarization
678		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
679
680		// Create a new message with the summarize prompt
681		promptMsg := message.Message{
682			Role:  message.User,
683			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
684		}
685
686		// Append the prompt to the messages
687		msgsWithPrompt := append(msgs, promptMsg)
688
689		event = AgentEvent{
690			Type:     AgentEventTypeSummarize,
691			Progress: "Generating summary...",
692		}
693
694		a.Publish(pubsub.CreatedEvent, event)
695
696		// Send the messages to the summarize provider
697		response := a.summarizeProvider.StreamResponse(
698			summarizeCtx,
699			msgsWithPrompt,
700			make([]tools.BaseTool, 0),
701		)
702		var finalResponse *provider.ProviderResponse
703		for r := range response {
704			if r.Error != nil {
705				event = AgentEvent{
706					Type:  AgentEventTypeError,
707					Error: fmt.Errorf("failed to summarize: %w", err),
708					Done:  true,
709				}
710				a.Publish(pubsub.CreatedEvent, event)
711				return
712			}
713			finalResponse = r.Response
714		}
715
716		summary := strings.TrimSpace(finalResponse.Content)
717		if summary == "" {
718			event = AgentEvent{
719				Type:  AgentEventTypeError,
720				Error: fmt.Errorf("empty summary returned"),
721				Done:  true,
722			}
723			a.Publish(pubsub.CreatedEvent, event)
724			return
725		}
726		event = AgentEvent{
727			Type:     AgentEventTypeSummarize,
728			Progress: "Creating new session...",
729		}
730
731		a.Publish(pubsub.CreatedEvent, event)
732		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
733		if err != nil {
734			event = AgentEvent{
735				Type:  AgentEventTypeError,
736				Error: fmt.Errorf("failed to get session: %w", err),
737				Done:  true,
738			}
739
740			a.Publish(pubsub.CreatedEvent, event)
741			return
742		}
743		// Create a message in the new session with the summary
744		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
745			Role: message.Assistant,
746			Parts: []message.ContentPart{
747				message.TextContent{Text: summary},
748				message.Finish{
749					Reason: message.FinishReasonEndTurn,
750					Time:   time.Now().Unix(),
751				},
752			},
753			Model:    a.summarizeProvider.Model().ID,
754			Provider: a.summarizeProviderID,
755		})
756		if err != nil {
757			event = AgentEvent{
758				Type:  AgentEventTypeError,
759				Error: fmt.Errorf("failed to create summary message: %w", err),
760				Done:  true,
761			}
762
763			a.Publish(pubsub.CreatedEvent, event)
764			return
765		}
766		oldSession.SummaryMessageID = msg.ID
767		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
768		oldSession.PromptTokens = 0
769		model := a.summarizeProvider.Model()
770		usage := finalResponse.Usage
771		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
772			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
773			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
774			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
775		oldSession.Cost += cost
776		_, err = a.sessions.Save(summarizeCtx, oldSession)
777		if err != nil {
778			event = AgentEvent{
779				Type:  AgentEventTypeError,
780				Error: fmt.Errorf("failed to save session: %w", err),
781				Done:  true,
782			}
783			a.Publish(pubsub.CreatedEvent, event)
784		}
785
786		event = AgentEvent{
787			Type:      AgentEventTypeSummarize,
788			SessionID: oldSession.ID,
789			Progress:  "Summary complete",
790			Done:      true,
791		}
792		a.Publish(pubsub.CreatedEvent, event)
793		// Send final success event with the new session ID
794	}()
795
796	return nil
797}
798
799func (a *agent) CancelAll() {
800	a.activeRequests.Range(func(key, value any) bool {
801		a.Cancel(key.(string)) // key is sessionID
802		return true
803	})
804}
805
806func (a *agent) UpdateModel() error {
807	cfg := config.Get()
808
809	// Get current provider configuration
810	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
811	if currentProviderCfg.ID == "" {
812		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
813	}
814
815	// Check if provider has changed
816	if string(currentProviderCfg.ID) != a.providerID {
817		// Provider changed, need to recreate the main provider
818		model := cfg.GetModelByType(a.agentCfg.Model)
819		if model.ID == "" {
820			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
821		}
822
823		promptID := agentPromptMap[a.agentCfg.ID]
824		if promptID == "" {
825			promptID = prompt.PromptDefault
826		}
827
828		opts := []provider.ProviderClientOption{
829			provider.WithModel(a.agentCfg.Model),
830			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
831		}
832
833		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
834		if err != nil {
835			return fmt.Errorf("failed to create new provider: %w", err)
836		}
837
838		// Update the provider and provider ID
839		a.provider = newProvider
840		a.providerID = string(currentProviderCfg.ID)
841	}
842
843	// Check if small model provider has changed (affects title and summarize providers)
844	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
845	var smallModelProviderCfg config.ProviderConfig
846
847	for _, p := range cfg.Providers {
848		if p.ID == smallModelCfg.Provider {
849			smallModelProviderCfg = p
850			break
851		}
852	}
853
854	if smallModelProviderCfg.ID == "" {
855		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
856	}
857
858	// Check if summarize provider has changed
859	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
860		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
861		if smallModel == nil {
862			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
863		}
864
865		// Recreate title provider
866		titleOpts := []provider.ProviderClientOption{
867			provider.WithModel(config.SelectedModelTypeSmall),
868			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
869			// We want the title to be short, so we limit the max tokens
870			provider.WithMaxTokens(40),
871		}
872		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
873		if err != nil {
874			return fmt.Errorf("failed to create new title provider: %w", err)
875		}
876
877		// Recreate summarize provider
878		summarizeOpts := []provider.ProviderClientOption{
879			provider.WithModel(config.SelectedModelTypeSmall),
880			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
881		}
882		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
883		if err != nil {
884			return fmt.Errorf("failed to create new summarize provider: %w", err)
885		}
886
887		// Update the providers and provider ID
888		a.titleProvider = newTitleProvider
889		a.summarizeProvider = newSummarizeProvider
890		a.summarizeProviderID = string(smallModelProviderCfg.ID)
891	}
892
893	return nil
894}