agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"slices"
  8	"strings"
  9	"sync"
 10	"time"
 11
 12	"github.com/charmbracelet/crush/internal/config"
 13	"github.com/charmbracelet/crush/internal/history"
 14	"github.com/charmbracelet/crush/internal/llm/prompt"
 15	"github.com/charmbracelet/crush/internal/llm/provider"
 16	"github.com/charmbracelet/crush/internal/llm/tools"
 17	"github.com/charmbracelet/crush/internal/logging"
 18	"github.com/charmbracelet/crush/internal/lsp"
 19	"github.com/charmbracelet/crush/internal/message"
 20	"github.com/charmbracelet/crush/internal/permission"
 21	"github.com/charmbracelet/crush/internal/pubsub"
 22	"github.com/charmbracelet/crush/internal/session"
 23)
 24
 25// Common errors
 26var (
 27	ErrRequestCancelled = errors.New("request cancelled by user")
 28	ErrSessionBusy      = errors.New("session is currently processing another request")
 29)
 30
 31type AgentEventType string
 32
 33const (
 34	AgentEventTypeError     AgentEventType = "error"
 35	AgentEventTypeResponse  AgentEventType = "response"
 36	AgentEventTypeSummarize AgentEventType = "summarize"
 37)
 38
 39type AgentEvent struct {
 40	Type    AgentEventType
 41	Message message.Message
 42	Error   error
 43
 44	// When summarizing
 45	SessionID string
 46	Progress  string
 47	Done      bool
 48}
 49
 50type Service interface {
 51	pubsub.Suscriber[AgentEvent]
 52	Model() config.Model
 53	EffectiveMaxTokens() int64
 54	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 55	Cancel(sessionID string)
 56	CancelAll()
 57	IsSessionBusy(sessionID string) bool
 58	IsBusy() bool
 59	Summarize(ctx context.Context, sessionID string) error
 60	UpdateModel() error
 61}
 62
 63type agent struct {
 64	*pubsub.Broker[AgentEvent]
 65	agentCfg config.Agent
 66	sessions session.Service
 67	messages message.Service
 68
 69	tools      []tools.BaseTool
 70	provider   provider.Provider
 71	providerID string
 72
 73	titleProvider       provider.Provider
 74	summarizeProvider   provider.Provider
 75	summarizeProviderID string
 76
 77	activeRequests sync.Map
 78}
 79
 80var agentPromptMap = map[config.AgentID]prompt.PromptID{
 81	config.AgentCoder: prompt.PromptCoder,
 82	config.AgentTask:  prompt.PromptTask,
 83}
 84
 85func NewAgent(
 86	agentCfg config.Agent,
 87	// These services are needed in the tools
 88	permissions permission.Service,
 89	sessions session.Service,
 90	messages message.Service,
 91	history history.Service,
 92	lspClients map[string]*lsp.Client,
 93) (Service, error) {
 94	ctx := context.Background()
 95	cfg := config.Get()
 96	otherTools := GetMcpTools(ctx, permissions)
 97	if len(lspClients) > 0 {
 98		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
 99	}
100
101	allTools := []tools.BaseTool{
102		tools.NewBashTool(permissions),
103		tools.NewEditTool(lspClients, permissions, history),
104		tools.NewFetchTool(permissions),
105		tools.NewGlobTool(),
106		tools.NewGrepTool(),
107		tools.NewLsTool(),
108		tools.NewSourcegraphTool(),
109		tools.NewViewTool(lspClients),
110		tools.NewWriteTool(lspClients, permissions, history),
111	}
112
113	if agentCfg.ID == config.AgentCoder {
114		taskAgentCfg := config.Get().Agents[config.AgentTask]
115		if taskAgentCfg.ID == "" {
116			return nil, fmt.Errorf("task agent not found in config")
117		}
118		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
119		if err != nil {
120			return nil, fmt.Errorf("failed to create task agent: %w", err)
121		}
122
123		allTools = append(
124			allTools,
125			NewAgentTool(
126				taskAgent,
127				sessions,
128				messages,
129			),
130		)
131	}
132
133	allTools = append(allTools, otherTools...)
134	providerCfg := config.GetAgentProvider(agentCfg.ID)
135	if providerCfg.ID == "" {
136		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
137	}
138	model := config.GetAgentModel(agentCfg.ID)
139
140	if model.ID == "" {
141		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
142	}
143
144	promptID := agentPromptMap[agentCfg.ID]
145	if promptID == "" {
146		promptID = prompt.PromptDefault
147	}
148	opts := []provider.ProviderClientOption{
149		provider.WithModel(agentCfg.Model),
150		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
151	}
152	agentProvider, err := provider.NewProvider(providerCfg, opts...)
153	if err != nil {
154		return nil, err
155	}
156
157	smallModelCfg := cfg.Models.Small
158	var smallModel config.Model
159
160	var smallModelProviderCfg config.ProviderConfig
161	if smallModelCfg.Provider == providerCfg.ID {
162		smallModelProviderCfg = providerCfg
163	} else {
164		for _, p := range cfg.Providers {
165			if p.ID == smallModelCfg.Provider {
166				smallModelProviderCfg = p
167				break
168			}
169		}
170		if smallModelProviderCfg.ID == "" {
171			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
172		}
173	}
174	for _, m := range smallModelProviderCfg.Models {
175		if m.ID == smallModelCfg.ModelID {
176			smallModel = m
177			break
178		}
179	}
180	if smallModel.ID == "" {
181		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
182	}
183
184	titleOpts := []provider.ProviderClientOption{
185		provider.WithModel(config.SmallModel),
186		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
187	}
188	titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
189	if err != nil {
190		return nil, err
191	}
192	summarizeOpts := []provider.ProviderClientOption{
193		provider.WithModel(config.SmallModel),
194		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
195	}
196	summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
197	if err != nil {
198		return nil, err
199	}
200
201	agentTools := []tools.BaseTool{}
202	if agentCfg.AllowedTools == nil {
203		agentTools = allTools
204	} else {
205		for _, tool := range allTools {
206			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
207				agentTools = append(agentTools, tool)
208			}
209		}
210	}
211
212	agent := &agent{
213		Broker:              pubsub.NewBroker[AgentEvent](),
214		agentCfg:            agentCfg,
215		provider:            agentProvider,
216		providerID:          string(providerCfg.ID),
217		messages:            messages,
218		sessions:            sessions,
219		tools:               agentTools,
220		titleProvider:       titleProvider,
221		summarizeProvider:   summarizeProvider,
222		summarizeProviderID: string(smallModelProviderCfg.ID),
223		activeRequests:      sync.Map{},
224	}
225
226	return agent, nil
227}
228
229func (a *agent) Model() config.Model {
230	return config.GetAgentModel(a.agentCfg.ID)
231}
232
233func (a *agent) EffectiveMaxTokens() int64 {
234	return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
235}
236
237func (a *agent) Cancel(sessionID string) {
238	// Cancel regular requests
239	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
240		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
241			logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
242			cancel()
243		}
244	}
245
246	// Also check for summarize requests
247	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
248		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
249			logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
250			cancel()
251		}
252	}
253}
254
255func (a *agent) IsBusy() bool {
256	busy := false
257	a.activeRequests.Range(func(key, value any) bool {
258		if cancelFunc, ok := value.(context.CancelFunc); ok {
259			if cancelFunc != nil {
260				busy = true
261				return false
262			}
263		}
264		return true
265	})
266	return busy
267}
268
269func (a *agent) IsSessionBusy(sessionID string) bool {
270	_, busy := a.activeRequests.Load(sessionID)
271	return busy
272}
273
274func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
275	if content == "" {
276		return nil
277	}
278	if a.titleProvider == nil {
279		return nil
280	}
281	session, err := a.sessions.Get(ctx, sessionID)
282	if err != nil {
283		return err
284	}
285	parts := []message.ContentPart{message.TextContent{
286		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
287	}}
288
289	// Use streaming approach like summarization
290	response := a.titleProvider.StreamResponse(
291		ctx,
292		[]message.Message{
293			{
294				Role:  message.User,
295				Parts: parts,
296			},
297		},
298		make([]tools.BaseTool, 0),
299	)
300
301	var finalResponse *provider.ProviderResponse
302	for r := range response {
303		if r.Error != nil {
304			return r.Error
305		}
306		finalResponse = r.Response
307	}
308
309	if finalResponse == nil {
310		return fmt.Errorf("no response received from title provider")
311	}
312
313	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
314	if title == "" {
315		return nil
316	}
317
318	session.Title = title
319	_, err = a.sessions.Save(ctx, session)
320	return err
321}
322
323func (a *agent) err(err error) AgentEvent {
324	return AgentEvent{
325		Type:  AgentEventTypeError,
326		Error: err,
327	}
328}
329
330func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
331	if !a.Model().SupportsImages && attachments != nil {
332		attachments = nil
333	}
334	events := make(chan AgentEvent)
335	if a.IsSessionBusy(sessionID) {
336		return nil, ErrSessionBusy
337	}
338
339	genCtx, cancel := context.WithCancel(ctx)
340
341	a.activeRequests.Store(sessionID, cancel)
342	go func() {
343		logging.Debug("Request started", "sessionID", sessionID)
344		defer logging.RecoverPanic("agent.Run", func() {
345			events <- a.err(fmt.Errorf("panic while running the agent"))
346		})
347		var attachmentParts []message.ContentPart
348		for _, attachment := range attachments {
349			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
350		}
351		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
352		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
353			logging.ErrorPersist(result.Error.Error())
354		}
355		logging.Debug("Request completed", "sessionID", sessionID)
356		a.activeRequests.Delete(sessionID)
357		cancel()
358		a.Publish(pubsub.CreatedEvent, result)
359		events <- result
360		close(events)
361	}()
362	return events, nil
363}
364
365func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
366	cfg := config.Get()
367	// List existing messages; if none, start title generation asynchronously.
368	msgs, err := a.messages.List(ctx, sessionID)
369	if err != nil {
370		return a.err(fmt.Errorf("failed to list messages: %w", err))
371	}
372	if len(msgs) == 0 {
373		go func() {
374			defer logging.RecoverPanic("agent.Run", func() {
375				logging.ErrorPersist("panic while generating title")
376			})
377			titleErr := a.generateTitle(context.Background(), sessionID, content)
378			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
379				logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
380			}
381		}()
382	}
383	session, err := a.sessions.Get(ctx, sessionID)
384	if err != nil {
385		return a.err(fmt.Errorf("failed to get session: %w", err))
386	}
387	if session.SummaryMessageID != "" {
388		summaryMsgInex := -1
389		for i, msg := range msgs {
390			if msg.ID == session.SummaryMessageID {
391				summaryMsgInex = i
392				break
393			}
394		}
395		if summaryMsgInex != -1 {
396			msgs = msgs[summaryMsgInex:]
397			msgs[0].Role = message.User
398		}
399	}
400
401	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
402	if err != nil {
403		return a.err(fmt.Errorf("failed to create user message: %w", err))
404	}
405	// Append the new user message to the conversation history.
406	msgHistory := append(msgs, userMsg)
407
408	for {
409		// Check for cancellation before each iteration
410		select {
411		case <-ctx.Done():
412			return a.err(ctx.Err())
413		default:
414			// Continue processing
415		}
416		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
417		if err != nil {
418			if errors.Is(err, context.Canceled) {
419				agentMessage.AddFinish(message.FinishReasonCanceled)
420				a.messages.Update(context.Background(), agentMessage)
421				return a.err(ErrRequestCancelled)
422			}
423			return a.err(fmt.Errorf("failed to process events: %w", err))
424		}
425		if cfg.Options.Debug {
426			seqId := (len(msgHistory) + 1) / 2
427			toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults)
428			logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath)
429		} else {
430			logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
431		}
432		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
433			// We are not done, we need to respond with the tool response
434			msgHistory = append(msgHistory, agentMessage, *toolResults)
435			continue
436		}
437		return AgentEvent{
438			Type:    AgentEventTypeResponse,
439			Message: agentMessage,
440			Done:    true,
441		}
442	}
443}
444
445func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
446	parts := []message.ContentPart{message.TextContent{Text: content}}
447	parts = append(parts, attachmentParts...)
448	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
449		Role:  message.User,
450		Parts: parts,
451	})
452}
453
454func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
455	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
456	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
457
458	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
459		Role:     message.Assistant,
460		Parts:    []message.ContentPart{},
461		Model:    a.Model().ID,
462		Provider: a.providerID,
463	})
464	if err != nil {
465		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
466	}
467
468	// Add the session and message ID into the context if needed by tools.
469	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
470
471	// Process each event in the stream.
472	for event := range eventChan {
473		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
474			a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
475			return assistantMsg, nil, processErr
476		}
477		if ctx.Err() != nil {
478			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
479			return assistantMsg, nil, ctx.Err()
480		}
481	}
482
483	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
484	toolCalls := assistantMsg.ToolCalls()
485	for i, toolCall := range toolCalls {
486		select {
487		case <-ctx.Done():
488			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
489			// Make all future tool calls cancelled
490			for j := i; j < len(toolCalls); j++ {
491				toolResults[j] = message.ToolResult{
492					ToolCallID: toolCalls[j].ID,
493					Content:    "Tool execution canceled by user",
494					IsError:    true,
495				}
496			}
497			goto out
498		default:
499			// Continue processing
500			var tool tools.BaseTool
501			for _, availableTool := range a.tools {
502				if availableTool.Info().Name == toolCall.Name {
503					tool = availableTool
504					break
505				}
506			}
507
508			// Tool not found
509			if tool == nil {
510				toolResults[i] = message.ToolResult{
511					ToolCallID: toolCall.ID,
512					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
513					IsError:    true,
514				}
515				continue
516			}
517			toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
518				ID:    toolCall.ID,
519				Name:  toolCall.Name,
520				Input: toolCall.Input,
521			})
522			if toolErr != nil {
523				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
524					toolResults[i] = message.ToolResult{
525						ToolCallID: toolCall.ID,
526						Content:    "Permission denied",
527						IsError:    true,
528					}
529					for j := i + 1; j < len(toolCalls); j++ {
530						toolResults[j] = message.ToolResult{
531							ToolCallID: toolCalls[j].ID,
532							Content:    "Tool execution canceled by user",
533							IsError:    true,
534						}
535					}
536					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
537					break
538				}
539			}
540			toolResults[i] = message.ToolResult{
541				ToolCallID: toolCall.ID,
542				Content:    toolResult.Content,
543				Metadata:   toolResult.Metadata,
544				IsError:    toolResult.IsError,
545			}
546		}
547	}
548out:
549	if len(toolResults) == 0 {
550		return assistantMsg, nil, nil
551	}
552	parts := make([]message.ContentPart, 0)
553	for _, tr := range toolResults {
554		parts = append(parts, tr)
555	}
556	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
557		Role:     message.Tool,
558		Parts:    parts,
559		Provider: a.providerID,
560	})
561	if err != nil {
562		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
563	}
564
565	return assistantMsg, &msg, err
566}
567
568func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
569	msg.AddFinish(finishReson)
570	_ = a.messages.Update(ctx, *msg)
571}
572
573func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
574	select {
575	case <-ctx.Done():
576		return ctx.Err()
577	default:
578		// Continue processing.
579	}
580
581	switch event.Type {
582	case provider.EventThinkingDelta:
583		assistantMsg.AppendReasoningContent(event.Content)
584		return a.messages.Update(ctx, *assistantMsg)
585	case provider.EventContentDelta:
586		assistantMsg.AppendContent(event.Content)
587		return a.messages.Update(ctx, *assistantMsg)
588	case provider.EventToolUseStart:
589		logging.Info("Tool call started", "toolCall", event.ToolCall)
590		assistantMsg.AddToolCall(*event.ToolCall)
591		return a.messages.Update(ctx, *assistantMsg)
592	case provider.EventToolUseDelta:
593		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
594		return a.messages.Update(ctx, *assistantMsg)
595	case provider.EventToolUseStop:
596		logging.Info("Finished tool call", "toolCall", event.ToolCall)
597		assistantMsg.FinishToolCall(event.ToolCall.ID)
598		return a.messages.Update(ctx, *assistantMsg)
599	case provider.EventError:
600		if errors.Is(event.Error, context.Canceled) {
601			logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
602			return context.Canceled
603		}
604		logging.ErrorPersist(event.Error.Error())
605		return event.Error
606	case provider.EventComplete:
607		assistantMsg.SetToolCalls(event.Response.ToolCalls)
608		assistantMsg.AddFinish(event.Response.FinishReason)
609		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
610			return fmt.Errorf("failed to update message: %w", err)
611		}
612		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
613	}
614
615	return nil
616}
617
618func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
619	sess, err := a.sessions.Get(ctx, sessionID)
620	if err != nil {
621		return fmt.Errorf("failed to get session: %w", err)
622	}
623
624	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
625		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
626		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
627		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
628
629	sess.Cost += cost
630	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
631	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
632
633	_, err = a.sessions.Save(ctx, sess)
634	if err != nil {
635		return fmt.Errorf("failed to save session: %w", err)
636	}
637	return nil
638}
639
640func (a *agent) Summarize(ctx context.Context, sessionID string) error {
641	if a.summarizeProvider == nil {
642		return fmt.Errorf("summarize provider not available")
643	}
644
645	// Check if session is busy
646	if a.IsSessionBusy(sessionID) {
647		return ErrSessionBusy
648	}
649
650	// Create a new context with cancellation
651	summarizeCtx, cancel := context.WithCancel(ctx)
652
653	// Store the cancel function in activeRequests to allow cancellation
654	a.activeRequests.Store(sessionID+"-summarize", cancel)
655
656	go func() {
657		defer a.activeRequests.Delete(sessionID + "-summarize")
658		defer cancel()
659		event := AgentEvent{
660			Type:     AgentEventTypeSummarize,
661			Progress: "Starting summarization...",
662		}
663
664		a.Publish(pubsub.CreatedEvent, event)
665		// Get all messages from the session
666		msgs, err := a.messages.List(summarizeCtx, sessionID)
667		if err != nil {
668			event = AgentEvent{
669				Type:  AgentEventTypeError,
670				Error: fmt.Errorf("failed to list messages: %w", err),
671				Done:  true,
672			}
673			a.Publish(pubsub.CreatedEvent, event)
674			return
675		}
676		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
677
678		if len(msgs) == 0 {
679			event = AgentEvent{
680				Type:  AgentEventTypeError,
681				Error: fmt.Errorf("no messages to summarize"),
682				Done:  true,
683			}
684			a.Publish(pubsub.CreatedEvent, event)
685			return
686		}
687
688		event = AgentEvent{
689			Type:     AgentEventTypeSummarize,
690			Progress: "Analyzing conversation...",
691		}
692		a.Publish(pubsub.CreatedEvent, event)
693
694		// Add a system message to guide the summarization
695		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
696
697		// Create a new message with the summarize prompt
698		promptMsg := message.Message{
699			Role:  message.User,
700			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
701		}
702
703		// Append the prompt to the messages
704		msgsWithPrompt := append(msgs, promptMsg)
705
706		event = AgentEvent{
707			Type:     AgentEventTypeSummarize,
708			Progress: "Generating summary...",
709		}
710
711		a.Publish(pubsub.CreatedEvent, event)
712
713		// Send the messages to the summarize provider
714		response := a.summarizeProvider.StreamResponse(
715			summarizeCtx,
716			msgsWithPrompt,
717			make([]tools.BaseTool, 0),
718		)
719		var finalResponse *provider.ProviderResponse
720		for r := range response {
721			if r.Error != nil {
722				event = AgentEvent{
723					Type:  AgentEventTypeError,
724					Error: fmt.Errorf("failed to summarize: %w", err),
725					Done:  true,
726				}
727				a.Publish(pubsub.CreatedEvent, event)
728				return
729			}
730			finalResponse = r.Response
731		}
732
733		summary := strings.TrimSpace(finalResponse.Content)
734		if summary == "" {
735			event = AgentEvent{
736				Type:  AgentEventTypeError,
737				Error: fmt.Errorf("empty summary returned"),
738				Done:  true,
739			}
740			a.Publish(pubsub.CreatedEvent, event)
741			return
742		}
743		event = AgentEvent{
744			Type:     AgentEventTypeSummarize,
745			Progress: "Creating new session...",
746		}
747
748		a.Publish(pubsub.CreatedEvent, event)
749		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
750		if err != nil {
751			event = AgentEvent{
752				Type:  AgentEventTypeError,
753				Error: fmt.Errorf("failed to get session: %w", err),
754				Done:  true,
755			}
756
757			a.Publish(pubsub.CreatedEvent, event)
758			return
759		}
760		// Create a message in the new session with the summary
761		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
762			Role: message.Assistant,
763			Parts: []message.ContentPart{
764				message.TextContent{Text: summary},
765				message.Finish{
766					Reason: message.FinishReasonEndTurn,
767					Time:   time.Now().Unix(),
768				},
769			},
770			Model:    a.summarizeProvider.Model().ID,
771			Provider: a.summarizeProviderID,
772		})
773		if err != nil {
774			event = AgentEvent{
775				Type:  AgentEventTypeError,
776				Error: fmt.Errorf("failed to create summary message: %w", err),
777				Done:  true,
778			}
779
780			a.Publish(pubsub.CreatedEvent, event)
781			return
782		}
783		oldSession.SummaryMessageID = msg.ID
784		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
785		oldSession.PromptTokens = 0
786		model := a.summarizeProvider.Model()
787		usage := finalResponse.Usage
788		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
789			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
790			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
791			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
792		oldSession.Cost += cost
793		_, err = a.sessions.Save(summarizeCtx, oldSession)
794		if err != nil {
795			event = AgentEvent{
796				Type:  AgentEventTypeError,
797				Error: fmt.Errorf("failed to save session: %w", err),
798				Done:  true,
799			}
800			a.Publish(pubsub.CreatedEvent, event)
801		}
802
803		event = AgentEvent{
804			Type:      AgentEventTypeSummarize,
805			SessionID: oldSession.ID,
806			Progress:  "Summary complete",
807			Done:      true,
808		}
809		a.Publish(pubsub.CreatedEvent, event)
810		// Send final success event with the new session ID
811	}()
812
813	return nil
814}
815
816func (a *agent) CancelAll() {
817	a.activeRequests.Range(func(key, value any) bool {
818		a.Cancel(key.(string)) // key is sessionID
819		return true
820	})
821}
822
823func (a *agent) UpdateModel() error {
824	cfg := config.Get()
825
826	// Get current provider configuration
827	currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
828	if currentProviderCfg.ID == "" {
829		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
830	}
831
832	// Check if provider has changed
833	if string(currentProviderCfg.ID) != a.providerID {
834		// Provider changed, need to recreate the main provider
835		model := config.GetAgentModel(a.agentCfg.ID)
836		if model.ID == "" {
837			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
838		}
839
840		promptID := agentPromptMap[a.agentCfg.ID]
841		if promptID == "" {
842			promptID = prompt.PromptDefault
843		}
844
845		opts := []provider.ProviderClientOption{
846			provider.WithModel(a.agentCfg.Model),
847			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
848		}
849
850		newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
851		if err != nil {
852			return fmt.Errorf("failed to create new provider: %w", err)
853		}
854
855		// Update the provider and provider ID
856		a.provider = newProvider
857		a.providerID = string(currentProviderCfg.ID)
858	}
859
860	// Check if small model provider has changed (affects title and summarize providers)
861	smallModelCfg := cfg.Models.Small
862	var smallModelProviderCfg config.ProviderConfig
863
864	for _, p := range cfg.Providers {
865		if p.ID == smallModelCfg.Provider {
866			smallModelProviderCfg = p
867			break
868		}
869	}
870
871	if smallModelProviderCfg.ID == "" {
872		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
873	}
874
875	// Check if summarize provider has changed
876	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
877		var smallModel config.Model
878		for _, m := range smallModelProviderCfg.Models {
879			if m.ID == smallModelCfg.ModelID {
880				smallModel = m
881				break
882			}
883		}
884		if smallModel.ID == "" {
885			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
886		}
887
888		// Recreate title provider
889		titleOpts := []provider.ProviderClientOption{
890			provider.WithModel(config.SmallModel),
891			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
892			// We want the title to be short, so we limit the max tokens
893			provider.WithMaxTokens(40),
894		}
895		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
896		if err != nil {
897			return fmt.Errorf("failed to create new title provider: %w", err)
898		}
899
900		// Recreate summarize provider
901		summarizeOpts := []provider.ProviderClientOption{
902			provider.WithModel(config.SmallModel),
903			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
904		}
905		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
906		if err != nil {
907			return fmt.Errorf("failed to create new summarize provider: %w", err)
908		}
909
910		// Update the providers and provider ID
911		a.titleProvider = newTitleProvider
912		a.summarizeProvider = newSummarizeProvider
913		a.summarizeProviderID = string(smallModelProviderCfg.ID)
914	}
915
916	return nil
917}