1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "sync/atomic"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/config"
15 fur "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() fur.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70
71 toolsDone atomic.Bool
72 tools []tools.BaseTool
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests sync.Map
82}
83
84var agentPromptMap = map[string]prompt.PromptID{
85 "coder": prompt.PromptCoder,
86 "task": prompt.PromptTask,
87}
88
89func NewAgent(
90 agentCfg config.Agent,
91 // These services are needed in the tools
92 permissions permission.Service,
93 sessions session.Service,
94 messages message.Service,
95 history history.Service,
96 lspClients map[string]*lsp.Client,
97) (Service, error) {
98 ctx := context.Background()
99 cfg := config.Get()
100
101 var agentTool tools.BaseTool
102 if agentCfg.ID == "coder" {
103 taskAgentCfg := config.Get().Agents["task"]
104 if taskAgentCfg.ID == "" {
105 return nil, fmt.Errorf("task agent not found in config")
106 }
107 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
108 if err != nil {
109 return nil, fmt.Errorf("failed to create task agent: %w", err)
110 }
111
112 agentTool = NewAgentTool(taskAgent, sessions, messages)
113 }
114
115 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
116 if providerCfg == nil {
117 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
118 }
119 model := config.Get().GetModelByType(agentCfg.Model)
120
121 if model == nil {
122 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
123 }
124
125 promptID := agentPromptMap[agentCfg.ID]
126 if promptID == "" {
127 promptID = prompt.PromptDefault
128 }
129 opts := []provider.ProviderClientOption{
130 provider.WithModel(agentCfg.Model),
131 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
132 }
133 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
134 if err != nil {
135 return nil, err
136 }
137
138 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
139 var smallModelProviderCfg *config.ProviderConfig
140 if smallModelCfg.Provider == providerCfg.ID {
141 smallModelProviderCfg = providerCfg
142 } else {
143 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
144
145 if smallModelProviderCfg.ID == "" {
146 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
147 }
148 }
149 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
150 if smallModel.ID == "" {
151 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
152 }
153
154 titleOpts := []provider.ProviderClientOption{
155 provider.WithModel(config.SelectedModelTypeSmall),
156 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
157 }
158 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
159 if err != nil {
160 return nil, err
161 }
162 summarizeOpts := []provider.ProviderClientOption{
163 provider.WithModel(config.SelectedModelTypeSmall),
164 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
165 }
166 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
167 if err != nil {
168 return nil, err
169 }
170
171 agent := &agent{
172 Broker: pubsub.NewBroker[AgentEvent](),
173 agentCfg: agentCfg,
174 provider: agentProvider,
175 providerID: string(providerCfg.ID),
176 messages: messages,
177 sessions: sessions,
178 titleProvider: titleProvider,
179 summarizeProvider: summarizeProvider,
180 summarizeProviderID: string(smallModelProviderCfg.ID),
181 activeRequests: sync.Map{},
182 }
183
184 go func() {
185 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
186 defer func() {
187 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
188 agent.toolsDone.Store(true)
189 }()
190
191 cwd := cfg.WorkingDir()
192 allTools := []tools.BaseTool{
193 tools.NewBashTool(permissions, cwd),
194 tools.NewDownloadTool(permissions, cwd),
195 tools.NewEditTool(lspClients, permissions, history, cwd),
196 tools.NewFetchTool(permissions, cwd),
197 tools.NewGlobTool(cwd),
198 tools.NewGrepTool(cwd),
199 tools.NewLsTool(cwd),
200 tools.NewSourcegraphTool(),
201 tools.NewViewTool(lspClients, cwd),
202 tools.NewWriteTool(lspClients, permissions, history, cwd),
203 }
204
205 mcpTools := GetMCPTools(ctx, permissions, cfg)
206 allTools = append(allTools, mcpTools...)
207
208 if len(lspClients) > 0 {
209 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
210 }
211
212 if agentTool != nil {
213 allTools = append(allTools, agentTool)
214 }
215
216 if agentCfg.AllowedTools == nil {
217 agent.tools = allTools
218 return
219 }
220
221 var filteredTools []tools.BaseTool
222 for _, tool := range allTools {
223 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
224 filteredTools = append(filteredTools, tool)
225 }
226 }
227 agent.tools = filteredTools
228 }()
229
230 return agent, nil
231}
232
233func (a *agent) Model() fur.Model {
234 return *config.Get().GetModelByType(a.agentCfg.Model)
235}
236
237func (a *agent) Cancel(sessionID string) {
238 // Cancel regular requests
239 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
240 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
241 slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
242 cancel()
243 }
244 }
245
246 // Also check for summarize requests
247 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
248 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
249 slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
250 cancel()
251 }
252 }
253}
254
255func (a *agent) IsBusy() bool {
256 busy := false
257 a.activeRequests.Range(func(key, value any) bool {
258 if cancelFunc, ok := value.(context.CancelFunc); ok {
259 if cancelFunc != nil {
260 busy = true
261 return false
262 }
263 }
264 return true
265 })
266 return busy
267}
268
269func (a *agent) IsSessionBusy(sessionID string) bool {
270 _, busy := a.activeRequests.Load(sessionID)
271 return busy
272}
273
274func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
275 if content == "" {
276 return nil
277 }
278 if a.titleProvider == nil {
279 return nil
280 }
281 session, err := a.sessions.Get(ctx, sessionID)
282 if err != nil {
283 return err
284 }
285 parts := []message.ContentPart{message.TextContent{
286 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
287 }}
288
289 // Use streaming approach like summarization
290 response := a.titleProvider.StreamResponse(
291 ctx,
292 []message.Message{
293 {
294 Role: message.User,
295 Parts: parts,
296 },
297 },
298 make([]tools.BaseTool, 0),
299 )
300
301 var finalResponse *provider.ProviderResponse
302 for r := range response {
303 if r.Error != nil {
304 return r.Error
305 }
306 finalResponse = r.Response
307 }
308
309 if finalResponse == nil {
310 return fmt.Errorf("no response received from title provider")
311 }
312
313 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
314 if title == "" {
315 return nil
316 }
317
318 session.Title = title
319 _, err = a.sessions.Save(ctx, session)
320 return err
321}
322
323func (a *agent) err(err error) AgentEvent {
324 return AgentEvent{
325 Type: AgentEventTypeError,
326 Error: err,
327 }
328}
329
330func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
331 if !a.Model().SupportsImages && attachments != nil {
332 attachments = nil
333 }
334 events := make(chan AgentEvent)
335 if a.IsSessionBusy(sessionID) {
336 return nil, ErrSessionBusy
337 }
338
339 genCtx, cancel := context.WithCancel(ctx)
340
341 a.activeRequests.Store(sessionID, cancel)
342 go func() {
343 slog.Debug("Request started", "sessionID", sessionID)
344 defer log.RecoverPanic("agent.Run", func() {
345 events <- a.err(fmt.Errorf("panic while running the agent"))
346 })
347 var attachmentParts []message.ContentPart
348 for _, attachment := range attachments {
349 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
350 }
351 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
352 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
353 slog.Error(result.Error.Error())
354 }
355 slog.Debug("Request completed", "sessionID", sessionID)
356 a.activeRequests.Delete(sessionID)
357 cancel()
358 a.Publish(pubsub.CreatedEvent, result)
359 events <- result
360 close(events)
361 }()
362 return events, nil
363}
364
365func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
366 cfg := config.Get()
367 // List existing messages; if none, start title generation asynchronously.
368 msgs, err := a.messages.List(ctx, sessionID)
369 if err != nil {
370 return a.err(fmt.Errorf("failed to list messages: %w", err))
371 }
372 if len(msgs) == 0 {
373 go func() {
374 defer log.RecoverPanic("agent.Run", func() {
375 slog.Error("panic while generating title")
376 })
377 titleErr := a.generateTitle(context.Background(), sessionID, content)
378 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
379 slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
380 }
381 }()
382 }
383 session, err := a.sessions.Get(ctx, sessionID)
384 if err != nil {
385 return a.err(fmt.Errorf("failed to get session: %w", err))
386 }
387 if session.SummaryMessageID != "" {
388 summaryMsgInex := -1
389 for i, msg := range msgs {
390 if msg.ID == session.SummaryMessageID {
391 summaryMsgInex = i
392 break
393 }
394 }
395 if summaryMsgInex != -1 {
396 msgs = msgs[summaryMsgInex:]
397 msgs[0].Role = message.User
398 }
399 }
400
401 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
402 if err != nil {
403 return a.err(fmt.Errorf("failed to create user message: %w", err))
404 }
405 // Append the new user message to the conversation history.
406 msgHistory := append(msgs, userMsg)
407
408 for {
409 // Check for cancellation before each iteration
410 select {
411 case <-ctx.Done():
412 return a.err(ctx.Err())
413 default:
414 // Continue processing
415 }
416 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
417 if err != nil {
418 if errors.Is(err, context.Canceled) {
419 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
420 a.messages.Update(context.Background(), agentMessage)
421 return a.err(ErrRequestCancelled)
422 }
423 return a.err(fmt.Errorf("failed to process events: %w", err))
424 }
425 if cfg.Options.Debug {
426 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
427 }
428 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
429 // We are not done, we need to respond with the tool response
430 msgHistory = append(msgHistory, agentMessage, *toolResults)
431 continue
432 }
433 return AgentEvent{
434 Type: AgentEventTypeResponse,
435 Message: agentMessage,
436 Done: true,
437 }
438 }
439}
440
441func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
442 parts := []message.ContentPart{message.TextContent{Text: content}}
443 parts = append(parts, attachmentParts...)
444 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
445 Role: message.User,
446 Parts: parts,
447 })
448}
449
450func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
451 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
452 if !a.toolsDone.Load() {
453 return message.Message{}, nil, fmt.Errorf("agent is still initializing, please wait a moment and try again")
454 }
455 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
456
457 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
458 Role: message.Assistant,
459 Parts: []message.ContentPart{},
460 Model: a.Model().ID,
461 Provider: a.providerID,
462 })
463 if err != nil {
464 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
465 }
466
467 // Add the session and message ID into the context if needed by tools.
468 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
469
470 // Process each event in the stream.
471 for event := range eventChan {
472 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
473 if errors.Is(processErr, context.Canceled) {
474 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
475 } else {
476 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
477 }
478 return assistantMsg, nil, processErr
479 }
480 if ctx.Err() != nil {
481 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
482 return assistantMsg, nil, ctx.Err()
483 }
484 }
485
486 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
487 toolCalls := assistantMsg.ToolCalls()
488 for i, toolCall := range toolCalls {
489 select {
490 case <-ctx.Done():
491 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
492 // Make all future tool calls cancelled
493 for j := i; j < len(toolCalls); j++ {
494 toolResults[j] = message.ToolResult{
495 ToolCallID: toolCalls[j].ID,
496 Content: "Tool execution canceled by user",
497 IsError: true,
498 }
499 }
500 goto out
501 default:
502 // Continue processing
503 var tool tools.BaseTool
504 for _, availableTool := range a.tools {
505 if availableTool.Info().Name == toolCall.Name {
506 tool = availableTool
507 break
508 }
509 }
510
511 // Tool not found
512 if tool == nil {
513 toolResults[i] = message.ToolResult{
514 ToolCallID: toolCall.ID,
515 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
516 IsError: true,
517 }
518 continue
519 }
520
521 // Run tool in goroutine to allow cancellation
522 type toolExecResult struct {
523 response tools.ToolResponse
524 err error
525 }
526 resultChan := make(chan toolExecResult, 1)
527
528 go func() {
529 response, err := tool.Run(ctx, tools.ToolCall{
530 ID: toolCall.ID,
531 Name: toolCall.Name,
532 Input: toolCall.Input,
533 })
534 resultChan <- toolExecResult{response: response, err: err}
535 }()
536
537 var toolResponse tools.ToolResponse
538 var toolErr error
539
540 select {
541 case <-ctx.Done():
542 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
543 // Mark remaining tool calls as cancelled
544 for j := i; j < len(toolCalls); j++ {
545 toolResults[j] = message.ToolResult{
546 ToolCallID: toolCalls[j].ID,
547 Content: "Tool execution canceled by user",
548 IsError: true,
549 }
550 }
551 goto out
552 case result := <-resultChan:
553 toolResponse = result.response
554 toolErr = result.err
555 }
556
557 if toolErr != nil {
558 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
559 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
560 toolResults[i] = message.ToolResult{
561 ToolCallID: toolCall.ID,
562 Content: "Permission denied",
563 IsError: true,
564 }
565 for j := i + 1; j < len(toolCalls); j++ {
566 toolResults[j] = message.ToolResult{
567 ToolCallID: toolCalls[j].ID,
568 Content: "Tool execution canceled by user",
569 IsError: true,
570 }
571 }
572 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
573 break
574 }
575 }
576 toolResults[i] = message.ToolResult{
577 ToolCallID: toolCall.ID,
578 Content: toolResponse.Content,
579 Metadata: toolResponse.Metadata,
580 IsError: toolResponse.IsError,
581 }
582 }
583 }
584out:
585 if len(toolResults) == 0 {
586 return assistantMsg, nil, nil
587 }
588 parts := make([]message.ContentPart, 0)
589 for _, tr := range toolResults {
590 parts = append(parts, tr)
591 }
592 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
593 Role: message.Tool,
594 Parts: parts,
595 Provider: a.providerID,
596 })
597 if err != nil {
598 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
599 }
600
601 return assistantMsg, &msg, err
602}
603
604func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
605 msg.AddFinish(finishReason, message, details)
606 _ = a.messages.Update(ctx, *msg)
607}
608
609func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
610 select {
611 case <-ctx.Done():
612 return ctx.Err()
613 default:
614 // Continue processing.
615 }
616
617 switch event.Type {
618 case provider.EventThinkingDelta:
619 assistantMsg.AppendReasoningContent(event.Thinking)
620 return a.messages.Update(ctx, *assistantMsg)
621 case provider.EventSignatureDelta:
622 assistantMsg.AppendReasoningSignature(event.Signature)
623 return a.messages.Update(ctx, *assistantMsg)
624 case provider.EventContentDelta:
625 assistantMsg.FinishThinking()
626 assistantMsg.AppendContent(event.Content)
627 return a.messages.Update(ctx, *assistantMsg)
628 case provider.EventToolUseStart:
629 assistantMsg.FinishThinking()
630 slog.Info("Tool call started", "toolCall", event.ToolCall)
631 assistantMsg.AddToolCall(*event.ToolCall)
632 return a.messages.Update(ctx, *assistantMsg)
633 case provider.EventToolUseDelta:
634 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
635 return a.messages.Update(ctx, *assistantMsg)
636 case provider.EventToolUseStop:
637 slog.Info("Finished tool call", "toolCall", event.ToolCall)
638 assistantMsg.FinishToolCall(event.ToolCall.ID)
639 return a.messages.Update(ctx, *assistantMsg)
640 case provider.EventError:
641 return event.Error
642 case provider.EventComplete:
643 assistantMsg.FinishThinking()
644 assistantMsg.SetToolCalls(event.Response.ToolCalls)
645 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
646 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
647 return fmt.Errorf("failed to update message: %w", err)
648 }
649 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
650 }
651
652 return nil
653}
654
655func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
656 sess, err := a.sessions.Get(ctx, sessionID)
657 if err != nil {
658 return fmt.Errorf("failed to get session: %w", err)
659 }
660
661 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
662 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
663 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
664 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
665
666 sess.Cost += cost
667 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
668 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
669
670 _, err = a.sessions.Save(ctx, sess)
671 if err != nil {
672 return fmt.Errorf("failed to save session: %w", err)
673 }
674 return nil
675}
676
677func (a *agent) Summarize(ctx context.Context, sessionID string) error {
678 if a.summarizeProvider == nil {
679 return fmt.Errorf("summarize provider not available")
680 }
681
682 // Check if session is busy
683 if a.IsSessionBusy(sessionID) {
684 return ErrSessionBusy
685 }
686
687 // Create a new context with cancellation
688 summarizeCtx, cancel := context.WithCancel(ctx)
689
690 // Store the cancel function in activeRequests to allow cancellation
691 a.activeRequests.Store(sessionID+"-summarize", cancel)
692
693 go func() {
694 defer a.activeRequests.Delete(sessionID + "-summarize")
695 defer cancel()
696 event := AgentEvent{
697 Type: AgentEventTypeSummarize,
698 Progress: "Starting summarization...",
699 }
700
701 a.Publish(pubsub.CreatedEvent, event)
702 // Get all messages from the session
703 msgs, err := a.messages.List(summarizeCtx, sessionID)
704 if err != nil {
705 event = AgentEvent{
706 Type: AgentEventTypeError,
707 Error: fmt.Errorf("failed to list messages: %w", err),
708 Done: true,
709 }
710 a.Publish(pubsub.CreatedEvent, event)
711 return
712 }
713 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
714
715 if len(msgs) == 0 {
716 event = AgentEvent{
717 Type: AgentEventTypeError,
718 Error: fmt.Errorf("no messages to summarize"),
719 Done: true,
720 }
721 a.Publish(pubsub.CreatedEvent, event)
722 return
723 }
724
725 event = AgentEvent{
726 Type: AgentEventTypeSummarize,
727 Progress: "Analyzing conversation...",
728 }
729 a.Publish(pubsub.CreatedEvent, event)
730
731 // Add a system message to guide the summarization
732 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
733
734 // Create a new message with the summarize prompt
735 promptMsg := message.Message{
736 Role: message.User,
737 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
738 }
739
740 // Append the prompt to the messages
741 msgsWithPrompt := append(msgs, promptMsg)
742
743 event = AgentEvent{
744 Type: AgentEventTypeSummarize,
745 Progress: "Generating summary...",
746 }
747
748 a.Publish(pubsub.CreatedEvent, event)
749
750 // Send the messages to the summarize provider
751 response := a.summarizeProvider.StreamResponse(
752 summarizeCtx,
753 msgsWithPrompt,
754 make([]tools.BaseTool, 0),
755 )
756 var finalResponse *provider.ProviderResponse
757 for r := range response {
758 if r.Error != nil {
759 event = AgentEvent{
760 Type: AgentEventTypeError,
761 Error: fmt.Errorf("failed to summarize: %w", err),
762 Done: true,
763 }
764 a.Publish(pubsub.CreatedEvent, event)
765 return
766 }
767 finalResponse = r.Response
768 }
769
770 summary := strings.TrimSpace(finalResponse.Content)
771 if summary == "" {
772 event = AgentEvent{
773 Type: AgentEventTypeError,
774 Error: fmt.Errorf("empty summary returned"),
775 Done: true,
776 }
777 a.Publish(pubsub.CreatedEvent, event)
778 return
779 }
780 event = AgentEvent{
781 Type: AgentEventTypeSummarize,
782 Progress: "Creating new session...",
783 }
784
785 a.Publish(pubsub.CreatedEvent, event)
786 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
787 if err != nil {
788 event = AgentEvent{
789 Type: AgentEventTypeError,
790 Error: fmt.Errorf("failed to get session: %w", err),
791 Done: true,
792 }
793
794 a.Publish(pubsub.CreatedEvent, event)
795 return
796 }
797 // Create a message in the new session with the summary
798 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
799 Role: message.Assistant,
800 Parts: []message.ContentPart{
801 message.TextContent{Text: summary},
802 message.Finish{
803 Reason: message.FinishReasonEndTurn,
804 Time: time.Now().Unix(),
805 },
806 },
807 Model: a.summarizeProvider.Model().ID,
808 Provider: a.summarizeProviderID,
809 })
810 if err != nil {
811 event = AgentEvent{
812 Type: AgentEventTypeError,
813 Error: fmt.Errorf("failed to create summary message: %w", err),
814 Done: true,
815 }
816
817 a.Publish(pubsub.CreatedEvent, event)
818 return
819 }
820 oldSession.SummaryMessageID = msg.ID
821 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
822 oldSession.PromptTokens = 0
823 model := a.summarizeProvider.Model()
824 usage := finalResponse.Usage
825 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
826 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
827 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
828 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
829 oldSession.Cost += cost
830 _, err = a.sessions.Save(summarizeCtx, oldSession)
831 if err != nil {
832 event = AgentEvent{
833 Type: AgentEventTypeError,
834 Error: fmt.Errorf("failed to save session: %w", err),
835 Done: true,
836 }
837 a.Publish(pubsub.CreatedEvent, event)
838 }
839
840 event = AgentEvent{
841 Type: AgentEventTypeSummarize,
842 SessionID: oldSession.ID,
843 Progress: "Summary complete",
844 Done: true,
845 }
846 a.Publish(pubsub.CreatedEvent, event)
847 // Send final success event with the new session ID
848 }()
849
850 return nil
851}
852
853func (a *agent) CancelAll() {
854 if !a.IsBusy() {
855 return
856 }
857 a.activeRequests.Range(func(key, value any) bool {
858 a.Cancel(key.(string)) // key is sessionID
859 return true
860 })
861
862 timeout := time.After(5 * time.Second)
863 for a.IsBusy() {
864 select {
865 case <-timeout:
866 return
867 default:
868 time.Sleep(200 * time.Millisecond)
869 }
870 }
871}
872
873func (a *agent) UpdateModel() error {
874 cfg := config.Get()
875
876 // Get current provider configuration
877 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
878 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
879 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
880 }
881
882 // Check if provider has changed
883 if string(currentProviderCfg.ID) != a.providerID {
884 // Provider changed, need to recreate the main provider
885 model := cfg.GetModelByType(a.agentCfg.Model)
886 if model.ID == "" {
887 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
888 }
889
890 promptID := agentPromptMap[a.agentCfg.ID]
891 if promptID == "" {
892 promptID = prompt.PromptDefault
893 }
894
895 opts := []provider.ProviderClientOption{
896 provider.WithModel(a.agentCfg.Model),
897 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
898 }
899
900 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
901 if err != nil {
902 return fmt.Errorf("failed to create new provider: %w", err)
903 }
904
905 // Update the provider and provider ID
906 a.provider = newProvider
907 a.providerID = string(currentProviderCfg.ID)
908 }
909
910 // Check if small model provider has changed (affects title and summarize providers)
911 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
912 var smallModelProviderCfg config.ProviderConfig
913
914 for _, p := range cfg.Providers.Seq2() {
915 if p.ID == smallModelCfg.Provider {
916 smallModelProviderCfg = p
917 break
918 }
919 }
920
921 if smallModelProviderCfg.ID == "" {
922 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
923 }
924
925 // Check if summarize provider has changed
926 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
927 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
928 if smallModel == nil {
929 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
930 }
931
932 // Recreate title provider
933 titleOpts := []provider.ProviderClientOption{
934 provider.WithModel(config.SelectedModelTypeSmall),
935 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
936 // We want the title to be short, so we limit the max tokens
937 provider.WithMaxTokens(40),
938 }
939 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
940 if err != nil {
941 return fmt.Errorf("failed to create new title provider: %w", err)
942 }
943
944 // Recreate summarize provider
945 summarizeOpts := []provider.ProviderClientOption{
946 provider.WithModel(config.SelectedModelTypeSmall),
947 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
948 }
949 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
950 if err != nil {
951 return fmt.Errorf("failed to create new summarize provider: %w", err)
952 }
953
954 // Update the providers and provider ID
955 a.titleProvider = newTitleProvider
956 a.summarizeProvider = newSummarizeProvider
957 a.summarizeProviderID = string(smallModelProviderCfg.ID)
958 }
959
960 return nil
961}