1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "slices"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/history"
14 "github.com/charmbracelet/crush/internal/llm/prompt"
15 "github.com/charmbracelet/crush/internal/llm/provider"
16 "github.com/charmbracelet/crush/internal/llm/tools"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25// Common errors
26var (
27 ErrRequestCancelled = errors.New("request cancelled by user")
28 ErrSessionBusy = errors.New("session is currently processing another request")
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() config.Model
53 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
54 Cancel(sessionID string)
55 CancelAll()
56 IsSessionBusy(sessionID string) bool
57 IsBusy() bool
58 Summarize(ctx context.Context, sessionID string) error
59 UpdateModel() error
60}
61
62type agent struct {
63 *pubsub.Broker[AgentEvent]
64 agentCfg config.Agent
65 sessions session.Service
66 messages message.Service
67
68 tools []tools.BaseTool
69 provider provider.Provider
70 providerID string
71
72 titleProvider provider.Provider
73 summarizeProvider provider.Provider
74 summarizeProviderID string
75
76 activeRequests sync.Map
77}
78
79var agentPromptMap = map[config.AgentID]prompt.PromptID{
80 config.AgentCoder: prompt.PromptCoder,
81 config.AgentTask: prompt.PromptTask,
82}
83
84func NewAgent(
85 agentCfg config.Agent,
86 // These services are needed in the tools
87 permissions permission.Service,
88 sessions session.Service,
89 messages message.Service,
90 history history.Service,
91 lspClients map[string]*lsp.Client,
92) (Service, error) {
93 ctx := context.Background()
94 cfg := config.Get()
95 otherTools := GetMcpTools(ctx, permissions)
96 if len(lspClients) > 0 {
97 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
98 }
99
100 allTools := []tools.BaseTool{
101 tools.NewBashTool(permissions),
102 tools.NewEditTool(lspClients, permissions, history),
103 tools.NewFetchTool(permissions),
104 tools.NewGlobTool(),
105 tools.NewGrepTool(),
106 tools.NewLsTool(),
107 tools.NewSourcegraphTool(),
108 tools.NewViewTool(lspClients),
109 tools.NewWriteTool(lspClients, permissions, history),
110 }
111
112 if agentCfg.ID == config.AgentCoder {
113 taskAgentCfg := config.Get().Agents[config.AgentTask]
114 if taskAgentCfg.ID == "" {
115 return nil, fmt.Errorf("task agent not found in config")
116 }
117 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
118 if err != nil {
119 return nil, fmt.Errorf("failed to create task agent: %w", err)
120 }
121
122 allTools = append(
123 allTools,
124 NewAgentTool(
125 taskAgent,
126 sessions,
127 messages,
128 ),
129 )
130 }
131
132 allTools = append(allTools, otherTools...)
133 providerCfg := config.GetAgentProvider(agentCfg.ID)
134 if providerCfg.ID == "" {
135 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
136 }
137 model := config.GetAgentModel(agentCfg.ID)
138
139 if model.ID == "" {
140 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
141 }
142
143 promptID := agentPromptMap[agentCfg.ID]
144 if promptID == "" {
145 promptID = prompt.PromptDefault
146 }
147 opts := []provider.ProviderClientOption{
148 provider.WithModel(agentCfg.Model),
149 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
150 }
151 agentProvider, err := provider.NewProvider(providerCfg, opts...)
152 if err != nil {
153 return nil, err
154 }
155
156 smallModelCfg := cfg.Models.Small
157 var smallModel config.Model
158
159 var smallModelProviderCfg config.ProviderConfig
160 if smallModelCfg.Provider == providerCfg.ID {
161 smallModelProviderCfg = providerCfg
162 } else {
163 for _, p := range cfg.Providers {
164 if p.ID == smallModelCfg.Provider {
165 smallModelProviderCfg = p
166 break
167 }
168 }
169 if smallModelProviderCfg.ID == "" {
170 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
171 }
172 }
173 for _, m := range smallModelProviderCfg.Models {
174 if m.ID == smallModelCfg.ModelID {
175 smallModel = m
176 break
177 }
178 }
179 if smallModel.ID == "" {
180 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
181 }
182
183 titleOpts := []provider.ProviderClientOption{
184 provider.WithModel(config.SmallModel),
185 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
186 }
187 titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
188 if err != nil {
189 return nil, err
190 }
191 summarizeOpts := []provider.ProviderClientOption{
192 provider.WithModel(config.SmallModel),
193 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
194 }
195 summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
196 if err != nil {
197 return nil, err
198 }
199
200 agentTools := []tools.BaseTool{}
201 if agentCfg.AllowedTools == nil {
202 agentTools = allTools
203 } else {
204 for _, tool := range allTools {
205 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
206 agentTools = append(agentTools, tool)
207 }
208 }
209 }
210
211 agent := &agent{
212 Broker: pubsub.NewBroker[AgentEvent](),
213 agentCfg: agentCfg,
214 provider: agentProvider,
215 providerID: string(providerCfg.ID),
216 messages: messages,
217 sessions: sessions,
218 tools: agentTools,
219 titleProvider: titleProvider,
220 summarizeProvider: summarizeProvider,
221 summarizeProviderID: string(smallModelProviderCfg.ID),
222 activeRequests: sync.Map{},
223 }
224
225 return agent, nil
226}
227
228func (a *agent) Model() config.Model {
229 return config.GetAgentModel(a.agentCfg.ID)
230}
231
232func (a *agent) Cancel(sessionID string) {
233 // Cancel regular requests
234 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
235 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
237 cancel()
238 }
239 }
240
241 // Also check for summarize requests
242 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
243 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
244 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
245 cancel()
246 }
247 }
248}
249
250func (a *agent) IsBusy() bool {
251 busy := false
252 a.activeRequests.Range(func(key, value any) bool {
253 if cancelFunc, ok := value.(context.CancelFunc); ok {
254 if cancelFunc != nil {
255 busy = true
256 return false
257 }
258 }
259 return true
260 })
261 return busy
262}
263
264func (a *agent) IsSessionBusy(sessionID string) bool {
265 _, busy := a.activeRequests.Load(sessionID)
266 return busy
267}
268
269func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
270 if content == "" {
271 return nil
272 }
273 if a.titleProvider == nil {
274 return nil
275 }
276 session, err := a.sessions.Get(ctx, sessionID)
277 if err != nil {
278 return err
279 }
280 parts := []message.ContentPart{message.TextContent{
281 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
282 }}
283
284 // Use streaming approach like summarization
285 response := a.titleProvider.StreamResponse(
286 ctx,
287 []message.Message{
288 {
289 Role: message.User,
290 Parts: parts,
291 },
292 },
293 make([]tools.BaseTool, 0),
294 )
295
296 var finalResponse *provider.ProviderResponse
297 for r := range response {
298 if r.Error != nil {
299 return r.Error
300 }
301 finalResponse = r.Response
302 }
303
304 if finalResponse == nil {
305 return fmt.Errorf("no response received from title provider")
306 }
307
308 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
309 if title == "" {
310 return nil
311 }
312
313 session.Title = title
314 _, err = a.sessions.Save(ctx, session)
315 return err
316}
317
318func (a *agent) err(err error) AgentEvent {
319 return AgentEvent{
320 Type: AgentEventTypeError,
321 Error: err,
322 }
323}
324
325func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
326 if !a.Model().SupportsImages && attachments != nil {
327 attachments = nil
328 }
329 events := make(chan AgentEvent)
330 if a.IsSessionBusy(sessionID) {
331 return nil, ErrSessionBusy
332 }
333
334 genCtx, cancel := context.WithCancel(ctx)
335
336 a.activeRequests.Store(sessionID, cancel)
337 go func() {
338 logging.Debug("Request started", "sessionID", sessionID)
339 defer logging.RecoverPanic("agent.Run", func() {
340 events <- a.err(fmt.Errorf("panic while running the agent"))
341 })
342 var attachmentParts []message.ContentPart
343 for _, attachment := range attachments {
344 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
345 }
346 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
347 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
348 logging.ErrorPersist(result.Error.Error())
349 }
350 logging.Debug("Request completed", "sessionID", sessionID)
351 a.activeRequests.Delete(sessionID)
352 cancel()
353 a.Publish(pubsub.CreatedEvent, result)
354 events <- result
355 close(events)
356 }()
357 return events, nil
358}
359
360func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
361 cfg := config.Get()
362 // List existing messages; if none, start title generation asynchronously.
363 msgs, err := a.messages.List(ctx, sessionID)
364 if err != nil {
365 return a.err(fmt.Errorf("failed to list messages: %w", err))
366 }
367 if len(msgs) == 0 {
368 go func() {
369 defer logging.RecoverPanic("agent.Run", func() {
370 logging.ErrorPersist("panic while generating title")
371 })
372 titleErr := a.generateTitle(context.Background(), sessionID, content)
373 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
374 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
375 }
376 }()
377 }
378 session, err := a.sessions.Get(ctx, sessionID)
379 if err != nil {
380 return a.err(fmt.Errorf("failed to get session: %w", err))
381 }
382 if session.SummaryMessageID != "" {
383 summaryMsgInex := -1
384 for i, msg := range msgs {
385 if msg.ID == session.SummaryMessageID {
386 summaryMsgInex = i
387 break
388 }
389 }
390 if summaryMsgInex != -1 {
391 msgs = msgs[summaryMsgInex:]
392 msgs[0].Role = message.User
393 }
394 }
395
396 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
397 if err != nil {
398 return a.err(fmt.Errorf("failed to create user message: %w", err))
399 }
400 // Append the new user message to the conversation history.
401 msgHistory := append(msgs, userMsg)
402
403 for {
404 // Check for cancellation before each iteration
405 select {
406 case <-ctx.Done():
407 return a.err(ctx.Err())
408 default:
409 // Continue processing
410 }
411 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
412 if err != nil {
413 if errors.Is(err, context.Canceled) {
414 agentMessage.AddFinish(message.FinishReasonCanceled)
415 a.messages.Update(context.Background(), agentMessage)
416 return a.err(ErrRequestCancelled)
417 }
418 return a.err(fmt.Errorf("failed to process events: %w", err))
419 }
420 if cfg.Options.Debug {
421 seqId := (len(msgHistory) + 1) / 2
422 toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults)
423 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath)
424 } else {
425 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
426 }
427 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
428 // We are not done, we need to respond with the tool response
429 msgHistory = append(msgHistory, agentMessage, *toolResults)
430 continue
431 }
432 return AgentEvent{
433 Type: AgentEventTypeResponse,
434 Message: agentMessage,
435 Done: true,
436 }
437 }
438}
439
440func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
441 parts := []message.ContentPart{message.TextContent{Text: content}}
442 parts = append(parts, attachmentParts...)
443 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444 Role: message.User,
445 Parts: parts,
446 })
447}
448
449func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
450 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
451 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
452
453 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
454 Role: message.Assistant,
455 Parts: []message.ContentPart{},
456 Model: a.Model().ID,
457 Provider: a.providerID,
458 })
459 if err != nil {
460 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
461 }
462
463 // Add the session and message ID into the context if needed by tools.
464 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
465
466 // Process each event in the stream.
467 for event := range eventChan {
468 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
469 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
470 return assistantMsg, nil, processErr
471 }
472 if ctx.Err() != nil {
473 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
474 return assistantMsg, nil, ctx.Err()
475 }
476 }
477
478 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
479 toolCalls := assistantMsg.ToolCalls()
480 for i, toolCall := range toolCalls {
481 select {
482 case <-ctx.Done():
483 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
484 // Make all future tool calls cancelled
485 for j := i; j < len(toolCalls); j++ {
486 toolResults[j] = message.ToolResult{
487 ToolCallID: toolCalls[j].ID,
488 Content: "Tool execution canceled by user",
489 IsError: true,
490 }
491 }
492 goto out
493 default:
494 // Continue processing
495 var tool tools.BaseTool
496 for _, availableTool := range a.tools {
497 if availableTool.Info().Name == toolCall.Name {
498 tool = availableTool
499 break
500 }
501 }
502
503 // Tool not found
504 if tool == nil {
505 toolResults[i] = message.ToolResult{
506 ToolCallID: toolCall.ID,
507 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
508 IsError: true,
509 }
510 continue
511 }
512 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
513 ID: toolCall.ID,
514 Name: toolCall.Name,
515 Input: toolCall.Input,
516 })
517 if toolErr != nil {
518 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
519 toolResults[i] = message.ToolResult{
520 ToolCallID: toolCall.ID,
521 Content: "Permission denied",
522 IsError: true,
523 }
524 for j := i + 1; j < len(toolCalls); j++ {
525 toolResults[j] = message.ToolResult{
526 ToolCallID: toolCalls[j].ID,
527 Content: "Tool execution canceled by user",
528 IsError: true,
529 }
530 }
531 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
532 break
533 }
534 }
535 toolResults[i] = message.ToolResult{
536 ToolCallID: toolCall.ID,
537 Content: toolResult.Content,
538 Metadata: toolResult.Metadata,
539 IsError: toolResult.IsError,
540 }
541 }
542 }
543out:
544 if len(toolResults) == 0 {
545 return assistantMsg, nil, nil
546 }
547 parts := make([]message.ContentPart, 0)
548 for _, tr := range toolResults {
549 parts = append(parts, tr)
550 }
551 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
552 Role: message.Tool,
553 Parts: parts,
554 Provider: a.providerID,
555 })
556 if err != nil {
557 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
558 }
559
560 return assistantMsg, &msg, err
561}
562
563func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
564 msg.AddFinish(finishReson)
565 _ = a.messages.Update(ctx, *msg)
566}
567
568func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
569 select {
570 case <-ctx.Done():
571 return ctx.Err()
572 default:
573 // Continue processing.
574 }
575
576 switch event.Type {
577 case provider.EventThinkingDelta:
578 assistantMsg.AppendReasoningContent(event.Content)
579 return a.messages.Update(ctx, *assistantMsg)
580 case provider.EventContentDelta:
581 assistantMsg.AppendContent(event.Content)
582 return a.messages.Update(ctx, *assistantMsg)
583 case provider.EventToolUseStart:
584 logging.Info("Tool call started", "toolCall", event.ToolCall)
585 assistantMsg.AddToolCall(*event.ToolCall)
586 return a.messages.Update(ctx, *assistantMsg)
587 case provider.EventToolUseDelta:
588 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
589 return a.messages.Update(ctx, *assistantMsg)
590 case provider.EventToolUseStop:
591 logging.Info("Finished tool call", "toolCall", event.ToolCall)
592 assistantMsg.FinishToolCall(event.ToolCall.ID)
593 return a.messages.Update(ctx, *assistantMsg)
594 case provider.EventError:
595 if errors.Is(event.Error, context.Canceled) {
596 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
597 return context.Canceled
598 }
599 logging.ErrorPersist(event.Error.Error())
600 return event.Error
601 case provider.EventComplete:
602 assistantMsg.SetToolCalls(event.Response.ToolCalls)
603 assistantMsg.AddFinish(event.Response.FinishReason)
604 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
605 return fmt.Errorf("failed to update message: %w", err)
606 }
607 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
608 }
609
610 return nil
611}
612
613func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
614 sess, err := a.sessions.Get(ctx, sessionID)
615 if err != nil {
616 return fmt.Errorf("failed to get session: %w", err)
617 }
618
619 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
620 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
621 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
622 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
623
624 sess.Cost += cost
625 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
626 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
627
628 _, err = a.sessions.Save(ctx, sess)
629 if err != nil {
630 return fmt.Errorf("failed to save session: %w", err)
631 }
632 return nil
633}
634
635func (a *agent) Summarize(ctx context.Context, sessionID string) error {
636 if a.summarizeProvider == nil {
637 return fmt.Errorf("summarize provider not available")
638 }
639
640 // Check if session is busy
641 if a.IsSessionBusy(sessionID) {
642 return ErrSessionBusy
643 }
644
645 // Create a new context with cancellation
646 summarizeCtx, cancel := context.WithCancel(ctx)
647
648 // Store the cancel function in activeRequests to allow cancellation
649 a.activeRequests.Store(sessionID+"-summarize", cancel)
650
651 go func() {
652 defer a.activeRequests.Delete(sessionID + "-summarize")
653 defer cancel()
654 event := AgentEvent{
655 Type: AgentEventTypeSummarize,
656 Progress: "Starting summarization...",
657 }
658
659 a.Publish(pubsub.CreatedEvent, event)
660 // Get all messages from the session
661 msgs, err := a.messages.List(summarizeCtx, sessionID)
662 if err != nil {
663 event = AgentEvent{
664 Type: AgentEventTypeError,
665 Error: fmt.Errorf("failed to list messages: %w", err),
666 Done: true,
667 }
668 a.Publish(pubsub.CreatedEvent, event)
669 return
670 }
671 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
672
673 if len(msgs) == 0 {
674 event = AgentEvent{
675 Type: AgentEventTypeError,
676 Error: fmt.Errorf("no messages to summarize"),
677 Done: true,
678 }
679 a.Publish(pubsub.CreatedEvent, event)
680 return
681 }
682
683 event = AgentEvent{
684 Type: AgentEventTypeSummarize,
685 Progress: "Analyzing conversation...",
686 }
687 a.Publish(pubsub.CreatedEvent, event)
688
689 // Add a system message to guide the summarization
690 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
691
692 // Create a new message with the summarize prompt
693 promptMsg := message.Message{
694 Role: message.User,
695 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
696 }
697
698 // Append the prompt to the messages
699 msgsWithPrompt := append(msgs, promptMsg)
700
701 event = AgentEvent{
702 Type: AgentEventTypeSummarize,
703 Progress: "Generating summary...",
704 }
705
706 a.Publish(pubsub.CreatedEvent, event)
707
708 // Send the messages to the summarize provider
709 response := a.summarizeProvider.StreamResponse(
710 summarizeCtx,
711 msgsWithPrompt,
712 make([]tools.BaseTool, 0),
713 )
714 var finalResponse *provider.ProviderResponse
715 for r := range response {
716 if r.Error != nil {
717 event = AgentEvent{
718 Type: AgentEventTypeError,
719 Error: fmt.Errorf("failed to summarize: %w", err),
720 Done: true,
721 }
722 a.Publish(pubsub.CreatedEvent, event)
723 return
724 }
725 finalResponse = r.Response
726 }
727
728 summary := strings.TrimSpace(finalResponse.Content)
729 if summary == "" {
730 event = AgentEvent{
731 Type: AgentEventTypeError,
732 Error: fmt.Errorf("empty summary returned"),
733 Done: true,
734 }
735 a.Publish(pubsub.CreatedEvent, event)
736 return
737 }
738 event = AgentEvent{
739 Type: AgentEventTypeSummarize,
740 Progress: "Creating new session...",
741 }
742
743 a.Publish(pubsub.CreatedEvent, event)
744 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
745 if err != nil {
746 event = AgentEvent{
747 Type: AgentEventTypeError,
748 Error: fmt.Errorf("failed to get session: %w", err),
749 Done: true,
750 }
751
752 a.Publish(pubsub.CreatedEvent, event)
753 return
754 }
755 // Create a message in the new session with the summary
756 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
757 Role: message.Assistant,
758 Parts: []message.ContentPart{
759 message.TextContent{Text: summary},
760 message.Finish{
761 Reason: message.FinishReasonEndTurn,
762 Time: time.Now().Unix(),
763 },
764 },
765 Model: a.summarizeProvider.Model().ID,
766 Provider: a.summarizeProviderID,
767 })
768 if err != nil {
769 event = AgentEvent{
770 Type: AgentEventTypeError,
771 Error: fmt.Errorf("failed to create summary message: %w", err),
772 Done: true,
773 }
774
775 a.Publish(pubsub.CreatedEvent, event)
776 return
777 }
778 oldSession.SummaryMessageID = msg.ID
779 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
780 oldSession.PromptTokens = 0
781 model := a.summarizeProvider.Model()
782 usage := finalResponse.Usage
783 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
784 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
785 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
786 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
787 oldSession.Cost += cost
788 _, err = a.sessions.Save(summarizeCtx, oldSession)
789 if err != nil {
790 event = AgentEvent{
791 Type: AgentEventTypeError,
792 Error: fmt.Errorf("failed to save session: %w", err),
793 Done: true,
794 }
795 a.Publish(pubsub.CreatedEvent, event)
796 }
797
798 event = AgentEvent{
799 Type: AgentEventTypeSummarize,
800 SessionID: oldSession.ID,
801 Progress: "Summary complete",
802 Done: true,
803 }
804 a.Publish(pubsub.CreatedEvent, event)
805 // Send final success event with the new session ID
806 }()
807
808 return nil
809}
810
811func (a *agent) CancelAll() {
812 a.activeRequests.Range(func(key, value any) bool {
813 a.Cancel(key.(string)) // key is sessionID
814 return true
815 })
816}
817
818func (a *agent) UpdateModel() error {
819 cfg := config.Get()
820
821 // Get current provider configuration
822 currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
823 if currentProviderCfg.ID == "" {
824 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
825 }
826
827 // Check if provider has changed
828 if string(currentProviderCfg.ID) != a.providerID {
829 // Provider changed, need to recreate the main provider
830 model := config.GetAgentModel(a.agentCfg.ID)
831 if model.ID == "" {
832 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
833 }
834
835 promptID := agentPromptMap[a.agentCfg.ID]
836 if promptID == "" {
837 promptID = prompt.PromptDefault
838 }
839
840 opts := []provider.ProviderClientOption{
841 provider.WithModel(a.agentCfg.Model),
842 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
843 }
844
845 newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
846 if err != nil {
847 return fmt.Errorf("failed to create new provider: %w", err)
848 }
849
850 // Update the provider and provider ID
851 a.provider = newProvider
852 a.providerID = string(currentProviderCfg.ID)
853 }
854
855 // Check if small model provider has changed (affects title and summarize providers)
856 smallModelCfg := cfg.Models.Small
857 var smallModelProviderCfg config.ProviderConfig
858
859 for _, p := range cfg.Providers {
860 if p.ID == smallModelCfg.Provider {
861 smallModelProviderCfg = p
862 break
863 }
864 }
865
866 if smallModelProviderCfg.ID == "" {
867 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
868 }
869
870 // Check if summarize provider has changed
871 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
872 var smallModel config.Model
873 for _, m := range smallModelProviderCfg.Models {
874 if m.ID == smallModelCfg.ModelID {
875 smallModel = m
876 break
877 }
878 }
879 if smallModel.ID == "" {
880 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
881 }
882
883 // Recreate title provider
884 titleOpts := []provider.ProviderClientOption{
885 provider.WithModel(config.SmallModel),
886 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
887 // We want the title to be short, so we limit the max tokens
888 provider.WithMaxTokens(40),
889 }
890 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
891 if err != nil {
892 return fmt.Errorf("failed to create new title provider: %w", err)
893 }
894
895 // Recreate summarize provider
896 summarizeOpts := []provider.ProviderClientOption{
897 provider.WithModel(config.SmallModel),
898 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
899 }
900 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
901 if err != nil {
902 return fmt.Errorf("failed to create new summarize provider: %w", err)
903 }
904
905 // Update the providers and provider ID
906 a.titleProvider = newTitleProvider
907 a.summarizeProvider = newSummarizeProvider
908 a.summarizeProviderID = string(smallModelProviderCfg.ID)
909 }
910
911 return nil
912}