1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "slices"
8 "strings"
9 "sync"
10 "time"
11
12 "github.com/charmbracelet/crush/internal/config"
13 fur "github.com/charmbracelet/crush/internal/fur/provider"
14 "github.com/charmbracelet/crush/internal/history"
15 "github.com/charmbracelet/crush/internal/llm/prompt"
16 "github.com/charmbracelet/crush/internal/llm/provider"
17 "github.com/charmbracelet/crush/internal/llm/tools"
18 "github.com/charmbracelet/crush/internal/logging"
19 "github.com/charmbracelet/crush/internal/lsp"
20 "github.com/charmbracelet/crush/internal/message"
21 "github.com/charmbracelet/crush/internal/permission"
22 "github.com/charmbracelet/crush/internal/pubsub"
23 "github.com/charmbracelet/crush/internal/session"
24)
25
26// Common errors
27var (
28 ErrRequestCancelled = errors.New("request cancelled by user")
29 ErrSessionBusy = errors.New("session is currently processing another request")
30)
31
32type AgentEventType string
33
34const (
35 AgentEventTypeError AgentEventType = "error"
36 AgentEventTypeResponse AgentEventType = "response"
37 AgentEventTypeSummarize AgentEventType = "summarize"
38)
39
40type AgentEvent struct {
41 Type AgentEventType
42 Message message.Message
43 Error error
44
45 // When summarizing
46 SessionID string
47 Progress string
48 Done bool
49}
50
51type Service interface {
52 pubsub.Suscriber[AgentEvent]
53 Model() fur.Model
54 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
55 Cancel(sessionID string)
56 CancelAll()
57 IsSessionBusy(sessionID string) bool
58 IsBusy() bool
59 Summarize(ctx context.Context, sessionID string) error
60 UpdateModel() error
61}
62
63type agent struct {
64 *pubsub.Broker[AgentEvent]
65 agentCfg config.Agent
66 sessions session.Service
67 messages message.Service
68
69 tools []tools.BaseTool
70 provider provider.Provider
71 providerID string
72
73 titleProvider provider.Provider
74 summarizeProvider provider.Provider
75 summarizeProviderID string
76
77 activeRequests sync.Map
78}
79
80var agentPromptMap = map[string]prompt.PromptID{
81 "coder": prompt.PromptCoder,
82 "task": prompt.PromptTask,
83}
84
85func NewAgent(
86 agentCfg config.Agent,
87 // These services are needed in the tools
88 permissions permission.Service,
89 sessions session.Service,
90 messages message.Service,
91 history history.Service,
92 lspClients map[string]*lsp.Client,
93) (Service, error) {
94 ctx := context.Background()
95 cfg := config.Get()
96 otherTools := GetMcpTools(ctx, permissions)
97 if len(lspClients) > 0 {
98 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
99 }
100
101 allTools := []tools.BaseTool{
102 tools.NewBashTool(permissions),
103 tools.NewEditTool(lspClients, permissions, history),
104 tools.NewFetchTool(permissions),
105 tools.NewGlobTool(),
106 tools.NewGrepTool(),
107 tools.NewLsTool(),
108 tools.NewSourcegraphTool(),
109 tools.NewViewTool(lspClients),
110 tools.NewWriteTool(lspClients, permissions, history),
111 }
112
113 if agentCfg.ID == "coder" {
114 taskAgentCfg := config.Get().Agents["task"]
115 if taskAgentCfg.ID == "" {
116 return nil, fmt.Errorf("task agent not found in config")
117 }
118 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
119 if err != nil {
120 return nil, fmt.Errorf("failed to create task agent: %w", err)
121 }
122
123 allTools = append(
124 allTools,
125 NewAgentTool(
126 taskAgent,
127 sessions,
128 messages,
129 ),
130 )
131 }
132
133 allTools = append(allTools, otherTools...)
134 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
135 if providerCfg == nil {
136 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
137 }
138 model := config.Get().GetModelByType(agentCfg.Model)
139
140 if model == nil {
141 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
142 }
143
144 promptID := agentPromptMap[agentCfg.ID]
145 if promptID == "" {
146 promptID = prompt.PromptDefault
147 }
148 opts := []provider.ProviderClientOption{
149 provider.WithModel(agentCfg.Model),
150 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
151 }
152 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
153 if err != nil {
154 return nil, err
155 }
156
157 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
158 var smallModelProviderCfg *config.ProviderConfig
159 if smallModelCfg.Provider == providerCfg.ID {
160 smallModelProviderCfg = providerCfg
161 } else {
162 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
163
164 if smallModelProviderCfg.ID == "" {
165 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
166 }
167 }
168 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
169 if smallModel.ID == "" {
170 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
171 }
172
173 titleOpts := []provider.ProviderClientOption{
174 provider.WithModel(config.SelectedModelTypeSmall),
175 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
176 }
177 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
178 if err != nil {
179 return nil, err
180 }
181 summarizeOpts := []provider.ProviderClientOption{
182 provider.WithModel(config.SelectedModelTypeSmall),
183 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
184 }
185 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
186 if err != nil {
187 return nil, err
188 }
189
190 agentTools := []tools.BaseTool{}
191 if agentCfg.AllowedTools == nil {
192 agentTools = allTools
193 } else {
194 for _, tool := range allTools {
195 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
196 agentTools = append(agentTools, tool)
197 }
198 }
199 }
200
201 agent := &agent{
202 Broker: pubsub.NewBroker[AgentEvent](),
203 agentCfg: agentCfg,
204 provider: agentProvider,
205 providerID: string(providerCfg.ID),
206 messages: messages,
207 sessions: sessions,
208 tools: agentTools,
209 titleProvider: titleProvider,
210 summarizeProvider: summarizeProvider,
211 summarizeProviderID: string(smallModelProviderCfg.ID),
212 activeRequests: sync.Map{},
213 }
214
215 return agent, nil
216}
217
218func (a *agent) Model() fur.Model {
219 return *config.Get().GetModelByType(a.agentCfg.Model)
220}
221
222func (a *agent) Cancel(sessionID string) {
223 // Cancel regular requests
224 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
225 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
226 logging.InfoPersist(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
227 cancel()
228 }
229 }
230
231 // Also check for summarize requests
232 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
233 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
234 logging.InfoPersist(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
235 cancel()
236 }
237 }
238}
239
240func (a *agent) IsBusy() bool {
241 busy := false
242 a.activeRequests.Range(func(key, value any) bool {
243 if cancelFunc, ok := value.(context.CancelFunc); ok {
244 if cancelFunc != nil {
245 busy = true
246 return false
247 }
248 }
249 return true
250 })
251 return busy
252}
253
254func (a *agent) IsSessionBusy(sessionID string) bool {
255 _, busy := a.activeRequests.Load(sessionID)
256 return busy
257}
258
259func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
260 if content == "" {
261 return nil
262 }
263 if a.titleProvider == nil {
264 return nil
265 }
266 session, err := a.sessions.Get(ctx, sessionID)
267 if err != nil {
268 return err
269 }
270 parts := []message.ContentPart{message.TextContent{
271 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
272 }}
273
274 // Use streaming approach like summarization
275 response := a.titleProvider.StreamResponse(
276 ctx,
277 []message.Message{
278 {
279 Role: message.User,
280 Parts: parts,
281 },
282 },
283 make([]tools.BaseTool, 0),
284 )
285
286 var finalResponse *provider.ProviderResponse
287 for r := range response {
288 if r.Error != nil {
289 return r.Error
290 }
291 finalResponse = r.Response
292 }
293
294 if finalResponse == nil {
295 return fmt.Errorf("no response received from title provider")
296 }
297
298 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
299 if title == "" {
300 return nil
301 }
302
303 session.Title = title
304 _, err = a.sessions.Save(ctx, session)
305 return err
306}
307
308func (a *agent) err(err error) AgentEvent {
309 return AgentEvent{
310 Type: AgentEventTypeError,
311 Error: err,
312 }
313}
314
315func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
316 if !a.Model().SupportsImages && attachments != nil {
317 attachments = nil
318 }
319 events := make(chan AgentEvent)
320 if a.IsSessionBusy(sessionID) {
321 return nil, ErrSessionBusy
322 }
323
324 genCtx, cancel := context.WithCancel(ctx)
325
326 a.activeRequests.Store(sessionID, cancel)
327 go func() {
328 logging.Debug("Request started", "sessionID", sessionID)
329 defer logging.RecoverPanic("agent.Run", func() {
330 events <- a.err(fmt.Errorf("panic while running the agent"))
331 })
332 var attachmentParts []message.ContentPart
333 for _, attachment := range attachments {
334 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
335 }
336 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
337 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
338 logging.ErrorPersist(result.Error.Error())
339 }
340 logging.Debug("Request completed", "sessionID", sessionID)
341 a.activeRequests.Delete(sessionID)
342 cancel()
343 a.Publish(pubsub.CreatedEvent, result)
344 events <- result
345 close(events)
346 }()
347 return events, nil
348}
349
350func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
351 cfg := config.Get()
352 // List existing messages; if none, start title generation asynchronously.
353 msgs, err := a.messages.List(ctx, sessionID)
354 if err != nil {
355 return a.err(fmt.Errorf("failed to list messages: %w", err))
356 }
357 if len(msgs) == 0 {
358 go func() {
359 defer logging.RecoverPanic("agent.Run", func() {
360 logging.ErrorPersist("panic while generating title")
361 })
362 titleErr := a.generateTitle(context.Background(), sessionID, content)
363 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
364 logging.ErrorPersist(fmt.Sprintf("failed to generate title: %v", titleErr))
365 }
366 }()
367 }
368 session, err := a.sessions.Get(ctx, sessionID)
369 if err != nil {
370 return a.err(fmt.Errorf("failed to get session: %w", err))
371 }
372 if session.SummaryMessageID != "" {
373 summaryMsgInex := -1
374 for i, msg := range msgs {
375 if msg.ID == session.SummaryMessageID {
376 summaryMsgInex = i
377 break
378 }
379 }
380 if summaryMsgInex != -1 {
381 msgs = msgs[summaryMsgInex:]
382 msgs[0].Role = message.User
383 }
384 }
385
386 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
387 if err != nil {
388 return a.err(fmt.Errorf("failed to create user message: %w", err))
389 }
390 // Append the new user message to the conversation history.
391 msgHistory := append(msgs, userMsg)
392
393 for {
394 // Check for cancellation before each iteration
395 select {
396 case <-ctx.Done():
397 return a.err(ctx.Err())
398 default:
399 // Continue processing
400 }
401 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
402 if err != nil {
403 if errors.Is(err, context.Canceled) {
404 agentMessage.AddFinish(message.FinishReasonCanceled)
405 a.messages.Update(context.Background(), agentMessage)
406 return a.err(ErrRequestCancelled)
407 }
408 return a.err(fmt.Errorf("failed to process events: %w", err))
409 }
410 if cfg.Options.Debug {
411 seqId := (len(msgHistory) + 1) / 2
412 toolResultFilepath := logging.WriteToolResultsJson(sessionID, seqId, toolResults)
413 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", "{}", "filepath", toolResultFilepath)
414 } else {
415 logging.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
416 }
417 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
418 // We are not done, we need to respond with the tool response
419 msgHistory = append(msgHistory, agentMessage, *toolResults)
420 continue
421 }
422 return AgentEvent{
423 Type: AgentEventTypeResponse,
424 Message: agentMessage,
425 Done: true,
426 }
427 }
428}
429
430func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
431 parts := []message.ContentPart{message.TextContent{Text: content}}
432 parts = append(parts, attachmentParts...)
433 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
434 Role: message.User,
435 Parts: parts,
436 })
437}
438
439func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
440 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
441 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
442
443 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444 Role: message.Assistant,
445 Parts: []message.ContentPart{},
446 Model: a.Model().ID,
447 Provider: a.providerID,
448 })
449 if err != nil {
450 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
451 }
452
453 // Add the session and message ID into the context if needed by tools.
454 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
455
456 // Process each event in the stream.
457 for event := range eventChan {
458 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
459 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
460 return assistantMsg, nil, processErr
461 }
462 if ctx.Err() != nil {
463 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
464 return assistantMsg, nil, ctx.Err()
465 }
466 }
467
468 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
469 toolCalls := assistantMsg.ToolCalls()
470 for i, toolCall := range toolCalls {
471 select {
472 case <-ctx.Done():
473 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
474 // Make all future tool calls cancelled
475 for j := i; j < len(toolCalls); j++ {
476 toolResults[j] = message.ToolResult{
477 ToolCallID: toolCalls[j].ID,
478 Content: "Tool execution canceled by user",
479 IsError: true,
480 }
481 }
482 goto out
483 default:
484 // Continue processing
485 var tool tools.BaseTool
486 for _, availableTool := range a.tools {
487 if availableTool.Info().Name == toolCall.Name {
488 tool = availableTool
489 break
490 }
491 }
492
493 // Tool not found
494 if tool == nil {
495 toolResults[i] = message.ToolResult{
496 ToolCallID: toolCall.ID,
497 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
498 IsError: true,
499 }
500 continue
501 }
502 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
503 ID: toolCall.ID,
504 Name: toolCall.Name,
505 Input: toolCall.Input,
506 })
507 if toolErr != nil {
508 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
509 toolResults[i] = message.ToolResult{
510 ToolCallID: toolCall.ID,
511 Content: "Permission denied",
512 IsError: true,
513 }
514 for j := i + 1; j < len(toolCalls); j++ {
515 toolResults[j] = message.ToolResult{
516 ToolCallID: toolCalls[j].ID,
517 Content: "Tool execution canceled by user",
518 IsError: true,
519 }
520 }
521 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
522 break
523 }
524 }
525 toolResults[i] = message.ToolResult{
526 ToolCallID: toolCall.ID,
527 Content: toolResult.Content,
528 Metadata: toolResult.Metadata,
529 IsError: toolResult.IsError,
530 }
531 }
532 }
533out:
534 if len(toolResults) == 0 {
535 return assistantMsg, nil, nil
536 }
537 parts := make([]message.ContentPart, 0)
538 for _, tr := range toolResults {
539 parts = append(parts, tr)
540 }
541 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
542 Role: message.Tool,
543 Parts: parts,
544 Provider: a.providerID,
545 })
546 if err != nil {
547 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
548 }
549
550 return assistantMsg, &msg, err
551}
552
553func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
554 msg.AddFinish(finishReson)
555 _ = a.messages.Update(ctx, *msg)
556}
557
558func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
559 select {
560 case <-ctx.Done():
561 return ctx.Err()
562 default:
563 // Continue processing.
564 }
565
566 switch event.Type {
567 case provider.EventThinkingDelta:
568 assistantMsg.AppendReasoningContent(event.Content)
569 return a.messages.Update(ctx, *assistantMsg)
570 case provider.EventContentDelta:
571 assistantMsg.AppendContent(event.Content)
572 return a.messages.Update(ctx, *assistantMsg)
573 case provider.EventToolUseStart:
574 logging.Info("Tool call started", "toolCall", event.ToolCall)
575 assistantMsg.AddToolCall(*event.ToolCall)
576 return a.messages.Update(ctx, *assistantMsg)
577 case provider.EventToolUseDelta:
578 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
579 return a.messages.Update(ctx, *assistantMsg)
580 case provider.EventToolUseStop:
581 logging.Info("Finished tool call", "toolCall", event.ToolCall)
582 assistantMsg.FinishToolCall(event.ToolCall.ID)
583 return a.messages.Update(ctx, *assistantMsg)
584 case provider.EventError:
585 if errors.Is(event.Error, context.Canceled) {
586 logging.InfoPersist(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
587 return context.Canceled
588 }
589 logging.ErrorPersist(event.Error.Error())
590 return event.Error
591 case provider.EventComplete:
592 assistantMsg.SetToolCalls(event.Response.ToolCalls)
593 assistantMsg.AddFinish(event.Response.FinishReason)
594 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
595 return fmt.Errorf("failed to update message: %w", err)
596 }
597 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
598 }
599
600 return nil
601}
602
603func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
604 sess, err := a.sessions.Get(ctx, sessionID)
605 if err != nil {
606 return fmt.Errorf("failed to get session: %w", err)
607 }
608
609 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
610 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
611 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
612 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
613
614 sess.Cost += cost
615 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
616 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
617
618 _, err = a.sessions.Save(ctx, sess)
619 if err != nil {
620 return fmt.Errorf("failed to save session: %w", err)
621 }
622 return nil
623}
624
625func (a *agent) Summarize(ctx context.Context, sessionID string) error {
626 if a.summarizeProvider == nil {
627 return fmt.Errorf("summarize provider not available")
628 }
629
630 // Check if session is busy
631 if a.IsSessionBusy(sessionID) {
632 return ErrSessionBusy
633 }
634
635 // Create a new context with cancellation
636 summarizeCtx, cancel := context.WithCancel(ctx)
637
638 // Store the cancel function in activeRequests to allow cancellation
639 a.activeRequests.Store(sessionID+"-summarize", cancel)
640
641 go func() {
642 defer a.activeRequests.Delete(sessionID + "-summarize")
643 defer cancel()
644 event := AgentEvent{
645 Type: AgentEventTypeSummarize,
646 Progress: "Starting summarization...",
647 }
648
649 a.Publish(pubsub.CreatedEvent, event)
650 // Get all messages from the session
651 msgs, err := a.messages.List(summarizeCtx, sessionID)
652 if err != nil {
653 event = AgentEvent{
654 Type: AgentEventTypeError,
655 Error: fmt.Errorf("failed to list messages: %w", err),
656 Done: true,
657 }
658 a.Publish(pubsub.CreatedEvent, event)
659 return
660 }
661 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
662
663 if len(msgs) == 0 {
664 event = AgentEvent{
665 Type: AgentEventTypeError,
666 Error: fmt.Errorf("no messages to summarize"),
667 Done: true,
668 }
669 a.Publish(pubsub.CreatedEvent, event)
670 return
671 }
672
673 event = AgentEvent{
674 Type: AgentEventTypeSummarize,
675 Progress: "Analyzing conversation...",
676 }
677 a.Publish(pubsub.CreatedEvent, event)
678
679 // Add a system message to guide the summarization
680 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
681
682 // Create a new message with the summarize prompt
683 promptMsg := message.Message{
684 Role: message.User,
685 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
686 }
687
688 // Append the prompt to the messages
689 msgsWithPrompt := append(msgs, promptMsg)
690
691 event = AgentEvent{
692 Type: AgentEventTypeSummarize,
693 Progress: "Generating summary...",
694 }
695
696 a.Publish(pubsub.CreatedEvent, event)
697
698 // Send the messages to the summarize provider
699 response := a.summarizeProvider.StreamResponse(
700 summarizeCtx,
701 msgsWithPrompt,
702 make([]tools.BaseTool, 0),
703 )
704 var finalResponse *provider.ProviderResponse
705 for r := range response {
706 if r.Error != nil {
707 event = AgentEvent{
708 Type: AgentEventTypeError,
709 Error: fmt.Errorf("failed to summarize: %w", err),
710 Done: true,
711 }
712 a.Publish(pubsub.CreatedEvent, event)
713 return
714 }
715 finalResponse = r.Response
716 }
717
718 summary := strings.TrimSpace(finalResponse.Content)
719 if summary == "" {
720 event = AgentEvent{
721 Type: AgentEventTypeError,
722 Error: fmt.Errorf("empty summary returned"),
723 Done: true,
724 }
725 a.Publish(pubsub.CreatedEvent, event)
726 return
727 }
728 event = AgentEvent{
729 Type: AgentEventTypeSummarize,
730 Progress: "Creating new session...",
731 }
732
733 a.Publish(pubsub.CreatedEvent, event)
734 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
735 if err != nil {
736 event = AgentEvent{
737 Type: AgentEventTypeError,
738 Error: fmt.Errorf("failed to get session: %w", err),
739 Done: true,
740 }
741
742 a.Publish(pubsub.CreatedEvent, event)
743 return
744 }
745 // Create a message in the new session with the summary
746 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
747 Role: message.Assistant,
748 Parts: []message.ContentPart{
749 message.TextContent{Text: summary},
750 message.Finish{
751 Reason: message.FinishReasonEndTurn,
752 Time: time.Now().Unix(),
753 },
754 },
755 Model: a.summarizeProvider.Model().ID,
756 Provider: a.summarizeProviderID,
757 })
758 if err != nil {
759 event = AgentEvent{
760 Type: AgentEventTypeError,
761 Error: fmt.Errorf("failed to create summary message: %w", err),
762 Done: true,
763 }
764
765 a.Publish(pubsub.CreatedEvent, event)
766 return
767 }
768 oldSession.SummaryMessageID = msg.ID
769 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
770 oldSession.PromptTokens = 0
771 model := a.summarizeProvider.Model()
772 usage := finalResponse.Usage
773 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
774 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
775 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
776 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
777 oldSession.Cost += cost
778 _, err = a.sessions.Save(summarizeCtx, oldSession)
779 if err != nil {
780 event = AgentEvent{
781 Type: AgentEventTypeError,
782 Error: fmt.Errorf("failed to save session: %w", err),
783 Done: true,
784 }
785 a.Publish(pubsub.CreatedEvent, event)
786 }
787
788 event = AgentEvent{
789 Type: AgentEventTypeSummarize,
790 SessionID: oldSession.ID,
791 Progress: "Summary complete",
792 Done: true,
793 }
794 a.Publish(pubsub.CreatedEvent, event)
795 // Send final success event with the new session ID
796 }()
797
798 return nil
799}
800
801func (a *agent) CancelAll() {
802 a.activeRequests.Range(func(key, value any) bool {
803 a.Cancel(key.(string)) // key is sessionID
804 return true
805 })
806}
807
808func (a *agent) UpdateModel() error {
809 cfg := config.Get()
810
811 // Get current provider configuration
812 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
813 if currentProviderCfg.ID == "" {
814 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
815 }
816
817 // Check if provider has changed
818 if string(currentProviderCfg.ID) != a.providerID {
819 // Provider changed, need to recreate the main provider
820 model := cfg.GetModelByType(a.agentCfg.Model)
821 if model.ID == "" {
822 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
823 }
824
825 promptID := agentPromptMap[a.agentCfg.ID]
826 if promptID == "" {
827 promptID = prompt.PromptDefault
828 }
829
830 opts := []provider.ProviderClientOption{
831 provider.WithModel(a.agentCfg.Model),
832 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
833 }
834
835 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
836 if err != nil {
837 return fmt.Errorf("failed to create new provider: %w", err)
838 }
839
840 // Update the provider and provider ID
841 a.provider = newProvider
842 a.providerID = string(currentProviderCfg.ID)
843 }
844
845 // Check if small model provider has changed (affects title and summarize providers)
846 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
847 var smallModelProviderCfg config.ProviderConfig
848
849 for _, p := range cfg.Providers {
850 if p.ID == smallModelCfg.Provider {
851 smallModelProviderCfg = p
852 break
853 }
854 }
855
856 if smallModelProviderCfg.ID == "" {
857 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
858 }
859
860 // Check if summarize provider has changed
861 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
862 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
863 if smallModel == nil {
864 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
865 }
866
867 // Recreate title provider
868 titleOpts := []provider.ProviderClientOption{
869 provider.WithModel(config.SelectedModelTypeSmall),
870 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
871 // We want the title to be short, so we limit the max tokens
872 provider.WithMaxTokens(40),
873 }
874 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
875 if err != nil {
876 return fmt.Errorf("failed to create new title provider: %w", err)
877 }
878
879 // Recreate summarize provider
880 summarizeOpts := []provider.ProviderClientOption{
881 provider.WithModel(config.SelectedModelTypeSmall),
882 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
883 }
884 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
885 if err != nil {
886 return fmt.Errorf("failed to create new summarize provider: %w", err)
887 }
888
889 // Update the providers and provider ID
890 a.titleProvider = newTitleProvider
891 a.summarizeProvider = newSummarizeProvider
892 a.summarizeProviderID = string(smallModelProviderCfg.ID)
893 }
894
895 return nil
896}