1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "sync/atomic"
12 "time"
13
14 "github.com/charmbracelet/crush/internal/config"
15 fur "github.com/charmbracelet/crush/internal/fur/provider"
16 "github.com/charmbracelet/crush/internal/history"
17 "github.com/charmbracelet/crush/internal/llm/prompt"
18 "github.com/charmbracelet/crush/internal/llm/provider"
19 "github.com/charmbracelet/crush/internal/llm/tools"
20 "github.com/charmbracelet/crush/internal/log"
21 "github.com/charmbracelet/crush/internal/lsp"
22 "github.com/charmbracelet/crush/internal/message"
23 "github.com/charmbracelet/crush/internal/permission"
24 "github.com/charmbracelet/crush/internal/pubsub"
25 "github.com/charmbracelet/crush/internal/session"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() fur.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70
71 toolsDone atomic.Bool
72 tools []tools.BaseTool
73
74 provider provider.Provider
75 providerID string
76
77 titleProvider provider.Provider
78 summarizeProvider provider.Provider
79 summarizeProviderID string
80
81 activeRequests sync.Map
82}
83
84var agentPromptMap = map[string]prompt.PromptID{
85 "coder": prompt.PromptCoder,
86 "task": prompt.PromptTask,
87}
88
89func NewAgent(
90 agentCfg config.Agent,
91 // These services are needed in the tools
92 permissions permission.Service,
93 sessions session.Service,
94 messages message.Service,
95 history history.Service,
96 lspClients map[string]*lsp.Client,
97) (Service, error) {
98 ctx := context.Background()
99 cfg := config.Get()
100
101 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
102 if providerCfg == nil {
103 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
104 }
105 model := config.Get().GetModelByType(agentCfg.Model)
106
107 if model == nil {
108 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
109 }
110
111 promptID := agentPromptMap[agentCfg.ID]
112 if promptID == "" {
113 promptID = prompt.PromptDefault
114 }
115 opts := []provider.ProviderClientOption{
116 provider.WithModel(agentCfg.Model),
117 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
118 }
119 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
120 if err != nil {
121 return nil, err
122 }
123
124 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
125 var smallModelProviderCfg *config.ProviderConfig
126 if smallModelCfg.Provider == providerCfg.ID {
127 smallModelProviderCfg = providerCfg
128 } else {
129 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
130
131 if smallModelProviderCfg.ID == "" {
132 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
133 }
134 }
135 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
136 if smallModel.ID == "" {
137 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
138 }
139
140 titleOpts := []provider.ProviderClientOption{
141 provider.WithModel(config.SelectedModelTypeSmall),
142 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
143 }
144 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
145 if err != nil {
146 return nil, err
147 }
148 summarizeOpts := []provider.ProviderClientOption{
149 provider.WithModel(config.SelectedModelTypeSmall),
150 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
151 }
152 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
153 if err != nil {
154 return nil, err
155 }
156
157 var agentTool tools.BaseTool
158 if agentCfg.ID == "coder" {
159 taskAgentCfg := config.Get().Agents["task"]
160 if taskAgentCfg.ID == "" {
161 return nil, fmt.Errorf("task agent not found in config")
162 }
163 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
164 if err != nil {
165 return nil, fmt.Errorf("failed to create task agent: %w", err)
166 }
167
168 agentTool = NewAgentTool(
169 taskAgent,
170 sessions,
171 messages,
172 )
173 }
174
175 agent := &agent{
176 Broker: pubsub.NewBroker[AgentEvent](),
177 agentCfg: agentCfg,
178 provider: agentProvider,
179 providerID: string(providerCfg.ID),
180 messages: messages,
181 sessions: sessions,
182 titleProvider: titleProvider,
183 summarizeProvider: summarizeProvider,
184 summarizeProviderID: string(smallModelProviderCfg.ID),
185 activeRequests: sync.Map{},
186 }
187
188 go func() {
189 slog.Info("Initializing agent tools", "agent", agentCfg.ID)
190
191 cwd := cfg.WorkingDir()
192 allTools := []tools.BaseTool{
193 tools.NewBashTool(permissions, cwd),
194 tools.NewDownloadTool(permissions, cwd),
195 tools.NewEditTool(lspClients, permissions, history, cwd),
196 tools.NewFetchTool(permissions, cwd),
197 tools.NewGlobTool(cwd),
198 tools.NewGrepTool(cwd),
199 tools.NewLsTool(cwd),
200 tools.NewSourcegraphTool(),
201 tools.NewViewTool(lspClients, cwd),
202 tools.NewWriteTool(lspClients, permissions, history, cwd),
203 }
204
205 mcpTools := GetMCPTools(ctx, permissions, cfg)
206 if len(lspClients) > 0 {
207 mcpTools = append(mcpTools, tools.NewDiagnosticsTool(lspClients))
208 }
209 allTools = append(allTools, mcpTools...)
210
211 if agentTool != nil {
212 allTools = append(allTools, agentTool)
213 }
214
215 agentTools := []tools.BaseTool{}
216 if agentCfg.AllowedTools == nil {
217 agentTools = allTools
218 } else {
219 for _, tool := range allTools {
220 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
221 agentTools = append(agentTools, tool)
222 }
223 }
224 }
225
226 slog.Info("Initialized agent tools", "agent", agentCfg.ID)
227 agent.tools = agentTools
228 agent.toolsDone.Store(true)
229 }()
230
231 return agent, nil
232}
233
234func (a *agent) Model() fur.Model {
235 return *config.Get().GetModelByType(a.agentCfg.Model)
236}
237
238func (a *agent) Cancel(sessionID string) {
239 // Cancel regular requests
240 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
241 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
242 slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
243 cancel()
244 }
245 }
246
247 // Also check for summarize requests
248 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
249 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
250 slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
251 cancel()
252 }
253 }
254}
255
256func (a *agent) IsBusy() bool {
257 busy := false
258 a.activeRequests.Range(func(key, value any) bool {
259 if cancelFunc, ok := value.(context.CancelFunc); ok {
260 if cancelFunc != nil {
261 busy = true
262 return false
263 }
264 }
265 return true
266 })
267 return busy
268}
269
270func (a *agent) IsSessionBusy(sessionID string) bool {
271 _, busy := a.activeRequests.Load(sessionID)
272 return busy
273}
274
275func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
276 if content == "" {
277 return nil
278 }
279 if a.titleProvider == nil {
280 return nil
281 }
282 session, err := a.sessions.Get(ctx, sessionID)
283 if err != nil {
284 return err
285 }
286 parts := []message.ContentPart{message.TextContent{
287 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
288 }}
289
290 // Use streaming approach like summarization
291 response := a.titleProvider.StreamResponse(
292 ctx,
293 []message.Message{
294 {
295 Role: message.User,
296 Parts: parts,
297 },
298 },
299 make([]tools.BaseTool, 0),
300 )
301
302 var finalResponse *provider.ProviderResponse
303 for r := range response {
304 if r.Error != nil {
305 return r.Error
306 }
307 finalResponse = r.Response
308 }
309
310 if finalResponse == nil {
311 return fmt.Errorf("no response received from title provider")
312 }
313
314 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
315 if title == "" {
316 return nil
317 }
318
319 session.Title = title
320 _, err = a.sessions.Save(ctx, session)
321 return err
322}
323
324func (a *agent) err(err error) AgentEvent {
325 return AgentEvent{
326 Type: AgentEventTypeError,
327 Error: err,
328 }
329}
330
331func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
332 if !a.Model().SupportsImages && attachments != nil {
333 attachments = nil
334 }
335 events := make(chan AgentEvent)
336 if a.IsSessionBusy(sessionID) {
337 return nil, ErrSessionBusy
338 }
339
340 genCtx, cancel := context.WithCancel(ctx)
341
342 a.activeRequests.Store(sessionID, cancel)
343 go func() {
344 slog.Debug("Request started", "sessionID", sessionID)
345 defer log.RecoverPanic("agent.Run", func() {
346 events <- a.err(fmt.Errorf("panic while running the agent"))
347 })
348 var attachmentParts []message.ContentPart
349 for _, attachment := range attachments {
350 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
351 }
352 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
353 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
354 slog.Error(result.Error.Error())
355 }
356 slog.Debug("Request completed", "sessionID", sessionID)
357 a.activeRequests.Delete(sessionID)
358 cancel()
359 a.Publish(pubsub.CreatedEvent, result)
360 events <- result
361 close(events)
362 }()
363 return events, nil
364}
365
366func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
367 cfg := config.Get()
368 // List existing messages; if none, start title generation asynchronously.
369 msgs, err := a.messages.List(ctx, sessionID)
370 if err != nil {
371 return a.err(fmt.Errorf("failed to list messages: %w", err))
372 }
373 if len(msgs) == 0 {
374 go func() {
375 defer log.RecoverPanic("agent.Run", func() {
376 slog.Error("panic while generating title")
377 })
378 titleErr := a.generateTitle(context.Background(), sessionID, content)
379 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
380 slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
381 }
382 }()
383 }
384 session, err := a.sessions.Get(ctx, sessionID)
385 if err != nil {
386 return a.err(fmt.Errorf("failed to get session: %w", err))
387 }
388 if session.SummaryMessageID != "" {
389 summaryMsgInex := -1
390 for i, msg := range msgs {
391 if msg.ID == session.SummaryMessageID {
392 summaryMsgInex = i
393 break
394 }
395 }
396 if summaryMsgInex != -1 {
397 msgs = msgs[summaryMsgInex:]
398 msgs[0].Role = message.User
399 }
400 }
401
402 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
403 if err != nil {
404 return a.err(fmt.Errorf("failed to create user message: %w", err))
405 }
406 // Append the new user message to the conversation history.
407 msgHistory := append(msgs, userMsg)
408
409 for {
410 // Check for cancellation before each iteration
411 select {
412 case <-ctx.Done():
413 return a.err(ctx.Err())
414 default:
415 // Continue processing
416 }
417 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
418 if err != nil {
419 if errors.Is(err, context.Canceled) {
420 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
421 a.messages.Update(context.Background(), agentMessage)
422 return a.err(ErrRequestCancelled)
423 }
424 return a.err(fmt.Errorf("failed to process events: %w", err))
425 }
426 if cfg.Options.Debug {
427 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
428 }
429 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
430 // We are not done, we need to respond with the tool response
431 msgHistory = append(msgHistory, agentMessage, *toolResults)
432 continue
433 }
434 return AgentEvent{
435 Type: AgentEventTypeResponse,
436 Message: agentMessage,
437 Done: true,
438 }
439 }
440}
441
442func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
443 parts := []message.ContentPart{message.TextContent{Text: content}}
444 parts = append(parts, attachmentParts...)
445 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
446 Role: message.User,
447 Parts: parts,
448 })
449}
450
451func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
452 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
453 if !a.toolsDone.Load() {
454 return message.Message{}, nil, fmt.Errorf("tools not initialized yet")
455 }
456 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
457
458 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
459 Role: message.Assistant,
460 Parts: []message.ContentPart{},
461 Model: a.Model().ID,
462 Provider: a.providerID,
463 })
464 if err != nil {
465 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
466 }
467
468 // Add the session and message ID into the context if needed by tools.
469 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
470
471 // Process each event in the stream.
472 for event := range eventChan {
473 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
474 if errors.Is(processErr, context.Canceled) {
475 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476 } else {
477 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
478 }
479 return assistantMsg, nil, processErr
480 }
481 if ctx.Err() != nil {
482 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
483 return assistantMsg, nil, ctx.Err()
484 }
485 }
486
487 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
488 toolCalls := assistantMsg.ToolCalls()
489 for i, toolCall := range toolCalls {
490 select {
491 case <-ctx.Done():
492 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
493 // Make all future tool calls cancelled
494 for j := i; j < len(toolCalls); j++ {
495 toolResults[j] = message.ToolResult{
496 ToolCallID: toolCalls[j].ID,
497 Content: "Tool execution canceled by user",
498 IsError: true,
499 }
500 }
501 goto out
502 default:
503 // Continue processing
504 var tool tools.BaseTool
505 for _, availableTool := range a.tools {
506 if availableTool.Info().Name == toolCall.Name {
507 tool = availableTool
508 break
509 }
510 }
511
512 // Tool not found
513 if tool == nil {
514 toolResults[i] = message.ToolResult{
515 ToolCallID: toolCall.ID,
516 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
517 IsError: true,
518 }
519 continue
520 }
521
522 // Run tool in goroutine to allow cancellation
523 type toolExecResult struct {
524 response tools.ToolResponse
525 err error
526 }
527 resultChan := make(chan toolExecResult, 1)
528
529 go func() {
530 response, err := tool.Run(ctx, tools.ToolCall{
531 ID: toolCall.ID,
532 Name: toolCall.Name,
533 Input: toolCall.Input,
534 })
535 resultChan <- toolExecResult{response: response, err: err}
536 }()
537
538 var toolResponse tools.ToolResponse
539 var toolErr error
540
541 select {
542 case <-ctx.Done():
543 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
544 // Mark remaining tool calls as cancelled
545 for j := i; j < len(toolCalls); j++ {
546 toolResults[j] = message.ToolResult{
547 ToolCallID: toolCalls[j].ID,
548 Content: "Tool execution canceled by user",
549 IsError: true,
550 }
551 }
552 goto out
553 case result := <-resultChan:
554 toolResponse = result.response
555 toolErr = result.err
556 }
557
558 if toolErr != nil {
559 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
560 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
561 toolResults[i] = message.ToolResult{
562 ToolCallID: toolCall.ID,
563 Content: "Permission denied",
564 IsError: true,
565 }
566 for j := i + 1; j < len(toolCalls); j++ {
567 toolResults[j] = message.ToolResult{
568 ToolCallID: toolCalls[j].ID,
569 Content: "Tool execution canceled by user",
570 IsError: true,
571 }
572 }
573 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
574 break
575 }
576 }
577 toolResults[i] = message.ToolResult{
578 ToolCallID: toolCall.ID,
579 Content: toolResponse.Content,
580 Metadata: toolResponse.Metadata,
581 IsError: toolResponse.IsError,
582 }
583 }
584 }
585out:
586 if len(toolResults) == 0 {
587 return assistantMsg, nil, nil
588 }
589 parts := make([]message.ContentPart, 0)
590 for _, tr := range toolResults {
591 parts = append(parts, tr)
592 }
593 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
594 Role: message.Tool,
595 Parts: parts,
596 Provider: a.providerID,
597 })
598 if err != nil {
599 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
600 }
601
602 return assistantMsg, &msg, err
603}
604
605func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
606 msg.AddFinish(finishReason, message, details)
607 _ = a.messages.Update(ctx, *msg)
608}
609
610func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
611 select {
612 case <-ctx.Done():
613 return ctx.Err()
614 default:
615 // Continue processing.
616 }
617
618 switch event.Type {
619 case provider.EventThinkingDelta:
620 assistantMsg.AppendReasoningContent(event.Thinking)
621 return a.messages.Update(ctx, *assistantMsg)
622 case provider.EventSignatureDelta:
623 assistantMsg.AppendReasoningSignature(event.Signature)
624 return a.messages.Update(ctx, *assistantMsg)
625 case provider.EventContentDelta:
626 assistantMsg.FinishThinking()
627 assistantMsg.AppendContent(event.Content)
628 return a.messages.Update(ctx, *assistantMsg)
629 case provider.EventToolUseStart:
630 assistantMsg.FinishThinking()
631 slog.Info("Tool call started", "toolCall", event.ToolCall)
632 assistantMsg.AddToolCall(*event.ToolCall)
633 return a.messages.Update(ctx, *assistantMsg)
634 case provider.EventToolUseDelta:
635 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
636 return a.messages.Update(ctx, *assistantMsg)
637 case provider.EventToolUseStop:
638 slog.Info("Finished tool call", "toolCall", event.ToolCall)
639 assistantMsg.FinishToolCall(event.ToolCall.ID)
640 return a.messages.Update(ctx, *assistantMsg)
641 case provider.EventError:
642 return event.Error
643 case provider.EventComplete:
644 assistantMsg.FinishThinking()
645 assistantMsg.SetToolCalls(event.Response.ToolCalls)
646 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
647 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
648 return fmt.Errorf("failed to update message: %w", err)
649 }
650 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
651 }
652
653 return nil
654}
655
656func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
657 sess, err := a.sessions.Get(ctx, sessionID)
658 if err != nil {
659 return fmt.Errorf("failed to get session: %w", err)
660 }
661
662 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
663 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
664 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
665 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
666
667 sess.Cost += cost
668 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
669 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
670
671 _, err = a.sessions.Save(ctx, sess)
672 if err != nil {
673 return fmt.Errorf("failed to save session: %w", err)
674 }
675 return nil
676}
677
678func (a *agent) Summarize(ctx context.Context, sessionID string) error {
679 if a.summarizeProvider == nil {
680 return fmt.Errorf("summarize provider not available")
681 }
682
683 // Check if session is busy
684 if a.IsSessionBusy(sessionID) {
685 return ErrSessionBusy
686 }
687
688 // Create a new context with cancellation
689 summarizeCtx, cancel := context.WithCancel(ctx)
690
691 // Store the cancel function in activeRequests to allow cancellation
692 a.activeRequests.Store(sessionID+"-summarize", cancel)
693
694 go func() {
695 defer a.activeRequests.Delete(sessionID + "-summarize")
696 defer cancel()
697 event := AgentEvent{
698 Type: AgentEventTypeSummarize,
699 Progress: "Starting summarization...",
700 }
701
702 a.Publish(pubsub.CreatedEvent, event)
703 // Get all messages from the session
704 msgs, err := a.messages.List(summarizeCtx, sessionID)
705 if err != nil {
706 event = AgentEvent{
707 Type: AgentEventTypeError,
708 Error: fmt.Errorf("failed to list messages: %w", err),
709 Done: true,
710 }
711 a.Publish(pubsub.CreatedEvent, event)
712 return
713 }
714 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
715
716 if len(msgs) == 0 {
717 event = AgentEvent{
718 Type: AgentEventTypeError,
719 Error: fmt.Errorf("no messages to summarize"),
720 Done: true,
721 }
722 a.Publish(pubsub.CreatedEvent, event)
723 return
724 }
725
726 event = AgentEvent{
727 Type: AgentEventTypeSummarize,
728 Progress: "Analyzing conversation...",
729 }
730 a.Publish(pubsub.CreatedEvent, event)
731
732 // Add a system message to guide the summarization
733 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
734
735 // Create a new message with the summarize prompt
736 promptMsg := message.Message{
737 Role: message.User,
738 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
739 }
740
741 // Append the prompt to the messages
742 msgsWithPrompt := append(msgs, promptMsg)
743
744 event = AgentEvent{
745 Type: AgentEventTypeSummarize,
746 Progress: "Generating summary...",
747 }
748
749 a.Publish(pubsub.CreatedEvent, event)
750
751 // Send the messages to the summarize provider
752 response := a.summarizeProvider.StreamResponse(
753 summarizeCtx,
754 msgsWithPrompt,
755 make([]tools.BaseTool, 0),
756 )
757 var finalResponse *provider.ProviderResponse
758 for r := range response {
759 if r.Error != nil {
760 event = AgentEvent{
761 Type: AgentEventTypeError,
762 Error: fmt.Errorf("failed to summarize: %w", err),
763 Done: true,
764 }
765 a.Publish(pubsub.CreatedEvent, event)
766 return
767 }
768 finalResponse = r.Response
769 }
770
771 summary := strings.TrimSpace(finalResponse.Content)
772 if summary == "" {
773 event = AgentEvent{
774 Type: AgentEventTypeError,
775 Error: fmt.Errorf("empty summary returned"),
776 Done: true,
777 }
778 a.Publish(pubsub.CreatedEvent, event)
779 return
780 }
781 event = AgentEvent{
782 Type: AgentEventTypeSummarize,
783 Progress: "Creating new session...",
784 }
785
786 a.Publish(pubsub.CreatedEvent, event)
787 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
788 if err != nil {
789 event = AgentEvent{
790 Type: AgentEventTypeError,
791 Error: fmt.Errorf("failed to get session: %w", err),
792 Done: true,
793 }
794
795 a.Publish(pubsub.CreatedEvent, event)
796 return
797 }
798 // Create a message in the new session with the summary
799 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
800 Role: message.Assistant,
801 Parts: []message.ContentPart{
802 message.TextContent{Text: summary},
803 message.Finish{
804 Reason: message.FinishReasonEndTurn,
805 Time: time.Now().Unix(),
806 },
807 },
808 Model: a.summarizeProvider.Model().ID,
809 Provider: a.summarizeProviderID,
810 })
811 if err != nil {
812 event = AgentEvent{
813 Type: AgentEventTypeError,
814 Error: fmt.Errorf("failed to create summary message: %w", err),
815 Done: true,
816 }
817
818 a.Publish(pubsub.CreatedEvent, event)
819 return
820 }
821 oldSession.SummaryMessageID = msg.ID
822 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
823 oldSession.PromptTokens = 0
824 model := a.summarizeProvider.Model()
825 usage := finalResponse.Usage
826 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
827 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
828 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
829 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
830 oldSession.Cost += cost
831 _, err = a.sessions.Save(summarizeCtx, oldSession)
832 if err != nil {
833 event = AgentEvent{
834 Type: AgentEventTypeError,
835 Error: fmt.Errorf("failed to save session: %w", err),
836 Done: true,
837 }
838 a.Publish(pubsub.CreatedEvent, event)
839 }
840
841 event = AgentEvent{
842 Type: AgentEventTypeSummarize,
843 SessionID: oldSession.ID,
844 Progress: "Summary complete",
845 Done: true,
846 }
847 a.Publish(pubsub.CreatedEvent, event)
848 // Send final success event with the new session ID
849 }()
850
851 return nil
852}
853
854func (a *agent) CancelAll() {
855 if !a.IsBusy() {
856 return
857 }
858 a.activeRequests.Range(func(key, value any) bool {
859 a.Cancel(key.(string)) // key is sessionID
860 return true
861 })
862
863 timeout := time.After(5 * time.Second)
864 for a.IsBusy() {
865 select {
866 case <-timeout:
867 return
868 default:
869 time.Sleep(200 * time.Millisecond)
870 }
871 }
872}
873
874func (a *agent) UpdateModel() error {
875 cfg := config.Get()
876
877 // Get current provider configuration
878 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
879 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
880 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
881 }
882
883 // Check if provider has changed
884 if string(currentProviderCfg.ID) != a.providerID {
885 // Provider changed, need to recreate the main provider
886 model := cfg.GetModelByType(a.agentCfg.Model)
887 if model.ID == "" {
888 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
889 }
890
891 promptID := agentPromptMap[a.agentCfg.ID]
892 if promptID == "" {
893 promptID = prompt.PromptDefault
894 }
895
896 opts := []provider.ProviderClientOption{
897 provider.WithModel(a.agentCfg.Model),
898 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
899 }
900
901 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
902 if err != nil {
903 return fmt.Errorf("failed to create new provider: %w", err)
904 }
905
906 // Update the provider and provider ID
907 a.provider = newProvider
908 a.providerID = string(currentProviderCfg.ID)
909 }
910
911 // Check if small model provider has changed (affects title and summarize providers)
912 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
913 var smallModelProviderCfg config.ProviderConfig
914
915 for _, p := range cfg.Providers {
916 if p.ID == smallModelCfg.Provider {
917 smallModelProviderCfg = p
918 break
919 }
920 }
921
922 if smallModelProviderCfg.ID == "" {
923 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
924 }
925
926 // Check if summarize provider has changed
927 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
928 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
929 if smallModel == nil {
930 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
931 }
932
933 // Recreate title provider
934 titleOpts := []provider.ProviderClientOption{
935 provider.WithModel(config.SelectedModelTypeSmall),
936 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
937 // We want the title to be short, so we limit the max tokens
938 provider.WithMaxTokens(40),
939 }
940 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
941 if err != nil {
942 return fmt.Errorf("failed to create new title provider: %w", err)
943 }
944
945 // Recreate summarize provider
946 summarizeOpts := []provider.ProviderClientOption{
947 provider.WithModel(config.SelectedModelTypeSmall),
948 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
949 }
950 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
951 if err != nil {
952 return fmt.Errorf("failed to create new summarize provider: %w", err)
953 }
954
955 // Update the providers and provider ID
956 a.titleProvider = newTitleProvider
957 a.summarizeProvider = newSummarizeProvider
958 a.summarizeProviderID = string(smallModelProviderCfg.ID)
959 }
960
961 return nil
962}