1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "slices"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/history"
14 "github.com/charmbracelet/crush/internal/llm/prompt"
15 "github.com/charmbracelet/crush/internal/llm/provider"
16 "github.com/charmbracelet/crush/internal/llm/tools"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25// Common errors
26var (
27 ErrRequestCancelled = errors.New("request cancelled by user")
28 ErrSessionBusy = errors.New("session is currently processing another request")
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() config.Model
53 EffectiveMaxTokens() int64
54 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
55 Cancel(sessionID string)
56 CancelAll()
57 IsSessionBusy(sessionID string) bool
58 IsBusy() bool
59 Summarize(ctx context.Context, sessionID string) error
60 UpdateModel() error
61}
62
63type agent struct {
64 *pubsub.Broker[AgentEvent]
65 agentCfg config.Agent
66 sessions session.Service
67 messages message.Service
68
69 tools []tools.BaseTool
70 provider provider.Provider
71 providerID string
72
73 titleProvider provider.Provider
74 summarizeProvider provider.Provider
75 summarizeProviderID string
76
77 activeRequests sync.Map
78}
79
80var agentPromptMap = map[config.AgentID]prompt.PromptID{
81 config.AgentCoder: prompt.PromptCoder,
82 config.AgentTask: prompt.PromptTask,
83}
84
85func NewAgent(
86 agentCfg config.Agent,
87 // These services are needed in the tools
88 permissions permission.Service,
89 sessions session.Service,
90 messages message.Service,
91 history history.Service,
92 lspClients map[string]*lsp.Client,
93) (Service, error) {
94 ctx := context.Background()
95 cfg := config.Get()
96 otherTools := GetMcpTools(ctx, permissions)
97 if len(lspClients) > 0 {
98 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
99 }
100
101 allTools := []tools.BaseTool{
102 tools.NewBashTool(permissions),
103 tools.NewEditTool(lspClients, permissions, history),
104 tools.NewFetchTool(permissions),
105 tools.NewGlobTool(),
106 tools.NewGrepTool(),
107 tools.NewLsTool(),
108 tools.NewSourcegraphTool(),
109 tools.NewViewTool(lspClients),
110 tools.NewWriteTool(lspClients, permissions, history),
111 }
112
113 if agentCfg.ID == config.AgentCoder {
114 taskAgentCfg := config.Get().Agents[config.AgentTask]
115 if taskAgentCfg.ID == "" {
116 return nil, fmt.Errorf("task agent not found in config")
117 }
118 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
119 if err != nil {
120 return nil, fmt.Errorf("failed to create task agent: %w", err)
121 }
122
123 allTools = append(
124 allTools,
125 NewAgentTool(
126 taskAgent,
127 sessions,
128 messages,
129 ),
130 )
131 }
132
133 allTools = append(allTools, otherTools...)
134 providerCfg := config.GetAgentProvider(agentCfg.ID)
135 if providerCfg.ID == "" {
136 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
137 }
138 model := config.GetAgentModel(agentCfg.ID)
139
140 if model.ID == "" {
141 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
142 }
143
144 promptID := agentPromptMap[agentCfg.ID]
145 if promptID == "" {
146 promptID = prompt.PromptDefault
147 }
148 opts := []provider.ProviderClientOption{
149 provider.WithModel(agentCfg.Model),
150 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
151 }
152 agentProvider, err := provider.NewProvider(providerCfg, opts...)
153 if err != nil {
154 return nil, err
155 }
156
157 smallModelCfg := cfg.Models.Small
158 var smallModel config.Model
159
160 var smallModelProviderCfg config.ProviderConfig
161 if smallModelCfg.Provider == providerCfg.ID {
162 smallModelProviderCfg = providerCfg
163 } else {
164 for _, p := range cfg.Providers {
165 if p.ID == smallModelCfg.Provider {
166 smallModelProviderCfg = p
167 break
168 }
169 }
170 if smallModelProviderCfg.ID == "" {
171 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
172 }
173 }
174 for _, m := range smallModelProviderCfg.Models {
175 if m.ID == smallModelCfg.ModelID {
176 smallModel = m
177 break
178 }
179 }
180 if smallModel.ID == "" {
181 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
182 }
183
184 titleOpts := []provider.ProviderClientOption{
185 provider.WithModel(config.SmallModel),
186 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
187 }
188 titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
189 if err != nil {
190 return nil, err
191 }
192 summarizeOpts := []provider.ProviderClientOption{
193 provider.WithModel(config.SmallModel),
194 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
195 }
196 summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
197 if err != nil {
198 return nil, err
199 }
200
201 agentTools := []tools.BaseTool{}
202 if agentCfg.AllowedTools == nil {
203 agentTools = allTools
204 } else {
205 for _, tool := range allTools {
206 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
207 agentTools = append(agentTools, tool)
208 }
209 }
210 }
211
212 agent := &agent{
213 Broker: pubsub.NewBroker[AgentEvent](),
214 agentCfg: agentCfg,
215 provider: agentProvider,
216 providerID: string(providerCfg.ID),
217 messages: messages,
218 sessions: sessions,
219 tools: agentTools,
220 titleProvider: titleProvider,
221 summarizeProvider: summarizeProvider,
222 summarizeProviderID: string(smallModelProviderCfg.ID),
223 activeRequests: sync.Map{},
224 }
225
226 return agent, nil
227}
228
229func (a *agent) Model() config.Model {
230 return config.GetAgentModel(a.agentCfg.ID)
231}
232
233func (a *agent) EffectiveMaxTokens() int64 {
234 return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
235}
236
237func (a *agent) Cancel(sessionID string) {
238 // Cancel regular requests
239 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
240 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
241 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
242 cancel()
243 }
244 }
245
246 // Also check for summarize requests
247 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
248 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
249 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
250 cancel()
251 }
252 }
253}
254
255func (a *agent) IsBusy() bool {
256 busy := false
257 a.activeRequests.Range(func(key, value any) bool {
258 if cancelFunc, ok := value.(context.CancelFunc); ok {
259 if cancelFunc != nil {
260 busy = true
261 return false
262 }
263 }
264 return true
265 })
266 return busy
267}
268
269func (a *agent) IsSessionBusy(sessionID string) bool {
270 _, busy := a.activeRequests.Load(sessionID)
271 return busy
272}
273
274func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
275 if content == "" {
276 return nil
277 }
278 if a.titleProvider == nil {
279 return nil
280 }
281 session, err := a.sessions.Get(ctx, sessionID)
282 if err != nil {
283 return err
284 }
285 parts := []message.ContentPart{message.TextContent{
286 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
287 }}
288
289 // Use streaming approach like summarization
290 response := a.titleProvider.StreamResponse(
291 ctx,
292 []message.Message{
293 {
294 Role: message.User,
295 Parts: parts,
296 },
297 },
298 make([]tools.BaseTool, 0),
299 )
300
301 var finalResponse *provider.ProviderResponse
302 for r := range response {
303 if r.Error != nil {
304 return r.Error
305 }
306 finalResponse = r.Response
307 }
308
309 if finalResponse == nil {
310 return fmt.Errorf("no response received from title provider")
311 }
312
313 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
314 if title == "" {
315 return nil
316 }
317
318 session.Title = title
319 _, err = a.sessions.Save(ctx, session)
320 return err
321}
322
323func (a *agent) err(err error) AgentEvent {
324 return AgentEvent{
325 Type: AgentEventTypeError,
326 Error: err,
327 }
328}
329
330func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
331 if !a.Model().SupportsImages && attachments != nil {
332 attachments = nil
333 }
334 events := make(chan AgentEvent)
335 if a.IsSessionBusy(sessionID) {
336 return nil, ErrSessionBusy
337 }
338
339 genCtx, cancel := context.WithCancel(ctx)
340
341 a.activeRequests.Store(sessionID, cancel)
342 go func() {
343 logging.Debug("Request started", "sessionID", sessionID)
344 defer logging.RecoverPanic("agent.Run", func() {
345 events <- a.err(fmt.Errorf("panic while running the agent"))
346 })
347 var attachmentParts []message.ContentPart
348 for _, attachment := range attachments {
349 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
350 }
351 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
352 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
353 logging.ErrorPersist(result.Error.Error())
354 }
355 logging.Debug("Request completed", "sessionID", sessionID)
356 a.activeRequests.Delete(sessionID)
357 cancel()
358 a.Publish(pubsub.CreatedEvent, result)
359 events <- result
360 close(events)
361 }()
362 return events, nil
363}
364
365func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
366 cfg := config.Get()
367 // List existing messages; if none, start title generation asynchronously.
368 msgs, err := a.messages.List(ctx, sessionID)
369 if err != nil {
370 return a.err(fmt.Errorf("failed to list messages: %w", err))
371 }
372 if len(msgs) == 0 {
373 go func() {
374 defer logging.RecoverPanic("agent.Run", func() {
375 logging.ErrorPersist("panic while generating title")
376 })
377 titleErr := a.generateTitle(context.Background(), sessionID, content)
378 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
379 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
380 }
381 }()
382 }
383 session, err := a.sessions.Get(ctx, sessionID)
384 if err != nil {
385 return a.err(fmt.Errorf("failed to get session: %w", err))
386 }
387 if session.SummaryMessageID != "" {
388 summaryMsgInex := -1
389 for i, msg := range msgs {
390 if msg.ID == session.SummaryMessageID {
391 summaryMsgInex = i
392 break
393 }
394 }
395 if summaryMsgInex != -1 {
396 msgs = msgs[summaryMsgInex:]
397 msgs[0].Role = message.User
398 }
399 }
400
401 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
402 if err != nil {
403 return a.err(fmt.Errorf("failed to create user message: %w", err))
404 }
405 // Append the new user message to the conversation history.
406 msgHistory := append(msgs, userMsg)
407
408 for {
409 // Check for cancellation before each iteration
410 select {
411 case <-ctx.Done():
412 return a.err(ctx.Err())
413 default:
414 // Continue processing
415 }
416 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
417 if err != nil {
418 if errors.Is(err, context.Canceled) {
419 agentMessage.AddFinish(message.FinishReasonCanceled)
420 a.messages.Update(context.Background(), agentMessage)
421 return a.err(ErrRequestCancelled)
422 }
423 return a.err(fmt.Errorf("failed to process events: %w", err))
424 }
425 if cfg.Options.Debug {
426 seqId := (len(msgHistory) + 1) / 2
427 toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults)
428 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath)
429 } else {
430 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
431 }
432 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
433 // We are not done, we need to respond with the tool response
434 msgHistory = append(msgHistory, agentMessage, *toolResults)
435 continue
436 }
437 return AgentEvent{
438 Type: AgentEventTypeResponse,
439 Message: agentMessage,
440 Done: true,
441 }
442 }
443}
444
445func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
446 parts := []message.ContentPart{message.TextContent{Text: content}}
447 parts = append(parts, attachmentParts...)
448 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
449 Role: message.User,
450 Parts: parts,
451 })
452}
453
454func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
455 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
456 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
457
458 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
459 Role: message.Assistant,
460 Parts: []message.ContentPart{},
461 Model: a.Model().ID,
462 Provider: a.providerID,
463 })
464 if err != nil {
465 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
466 }
467
468 // Add the session and message ID into the context if needed by tools.
469 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
470
471 // Process each event in the stream.
472 for event := range eventChan {
473 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
474 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
475 return assistantMsg, nil, processErr
476 }
477 if ctx.Err() != nil {
478 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
479 return assistantMsg, nil, ctx.Err()
480 }
481 }
482
483 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
484 toolCalls := assistantMsg.ToolCalls()
485 for i, toolCall := range toolCalls {
486 select {
487 case <-ctx.Done():
488 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
489 // Make all future tool calls cancelled
490 for j := i; j < len(toolCalls); j++ {
491 toolResults[j] = message.ToolResult{
492 ToolCallID: toolCalls[j].ID,
493 Content: "Tool execution canceled by user",
494 IsError: true,
495 }
496 }
497 goto out
498 default:
499 // Continue processing
500 var tool tools.BaseTool
501 for _, availableTool := range a.tools {
502 if availableTool.Info().Name == toolCall.Name {
503 tool = availableTool
504 break
505 }
506 }
507
508 // Tool not found
509 if tool == nil {
510 toolResults[i] = message.ToolResult{
511 ToolCallID: toolCall.ID,
512 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
513 IsError: true,
514 }
515 continue
516 }
517 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
518 ID: toolCall.ID,
519 Name: toolCall.Name,
520 Input: toolCall.Input,
521 })
522 if toolErr != nil {
523 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
524 toolResults[i] = message.ToolResult{
525 ToolCallID: toolCall.ID,
526 Content: "Permission denied",
527 IsError: true,
528 }
529 for j := i + 1; j < len(toolCalls); j++ {
530 toolResults[j] = message.ToolResult{
531 ToolCallID: toolCalls[j].ID,
532 Content: "Tool execution canceled by user",
533 IsError: true,
534 }
535 }
536 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
537 break
538 }
539 }
540 toolResults[i] = message.ToolResult{
541 ToolCallID: toolCall.ID,
542 Content: toolResult.Content,
543 Metadata: toolResult.Metadata,
544 IsError: toolResult.IsError,
545 }
546 }
547 }
548out:
549 if len(toolResults) == 0 {
550 return assistantMsg, nil, nil
551 }
552 parts := make([]message.ContentPart, 0)
553 for _, tr := range toolResults {
554 parts = append(parts, tr)
555 }
556 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
557 Role: message.Tool,
558 Parts: parts,
559 Provider: a.providerID,
560 })
561 if err != nil {
562 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
563 }
564
565 return assistantMsg, &msg, err
566}
567
568func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
569 msg.AddFinish(finishReson)
570 _ = a.messages.Update(ctx, *msg)
571}
572
573func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
574 select {
575 case <-ctx.Done():
576 return ctx.Err()
577 default:
578 // Continue processing.
579 }
580
581 switch event.Type {
582 case provider.EventThinkingDelta:
583 assistantMsg.AppendReasoningContent(event.Content)
584 return a.messages.Update(ctx, *assistantMsg)
585 case provider.EventContentDelta:
586 assistantMsg.AppendContent(event.Content)
587 return a.messages.Update(ctx, *assistantMsg)
588 case provider.EventToolUseStart:
589 logging.Info("Tool call started", "toolCall", event.ToolCall)
590 assistantMsg.AddToolCall(*event.ToolCall)
591 return a.messages.Update(ctx, *assistantMsg)
592 case provider.EventToolUseDelta:
593 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
594 return a.messages.Update(ctx, *assistantMsg)
595 case provider.EventToolUseStop:
596 logging.Info("Finished tool call", "toolCall", event.ToolCall)
597 assistantMsg.FinishToolCall(event.ToolCall.ID)
598 return a.messages.Update(ctx, *assistantMsg)
599 case provider.EventError:
600 if errors.Is(event.Error, context.Canceled) {
601 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
602 return context.Canceled
603 }
604 logging.ErrorPersist(event.Error.Error())
605 return event.Error
606 case provider.EventComplete:
607 assistantMsg.SetToolCalls(event.Response.ToolCalls)
608 assistantMsg.AddFinish(event.Response.FinishReason)
609 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
610 return fmt.Errorf("failed to update message: %w", err)
611 }
612 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
613 }
614
615 return nil
616}
617
618func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
619 sess, err := a.sessions.Get(ctx, sessionID)
620 if err != nil {
621 return fmt.Errorf("failed to get session: %w", err)
622 }
623
624 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
625 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
626 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
627 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
628
629 sess.Cost += cost
630 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
631 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
632
633 _, err = a.sessions.Save(ctx, sess)
634 if err != nil {
635 return fmt.Errorf("failed to save session: %w", err)
636 }
637 return nil
638}
639
640func (a *agent) Summarize(ctx context.Context, sessionID string) error {
641 if a.summarizeProvider == nil {
642 return fmt.Errorf("summarize provider not available")
643 }
644
645 // Check if session is busy
646 if a.IsSessionBusy(sessionID) {
647 return ErrSessionBusy
648 }
649
650 // Create a new context with cancellation
651 summarizeCtx, cancel := context.WithCancel(ctx)
652
653 // Store the cancel function in activeRequests to allow cancellation
654 a.activeRequests.Store(sessionID+"-summarize", cancel)
655
656 go func() {
657 defer a.activeRequests.Delete(sessionID + "-summarize")
658 defer cancel()
659 event := AgentEvent{
660 Type: AgentEventTypeSummarize,
661 Progress: "Starting summarization...",
662 }
663
664 a.Publish(pubsub.CreatedEvent, event)
665 // Get all messages from the session
666 msgs, err := a.messages.List(summarizeCtx, sessionID)
667 if err != nil {
668 event = AgentEvent{
669 Type: AgentEventTypeError,
670 Error: fmt.Errorf("failed to list messages: %w", err),
671 Done: true,
672 }
673 a.Publish(pubsub.CreatedEvent, event)
674 return
675 }
676 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
677
678 if len(msgs) == 0 {
679 event = AgentEvent{
680 Type: AgentEventTypeError,
681 Error: fmt.Errorf("no messages to summarize"),
682 Done: true,
683 }
684 a.Publish(pubsub.CreatedEvent, event)
685 return
686 }
687
688 event = AgentEvent{
689 Type: AgentEventTypeSummarize,
690 Progress: "Analyzing conversation...",
691 }
692 a.Publish(pubsub.CreatedEvent, event)
693
694 // Add a system message to guide the summarization
695 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
696
697 // Create a new message with the summarize prompt
698 promptMsg := message.Message{
699 Role: message.User,
700 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
701 }
702
703 // Append the prompt to the messages
704 msgsWithPrompt := append(msgs, promptMsg)
705
706 event = AgentEvent{
707 Type: AgentEventTypeSummarize,
708 Progress: "Generating summary...",
709 }
710
711 a.Publish(pubsub.CreatedEvent, event)
712
713 // Send the messages to the summarize provider
714 response := a.summarizeProvider.StreamResponse(
715 summarizeCtx,
716 msgsWithPrompt,
717 make([]tools.BaseTool, 0),
718 )
719 var finalResponse *provider.ProviderResponse
720 for r := range response {
721 if r.Error != nil {
722 event = AgentEvent{
723 Type: AgentEventTypeError,
724 Error: fmt.Errorf("failed to summarize: %w", err),
725 Done: true,
726 }
727 a.Publish(pubsub.CreatedEvent, event)
728 return
729 }
730 finalResponse = r.Response
731 }
732
733 summary := strings.TrimSpace(finalResponse.Content)
734 if summary == "" {
735 event = AgentEvent{
736 Type: AgentEventTypeError,
737 Error: fmt.Errorf("empty summary returned"),
738 Done: true,
739 }
740 a.Publish(pubsub.CreatedEvent, event)
741 return
742 }
743 event = AgentEvent{
744 Type: AgentEventTypeSummarize,
745 Progress: "Creating new session...",
746 }
747
748 a.Publish(pubsub.CreatedEvent, event)
749 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
750 if err != nil {
751 event = AgentEvent{
752 Type: AgentEventTypeError,
753 Error: fmt.Errorf("failed to get session: %w", err),
754 Done: true,
755 }
756
757 a.Publish(pubsub.CreatedEvent, event)
758 return
759 }
760 // Create a message in the new session with the summary
761 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
762 Role: message.Assistant,
763 Parts: []message.ContentPart{
764 message.TextContent{Text: summary},
765 message.Finish{
766 Reason: message.FinishReasonEndTurn,
767 Time: time.Now().Unix(),
768 },
769 },
770 Model: a.summarizeProvider.Model().ID,
771 Provider: a.summarizeProviderID,
772 })
773 if err != nil {
774 event = AgentEvent{
775 Type: AgentEventTypeError,
776 Error: fmt.Errorf("failed to create summary message: %w", err),
777 Done: true,
778 }
779
780 a.Publish(pubsub.CreatedEvent, event)
781 return
782 }
783 oldSession.SummaryMessageID = msg.ID
784 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
785 oldSession.PromptTokens = 0
786 model := a.summarizeProvider.Model()
787 usage := finalResponse.Usage
788 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
789 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
790 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
791 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
792 oldSession.Cost += cost
793 _, err = a.sessions.Save(summarizeCtx, oldSession)
794 if err != nil {
795 event = AgentEvent{
796 Type: AgentEventTypeError,
797 Error: fmt.Errorf("failed to save session: %w", err),
798 Done: true,
799 }
800 a.Publish(pubsub.CreatedEvent, event)
801 }
802
803 event = AgentEvent{
804 Type: AgentEventTypeSummarize,
805 SessionID: oldSession.ID,
806 Progress: "Summary complete",
807 Done: true,
808 }
809 a.Publish(pubsub.CreatedEvent, event)
810 // Send final success event with the new session ID
811 }()
812
813 return nil
814}
815
816func (a *agent) CancelAll() {
817 a.activeRequests.Range(func(key, value any) bool {
818 a.Cancel(key.(string)) // key is sessionID
819 return true
820 })
821}
822
823func (a *agent) UpdateModel() error {
824 cfg := config.Get()
825
826 // Get current provider configuration
827 currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
828 if currentProviderCfg.ID == "" {
829 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
830 }
831
832 // Check if provider has changed
833 if string(currentProviderCfg.ID) != a.providerID {
834 // Provider changed, need to recreate the main provider
835 model := config.GetAgentModel(a.agentCfg.ID)
836 if model.ID == "" {
837 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
838 }
839
840 promptID := agentPromptMap[a.agentCfg.ID]
841 if promptID == "" {
842 promptID = prompt.PromptDefault
843 }
844
845 opts := []provider.ProviderClientOption{
846 provider.WithModel(a.agentCfg.Model),
847 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
848 }
849
850 newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
851 if err != nil {
852 return fmt.Errorf("failed to create new provider: %w", err)
853 }
854
855 // Update the provider and provider ID
856 a.provider = newProvider
857 a.providerID = string(currentProviderCfg.ID)
858 }
859
860 // Check if small model provider has changed (affects title and summarize providers)
861 smallModelCfg := cfg.Models.Small
862 var smallModelProviderCfg config.ProviderConfig
863
864 for _, p := range cfg.Providers {
865 if p.ID == smallModelCfg.Provider {
866 smallModelProviderCfg = p
867 break
868 }
869 }
870
871 if smallModelProviderCfg.ID == "" {
872 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
873 }
874
875 // Check if summarize provider has changed
876 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
877 var smallModel config.Model
878 for _, m := range smallModelProviderCfg.Models {
879 if m.ID == smallModelCfg.ModelID {
880 smallModel = m
881 break
882 }
883 }
884 if smallModel.ID == "" {
885 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
886 }
887
888 // Recreate title provider
889 titleOpts := []provider.ProviderClientOption{
890 provider.WithModel(config.SmallModel),
891 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
892 // We want the title to be short, so we limit the max tokens
893 provider.WithMaxTokens(40),
894 }
895 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
896 if err != nil {
897 return fmt.Errorf("failed to create new title provider: %w", err)
898 }
899
900 // Recreate summarize provider
901 summarizeOpts := []provider.ProviderClientOption{
902 provider.WithModel(config.SmallModel),
903 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
904 }
905 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
906 if err != nil {
907 return fmt.Errorf("failed to create new summarize provider: %w", err)
908 }
909
910 // Update the providers and provider ID
911 a.titleProvider = newTitleProvider
912 a.summarizeProvider = newSummarizeProvider
913 a.summarizeProviderID = string(smallModelProviderCfg.ID)
914 }
915
916 return nil
917}