1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "slices"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/history"
14 "github.com/charmbracelet/crush/internal/llm/prompt"
15 "github.com/charmbracelet/crush/internal/llm/provider"
16 "github.com/charmbracelet/crush/internal/llm/tools"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25// Common errors
26var (
27 ErrRequestCancelled = errors.New("request cancelled by user")
28 ErrSessionBusy = errors.New("session is currently processing another request")
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() config.Model
53 EffectiveMaxTokens() int64
54 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
55 Cancel(sessionID string)
56 CancelAll()
57 IsSessionBusy(sessionID string) bool
58 IsBusy() bool
59 Summarize(ctx context.Context, sessionID string) error
60 UpdateModel() error
61}
62
63type agent struct {
64 *pubsub.Broker[AgentEvent]
65 agentCfg config.Agent
66 sessions session.Service
67 messages message.Service
68
69 tools []tools.BaseTool
70 provider provider.Provider
71 providerID string
72
73 titleProvider provider.Provider
74 summarizeProvider provider.Provider
75 summarizeProviderID string
76
77 activeRequests sync.Map
78}
79
80var agentPromptMap = map[config.AgentID]prompt.PromptID{
81 config.AgentCoder: prompt.PromptCoder,
82 config.AgentTask: prompt.PromptTask,
83}
84
85func NewAgent(
86 agentCfg config.Agent,
87 // These services are needed in the tools
88 permissions permission.Service,
89 sessions session.Service,
90 messages message.Service,
91 history history.Service,
92 lspClients map[string]*lsp.Client,
93) (Service, error) {
94 ctx := context.Background()
95 cfg := config.Get()
96 otherTools := GetMcpTools(ctx, permissions)
97 if len(cfg.LSP) > 0 {
98 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
99 otherTools = append(otherTools, tools.NewDefinitionsTool(lspClients))
100 }
101
102 allTools := []tools.BaseTool{
103 tools.NewBashTool(permissions),
104 tools.NewEditTool(lspClients, permissions, history),
105 tools.NewFetchTool(permissions),
106 tools.NewGlobTool(),
107 tools.NewGrepTool(),
108 tools.NewLsTool(),
109 tools.NewSourcegraphTool(),
110 tools.NewViewTool(lspClients),
111 tools.NewWriteTool(lspClients, permissions, history),
112 }
113
114 if agentCfg.ID == config.AgentCoder {
115 taskAgentCfg := config.Get().Agents[config.AgentTask]
116 if taskAgentCfg.ID == "" {
117 return nil, fmt.Errorf("task agent not found in config")
118 }
119 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
120 if err != nil {
121 return nil, fmt.Errorf("failed to create task agent: %w", err)
122 }
123
124 allTools = append(
125 allTools,
126 NewAgentTool(
127 taskAgent,
128 sessions,
129 messages,
130 ),
131 )
132 }
133
134 allTools = append(allTools, otherTools...)
135 providerCfg := config.GetAgentProvider(agentCfg.ID)
136 if providerCfg.ID == "" {
137 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
138 }
139 model := config.GetAgentModel(agentCfg.ID)
140
141 if model.ID == "" {
142 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
143 }
144
145 promptID := agentPromptMap[agentCfg.ID]
146 if promptID == "" {
147 promptID = prompt.PromptDefault
148 }
149 opts := []provider.ProviderClientOption{
150 provider.WithModel(agentCfg.Model),
151 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
152 }
153 agentProvider, err := provider.NewProvider(providerCfg, opts...)
154 if err != nil {
155 return nil, err
156 }
157
158 smallModelCfg := cfg.Models.Small
159 var smallModel config.Model
160
161 var smallModelProviderCfg config.ProviderConfig
162 if smallModelCfg.Provider == providerCfg.ID {
163 smallModelProviderCfg = providerCfg
164 } else {
165 for _, p := range cfg.Providers {
166 if p.ID == smallModelCfg.Provider {
167 smallModelProviderCfg = p
168 break
169 }
170 }
171 if smallModelProviderCfg.ID == "" {
172 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
173 }
174 }
175 for _, m := range smallModelProviderCfg.Models {
176 if m.ID == smallModelCfg.ModelID {
177 smallModel = m
178 break
179 }
180 }
181 if smallModel.ID == "" {
182 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
183 }
184
185 titleOpts := []provider.ProviderClientOption{
186 provider.WithModel(config.SmallModel),
187 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
188 }
189 titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
190 if err != nil {
191 return nil, err
192 }
193 summarizeOpts := []provider.ProviderClientOption{
194 provider.WithModel(config.SmallModel),
195 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
196 }
197 summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
198 if err != nil {
199 return nil, err
200 }
201
202 agentTools := []tools.BaseTool{}
203 if agentCfg.AllowedTools == nil {
204 agentTools = allTools
205 } else {
206 for _, tool := range allTools {
207 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
208 agentTools = append(agentTools, tool)
209 }
210 }
211 }
212
213 agent := &agent{
214 Broker: pubsub.NewBroker[AgentEvent](),
215 agentCfg: agentCfg,
216 provider: agentProvider,
217 providerID: string(providerCfg.ID),
218 messages: messages,
219 sessions: sessions,
220 tools: agentTools,
221 titleProvider: titleProvider,
222 summarizeProvider: summarizeProvider,
223 summarizeProviderID: string(smallModelProviderCfg.ID),
224 activeRequests: sync.Map{},
225 }
226
227 return agent, nil
228}
229
230func (a *agent) Model() config.Model {
231 return config.GetAgentModel(a.agentCfg.ID)
232}
233
234func (a *agent) EffectiveMaxTokens() int64 {
235 return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
236}
237
238func (a *agent) Cancel(sessionID string) {
239 // Cancel regular requests
240 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
241 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
242 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
243 cancel()
244 }
245 }
246
247 // Also check for summarize requests
248 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
249 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
250 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
251 cancel()
252 }
253 }
254}
255
256func (a *agent) IsBusy() bool {
257 busy := false
258 a.activeRequests.Range(func(key, value any) bool {
259 if cancelFunc, ok := value.(context.CancelFunc); ok {
260 if cancelFunc != nil {
261 busy = true
262 return false
263 }
264 }
265 return true
266 })
267 return busy
268}
269
270func (a *agent) IsSessionBusy(sessionID string) bool {
271 _, busy := a.activeRequests.Load(sessionID)
272 return busy
273}
274
275func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
276 if content == "" {
277 return nil
278 }
279 if a.titleProvider == nil {
280 return nil
281 }
282 session, err := a.sessions.Get(ctx, sessionID)
283 if err != nil {
284 return err
285 }
286 parts := []message.ContentPart{message.TextContent{
287 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
288 }}
289
290 // Use streaming approach like summarization
291 response := a.titleProvider.StreamResponse(
292 ctx,
293 []message.Message{
294 {
295 Role: message.User,
296 Parts: parts,
297 },
298 },
299 make([]tools.BaseTool, 0),
300 )
301
302 var finalResponse *provider.ProviderResponse
303 for r := range response {
304 if r.Error != nil {
305 return r.Error
306 }
307 finalResponse = r.Response
308 }
309
310 if finalResponse == nil {
311 return fmt.Errorf("no response received from title provider")
312 }
313
314 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
315 if title == "" {
316 return nil
317 }
318
319 session.Title = title
320 _, err = a.sessions.Save(ctx, session)
321 return err
322}
323
324func (a *agent) err(err error) AgentEvent {
325 return AgentEvent{
326 Type: AgentEventTypeError,
327 Error: err,
328 }
329}
330
331func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
332 if !a.Model().SupportsImages && attachments != nil {
333 attachments = nil
334 }
335 events := make(chan AgentEvent)
336 if a.IsSessionBusy(sessionID) {
337 return nil, ErrSessionBusy
338 }
339
340 genCtx, cancel := context.WithCancel(ctx)
341
342 a.activeRequests.Store(sessionID, cancel)
343 go func() {
344 logging.Debug("Request started", "sessionID", sessionID)
345 defer logging.RecoverPanic("agent.Run", func() {
346 events <- a.err(fmt.Errorf("panic while running the agent"))
347 })
348 var attachmentParts []message.ContentPart
349 for _, attachment := range attachments {
350 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
351 }
352 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
353 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
354 logging.ErrorPersist(result.Error.Error())
355 }
356 logging.Debug("Request completed", "sessionID", sessionID)
357 a.activeRequests.Delete(sessionID)
358 cancel()
359 a.Publish(pubsub.CreatedEvent, result)
360 events <- result
361 close(events)
362 }()
363 return events, nil
364}
365
366func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
367 // List existing messages; if none, start title generation asynchronously.
368 msgs, err := a.messages.List(ctx, sessionID)
369 if err != nil {
370 return a.err(fmt.Errorf("failed to list messages: %w", err))
371 }
372 if len(msgs) == 0 {
373 go func() {
374 defer logging.RecoverPanic("agent.Run", func() {
375 logging.ErrorPersist("panic while generating title")
376 })
377 titleErr := a.generateTitle(context.Background(), sessionID, content)
378 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
379 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
380 }
381 }()
382 }
383 session, err := a.sessions.Get(ctx, sessionID)
384 if err != nil {
385 return a.err(fmt.Errorf("failed to get session: %w", err))
386 }
387 if session.SummaryMessageID != "" {
388 summaryMsgInex := -1
389 for i, msg := range msgs {
390 if msg.ID == session.SummaryMessageID {
391 summaryMsgInex = i
392 break
393 }
394 }
395 if summaryMsgInex != -1 {
396 msgs = msgs[summaryMsgInex:]
397 msgs[0].Role = message.User
398 }
399 }
400
401 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
402 if err != nil {
403 return a.err(fmt.Errorf("failed to create user message: %w", err))
404 }
405 // Append the new user message to the conversation history.
406 msgHistory := append(msgs, userMsg)
407
408 for {
409 // Check for cancellation before each iteration
410 select {
411 case <-ctx.Done():
412 return a.err(ctx.Err())
413 default:
414 // Continue processing
415 }
416 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
417 if err != nil {
418 if errors.Is(err, context.Canceled) {
419 agentMessage.AddFinish(message.FinishReasonCanceled)
420 a.messages.Update(context.Background(), agentMessage)
421 return a.err(ErrRequestCancelled)
422 }
423 return a.err(fmt.Errorf("failed to process events: %w", err))
424 }
425 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
426 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
427 // We are not done, we need to respond with the tool response
428 msgHistory = append(msgHistory, agentMessage, *toolResults)
429 continue
430 }
431 return AgentEvent{
432 Type: AgentEventTypeResponse,
433 Message: agentMessage,
434 Done: true,
435 }
436 }
437}
438
439func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
440 parts := []message.ContentPart{message.TextContent{Text: content}}
441 parts = append(parts, attachmentParts...)
442 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
443 Role: message.User,
444 Parts: parts,
445 })
446}
447
448func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
449 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
450
451 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
452 Role: message.Assistant,
453 Parts: []message.ContentPart{},
454 Model: a.Model().ID,
455 Provider: a.providerID,
456 })
457 if err != nil {
458 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
459 }
460
461 // Add the session and message ID into the context if needed by tools.
462 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
463 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
464
465 // Process each event in the stream.
466 for event := range eventChan {
467 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
468 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
469 return assistantMsg, nil, processErr
470 }
471 if ctx.Err() != nil {
472 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
473 return assistantMsg, nil, ctx.Err()
474 }
475 }
476
477 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
478 toolCalls := assistantMsg.ToolCalls()
479 for i, toolCall := range toolCalls {
480 select {
481 case <-ctx.Done():
482 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
483 // Make all future tool calls cancelled
484 for j := i; j < len(toolCalls); j++ {
485 toolResults[j] = message.ToolResult{
486 ToolCallID: toolCalls[j].ID,
487 Content: "Tool execution canceled by user",
488 IsError: true,
489 }
490 }
491 goto out
492 default:
493 // Continue processing
494 var tool tools.BaseTool
495 for _, availableTools := range a.tools {
496 if availableTools.Info().Name == toolCall.Name {
497 tool = availableTools
498 }
499 }
500
501 // Tool not found
502 if tool == nil {
503 toolResults[i] = message.ToolResult{
504 ToolCallID: toolCall.ID,
505 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
506 IsError: true,
507 }
508 continue
509 }
510 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
511 ID: toolCall.ID,
512 Name: toolCall.Name,
513 Input: toolCall.Input,
514 })
515 if toolErr != nil {
516 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
517 toolResults[i] = message.ToolResult{
518 ToolCallID: toolCall.ID,
519 Content: "Permission denied",
520 IsError: true,
521 }
522 for j := i + 1; j < len(toolCalls); j++ {
523 toolResults[j] = message.ToolResult{
524 ToolCallID: toolCalls[j].ID,
525 Content: "Tool execution canceled by user",
526 IsError: true,
527 }
528 }
529 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
530 break
531 }
532 }
533 toolResults[i] = message.ToolResult{
534 ToolCallID: toolCall.ID,
535 Content: toolResult.Content,
536 Metadata: toolResult.Metadata,
537 IsError: toolResult.IsError,
538 }
539 }
540 }
541out:
542 if len(toolResults) == 0 {
543 return assistantMsg, nil, nil
544 }
545 parts := make([]message.ContentPart, 0)
546 for _, tr := range toolResults {
547 parts = append(parts, tr)
548 }
549 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
550 Role: message.Tool,
551 Parts: parts,
552 Provider: a.providerID,
553 })
554 if err != nil {
555 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
556 }
557
558 return assistantMsg, &msg, err
559}
560
561func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
562 msg.AddFinish(finishReson)
563 _ = a.messages.Update(ctx, *msg)
564}
565
566func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
567 select {
568 case <-ctx.Done():
569 return ctx.Err()
570 default:
571 // Continue processing.
572 }
573
574 switch event.Type {
575 case provider.EventThinkingDelta:
576 assistantMsg.AppendReasoningContent(event.Content)
577 return a.messages.Update(ctx, *assistantMsg)
578 case provider.EventContentDelta:
579 assistantMsg.AppendContent(event.Content)
580 return a.messages.Update(ctx, *assistantMsg)
581 case provider.EventToolUseStart:
582 logging.Info("Tool call started", "toolCall", event.ToolCall)
583 assistantMsg.AddToolCall(*event.ToolCall)
584 return a.messages.Update(ctx, *assistantMsg)
585 case provider.EventToolUseDelta:
586 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
587 return a.messages.Update(ctx, *assistantMsg)
588 case provider.EventToolUseStop:
589 logging.Info("Finished tool call", "toolCall", event.ToolCall)
590 assistantMsg.FinishToolCall(event.ToolCall.ID)
591 return a.messages.Update(ctx, *assistantMsg)
592 case provider.EventError:
593 if errors.Is(event.Error, context.Canceled) {
594 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
595 return context.Canceled
596 }
597 logging.ErrorPersist(event.Error.Error())
598 return event.Error
599 case provider.EventComplete:
600 assistantMsg.SetToolCalls(event.Response.ToolCalls)
601 assistantMsg.AddFinish(event.Response.FinishReason)
602 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
603 return fmt.Errorf("failed to update message: %w", err)
604 }
605 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
606 }
607
608 return nil
609}
610
611func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
612 sess, err := a.sessions.Get(ctx, sessionID)
613 if err != nil {
614 return fmt.Errorf("failed to get session: %w", err)
615 }
616
617 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
618 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
619 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
620 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
621
622 sess.Cost += cost
623 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
624 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
625
626 _, err = a.sessions.Save(ctx, sess)
627 if err != nil {
628 return fmt.Errorf("failed to save session: %w", err)
629 }
630 return nil
631}
632
633func (a *agent) Summarize(ctx context.Context, sessionID string) error {
634 if a.summarizeProvider == nil {
635 return fmt.Errorf("summarize provider not available")
636 }
637
638 // Check if session is busy
639 if a.IsSessionBusy(sessionID) {
640 return ErrSessionBusy
641 }
642
643 // Create a new context with cancellation
644 summarizeCtx, cancel := context.WithCancel(ctx)
645
646 // Store the cancel function in activeRequests to allow cancellation
647 a.activeRequests.Store(sessionID+"-summarize", cancel)
648
649 go func() {
650 defer a.activeRequests.Delete(sessionID + "-summarize")
651 defer cancel()
652 event := AgentEvent{
653 Type: AgentEventTypeSummarize,
654 Progress: "Starting summarization...",
655 }
656
657 a.Publish(pubsub.CreatedEvent, event)
658 // Get all messages from the session
659 msgs, err := a.messages.List(summarizeCtx, sessionID)
660 if err != nil {
661 event = AgentEvent{
662 Type: AgentEventTypeError,
663 Error: fmt.Errorf("failed to list messages: %w", err),
664 Done: true,
665 }
666 a.Publish(pubsub.CreatedEvent, event)
667 return
668 }
669
670 if len(msgs) == 0 {
671 event = AgentEvent{
672 Type: AgentEventTypeError,
673 Error: fmt.Errorf("no messages to summarize"),
674 Done: true,
675 }
676 a.Publish(pubsub.CreatedEvent, event)
677 return
678 }
679
680 event = AgentEvent{
681 Type: AgentEventTypeSummarize,
682 Progress: "Analyzing conversation...",
683 }
684 a.Publish(pubsub.CreatedEvent, event)
685
686 // Add a system message to guide the summarization
687 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
688
689 // Create a new message with the summarize prompt
690 promptMsg := message.Message{
691 Role: message.User,
692 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
693 }
694
695 // Append the prompt to the messages
696 msgsWithPrompt := append(msgs, promptMsg)
697
698 event = AgentEvent{
699 Type: AgentEventTypeSummarize,
700 Progress: "Generating summary...",
701 }
702
703 a.Publish(pubsub.CreatedEvent, event)
704
705 // Send the messages to the summarize provider
706 response := a.summarizeProvider.StreamResponse(
707 summarizeCtx,
708 msgsWithPrompt,
709 make([]tools.BaseTool, 0),
710 )
711 var finalResponse *provider.ProviderResponse
712 for r := range response {
713 if r.Error != nil {
714 event = AgentEvent{
715 Type: AgentEventTypeError,
716 Error: fmt.Errorf("failed to summarize: %w", err),
717 Done: true,
718 }
719 a.Publish(pubsub.CreatedEvent, event)
720 return
721 }
722 finalResponse = r.Response
723 }
724
725 summary := strings.TrimSpace(finalResponse.Content)
726 if summary == "" {
727 event = AgentEvent{
728 Type: AgentEventTypeError,
729 Error: fmt.Errorf("empty summary returned"),
730 Done: true,
731 }
732 a.Publish(pubsub.CreatedEvent, event)
733 return
734 }
735 event = AgentEvent{
736 Type: AgentEventTypeSummarize,
737 Progress: "Creating new session...",
738 }
739
740 a.Publish(pubsub.CreatedEvent, event)
741 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
742 if err != nil {
743 event = AgentEvent{
744 Type: AgentEventTypeError,
745 Error: fmt.Errorf("failed to get session: %w", err),
746 Done: true,
747 }
748
749 a.Publish(pubsub.CreatedEvent, event)
750 return
751 }
752 // Create a message in the new session with the summary
753 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
754 Role: message.Assistant,
755 Parts: []message.ContentPart{
756 message.TextContent{Text: summary},
757 message.Finish{
758 Reason: message.FinishReasonEndTurn,
759 Time: time.Now().Unix(),
760 },
761 },
762 Model: a.summarizeProvider.Model().ID,
763 Provider: a.summarizeProviderID,
764 })
765 if err != nil {
766 event = AgentEvent{
767 Type: AgentEventTypeError,
768 Error: fmt.Errorf("failed to create summary message: %w", err),
769 Done: true,
770 }
771
772 a.Publish(pubsub.CreatedEvent, event)
773 return
774 }
775 oldSession.SummaryMessageID = msg.ID
776 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
777 oldSession.PromptTokens = 0
778 model := a.summarizeProvider.Model()
779 usage := finalResponse.Usage
780 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
781 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
782 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
783 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
784 oldSession.Cost += cost
785 _, err = a.sessions.Save(summarizeCtx, oldSession)
786 if err != nil {
787 event = AgentEvent{
788 Type: AgentEventTypeError,
789 Error: fmt.Errorf("failed to save session: %w", err),
790 Done: true,
791 }
792 a.Publish(pubsub.CreatedEvent, event)
793 }
794
795 event = AgentEvent{
796 Type: AgentEventTypeSummarize,
797 SessionID: oldSession.ID,
798 Progress: "Summary complete",
799 Done: true,
800 }
801 a.Publish(pubsub.CreatedEvent, event)
802 // Send final success event with the new session ID
803 }()
804
805 return nil
806}
807
808func (a *agent) CancelAll() {
809 a.activeRequests.Range(func(key, value any) bool {
810 a.Cancel(key.(string)) // key is sessionID
811 return true
812 })
813}
814
815func (a *agent) UpdateModel() error {
816 cfg := config.Get()
817
818 // Get current provider configuration
819 currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
820 if currentProviderCfg.ID == "" {
821 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
822 }
823
824 // Check if provider has changed
825 if string(currentProviderCfg.ID) != a.providerID {
826 // Provider changed, need to recreate the main provider
827 model := config.GetAgentModel(a.agentCfg.ID)
828 if model.ID == "" {
829 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
830 }
831
832 promptID := agentPromptMap[a.agentCfg.ID]
833 if promptID == "" {
834 promptID = prompt.PromptDefault
835 }
836
837 opts := []provider.ProviderClientOption{
838 provider.WithModel(a.agentCfg.Model),
839 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
840 }
841
842 newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
843 if err != nil {
844 return fmt.Errorf("failed to create new provider: %w", err)
845 }
846
847 // Update the provider and provider ID
848 a.provider = newProvider
849 a.providerID = string(currentProviderCfg.ID)
850 }
851
852 // Check if small model provider has changed (affects title and summarize providers)
853 smallModelCfg := cfg.Models.Small
854 var smallModelProviderCfg config.ProviderConfig
855
856 for _, p := range cfg.Providers {
857 if p.ID == smallModelCfg.Provider {
858 smallModelProviderCfg = p
859 break
860 }
861 }
862
863 if smallModelProviderCfg.ID == "" {
864 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
865 }
866
867 // Check if summarize provider has changed
868 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
869 var smallModel config.Model
870 for _, m := range smallModelProviderCfg.Models {
871 if m.ID == smallModelCfg.ModelID {
872 smallModel = m
873 break
874 }
875 }
876 if smallModel.ID == "" {
877 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
878 }
879
880 // Recreate title provider
881 titleOpts := []provider.ProviderClientOption{
882 provider.WithModel(config.SmallModel),
883 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
884 // We want the title to be short, so we limit the max tokens
885 provider.WithMaxTokens(40),
886 }
887 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
888 if err != nil {
889 return fmt.Errorf("failed to create new title provider: %w", err)
890 }
891
892 // Recreate summarize provider
893 summarizeOpts := []provider.ProviderClientOption{
894 provider.WithModel(config.SmallModel),
895 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
896 }
897 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
898 if err != nil {
899 return fmt.Errorf("failed to create new summarize provider: %w", err)
900 }
901
902 // Update the providers and provider ID
903 a.titleProvider = newTitleProvider
904 a.summarizeProvider = newSummarizeProvider
905 a.summarizeProviderID = string(smallModelProviderCfg.ID)
906 }
907
908 return nil
909}