1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 "github.com/charmbracelet/crush/internal/csync"
15 fur "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() fur.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70
71 tools *csync.LazySlice[tools.BaseTool]
72
73 provider provider.Provider
74 providerID string
75
76 titleProvider provider.Provider
77 summarizeProvider provider.Provider
78 summarizeProviderID string
79
80 activeRequests sync.Map
81}
82
83var agentPromptMap = map[string]prompt.PromptID{
84 "coder": prompt.PromptCoder,
85 "task": prompt.PromptTask,
86}
87
88func NewAgent(
89 agentCfg config.Agent,
90 // These services are needed in the tools
91 permissions permission.Service,
92 sessions session.Service,
93 messages message.Service,
94 history history.Service,
95 lspClients map[string]*lsp.Client,
96) (Service, error) {
97 ctx := context.Background()
98 cfg := config.Get()
99
100 var agentTool tools.BaseTool
101 if agentCfg.ID == "coder" {
102 taskAgentCfg := config.Get().Agents["task"]
103 if taskAgentCfg.ID == "" {
104 return nil, fmt.Errorf("task agent not found in config")
105 }
106 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
107 if err != nil {
108 return nil, fmt.Errorf("failed to create task agent: %w", err)
109 }
110
111 agentTool = NewAgentTool(taskAgent, sessions, messages)
112 }
113
114 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
115 if providerCfg == nil {
116 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
117 }
118 model := config.Get().GetModelByType(agentCfg.Model)
119
120 if model == nil {
121 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
122 }
123
124 promptID := agentPromptMap[agentCfg.ID]
125 if promptID == "" {
126 promptID = prompt.PromptDefault
127 }
128 opts := []provider.ProviderClientOption{
129 provider.WithModel(agentCfg.Model),
130 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
131 }
132 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
133 if err != nil {
134 return nil, err
135 }
136
137 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
138 var smallModelProviderCfg *config.ProviderConfig
139 if smallModelCfg.Provider == providerCfg.ID {
140 smallModelProviderCfg = providerCfg
141 } else {
142 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
143
144 if smallModelProviderCfg.ID == "" {
145 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
146 }
147 }
148 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
149 if smallModel.ID == "" {
150 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
151 }
152
153 titleOpts := []provider.ProviderClientOption{
154 provider.WithModel(config.SelectedModelTypeSmall),
155 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
156 }
157 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
158 if err != nil {
159 return nil, err
160 }
161 summarizeOpts := []provider.ProviderClientOption{
162 provider.WithModel(config.SelectedModelTypeSmall),
163 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
164 }
165 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
166 if err != nil {
167 return nil, err
168 }
169
170 toolFn := func() []tools.BaseTool {
171 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
172 defer func() {
173 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
174 }()
175
176 cwd := cfg.WorkingDir()
177 allTools := []tools.BaseTool{
178 tools.NewBashTool(permissions, cwd),
179 tools.NewDownloadTool(permissions, cwd),
180 tools.NewEditTool(lspClients, permissions, history, cwd),
181 tools.NewFetchTool(permissions, cwd),
182 tools.NewGlobTool(cwd),
183 tools.NewGrepTool(cwd),
184 tools.NewLsTool(cwd),
185 tools.NewSourcegraphTool(),
186 tools.NewViewTool(lspClients, cwd),
187 tools.NewWriteTool(lspClients, permissions, history, cwd),
188 }
189
190 mcpTools := GetMCPTools(ctx, permissions, cfg)
191 allTools = append(allTools, mcpTools...)
192
193 if len(lspClients) > 0 {
194 allTools = append(allTools, tools.NewDiagnosticsTool(lspClients))
195 }
196
197 if agentTool != nil {
198 allTools = append(allTools, agentTool)
199 }
200
201 if agentCfg.AllowedTools == nil {
202 return allTools
203 }
204
205 var filteredTools []tools.BaseTool
206 for _, tool := range allTools {
207 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
208 filteredTools = append(filteredTools, tool)
209 }
210 }
211 return filteredTools
212 }
213
214 return &agent{
215 Broker: pubsub.NewBroker[AgentEvent](),
216 agentCfg: agentCfg,
217 provider: agentProvider,
218 providerID: string(providerCfg.ID),
219 messages: messages,
220 sessions: sessions,
221 titleProvider: titleProvider,
222 summarizeProvider: summarizeProvider,
223 summarizeProviderID: string(smallModelProviderCfg.ID),
224 activeRequests: sync.Map{},
225 tools: csync.NewLazySlice(toolFn),
226 }, nil
227}
228
229func (a *agent) Model() fur.Model {
230 return *config.Get().GetModelByType(a.agentCfg.Model)
231}
232
233func (a *agent) Cancel(sessionID string) {
234 // Cancel regular requests
235 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
236 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
237 slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
238 cancel()
239 }
240 }
241
242 // Also check for summarize requests
243 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
244 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
245 slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
246 cancel()
247 }
248 }
249}
250
251func (a *agent) IsBusy() bool {
252 busy := false
253 a.activeRequests.Range(func(key, value any) bool {
254 if cancelFunc, ok := value.(context.CancelFunc); ok {
255 if cancelFunc != nil {
256 busy = true
257 return false
258 }
259 }
260 return true
261 })
262 return busy
263}
264
265func (a *agent) IsSessionBusy(sessionID string) bool {
266 _, busy := a.activeRequests.Load(sessionID)
267 return busy
268}
269
270func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
271 if content == "" {
272 return nil
273 }
274 if a.titleProvider == nil {
275 return nil
276 }
277 session, err := a.sessions.Get(ctx, sessionID)
278 if err != nil {
279 return err
280 }
281 parts := []message.ContentPart{message.TextContent{
282 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
283 }}
284
285 // Use streaming approach like summarization
286 response := a.titleProvider.StreamResponse(
287 ctx,
288 []message.Message{
289 {
290 Role: message.User,
291 Parts: parts,
292 },
293 },
294 make([]tools.BaseTool, 0),
295 )
296
297 var finalResponse *provider.ProviderResponse
298 for r := range response {
299 if r.Error != nil {
300 return r.Error
301 }
302 finalResponse = r.Response
303 }
304
305 if finalResponse == nil {
306 return fmt.Errorf("no response received from title provider")
307 }
308
309 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
310 if title == "" {
311 return nil
312 }
313
314 session.Title = title
315 _, err = a.sessions.Save(ctx, session)
316 return err
317}
318
319func (a *agent) err(err error) AgentEvent {
320 return AgentEvent{
321 Type: AgentEventTypeError,
322 Error: err,
323 }
324}
325
326func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
327 if !a.Model().SupportsImages && attachments != nil {
328 attachments = nil
329 }
330 events := make(chan AgentEvent)
331 if a.IsSessionBusy(sessionID) {
332 return nil, ErrSessionBusy
333 }
334
335 genCtx, cancel := context.WithCancel(ctx)
336
337 a.activeRequests.Store(sessionID, cancel)
338 go func() {
339 slog.Debug("Request started", "sessionID", sessionID)
340 defer log.RecoverPanic("agent.Run", func() {
341 events <- a.err(fmt.Errorf("panic while running the agent"))
342 })
343 var attachmentParts []message.ContentPart
344 for _, attachment := range attachments {
345 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
346 }
347 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
348 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
349 slog.Error(result.Error.Error())
350 }
351 slog.Debug("Request completed", "sessionID", sessionID)
352 a.activeRequests.Delete(sessionID)
353 cancel()
354 a.Publish(pubsub.CreatedEvent, result)
355 events <- result
356 close(events)
357 }()
358 return events, nil
359}
360
361func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
362 cfg := config.Get()
363 // List existing messages; if none, start title generation asynchronously.
364 msgs, err := a.messages.List(ctx, sessionID)
365 if err != nil {
366 return a.err(fmt.Errorf("failed to list messages: %w", err))
367 }
368 if len(msgs) == 0 {
369 go func() {
370 defer log.RecoverPanic("agent.Run", func() {
371 slog.Error("panic while generating title")
372 })
373 titleErr := a.generateTitle(context.Background(), sessionID, content)
374 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
375 slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
376 }
377 }()
378 }
379 session, err := a.sessions.Get(ctx, sessionID)
380 if err != nil {
381 return a.err(fmt.Errorf("failed to get session: %w", err))
382 }
383 if session.SummaryMessageID != "" {
384 summaryMsgInex := -1
385 for i, msg := range msgs {
386 if msg.ID == session.SummaryMessageID {
387 summaryMsgInex = i
388 break
389 }
390 }
391 if summaryMsgInex != -1 {
392 msgs = msgs[summaryMsgInex:]
393 msgs[0].Role = message.User
394 }
395 }
396
397 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
398 if err != nil {
399 return a.err(fmt.Errorf("failed to create user message: %w", err))
400 }
401 // Append the new user message to the conversation history.
402 msgHistory := append(msgs, userMsg)
403
404 for {
405 // Check for cancellation before each iteration
406 select {
407 case <-ctx.Done():
408 return a.err(ctx.Err())
409 default:
410 // Continue processing
411 }
412 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
413 if err != nil {
414 if errors.Is(err, context.Canceled) {
415 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
416 a.messages.Update(context.Background(), agentMessage)
417 return a.err(ErrRequestCancelled)
418 }
419 return a.err(fmt.Errorf("failed to process events: %w", err))
420 }
421 if cfg.Options.Debug {
422 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
423 }
424 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
425 // We are not done, we need to respond with the tool response
426 msgHistory = append(msgHistory, agentMessage, *toolResults)
427 continue
428 }
429 return AgentEvent{
430 Type: AgentEventTypeResponse,
431 Message: agentMessage,
432 Done: true,
433 }
434 }
435}
436
437func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
438 parts := []message.ContentPart{message.TextContent{Text: content}}
439 parts = append(parts, attachmentParts...)
440 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
441 Role: message.User,
442 Parts: parts,
443 })
444}
445
446func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
447 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
448 eventChan := a.provider.StreamResponse(ctx, msgHistory, slices.Collect(a.tools.Iter()))
449
450 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
451 Role: message.Assistant,
452 Parts: []message.ContentPart{},
453 Model: a.Model().ID,
454 Provider: a.providerID,
455 })
456 if err != nil {
457 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
458 }
459
460 // Add the session and message ID into the context if needed by tools.
461 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
462
463 // Process each event in the stream.
464 for event := range eventChan {
465 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
466 if errors.Is(processErr, context.Canceled) {
467 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
468 } else {
469 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
470 }
471 return assistantMsg, nil, processErr
472 }
473 if ctx.Err() != nil {
474 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
475 return assistantMsg, nil, ctx.Err()
476 }
477 }
478
479 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
480 toolCalls := assistantMsg.ToolCalls()
481 for i, toolCall := range toolCalls {
482 select {
483 case <-ctx.Done():
484 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
485 // Make all future tool calls cancelled
486 for j := i; j < len(toolCalls); j++ {
487 toolResults[j] = message.ToolResult{
488 ToolCallID: toolCalls[j].ID,
489 Content: "Tool execution canceled by user",
490 IsError: true,
491 }
492 }
493 goto out
494 default:
495 // Continue processing
496 var tool tools.BaseTool
497 for availableTool := range a.tools.Iter() {
498 if availableTool.Info().Name == toolCall.Name {
499 tool = availableTool
500 break
501 }
502 }
503
504 // Tool not found
505 if tool == nil {
506 toolResults[i] = message.ToolResult{
507 ToolCallID: toolCall.ID,
508 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
509 IsError: true,
510 }
511 continue
512 }
513
514 // Run tool in goroutine to allow cancellation
515 type toolExecResult struct {
516 response tools.ToolResponse
517 err error
518 }
519 resultChan := make(chan toolExecResult, 1)
520
521 go func() {
522 response, err := tool.Run(ctx, tools.ToolCall{
523 ID: toolCall.ID,
524 Name: toolCall.Name,
525 Input: toolCall.Input,
526 })
527 resultChan <- toolExecResult{response: response, err: err}
528 }()
529
530 var toolResponse tools.ToolResponse
531 var toolErr error
532
533 select {
534 case <-ctx.Done():
535 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
536 // Mark remaining tool calls as cancelled
537 for j := i; j < len(toolCalls); j++ {
538 toolResults[j] = message.ToolResult{
539 ToolCallID: toolCalls[j].ID,
540 Content: "Tool execution canceled by user",
541 IsError: true,
542 }
543 }
544 goto out
545 case result := <-resultChan:
546 toolResponse = result.response
547 toolErr = result.err
548 }
549
550 if toolErr != nil {
551 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
552 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
553 toolResults[i] = message.ToolResult{
554 ToolCallID: toolCall.ID,
555 Content: "Permission denied",
556 IsError: true,
557 }
558 for j := i + 1; j < len(toolCalls); j++ {
559 toolResults[j] = message.ToolResult{
560 ToolCallID: toolCalls[j].ID,
561 Content: "Tool execution canceled by user",
562 IsError: true,
563 }
564 }
565 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
566 break
567 }
568 }
569 toolResults[i] = message.ToolResult{
570 ToolCallID: toolCall.ID,
571 Content: toolResponse.Content,
572 Metadata: toolResponse.Metadata,
573 IsError: toolResponse.IsError,
574 }
575 }
576 }
577out:
578 if len(toolResults) == 0 {
579 return assistantMsg, nil, nil
580 }
581 parts := make([]message.ContentPart, 0)
582 for _, tr := range toolResults {
583 parts = append(parts, tr)
584 }
585 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
586 Role: message.Tool,
587 Parts: parts,
588 Provider: a.providerID,
589 })
590 if err != nil {
591 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
592 }
593
594 return assistantMsg, &msg, err
595}
596
597func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
598 msg.AddFinish(finishReason, message, details)
599 _ = a.messages.Update(ctx, *msg)
600}
601
602func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
603 select {
604 case <-ctx.Done():
605 return ctx.Err()
606 default:
607 // Continue processing.
608 }
609
610 switch event.Type {
611 case provider.EventThinkingDelta:
612 assistantMsg.AppendReasoningContent(event.Thinking)
613 return a.messages.Update(ctx, *assistantMsg)
614 case provider.EventSignatureDelta:
615 assistantMsg.AppendReasoningSignature(event.Signature)
616 return a.messages.Update(ctx, *assistantMsg)
617 case provider.EventContentDelta:
618 assistantMsg.FinishThinking()
619 assistantMsg.AppendContent(event.Content)
620 return a.messages.Update(ctx, *assistantMsg)
621 case provider.EventToolUseStart:
622 assistantMsg.FinishThinking()
623 slog.Info("Tool call started", "toolCall", event.ToolCall)
624 assistantMsg.AddToolCall(*event.ToolCall)
625 return a.messages.Update(ctx, *assistantMsg)
626 case provider.EventToolUseDelta:
627 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
628 return a.messages.Update(ctx, *assistantMsg)
629 case provider.EventToolUseStop:
630 slog.Info("Finished tool call", "toolCall", event.ToolCall)
631 assistantMsg.FinishToolCall(event.ToolCall.ID)
632 return a.messages.Update(ctx, *assistantMsg)
633 case provider.EventError:
634 return event.Error
635 case provider.EventComplete:
636 assistantMsg.FinishThinking()
637 assistantMsg.SetToolCalls(event.Response.ToolCalls)
638 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
639 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
640 return fmt.Errorf("failed to update message: %w", err)
641 }
642 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
643 }
644
645 return nil
646}
647
648func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
649 sess, err := a.sessions.Get(ctx, sessionID)
650 if err != nil {
651 return fmt.Errorf("failed to get session: %w", err)
652 }
653
654 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
655 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
656 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
657 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
658
659 sess.Cost += cost
660 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
661 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
662
663 _, err = a.sessions.Save(ctx, sess)
664 if err != nil {
665 return fmt.Errorf("failed to save session: %w", err)
666 }
667 return nil
668}
669
670func (a *agent) Summarize(ctx context.Context, sessionID string) error {
671 if a.summarizeProvider == nil {
672 return fmt.Errorf("summarize provider not available")
673 }
674
675 // Check if session is busy
676 if a.IsSessionBusy(sessionID) {
677 return ErrSessionBusy
678 }
679
680 // Create a new context with cancellation
681 summarizeCtx, cancel := context.WithCancel(ctx)
682
683 // Store the cancel function in activeRequests to allow cancellation
684 a.activeRequests.Store(sessionID+"-summarize", cancel)
685
686 go func() {
687 defer a.activeRequests.Delete(sessionID + "-summarize")
688 defer cancel()
689 event := AgentEvent{
690 Type: AgentEventTypeSummarize,
691 Progress: "Starting summarization...",
692 }
693
694 a.Publish(pubsub.CreatedEvent, event)
695 // Get all messages from the session
696 msgs, err := a.messages.List(summarizeCtx, sessionID)
697 if err != nil {
698 event = AgentEvent{
699 Type: AgentEventTypeError,
700 Error: fmt.Errorf("failed to list messages: %w", err),
701 Done: true,
702 }
703 a.Publish(pubsub.CreatedEvent, event)
704 return
705 }
706 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
707
708 if len(msgs) == 0 {
709 event = AgentEvent{
710 Type: AgentEventTypeError,
711 Error: fmt.Errorf("no messages to summarize"),
712 Done: true,
713 }
714 a.Publish(pubsub.CreatedEvent, event)
715 return
716 }
717
718 event = AgentEvent{
719 Type: AgentEventTypeSummarize,
720 Progress: "Analyzing conversation...",
721 }
722 a.Publish(pubsub.CreatedEvent, event)
723
724 // Add a system message to guide the summarization
725 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
726
727 // Create a new message with the summarize prompt
728 promptMsg := message.Message{
729 Role: message.User,
730 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
731 }
732
733 // Append the prompt to the messages
734 msgsWithPrompt := append(msgs, promptMsg)
735
736 event = AgentEvent{
737 Type: AgentEventTypeSummarize,
738 Progress: "Generating summary...",
739 }
740
741 a.Publish(pubsub.CreatedEvent, event)
742
743 // Send the messages to the summarize provider
744 response := a.summarizeProvider.StreamResponse(
745 summarizeCtx,
746 msgsWithPrompt,
747 make([]tools.BaseTool, 0),
748 )
749 var finalResponse *provider.ProviderResponse
750 for r := range response {
751 if r.Error != nil {
752 event = AgentEvent{
753 Type: AgentEventTypeError,
754 Error: fmt.Errorf("failed to summarize: %w", err),
755 Done: true,
756 }
757 a.Publish(pubsub.CreatedEvent, event)
758 return
759 }
760 finalResponse = r.Response
761 }
762
763 summary := strings.TrimSpace(finalResponse.Content)
764 if summary == "" {
765 event = AgentEvent{
766 Type: AgentEventTypeError,
767 Error: fmt.Errorf("empty summary returned"),
768 Done: true,
769 }
770 a.Publish(pubsub.CreatedEvent, event)
771 return
772 }
773 event = AgentEvent{
774 Type: AgentEventTypeSummarize,
775 Progress: "Creating new session...",
776 }
777
778 a.Publish(pubsub.CreatedEvent, event)
779 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
780 if err != nil {
781 event = AgentEvent{
782 Type: AgentEventTypeError,
783 Error: fmt.Errorf("failed to get session: %w", err),
784 Done: true,
785 }
786
787 a.Publish(pubsub.CreatedEvent, event)
788 return
789 }
790 // Create a message in the new session with the summary
791 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
792 Role: message.Assistant,
793 Parts: []message.ContentPart{
794 message.TextContent{Text: summary},
795 message.Finish{
796 Reason: message.FinishReasonEndTurn,
797 Time: time.Now().Unix(),
798 },
799 },
800 Model: a.summarizeProvider.Model().ID,
801 Provider: a.summarizeProviderID,
802 })
803 if err != nil {
804 event = AgentEvent{
805 Type: AgentEventTypeError,
806 Error: fmt.Errorf("failed to create summary message: %w", err),
807 Done: true,
808 }
809
810 a.Publish(pubsub.CreatedEvent, event)
811 return
812 }
813 oldSession.SummaryMessageID = msg.ID
814 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
815 oldSession.PromptTokens = 0
816 model := a.summarizeProvider.Model()
817 usage := finalResponse.Usage
818 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
819 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
820 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
821 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
822 oldSession.Cost += cost
823 _, err = a.sessions.Save(summarizeCtx, oldSession)
824 if err != nil {
825 event = AgentEvent{
826 Type: AgentEventTypeError,
827 Error: fmt.Errorf("failed to save session: %w", err),
828 Done: true,
829 }
830 a.Publish(pubsub.CreatedEvent, event)
831 }
832
833 event = AgentEvent{
834 Type: AgentEventTypeSummarize,
835 SessionID: oldSession.ID,
836 Progress: "Summary complete",
837 Done: true,
838 }
839 a.Publish(pubsub.CreatedEvent, event)
840 // Send final success event with the new session ID
841 }()
842
843 return nil
844}
845
846func (a *agent) CancelAll() {
847 if !a.IsBusy() {
848 return
849 }
850 a.activeRequests.Range(func(key, value any) bool {
851 a.Cancel(key.(string)) // key is sessionID
852 return true
853 })
854
855 timeout := time.After(5 * time.Second)
856 for a.IsBusy() {
857 select {
858 case <-timeout:
859 return
860 default:
861 time.Sleep(200 * time.Millisecond)
862 }
863 }
864}
865
866func (a *agent) UpdateModel() error {
867 cfg := config.Get()
868
869 // Get current provider configuration
870 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
871 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
872 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
873 }
874
875 // Check if provider has changed
876 if string(currentProviderCfg.ID) != a.providerID {
877 // Provider changed, need to recreate the main provider
878 model := cfg.GetModelByType(a.agentCfg.Model)
879 if model.ID == "" {
880 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
881 }
882
883 promptID := agentPromptMap[a.agentCfg.ID]
884 if promptID == "" {
885 promptID = prompt.PromptDefault
886 }
887
888 opts := []provider.ProviderClientOption{
889 provider.WithModel(a.agentCfg.Model),
890 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
891 }
892
893 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
894 if err != nil {
895 return fmt.Errorf("failed to create new provider: %w", err)
896 }
897
898 // Update the provider and provider ID
899 a.provider = newProvider
900 a.providerID = string(currentProviderCfg.ID)
901 }
902
903 // Check if small model provider has changed (affects title and summarize providers)
904 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
905 var smallModelProviderCfg config.ProviderConfig
906
907 for _, p := range cfg.Providers.Seq2() {
908 if p.ID == smallModelCfg.Provider {
909 smallModelProviderCfg = p
910 break
911 }
912 }
913
914 if smallModelProviderCfg.ID == "" {
915 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
916 }
917
918 // Check if summarize provider has changed
919 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
920 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
921 if smallModel == nil {
922 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
923 }
924
925 // Recreate title provider
926 titleOpts := []provider.ProviderClientOption{
927 provider.WithModel(config.SelectedModelTypeSmall),
928 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
929 // We want the title to be short, so we limit the max tokens
930 provider.WithMaxTokens(40),
931 }
932 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
933 if err != nil {
934 return fmt.Errorf("failed to create new title provider: %w", err)
935 }
936
937 // Recreate summarize provider
938 summarizeOpts := []provider.ProviderClientOption{
939 provider.WithModel(config.SelectedModelTypeSmall),
940 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
941 }
942 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
943 if err != nil {
944 return fmt.Errorf("failed to create new summarize provider: %w", err)
945 }
946
947 // Update the providers and provider ID
948 a.titleProvider = newTitleProvider
949 a.summarizeProvider = newSummarizeProvider
950 a.summarizeProviderID = string(smallModelProviderCfg.ID)
951 }
952
953 return nil
954}