1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 fur "github.com/charmbracelet/crush/internal/fur/provider"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25)
26
27// Common errors
28var (
29 ErrRequestCancelled = errors.New("request canceled by user")
30 ErrSessionBusy = errors.New("session is currently processing another request")
31)
32
33type AgentEventType string
34
35const (
36 AgentEventTypeError AgentEventType = "error"
37 AgentEventTypeResponse AgentEventType = "response"
38 AgentEventTypeSummarize AgentEventType = "summarize"
39)
40
41type AgentEvent struct {
42 Type AgentEventType
43 Message message.Message
44 Error error
45
46 // When summarizing
47 SessionID string
48 Progress string
49 Done bool
50}
51
52type Service interface {
53 pubsub.Suscriber[AgentEvent]
54 Model() fur.Model
55 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
56 Cancel(sessionID string)
57 CancelAll()
58 IsSessionBusy(sessionID string) bool
59 IsBusy() bool
60 Summarize(ctx context.Context, sessionID string) error
61 UpdateModel() error
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69
70 tools []tools.BaseTool
71 provider provider.Provider
72 providerID string
73
74 titleProvider provider.Provider
75 summarizeProvider provider.Provider
76 summarizeProviderID string
77
78 activeRequests sync.Map
79}
80
81var agentPromptMap = map[string]prompt.PromptID{
82 "coder": prompt.PromptCoder,
83 "task": prompt.PromptTask,
84}
85
86func NewAgent(
87 agentCfg config.Agent,
88 // These services are needed in the tools
89 permissions permission.Service,
90 sessions session.Service,
91 messages message.Service,
92 history history.Service,
93 lspClients map[string]*lsp.Client,
94) (Service, error) {
95 ctx := context.Background()
96 cfg := config.Get()
97 otherTools := GetMCPTools(ctx, permissions, cfg)
98 if len(lspClients) > 0 {
99 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100 }
101
102 cwd := cfg.WorkingDir()
103 allTools := []tools.BaseTool{
104 tools.NewBashTool(permissions, cwd),
105 tools.NewEditTool(lspClients, permissions, history, cwd),
106 tools.NewFetchTool(permissions, cwd),
107 tools.NewGlobTool(cwd),
108 tools.NewGrepTool(cwd),
109 tools.NewLsTool(cwd),
110 tools.NewSourcegraphTool(),
111 tools.NewViewTool(lspClients, cwd),
112 tools.NewWriteTool(lspClients, permissions, history, cwd),
113 }
114
115 if agentCfg.ID == "coder" {
116 taskAgentCfg := config.Get().Agents["task"]
117 if taskAgentCfg.ID == "" {
118 return nil, fmt.Errorf("task agent not found in config")
119 }
120 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
121 if err != nil {
122 return nil, fmt.Errorf("failed to create task agent: %w", err)
123 }
124
125 allTools = append(
126 allTools,
127 NewAgentTool(
128 taskAgent,
129 sessions,
130 messages,
131 ),
132 )
133 }
134
135 allTools = append(allTools, otherTools...)
136 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
137 if providerCfg == nil {
138 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
139 }
140 model := config.Get().GetModelByType(agentCfg.Model)
141
142 if model == nil {
143 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
144 }
145
146 promptID := agentPromptMap[agentCfg.ID]
147 if promptID == "" {
148 promptID = prompt.PromptDefault
149 }
150 opts := []provider.ProviderClientOption{
151 provider.WithModel(agentCfg.Model),
152 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
153 }
154 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
155 if err != nil {
156 return nil, err
157 }
158
159 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
160 var smallModelProviderCfg *config.ProviderConfig
161 if smallModelCfg.Provider == providerCfg.ID {
162 smallModelProviderCfg = providerCfg
163 } else {
164 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
165
166 if smallModelProviderCfg.ID == "" {
167 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
168 }
169 }
170 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
171 if smallModel.ID == "" {
172 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
173 }
174
175 titleOpts := []provider.ProviderClientOption{
176 provider.WithModel(config.SelectedModelTypeSmall),
177 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
178 }
179 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
180 if err != nil {
181 return nil, err
182 }
183 summarizeOpts := []provider.ProviderClientOption{
184 provider.WithModel(config.SelectedModelTypeSmall),
185 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
186 }
187 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
188 if err != nil {
189 return nil, err
190 }
191
192 agentTools := []tools.BaseTool{}
193 if agentCfg.AllowedTools == nil {
194 agentTools = allTools
195 } else {
196 for _, tool := range allTools {
197 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
198 agentTools = append(agentTools, tool)
199 }
200 }
201 }
202
203 agent := &agent{
204 Broker: pubsub.NewBroker[AgentEvent](),
205 agentCfg: agentCfg,
206 provider: agentProvider,
207 providerID: string(providerCfg.ID),
208 messages: messages,
209 sessions: sessions,
210 tools: agentTools,
211 titleProvider: titleProvider,
212 summarizeProvider: summarizeProvider,
213 summarizeProviderID: string(smallModelProviderCfg.ID),
214 activeRequests: sync.Map{},
215 }
216
217 return agent, nil
218}
219
220func (a *agent) Model() fur.Model {
221 return *config.Get().GetModelByType(a.agentCfg.Model)
222}
223
224func (a *agent) Cancel(sessionID string) {
225 // Cancel regular requests
226 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
227 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
228 slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
229 cancel()
230 }
231 }
232
233 // Also check for summarize requests
234 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
235 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236 slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
237 cancel()
238 }
239 }
240}
241
242func (a *agent) IsBusy() bool {
243 busy := false
244 a.activeRequests.Range(func(key, value any) bool {
245 if cancelFunc, ok := value.(context.CancelFunc); ok {
246 if cancelFunc != nil {
247 busy = true
248 return false
249 }
250 }
251 return true
252 })
253 return busy
254}
255
256func (a *agent) IsSessionBusy(sessionID string) bool {
257 _, busy := a.activeRequests.Load(sessionID)
258 return busy
259}
260
261func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
262 if content == "" {
263 return nil
264 }
265 if a.titleProvider == nil {
266 return nil
267 }
268 session, err := a.sessions.Get(ctx, sessionID)
269 if err != nil {
270 return err
271 }
272 parts := []message.ContentPart{message.TextContent{
273 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
274 }}
275
276 // Use streaming approach like summarization
277 response := a.titleProvider.StreamResponse(
278 ctx,
279 []message.Message{
280 {
281 Role: message.User,
282 Parts: parts,
283 },
284 },
285 make([]tools.BaseTool, 0),
286 )
287
288 var finalResponse *provider.ProviderResponse
289 for r := range response {
290 if r.Error != nil {
291 return r.Error
292 }
293 finalResponse = r.Response
294 }
295
296 if finalResponse == nil {
297 return fmt.Errorf("no response received from title provider")
298 }
299
300 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
301 if title == "" {
302 return nil
303 }
304
305 session.Title = title
306 _, err = a.sessions.Save(ctx, session)
307 return err
308}
309
310func (a *agent) err(err error) AgentEvent {
311 return AgentEvent{
312 Type: AgentEventTypeError,
313 Error: err,
314 }
315}
316
317func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
318 if !a.Model().SupportsImages && attachments != nil {
319 attachments = nil
320 }
321 events := make(chan AgentEvent)
322 if a.IsSessionBusy(sessionID) {
323 return nil, ErrSessionBusy
324 }
325
326 genCtx, cancel := context.WithCancel(ctx)
327
328 a.activeRequests.Store(sessionID, cancel)
329 go func() {
330 slog.Debug("Request started", "sessionID", sessionID)
331 defer log.RecoverPanic("agent.Run", func() {
332 events <- a.err(fmt.Errorf("panic while running the agent"))
333 })
334 var attachmentParts []message.ContentPart
335 for _, attachment := range attachments {
336 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
337 }
338 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
339 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
340 slog.Error(result.Error.Error())
341 }
342 slog.Debug("Request completed", "sessionID", sessionID)
343 a.activeRequests.Delete(sessionID)
344 cancel()
345 a.Publish(pubsub.CreatedEvent, result)
346 events <- result
347 close(events)
348 }()
349 return events, nil
350}
351
352func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
353 cfg := config.Get()
354 // List existing messages; if none, start title generation asynchronously.
355 msgs, err := a.messages.List(ctx, sessionID)
356 if err != nil {
357 return a.err(fmt.Errorf("failed to list messages: %w", err))
358 }
359 if len(msgs) == 0 {
360 go func() {
361 defer log.RecoverPanic("agent.Run", func() {
362 slog.Error("panic while generating title")
363 })
364 titleErr := a.generateTitle(context.Background(), sessionID, content)
365 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
366 slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
367 }
368 }()
369 }
370 session, err := a.sessions.Get(ctx, sessionID)
371 if err != nil {
372 return a.err(fmt.Errorf("failed to get session: %w", err))
373 }
374 if session.SummaryMessageID != "" {
375 summaryMsgInex := -1
376 for i, msg := range msgs {
377 if msg.ID == session.SummaryMessageID {
378 summaryMsgInex = i
379 break
380 }
381 }
382 if summaryMsgInex != -1 {
383 msgs = msgs[summaryMsgInex:]
384 msgs[0].Role = message.User
385 }
386 }
387
388 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
389 if err != nil {
390 return a.err(fmt.Errorf("failed to create user message: %w", err))
391 }
392 // Append the new user message to the conversation history.
393 msgHistory := append(msgs, userMsg)
394
395 for {
396 // Check for cancellation before each iteration
397 select {
398 case <-ctx.Done():
399 return a.err(ctx.Err())
400 default:
401 // Continue processing
402 }
403 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
404 if err != nil {
405 if errors.Is(err, context.Canceled) {
406 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
407 a.messages.Update(context.Background(), agentMessage)
408 return a.err(ErrRequestCancelled)
409 }
410 return a.err(fmt.Errorf("failed to process events: %w", err))
411 }
412 if cfg.Options.Debug {
413 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
414 }
415 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
416 // We are not done, we need to respond with the tool response
417 msgHistory = append(msgHistory, agentMessage, *toolResults)
418 continue
419 }
420 return AgentEvent{
421 Type: AgentEventTypeResponse,
422 Message: agentMessage,
423 Done: true,
424 }
425 }
426}
427
428func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
429 parts := []message.ContentPart{message.TextContent{Text: content}}
430 parts = append(parts, attachmentParts...)
431 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
432 Role: message.User,
433 Parts: parts,
434 })
435}
436
437func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
438 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
439 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
440
441 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442 Role: message.Assistant,
443 Parts: []message.ContentPart{},
444 Model: a.Model().ID,
445 Provider: a.providerID,
446 })
447 if err != nil {
448 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
449 }
450
451 // Add the session and message ID into the context if needed by tools.
452 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
453
454 // Process each event in the stream.
455 for event := range eventChan {
456 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
457 if errors.Is(processErr, context.Canceled) {
458 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
459 } else {
460 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
461 }
462 return assistantMsg, nil, processErr
463 }
464 if ctx.Err() != nil {
465 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
466 return assistantMsg, nil, ctx.Err()
467 }
468 }
469
470 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
471 toolCalls := assistantMsg.ToolCalls()
472 for i, toolCall := range toolCalls {
473 select {
474 case <-ctx.Done():
475 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476 // Make all future tool calls cancelled
477 for j := i; j < len(toolCalls); j++ {
478 toolResults[j] = message.ToolResult{
479 ToolCallID: toolCalls[j].ID,
480 Content: "Tool execution canceled by user",
481 IsError: true,
482 }
483 }
484 goto out
485 default:
486 // Continue processing
487 var tool tools.BaseTool
488 for _, availableTool := range a.tools {
489 if availableTool.Info().Name == toolCall.Name {
490 tool = availableTool
491 break
492 }
493 }
494
495 // Tool not found
496 if tool == nil {
497 toolResults[i] = message.ToolResult{
498 ToolCallID: toolCall.ID,
499 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
500 IsError: true,
501 }
502 continue
503 }
504
505 // Run tool in goroutine to allow cancellation
506 type toolExecResult struct {
507 response tools.ToolResponse
508 err error
509 }
510 resultChan := make(chan toolExecResult, 1)
511
512 go func() {
513 response, err := tool.Run(ctx, tools.ToolCall{
514 ID: toolCall.ID,
515 Name: toolCall.Name,
516 Input: toolCall.Input,
517 })
518 resultChan <- toolExecResult{response: response, err: err}
519 }()
520
521 var toolResponse tools.ToolResponse
522 var toolErr error
523
524 select {
525 case <-ctx.Done():
526 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
527 // Mark remaining tool calls as cancelled
528 for j := i; j < len(toolCalls); j++ {
529 toolResults[j] = message.ToolResult{
530 ToolCallID: toolCalls[j].ID,
531 Content: "Tool execution canceled by user",
532 IsError: true,
533 }
534 }
535 goto out
536 case result := <-resultChan:
537 toolResponse = result.response
538 toolErr = result.err
539 }
540
541 if toolErr != nil {
542 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
543 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
544 toolResults[i] = message.ToolResult{
545 ToolCallID: toolCall.ID,
546 Content: "Permission denied",
547 IsError: true,
548 }
549 for j := i + 1; j < len(toolCalls); j++ {
550 toolResults[j] = message.ToolResult{
551 ToolCallID: toolCalls[j].ID,
552 Content: "Tool execution canceled by user",
553 IsError: true,
554 }
555 }
556 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
557 break
558 }
559 }
560 toolResults[i] = message.ToolResult{
561 ToolCallID: toolCall.ID,
562 Content: toolResponse.Content,
563 Metadata: toolResponse.Metadata,
564 IsError: toolResponse.IsError,
565 }
566 }
567 }
568out:
569 if len(toolResults) == 0 {
570 return assistantMsg, nil, nil
571 }
572 parts := make([]message.ContentPart, 0)
573 for _, tr := range toolResults {
574 parts = append(parts, tr)
575 }
576 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
577 Role: message.Tool,
578 Parts: parts,
579 Provider: a.providerID,
580 })
581 if err != nil {
582 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
583 }
584
585 return assistantMsg, &msg, err
586}
587
588func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
589 msg.AddFinish(finishReason, message, details)
590 _ = a.messages.Update(ctx, *msg)
591}
592
593func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
594 select {
595 case <-ctx.Done():
596 return ctx.Err()
597 default:
598 // Continue processing.
599 }
600
601 switch event.Type {
602 case provider.EventThinkingDelta:
603 assistantMsg.AppendReasoningContent(event.Content)
604 return a.messages.Update(ctx, *assistantMsg)
605 case provider.EventContentDelta:
606 assistantMsg.AppendContent(event.Content)
607 return a.messages.Update(ctx, *assistantMsg)
608 case provider.EventToolUseStart:
609 slog.Info("Tool call started", "toolCall", event.ToolCall)
610 assistantMsg.AddToolCall(*event.ToolCall)
611 return a.messages.Update(ctx, *assistantMsg)
612 case provider.EventToolUseDelta:
613 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
614 return a.messages.Update(ctx, *assistantMsg)
615 case provider.EventToolUseStop:
616 slog.Info("Finished tool call", "toolCall", event.ToolCall)
617 assistantMsg.FinishToolCall(event.ToolCall.ID)
618 return a.messages.Update(ctx, *assistantMsg)
619 case provider.EventError:
620 return event.Error
621 case provider.EventComplete:
622 assistantMsg.SetToolCalls(event.Response.ToolCalls)
623 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
624 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
625 return fmt.Errorf("failed to update message: %w", err)
626 }
627 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
628 }
629
630 return nil
631}
632
633func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
634 sess, err := a.sessions.Get(ctx, sessionID)
635 if err != nil {
636 return fmt.Errorf("failed to get session: %w", err)
637 }
638
639 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
640 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
641 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
642 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
643
644 sess.Cost += cost
645 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
646 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
647
648 _, err = a.sessions.Save(ctx, sess)
649 if err != nil {
650 return fmt.Errorf("failed to save session: %w", err)
651 }
652 return nil
653}
654
655func (a *agent) Summarize(ctx context.Context, sessionID string) error {
656 if a.summarizeProvider == nil {
657 return fmt.Errorf("summarize provider not available")
658 }
659
660 // Check if session is busy
661 if a.IsSessionBusy(sessionID) {
662 return ErrSessionBusy
663 }
664
665 // Create a new context with cancellation
666 summarizeCtx, cancel := context.WithCancel(ctx)
667
668 // Store the cancel function in activeRequests to allow cancellation
669 a.activeRequests.Store(sessionID+"-summarize", cancel)
670
671 go func() {
672 defer a.activeRequests.Delete(sessionID + "-summarize")
673 defer cancel()
674 event := AgentEvent{
675 Type: AgentEventTypeSummarize,
676 Progress: "Starting summarization...",
677 }
678
679 a.Publish(pubsub.CreatedEvent, event)
680 // Get all messages from the session
681 msgs, err := a.messages.List(summarizeCtx, sessionID)
682 if err != nil {
683 event = AgentEvent{
684 Type: AgentEventTypeError,
685 Error: fmt.Errorf("failed to list messages: %w", err),
686 Done: true,
687 }
688 a.Publish(pubsub.CreatedEvent, event)
689 return
690 }
691 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
692
693 if len(msgs) == 0 {
694 event = AgentEvent{
695 Type: AgentEventTypeError,
696 Error: fmt.Errorf("no messages to summarize"),
697 Done: true,
698 }
699 a.Publish(pubsub.CreatedEvent, event)
700 return
701 }
702
703 event = AgentEvent{
704 Type: AgentEventTypeSummarize,
705 Progress: "Analyzing conversation...",
706 }
707 a.Publish(pubsub.CreatedEvent, event)
708
709 // Add a system message to guide the summarization
710 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
711
712 // Create a new message with the summarize prompt
713 promptMsg := message.Message{
714 Role: message.User,
715 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
716 }
717
718 // Append the prompt to the messages
719 msgsWithPrompt := append(msgs, promptMsg)
720
721 event = AgentEvent{
722 Type: AgentEventTypeSummarize,
723 Progress: "Generating summary...",
724 }
725
726 a.Publish(pubsub.CreatedEvent, event)
727
728 // Send the messages to the summarize provider
729 response := a.summarizeProvider.StreamResponse(
730 summarizeCtx,
731 msgsWithPrompt,
732 make([]tools.BaseTool, 0),
733 )
734 var finalResponse *provider.ProviderResponse
735 for r := range response {
736 if r.Error != nil {
737 event = AgentEvent{
738 Type: AgentEventTypeError,
739 Error: fmt.Errorf("failed to summarize: %w", err),
740 Done: true,
741 }
742 a.Publish(pubsub.CreatedEvent, event)
743 return
744 }
745 finalResponse = r.Response
746 }
747
748 summary := strings.TrimSpace(finalResponse.Content)
749 if summary == "" {
750 event = AgentEvent{
751 Type: AgentEventTypeError,
752 Error: fmt.Errorf("empty summary returned"),
753 Done: true,
754 }
755 a.Publish(pubsub.CreatedEvent, event)
756 return
757 }
758 event = AgentEvent{
759 Type: AgentEventTypeSummarize,
760 Progress: "Creating new session...",
761 }
762
763 a.Publish(pubsub.CreatedEvent, event)
764 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
765 if err != nil {
766 event = AgentEvent{
767 Type: AgentEventTypeError,
768 Error: fmt.Errorf("failed to get session: %w", err),
769 Done: true,
770 }
771
772 a.Publish(pubsub.CreatedEvent, event)
773 return
774 }
775 // Create a message in the new session with the summary
776 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
777 Role: message.Assistant,
778 Parts: []message.ContentPart{
779 message.TextContent{Text: summary},
780 message.Finish{
781 Reason: message.FinishReasonEndTurn,
782 Time: time.Now().Unix(),
783 },
784 },
785 Model: a.summarizeProvider.Model().ID,
786 Provider: a.summarizeProviderID,
787 })
788 if err != nil {
789 event = AgentEvent{
790 Type: AgentEventTypeError,
791 Error: fmt.Errorf("failed to create summary message: %w", err),
792 Done: true,
793 }
794
795 a.Publish(pubsub.CreatedEvent, event)
796 return
797 }
798 oldSession.SummaryMessageID = msg.ID
799 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
800 oldSession.PromptTokens = 0
801 model := a.summarizeProvider.Model()
802 usage := finalResponse.Usage
803 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
804 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
805 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
806 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
807 oldSession.Cost += cost
808 _, err = a.sessions.Save(summarizeCtx, oldSession)
809 if err != nil {
810 event = AgentEvent{
811 Type: AgentEventTypeError,
812 Error: fmt.Errorf("failed to save session: %w", err),
813 Done: true,
814 }
815 a.Publish(pubsub.CreatedEvent, event)
816 }
817
818 event = AgentEvent{
819 Type: AgentEventTypeSummarize,
820 SessionID: oldSession.ID,
821 Progress: "Summary complete",
822 Done: true,
823 }
824 a.Publish(pubsub.CreatedEvent, event)
825 // Send final success event with the new session ID
826 }()
827
828 return nil
829}
830
831func (a *agent) CancelAll() {
832 if !a.IsBusy() {
833 return
834 }
835 a.activeRequests.Range(func(key, value any) bool {
836 a.Cancel(key.(string)) // key is sessionID
837 return true
838 })
839
840 timeout := time.After(5 * time.Second)
841 for a.IsBusy() {
842 select {
843 case <-timeout:
844 return
845 default:
846 time.Sleep(200 * time.Millisecond)
847 }
848 }
849}
850
851func (a *agent) UpdateModel() error {
852 cfg := config.Get()
853
854 // Get current provider configuration
855 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
856 if currentProviderCfg.ID == "" {
857 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
858 }
859
860 // Check if provider has changed
861 if string(currentProviderCfg.ID) != a.providerID {
862 // Provider changed, need to recreate the main provider
863 model := cfg.GetModelByType(a.agentCfg.Model)
864 if model.ID == "" {
865 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
866 }
867
868 promptID := agentPromptMap[a.agentCfg.ID]
869 if promptID == "" {
870 promptID = prompt.PromptDefault
871 }
872
873 opts := []provider.ProviderClientOption{
874 provider.WithModel(a.agentCfg.Model),
875 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
876 }
877
878 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
879 if err != nil {
880 return fmt.Errorf("failed to create new provider: %w", err)
881 }
882
883 // Update the provider and provider ID
884 a.provider = newProvider
885 a.providerID = string(currentProviderCfg.ID)
886 }
887
888 // Check if small model provider has changed (affects title and summarize providers)
889 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
890 var smallModelProviderCfg config.ProviderConfig
891
892 for _, p := range cfg.Providers {
893 if p.ID == smallModelCfg.Provider {
894 smallModelProviderCfg = p
895 break
896 }
897 }
898
899 if smallModelProviderCfg.ID == "" {
900 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
901 }
902
903 // Check if summarize provider has changed
904 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
905 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
906 if smallModel == nil {
907 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
908 }
909
910 // Recreate title provider
911 titleOpts := []provider.ProviderClientOption{
912 provider.WithModel(config.SelectedModelTypeSmall),
913 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
914 // We want the title to be short, so we limit the max tokens
915 provider.WithMaxTokens(40),
916 }
917 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
918 if err != nil {
919 return fmt.Errorf("failed to create new title provider: %w", err)
920 }
921
922 // Recreate summarize provider
923 summarizeOpts := []provider.ProviderClientOption{
924 provider.WithModel(config.SelectedModelTypeSmall),
925 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
926 }
927 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
928 if err != nil {
929 return fmt.Errorf("failed to create new summarize provider: %w", err)
930 }
931
932 // Update the providers and provider ID
933 a.titleProvider = newTitleProvider
934 a.summarizeProvider = newSummarizeProvider
935 a.summarizeProviderID = string(smallModelProviderCfg.ID)
936 }
937
938 return nil
939}