1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "slices"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 "github.com/charmbracelet/crush/internal/history"
14 "github.com/charmbracelet/crush/internal/llm/prompt"
15 "github.com/charmbracelet/crush/internal/llm/provider"
16 "github.com/charmbracelet/crush/internal/llm/tools"
17 "github.com/charmbracelet/crush/internal/logging"
18 "github.com/charmbracelet/crush/internal/lsp"
19 "github.com/charmbracelet/crush/internal/message"
20 "github.com/charmbracelet/crush/internal/permission"
21 "github.com/charmbracelet/crush/internal/pubsub"
22 "github.com/charmbracelet/crush/internal/session"
23)
24
25// Common errors
26var (
27 ErrRequestCancelled = errors.New("request cancelled by user")
28 ErrSessionBusy = errors.New("session is currently processing another request")
29)
30
31type AgentEventType string
32
33const (
34 AgentEventTypeError AgentEventType = "error"
35 AgentEventTypeResponse AgentEventType = "response"
36 AgentEventTypeSummarize AgentEventType = "summarize"
37)
38
39type AgentEvent struct {
40 Type AgentEventType
41 Message message.Message
42 Error error
43
44 // When summarizing
45 SessionID string
46 Progress string
47 Done bool
48}
49
50type Service interface {
51 pubsub.Suscriber[AgentEvent]
52 Model() config.Model
53 EffectiveMaxTokens() int64
54 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
55 Cancel(sessionID string)
56 CancelAll()
57 IsSessionBusy(sessionID string) bool
58 IsBusy() bool
59 Summarize(ctx context.Context, sessionID string) error
60 UpdateModel() error
61}
62
63type agent struct {
64 *pubsub.Broker[AgentEvent]
65 agentCfg config.Agent
66 sessions session.Service
67 messages message.Service
68
69 tools []tools.BaseTool
70 provider provider.Provider
71 providerID string
72
73 titleProvider provider.Provider
74 summarizeProvider provider.Provider
75 summarizeProviderID string
76
77 activeRequests sync.Map
78}
79
80var agentPromptMap = map[config.AgentID]prompt.PromptID{
81 config.AgentCoder: prompt.PromptCoder,
82 config.AgentTask: prompt.PromptTask,
83}
84
85func NewAgent(
86 agentCfg config.Agent,
87 // These services are needed in the tools
88 permissions permission.Service,
89 sessions session.Service,
90 messages message.Service,
91 history history.Service,
92 lspClients map[string]*lsp.Client,
93) (Service, error) {
94 ctx := context.Background()
95 cfg := config.Get()
96 otherTools := GetMcpTools(ctx, permissions)
97 if len(lspClients) > 0 {
98 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
99 }
100
101 allTools := []tools.BaseTool{
102 tools.NewBashTool(permissions),
103 tools.NewEditTool(lspClients, permissions, history),
104 tools.NewFetchTool(permissions),
105 tools.NewGlobTool(),
106 tools.NewGrepTool(),
107 tools.NewLsTool(),
108 tools.NewSourcegraphTool(),
109 tools.NewViewTool(lspClients),
110 tools.NewWriteTool(lspClients, permissions, history),
111 }
112
113 if agentCfg.ID == config.AgentCoder {
114 taskAgentCfg := config.Get().Agents[config.AgentTask]
115 if taskAgentCfg.ID == "" {
116 return nil, fmt.Errorf("task agent not found in config")
117 }
118 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
119 if err != nil {
120 return nil, fmt.Errorf("failed to create task agent: %w", err)
121 }
122
123 allTools = append(
124 allTools,
125 NewAgentTool(
126 taskAgent,
127 sessions,
128 messages,
129 ),
130 )
131 }
132
133 allTools = append(allTools, otherTools...)
134 providerCfg := config.GetAgentProvider(agentCfg.ID)
135 if providerCfg.ID == "" {
136 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
137 }
138 model := config.GetAgentModel(agentCfg.ID)
139
140 if model.ID == "" {
141 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
142 }
143
144 promptID := agentPromptMap[agentCfg.ID]
145 if promptID == "" {
146 promptID = prompt.PromptDefault
147 }
148 opts := []provider.ProviderClientOption{
149 provider.WithModel(agentCfg.Model),
150 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
151 }
152 agentProvider, err := provider.NewProvider(providerCfg, opts...)
153 if err != nil {
154 return nil, err
155 }
156
157 smallModelCfg := cfg.Models.Small
158 var smallModel config.Model
159
160 var smallModelProviderCfg config.ProviderConfig
161 if smallModelCfg.Provider == providerCfg.ID {
162 smallModelProviderCfg = providerCfg
163 } else {
164 for _, p := range cfg.Providers {
165 if p.ID == smallModelCfg.Provider {
166 smallModelProviderCfg = p
167 break
168 }
169 }
170 if smallModelProviderCfg.ID == "" {
171 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
172 }
173 }
174 for _, m := range smallModelProviderCfg.Models {
175 if m.ID == smallModelCfg.ModelID {
176 smallModel = m
177 break
178 }
179 }
180 if smallModel.ID == "" {
181 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
182 }
183
184 titleOpts := []provider.ProviderClientOption{
185 provider.WithModel(config.SmallModel),
186 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
187 }
188 titleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
189 if err != nil {
190 return nil, err
191 }
192 summarizeOpts := []provider.ProviderClientOption{
193 provider.WithModel(config.SmallModel),
194 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
195 }
196 summarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
197 if err != nil {
198 return nil, err
199 }
200
201 agentTools := []tools.BaseTool{}
202 if agentCfg.AllowedTools == nil {
203 agentTools = allTools
204 } else {
205 for _, tool := range allTools {
206 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
207 agentTools = append(agentTools, tool)
208 }
209 }
210 }
211
212 agent := &agent{
213 Broker: pubsub.NewBroker[AgentEvent](),
214 agentCfg: agentCfg,
215 provider: agentProvider,
216 providerID: string(providerCfg.ID),
217 messages: messages,
218 sessions: sessions,
219 tools: agentTools,
220 titleProvider: titleProvider,
221 summarizeProvider: summarizeProvider,
222 summarizeProviderID: string(smallModelProviderCfg.ID),
223 activeRequests: sync.Map{},
224 }
225
226 return agent, nil
227}
228
229func (a *agent) Model() config.Model {
230 return config.GetAgentModel(a.agentCfg.ID)
231}
232
233func (a *agent) EffectiveMaxTokens() int64 {
234 return config.GetAgentEffectiveMaxTokens(a.agentCfg.ID)
235}
236
237func (a *agent) Cancel(sessionID string) {
238 // Cancel regular requests
239 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
240 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
241 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
242 cancel()
243 }
244 }
245
246 // Also check for summarize requests
247 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
248 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
249 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
250 cancel()
251 }
252 }
253}
254
255func (a *agent) IsBusy() bool {
256 busy := false
257 a.activeRequests.Range(func(key, value any) bool {
258 if cancelFunc, ok := value.(context.CancelFunc); ok {
259 if cancelFunc != nil {
260 busy = true
261 return false
262 }
263 }
264 return true
265 })
266 return busy
267}
268
269func (a *agent) IsSessionBusy(sessionID string) bool {
270 _, busy := a.activeRequests.Load(sessionID)
271 return busy
272}
273
274func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
275 if content == "" {
276 return nil
277 }
278 if a.titleProvider == nil {
279 return nil
280 }
281 session, err := a.sessions.Get(ctx, sessionID)
282 if err != nil {
283 return err
284 }
285 parts := []message.ContentPart{message.TextContent{
286 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
287 }}
288
289 // Use streaming approach like summarization
290 response := a.titleProvider.StreamResponse(
291 ctx,
292 []message.Message{
293 {
294 Role: message.User,
295 Parts: parts,
296 },
297 },
298 make([]tools.BaseTool, 0),
299 )
300
301 var finalResponse *provider.ProviderResponse
302 for r := range response {
303 if r.Error != nil {
304 return r.Error
305 }
306 finalResponse = r.Response
307 }
308
309 if finalResponse == nil {
310 return fmt.Errorf("no response received from title provider")
311 }
312
313 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
314 if title == "" {
315 return nil
316 }
317
318 session.Title = title
319 _, err = a.sessions.Save(ctx, session)
320 return err
321}
322
323func (a *agent) err(err error) AgentEvent {
324 return AgentEvent{
325 Type: AgentEventTypeError,
326 Error: err,
327 }
328}
329
330func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
331 if !a.Model().SupportsImages && attachments != nil {
332 attachments = nil
333 }
334 events := make(chan AgentEvent)
335 if a.IsSessionBusy(sessionID) {
336 return nil, ErrSessionBusy
337 }
338
339 genCtx, cancel := context.WithCancel(ctx)
340
341 a.activeRequests.Store(sessionID, cancel)
342 go func() {
343 logging.Debug("Request started", "sessionID", sessionID)
344 defer logging.RecoverPanic("agent.Run", func() {
345 events <- a.err(fmt.Errorf("panic while running the agent"))
346 })
347 var attachmentParts []message.ContentPart
348 for _, attachment := range attachments {
349 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
350 }
351 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
352 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
353 logging.ErrorPersist(result.Error.Error())
354 }
355 logging.Debug("Request completed", "sessionID", sessionID)
356 a.activeRequests.Delete(sessionID)
357 cancel()
358 a.Publish(pubsub.CreatedEvent, result)
359 events <- result
360 close(events)
361 }()
362 return events, nil
363}
364
365func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
366 cfg := config.Get()
367 // List existing messages; if none, start title generation asynchronously.
368 msgs, err := a.messages.List(ctx, sessionID)
369 if err != nil {
370 return a.err(fmt.Errorf("failed to list messages: %w", err))
371 }
372 if len(msgs) == 0 {
373 go func() {
374 defer logging.RecoverPanic("agent.Run", func() {
375 logging.ErrorPersist("panic while generating title")
376 })
377 titleErr := a.generateTitle(context.Background(), sessionID, content)
378 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
379 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
380 }
381 }()
382 }
383 session, err := a.sessions.Get(ctx, sessionID)
384 if err != nil {
385 return a.err(fmt.Errorf("failed to get session: %w", err))
386 }
387 if session.SummaryMessageID != "" {
388 summaryMsgInex := -1
389 for i, msg := range msgs {
390 if msg.ID == session.SummaryMessageID {
391 summaryMsgInex = i
392 break
393 }
394 }
395 if summaryMsgInex != -1 {
396 msgs = msgs[summaryMsgInex:]
397 msgs[0].Role = message.User
398 }
399 }
400
401 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
402 if err != nil {
403 return a.err(fmt.Errorf("failed to create user message: %w", err))
404 }
405 // Append the new user message to the conversation history.
406 msgHistory := append(msgs, userMsg)
407
408 for {
409 // Check for cancellation before each iteration
410 select {
411 case <-ctx.Done():
412 return a.err(ctx.Err())
413 default:
414 // Continue processing
415 }
416 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
417 if err != nil {
418 if errors.Is(err, context.Canceled) {
419 agentMessage.AddFinish(message.FinishReasonCanceled)
420 a.messages.Update(context.Background(), agentMessage)
421 return a.err(ErrRequestCancelled)
422 }
423 return a.err(fmt.Errorf("failed to process events: %w", err))
424 }
425 if cfg.Options.Debug {
426 seqId := (len(msgHistory) + 1) / 2
427 toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults)
428 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath)
429 } else {
430 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
431 }
432 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
433 // We are not done, we need to respond with the tool response
434 msgHistory = append(msgHistory, agentMessage, *toolResults)
435 continue
436 }
437 return AgentEvent{
438 Type: AgentEventTypeResponse,
439 Message: agentMessage,
440 Done: true,
441 }
442 }
443}
444
445func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
446 parts := []message.ContentPart{message.TextContent{Text: content}}
447 parts = append(parts, attachmentParts...)
448 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
449 Role: message.User,
450 Parts: parts,
451 })
452}
453
454func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
455 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
456 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
457
458 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
459 Role: message.Assistant,
460 Parts: []message.ContentPart{},
461 Model: a.Model().ID,
462 Provider: a.providerID,
463 })
464 if err != nil {
465 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
466 }
467
468 // Add the session and message ID into the context if needed by tools.
469 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
470
471 // Process each event in the stream.
472 for event := range eventChan {
473 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
474 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
475 return assistantMsg, nil, processErr
476 }
477 if ctx.Err() != nil {
478 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
479 return assistantMsg, nil, ctx.Err()
480 }
481 }
482
483 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
484 toolCalls := assistantMsg.ToolCalls()
485 for i, toolCall := range toolCalls {
486 select {
487 case <-ctx.Done():
488 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
489 // Make all future tool calls cancelled
490 for j := i; j < len(toolCalls); j++ {
491 toolResults[j] = message.ToolResult{
492 ToolCallID: toolCalls[j].ID,
493 Content: "Tool execution canceled by user",
494 IsError: true,
495 }
496 }
497 goto out
498 default:
499 // Continue processing
500 var tool tools.BaseTool
501 for _, availableTool := range a.tools {
502 if availableTool.Info().Name == toolCall.Name {
503 tool = availableTool
504 break
505 }
506 // Monkey patch for Copilot Sonnet-4 tool repetition obfuscation
507 // if strings.HasPrefix(toolCall.Name, availableTool.Info().Name) &&
508 // strings.HasPrefix(toolCall.Name, availableTool.Info().Name+availableTool.Info().Name) {
509 // tool = availableTool
510 // break
511 // }
512 }
513
514 // Tool not found
515 if tool == nil {
516 toolResults[i] = message.ToolResult{
517 ToolCallID: toolCall.ID,
518 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
519 IsError: true,
520 }
521 continue
522 }
523 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
524 ID: toolCall.ID,
525 Name: toolCall.Name,
526 Input: toolCall.Input,
527 })
528 if toolErr != nil {
529 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
530 toolResults[i] = message.ToolResult{
531 ToolCallID: toolCall.ID,
532 Content: "Permission denied",
533 IsError: true,
534 }
535 for j := i + 1; j < len(toolCalls); j++ {
536 toolResults[j] = message.ToolResult{
537 ToolCallID: toolCalls[j].ID,
538 Content: "Tool execution canceled by user",
539 IsError: true,
540 }
541 }
542 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
543 break
544 }
545 }
546 toolResults[i] = message.ToolResult{
547 ToolCallID: toolCall.ID,
548 Content: toolResult.Content,
549 Metadata: toolResult.Metadata,
550 IsError: toolResult.IsError,
551 }
552 }
553 }
554out:
555 if len(toolResults) == 0 {
556 return assistantMsg, nil, nil
557 }
558 parts := make([]message.ContentPart, 0)
559 for _, tr := range toolResults {
560 parts = append(parts, tr)
561 }
562 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
563 Role: message.Tool,
564 Parts: parts,
565 Provider: a.providerID,
566 })
567 if err != nil {
568 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
569 }
570
571 return assistantMsg, &msg, err
572}
573
574func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
575 msg.AddFinish(finishReson)
576 _ = a.messages.Update(ctx, *msg)
577}
578
579func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
580 select {
581 case <-ctx.Done():
582 return ctx.Err()
583 default:
584 // Continue processing.
585 }
586
587 switch event.Type {
588 case provider.EventThinkingDelta:
589 assistantMsg.AppendReasoningContent(event.Content)
590 return a.messages.Update(ctx, *assistantMsg)
591 case provider.EventContentDelta:
592 assistantMsg.AppendContent(event.Content)
593 return a.messages.Update(ctx, *assistantMsg)
594 case provider.EventToolUseStart:
595 logging.Info("Tool call started", "toolCall", event.ToolCall)
596 assistantMsg.AddToolCall(*event.ToolCall)
597 return a.messages.Update(ctx, *assistantMsg)
598 case provider.EventToolUseDelta:
599 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
600 return a.messages.Update(ctx, *assistantMsg)
601 case provider.EventToolUseStop:
602 logging.Info("Finished tool call", "toolCall", event.ToolCall)
603 assistantMsg.FinishToolCall(event.ToolCall.ID)
604 return a.messages.Update(ctx, *assistantMsg)
605 case provider.EventError:
606 if errors.Is(event.Error, context.Canceled) {
607 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
608 return context.Canceled
609 }
610 logging.ErrorPersist(event.Error.Error())
611 return event.Error
612 case provider.EventComplete:
613 assistantMsg.SetToolCalls(event.Response.ToolCalls)
614 assistantMsg.AddFinish(event.Response.FinishReason)
615 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
616 return fmt.Errorf("failed to update message: %w", err)
617 }
618 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
619 }
620
621 return nil
622}
623
624func (a *agent) TrackUsage(ctx context.Context, sessionID string, model config.Model, usage provider.TokenUsage) error {
625 sess, err := a.sessions.Get(ctx, sessionID)
626 if err != nil {
627 return fmt.Errorf("failed to get session: %w", err)
628 }
629
630 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
631 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
632 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
633 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
634
635 sess.Cost += cost
636 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
637 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
638
639 _, err = a.sessions.Save(ctx, sess)
640 if err != nil {
641 return fmt.Errorf("failed to save session: %w", err)
642 }
643 return nil
644}
645
646func (a *agent) Summarize(ctx context.Context, sessionID string) error {
647 if a.summarizeProvider == nil {
648 return fmt.Errorf("summarize provider not available")
649 }
650
651 // Check if session is busy
652 if a.IsSessionBusy(sessionID) {
653 return ErrSessionBusy
654 }
655
656 // Create a new context with cancellation
657 summarizeCtx, cancel := context.WithCancel(ctx)
658
659 // Store the cancel function in activeRequests to allow cancellation
660 a.activeRequests.Store(sessionID+"-summarize", cancel)
661
662 go func() {
663 defer a.activeRequests.Delete(sessionID + "-summarize")
664 defer cancel()
665 event := AgentEvent{
666 Type: AgentEventTypeSummarize,
667 Progress: "Starting summarization...",
668 }
669
670 a.Publish(pubsub.CreatedEvent, event)
671 // Get all messages from the session
672 msgs, err := a.messages.List(summarizeCtx, sessionID)
673 if err != nil {
674 event = AgentEvent{
675 Type: AgentEventTypeError,
676 Error: fmt.Errorf("failed to list messages: %w", err),
677 Done: true,
678 }
679 a.Publish(pubsub.CreatedEvent, event)
680 return
681 }
682 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
683
684 if len(msgs) == 0 {
685 event = AgentEvent{
686 Type: AgentEventTypeError,
687 Error: fmt.Errorf("no messages to summarize"),
688 Done: true,
689 }
690 a.Publish(pubsub.CreatedEvent, event)
691 return
692 }
693
694 event = AgentEvent{
695 Type: AgentEventTypeSummarize,
696 Progress: "Analyzing conversation...",
697 }
698 a.Publish(pubsub.CreatedEvent, event)
699
700 // Add a system message to guide the summarization
701 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
702
703 // Create a new message with the summarize prompt
704 promptMsg := message.Message{
705 Role: message.User,
706 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
707 }
708
709 // Append the prompt to the messages
710 msgsWithPrompt := append(msgs, promptMsg)
711
712 event = AgentEvent{
713 Type: AgentEventTypeSummarize,
714 Progress: "Generating summary...",
715 }
716
717 a.Publish(pubsub.CreatedEvent, event)
718
719 // Send the messages to the summarize provider
720 response := a.summarizeProvider.StreamResponse(
721 summarizeCtx,
722 msgsWithPrompt,
723 make([]tools.BaseTool, 0),
724 )
725 var finalResponse *provider.ProviderResponse
726 for r := range response {
727 if r.Error != nil {
728 event = AgentEvent{
729 Type: AgentEventTypeError,
730 Error: fmt.Errorf("failed to summarize: %w", err),
731 Done: true,
732 }
733 a.Publish(pubsub.CreatedEvent, event)
734 return
735 }
736 finalResponse = r.Response
737 }
738
739 summary := strings.TrimSpace(finalResponse.Content)
740 if summary == "" {
741 event = AgentEvent{
742 Type: AgentEventTypeError,
743 Error: fmt.Errorf("empty summary returned"),
744 Done: true,
745 }
746 a.Publish(pubsub.CreatedEvent, event)
747 return
748 }
749 event = AgentEvent{
750 Type: AgentEventTypeSummarize,
751 Progress: "Creating new session...",
752 }
753
754 a.Publish(pubsub.CreatedEvent, event)
755 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
756 if err != nil {
757 event = AgentEvent{
758 Type: AgentEventTypeError,
759 Error: fmt.Errorf("failed to get session: %w", err),
760 Done: true,
761 }
762
763 a.Publish(pubsub.CreatedEvent, event)
764 return
765 }
766 // Create a message in the new session with the summary
767 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
768 Role: message.Assistant,
769 Parts: []message.ContentPart{
770 message.TextContent{Text: summary},
771 message.Finish{
772 Reason: message.FinishReasonEndTurn,
773 Time: time.Now().Unix(),
774 },
775 },
776 Model: a.summarizeProvider.Model().ID,
777 Provider: a.summarizeProviderID,
778 })
779 if err != nil {
780 event = AgentEvent{
781 Type: AgentEventTypeError,
782 Error: fmt.Errorf("failed to create summary message: %w", err),
783 Done: true,
784 }
785
786 a.Publish(pubsub.CreatedEvent, event)
787 return
788 }
789 oldSession.SummaryMessageID = msg.ID
790 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
791 oldSession.PromptTokens = 0
792 model := a.summarizeProvider.Model()
793 usage := finalResponse.Usage
794 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
795 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
796 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
797 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
798 oldSession.Cost += cost
799 _, err = a.sessions.Save(summarizeCtx, oldSession)
800 if err != nil {
801 event = AgentEvent{
802 Type: AgentEventTypeError,
803 Error: fmt.Errorf("failed to save session: %w", err),
804 Done: true,
805 }
806 a.Publish(pubsub.CreatedEvent, event)
807 }
808
809 event = AgentEvent{
810 Type: AgentEventTypeSummarize,
811 SessionID: oldSession.ID,
812 Progress: "Summary complete",
813 Done: true,
814 }
815 a.Publish(pubsub.CreatedEvent, event)
816 // Send final success event with the new session ID
817 }()
818
819 return nil
820}
821
822func (a *agent) CancelAll() {
823 a.activeRequests.Range(func(key, value any) bool {
824 a.Cancel(key.(string)) // key is sessionID
825 return true
826 })
827}
828
829func (a *agent) UpdateModel() error {
830 cfg := config.Get()
831
832 // Get current provider configuration
833 currentProviderCfg := config.GetAgentProvider(a.agentCfg.ID)
834 if currentProviderCfg.ID == "" {
835 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
836 }
837
838 // Check if provider has changed
839 if string(currentProviderCfg.ID) != a.providerID {
840 // Provider changed, need to recreate the main provider
841 model := config.GetAgentModel(a.agentCfg.ID)
842 if model.ID == "" {
843 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
844 }
845
846 promptID := agentPromptMap[a.agentCfg.ID]
847 if promptID == "" {
848 promptID = prompt.PromptDefault
849 }
850
851 opts := []provider.ProviderClientOption{
852 provider.WithModel(a.agentCfg.Model),
853 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
854 }
855
856 newProvider, err := provider.NewProvider(currentProviderCfg, opts...)
857 if err != nil {
858 return fmt.Errorf("failed to create new provider: %w", err)
859 }
860
861 // Update the provider and provider ID
862 a.provider = newProvider
863 a.providerID = string(currentProviderCfg.ID)
864 }
865
866 // Check if small model provider has changed (affects title and summarize providers)
867 smallModelCfg := cfg.Models.Small
868 var smallModelProviderCfg config.ProviderConfig
869
870 for _, p := range cfg.Providers {
871 if p.ID == smallModelCfg.Provider {
872 smallModelProviderCfg = p
873 break
874 }
875 }
876
877 if smallModelProviderCfg.ID == "" {
878 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
879 }
880
881 // Check if summarize provider has changed
882 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
883 var smallModel config.Model
884 for _, m := range smallModelProviderCfg.Models {
885 if m.ID == smallModelCfg.ModelID {
886 smallModel = m
887 break
888 }
889 }
890 if smallModel.ID == "" {
891 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.ModelID, smallModelProviderCfg.ID)
892 }
893
894 // Recreate title provider
895 titleOpts := []provider.ProviderClientOption{
896 provider.WithModel(config.SmallModel),
897 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
898 // We want the title to be short, so we limit the max tokens
899 provider.WithMaxTokens(40),
900 }
901 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
902 if err != nil {
903 return fmt.Errorf("failed to create new title provider: %w", err)
904 }
905
906 // Recreate summarize provider
907 summarizeOpts := []provider.ProviderClientOption{
908 provider.WithModel(config.SmallModel),
909 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
910 }
911 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
912 if err != nil {
913 return fmt.Errorf("failed to create new summarize provider: %w", err)
914 }
915
916 // Update the providers and provider ID
917 a.titleProvider = newTitleProvider
918 a.summarizeProvider = newSummarizeProvider
919 a.summarizeProviderID = string(smallModelProviderCfg.ID)
920 }
921
922 return nil
923}