agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"slices"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/history"
 14	"github.com/charmbracelet/crush/internal/llm/prompt"
 15	"github.com/charmbracelet/crush/internal/llm/provider"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/session"
 23)
 24
 25// Common errors
 26var (
 27	ErrRequestCancelled = errors.New("request cancelled by user")
 28	ErrSessionBusy      = errors.New("session is currently processing another request")
 29)
 30
 31type AgentEventType string
 32
 33const (
 34	AgentEventTypeError     AgentEventType = "error"
 35	AgentEventTypeResponse  AgentEventType = "response"
 36	AgentEventTypeSummarize AgentEventType = "summarize"
 37)
 38
 39type AgentEvent struct {
 40	Type    AgentEventType
 41	Message message.Message
 42	Error   error
 43
 44	// When summarizing
 45	SessionID string
 46	Progress  string
 47	Done      bool
 48}
 49
 50type Service interface {
 51	pubsub.Suscriber[AgentEvent]
 52	Model() config.Model
 53	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 54	Cancel(sessionID string)
 55	CancelAll()
 56	IsSessionBusy(sessionID string) bool
 57	IsBusy() bool
 58	Summarize(ctx context.Context, sessionID string) error
 59	UpdateModel() error
 60}
 61
 62type agent struct {
 63	*pubsub.Broker[AgentEvent]
 64	agentCfg config.Agent
 65	sessions session.Service
 66	messages message.Service
 67
 68	tools      []tools.BaseTool
 69	provider   provider.Provider
 70	providerID string
 71
 72	titleProvider       provider.Provider
 73	summarizeProvider   provider.Provider
 74	summarizeProviderID string
 75
 76	activeRequests sync.Map
 77}
 78
 79var agentPromptMap = map[config.AgentID]prompt.PromptID{
 80	config.AgentCoder: prompt.PromptCoder,
 81	config.AgentTask:  prompt.PromptTask,
 82}
 83
 84func NewAgent(
 85	agentCfg config.Agent,
 86	// These services are needed in the tools
 87	permissions permission.Service,
 88	sessions session.Service,
 89	messages message.Service,
 90	history history.Service,
 91	lspClients map[string]*lsp.Client,
 92) (Service, error) {
 93	ctx := context.Background()
 94	cfg := config.Get()
 95	otherTools := GetMcpTools(ctx, permissions)
 96	if len(lspClients) > 0 {
 97		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 98	}
 99
100	allTools := []tools.BaseTool{
101		tools.NewBashTool(permissions),
102		tools.NewEditTool(lspClients, permissions, history),
103		tools.NewFetchTool(permissions),
104		tools.NewGlobTool(),
105		tools.NewGrepTool(),
106		tools.NewLsTool(),
107		tools.NewSourcegraphTool(),
108		tools.NewViewTool(lspClients),
109		tools.NewWriteTool(lspClients, permissions, history),
110	}
111
112	if agentCfg.ID == config.AgentCoder {
113		taskAgentCfg := config.Get().Agents[config.AgentTask]
114		if taskAgentCfg.ID == "" {
115			return nil, fmt.Errorf("task agent not found in config")
116		}
117		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
118		if err != nil {
119			return nil, fmt.Errorf("failed to create task agent: %w", err)
120		}
121
122		allTools = append(
123			allTools,
124			NewAgentTool(
125				taskAgent,
126				sessions,
127				messages,
128			),
129		)
130	}
131
132	allTools = append(allTools, otherTools...)
133	providerCfg := config.GetAgentProvider(agentCfg.ID)
134	if providerCfg.ID == "" {
135		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
136	}
137	model := config.GetAgentModel(agentCfg.ID)
138
139	if model.ID == "" {
140		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
141	}
142
143	promptID := agentPromptMap[agentCfg.ID]
144	if promptID == "" {
145		promptID = prompt.PromptDefault
146	}
147	opts := []provider.ProviderClientOption{
148		provider.WithModel(agentCfg.Model),
149		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
150	}
151	agentProvider, err := provider.NewProvider(providerCfg, opts...)
152	if err != nil {
153		return nil, err
154	}
155
156	smallModelCfg := cfg.Models.Small
157	var smallModel config.Model
158
159	var smallModelProviderCfg config.ProviderConfig
160	if smallModelCfg.Provider == providerCfg.ID {
161		smallModelProviderCfg = providerCfg
162	} else {
163		for _, p := range cfg.Providers {
164			if p.ID == smallModelCfg.Provider {
165				smallModelProviderCfg = p
166				break
167			}
168		}
169		if smallModelProviderCfg.ID == "" {
170			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
171		}
172	}
173	for _, m := range smallModelProviderCfg.Models {
174		if m.ID == smallModelCfg.ModelID {
175			smallModel = m
176			break
177		}
178	}
179	if smallModel.ID == "" {
180		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
181	}
182
183	titleOpts := []provider.ProviderClientOption{
184		provider.WithModel(config.SmallModel),
185		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
186	}
187	titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
188	if err != nil {
189		return nil, err
190	}
191	summarizeOpts := []provider.ProviderClientOption{
192		provider.WithModel(config.SmallModel),
193		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
194	}
195	summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
196	if err != nil {
197		return nil, err
198	}
199
200	agentTools := []tools.BaseTool{}
201	if agentCfg.AllowedTools == nil {
202		agentTools = allTools
203	} else {
204		for _, tool := range allTools {
205			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
206				agentTools = append(agentTools, tool)
207			}
208		}
209	}
210
211	agent := &agent{
212		Broker:              pubsub.NewBroker[AgentEvent](),
213		agentCfg:            agentCfg,
214		provider:            agentProvider,
215		providerID:          string(providerCfg.ID),
216		messages:            messages,
217		sessions:            sessions,
218		tools:               agentTools,
219		titleProvider:       titleProvider,
220		summarizeProvider:   summarizeProvider,
221		summarizeProviderID: string(smallModelProviderCfg.ID),
222		activeRequests:      sync.Map{},
223	}
224
225	return agent, nil
226}
227
228func (a *agent) Model() config.Model {
229	return config.GetAgentModel(a.agentCfg.ID)
230}
231
232func (a *agent) Cancel(sessionID string) {
233	// Cancel regular requests
234	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
235		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
237			cancel()
238		}
239	}
240
241	// Also check for summarize requests
242	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
243		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
244			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
245			cancel()
246		}
247	}
248}
249
250func (a *agent) IsBusy() bool {
251	busy := false
252	a.activeRequests.Range(func(key, value any) bool {
253		if cancelFunc, ok := value.(context.CancelFunc); ok {
254			if cancelFunc != nil {
255				busy = true
256				return false
257			}
258		}
259		return true
260	})
261	return busy
262}
263
264func (a *agent) IsSessionBusy(sessionID string) bool {
265	_, busy := a.activeRequests.Load(sessionID)
266	return busy
267}
268
269func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
270	if content == "" {
271		return nil
272	}
273	if a.titleProvider == nil {
274		return nil
275	}
276	session, err := a.sessions.Get(ctx, sessionID)
277	if err != nil {
278		return err
279	}
280	parts := []message.ContentPart{message.TextContent{
281		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
282	}}
283
284	// Use streaming approach like summarization
285	response := a.titleProvider.StreamResponse(
286		ctx,
287		[]message.Message{
288			{
289				Role:  message.User,
290				Parts: parts,
291			},
292		},
293		make([]tools.BaseTool, 0),
294	)
295
296	var finalResponse *provider.ProviderResponse
297	for r := range response {
298		if r.Error != nil {
299			return r.Error
300		}
301		finalResponse = r.Response
302	}
303
304	if finalResponse == nil {
305		return fmt.Errorf("no response received from title provider")
306	}
307
308	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
309	if title == "" {
310		return nil
311	}
312
313	session.Title = title
314	_, err = a.sessions.Save(ctx, session)
315	return err
316}
317
318func (a *agent) err(err error) AgentEvent {
319	return AgentEvent{
320		Type:  AgentEventTypeError,
321		Error: err,
322	}
323}
324
325func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
326	if !a.Model().SupportsImages && attachments != nil {
327		attachments = nil
328	}
329	events := make(chan AgentEvent)
330	if a.IsSessionBusy(sessionID) {
331		return nil, ErrSessionBusy
332	}
333
334	genCtx, cancel := context.WithCancel(ctx)
335
336	a.activeRequests.Store(sessionID, cancel)
337	go func() {
338		logging.Debug("Request started", "sessionID", sessionID)
339		defer logging.RecoverPanic("agent.Run", func() {
340			events <- a.err(fmt.Errorf("panic while running the agent"))
341		})
342		var attachmentParts []message.ContentPart
343		for _, attachment := range attachments {
344			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
345		}
346		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
347		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
348			logging.ErrorPersist(result.Error.Error())
349		}
350		logging.Debug("Request completed", "sessionID", sessionID)
351		a.activeRequests.Delete(sessionID)
352		cancel()
353		a.Publish(pubsub.CreatedEvent, result)
354		events <- result
355		close(events)
356	}()
357	return events, nil
358}
359
360func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
361	cfg := config.Get()
362	// List existing messages; if none, start title generation asynchronously.
363	msgs, err := a.messages.List(ctx, sessionID)
364	if err != nil {
365		return a.err(fmt.Errorf("failed to list messages: %w", err))
366	}
367	if len(msgs) == 0 {
368		go func() {
369			defer logging.RecoverPanic("agent.Run", func() {
370				logging.ErrorPersist("panic while generating title")
371			})
372			titleErr := a.generateTitle(context.Background(), sessionID, content)
373			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
374				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
375			}
376		}()
377	}
378	session, err := a.sessions.Get(ctx, sessionID)
379	if err != nil {
380		return a.err(fmt.Errorf("failed to get session: %w", err))
381	}
382	if session.SummaryMessageID != "" {
383		summaryMsgInex := -1
384		for i, msg := range msgs {
385			if msg.ID == session.SummaryMessageID {
386				summaryMsgInex = i
387				break
388			}
389		}
390		if summaryMsgInex != -1 {
391			msgs = msgs[summaryMsgInex:]
392			msgs[0].Role = message.User
393		}
394	}
395
396	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
397	if err != nil {
398		return a.err(fmt.Errorf("failed to create user message: %w", err))
399	}
400	// Append the new user message to the conversation history.
401	msgHistory := append(msgs, userMsg)
402
403	for {
404		// Check for cancellation before each iteration
405		select {
406		case <-ctx.Done():
407			return a.err(ctx.Err())
408		default:
409			// Continue processing
410		}
411		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
412		if err != nil {
413			if errors.Is(err, context.Canceled) {
414				agentMessage.AddFinish(message.FinishReasonCanceled)
415				a.messages.Update(context.Background(), agentMessage)
416				return a.err(ErrRequestCancelled)
417			}
418			return a.err(fmt.Errorf("failed to process events: %w", err))
419		}
420		if cfg.Options.Debug {
421			seqId := (len(msgHistory) + 1) / 2
422			toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults)
423			logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath)
424		} else {
425			logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
426		}
427		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
428			// We are not done, we need to respond with the tool response
429			msgHistory = append(msgHistory, agentMessage, *toolResults)
430			continue
431		}
432		return AgentEvent{
433			Type:    AgentEventTypeResponse,
434			Message: agentMessage,
435			Done:    true,
436		}
437	}
438}
439
440func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
441	parts := []message.ContentPart{message.TextContent{Text: content}}
442	parts = append(parts, attachmentParts...)
443	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444		Role:  message.User,
445		Parts: parts,
446	})
447}
448
449func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
450	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
451	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
452
453	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
454		Role:     message.Assistant,
455		Parts:    []message.ContentPart{},
456		Model:    a.Model().ID,
457		Provider: a.providerID,
458	})
459	if err != nil {
460		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
461	}
462
463	// Add the session and message ID into the context if needed by tools.
464	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
465
466	// Process each event in the stream.
467	for event := range eventChan {
468		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
469			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
470			return assistantMsg, nil, processErr
471		}
472		if ctx.Err() != nil {
473			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
474			return assistantMsg, nil, ctx.Err()
475		}
476	}
477
478	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
479	toolCalls := assistantMsg.ToolCalls()
480	for i, toolCall := range toolCalls {
481		select {
482		case <-ctx.Done():
483			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
484			// Make all future tool calls cancelled
485			for j := i; j < len(toolCalls); j++ {
486				toolResults[j] = message.ToolResult{
487					ToolCallID: toolCalls[j].ID,
488					Content:    "Tool execution canceled by user",
489					IsError:    true,
490				}
491			}
492			goto out
493		default:
494			// Continue processing
495			var tool tools.BaseTool
496			for _, availableTool := range a.tools {
497				if availableTool.Info().Name == toolCall.Name {
498					tool = availableTool
499					break
500				}
501			}
502
503			// Tool not found
504			if tool == nil {
505				toolResults[i] = message.ToolResult{
506					ToolCallID: toolCall.ID,
507					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
508					IsError:    true,
509				}
510				continue
511			}
512			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
513				ID:    toolCall.ID,
514				Name:  toolCall.Name,
515				Input: toolCall.Input,
516			})
517			if toolErr != nil {
518				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
519					toolResults[i] = message.ToolResult{
520						ToolCallID: toolCall.ID,
521						Content:    "Permission denied",
522						IsError:    true,
523					}
524					for j := i + 1; j < len(toolCalls); j++ {
525						toolResults[j] = message.ToolResult{
526							ToolCallID: toolCalls[j].ID,
527							Content:    "Tool execution canceled by user",
528							IsError:    true,
529						}
530					}
531					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
532					break
533				}
534			}
535			toolResults[i] = message.ToolResult{
536				ToolCallID: toolCall.ID,
537				Content:    toolResult.Content,
538				Metadata:   toolResult.Metadata,
539				IsError:    toolResult.IsError,
540			}
541		}
542	}
543out:
544	if len(toolResults) == 0 {
545		return assistantMsg, nil, nil
546	}
547	parts := make([]message.ContentPart, 0)
548	for _, tr := range toolResults {
549		parts = append(parts, tr)
550	}
551	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
552		Role:     message.Tool,
553		Parts:    parts,
554		Provider: a.providerID,
555	})
556	if err != nil {
557		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
558	}
559
560	return assistantMsg, &msg, err
561}
562
563func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
564	msg.AddFinish(finishReson)
565	_ = a.messages.Update(ctx, *msg)
566}
567
568func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
569	select {
570	case <-ctx.Done():
571		return ctx.Err()
572	default:
573		// Continue processing.
574	}
575
576	switch event.Type {
577	case provider.EventThinkingDelta:
578		assistantMsg.AppendReasoningContent(event.Content)
579		return a.messages.Update(ctx, *assistantMsg)
580	case provider.EventContentDelta:
581		assistantMsg.AppendContent(event.Content)
582		return a.messages.Update(ctx, *assistantMsg)
583	case provider.EventToolUseStart:
584		logging.Info("Tool call started", "toolCall", event.ToolCall)
585		assistantMsg.AddToolCall(*event.ToolCall)
586		return a.messages.Update(ctx, *assistantMsg)
587	case provider.EventToolUseDelta:
588		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
589		return a.messages.Update(ctx, *assistantMsg)
590	case provider.EventToolUseStop:
591		logging.Info("Finished tool call", "toolCall", event.ToolCall)
592		assistantMsg.FinishToolCall(event.ToolCall.ID)
593		return a.messages.Update(ctx, *assistantMsg)
594	case provider.EventError:
595		if errors.Is(event.Error, context.Canceled) {
596			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
597			return context.Canceled
598		}
599		logging.ErrorPersist(event.Error.Error())
600		return event.Error
601	case provider.EventComplete:
602		assistantMsg.SetToolCalls(event.Response.ToolCalls)
603		assistantMsg.AddFinish(event.Response.FinishReason)
604		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
605			return fmt.Errorf("failed to update message: %w", err)
606		}
607		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
608	}
609
610	return nil
611}
612
613func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
614	sess, err := a.sessions.Get(ctx, sessionID)
615	if err != nil {
616		return fmt.Errorf("failed to get session: %w", err)
617	}
618
619	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
620		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
621		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
622		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
623
624	sess.Cost += cost
625	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
626	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
627
628	_, err = a.sessions.Save(ctx, sess)
629	if err != nil {
630		return fmt.Errorf("failed to save session: %w", err)
631	}
632	return nil
633}
634
635func (a *agent) Summarize(ctx context.Context, sessionID string) error {
636	if a.summarizeProvider == nil {
637		return fmt.Errorf("summarize provider not available")
638	}
639
640	// Check if session is busy
641	if a.IsSessionBusy(sessionID) {
642		return ErrSessionBusy
643	}
644
645	// Create a new context with cancellation
646	summarizeCtx, cancel := context.WithCancel(ctx)
647
648	// Store the cancel function in activeRequests to allow cancellation
649	a.activeRequests.Store(sessionID+"-summarize", cancel)
650
651	go func() {
652		defer a.activeRequests.Delete(sessionID + "-summarize")
653		defer cancel()
654		event := AgentEvent{
655			Type:     AgentEventTypeSummarize,
656			Progress: "Starting summarization...",
657		}
658
659		a.Publish(pubsub.CreatedEvent, event)
660		// Get all messages from the session
661		msgs, err := a.messages.List(summarizeCtx, sessionID)
662		if err != nil {
663			event = AgentEvent{
664				Type:  AgentEventTypeError,
665				Error: fmt.Errorf("failed to list messages: %w", err),
666				Done:  true,
667			}
668			a.Publish(pubsub.CreatedEvent, event)
669			return
670		}
671		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
672
673		if len(msgs) == 0 {
674			event = AgentEvent{
675				Type:  AgentEventTypeError,
676				Error: fmt.Errorf("no messages to summarize"),
677				Done:  true,
678			}
679			a.Publish(pubsub.CreatedEvent, event)
680			return
681		}
682
683		event = AgentEvent{
684			Type:     AgentEventTypeSummarize,
685			Progress: "Analyzing conversation...",
686		}
687		a.Publish(pubsub.CreatedEvent, event)
688
689		// Add a system message to guide the summarization
690		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
691
692		// Create a new message with the summarize prompt
693		promptMsg := message.Message{
694			Role:  message.User,
695			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
696		}
697
698		// Append the prompt to the messages
699		msgsWithPrompt := append(msgs, promptMsg)
700
701		event = AgentEvent{
702			Type:     AgentEventTypeSummarize,
703			Progress: "Generating summary...",
704		}
705
706		a.Publish(pubsub.CreatedEvent, event)
707
708		// Send the messages to the summarize provider
709		response := a.summarizeProvider.StreamResponse(
710			summarizeCtx,
711			msgsWithPrompt,
712			make([]tools.BaseTool, 0),
713		)
714		var finalResponse *provider.ProviderResponse
715		for r := range response {
716			if r.Error != nil {
717				event = AgentEvent{
718					Type:  AgentEventTypeError,
719					Error: fmt.Errorf("failed to summarize: %w", err),
720					Done:  true,
721				}
722				a.Publish(pubsub.CreatedEvent, event)
723				return
724			}
725			finalResponse = r.Response
726		}
727
728		summary := strings.TrimSpace(finalResponse.Content)
729		if summary == "" {
730			event = AgentEvent{
731				Type:  AgentEventTypeError,
732				Error: fmt.Errorf("empty summary returned"),
733				Done:  true,
734			}
735			a.Publish(pubsub.CreatedEvent, event)
736			return
737		}
738		event = AgentEvent{
739			Type:     AgentEventTypeSummarize,
740			Progress: "Creating new session...",
741		}
742
743		a.Publish(pubsub.CreatedEvent, event)
744		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
745		if err != nil {
746			event = AgentEvent{
747				Type:  AgentEventTypeError,
748				Error: fmt.Errorf("failed to get session: %w", err),
749				Done:  true,
750			}
751
752			a.Publish(pubsub.CreatedEvent, event)
753			return
754		}
755		// Create a message in the new session with the summary
756		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
757			Role: message.Assistant,
758			Parts: []message.ContentPart{
759				message.TextContent{Text: summary},
760				message.Finish{
761					Reason: message.FinishReasonEndTurn,
762					Time:   time.Now().Unix(),
763				},
764			},
765			Model:    a.summarizeProvider.Model().ID,
766			Provider: a.summarizeProviderID,
767		})
768		if err != nil {
769			event = AgentEvent{
770				Type:  AgentEventTypeError,
771				Error: fmt.Errorf("failed to create summary message: %w", err),
772				Done:  true,
773			}
774
775			a.Publish(pubsub.CreatedEvent, event)
776			return
777		}
778		oldSession.SummaryMessageID = msg.ID
779		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
780		oldSession.PromptTokens = 0
781		model := a.summarizeProvider.Model()
782		usage := finalResponse.Usage
783		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
784			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
785			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
786			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
787		oldSession.Cost += cost
788		_, err = a.sessions.Save(summarizeCtx, oldSession)
789		if err != nil {
790			event = AgentEvent{
791				Type:  AgentEventTypeError,
792				Error: fmt.Errorf("failed to save session: %w", err),
793				Done:  true,
794			}
795			a.Publish(pubsub.CreatedEvent, event)
796		}
797
798		event = AgentEvent{
799			Type:      AgentEventTypeSummarize,
800			SessionID: oldSession.ID,
801			Progress:  "Summary complete",
802			Done:      true,
803		}
804		a.Publish(pubsub.CreatedEvent, event)
805		// Send final success event with the new session ID
806	}()
807
808	return nil
809}
810
811func (a *agent) CancelAll() {
812	a.activeRequests.Range(func(key, value any) bool {
813		a.Cancel(key.(string)) // key is sessionID
814		return true
815	})
816}
817
818func (a *agent) UpdateModel() error {
819	cfg := config.Get()
820
821	// Get current provider configuration
822	currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
823	if currentProviderCfg.ID == "" {
824		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
825	}
826
827	// Check if provider has changed
828	if string(currentProviderCfg.ID) != a.providerID {
829		// Provider changed, need to recreate the main provider
830		model := config.GetAgentModel(a.agentCfg.ID)
831		if model.ID == "" {
832			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
833		}
834
835		promptID := agentPromptMap[a.agentCfg.ID]
836		if promptID == "" {
837			promptID = prompt.PromptDefault
838		}
839
840		opts := []provider.ProviderClientOption{
841			provider.WithModel(a.agentCfg.Model),
842			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
843		}
844
845		newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
846		if err != nil {
847			return fmt.Errorf("failed to create new provider: %w", err)
848		}
849
850		// Update the provider and provider ID
851		a.provider = newProvider
852		a.providerID = string(currentProviderCfg.ID)
853	}
854
855	// Check if small model provider has changed (affects title and summarize providers)
856	smallModelCfg := cfg.Models.Small
857	var smallModelProviderCfg config.ProviderConfig
858
859	for _, p := range cfg.Providers {
860		if p.ID == smallModelCfg.Provider {
861			smallModelProviderCfg = p
862			break
863		}
864	}
865
866	if smallModelProviderCfg.ID == "" {
867		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
868	}
869
870	// Check if summarize provider has changed
871	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
872		var smallModel config.Model
873		for _, m := range smallModelProviderCfg.Models {
874			if m.ID == smallModelCfg.ModelID {
875				smallModel = m
876				break
877			}
878		}
879		if smallModel.ID == "" {
880			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
881		}
882
883		// Recreate title provider
884		titleOpts := []provider.ProviderClientOption{
885			provider.WithModel(config.SmallModel),
886			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
887			// We want the title to be short, so we limit the max tokens
888			provider.WithMaxTokens(40),
889		}
890		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
891		if err != nil {
892			return fmt.Errorf("failed to create new title provider: %w", err)
893		}
894
895		// Recreate summarize provider
896		summarizeOpts := []provider.ProviderClientOption{
897			provider.WithModel(config.SmallModel),
898			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
899		}
900		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
901		if err != nil {
902			return fmt.Errorf("failed to create new summarize provider: %w", err)
903		}
904
905		// Update the providers and provider ID
906		a.titleProvider = newTitleProvider
907		a.summarizeProvider = newSummarizeProvider
908		a.summarizeProviderID = string(smallModelProviderCfg.ID)
909	}
910
911	return nil
912}