1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 fur "github.com/charmbracelet/crush/internal/fur/provider"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25)
26
27// Common errors
28var (
29 ErrRequestCancelled = errors.New("request canceled by user")
30 ErrSessionBusy = errors.New("session is currently processing another request")
31)
32
33type AgentEventType string
34
35const (
36 AgentEventTypeError AgentEventType = "error"
37 AgentEventTypeResponse AgentEventType = "response"
38 AgentEventTypeSummarize AgentEventType = "summarize"
39)
40
41type AgentEvent struct {
42 Type AgentEventType
43 Message message.Message
44 Error error
45
46 // When summarizing
47 SessionID string
48 Progress string
49 Done bool
50}
51
52type Service interface {
53 pubsub.Suscriber[AgentEvent]
54 Model() fur.Model
55 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
56 Cancel(sessionID string)
57 CancelAll()
58 IsSessionBusy(sessionID string) bool
59 IsBusy() bool
60 Summarize(ctx context.Context, sessionID string) error
61 UpdateModel() error
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69
70 tools []tools.BaseTool
71 provider provider.Provider
72 providerID string
73
74 titleProvider provider.Provider
75 summarizeProvider provider.Provider
76 summarizeProviderID string
77
78 activeRequests sync.Map
79}
80
81var agentPromptMap = map[string]prompt.PromptID{
82 "coder": prompt.PromptCoder,
83 "task": prompt.PromptTask,
84}
85
86func NewAgent(
87 agentCfg config.Agent,
88 // These services are needed in the tools
89 permissions permission.Service,
90 sessions session.Service,
91 messages message.Service,
92 history history.Service,
93 lspClients map[string]*lsp.Client,
94) (Service, error) {
95 ctx := context.Background()
96 cfg := config.Get()
97 otherTools := GetMcpTools(ctx, permissions, cfg)
98 if len(lspClients) > 0 {
99 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100 }
101
102 cwd := cfg.WorkingDir()
103 allTools := []tools.BaseTool{
104 tools.NewBashTool(permissions, cwd),
105 tools.NewEditTool(lspClients, permissions, history, cwd),
106 tools.NewFetchTool(permissions, cwd),
107 tools.NewGlobTool(cwd),
108 tools.NewGrepTool(cwd),
109 tools.NewLsTool(cwd),
110 tools.NewSourcegraphTool(),
111 tools.NewViewTool(lspClients, cwd),
112 tools.NewWriteTool(lspClients, permissions, history, cwd),
113 }
114
115 if agentCfg.ID == "coder" {
116 taskAgentCfg := config.Get().Agents["task"]
117 if taskAgentCfg.ID == "" {
118 return nil, fmt.Errorf("task agent not found in config")
119 }
120 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
121 if err != nil {
122 return nil, fmt.Errorf("failed to create task agent: %w", err)
123 }
124
125 allTools = append(
126 allTools,
127 NewAgentTool(
128 taskAgent,
129 sessions,
130 messages,
131 ),
132 )
133 }
134
135 allTools = append(allTools, otherTools...)
136 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
137 if providerCfg == nil {
138 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
139 }
140 model := config.Get().GetModelByType(agentCfg.Model)
141
142 if model == nil {
143 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
144 }
145
146 promptID := agentPromptMap[agentCfg.ID]
147 if promptID == "" {
148 promptID = prompt.PromptDefault
149 }
150 opts := []provider.ProviderClientOption{
151 provider.WithModel(agentCfg.Model),
152 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
153 }
154 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
155 if err != nil {
156 return nil, err
157 }
158
159 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
160 var smallModelProviderCfg *config.ProviderConfig
161 if smallModelCfg.Provider == providerCfg.ID {
162 smallModelProviderCfg = providerCfg
163 } else {
164 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
165
166 if smallModelProviderCfg.ID == "" {
167 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
168 }
169 }
170 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
171 if smallModel.ID == "" {
172 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
173 }
174
175 titleOpts := []provider.ProviderClientOption{
176 provider.WithModel(config.SelectedModelTypeSmall),
177 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
178 }
179 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
180 if err != nil {
181 return nil, err
182 }
183 summarizeOpts := []provider.ProviderClientOption{
184 provider.WithModel(config.SelectedModelTypeSmall),
185 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
186 }
187 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
188 if err != nil {
189 return nil, err
190 }
191
192 agentTools := []tools.BaseTool{}
193 if agentCfg.AllowedTools == nil {
194 agentTools = allTools
195 } else {
196 for _, tool := range allTools {
197 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
198 agentTools = append(agentTools, tool)
199 }
200 }
201 }
202
203 agent := &agent{
204 Broker: pubsub.NewBroker[AgentEvent](),
205 agentCfg: agentCfg,
206 provider: agentProvider,
207 providerID: string(providerCfg.ID),
208 messages: messages,
209 sessions: sessions,
210 tools: agentTools,
211 titleProvider: titleProvider,
212 summarizeProvider: summarizeProvider,
213 summarizeProviderID: string(smallModelProviderCfg.ID),
214 activeRequests: sync.Map{},
215 }
216
217 return agent, nil
218}
219
220func (a *agent) Model() fur.Model {
221 return *config.Get().GetModelByType(a.agentCfg.Model)
222}
223
224func (a *agent) Cancel(sessionID string) {
225 // Cancel regular requests
226 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
227 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
228 slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
229 cancel()
230 }
231 }
232
233 // Also check for summarize requests
234 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
235 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236 slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
237 cancel()
238 }
239 }
240}
241
242func (a *agent) IsBusy() bool {
243 busy := false
244 a.activeRequests.Range(func(key, value any) bool {
245 if cancelFunc, ok := value.(context.CancelFunc); ok {
246 if cancelFunc != nil {
247 busy = true
248 return false
249 }
250 }
251 return true
252 })
253 return busy
254}
255
256func (a *agent) IsSessionBusy(sessionID string) bool {
257 _, busy := a.activeRequests.Load(sessionID)
258 return busy
259}
260
261func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
262 if content == "" {
263 return nil
264 }
265 if a.titleProvider == nil {
266 return nil
267 }
268 session, err := a.sessions.Get(ctx, sessionID)
269 if err != nil {
270 return err
271 }
272 parts := []message.ContentPart{message.TextContent{
273 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
274 }}
275
276 // Use streaming approach like summarization
277 response := a.titleProvider.StreamResponse(
278 ctx,
279 []message.Message{
280 {
281 Role: message.User,
282 Parts: parts,
283 },
284 },
285 make([]tools.BaseTool, 0),
286 )
287
288 var finalResponse *provider.ProviderResponse
289 for r := range response {
290 if r.Error != nil {
291 return r.Error
292 }
293 finalResponse = r.Response
294 }
295
296 if finalResponse == nil {
297 return fmt.Errorf("no response received from title provider")
298 }
299
300 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
301 if title == "" {
302 return nil
303 }
304
305 session.Title = title
306 _, err = a.sessions.Save(ctx, session)
307 return err
308}
309
310func (a *agent) err(err error) AgentEvent {
311 return AgentEvent{
312 Type: AgentEventTypeError,
313 Error: err,
314 }
315}
316
317func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
318 if !a.Model().SupportsImages && attachments != nil {
319 attachments = nil
320 }
321 events := make(chan AgentEvent)
322 if a.IsSessionBusy(sessionID) {
323 return nil, ErrSessionBusy
324 }
325
326 genCtx, cancel := context.WithCancel(ctx)
327
328 a.activeRequests.Store(sessionID, cancel)
329 go func() {
330 slog.Debug("Request started", "sessionID", sessionID)
331 defer log.RecoverPanic("agent.Run", func() {
332 events <- a.err(fmt.Errorf("panic while running the agent"))
333 })
334 var attachmentParts []message.ContentPart
335 for _, attachment := range attachments {
336 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
337 }
338 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
339 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
340 slog.Error(result.Error.Error())
341 }
342 slog.Debug("Request completed", "sessionID", sessionID)
343 a.activeRequests.Delete(sessionID)
344 cancel()
345 a.Publish(pubsub.CreatedEvent, result)
346 events <- result
347 close(events)
348 }()
349 return events, nil
350}
351
352func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
353 cfg := config.Get()
354 // List existing messages; if none, start title generation asynchronously.
355 msgs, err := a.messages.List(ctx, sessionID)
356 if err != nil {
357 return a.err(fmt.Errorf("failed to list messages: %w", err))
358 }
359 if len(msgs) == 0 {
360 go func() {
361 defer log.RecoverPanic("agent.Run", func() {
362 slog.Error("panic while generating title")
363 })
364 titleErr := a.generateTitle(context.Background(), sessionID, content)
365 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
366 slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
367 }
368 }()
369 }
370 session, err := a.sessions.Get(ctx, sessionID)
371 if err != nil {
372 return a.err(fmt.Errorf("failed to get session: %w", err))
373 }
374 if session.SummaryMessageID != "" {
375 summaryMsgInex := -1
376 for i, msg := range msgs {
377 if msg.ID == session.SummaryMessageID {
378 summaryMsgInex = i
379 break
380 }
381 }
382 if summaryMsgInex != -1 {
383 msgs = msgs[summaryMsgInex:]
384 msgs[0].Role = message.User
385 }
386 }
387
388 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
389 if err != nil {
390 return a.err(fmt.Errorf("failed to create user message: %w", err))
391 }
392 // Append the new user message to the conversation history.
393 msgHistory := append(msgs, userMsg)
394
395 for {
396 // Check for cancellation before each iteration
397 select {
398 case <-ctx.Done():
399 return a.err(ctx.Err())
400 default:
401 // Continue processing
402 }
403 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
404 if err != nil {
405 if errors.Is(err, context.Canceled) {
406 agentMessage.AddFinish(message.FinishReasonCanceled)
407 a.messages.Update(context.Background(), agentMessage)
408 return a.err(ErrRequestCancelled)
409 }
410 return a.err(fmt.Errorf("failed to process events: %w", err))
411 }
412 if cfg.Options.Debug {
413 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
414 }
415 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
416 // We are not done, we need to respond with the tool response
417 msgHistory = append(msgHistory, agentMessage, *toolResults)
418 continue
419 }
420 return AgentEvent{
421 Type: AgentEventTypeResponse,
422 Message: agentMessage,
423 Done: true,
424 }
425 }
426}
427
428func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
429 parts := []message.ContentPart{message.TextContent{Text: content}}
430 parts = append(parts, attachmentParts...)
431 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
432 Role: message.User,
433 Parts: parts,
434 })
435}
436
437func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
438 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
439 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
440
441 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442 Role: message.Assistant,
443 Parts: []message.ContentPart{},
444 Model: a.Model().ID,
445 Provider: a.providerID,
446 })
447 if err != nil {
448 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
449 }
450
451 // Add the session and message ID into the context if needed by tools.
452 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
453
454 // Process each event in the stream.
455 for event := range eventChan {
456 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
457 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
458 return assistantMsg, nil, processErr
459 }
460 if ctx.Err() != nil {
461 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
462 return assistantMsg, nil, ctx.Err()
463 }
464 }
465
466 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
467 toolCalls := assistantMsg.ToolCalls()
468 for i, toolCall := range toolCalls {
469 select {
470 case <-ctx.Done():
471 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
472 // Make all future tool calls cancelled
473 for j := i; j < len(toolCalls); j++ {
474 toolResults[j] = message.ToolResult{
475 ToolCallID: toolCalls[j].ID,
476 Content: "Tool execution canceled by user",
477 IsError: true,
478 }
479 }
480 goto out
481 default:
482 // Continue processing
483 var tool tools.BaseTool
484 for _, availableTool := range a.tools {
485 if availableTool.Info().Name == toolCall.Name {
486 tool = availableTool
487 break
488 }
489 }
490
491 // Tool not found
492 if tool == nil {
493 toolResults[i] = message.ToolResult{
494 ToolCallID: toolCall.ID,
495 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
496 IsError: true,
497 }
498 continue
499 }
500 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
501 ID: toolCall.ID,
502 Name: toolCall.Name,
503 Input: toolCall.Input,
504 })
505 if toolErr != nil {
506 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
507 toolResults[i] = message.ToolResult{
508 ToolCallID: toolCall.ID,
509 Content: "Permission denied",
510 IsError: true,
511 }
512 for j := i + 1; j < len(toolCalls); j++ {
513 toolResults[j] = message.ToolResult{
514 ToolCallID: toolCalls[j].ID,
515 Content: "Tool execution canceled by user",
516 IsError: true,
517 }
518 }
519 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
520 break
521 }
522 }
523 toolResults[i] = message.ToolResult{
524 ToolCallID: toolCall.ID,
525 Content: toolResult.Content,
526 Metadata: toolResult.Metadata,
527 IsError: toolResult.IsError,
528 }
529 }
530 }
531out:
532 if len(toolResults) == 0 {
533 return assistantMsg, nil, nil
534 }
535 parts := make([]message.ContentPart, 0)
536 for _, tr := range toolResults {
537 parts = append(parts, tr)
538 }
539 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
540 Role: message.Tool,
541 Parts: parts,
542 Provider: a.providerID,
543 })
544 if err != nil {
545 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
546 }
547
548 return assistantMsg, &msg, err
549}
550
551func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
552 msg.AddFinish(finishReson)
553 _ = a.messages.Update(ctx, *msg)
554}
555
556func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
557 select {
558 case <-ctx.Done():
559 return ctx.Err()
560 default:
561 // Continue processing.
562 }
563
564 switch event.Type {
565 case provider.EventThinkingDelta:
566 assistantMsg.AppendReasoningContent(event.Content)
567 return a.messages.Update(ctx, *assistantMsg)
568 case provider.EventContentDelta:
569 assistantMsg.AppendContent(event.Content)
570 return a.messages.Update(ctx, *assistantMsg)
571 case provider.EventToolUseStart:
572 slog.Info("Tool call started", "toolCall", event.ToolCall)
573 assistantMsg.AddToolCall(*event.ToolCall)
574 return a.messages.Update(ctx, *assistantMsg)
575 case provider.EventToolUseDelta:
576 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
577 return a.messages.Update(ctx, *assistantMsg)
578 case provider.EventToolUseStop:
579 slog.Info("Finished tool call", "toolCall", event.ToolCall)
580 assistantMsg.FinishToolCall(event.ToolCall.ID)
581 return a.messages.Update(ctx, *assistantMsg)
582 case provider.EventError:
583 if errors.Is(event.Error, context.Canceled) {
584 slog.Info(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
585 return context.Canceled
586 }
587 slog.Error(event.Error.Error())
588 return event.Error
589 case provider.EventComplete:
590 assistantMsg.SetToolCalls(event.Response.ToolCalls)
591 assistantMsg.AddFinish(event.Response.FinishReason)
592 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
593 return fmt.Errorf("failed to update message: %w", err)
594 }
595 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
596 }
597
598 return nil
599}
600
601func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
602 sess, err := a.sessions.Get(ctx, sessionID)
603 if err != nil {
604 return fmt.Errorf("failed to get session: %w", err)
605 }
606
607 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
608 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
609 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
610 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
611
612 sess.Cost += cost
613 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
614 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
615
616 _, err = a.sessions.Save(ctx, sess)
617 if err != nil {
618 return fmt.Errorf("failed to save session: %w", err)
619 }
620 return nil
621}
622
623func (a *agent) Summarize(ctx context.Context, sessionID string) error {
624 if a.summarizeProvider == nil {
625 return fmt.Errorf("summarize provider not available")
626 }
627
628 // Check if session is busy
629 if a.IsSessionBusy(sessionID) {
630 return ErrSessionBusy
631 }
632
633 // Create a new context with cancellation
634 summarizeCtx, cancel := context.WithCancel(ctx)
635
636 // Store the cancel function in activeRequests to allow cancellation
637 a.activeRequests.Store(sessionID+"-summarize", cancel)
638
639 go func() {
640 defer a.activeRequests.Delete(sessionID + "-summarize")
641 defer cancel()
642 event := AgentEvent{
643 Type: AgentEventTypeSummarize,
644 Progress: "Starting summarization...",
645 }
646
647 a.Publish(pubsub.CreatedEvent, event)
648 // Get all messages from the session
649 msgs, err := a.messages.List(summarizeCtx, sessionID)
650 if err != nil {
651 event = AgentEvent{
652 Type: AgentEventTypeError,
653 Error: fmt.Errorf("failed to list messages: %w", err),
654 Done: true,
655 }
656 a.Publish(pubsub.CreatedEvent, event)
657 return
658 }
659 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
660
661 if len(msgs) == 0 {
662 event = AgentEvent{
663 Type: AgentEventTypeError,
664 Error: fmt.Errorf("no messages to summarize"),
665 Done: true,
666 }
667 a.Publish(pubsub.CreatedEvent, event)
668 return
669 }
670
671 event = AgentEvent{
672 Type: AgentEventTypeSummarize,
673 Progress: "Analyzing conversation...",
674 }
675 a.Publish(pubsub.CreatedEvent, event)
676
677 // Add a system message to guide the summarization
678 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
679
680 // Create a new message with the summarize prompt
681 promptMsg := message.Message{
682 Role: message.User,
683 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
684 }
685
686 // Append the prompt to the messages
687 msgsWithPrompt := append(msgs, promptMsg)
688
689 event = AgentEvent{
690 Type: AgentEventTypeSummarize,
691 Progress: "Generating summary...",
692 }
693
694 a.Publish(pubsub.CreatedEvent, event)
695
696 // Send the messages to the summarize provider
697 response := a.summarizeProvider.StreamResponse(
698 summarizeCtx,
699 msgsWithPrompt,
700 make([]tools.BaseTool, 0),
701 )
702 var finalResponse *provider.ProviderResponse
703 for r := range response {
704 if r.Error != nil {
705 event = AgentEvent{
706 Type: AgentEventTypeError,
707 Error: fmt.Errorf("failed to summarize: %w", err),
708 Done: true,
709 }
710 a.Publish(pubsub.CreatedEvent, event)
711 return
712 }
713 finalResponse = r.Response
714 }
715
716 summary := strings.TrimSpace(finalResponse.Content)
717 if summary == "" {
718 event = AgentEvent{
719 Type: AgentEventTypeError,
720 Error: fmt.Errorf("empty summary returned"),
721 Done: true,
722 }
723 a.Publish(pubsub.CreatedEvent, event)
724 return
725 }
726 event = AgentEvent{
727 Type: AgentEventTypeSummarize,
728 Progress: "Creating new session...",
729 }
730
731 a.Publish(pubsub.CreatedEvent, event)
732 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
733 if err != nil {
734 event = AgentEvent{
735 Type: AgentEventTypeError,
736 Error: fmt.Errorf("failed to get session: %w", err),
737 Done: true,
738 }
739
740 a.Publish(pubsub.CreatedEvent, event)
741 return
742 }
743 // Create a message in the new session with the summary
744 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
745 Role: message.Assistant,
746 Parts: []message.ContentPart{
747 message.TextContent{Text: summary},
748 message.Finish{
749 Reason: message.FinishReasonEndTurn,
750 Time: time.Now().Unix(),
751 },
752 },
753 Model: a.summarizeProvider.Model().ID,
754 Provider: a.summarizeProviderID,
755 })
756 if err != nil {
757 event = AgentEvent{
758 Type: AgentEventTypeError,
759 Error: fmt.Errorf("failed to create summary message: %w", err),
760 Done: true,
761 }
762
763 a.Publish(pubsub.CreatedEvent, event)
764 return
765 }
766 oldSession.SummaryMessageID = msg.ID
767 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
768 oldSession.PromptTokens = 0
769 model := a.summarizeProvider.Model()
770 usage := finalResponse.Usage
771 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
772 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
773 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
774 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
775 oldSession.Cost += cost
776 _, err = a.sessions.Save(summarizeCtx, oldSession)
777 if err != nil {
778 event = AgentEvent{
779 Type: AgentEventTypeError,
780 Error: fmt.Errorf("failed to save session: %w", err),
781 Done: true,
782 }
783 a.Publish(pubsub.CreatedEvent, event)
784 }
785
786 event = AgentEvent{
787 Type: AgentEventTypeSummarize,
788 SessionID: oldSession.ID,
789 Progress: "Summary complete",
790 Done: true,
791 }
792 a.Publish(pubsub.CreatedEvent, event)
793 // Send final success event with the new session ID
794 }()
795
796 return nil
797}
798
799func (a *agent) CancelAll() {
800 a.activeRequests.Range(func(key, value any) bool {
801 a.Cancel(key.(string)) // key is sessionID
802 return true
803 })
804 for {
805 if a.IsBusy() {
806 time.Sleep(200 * time.Millisecond)
807 } else {
808 break
809 }
810 }
811}
812
813func (a *agent) UpdateModel() error {
814 cfg := config.Get()
815
816 // Get current provider configuration
817 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
818 if currentProviderCfg.ID == "" {
819 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
820 }
821
822 // Check if provider has changed
823 if string(currentProviderCfg.ID) != a.providerID {
824 // Provider changed, need to recreate the main provider
825 model := cfg.GetModelByType(a.agentCfg.Model)
826 if model.ID == "" {
827 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
828 }
829
830 promptID := agentPromptMap[a.agentCfg.ID]
831 if promptID == "" {
832 promptID = prompt.PromptDefault
833 }
834
835 opts := []provider.ProviderClientOption{
836 provider.WithModel(a.agentCfg.Model),
837 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
838 }
839
840 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
841 if err != nil {
842 return fmt.Errorf("failed to create new provider: %w", err)
843 }
844
845 // Update the provider and provider ID
846 a.provider = newProvider
847 a.providerID = string(currentProviderCfg.ID)
848 }
849
850 // Check if small model provider has changed (affects title and summarize providers)
851 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
852 var smallModelProviderCfg config.ProviderConfig
853
854 for _, p := range cfg.Providers {
855 if p.ID == smallModelCfg.Provider {
856 smallModelProviderCfg = p
857 break
858 }
859 }
860
861 if smallModelProviderCfg.ID == "" {
862 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
863 }
864
865 // Check if summarize provider has changed
866 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
867 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
868 if smallModel == nil {
869 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
870 }
871
872 // Recreate title provider
873 titleOpts := []provider.ProviderClientOption{
874 provider.WithModel(config.SelectedModelTypeSmall),
875 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
876 // We want the title to be short, so we limit the max tokens
877 provider.WithMaxTokens(40),
878 }
879 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
880 if err != nil {
881 return fmt.Errorf("failed to create new title provider: %w", err)
882 }
883
884 // Recreate summarize provider
885 summarizeOpts := []provider.ProviderClientOption{
886 provider.WithModel(config.SelectedModelTypeSmall),
887 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
888 }
889 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
890 if err != nil {
891 return fmt.Errorf("failed to create new summarize provider: %w", err)
892 }
893
894 // Update the providers and provider ID
895 a.titleProvider = newTitleProvider
896 a.summarizeProvider = newSummarizeProvider
897 a.summarizeProviderID = string(smallModelProviderCfg.ID)
898 }
899
900 return nil
901}