agent.go

  1package agent
  2
  3import (
  4	"context"
  5	"errors"
  6	"fmt"
  7	"log/slog"
  8	"slices"
  9	"strings"
 10	"sync"
 11	"time"
 12
 13	"github.com/charmbracelet/crush/internal/config"
 14	fur "github.com/charmbracelet/crush/internal/fur/provider"
 15	"github.com/charmbracelet/crush/internal/history"
 16	"github.com/charmbracelet/crush/internal/llm/prompt"
 17	"github.com/charmbracelet/crush/internal/llm/provider"
 18	"github.com/charmbracelet/crush/internal/llm/tools"
 19	"github.com/charmbracelet/crush/internal/log"
 20	"github.com/charmbracelet/crush/internal/lsp"
 21	"github.com/charmbracelet/crush/internal/message"
 22	"github.com/charmbracelet/crush/internal/permission"
 23	"github.com/charmbracelet/crush/internal/pubsub"
 24	"github.com/charmbracelet/crush/internal/session"
 25)
 26
 27// Common errors
 28var (
 29	ErrRequestCancelled = errors.New("request canceled by user")
 30	ErrSessionBusy      = errors.New("session is currently processing another request")
 31)
 32
 33type AgentEventType string
 34
 35const (
 36	AgentEventTypeError     AgentEventType = "error"
 37	AgentEventTypeResponse  AgentEventType = "response"
 38	AgentEventTypeSummarize AgentEventType = "summarize"
 39)
 40
 41type AgentEvent struct {
 42	Type    AgentEventType
 43	Message message.Message
 44	Error   error
 45
 46	// When summarizing
 47	SessionID string
 48	Progress  string
 49	Done      bool
 50}
 51
 52type Service interface {
 53	pubsub.Suscriber[AgentEvent]
 54	Model() fur.Model
 55	Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
 56	Cancel(sessionID string)
 57	CancelAll()
 58	IsSessionBusy(sessionID string) bool
 59	IsBusy() bool
 60	Summarize(ctx context.Context, sessionID string) error
 61	UpdateModel() error
 62}
 63
 64type agent struct {
 65	*pubsub.Broker[AgentEvent]
 66	agentCfg config.Agent
 67	sessions session.Service
 68	messages message.Service
 69
 70	tools      []tools.BaseTool
 71	provider   provider.Provider
 72	providerID string
 73
 74	titleProvider       provider.Provider
 75	summarizeProvider   provider.Provider
 76	summarizeProviderID string
 77
 78	activeRequests sync.Map
 79}
 80
 81var agentPromptMap = map[string]prompt.PromptID{
 82	"coder": prompt.PromptCoder,
 83	"task":  prompt.PromptTask,
 84}
 85
 86func NewAgent(
 87	agentCfg config.Agent,
 88	// These services are needed in the tools
 89	permissions permission.Service,
 90	sessions session.Service,
 91	messages message.Service,
 92	history history.Service,
 93	lspClients map[string]*lsp.Client,
 94) (Service, error) {
 95	ctx := context.Background()
 96	cfg := config.Get()
 97	otherTools := GetMCPTools(ctx, permissions, cfg)
 98	if len(lspClients) > 0 {
 99		otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100	}
101
102	cwd := cfg.WorkingDir()
103	allTools := []tools.BaseTool{
104		tools.NewBashTool(permissions, cwd),
105		tools.NewDownloadTool(permissions, cwd),
106		tools.NewEditTool(lspClients, permissions, history, cwd),
107		tools.NewFetchTool(permissions, cwd),
108		tools.NewGlobTool(cwd),
109		tools.NewGrepTool(cwd),
110		tools.NewLsTool(cwd),
111		tools.NewSourcegraphTool(),
112		tools.NewViewTool(lspClients, cwd),
113		tools.NewWriteTool(lspClients, permissions, history, cwd),
114	}
115
116	if agentCfg.ID == "coder" {
117		taskAgentCfg := config.Get().Agents["task"]
118		if taskAgentCfg.ID == "" {
119			return nil, fmt.Errorf("task agent not found in config")
120		}
121		taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
122		if err != nil {
123			return nil, fmt.Errorf("failed to create task agent: %w", err)
124		}
125
126		allTools = append(
127			allTools,
128			NewAgentTool(
129				taskAgent,
130				sessions,
131				messages,
132			),
133		)
134	}
135
136	allTools = append(allTools, otherTools...)
137	providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
138	if providerCfg == nil {
139		return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
140	}
141	model := config.Get().GetModelByType(agentCfg.Model)
142
143	if model == nil {
144		return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
145	}
146
147	promptID := agentPromptMap[agentCfg.ID]
148	if promptID == "" {
149		promptID = prompt.PromptDefault
150	}
151	opts := []provider.ProviderClientOption{
152		provider.WithModel(agentCfg.Model),
153		provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
154	}
155	agentProvider, err := provider.NewProvider(*providerCfg, opts...)
156	if err != nil {
157		return nil, err
158	}
159
160	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
161	var smallModelProviderCfg *config.ProviderConfig
162	if smallModelCfg.Provider == providerCfg.ID {
163		smallModelProviderCfg = providerCfg
164	} else {
165		smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
166
167		if smallModelProviderCfg.ID == "" {
168			return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
169		}
170	}
171	smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
172	if smallModel.ID == "" {
173		return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
174	}
175
176	titleOpts := []provider.ProviderClientOption{
177		provider.WithModel(config.SelectedModelTypeSmall),
178		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
179	}
180	titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
181	if err != nil {
182		return nil, err
183	}
184	summarizeOpts := []provider.ProviderClientOption{
185		provider.WithModel(config.SelectedModelTypeSmall),
186		provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
187	}
188	summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
189	if err != nil {
190		return nil, err
191	}
192
193	agentTools := []tools.BaseTool{}
194	if agentCfg.AllowedTools == nil {
195		agentTools = allTools
196	} else {
197		for _, tool := range allTools {
198			if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
199				agentTools = append(agentTools, tool)
200			}
201		}
202	}
203
204	agent := &agent{
205		Broker:              pubsub.NewBroker[AgentEvent](),
206		agentCfg:            agentCfg,
207		provider:            agentProvider,
208		providerID:          string(providerCfg.ID),
209		messages:            messages,
210		sessions:            sessions,
211		tools:               agentTools,
212		titleProvider:       titleProvider,
213		summarizeProvider:   summarizeProvider,
214		summarizeProviderID: string(smallModelProviderCfg.ID),
215		activeRequests:      sync.Map{},
216	}
217
218	return agent, nil
219}
220
221func (a *agent) Model() fur.Model {
222	return *config.Get().GetModelByType(a.agentCfg.Model)
223}
224
225func (a *agent) Cancel(sessionID string) {
226	// Cancel regular requests
227	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
228		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
229			slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
230			cancel()
231		}
232	}
233
234	// Also check for summarize requests
235	if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
236		if cancel, ok := cancelFunc.(context.CancelFunc); ok {
237			slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
238			cancel()
239		}
240	}
241}
242
243func (a *agent) IsBusy() bool {
244	busy := false
245	a.activeRequests.Range(func(key, value any) bool {
246		if cancelFunc, ok := value.(context.CancelFunc); ok {
247			if cancelFunc != nil {
248				busy = true
249				return false
250			}
251		}
252		return true
253	})
254	return busy
255}
256
257func (a *agent) IsSessionBusy(sessionID string) bool {
258	_, busy := a.activeRequests.Load(sessionID)
259	return busy
260}
261
262func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
263	if content == "" {
264		return nil
265	}
266	if a.titleProvider == nil {
267		return nil
268	}
269	session, err := a.sessions.Get(ctx, sessionID)
270	if err != nil {
271		return err
272	}
273	parts := []message.ContentPart{message.TextContent{
274		Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
275	}}
276
277	// Use streaming approach like summarization
278	response := a.titleProvider.StreamResponse(
279		ctx,
280		[]message.Message{
281			{
282				Role:  message.User,
283				Parts: parts,
284			},
285		},
286		make([]tools.BaseTool, 0),
287	)
288
289	var finalResponse *provider.ProviderResponse
290	for r := range response {
291		if r.Error != nil {
292			return r.Error
293		}
294		finalResponse = r.Response
295	}
296
297	if finalResponse == nil {
298		return fmt.Errorf("no response received from title provider")
299	}
300
301	title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
302	if title == "" {
303		return nil
304	}
305
306	session.Title = title
307	_, err = a.sessions.Save(ctx, session)
308	return err
309}
310
311func (a *agent) err(err error) AgentEvent {
312	return AgentEvent{
313		Type:  AgentEventTypeError,
314		Error: err,
315	}
316}
317
318func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
319	if !a.Model().SupportsImages && attachments != nil {
320		attachments = nil
321	}
322	events := make(chan AgentEvent)
323	if a.IsSessionBusy(sessionID) {
324		return nil, ErrSessionBusy
325	}
326
327	genCtx, cancel := context.WithCancel(ctx)
328
329	a.activeRequests.Store(sessionID, cancel)
330	go func() {
331		slog.Debug("Request started", "sessionID", sessionID)
332		defer log.RecoverPanic("agent.Run", func() {
333			events <- a.err(fmt.Errorf("panic while running the agent"))
334		})
335		var attachmentParts []message.ContentPart
336		for _, attachment := range attachments {
337			attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
338		}
339		result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
340		if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
341			slog.Error(result.Error.Error())
342		}
343		slog.Debug("Request completed", "sessionID", sessionID)
344		a.activeRequests.Delete(sessionID)
345		cancel()
346		a.Publish(pubsub.CreatedEvent, result)
347		events <- result
348		close(events)
349	}()
350	return events, nil
351}
352
353func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
354	cfg := config.Get()
355	// List existing messages; if none, start title generation asynchronously.
356	msgs, err := a.messages.List(ctx, sessionID)
357	if err != nil {
358		return a.err(fmt.Errorf("failed to list messages: %w", err))
359	}
360	if len(msgs) == 0 {
361		go func() {
362			defer log.RecoverPanic("agent.Run", func() {
363				slog.Error("panic while generating title")
364			})
365			titleErr := a.generateTitle(context.Background(), sessionID, content)
366			if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
367				slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
368			}
369		}()
370	}
371	session, err := a.sessions.Get(ctx, sessionID)
372	if err != nil {
373		return a.err(fmt.Errorf("failed to get session: %w", err))
374	}
375	if session.SummaryMessageID != "" {
376		summaryMsgInex := -1
377		for i, msg := range msgs {
378			if msg.ID == session.SummaryMessageID {
379				summaryMsgInex = i
380				break
381			}
382		}
383		if summaryMsgInex != -1 {
384			msgs = msgs[summaryMsgInex:]
385			msgs[0].Role = message.User
386		}
387	}
388
389	userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
390	if err != nil {
391		return a.err(fmt.Errorf("failed to create user message: %w", err))
392	}
393	// Append the new user message to the conversation history.
394	msgHistory := append(msgs, userMsg)
395
396	for {
397		// Check for cancellation before each iteration
398		select {
399		case <-ctx.Done():
400			return a.err(ctx.Err())
401		default:
402			// Continue processing
403		}
404		agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
405		if err != nil {
406			if errors.Is(err, context.Canceled) {
407				agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
408				a.messages.Update(context.Background(), agentMessage)
409				return a.err(ErrRequestCancelled)
410			}
411			return a.err(fmt.Errorf("failed to process events: %w", err))
412		}
413		if cfg.Options.Debug {
414			slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
415		}
416		if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
417			// We are not done, we need to respond with the tool response
418			msgHistory = append(msgHistory, agentMessage, *toolResults)
419			continue
420		}
421		return AgentEvent{
422			Type:    AgentEventTypeResponse,
423			Message: agentMessage,
424			Done:    true,
425		}
426	}
427}
428
429func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
430	parts := []message.ContentPart{message.TextContent{Text: content}}
431	parts = append(parts, attachmentParts...)
432	return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
433		Role:  message.User,
434		Parts: parts,
435	})
436}
437
438func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
439	ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
440	eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
441
442	assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
443		Role:     message.Assistant,
444		Parts:    []message.ContentPart{},
445		Model:    a.Model().ID,
446		Provider: a.providerID,
447	})
448	if err != nil {
449		return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
450	}
451
452	// Add the session and message ID into the context if needed by tools.
453	ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
454
455	// Process each event in the stream.
456	for event := range eventChan {
457		if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
458			if errors.Is(processErr, context.Canceled) {
459				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
460			} else {
461				a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
462			}
463			return assistantMsg, nil, processErr
464		}
465		if ctx.Err() != nil {
466			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
467			return assistantMsg, nil, ctx.Err()
468		}
469	}
470
471	toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
472	toolCalls := assistantMsg.ToolCalls()
473	for i, toolCall := range toolCalls {
474		select {
475		case <-ctx.Done():
476			a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
477			// Make all future tool calls cancelled
478			for j := i; j < len(toolCalls); j++ {
479				toolResults[j] = message.ToolResult{
480					ToolCallID: toolCalls[j].ID,
481					Content:    "Tool execution canceled by user",
482					IsError:    true,
483				}
484			}
485			goto out
486		default:
487			// Continue processing
488			var tool tools.BaseTool
489			for _, availableTool := range a.tools {
490				if availableTool.Info().Name == toolCall.Name {
491					tool = availableTool
492					break
493				}
494			}
495
496			// Tool not found
497			if tool == nil {
498				toolResults[i] = message.ToolResult{
499					ToolCallID: toolCall.ID,
500					Content:    fmt.Sprintf("Tool not found: %s", toolCall.Name),
501					IsError:    true,
502				}
503				continue
504			}
505
506			// Run tool in goroutine to allow cancellation
507			type toolExecResult struct {
508				response tools.ToolResponse
509				err      error
510			}
511			resultChan := make(chan toolExecResult, 1)
512
513			go func() {
514				response, err := tool.Run(ctx, tools.ToolCall{
515					ID:    toolCall.ID,
516					Name:  toolCall.Name,
517					Input: toolCall.Input,
518				})
519				resultChan <- toolExecResult{response: response, err: err}
520			}()
521
522			var toolResponse tools.ToolResponse
523			var toolErr error
524
525			select {
526			case <-ctx.Done():
527				a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
528				// Mark remaining tool calls as cancelled
529				for j := i; j < len(toolCalls); j++ {
530					toolResults[j] = message.ToolResult{
531						ToolCallID: toolCalls[j].ID,
532						Content:    "Tool execution canceled by user",
533						IsError:    true,
534					}
535				}
536				goto out
537			case result := <-resultChan:
538				toolResponse = result.response
539				toolErr = result.err
540			}
541
542			if toolErr != nil {
543				slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
544				if errors.Is(toolErr, permission.ErrorPermissionDenied) {
545					toolResults[i] = message.ToolResult{
546						ToolCallID: toolCall.ID,
547						Content:    "Permission denied",
548						IsError:    true,
549					}
550					for j := i + 1; j < len(toolCalls); j++ {
551						toolResults[j] = message.ToolResult{
552							ToolCallID: toolCalls[j].ID,
553							Content:    "Tool execution canceled by user",
554							IsError:    true,
555						}
556					}
557					a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
558					break
559				}
560			}
561			toolResults[i] = message.ToolResult{
562				ToolCallID: toolCall.ID,
563				Content:    toolResponse.Content,
564				Metadata:   toolResponse.Metadata,
565				IsError:    toolResponse.IsError,
566			}
567		}
568	}
569out:
570	if len(toolResults) == 0 {
571		return assistantMsg, nil, nil
572	}
573	parts := make([]message.ContentPart, 0)
574	for _, tr := range toolResults {
575		parts = append(parts, tr)
576	}
577	msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
578		Role:     message.Tool,
579		Parts:    parts,
580		Provider: a.providerID,
581	})
582	if err != nil {
583		return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
584	}
585
586	return assistantMsg, &msg, err
587}
588
589func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
590	msg.AddFinish(finishReason, message, details)
591	_ = a.messages.Update(ctx, *msg)
592}
593
594func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
595	select {
596	case <-ctx.Done():
597		return ctx.Err()
598	default:
599		// Continue processing.
600	}
601
602	switch event.Type {
603	case provider.EventThinkingDelta:
604		assistantMsg.AppendReasoningContent(event.Thinking)
605		return a.messages.Update(ctx, *assistantMsg)
606	case provider.EventSignatureDelta:
607		assistantMsg.AppendReasoningSignature(event.Signature)
608		return a.messages.Update(ctx, *assistantMsg)
609	case provider.EventContentDelta:
610		assistantMsg.FinishThinking()
611		assistantMsg.AppendContent(event.Content)
612		return a.messages.Update(ctx, *assistantMsg)
613	case provider.EventToolUseStart:
614		assistantMsg.FinishThinking()
615		slog.Info("Tool call started", "toolCall", event.ToolCall)
616		assistantMsg.AddToolCall(*event.ToolCall)
617		return a.messages.Update(ctx, *assistantMsg)
618	case provider.EventToolUseDelta:
619		assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
620		return a.messages.Update(ctx, *assistantMsg)
621	case provider.EventToolUseStop:
622		slog.Info("Finished tool call", "toolCall", event.ToolCall)
623		assistantMsg.FinishToolCall(event.ToolCall.ID)
624		return a.messages.Update(ctx, *assistantMsg)
625	case provider.EventError:
626		return event.Error
627	case provider.EventComplete:
628		assistantMsg.FinishThinking()
629		assistantMsg.SetToolCalls(event.Response.ToolCalls)
630		assistantMsg.AddFinish(event.Response.FinishReason, "", "")
631		if err := a.messages.Update(ctx, *assistantMsg); err != nil {
632			return fmt.Errorf("failed to update message: %w", err)
633		}
634		return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
635	}
636
637	return nil
638}
639
640func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
641	sess, err := a.sessions.Get(ctx, sessionID)
642	if err != nil {
643		return fmt.Errorf("failed to get session: %w", err)
644	}
645
646	cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
647		model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
648		model.CostPer1MIn/1e6*float64(usage.InputTokens) +
649		model.CostPer1MOut/1e6*float64(usage.OutputTokens)
650
651	sess.Cost += cost
652	sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
653	sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
654
655	_, err = a.sessions.Save(ctx, sess)
656	if err != nil {
657		return fmt.Errorf("failed to save session: %w", err)
658	}
659	return nil
660}
661
662func (a *agent) Summarize(ctx context.Context, sessionID string) error {
663	if a.summarizeProvider == nil {
664		return fmt.Errorf("summarize provider not available")
665	}
666
667	// Check if session is busy
668	if a.IsSessionBusy(sessionID) {
669		return ErrSessionBusy
670	}
671
672	// Create a new context with cancellation
673	summarizeCtx, cancel := context.WithCancel(ctx)
674
675	// Store the cancel function in activeRequests to allow cancellation
676	a.activeRequests.Store(sessionID+"-summarize", cancel)
677
678	go func() {
679		defer a.activeRequests.Delete(sessionID + "-summarize")
680		defer cancel()
681		event := AgentEvent{
682			Type:     AgentEventTypeSummarize,
683			Progress: "Starting summarization...",
684		}
685
686		a.Publish(pubsub.CreatedEvent, event)
687		// Get all messages from the session
688		msgs, err := a.messages.List(summarizeCtx, sessionID)
689		if err != nil {
690			event = AgentEvent{
691				Type:  AgentEventTypeError,
692				Error: fmt.Errorf("failed to list messages: %w", err),
693				Done:  true,
694			}
695			a.Publish(pubsub.CreatedEvent, event)
696			return
697		}
698		summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
699
700		if len(msgs) == 0 {
701			event = AgentEvent{
702				Type:  AgentEventTypeError,
703				Error: fmt.Errorf("no messages to summarize"),
704				Done:  true,
705			}
706			a.Publish(pubsub.CreatedEvent, event)
707			return
708		}
709
710		event = AgentEvent{
711			Type:     AgentEventTypeSummarize,
712			Progress: "Analyzing conversation...",
713		}
714		a.Publish(pubsub.CreatedEvent, event)
715
716		// Add a system message to guide the summarization
717		summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
718
719		// Create a new message with the summarize prompt
720		promptMsg := message.Message{
721			Role:  message.User,
722			Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
723		}
724
725		// Append the prompt to the messages
726		msgsWithPrompt := append(msgs, promptMsg)
727
728		event = AgentEvent{
729			Type:     AgentEventTypeSummarize,
730			Progress: "Generating summary...",
731		}
732
733		a.Publish(pubsub.CreatedEvent, event)
734
735		// Send the messages to the summarize provider
736		response := a.summarizeProvider.StreamResponse(
737			summarizeCtx,
738			msgsWithPrompt,
739			make([]tools.BaseTool, 0),
740		)
741		var finalResponse *provider.ProviderResponse
742		for r := range response {
743			if r.Error != nil {
744				event = AgentEvent{
745					Type:  AgentEventTypeError,
746					Error: fmt.Errorf("failed to summarize: %w", err),
747					Done:  true,
748				}
749				a.Publish(pubsub.CreatedEvent, event)
750				return
751			}
752			finalResponse = r.Response
753		}
754
755		summary := strings.TrimSpace(finalResponse.Content)
756		if summary == "" {
757			event = AgentEvent{
758				Type:  AgentEventTypeError,
759				Error: fmt.Errorf("empty summary returned"),
760				Done:  true,
761			}
762			a.Publish(pubsub.CreatedEvent, event)
763			return
764		}
765		event = AgentEvent{
766			Type:     AgentEventTypeSummarize,
767			Progress: "Creating new session...",
768		}
769
770		a.Publish(pubsub.CreatedEvent, event)
771		oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
772		if err != nil {
773			event = AgentEvent{
774				Type:  AgentEventTypeError,
775				Error: fmt.Errorf("failed to get session: %w", err),
776				Done:  true,
777			}
778
779			a.Publish(pubsub.CreatedEvent, event)
780			return
781		}
782		// Create a message in the new session with the summary
783		msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
784			Role: message.Assistant,
785			Parts: []message.ContentPart{
786				message.TextContent{Text: summary},
787				message.Finish{
788					Reason: message.FinishReasonEndTurn,
789					Time:   time.Now().Unix(),
790				},
791			},
792			Model:    a.summarizeProvider.Model().ID,
793			Provider: a.summarizeProviderID,
794		})
795		if err != nil {
796			event = AgentEvent{
797				Type:  AgentEventTypeError,
798				Error: fmt.Errorf("failed to create summary message: %w", err),
799				Done:  true,
800			}
801
802			a.Publish(pubsub.CreatedEvent, event)
803			return
804		}
805		oldSession.SummaryMessageID = msg.ID
806		oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
807		oldSession.PromptTokens = 0
808		model := a.summarizeProvider.Model()
809		usage := finalResponse.Usage
810		cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
811			model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
812			model.CostPer1MIn/1e6*float64(usage.InputTokens) +
813			model.CostPer1MOut/1e6*float64(usage.OutputTokens)
814		oldSession.Cost += cost
815		_, err = a.sessions.Save(summarizeCtx, oldSession)
816		if err != nil {
817			event = AgentEvent{
818				Type:  AgentEventTypeError,
819				Error: fmt.Errorf("failed to save session: %w", err),
820				Done:  true,
821			}
822			a.Publish(pubsub.CreatedEvent, event)
823		}
824
825		event = AgentEvent{
826			Type:      AgentEventTypeSummarize,
827			SessionID: oldSession.ID,
828			Progress:  "Summary complete",
829			Done:      true,
830		}
831		a.Publish(pubsub.CreatedEvent, event)
832		// Send final success event with the new session ID
833	}()
834
835	return nil
836}
837
838func (a *agent) CancelAll() {
839	if !a.IsBusy() {
840		return
841	}
842	a.activeRequests.Range(func(key, value any) bool {
843		a.Cancel(key.(string)) // key is sessionID
844		return true
845	})
846
847	timeout := time.After(5 * time.Second)
848	for a.IsBusy() {
849		select {
850		case <-timeout:
851			return
852		default:
853			time.Sleep(200 * time.Millisecond)
854		}
855	}
856}
857
858func (a *agent) UpdateModel() error {
859	cfg := config.Get()
860
861	// Get current provider configuration
862	currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
863	if currentProviderCfg == nil || currentProviderCfg.ID == "" {
864		return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
865	}
866
867	// Check if provider has changed
868	if string(currentProviderCfg.ID) != a.providerID {
869		// Provider changed, need to recreate the main provider
870		model := cfg.GetModelByType(a.agentCfg.Model)
871		if model.ID == "" {
872			return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
873		}
874
875		promptID := agentPromptMap[a.agentCfg.ID]
876		if promptID == "" {
877			promptID = prompt.PromptDefault
878		}
879
880		opts := []provider.ProviderClientOption{
881			provider.WithModel(a.agentCfg.Model),
882			provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
883		}
884
885		newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
886		if err != nil {
887			return fmt.Errorf("failed to create new provider: %w", err)
888		}
889
890		// Update the provider and provider ID
891		a.provider = newProvider
892		a.providerID = string(currentProviderCfg.ID)
893	}
894
895	// Check if small model provider has changed (affects title and summarize providers)
896	smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
897	var smallModelProviderCfg config.ProviderConfig
898
899	for _, p := range cfg.Providers {
900		if p.ID == smallModelCfg.Provider {
901			smallModelProviderCfg = p
902			break
903		}
904	}
905
906	if smallModelProviderCfg.ID == "" {
907		return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
908	}
909
910	// Check if summarize provider has changed
911	if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
912		smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
913		if smallModel == nil {
914			return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
915		}
916
917		// Recreate title provider
918		titleOpts := []provider.ProviderClientOption{
919			provider.WithModel(config.SelectedModelTypeSmall),
920			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
921			// We want the title to be short, so we limit the max tokens
922			provider.WithMaxTokens(40),
923		}
924		newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
925		if err != nil {
926			return fmt.Errorf("failed to create new title provider: %w", err)
927		}
928
929		// Recreate summarize provider
930		summarizeOpts := []provider.ProviderClientOption{
931			provider.WithModel(config.SelectedModelTypeSmall),
932			provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
933		}
934		newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
935		if err != nil {
936			return fmt.Errorf("failed to create new summarize provider: %w", err)
937		}
938
939		// Update the providers and provider ID
940		a.titleProvider = newTitleProvider
941		a.summarizeProvider = newSummarizeProvider
942		a.summarizeProviderID = string(smallModelProviderCfg.ID)
943	}
944
945	return nil
946}