1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 fur "github.com/charmbracelet/crush/internal/fur/provider"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25)
26
27// Common errors
28var (
29 ErrRequestCancelled = errors.New("request canceled by user")
30 ErrSessionBusy = errors.New("session is currently processing another request")
31)
32
33type AgentEventType string
34
35const (
36 AgentEventTypeError AgentEventType = "error"
37 AgentEventTypeResponse AgentEventType = "response"
38 AgentEventTypeSummarize AgentEventType = "summarize"
39)
40
41type AgentEvent struct {
42 Type AgentEventType
43 Message message.Message
44 Error error
45
46 // When summarizing
47 SessionID string
48 Progress string
49 Done bool
50}
51
52type Service interface {
53 pubsub.Suscriber[AgentEvent]
54 Model() fur.Model
55 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
56 Cancel(sessionID string)
57 CancelAll()
58 IsSessionBusy(sessionID string) bool
59 IsBusy() bool
60 Summarize(ctx context.Context, sessionID string) error
61 UpdateModel() error
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69
70 tools []tools.BaseTool
71 provider provider.Provider
72 providerID string
73
74 titleProvider provider.Provider
75 summarizeProvider provider.Provider
76 summarizeProviderID string
77
78 activeRequests sync.Map
79}
80
81var agentPromptMap = map[string]prompt.PromptID{
82 "coder": prompt.PromptCoder,
83 "task": prompt.PromptTask,
84}
85
86func NewAgent(
87 agentCfg config.Agent,
88 // These services are needed in the tools
89 permissions permission.Service,
90 sessions session.Service,
91 messages message.Service,
92 history history.Service,
93 lspClients map[string]*lsp.Client,
94) (Service, error) {
95 ctx := context.Background()
96 cfg := config.Get()
97 otherTools := GetMcpTools(ctx, permissions, cfg)
98 if len(lspClients) > 0 {
99 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100 }
101
102 cwd := cfg.WorkingDir()
103 allTools := []tools.BaseTool{
104 tools.NewBashTool(permissions, cwd),
105 tools.NewEditTool(lspClients, permissions, history, cwd),
106 tools.NewFetchTool(permissions, cwd),
107 tools.NewGlobTool(cwd),
108 tools.NewGrepTool(cwd),
109 tools.NewLsTool(cwd),
110 tools.NewSourcegraphTool(),
111 tools.NewViewTool(lspClients, cwd),
112 tools.NewWriteTool(lspClients, permissions, history, cwd),
113 }
114
115 if agentCfg.ID == "coder" {
116 taskAgentCfg := config.Get().Agents["task"]
117 if taskAgentCfg.ID == "" {
118 return nil, fmt.Errorf("task agent not found in config")
119 }
120 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
121 if err != nil {
122 return nil, fmt.Errorf("failed to create task agent: %w", err)
123 }
124
125 allTools = append(
126 allTools,
127 NewAgentTool(
128 taskAgent,
129 sessions,
130 messages,
131 ),
132 )
133 }
134
135 allTools = append(allTools, otherTools...)
136 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
137 if providerCfg == nil {
138 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
139 }
140 model := config.Get().GetModelByType(agentCfg.Model)
141
142 if model == nil {
143 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
144 }
145
146 promptID := agentPromptMap[agentCfg.ID]
147 if promptID == "" {
148 promptID = prompt.PromptDefault
149 }
150 opts := []provider.ProviderClientOption{
151 provider.WithModel(agentCfg.Model),
152 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID)),
153 }
154 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
155 if err != nil {
156 return nil, err
157 }
158
159 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
160 var smallModelProviderCfg *config.ProviderConfig
161 if smallModelCfg.Provider == providerCfg.ID {
162 smallModelProviderCfg = providerCfg
163 } else {
164 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
165
166 if smallModelProviderCfg.ID == "" {
167 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
168 }
169 }
170 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
171 if smallModel.ID == "" {
172 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
173 }
174
175 titleOpts := []provider.ProviderClientOption{
176 provider.WithModel(config.SelectedModelTypeSmall),
177 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
178 }
179 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
180 if err != nil {
181 return nil, err
182 }
183 summarizeOpts := []provider.ProviderClientOption{
184 provider.WithModel(config.SelectedModelTypeSmall),
185 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
186 }
187 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
188 if err != nil {
189 return nil, err
190 }
191
192 agentTools := []tools.BaseTool{}
193 if agentCfg.AllowedTools == nil {
194 agentTools = allTools
195 } else {
196 for _, tool := range allTools {
197 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
198 agentTools = append(agentTools, tool)
199 }
200 }
201 }
202
203 agent := &agent{
204 Broker: pubsub.NewBroker[AgentEvent](),
205 agentCfg: agentCfg,
206 provider: agentProvider,
207 providerID: string(providerCfg.ID),
208 messages: messages,
209 sessions: sessions,
210 tools: agentTools,
211 titleProvider: titleProvider,
212 summarizeProvider: summarizeProvider,
213 summarizeProviderID: string(smallModelProviderCfg.ID),
214 activeRequests: sync.Map{},
215 }
216
217 return agent, nil
218}
219
220func (a *agent) Model() fur.Model {
221 return *config.Get().GetModelByType(a.agentCfg.Model)
222}
223
224func (a *agent) Cancel(sessionID string) {
225 // Cancel regular requests
226 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
227 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
228 slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
229 cancel()
230 }
231 }
232
233 // Also check for summarize requests
234 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
235 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236 slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
237 cancel()
238 }
239 }
240}
241
242func (a *agent) IsBusy() bool {
243 busy := false
244 a.activeRequests.Range(func(key, value any) bool {
245 if cancelFunc, ok := value.(context.CancelFunc); ok {
246 if cancelFunc != nil {
247 busy = true
248 return false
249 }
250 }
251 return true
252 })
253 return busy
254}
255
256func (a *agent) IsSessionBusy(sessionID string) bool {
257 _, busy := a.activeRequests.Load(sessionID)
258 return busy
259}
260
261func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
262 if content == "" {
263 return nil
264 }
265 if a.titleProvider == nil {
266 return nil
267 }
268 session, err := a.sessions.Get(ctx, sessionID)
269 if err != nil {
270 return err
271 }
272 parts := []message.ContentPart{message.TextContent{
273 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
274 }}
275
276 // Use streaming approach like summarization
277 response := a.titleProvider.StreamResponse(
278 ctx,
279 []message.Message{
280 {
281 Role: message.User,
282 Parts: parts,
283 },
284 },
285 make([]tools.BaseTool, 0),
286 )
287
288 var finalResponse *provider.ProviderResponse
289 for r := range response {
290 if r.Error != nil {
291 return r.Error
292 }
293 finalResponse = r.Response
294 }
295
296 if finalResponse == nil {
297 return fmt.Errorf("no response received from title provider")
298 }
299
300 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
301 if title == "" {
302 return nil
303 }
304
305 session.Title = title
306 _, err = a.sessions.Save(ctx, session)
307 return err
308}
309
310func (a *agent) err(err error) AgentEvent {
311 return AgentEvent{
312 Type: AgentEventTypeError,
313 Error: err,
314 }
315}
316
317func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
318 if !a.Model().SupportsImages && attachments != nil {
319 attachments = nil
320 }
321 events := make(chan AgentEvent)
322 if a.IsSessionBusy(sessionID) {
323 return nil, ErrSessionBusy
324 }
325
326 genCtx, cancel := context.WithCancel(ctx)
327
328 a.activeRequests.Store(sessionID, cancel)
329 go func() {
330 slog.Debug("Request started", "sessionID", sessionID)
331 defer log.RecoverPanic("agent.Run", func() {
332 events <- a.err(fmt.Errorf("panic while running the agent"))
333 })
334 var attachmentParts []message.ContentPart
335 for _, attachment := range attachments {
336 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
337 }
338 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
339 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
340 slog.Error(result.Error.Error())
341 }
342 slog.Debug("Request completed", "sessionID", sessionID)
343 a.activeRequests.Delete(sessionID)
344 cancel()
345 a.Publish(pubsub.CreatedEvent, result)
346 events <- result
347 close(events)
348 }()
349 return events, nil
350}
351
352func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
353 cfg := config.Get()
354 // List existing messages; if none, start title generation asynchronously.
355 msgs, err := a.messages.List(ctx, sessionID)
356 if err != nil {
357 return a.err(fmt.Errorf("failed to list messages: %w", err))
358 }
359 if len(msgs) == 0 {
360 go func() {
361 defer log.RecoverPanic("agent.Run", func() {
362 slog.Error("panic while generating title")
363 })
364 titleErr := a.generateTitle(context.Background(), sessionID, content)
365 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
366 slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
367 }
368 }()
369 }
370 session, err := a.sessions.Get(ctx, sessionID)
371 if err != nil {
372 return a.err(fmt.Errorf("failed to get session: %w", err))
373 }
374 if session.SummaryMessageID != "" {
375 summaryMsgInex := -1
376 for i, msg := range msgs {
377 if msg.ID == session.SummaryMessageID {
378 summaryMsgInex = i
379 break
380 }
381 }
382 if summaryMsgInex != -1 {
383 msgs = msgs[summaryMsgInex:]
384 msgs[0].Role = message.User
385 }
386 }
387
388 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
389 if err != nil {
390 return a.err(fmt.Errorf("failed to create user message: %w", err))
391 }
392 // Append the new user message to the conversation history.
393 msgHistory := append(msgs, userMsg)
394
395 for {
396 // Check for cancellation before each iteration
397 select {
398 case <-ctx.Done():
399 return a.err(ctx.Err())
400 default:
401 // Continue processing
402 }
403 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
404 if err != nil {
405 if errors.Is(err, context.Canceled) {
406 agentMessage.AddFinish(message.FinishReasonCanceled)
407 a.messages.Update(context.Background(), agentMessage)
408 return a.err(ErrRequestCancelled)
409 }
410 return a.err(fmt.Errorf("failed to process events: %w", err))
411 }
412 if cfg.Options.Debug {
413 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
414 }
415 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
416 // We are not done, we need to respond with the tool response
417 msgHistory = append(msgHistory, agentMessage, *toolResults)
418 continue
419 }
420 return AgentEvent{
421 Type: AgentEventTypeResponse,
422 Message: agentMessage,
423 Done: true,
424 }
425 }
426}
427
428func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
429 parts := []message.ContentPart{message.TextContent{Text: content}}
430 parts = append(parts, attachmentParts...)
431 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
432 Role: message.User,
433 Parts: parts,
434 })
435}
436
437func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
438 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
439 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
440
441 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442 Role: message.Assistant,
443 Parts: []message.ContentPart{},
444 Model: a.Model().ID,
445 Provider: a.providerID,
446 })
447 if err != nil {
448 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
449 }
450
451 // Add the session and message ID into the context if needed by tools.
452 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
453
454 // Process each event in the stream.
455 for event := range eventChan {
456 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
457 a.finishMessage(ctx, &assistantMsg, message.FinishReasonCanceled)
458 return assistantMsg, nil, processErr
459 }
460 if ctx.Err() != nil {
461 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
462 return assistantMsg, nil, ctx.Err()
463 }
464 }
465
466 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
467 toolCalls := assistantMsg.ToolCalls()
468 for i, toolCall := range toolCalls {
469 select {
470 case <-ctx.Done():
471 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled)
472 // Make all future tool calls cancelled
473 for j := i; j < len(toolCalls); j++ {
474 toolResults[j] = message.ToolResult{
475 ToolCallID: toolCalls[j].ID,
476 Content: "Tool execution canceled by user",
477 IsError: true,
478 }
479 }
480 goto out
481 default:
482 // Continue processing
483 var tool tools.BaseTool
484 for _, availableTool := range a.tools {
485 if availableTool.Info().Name == toolCall.Name {
486 tool = availableTool
487 break
488 }
489 }
490
491 // Tool not found
492 if tool == nil {
493 toolResults[i] = message.ToolResult{
494 ToolCallID: toolCall.ID,
495 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
496 IsError: true,
497 }
498 continue
499 }
500 toolResult, toolErr := tool.Run(ctx, tools.ToolCall{
501 ID: toolCall.ID,
502 Name: toolCall.Name,
503 Input: toolCall.Input,
504 })
505 if toolErr != nil {
506 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
507 toolResults[i] = message.ToolResult{
508 ToolCallID: toolCall.ID,
509 Content: "Permission denied",
510 IsError: true,
511 }
512 for j := i + 1; j < len(toolCalls); j++ {
513 toolResults[j] = message.ToolResult{
514 ToolCallID: toolCalls[j].ID,
515 Content: "Tool execution canceled by user",
516 IsError: true,
517 }
518 }
519 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied)
520 break
521 }
522 }
523 toolResults[i] = message.ToolResult{
524 ToolCallID: toolCall.ID,
525 Content: toolResult.Content,
526 Metadata: toolResult.Metadata,
527 IsError: toolResult.IsError,
528 }
529 }
530 }
531out:
532 if len(toolResults) == 0 {
533 return assistantMsg, nil, nil
534 }
535 parts := make([]message.ContentPart, 0)
536 for _, tr := range toolResults {
537 parts = append(parts, tr)
538 }
539 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
540 Role: message.Tool,
541 Parts: parts,
542 Provider: a.providerID,
543 })
544 if err != nil {
545 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
546 }
547
548 return assistantMsg, &msg, err
549}
550
551func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReson message.FinishReason) {
552 msg.AddFinish(finishReson)
553 _ = a.messages.Update(ctx, *msg)
554}
555
556func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
557 select {
558 case <-ctx.Done():
559 return ctx.Err()
560 default:
561 // Continue processing.
562 }
563
564 switch event.Type {
565 case provider.EventThinkingDelta:
566 assistantMsg.AppendReasoningContent(event.Content)
567 return a.messages.Update(ctx, *assistantMsg)
568 case provider.EventContentDelta:
569 assistantMsg.AppendContent(event.Content)
570 return a.messages.Update(ctx, *assistantMsg)
571 case provider.EventToolUseStart:
572 slog.Info("Tool call started", "toolCall", event.ToolCall)
573 assistantMsg.AddToolCall(*event.ToolCall)
574 return a.messages.Update(ctx, *assistantMsg)
575 case provider.EventToolUseDelta:
576 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
577 return a.messages.Update(ctx, *assistantMsg)
578 case provider.EventToolUseStop:
579 slog.Info("Finished tool call", "toolCall", event.ToolCall)
580 assistantMsg.FinishToolCall(event.ToolCall.ID)
581 return a.messages.Update(ctx, *assistantMsg)
582 case provider.EventError:
583 if errors.Is(event.Error, context.Canceled) {
584 slog.Info(fmt.Sprintf("Event processing canceled for session: %s", sessionID))
585 return context.Canceled
586 }
587 slog.Error(event.Error.Error())
588 return event.Error
589 case provider.EventComplete:
590 assistantMsg.SetToolCalls(event.Response.ToolCalls)
591 assistantMsg.AddFinish(event.Response.FinishReason)
592 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
593 return fmt.Errorf("failed to update message: %w", err)
594 }
595 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
596 }
597
598 return nil
599}
600
601func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
602 sess, err := a.sessions.Get(ctx, sessionID)
603 if err != nil {
604 return fmt.Errorf("failed to get session: %w", err)
605 }
606
607 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
608 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
609 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
610 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
611
612 sess.Cost += cost
613 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
614 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
615
616 _, err = a.sessions.Save(ctx, sess)
617 if err != nil {
618 return fmt.Errorf("failed to save session: %w", err)
619 }
620 return nil
621}
622
623func (a *agent) Summarize(ctx context.Context, sessionID string) error {
624 if a.summarizeProvider == nil {
625 return fmt.Errorf("summarize provider not available")
626 }
627
628 // Check if session is busy
629 if a.IsSessionBusy(sessionID) {
630 return ErrSessionBusy
631 }
632
633 // Create a new context with cancellation
634 summarizeCtx, cancel := context.WithCancel(ctx)
635
636 // Store the cancel function in activeRequests to allow cancellation
637 a.activeRequests.Store(sessionID+"-summarize", cancel)
638
639 go func() {
640 defer a.activeRequests.Delete(sessionID + "-summarize")
641 defer cancel()
642 event := AgentEvent{
643 Type: AgentEventTypeSummarize,
644 Progress: "Starting summarization...",
645 }
646
647 a.Publish(pubsub.CreatedEvent, event)
648 // Get all messages from the session
649 msgs, err := a.messages.List(summarizeCtx, sessionID)
650 if err != nil {
651 event = AgentEvent{
652 Type: AgentEventTypeError,
653 Error: fmt.Errorf("failed to list messages: %w", err),
654 Done: true,
655 }
656 a.Publish(pubsub.CreatedEvent, event)
657 return
658 }
659 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
660
661 if len(msgs) == 0 {
662 event = AgentEvent{
663 Type: AgentEventTypeError,
664 Error: fmt.Errorf("no messages to summarize"),
665 Done: true,
666 }
667 a.Publish(pubsub.CreatedEvent, event)
668 return
669 }
670
671 event = AgentEvent{
672 Type: AgentEventTypeSummarize,
673 Progress: "Analyzing conversation...",
674 }
675 a.Publish(pubsub.CreatedEvent, event)
676
677 // Add a system message to guide the summarization
678 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
679
680 // Create a new message with the summarize prompt
681 promptMsg := message.Message{
682 Role: message.User,
683 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
684 }
685
686 // Append the prompt to the messages
687 msgsWithPrompt := append(msgs, promptMsg)
688
689 event = AgentEvent{
690 Type: AgentEventTypeSummarize,
691 Progress: "Generating summary...",
692 }
693
694 a.Publish(pubsub.CreatedEvent, event)
695
696 // Send the messages to the summarize provider
697 response := a.summarizeProvider.StreamResponse(
698 summarizeCtx,
699 msgsWithPrompt,
700 make([]tools.BaseTool, 0),
701 )
702 var finalResponse *provider.ProviderResponse
703 for r := range response {
704 if r.Error != nil {
705 event = AgentEvent{
706 Type: AgentEventTypeError,
707 Error: fmt.Errorf("failed to summarize: %w", err),
708 Done: true,
709 }
710 a.Publish(pubsub.CreatedEvent, event)
711 return
712 }
713 finalResponse = r.Response
714 }
715
716 summary := strings.TrimSpace(finalResponse.Content)
717 if summary == "" {
718 event = AgentEvent{
719 Type: AgentEventTypeError,
720 Error: fmt.Errorf("empty summary returned"),
721 Done: true,
722 }
723 a.Publish(pubsub.CreatedEvent, event)
724 return
725 }
726 event = AgentEvent{
727 Type: AgentEventTypeSummarize,
728 Progress: "Creating new session...",
729 }
730
731 a.Publish(pubsub.CreatedEvent, event)
732 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
733 if err != nil {
734 event = AgentEvent{
735 Type: AgentEventTypeError,
736 Error: fmt.Errorf("failed to get session: %w", err),
737 Done: true,
738 }
739
740 a.Publish(pubsub.CreatedEvent, event)
741 return
742 }
743 // Create a message in the new session with the summary
744 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
745 Role: message.Assistant,
746 Parts: []message.ContentPart{
747 message.TextContent{Text: summary},
748 message.Finish{
749 Reason: message.FinishReasonEndTurn,
750 Time: time.Now().Unix(),
751 },
752 },
753 Model: a.summarizeProvider.Model().ID,
754 Provider: a.summarizeProviderID,
755 })
756 if err != nil {
757 event = AgentEvent{
758 Type: AgentEventTypeError,
759 Error: fmt.Errorf("failed to create summary message: %w", err),
760 Done: true,
761 }
762
763 a.Publish(pubsub.CreatedEvent, event)
764 return
765 }
766 oldSession.SummaryMessageID = msg.ID
767 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
768 oldSession.PromptTokens = 0
769 model := a.summarizeProvider.Model()
770 usage := finalResponse.Usage
771 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
772 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
773 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
774 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
775 oldSession.Cost += cost
776 _, err = a.sessions.Save(summarizeCtx, oldSession)
777 if err != nil {
778 event = AgentEvent{
779 Type: AgentEventTypeError,
780 Error: fmt.Errorf("failed to save session: %w", err),
781 Done: true,
782 }
783 a.Publish(pubsub.CreatedEvent, event)
784 }
785
786 event = AgentEvent{
787 Type: AgentEventTypeSummarize,
788 SessionID: oldSession.ID,
789 Progress: "Summary complete",
790 Done: true,
791 }
792 a.Publish(pubsub.CreatedEvent, event)
793 // Send final success event with the new session ID
794 }()
795
796 return nil
797}
798
799func (a *agent) CancelAll() {
800 a.activeRequests.Range(func(key, value any) bool {
801 a.Cancel(key.(string)) // key is sessionID
802 return true
803 })
804}
805
806func (a *agent) UpdateModel() error {
807 cfg := config.Get()
808
809 // Get current provider configuration
810 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
811 if currentProviderCfg.ID == "" {
812 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
813 }
814
815 // Check if provider has changed
816 if string(currentProviderCfg.ID) != a.providerID {
817 // Provider changed, need to recreate the main provider
818 model := cfg.GetModelByType(a.agentCfg.Model)
819 if model.ID == "" {
820 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
821 }
822
823 promptID := agentPromptMap[a.agentCfg.ID]
824 if promptID == "" {
825 promptID = prompt.PromptDefault
826 }
827
828 opts := []provider.ProviderClientOption{
829 provider.WithModel(a.agentCfg.Model),
830 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID)),
831 }
832
833 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
834 if err != nil {
835 return fmt.Errorf("failed to create new provider: %w", err)
836 }
837
838 // Update the provider and provider ID
839 a.provider = newProvider
840 a.providerID = string(currentProviderCfg.ID)
841 }
842
843 // Check if small model provider has changed (affects title and summarize providers)
844 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
845 var smallModelProviderCfg config.ProviderConfig
846
847 for _, p := range cfg.Providers {
848 if p.ID == smallModelCfg.Provider {
849 smallModelProviderCfg = p
850 break
851 }
852 }
853
854 if smallModelProviderCfg.ID == "" {
855 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
856 }
857
858 // Check if summarize provider has changed
859 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
860 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
861 if smallModel == nil {
862 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
863 }
864
865 // Recreate title provider
866 titleOpts := []provider.ProviderClientOption{
867 provider.WithModel(config.SelectedModelTypeSmall),
868 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
869 // We want the title to be short, so we limit the max tokens
870 provider.WithMaxTokens(40),
871 }
872 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
873 if err != nil {
874 return fmt.Errorf("failed to create new title provider: %w", err)
875 }
876
877 // Recreate summarize provider
878 summarizeOpts := []provider.ProviderClientOption{
879 provider.WithModel(config.SelectedModelTypeSmall),
880 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
881 }
882 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
883 if err != nil {
884 return fmt.Errorf("failed to create new summarize provider: %w", err)
885 }
886
887 // Update the providers and provider ID
888 a.titleProvider = newTitleProvider
889 a.summarizeProvider = newSummarizeProvider
890 a.summarizeProviderID = string(smallModelProviderCfg.ID)
891 }
892
893 return nil
894}