1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "slices"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/history"
14 "github.com/charmbracelet/crush/internal/llm/prompt"
15 "github.com/charmbracelet/crush/internal/llm/provider"
16 "github.com/charmbracelet/crush/internal/llm/tools"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25// Common errors
26var (
27 ErrRequestCancelled = errors.New("request cancelled by user")
28 ErrSessionBusy = errors.New("session is currently processing another request")
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() config.Model
53 EffectiveMaxTokens() int64
54 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
55 Cancel(sessionID string)
56 CancelAll()
57 IsSessionBusy(sessionID string) bool
58 IsBusy() bool
59 Summarize(ctx context.Context, sessionID string) error
60 UpdateModel() error
61}
62
63type agent struct {
64 *pubsub.Broker[AgentEvent]
65 agentCfg config.Agent
66 sessions session.Service
67 messages message.Service
68
69 tools []tools.BaseTool
70 provider provider.Provider
71 providerID string
72
73 titleProvider provider.Provider
74 summarizeProvider provider.Provider
75 summarizeProviderID string
76
77 activeRequests sync.Map
78}
79
80var agentPromptMap = map[config.AgentID]prompt.PromptID{
81 config.AgentCoder: prompt.PromptCoder,
82 config.AgentTask: prompt.PromptTask,
83}
84
85func NewAgent(
86 agentCfg config.Agent,
87 // These services are needed in the tools
88 permissions permission.Service,
89 sessions session.Service,
90 messages message.Service,
91 history history.Service,
92 lspClients map[string]*lsp.Client,
93) (Service, error) {
94 ctx := context.Background()
95 cfg := config.Get()
96 otherTools := GetMcpTools(ctx, permissions)
97 if len(lspClients) > 0 {
98 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
99 }
100
101 allTools := []tools.BaseTool{
102 tools.NewBashTool(permissions),
103 tools.NewEditTool(lspClients, permissions, history),
104 tools.NewFetchTool(permissions),
105 tools.NewGlobTool(),
106 tools.NewGrepTool(),
107 tools.NewLsTool(),
108 tools.NewSourcegraphTool(),
109 tools.NewViewTool(lspClients),
110 tools.NewWriteTool(lspClients, permissions, history),
111 }
112
113 if agentCfg.ID == config.AgentCoder {
114 taskAgentCfg := config.Get().Agents[config.AgentTask]
115 if taskAgentCfg.ID == "" {
116 return nil, fmt.Errorf("task agent not found in config")
117 }
118 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
119 if err != nil {
120 return nil, fmt.Errorf("failed to create task agent: %w", err)
121 }
122
123 allTools = append(
124 allTools,
125 NewAgentTool(
126 taskAgent,
127 sessions,
128 messages,
129 ),
130 )
131 }
132
133 allTools = append(allTools, otherTools...)
134 providerCfg := config.GetAgentProvider(agentCfg.ID)
135 if providerCfg.ID == "" {
136 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
137 }
138 model := config.GetAgentModel(agentCfg.ID)
139
140 if model.ID == "" {
141 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
142 }
143
144 promptID := agentPromptMap[agentCfg.ID]
145 if promptID == "" {
146 promptID = prompt.PromptDefault
147 }
148 opts := []provider.ProviderClientOption{
149 provider.WithModel(agentCfg.Model),
150 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
151 }
152 agentProvider, err := provider.NewProvider(providerCfg, opts...)
153 if err != nil {
154 return nil, err
155 }
156
157 smallModelCfg := cfg.Models.Small
158 var smallModel config.Model
159
160 var smallModelProviderCfg config.ProviderConfig
161 if smallModelCfg.Provider == providerCfg.ID {
162 smallModelProviderCfg = providerCfg
163 } else {
164 for _, p := range cfg.Providers {
165 if p.ID == smallModelCfg.Provider {
166 smallModelProviderCfg = p
167 break
168 }
169 }
170 if smallModelProviderCfg.ID == "" {
171 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
172 }
173 }
174 for _, m := range smallModelProviderCfg.Models {
175 if m.ID == smallModelCfg.ModelID {
176 smallModel = m
177 break
178 }
179 }
180 if smallModel.ID == "" {
181 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
182 }
183
184 titleOpts := []provider.ProviderClientOption{
185 provider.WithModel(config.SmallModel),
186 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
187 }
188 titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
189 if err != nil {
190 return nil, err
191 }
192 summarizeOpts := []provider.ProviderClientOption{
193 provider.WithModel(config.SmallModel),
194 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
195 }
196 summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
197 if err != nil {
198 return nil, err
199 }
200
201 agentTools := []tools.BaseTool{}
202 if agentCfg.AllowedTools == nil {
203 agentTools = allTools
204 } else {
205 for _, tool := range allTools {
206 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
207 agentTools = append(agentTools, tool)
208 }
209 }
210 }
211
212 agent := &agent{
213 Broker: pubsub.NewBroker[AgentEvent](),
214 agentCfg: agentCfg,
215 provider: agentProvider,
216 providerID: string(providerCfg.ID),
217 messages: messages,
218 sessions: sessions,
219 tools: agentTools,
220 titleProvider: titleProvider,
221 summarizeProvider: summarizeProvider,
222 summarizeProviderID: string(smallModelProviderCfg.ID),
223 activeRequests: sync.Map{},
224 }
225
226 return agent, nil
227}
228
229func (a *agent) Model() config.Model {
230 return config.GetAgentModel(a.agentCfg.ID)
231}
232
233func (a *agent) EffectiveMaxTokens() int64 {
234 return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
235}
236
237func (a *agent) Cancel(sessionID string) {
238 // Cancel regular requests
239 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
240 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
241 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
242 cancel()
243 }
244 }
245
246 // Also check for summarize requests
247 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
248 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
249 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
250 cancel()
251 }
252 }
253}
254
255func (a *agent) IsBusy() bool {
256 busy := false
257 a.activeRequests.Range(func(key, value any) bool {
258 if cancelFunc, ok := value.(context.CancelFunc); ok {
259 if cancelFunc != nil {
260 busy = true
261 return false
262 }
263 }
264 return true
265 })
266 return busy
267}
268
269func (a *agent) IsSessionBusy(sessionID string) bool {
270 _, busy := a.activeRequests.Load(sessionID)
271 return busy
272}
273
274func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
275 if content == "" {
276 return nil
277 }
278 if a.titleProvider == nil {
279 return nil
280 }
281 session, err := a.sessions.Get(ctx, sessionID)
282 if err != nil {
283 return err
284 }
285 parts := []message.ContentPart{message.TextContent{
286 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
287 }}
288
289 // Use streaming approach like summarization
290 response := a.titleProvider.StreamResponse(
291 ctx,
292 []message.Message{
293 {
294 Role: message.User,
295 Parts: parts,
296 },
297 },
298 make([]tools.BaseTool, 0),
299 )
300
301 var finalResponse *provider.ProviderResponse
302 for r := range response {
303 if r.Error != nil {
304 return r.Error
305 }
306 finalResponse = r.Response
307 }
308
309 if finalResponse == nil {
310 return fmt.Errorf("no response received from title provider")
311 }
312
313 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
314 if title == "" {
315 return nil
316 }
317
318 session.Title = title
319 _, err = a.sessions.Save(ctx, session)
320 return err
321}
322
323func (a *agent) err(err error) AgentEvent {
324 return AgentEvent{
325 Type: AgentEventTypeError,
326 Error: err,
327 }
328}
329
330func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
331 if !a.Model().SupportsImages && attachments != nil {
332 attachments = nil
333 }
334 events := make(chan AgentEvent)
335 if a.IsSessionBusy(sessionID) {
336 return nil, ErrSessionBusy
337 }
338
339 genCtx, cancel := context.WithCancel(ctx)
340
341 a.activeRequests.Store(sessionID, cancel)
342 go func() {
343 logging.Debug("Request started", "sessionID", sessionID)
344 defer logging.RecoverPanic("agent.Run", func() {
345 events <- a.err(fmt.Errorf("panic while running the agent"))
346 })
347 var attachmentParts []message.ContentPart
348 for _, attachment := range attachments {
349 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
350 }
351 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
352 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
353 logging.ErrorPersist(result.Error.Error())
354 }
355 logging.Debug("Request completed", "sessionID", sessionID)
356 a.activeRequests.Delete(sessionID)
357 cancel()
358 a.Publish(pubsub.CreatedEvent, result)
359 events <- result
360 close(events)
361 }()
362 return events, nil
363}
364
365func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
366 // List existing messages; if none, start title generation asynchronously.
367 msgs, err := a.messages.List(ctx, sessionID)
368 if err != nil {
369 return a.err(fmt.Errorf("failed to list messages: %w", err))
370 }
371 if len(msgs) == 0 {
372 go func() {
373 defer logging.RecoverPanic("agent.Run", func() {
374 logging.ErrorPersist("panic while generating title")
375 })
376 titleErr := a.generateTitle(context.Background(), sessionID, content)
377 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
378 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
379 }
380 }()
381 }
382 session, err := a.sessions.Get(ctx, sessionID)
383 if err != nil {
384 return a.err(fmt.Errorf("failed to get session: %w", err))
385 }
386 if session.SummaryMessageID != "" {
387 summaryMsgInex := -1
388 for i, msg := range msgs {
389 if msg.ID == session.SummaryMessageID {
390 summaryMsgInex = i
391 break
392 }
393 }
394 if summaryMsgInex != -1 {
395 msgs = msgs[summaryMsgInex:]
396 msgs[0].Role = message.User
397 }
398 }
399
400 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
401 if err != nil {
402 return a.err(fmt.Errorf("failed to create user message: %w", err))
403 }
404 // Append the new user message to the conversation history.
405 msgHistory := append(msgs, userMsg)
406
407 for {
408 // Check for cancellation before each iteration
409 select {
410 case <-ctx.Done():
411 return a.err(ctx.Err())
412 default:
413 // Continue processing
414 }
415 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
416 if err != nil {
417 if errors.Is(err, context.Canceled) {
418 agentMessage.AddFinish(message.FinishReasonCanceled)
419 a.messages.Update(context.Background(), agentMessage)
420 return a.err(ErrRequestCancelled)
421 }
422 return a.err(fmt.Errorf("failed to process events: %w", err))
423 }
424 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
425 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
426 // We are not done, we need to respond with the tool response
427 msgHistory = append(msgHistory, agentMessage, *toolResults)
428 continue
429 }
430 return AgentEvent{
431 Type: AgentEventTypeResponse,
432 Message: agentMessage,
433 Done: true,
434 }
435 }
436}
437
438func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
439 parts := []message.ContentPart{message.TextContent{Text: content}}
440 parts = append(parts, attachmentParts...)
441 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442 Role: message.User,
443 Parts: parts,
444 })
445}
446
447func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
448 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
449
450 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
451 Role: message.Assistant,
452 Parts: []message.ContentPart{},
453 Model: a.Model().ID,
454 Provider: a.providerID,
455 })
456 if err != nil {
457 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
458 }
459
460 // Add the session and message ID into the context if needed by tools.
461 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
462 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
463
464 // Process each event in the stream.
465 for event := range eventChan {
466 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
467 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
468 return assistantMsg, nil, processErr
469 }
470 if ctx.Err() != nil {
471 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
472 return assistantMsg, nil, ctx.Err()
473 }
474 }
475
476 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
477 toolCalls := assistantMsg.ToolCalls()
478 for i, toolCall := range toolCalls {
479 select {
480 case <-ctx.Done():
481 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
482 // Make all future tool calls cancelled
483 for j := i; j < len(toolCalls); j++ {
484 toolResults[j] = message.ToolResult{
485 ToolCallID: toolCalls[j].ID,
486 Content: "Tool execution canceled by user",
487 IsError: true,
488 }
489 }
490 goto out
491 default:
492 // Continue processing
493 var tool tools.BaseTool
494 for _, availableTools := range a.tools {
495 if availableTools.Info().Name == toolCall.Name {
496 tool = availableTools
497 }
498 }
499
500 // Tool not found
501 if tool == nil {
502 toolResults[i] = message.ToolResult{
503 ToolCallID: toolCall.ID,
504 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
505 IsError: true,
506 }
507 continue
508 }
509 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
510 ID: toolCall.ID,
511 Name: toolCall.Name,
512 Input: toolCall.Input,
513 })
514 if toolErr != nil {
515 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
516 toolResults[i] = message.ToolResult{
517 ToolCallID: toolCall.ID,
518 Content: "Permission denied",
519 IsError: true,
520 }
521 for j := i + 1; j < len(toolCalls); j++ {
522 toolResults[j] = message.ToolResult{
523 ToolCallID: toolCalls[j].ID,
524 Content: "Tool execution canceled by user",
525 IsError: true,
526 }
527 }
528 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
529 break
530 }
531 }
532 toolResults[i] = message.ToolResult{
533 ToolCallID: toolCall.ID,
534 Content: toolResult.Content,
535 Metadata: toolResult.Metadata,
536 IsError: toolResult.IsError,
537 }
538 }
539 }
540out:
541 if len(toolResults) == 0 {
542 return assistantMsg, nil, nil
543 }
544 parts := make([]message.ContentPart, 0)
545 for _, tr := range toolResults {
546 parts = append(parts, tr)
547 }
548 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
549 Role: message.Tool,
550 Parts: parts,
551 Provider: a.providerID,
552 })
553 if err != nil {
554 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
555 }
556
557 return assistantMsg, &msg, err
558}
559
560func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
561 msg.AddFinish(finishReson)
562 _ = a.messages.Update(ctx, *msg)
563}
564
565func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
566 select {
567 case <-ctx.Done():
568 return ctx.Err()
569 default:
570 // Continue processing.
571 }
572
573 switch event.Type {
574 case provider.EventThinkingDelta:
575 assistantMsg.AppendReasoningContent(event.Content)
576 return a.messages.Update(ctx, *assistantMsg)
577 case provider.EventContentDelta:
578 assistantMsg.AppendContent(event.Content)
579 return a.messages.Update(ctx, *assistantMsg)
580 case provider.EventToolUseStart:
581 logging.Info("Tool call started", "toolCall", event.ToolCall)
582 assistantMsg.AddToolCall(*event.ToolCall)
583 return a.messages.Update(ctx, *assistantMsg)
584 case provider.EventToolUseDelta:
585 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
586 return a.messages.Update(ctx, *assistantMsg)
587 case provider.EventToolUseStop:
588 logging.Info("Finished tool call", "toolCall", event.ToolCall)
589 assistantMsg.FinishToolCall(event.ToolCall.ID)
590 return a.messages.Update(ctx, *assistantMsg)
591 case provider.EventError:
592 if errors.Is(event.Error, context.Canceled) {
593 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
594 return context.Canceled
595 }
596 logging.ErrorPersist(event.Error.Error())
597 return event.Error
598 case provider.EventComplete:
599 assistantMsg.SetToolCalls(event.Response.ToolCalls)
600 assistantMsg.AddFinish(event.Response.FinishReason)
601 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
602 return fmt.Errorf("failed to update message: %w", err)
603 }
604 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
605 }
606
607 return nil
608}
609
610func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
611 sess, err := a.sessions.Get(ctx, sessionID)
612 if err != nil {
613 return fmt.Errorf("failed to get session: %w", err)
614 }
615
616 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
617 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
618 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
619 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
620
621 sess.Cost += cost
622 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
623 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
624
625 _, err = a.sessions.Save(ctx, sess)
626 if err != nil {
627 return fmt.Errorf("failed to save session: %w", err)
628 }
629 return nil
630}
631
632func (a *agent) Summarize(ctx context.Context, sessionID string) error {
633 if a.summarizeProvider == nil {
634 return fmt.Errorf("summarize provider not available")
635 }
636
637 // Check if session is busy
638 if a.IsSessionBusy(sessionID) {
639 return ErrSessionBusy
640 }
641
642 // Create a new context with cancellation
643 summarizeCtx, cancel := context.WithCancel(ctx)
644
645 // Store the cancel function in activeRequests to allow cancellation
646 a.activeRequests.Store(sessionID+"-summarize", cancel)
647
648 go func() {
649 defer a.activeRequests.Delete(sessionID + "-summarize")
650 defer cancel()
651 event := AgentEvent{
652 Type: AgentEventTypeSummarize,
653 Progress: "Starting summarization...",
654 }
655
656 a.Publish(pubsub.CreatedEvent, event)
657 // Get all messages from the session
658 msgs, err := a.messages.List(summarizeCtx, sessionID)
659 if err != nil {
660 event = AgentEvent{
661 Type: AgentEventTypeError,
662 Error: fmt.Errorf("failed to list messages: %w", err),
663 Done: true,
664 }
665 a.Publish(pubsub.CreatedEvent, event)
666 return
667 }
668
669 if len(msgs) == 0 {
670 event = AgentEvent{
671 Type: AgentEventTypeError,
672 Error: fmt.Errorf("no messages to summarize"),
673 Done: true,
674 }
675 a.Publish(pubsub.CreatedEvent, event)
676 return
677 }
678
679 event = AgentEvent{
680 Type: AgentEventTypeSummarize,
681 Progress: "Analyzing conversation...",
682 }
683 a.Publish(pubsub.CreatedEvent, event)
684
685 // Add a system message to guide the summarization
686 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
687
688 // Create a new message with the summarize prompt
689 promptMsg := message.Message{
690 Role: message.User,
691 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
692 }
693
694 // Append the prompt to the messages
695 msgsWithPrompt := append(msgs, promptMsg)
696
697 event = AgentEvent{
698 Type: AgentEventTypeSummarize,
699 Progress: "Generating summary...",
700 }
701
702 a.Publish(pubsub.CreatedEvent, event)
703
704 // Send the messages to the summarize provider
705 response := a.summarizeProvider.StreamResponse(
706 summarizeCtx,
707 msgsWithPrompt,
708 make([]tools.BaseTool, 0),
709 )
710 var finalResponse *provider.ProviderResponse
711 for r := range response {
712 if r.Error != nil {
713 event = AgentEvent{
714 Type: AgentEventTypeError,
715 Error: fmt.Errorf("failed to summarize: %w", err),
716 Done: true,
717 }
718 a.Publish(pubsub.CreatedEvent, event)
719 return
720 }
721 finalResponse = r.Response
722 }
723
724 summary := strings.TrimSpace(finalResponse.Content)
725 if summary == "" {
726 event = AgentEvent{
727 Type: AgentEventTypeError,
728 Error: fmt.Errorf("empty summary returned"),
729 Done: true,
730 }
731 a.Publish(pubsub.CreatedEvent, event)
732 return
733 }
734 event = AgentEvent{
735 Type: AgentEventTypeSummarize,
736 Progress: "Creating new session...",
737 }
738
739 a.Publish(pubsub.CreatedEvent, event)
740 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
741 if err != nil {
742 event = AgentEvent{
743 Type: AgentEventTypeError,
744 Error: fmt.Errorf("failed to get session: %w", err),
745 Done: true,
746 }
747
748 a.Publish(pubsub.CreatedEvent, event)
749 return
750 }
751 // Create a message in the new session with the summary
752 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
753 Role: message.Assistant,
754 Parts: []message.ContentPart{
755 message.TextContent{Text: summary},
756 message.Finish{
757 Reason: message.FinishReasonEndTurn,
758 Time: time.Now().Unix(),
759 },
760 },
761 Model: a.summarizeProvider.Model().ID,
762 Provider: a.summarizeProviderID,
763 })
764 if err != nil {
765 event = AgentEvent{
766 Type: AgentEventTypeError,
767 Error: fmt.Errorf("failed to create summary message: %w", err),
768 Done: true,
769 }
770
771 a.Publish(pubsub.CreatedEvent, event)
772 return
773 }
774 oldSession.SummaryMessageID = msg.ID
775 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
776 oldSession.PromptTokens = 0
777 model := a.summarizeProvider.Model()
778 usage := finalResponse.Usage
779 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
780 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
781 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
782 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
783 oldSession.Cost += cost
784 _, err = a.sessions.Save(summarizeCtx, oldSession)
785 if err != nil {
786 event = AgentEvent{
787 Type: AgentEventTypeError,
788 Error: fmt.Errorf("failed to save session: %w", err),
789 Done: true,
790 }
791 a.Publish(pubsub.CreatedEvent, event)
792 }
793
794 event = AgentEvent{
795 Type: AgentEventTypeSummarize,
796 SessionID: oldSession.ID,
797 Progress: "Summary complete",
798 Done: true,
799 }
800 a.Publish(pubsub.CreatedEvent, event)
801 // Send final success event with the new session ID
802 }()
803
804 return nil
805}
806
807func (a *agent) CancelAll() {
808 a.activeRequests.Range(func(key, value any) bool {
809 a.Cancel(key.(string)) // key is sessionID
810 return true
811 })
812}
813
814func (a *agent) UpdateModel() error {
815 cfg := config.Get()
816
817 // Get current provider configuration
818 currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
819 if currentProviderCfg.ID == "" {
820 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
821 }
822
823 // Check if provider has changed
824 if string(currentProviderCfg.ID) != a.providerID {
825 // Provider changed, need to recreate the main provider
826 model := config.GetAgentModel(a.agentCfg.ID)
827 if model.ID == "" {
828 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
829 }
830
831 promptID := agentPromptMap[a.agentCfg.ID]
832 if promptID == "" {
833 promptID = prompt.PromptDefault
834 }
835
836 opts := []provider.ProviderClientOption{
837 provider.WithModel(a.agentCfg.Model),
838 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
839 }
840
841 newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
842 if err != nil {
843 return fmt.Errorf("failed to create new provider: %w", err)
844 }
845
846 // Update the provider and provider ID
847 a.provider = newProvider
848 a.providerID = string(currentProviderCfg.ID)
849 }
850
851 // Check if small model provider has changed (affects title and summarize providers)
852 smallModelCfg := cfg.Models.Small
853 var smallModelProviderCfg config.ProviderConfig
854
855 for _, p := range cfg.Providers {
856 if p.ID == smallModelCfg.Provider {
857 smallModelProviderCfg = p
858 break
859 }
860 }
861
862 if smallModelProviderCfg.ID == "" {
863 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
864 }
865
866 // Check if summarize provider has changed
867 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
868 var smallModel config.Model
869 for _, m := range smallModelProviderCfg.Models {
870 if m.ID == smallModelCfg.ModelID {
871 smallModel = m
872 break
873 }
874 }
875 if smallModel.ID == "" {
876 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
877 }
878
879 // Recreate title provider
880 titleOpts := []provider.ProviderClientOption{
881 provider.WithModel(config.SmallModel),
882 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
883 // We want the title to be short, so we limit the max tokens
884 provider.WithMaxTokens(40),
885 }
886 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
887 if err != nil {
888 return fmt.Errorf("failed to create new title provider: %w", err)
889 }
890
891 // Recreate summarize provider
892 summarizeOpts := []provider.ProviderClientOption{
893 provider.WithModel(config.SmallModel),
894 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
895 }
896 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
897 if err != nil {
898 return fmt.Errorf("failed to create new summarize provider: %w", err)
899 }
900
901 // Update the providers and provider ID
902 a.titleProvider = newTitleProvider
903 a.summarizeProvider = newSummarizeProvider
904 a.summarizeProviderID = string(smallModelProviderCfg.ID)
905 }
906
907 return nil
908}