1package agent
2
3import (
4 "context"
5 "errors"
6 "fmt"
7 "log/slog"
8 "slices"
9 "strings"
10 "sync"
11 "time"
12
13 "github.com/charmbracelet/crush/internal/config"
14 fur "github.com/charmbracelet/crush/internal/fur/provider"
15 "github.com/charmbracelet/crush/internal/history"
16 "github.com/charmbracelet/crush/internal/llm/prompt"
17 "github.com/charmbracelet/crush/internal/llm/provider"
18 "github.com/charmbracelet/crush/internal/llm/tools"
19 "github.com/charmbracelet/crush/internal/log"
20 "github.com/charmbracelet/crush/internal/lsp"
21 "github.com/charmbracelet/crush/internal/message"
22 "github.com/charmbracelet/crush/internal/permission"
23 "github.com/charmbracelet/crush/internal/pubsub"
24 "github.com/charmbracelet/crush/internal/session"
25)
26
27// Common errors
28var (
29 ErrRequestCancelled = errors.New("request canceled by user")
30 ErrSessionBusy = errors.New("session is currently processing another request")
31)
32
33type AgentEventType string
34
35const (
36 AgentEventTypeError AgentEventType = "error"
37 AgentEventTypeResponse AgentEventType = "response"
38 AgentEventTypeSummarize AgentEventType = "summarize"
39)
40
41type AgentEvent struct {
42 Type AgentEventType
43 Message message.Message
44 Error error
45
46 // When summarizing
47 SessionID string
48 Progress string
49 Done bool
50}
51
52type Service interface {
53 pubsub.Suscriber[AgentEvent]
54 Model() fur.Model
55 Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error)
56 Cancel(sessionID string)
57 CancelAll()
58 IsSessionBusy(sessionID string) bool
59 IsBusy() bool
60 Summarize(ctx context.Context, sessionID string) error
61 UpdateModel() error
62}
63
64type agent struct {
65 *pubsub.Broker[AgentEvent]
66 agentCfg config.Agent
67 sessions session.Service
68 messages message.Service
69
70 tools []tools.BaseTool
71 provider provider.Provider
72 providerID string
73
74 titleProvider provider.Provider
75 summarizeProvider provider.Provider
76 summarizeProviderID string
77
78 activeRequests sync.Map
79}
80
81var agentPromptMap = map[string]prompt.PromptID{
82 "coder": prompt.PromptCoder,
83 "task": prompt.PromptTask,
84}
85
86func NewAgent(
87 agentCfg config.Agent,
88 // These services are needed in the tools
89 permissions permission.Service,
90 sessions session.Service,
91 messages message.Service,
92 history history.Service,
93 lspClients map[string]*lsp.Client,
94) (Service, error) {
95 ctx := context.Background()
96 cfg := config.Get()
97 otherTools := GetMCPTools(ctx, permissions, cfg)
98 if len(lspClients) > 0 {
99 otherTools = append(otherTools, tools.NewDiagnosticsTool(lspClients))
100 }
101
102 cwd := cfg.WorkingDir()
103 allTools := []tools.BaseTool{
104 tools.NewBashTool(permissions, cwd),
105 tools.NewEditTool(lspClients, permissions, history, cwd),
106 tools.NewFetchTool(permissions, cwd),
107 tools.NewGlobTool(cwd),
108 tools.NewGrepTool(cwd),
109 tools.NewLsTool(cwd),
110 tools.NewSourcegraphTool(),
111 tools.NewViewTool(lspClients, cwd),
112 tools.NewWriteTool(lspClients, permissions, history, cwd),
113 }
114
115 if agentCfg.ID == "coder" {
116 taskAgentCfg := config.Get().Agents["task"]
117 if taskAgentCfg.ID == "" {
118 return nil, fmt.Errorf("task agent not found in config")
119 }
120 taskAgent, err := NewAgent(taskAgentCfg, permissions, sessions, messages, history, lspClients)
121 if err != nil {
122 return nil, fmt.Errorf("failed to create task agent: %w", err)
123 }
124
125 allTools = append(
126 allTools,
127 NewAgentTool(
128 taskAgent,
129 sessions,
130 messages,
131 ),
132 )
133 }
134
135 allTools = append(allTools, otherTools...)
136 providerCfg := config.Get().GetProviderForModel(agentCfg.Model)
137 if providerCfg == nil {
138 return nil, fmt.Errorf("provider for agent %s not found in config", agentCfg.Name)
139 }
140 model := config.Get().GetModelByType(agentCfg.Model)
141
142 if model == nil {
143 return nil, fmt.Errorf("model not found for agent %s", agentCfg.Name)
144 }
145
146 promptID := agentPromptMap[agentCfg.ID]
147 if promptID == "" {
148 promptID = prompt.PromptDefault
149 }
150 opts := []provider.ProviderClientOption{
151 provider.WithModel(agentCfg.Model),
152 provider.WithSystemMessage(prompt.GetPrompt(promptID, providerCfg.ID, config.Get().Options.ContextPaths...)),
153 }
154 agentProvider, err := provider.NewProvider(*providerCfg, opts...)
155 if err != nil {
156 return nil, err
157 }
158
159 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
160 var smallModelProviderCfg *config.ProviderConfig
161 if smallModelCfg.Provider == providerCfg.ID {
162 smallModelProviderCfg = providerCfg
163 } else {
164 smallModelProviderCfg = cfg.GetProviderForModel(config.SelectedModelTypeSmall)
165
166 if smallModelProviderCfg.ID == "" {
167 return nil, fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
168 }
169 }
170 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
171 if smallModel.ID == "" {
172 return nil, fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
173 }
174
175 titleOpts := []provider.ProviderClientOption{
176 provider.WithModel(config.SelectedModelTypeSmall),
177 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
178 }
179 titleProvider, err := provider.NewProvider(*smallModelProviderCfg, titleOpts...)
180 if err != nil {
181 return nil, err
182 }
183 summarizeOpts := []provider.ProviderClientOption{
184 provider.WithModel(config.SelectedModelTypeSmall),
185 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
186 }
187 summarizeProvider, err := provider.NewProvider(*smallModelProviderCfg, summarizeOpts...)
188 if err != nil {
189 return nil, err
190 }
191
192 agentTools := []tools.BaseTool{}
193 if agentCfg.AllowedTools == nil {
194 agentTools = allTools
195 } else {
196 for _, tool := range allTools {
197 if slices.Contains(agentCfg.AllowedTools, tool.Name()) {
198 agentTools = append(agentTools, tool)
199 }
200 }
201 }
202
203 agent := &agent{
204 Broker: pubsub.NewBroker[AgentEvent](),
205 agentCfg: agentCfg,
206 provider: agentProvider,
207 providerID: string(providerCfg.ID),
208 messages: messages,
209 sessions: sessions,
210 tools: agentTools,
211 titleProvider: titleProvider,
212 summarizeProvider: summarizeProvider,
213 summarizeProviderID: string(smallModelProviderCfg.ID),
214 activeRequests: sync.Map{},
215 }
216
217 return agent, nil
218}
219
220func (a *agent) Model() fur.Model {
221 return *config.Get().GetModelByType(a.agentCfg.Model)
222}
223
224func (a *agent) Cancel(sessionID string) {
225 // Cancel regular requests
226 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID); exists {
227 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
228 slog.Info(fmt.Sprintf("Request cancellation initiated for session: %s", sessionID))
229 cancel()
230 }
231 }
232
233 // Also check for summarize requests
234 if cancelFunc, exists := a.activeRequests.LoadAndDelete(sessionID + "-summarize"); exists {
235 if cancel, ok := cancelFunc.(context.CancelFunc); ok {
236 slog.Info(fmt.Sprintf("Summarize cancellation initiated for session: %s", sessionID))
237 cancel()
238 }
239 }
240}
241
242func (a *agent) IsBusy() bool {
243 busy := false
244 a.activeRequests.Range(func(key, value any) bool {
245 if cancelFunc, ok := value.(context.CancelFunc); ok {
246 if cancelFunc != nil {
247 busy = true
248 return false
249 }
250 }
251 return true
252 })
253 return busy
254}
255
256func (a *agent) IsSessionBusy(sessionID string) bool {
257 _, busy := a.activeRequests.Load(sessionID)
258 return busy
259}
260
261func (a *agent) generateTitle(ctx context.Context, sessionID string, content string) error {
262 if content == "" {
263 return nil
264 }
265 if a.titleProvider == nil {
266 return nil
267 }
268 session, err := a.sessions.Get(ctx, sessionID)
269 if err != nil {
270 return err
271 }
272 parts := []message.ContentPart{message.TextContent{
273 Text: fmt.Sprintf("Generate a concise title for the following content:\n\n%s", content),
274 }}
275
276 // Use streaming approach like summarization
277 response := a.titleProvider.StreamResponse(
278 ctx,
279 []message.Message{
280 {
281 Role: message.User,
282 Parts: parts,
283 },
284 },
285 make([]tools.BaseTool, 0),
286 )
287
288 var finalResponse *provider.ProviderResponse
289 for r := range response {
290 if r.Error != nil {
291 return r.Error
292 }
293 finalResponse = r.Response
294 }
295
296 if finalResponse == nil {
297 return fmt.Errorf("no response received from title provider")
298 }
299
300 title := strings.TrimSpace(strings.ReplaceAll(finalResponse.Content, "\n", " "))
301 if title == "" {
302 return nil
303 }
304
305 session.Title = title
306 _, err = a.sessions.Save(ctx, session)
307 return err
308}
309
310func (a *agent) err(err error) AgentEvent {
311 return AgentEvent{
312 Type: AgentEventTypeError,
313 Error: err,
314 }
315}
316
317func (a *agent) Run(ctx context.Context, sessionID string, content string, attachments ...message.Attachment) (<-chan AgentEvent, error) {
318 if !a.Model().SupportsImages && attachments != nil {
319 attachments = nil
320 }
321 events := make(chan AgentEvent)
322 if a.IsSessionBusy(sessionID) {
323 return nil, ErrSessionBusy
324 }
325
326 genCtx, cancel := context.WithCancel(ctx)
327
328 a.activeRequests.Store(sessionID, cancel)
329 go func() {
330 slog.Debug("Request started", "sessionID", sessionID)
331 defer log.RecoverPanic("agent.Run", func() {
332 events <- a.err(fmt.Errorf("panic while running the agent"))
333 })
334 var attachmentParts []message.ContentPart
335 for _, attachment := range attachments {
336 attachmentParts = append(attachmentParts, message.BinaryContent{Path: attachment.FilePath, MIMEType: attachment.MimeType, Data: attachment.Content})
337 }
338 result := a.processGeneration(genCtx, sessionID, content, attachmentParts)
339 if result.Error != nil && !errors.Is(result.Error, ErrRequestCancelled) && !errors.Is(result.Error, context.Canceled) {
340 slog.Error(result.Error.Error())
341 }
342 slog.Debug("Request completed", "sessionID", sessionID)
343 a.activeRequests.Delete(sessionID)
344 cancel()
345 a.Publish(pubsub.CreatedEvent, result)
346 events <- result
347 close(events)
348 }()
349 return events, nil
350}
351
352func (a *agent) processGeneration(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) AgentEvent {
353 cfg := config.Get()
354 // List existing messages; if none, start title generation asynchronously.
355 msgs, err := a.messages.List(ctx, sessionID)
356 if err != nil {
357 return a.err(fmt.Errorf("failed to list messages: %w", err))
358 }
359 if len(msgs) == 0 {
360 go func() {
361 defer log.RecoverPanic("agent.Run", func() {
362 slog.Error("panic while generating title")
363 })
364 titleErr := a.generateTitle(context.Background(), sessionID, content)
365 if titleErr != nil && !errors.Is(titleErr, context.Canceled) && !errors.Is(titleErr, context.DeadlineExceeded) {
366 slog.Error(fmt.Sprintf("failed to generate title: %v", titleErr))
367 }
368 }()
369 }
370 session, err := a.sessions.Get(ctx, sessionID)
371 if err != nil {
372 return a.err(fmt.Errorf("failed to get session: %w", err))
373 }
374 if session.SummaryMessageID != "" {
375 summaryMsgInex := -1
376 for i, msg := range msgs {
377 if msg.ID == session.SummaryMessageID {
378 summaryMsgInex = i
379 break
380 }
381 }
382 if summaryMsgInex != -1 {
383 msgs = msgs[summaryMsgInex:]
384 msgs[0].Role = message.User
385 }
386 }
387
388 userMsg, err := a.createUserMessage(ctx, sessionID, content, attachmentParts)
389 if err != nil {
390 return a.err(fmt.Errorf("failed to create user message: %w", err))
391 }
392 // Append the new user message to the conversation history.
393 msgHistory := append(msgs, userMsg)
394
395 for {
396 // Check for cancellation before each iteration
397 select {
398 case <-ctx.Done():
399 return a.err(ctx.Err())
400 default:
401 // Continue processing
402 }
403 agentMessage, toolResults, err := a.streamAndHandleEvents(ctx, sessionID, msgHistory)
404 if err != nil {
405 if errors.Is(err, context.Canceled) {
406 agentMessage.AddFinish(message.FinishReasonCanceled, "Request cancelled", "")
407 a.messages.Update(context.Background(), agentMessage)
408 return a.err(ErrRequestCancelled)
409 }
410 return a.err(fmt.Errorf("failed to process events: %w", err))
411 }
412 if cfg.Options.Debug {
413 slog.Info("Result", "message", agentMessage.FinishReason(), "toolResults", toolResults)
414 }
415 if (agentMessage.FinishReason() == message.FinishReasonToolUse) && toolResults != nil {
416 // We are not done, we need to respond with the tool response
417 msgHistory = append(msgHistory, agentMessage, *toolResults)
418 continue
419 }
420 return AgentEvent{
421 Type: AgentEventTypeResponse,
422 Message: agentMessage,
423 Done: true,
424 }
425 }
426}
427
428func (a *agent) createUserMessage(ctx context.Context, sessionID, content string, attachmentParts []message.ContentPart) (message.Message, error) {
429 parts := []message.ContentPart{message.TextContent{Text: content}}
430 parts = append(parts, attachmentParts...)
431 return a.messages.Create(ctx, sessionID, message.CreateMessageParams{
432 Role: message.User,
433 Parts: parts,
434 })
435}
436
437func (a *agent) streamAndHandleEvents(ctx context.Context, sessionID string, msgHistory []message.Message) (message.Message, *message.Message, error) {
438 ctx = context.WithValue(ctx, tools.SessionIDContextKey, sessionID)
439 eventChan := a.provider.StreamResponse(ctx, msgHistory, a.tools)
440
441 assistantMsg, err := a.messages.Create(ctx, sessionID, message.CreateMessageParams{
442 Role: message.Assistant,
443 Parts: []message.ContentPart{},
444 Model: a.Model().ID,
445 Provider: a.providerID,
446 })
447 if err != nil {
448 return assistantMsg, nil, fmt.Errorf("failed to create assistant message: %w", err)
449 }
450
451 // Add the session and message ID into the context if needed by tools.
452 ctx = context.WithValue(ctx, tools.MessageIDContextKey, assistantMsg.ID)
453
454 // Process each event in the stream.
455 for event := range eventChan {
456 if processErr := a.processEvent(ctx, sessionID, &assistantMsg, event); processErr != nil {
457 if errors.Is(processErr, context.Canceled) {
458 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
459 } else {
460 a.finishMessage(ctx, &assistantMsg, message.FinishReasonError, "API Error", processErr.Error())
461 }
462 return assistantMsg, nil, processErr
463 }
464 if ctx.Err() != nil {
465 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
466 return assistantMsg, nil, ctx.Err()
467 }
468 }
469
470 toolResults := make([]message.ToolResult, len(assistantMsg.ToolCalls()))
471 toolCalls := assistantMsg.ToolCalls()
472 for i, toolCall := range toolCalls {
473 select {
474 case <-ctx.Done():
475 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
476 // Make all future tool calls cancelled
477 for j := i; j < len(toolCalls); j++ {
478 toolResults[j] = message.ToolResult{
479 ToolCallID: toolCalls[j].ID,
480 Content: "Tool execution canceled by user",
481 IsError: true,
482 }
483 }
484 goto out
485 default:
486 // Continue processing
487 var tool tools.BaseTool
488 for _, availableTool := range a.tools {
489 if availableTool.Info().Name == toolCall.Name {
490 tool = availableTool
491 break
492 }
493 }
494
495 // Tool not found
496 if tool == nil {
497 toolResults[i] = message.ToolResult{
498 ToolCallID: toolCall.ID,
499 Content: fmt.Sprintf("Tool not found: %s", toolCall.Name),
500 IsError: true,
501 }
502 continue
503 }
504
505 // Run tool in goroutine to allow cancellation
506 type toolExecResult struct {
507 response tools.ToolResponse
508 err error
509 }
510 resultChan := make(chan toolExecResult, 1)
511
512 go func() {
513 response, err := tool.Run(ctx, tools.ToolCall{
514 ID: toolCall.ID,
515 Name: toolCall.Name,
516 Input: toolCall.Input,
517 })
518 resultChan <- toolExecResult{response: response, err: err}
519 }()
520
521 var toolResponse tools.ToolResponse
522 var toolErr error
523
524 select {
525 case <-ctx.Done():
526 a.finishMessage(context.Background(), &assistantMsg, message.FinishReasonCanceled, "Request cancelled", "")
527 // Mark remaining tool calls as cancelled
528 for j := i; j < len(toolCalls); j++ {
529 toolResults[j] = message.ToolResult{
530 ToolCallID: toolCalls[j].ID,
531 Content: "Tool execution canceled by user",
532 IsError: true,
533 }
534 }
535 goto out
536 case result := <-resultChan:
537 toolResponse = result.response
538 toolErr = result.err
539 }
540
541 if toolErr != nil {
542 slog.Error("Tool execution error", "toolCall", toolCall.ID, "error", toolErr)
543 if errors.Is(toolErr, permission.ErrorPermissionDenied) {
544 toolResults[i] = message.ToolResult{
545 ToolCallID: toolCall.ID,
546 Content: "Permission denied",
547 IsError: true,
548 }
549 for j := i + 1; j < len(toolCalls); j++ {
550 toolResults[j] = message.ToolResult{
551 ToolCallID: toolCalls[j].ID,
552 Content: "Tool execution canceled by user",
553 IsError: true,
554 }
555 }
556 a.finishMessage(ctx, &assistantMsg, message.FinishReasonPermissionDenied, "Permission denied", "")
557 break
558 }
559 }
560 toolResults[i] = message.ToolResult{
561 ToolCallID: toolCall.ID,
562 Content: toolResponse.Content,
563 Metadata: toolResponse.Metadata,
564 IsError: toolResponse.IsError,
565 }
566 }
567 }
568out:
569 if len(toolResults) == 0 {
570 return assistantMsg, nil, nil
571 }
572 parts := make([]message.ContentPart, 0)
573 for _, tr := range toolResults {
574 parts = append(parts, tr)
575 }
576 msg, err := a.messages.Create(context.Background(), assistantMsg.SessionID, message.CreateMessageParams{
577 Role: message.Tool,
578 Parts: parts,
579 Provider: a.providerID,
580 })
581 if err != nil {
582 return assistantMsg, nil, fmt.Errorf("failed to create cancelled tool message: %w", err)
583 }
584
585 return assistantMsg, &msg, err
586}
587
588func (a *agent) finishMessage(ctx context.Context, msg *message.Message, finishReason message.FinishReason, message, details string) {
589 msg.AddFinish(finishReason, message, details)
590 _ = a.messages.Update(ctx, *msg)
591}
592
593func (a *agent) processEvent(ctx context.Context, sessionID string, assistantMsg *message.Message, event provider.ProviderEvent) error {
594 select {
595 case <-ctx.Done():
596 return ctx.Err()
597 default:
598 // Continue processing.
599 }
600
601 switch event.Type {
602 case provider.EventThinkingDelta:
603 assistantMsg.AppendReasoningContent(event.Thinking)
604 return a.messages.Update(ctx, *assistantMsg)
605 case provider.EventSignatureDelta:
606 assistantMsg.AppendReasoningSignature(event.Signature)
607 return a.messages.Update(ctx, *assistantMsg)
608 case provider.EventContentDelta:
609 assistantMsg.FinishThinking()
610 assistantMsg.AppendContent(event.Content)
611 return a.messages.Update(ctx, *assistantMsg)
612 case provider.EventToolUseStart:
613 assistantMsg.FinishThinking()
614 slog.Info("Tool call started", "toolCall", event.ToolCall)
615 assistantMsg.AddToolCall(*event.ToolCall)
616 return a.messages.Update(ctx, *assistantMsg)
617 case provider.EventToolUseDelta:
618 assistantMsg.AppendToolCallInput(event.ToolCall.ID, event.ToolCall.Input)
619 return a.messages.Update(ctx, *assistantMsg)
620 case provider.EventToolUseStop:
621 slog.Info("Finished tool call", "toolCall", event.ToolCall)
622 assistantMsg.FinishToolCall(event.ToolCall.ID)
623 return a.messages.Update(ctx, *assistantMsg)
624 case provider.EventError:
625 return event.Error
626 case provider.EventComplete:
627 assistantMsg.FinishThinking()
628 assistantMsg.SetToolCalls(event.Response.ToolCalls)
629 assistantMsg.AddFinish(event.Response.FinishReason, "", "")
630 if err := a.messages.Update(ctx, *assistantMsg); err != nil {
631 return fmt.Errorf("failed to update message: %w", err)
632 }
633 return a.TrackUsage(ctx, sessionID, a.Model(), event.Response.Usage)
634 }
635
636 return nil
637}
638
639func (a *agent) TrackUsage(ctx context.Context, sessionID string, model fur.Model, usage provider.TokenUsage) error {
640 sess, err := a.sessions.Get(ctx, sessionID)
641 if err != nil {
642 return fmt.Errorf("failed to get session: %w", err)
643 }
644
645 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
646 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
647 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
648 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
649
650 sess.Cost += cost
651 sess.CompletionTokens = usage.OutputTokens + usage.CacheReadTokens
652 sess.PromptTokens = usage.InputTokens + usage.CacheCreationTokens
653
654 _, err = a.sessions.Save(ctx, sess)
655 if err != nil {
656 return fmt.Errorf("failed to save session: %w", err)
657 }
658 return nil
659}
660
661func (a *agent) Summarize(ctx context.Context, sessionID string) error {
662 if a.summarizeProvider == nil {
663 return fmt.Errorf("summarize provider not available")
664 }
665
666 // Check if session is busy
667 if a.IsSessionBusy(sessionID) {
668 return ErrSessionBusy
669 }
670
671 // Create a new context with cancellation
672 summarizeCtx, cancel := context.WithCancel(ctx)
673
674 // Store the cancel function in activeRequests to allow cancellation
675 a.activeRequests.Store(sessionID+"-summarize", cancel)
676
677 go func() {
678 defer a.activeRequests.Delete(sessionID + "-summarize")
679 defer cancel()
680 event := AgentEvent{
681 Type: AgentEventTypeSummarize,
682 Progress: "Starting summarization...",
683 }
684
685 a.Publish(pubsub.CreatedEvent, event)
686 // Get all messages from the session
687 msgs, err := a.messages.List(summarizeCtx, sessionID)
688 if err != nil {
689 event = AgentEvent{
690 Type: AgentEventTypeError,
691 Error: fmt.Errorf("failed to list messages: %w", err),
692 Done: true,
693 }
694 a.Publish(pubsub.CreatedEvent, event)
695 return
696 }
697 summarizeCtx = context.WithValue(summarizeCtx, tools.SessionIDContextKey, sessionID)
698
699 if len(msgs) == 0 {
700 event = AgentEvent{
701 Type: AgentEventTypeError,
702 Error: fmt.Errorf("no messages to summarize"),
703 Done: true,
704 }
705 a.Publish(pubsub.CreatedEvent, event)
706 return
707 }
708
709 event = AgentEvent{
710 Type: AgentEventTypeSummarize,
711 Progress: "Analyzing conversation...",
712 }
713 a.Publish(pubsub.CreatedEvent, event)
714
715 // Add a system message to guide the summarization
716 summarizePrompt := "Provide a detailed but concise summary of our conversation above. Focus on information that would be helpful for continuing the conversation, including what we did, what we're doing, which files we're working on, and what we're going to do next."
717
718 // Create a new message with the summarize prompt
719 promptMsg := message.Message{
720 Role: message.User,
721 Parts: []message.ContentPart{message.TextContent{Text: summarizePrompt}},
722 }
723
724 // Append the prompt to the messages
725 msgsWithPrompt := append(msgs, promptMsg)
726
727 event = AgentEvent{
728 Type: AgentEventTypeSummarize,
729 Progress: "Generating summary...",
730 }
731
732 a.Publish(pubsub.CreatedEvent, event)
733
734 // Send the messages to the summarize provider
735 response := a.summarizeProvider.StreamResponse(
736 summarizeCtx,
737 msgsWithPrompt,
738 make([]tools.BaseTool, 0),
739 )
740 var finalResponse *provider.ProviderResponse
741 for r := range response {
742 if r.Error != nil {
743 event = AgentEvent{
744 Type: AgentEventTypeError,
745 Error: fmt.Errorf("failed to summarize: %w", err),
746 Done: true,
747 }
748 a.Publish(pubsub.CreatedEvent, event)
749 return
750 }
751 finalResponse = r.Response
752 }
753
754 summary := strings.TrimSpace(finalResponse.Content)
755 if summary == "" {
756 event = AgentEvent{
757 Type: AgentEventTypeError,
758 Error: fmt.Errorf("empty summary returned"),
759 Done: true,
760 }
761 a.Publish(pubsub.CreatedEvent, event)
762 return
763 }
764 event = AgentEvent{
765 Type: AgentEventTypeSummarize,
766 Progress: "Creating new session...",
767 }
768
769 a.Publish(pubsub.CreatedEvent, event)
770 oldSession, err := a.sessions.Get(summarizeCtx, sessionID)
771 if err != nil {
772 event = AgentEvent{
773 Type: AgentEventTypeError,
774 Error: fmt.Errorf("failed to get session: %w", err),
775 Done: true,
776 }
777
778 a.Publish(pubsub.CreatedEvent, event)
779 return
780 }
781 // Create a message in the new session with the summary
782 msg, err := a.messages.Create(summarizeCtx, oldSession.ID, message.CreateMessageParams{
783 Role: message.Assistant,
784 Parts: []message.ContentPart{
785 message.TextContent{Text: summary},
786 message.Finish{
787 Reason: message.FinishReasonEndTurn,
788 Time: time.Now().Unix(),
789 },
790 },
791 Model: a.summarizeProvider.Model().ID,
792 Provider: a.summarizeProviderID,
793 })
794 if err != nil {
795 event = AgentEvent{
796 Type: AgentEventTypeError,
797 Error: fmt.Errorf("failed to create summary message: %w", err),
798 Done: true,
799 }
800
801 a.Publish(pubsub.CreatedEvent, event)
802 return
803 }
804 oldSession.SummaryMessageID = msg.ID
805 oldSession.CompletionTokens = finalResponse.Usage.OutputTokens
806 oldSession.PromptTokens = 0
807 model := a.summarizeProvider.Model()
808 usage := finalResponse.Usage
809 cost := model.CostPer1MInCached/1e6*float64(usage.CacheCreationTokens) +
810 model.CostPer1MOutCached/1e6*float64(usage.CacheReadTokens) +
811 model.CostPer1MIn/1e6*float64(usage.InputTokens) +
812 model.CostPer1MOut/1e6*float64(usage.OutputTokens)
813 oldSession.Cost += cost
814 _, err = a.sessions.Save(summarizeCtx, oldSession)
815 if err != nil {
816 event = AgentEvent{
817 Type: AgentEventTypeError,
818 Error: fmt.Errorf("failed to save session: %w", err),
819 Done: true,
820 }
821 a.Publish(pubsub.CreatedEvent, event)
822 }
823
824 event = AgentEvent{
825 Type: AgentEventTypeSummarize,
826 SessionID: oldSession.ID,
827 Progress: "Summary complete",
828 Done: true,
829 }
830 a.Publish(pubsub.CreatedEvent, event)
831 // Send final success event with the new session ID
832 }()
833
834 return nil
835}
836
837func (a *agent) CancelAll() {
838 if !a.IsBusy() {
839 return
840 }
841 a.activeRequests.Range(func(key, value any) bool {
842 a.Cancel(key.(string)) // key is sessionID
843 return true
844 })
845
846 timeout := time.After(5 * time.Second)
847 for a.IsBusy() {
848 select {
849 case <-timeout:
850 return
851 default:
852 time.Sleep(200 * time.Millisecond)
853 }
854 }
855}
856
857func (a *agent) UpdateModel() error {
858 cfg := config.Get()
859
860 // Get current provider configuration
861 currentProviderCfg := cfg.GetProviderForModel(a.agentCfg.Model)
862 if currentProviderCfg.ID == "" {
863 return fmt.Errorf("provider for agent %s not found in config", a.agentCfg.Name)
864 }
865
866 // Check if provider has changed
867 if string(currentProviderCfg.ID) != a.providerID {
868 // Provider changed, need to recreate the main provider
869 model := cfg.GetModelByType(a.agentCfg.Model)
870 if model.ID == "" {
871 return fmt.Errorf("model not found for agent %s", a.agentCfg.Name)
872 }
873
874 promptID := agentPromptMap[a.agentCfg.ID]
875 if promptID == "" {
876 promptID = prompt.PromptDefault
877 }
878
879 opts := []provider.ProviderClientOption{
880 provider.WithModel(a.agentCfg.Model),
881 provider.WithSystemMessage(prompt.GetPrompt(promptID, currentProviderCfg.ID, cfg.Options.ContextPaths...)),
882 }
883
884 newProvider, err := provider.NewProvider(*currentProviderCfg, opts...)
885 if err != nil {
886 return fmt.Errorf("failed to create new provider: %w", err)
887 }
888
889 // Update the provider and provider ID
890 a.provider = newProvider
891 a.providerID = string(currentProviderCfg.ID)
892 }
893
894 // Check if small model provider has changed (affects title and summarize providers)
895 smallModelCfg := cfg.Models[config.SelectedModelTypeSmall]
896 var smallModelProviderCfg config.ProviderConfig
897
898 for _, p := range cfg.Providers {
899 if p.ID == smallModelCfg.Provider {
900 smallModelProviderCfg = p
901 break
902 }
903 }
904
905 if smallModelProviderCfg.ID == "" {
906 return fmt.Errorf("provider %s not found in config", smallModelCfg.Provider)
907 }
908
909 // Check if summarize provider has changed
910 if string(smallModelProviderCfg.ID) != a.summarizeProviderID {
911 smallModel := cfg.GetModelByType(config.SelectedModelTypeSmall)
912 if smallModel == nil {
913 return fmt.Errorf("model %s not found in provider %s", smallModelCfg.Model, smallModelProviderCfg.ID)
914 }
915
916 // Recreate title provider
917 titleOpts := []provider.ProviderClientOption{
918 provider.WithModel(config.SelectedModelTypeSmall),
919 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptTitle, smallModelProviderCfg.ID)),
920 // We want the title to be short, so we limit the max tokens
921 provider.WithMaxTokens(40),
922 }
923 newTitleProvider, err := provider.NewProvider(smallModelProviderCfg, titleOpts...)
924 if err != nil {
925 return fmt.Errorf("failed to create new title provider: %w", err)
926 }
927
928 // Recreate summarize provider
929 summarizeOpts := []provider.ProviderClientOption{
930 provider.WithModel(config.SelectedModelTypeSmall),
931 provider.WithSystemMessage(prompt.GetPrompt(prompt.PromptSummarizer, smallModelProviderCfg.ID)),
932 }
933 newSummarizeProvider, err := provider.NewProvider(smallModelProviderCfg, summarizeOpts...)
934 if err != nil {
935 return fmt.Errorf("failed to create new summarize provider: %w", err)
936 }
937
938 // Update the providers and provider ID
939 a.titleProvider = newTitleProvider
940 a.summarizeProvider = newSummarizeProvider
941 a.summarizeProviderID = string(smallModelProviderCfg.ID)
942 }
943
944 return nil
945}