1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 fur "github.com/charmbracelet/crush/internal/fur/provider"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25 "github.com/charmbracelet/crush/internal/shell"
26)
27
28// Common errors
29var (
30 ErrRequestCancelled = errors.New("request canceled by user")
31 ErrSessionBusy = errors.New("session is currently processing another request")
32)
33
34type AgentEventType string
35
36const (
37 AgentEventTypeError AgentEventType = "error"
38 AgentEventTypeResponse AgentEventType = "response"
39 AgentEventTypeSummarize AgentEventType = "summarize"
40)
41
42type AgentEvent struct {
43 Type AgentEventType
44 Message message.Message
45 Error error
46
47 // When summarizing
48 SessionID string
49 Progress string
50 Done bool
51}
52
53type Service interface {
54 pubsub.Suscriber[AgentEvent]
55 Model() fur.Model
56 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
57 Cancel(sessionID string)
58 CancelAll()
59 IsSessionBusy(sessionID string) bool
60 IsBusy() bool
61 Summarize(ctx context.Context, sessionID string) error
62 UpdateModel() error
63}
64
65type agent struct {
66 *pubsub.Broker[AgentEvent]
67 agentCfg config.Agent
68 sessions session.Service
69 messages message.Service
70
71 tools []tools.BaseTool
72 provider provider.Provider
73 providerID string
74
75 titleProvider provider.Provider
76 summarizeProvider provider.Provider
77 summarizeProviderID string
78
79 activeRequests sync.Map
80}
81
82var agentPromptMap = map[string]prompt.PromptID{
83 "coder": prompt.PromptCoder,
84 "task": prompt.PromptTask,
85}
86
87func NewAgent(
88 agentCfg config.Agent,
89 // These services are needed in the tools
90 permissions permission.Service,
91 sessions session.Service,
92 messages message.Service,
93 history history.Service,
94 lspClients map[string]*lsp.Client,
95) (Service, error) {
96 ctx := context.Background()
97 cfg := config.Get()
98 otherTools := GetMCPTools(ctx, permissions, cfg)
99 if len(lspClients) > 0 {
100 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
101 }
102
103 cwd := cfg.WorkingDir()
104 allTools := []tools.BaseTool{
105 tools.NewBashTool(permissions, cwd),
106 tools.NewDownloadTool(permissions, cwd),
107 tools.NewEditTool(lspClients, permissions, history, cwd),
108 tools.NewFetchTool(permissions, cwd),
109 tools.NewGlobTool(cwd),
110 tools.NewGrepTool(cwd),
111 tools.NewLsTool(cwd),
112 tools.NewSourcegraphTool(),
113 tools.NewViewTool(lspClients, cwd),
114 tools.NewWriteTool(lspClients, permissions, history, cwd),
115 }
116
117 if agentCfg.ID == "coder" {
118 taskAgentCfg := config.Get().Agents["task"]
119 if taskAgentCfg.ID == "" {
120 return nil, fmt.Errorf("task agent not found in config")
121 }
122 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
123 if err != nil {
124 return nil, fmt.Errorf("failed to create task agent: %w", err)
125 }
126
127 allTools = append(
128 allTools,
129 NewAgentTool(
130 taskAgent,
131 sessions,
132 messages,
133 ),
134 )
135 }
136
137 allTools = append(allTools, otherTools...)
138 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
139 if providerCfg == nil {
140 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
141 }
142 model := config.Get().GetModelByType(agentCfg.Model)
143
144 if model == nil {
145 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
146 }
147
148 promptID := agentPromptMap[agentCfg.ID]
149 if promptID == "" {
150 promptID = prompt.PromptDefault
151 }
152 opts := []provider.ProviderClientOption{
153 provider.WithModel(agentCfg.Model),
154 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
155 }
156 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
157 if err != nil {
158 return nil, err
159 }
160
161 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
162 var smallModelProviderCfg *config.ProviderConfig
163 if smallModelCfg.Provider == providerCfg.ID {
164 smallModelProviderCfg = providerCfg
165 } else {
166 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
167
168 if smallModelProviderCfg.ID == "" {
169 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
170 }
171 }
172 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
173 if smallModel.ID == "" {
174 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
175 }
176
177 titleOpts := []provider.ProviderClientOption{
178 provider.WithModel(config.SelectedModelTypeSmall),
179 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
180 }
181 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
182 if err != nil {
183 return nil, err
184 }
185 summarizeOpts := []provider.ProviderClientOption{
186 provider.WithModel(config.SelectedModelTypeSmall),
187 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
188 }
189 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
190 if err != nil {
191 return nil, err
192 }
193
194 agentTools := []tools.BaseTool{}
195 if agentCfg.AllowedTools == nil {
196 agentTools = allTools
197 } else {
198 for _, tool := range allTools {
199 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
200 agentTools = append(agentTools, tool)
201 }
202 }
203 }
204
205 agent := &agent{
206 Broker: pubsub.NewBroker[AgentEvent](),
207 agentCfg: agentCfg,
208 provider: agentProvider,
209 providerID: string(providerCfg.ID),
210 messages: messages,
211 sessions: sessions,
212 tools: agentTools,
213 titleProvider: titleProvider,
214 summarizeProvider: summarizeProvider,
215 summarizeProviderID: string(smallModelProviderCfg.ID),
216 activeRequests: sync.Map{},
217 }
218
219 return agent, nil
220}
221
222func (a *agent) Model() fur.Model {
223 return *config.Get().GetModelByType(a.agentCfg.Model)
224}
225
226func (a *agent) Cancel(sessionID string) {
227 // Cancel regular requests
228 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
229 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
230 slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
231 cancel()
232 }
233 }
234
235 // Also check for summarize requests
236 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
237 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
238 slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
239 cancel()
240 }
241 }
242}
243
244func (a *agent) IsBusy() bool {
245 busy := false
246 a.activeRequests.Range(func(key, value any) bool {
247 if cancelFunc, ok := value.(context.CancelFunc); ok {
248 if cancelFunc != nil {
249 busy = true
250 return false
251 }
252 }
253 return true
254 })
255 return busy
256}
257
258func (a *agent) IsSessionBusy(sessionID string) bool {
259 _, busy := a.activeRequests.Load(sessionID)
260 return busy
261}
262
263func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
264 if content == "" {
265 return nil
266 }
267 if a.titleProvider == nil {
268 return nil
269 }
270 session, err := a.sessions.Get(ctx, sessionID)
271 if err != nil {
272 return err
273 }
274 parts := []message.ContentPart{message.TextContent{
275 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
276 }}
277
278 // Use streaming approach like summarization
279 response := a.titleProvider.StreamResponse(
280 ctx,
281 []message.Message{
282 {
283 Role: message.User,
284 Parts: parts,
285 },
286 },
287 make([]tools.BaseTool, 0),
288 )
289
290 var finalResponse *provider.ProviderResponse
291 for r := range response {
292 if r.Error != nil {
293 return r.Error
294 }
295 finalResponse = r.Response
296 }
297
298 if finalResponse == nil {
299 return fmt.Errorf("no response received from title provider")
300 }
301
302 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
303 if title == "" {
304 return nil
305 }
306
307 session.Title = title
308 _, err = a.sessions.Save(ctx, session)
309 return err
310}
311
312func (a *agent) err(err error) AgentEvent {
313 return AgentEvent{
314 Type: AgentEventTypeError,
315 Error: err,
316 }
317}
318
319func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
320 if !a.Model().SupportsImages && attachments != nil {
321 attachments = nil
322 }
323 events := make(chan AgentEvent)
324 if a.IsSessionBusy(sessionID) {
325 return nil, ErrSessionBusy
326 }
327
328 genCtx, cancel := context.WithCancel(ctx)
329
330 a.activeRequests.Store(sessionID, cancel)
331 go func() {
332 slog.Debug("Request started", "sessionID", sessionID)
333 defer log.RecoverPanic("agent.Run", func() {
334 events <- a.err(fmt.Errorf("panic while running the agent"))
335 })
336 var attachmentParts []message.ContentPart
337 for _, attachment := range attachments {
338 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
339 }
340 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
341 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
342 slog.Error(result.Error.Error())
343 }
344 slog.Debug("Request completed", "sessionID", sessionID)
345 a.activeRequests.Delete(sessionID)
346 cancel()
347 a.Publish(pubsub.CreatedEvent, result)
348 events <- result
349 close(events)
350 }()
351 return events, nil
352}
353
354func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
355 cfg := config.Get()
356 // List existing messages; if none, start title generation asynchronously.
357 msgs, err := a.messages.List(ctx, sessionID)
358 if err != nil {
359 return a.err(fmt.Errorf("failed to list messages: %w", err))
360 }
361 if len(msgs) == 0 {
362 go func() {
363 defer log.RecoverPanic("agent.Run", func() {
364 slog.Error("panic while generating title")
365 })
366 titleErr := a.generateTitle(context.Background(), sessionID, content)
367 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
368 slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
369 }
370 }()
371 }
372 session, err := a.sessions.Get(ctx, sessionID)
373 if err != nil {
374 return a.err(fmt.Errorf("failed to get session: %w", err))
375 }
376 if session.SummaryMessageID != "" {
377 summaryMsgInex := -1
378 for i, msg := range msgs {
379 if msg.ID == session.SummaryMessageID {
380 summaryMsgInex = i
381 break
382 }
383 }
384 if summaryMsgInex != -1 {
385 msgs = msgs[summaryMsgInex:]
386 msgs[0].Role = message.User
387 }
388 }
389
390 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
391 if err != nil {
392 return a.err(fmt.Errorf("failed to create user message: %w", err))
393 }
394 // Append the new user message to the conversation history.
395 msgHistory := append(msgs, userMsg)
396
397 for {
398 // Check for cancellation before each iteration
399 select {
400 case <-ctx.Done():
401 return a.err(ctx.Err())
402 default:
403 // Continue processing
404 }
405 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
406 if err != nil {
407 if errors.Is(err, context.Canceled) {
408 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
409 a.messages.Update(context.Background(), agentMessage)
410 return a.err(ErrRequestCancelled)
411 }
412 return a.err(fmt.Errorf("failed to process events: %w", err))
413 }
414 if cfg.Options.Debug {
415 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
416 }
417 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
418 // We are not done, we need to respond with the tool response
419 msgHistory = append(msgHistory, agentMessage, *toolResults)
420 continue
421 }
422 return AgentEvent{
423 Type: AgentEventTypeResponse,
424 Message: agentMessage,
425 Done: true,
426 }
427 }
428}
429
430func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
431 parts := []message.ContentPart{message.TextContent{Text: content}}
432 parts = append(parts, attachmentParts...)
433 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
434 Role: message.User,
435 Parts: parts,
436 })
437}
438
439func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
440 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
441 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
442
443 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
444 Role: message.Assistant,
445 Parts: []message.ContentPart{},
446 Model: a.Model().ID,
447 Provider: a.providerID,
448 })
449 if err != nil {
450 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
451 }
452
453 // Add the session and message ID into the context if needed by tools.
454 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
455
456 // Process each event in the stream.
457 for event := range eventChan {
458 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
459 if errors.Is(processErr, context.Canceled) {
460 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
461 } else {
462 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
463 }
464 return assistantMsg, nil, processErr
465 }
466 if ctx.Err() != nil {
467 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
468 return assistantMsg, nil, ctx.Err()
469 }
470 }
471
472 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
473 toolCalls := assistantMsg.ToolCalls()
474 for i, toolCall := range toolCalls {
475 select {
476 case <-ctx.Done():
477 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
478 // Make all future tool calls cancelled
479 for j := i; j < len(toolCalls); j++ {
480 toolResults[j] = message.ToolResult{
481 ToolCallID: toolCalls[j].ID,
482 Content: "Tool execution canceled by user",
483 IsError: true,
484 }
485 }
486 goto out
487 default:
488 // Continue processing
489 var tool tools.BaseTool
490 for _, availableTool := range a.tools {
491 if availableTool.Info().Name == toolCall.Name {
492 tool = availableTool
493 break
494 }
495 }
496
497 // Tool not found
498 if tool == nil {
499 toolResults[i] = message.ToolResult{
500 ToolCallID: toolCall.ID,
501 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
502 IsError: true,
503 }
504 continue
505 }
506
507 // Run tool in goroutine to allow cancellation
508 type toolExecResult struct {
509 response tools.ToolResponse
510 err error
511 }
512 resultChan := make(chan toolExecResult, 1)
513
514 go func() {
515 response, err := tool.Run(ctx, tools.ToolCall{
516 ID: toolCall.ID,
517 Name: toolCall.Name,
518 Input: toolCall.Input,
519 })
520 resultChan <- toolExecResult{response: response, err: err}
521 }()
522
523 var toolResponse tools.ToolResponse
524 var toolErr error
525
526 select {
527 case <-ctx.Done():
528 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
529 // Mark remaining tool calls as cancelled
530 for j := i; j < len(toolCalls); j++ {
531 toolResults[j] = message.ToolResult{
532 ToolCallID: toolCalls[j].ID,
533 Content: "Tool execution canceled by user",
534 IsError: true,
535 }
536 }
537 goto out
538 case result := <-resultChan:
539 toolResponse = result.response
540 toolErr = result.err
541 }
542
543 if toolErr != nil {
544 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
545 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
546 toolResults[i] = message.ToolResult{
547 ToolCallID: toolCall.ID,
548 Content: "Permission denied",
549 IsError: true,
550 }
551 for j := i + 1; j < len(toolCalls); j++ {
552 toolResults[j] = message.ToolResult{
553 ToolCallID: toolCalls[j].ID,
554 Content: "Tool execution canceled by user",
555 IsError: true,
556 }
557 }
558 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
559 break
560 }
561 }
562 toolResults[i] = message.ToolResult{
563 ToolCallID: toolCall.ID,
564 Content: toolResponse.Content,
565 Metadata: toolResponse.Metadata,
566 IsError: toolResponse.IsError,
567 }
568 }
569 }
570out:
571 if len(toolResults) == 0 {
572 return assistantMsg, nil, nil
573 }
574 parts := make([]message.ContentPart, 0)
575 for _, tr := range toolResults {
576 parts = append(parts, tr)
577 }
578 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
579 Role: message.Tool,
580 Parts: parts,
581 Provider: a.providerID,
582 })
583 if err != nil {
584 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
585 }
586
587 return assistantMsg, &msg, err
588}
589
590func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
591 msg.AddFinish(finishReason, message, details)
592 _ = a.messages.Update(ctx, *msg)
593}
594
595func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
596 select {
597 case <-ctx.Done():
598 return ctx.Err()
599 default:
600 // Continue processing.
601 }
602
603 switch event.Type {
604 case provider.EventThinkingDelta:
605 assistantMsg.AppendReasoningContent(event.Thinking)
606 return a.messages.Update(ctx, *assistantMsg)
607 case provider.EventSignatureDelta:
608 assistantMsg.AppendReasoningSignature(event.Signature)
609 return a.messages.Update(ctx, *assistantMsg)
610 case provider.EventContentDelta:
611 assistantMsg.FinishThinking()
612 assistantMsg.AppendContent(event.Content)
613 return a.messages.Update(ctx, *assistantMsg)
614 case provider.EventToolUseStart:
615 assistantMsg.FinishThinking()
616 slog.Info("Tool call started", "toolCall", event.ToolCall)
617 assistantMsg.AddToolCall(*event.ToolCall)
618 return a.messages.Update(ctx, *assistantMsg)
619 case provider.EventToolUseDelta:
620 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
621 return a.messages.Update(ctx, *assistantMsg)
622 case provider.EventToolUseStop:
623 slog.Info("Finished tool call", "toolCall", event.ToolCall)
624 assistantMsg.FinishToolCall(event.ToolCall.ID)
625 return a.messages.Update(ctx, *assistantMsg)
626 case provider.EventError:
627 return event.Error
628 case provider.EventComplete:
629 assistantMsg.FinishThinking()
630 assistantMsg.SetToolCalls(event.Response.ToolCalls)
631 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
632 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
633 return fmt.Errorf("failed to update message: %w", err)
634 }
635 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
636 }
637
638 return nil
639}
640
641func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
642 sess, err := a.sessions.Get(ctx, sessionID)
643 if err != nil {
644 return fmt.Errorf("failed to get session: %w", err)
645 }
646
647 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
648 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
649 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
650 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
651
652 sess.Cost += cost
653 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
654 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
655
656 _, err = a.sessions.Save(ctx, sess)
657 if err != nil {
658 return fmt.Errorf("failed to save session: %w", err)
659 }
660 return nil
661}
662
663func (a *agent) Summarize(ctx context.Context, sessionID string) error {
664 if a.summarizeProvider == nil {
665 return fmt.Errorf("summarize provider not available")
666 }
667
668 // Check if session is busy
669 if a.IsSessionBusy(sessionID) {
670 return ErrSessionBusy
671 }
672
673 // Create a new context with cancellation
674 summarizeCtx, cancel := context.WithCancel(ctx)
675
676 // Store the cancel function in activeRequests to allow cancellation
677 a.activeRequests.Store(sessionID+"-summarize", cancel)
678
679 go func() {
680 defer a.activeRequests.Delete(sessionID + "-summarize")
681 defer cancel()
682 event := AgentEvent{
683 Type: AgentEventTypeSummarize,
684 Progress: "Starting summarization...",
685 }
686
687 a.Publish(pubsub.CreatedEvent, event)
688 // Get all messages from the session
689 msgs, err := a.messages.List(summarizeCtx, sessionID)
690 if err != nil {
691 event = AgentEvent{
692 Type: AgentEventTypeError,
693 Error: fmt.Errorf("failed to list messages: %w", err),
694 Done: true,
695 }
696 a.Publish(pubsub.CreatedEvent, event)
697 return
698 }
699 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
700
701 if len(msgs) == 0 {
702 event = AgentEvent{
703 Type: AgentEventTypeError,
704 Error: fmt.Errorf("no messages to summarize"),
705 Done: true,
706 }
707 a.Publish(pubsub.CreatedEvent, event)
708 return
709 }
710
711 event = AgentEvent{
712 Type: AgentEventTypeSummarize,
713 Progress: "Analyzing conversation...",
714 }
715 a.Publish(pubsub.CreatedEvent, event)
716
717 // Add a system message to guide the summarization
718 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
719
720 // Create a new message with the summarize prompt
721 promptMsg := message.Message{
722 Role: message.User,
723 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
724 }
725
726 // Append the prompt to the messages
727 msgsWithPrompt := append(msgs, promptMsg)
728
729 event = AgentEvent{
730 Type: AgentEventTypeSummarize,
731 Progress: "Generating summary...",
732 }
733
734 a.Publish(pubsub.CreatedEvent, event)
735
736 // Send the messages to the summarize provider
737 response := a.summarizeProvider.StreamResponse(
738 summarizeCtx,
739 msgsWithPrompt,
740 make([]tools.BaseTool, 0),
741 )
742 var finalResponse *provider.ProviderResponse
743 for r := range response {
744 if r.Error != nil {
745 event = AgentEvent{
746 Type: AgentEventTypeError,
747 Error: fmt.Errorf("failed to summarize: %w", err),
748 Done: true,
749 }
750 a.Publish(pubsub.CreatedEvent, event)
751 return
752 }
753 finalResponse = r.Response
754 }
755
756 summary := strings.TrimSpace(finalResponse.Content)
757 if summary == "" {
758 event = AgentEvent{
759 Type: AgentEventTypeError,
760 Error: fmt.Errorf("empty summary returned"),
761 Done: true,
762 }
763 a.Publish(pubsub.CreatedEvent, event)
764 return
765 }
766 shell := shell.GetPersistentShell(config.Get().WorkingDir())
767 summary += "\n\n**Current working directory of the persistent shell**\n\n" + shell.GetWorkingDir()
768 event = AgentEvent{
769 Type: AgentEventTypeSummarize,
770 Progress: "Creating new session...",
771 }
772
773 a.Publish(pubsub.CreatedEvent, event)
774 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
775 if err != nil {
776 event = AgentEvent{
777 Type: AgentEventTypeError,
778 Error: fmt.Errorf("failed to get session: %w", err),
779 Done: true,
780 }
781
782 a.Publish(pubsub.CreatedEvent, event)
783 return
784 }
785 // Create a message in the new session with the summary
786 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
787 Role: message.Assistant,
788 Parts: []message.ContentPart{
789 message.TextContent{Text: summary},
790 message.Finish{
791 Reason: message.FinishReasonEndTurn,
792 Time: time.Now().Unix(),
793 },
794 },
795 Model: a.summarizeProvider.Model().ID,
796 Provider: a.summarizeProviderID,
797 })
798 if err != nil {
799 event = AgentEvent{
800 Type: AgentEventTypeError,
801 Error: fmt.Errorf("failed to create summary message: %w", err),
802 Done: true,
803 }
804
805 a.Publish(pubsub.CreatedEvent, event)
806 return
807 }
808 oldSession.SummaryMessageID = msg.ID
809 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
810 oldSession.PromptTokens = 0
811 model := a.summarizeProvider.Model()
812 usage := finalResponse.Usage
813 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
814 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
815 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
816 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
817 oldSession.Cost += cost
818 _, err = a.sessions.Save(summarizeCtx, oldSession)
819 if err != nil {
820 event = AgentEvent{
821 Type: AgentEventTypeError,
822 Error: fmt.Errorf("failed to save session: %w", err),
823 Done: true,
824 }
825 a.Publish(pubsub.CreatedEvent, event)
826 }
827
828 event = AgentEvent{
829 Type: AgentEventTypeSummarize,
830 SessionID: oldSession.ID,
831 Progress: "Summary complete",
832 Done: true,
833 }
834 a.Publish(pubsub.CreatedEvent, event)
835 // Send final success event with the new session ID
836 }()
837
838 return nil
839}
840
841func (a *agent) CancelAll() {
842 if !a.IsBusy() {
843 return
844 }
845 a.activeRequests.Range(func(key, value any) bool {
846 a.Cancel(key.(string)) // key is sessionID
847 return true
848 })
849
850 timeout := time.After(5 * time.Second)
851 for a.IsBusy() {
852 select {
853 case <-timeout:
854 return
855 default:
856 time.Sleep(200 * time.Millisecond)
857 }
858 }
859}
860
861func (a *agent) UpdateModel() error {
862 cfg := config.Get()
863
864 // Get current provider configuration
865 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
866 if currentProviderCfg == nil || currentProviderCfg.ID == "" {
867 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
868 }
869
870 // Check if provider has changed
871 if string(currentProviderCfg.ID) != a.providerID {
872 // Provider changed, need to recreate the main provider
873 model := cfg.GetModelByType(a.agentCfg.Model)
874 if model.ID == "" {
875 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
876 }
877
878 promptID := agentPromptMap[a.agentCfg.ID]
879 if promptID == "" {
880 promptID = prompt.PromptDefault
881 }
882
883 opts := []provider.ProviderClientOption{
884 provider.WithModel(a.agentCfg.Model),
885 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
886 }
887
888 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
889 if err != nil {
890 return fmt.Errorf("failed to create new provider: %w", err)
891 }
892
893 // Update the provider and provider ID
894 a.provider = newProvider
895 a.providerID = string(currentProviderCfg.ID)
896 }
897
898 // Check if small model provider has changed (affects title and summarize providers)
899 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
900 var smallModelProviderCfg config.ProviderConfig
901
902 for _, p := range cfg.Providers {
903 if p.ID == smallModelCfg.Provider {
904 smallModelProviderCfg = p
905 break
906 }
907 }
908
909 if smallModelProviderCfg.ID == "" {
910 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
911 }
912
913 // Check if summarize provider has changed
914 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
915 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
916 if smallModel == nil {
917 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
918 }
919
920 // Recreate title provider
921 titleOpts := []provider.ProviderClientOption{
922 provider.WithModel(config.SelectedModelTypeSmall),
923 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
924 // We want the title to be short, so we limit the max tokens
925 provider.WithMaxTokens(40),
926 }
927 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
928 if err != nil {
929 return fmt.Errorf("failed to create new title provider: %w", err)
930 }
931
932 // Recreate summarize provider
933 summarizeOpts := []provider.ProviderClientOption{
934 provider.WithModel(config.SelectedModelTypeSmall),
935 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
936 }
937 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
938 if err != nil {
939 return fmt.Errorf("failed to create new summarize provider: %w", err)
940 }
941
942 // Update the providers and provider ID
943 a.titleProvider = newTitleProvider
944 a.summarizeProvider = newSummarizeProvider
945 a.summarizeProviderID = string(smallModelProviderCfg.ID)
946 }
947
948 return nil
949}