From d26c477d868dc86f135653f90bd296647260f138 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Tue, 1 Apr 2025 18:48:56 -0300 Subject: [PATCH] assistant2: Summarize threads in context and continue long ones (#27851) We'll now prompt the user to start a new thread when the active one gets too long. When they click "Start New Thread", will create a new one with the previous one added as context. Instead of including the full thread text, we'll now add summarized versions of threads to the context, allowing you to continue the conversation even if it was near the token limit. - Thread summaries are cached and persisted. - A cached summary is invalidated if the thread is continued. - We start generating the thread summary as soon as it's selected from the picker. Most times, the summary will be ready by the time the user sends the message. - If the summary isn't ready by the time a message is sent, the user message will be displayed in the thread immediately, and a "Summarizing context..." indicator will appear. After the summaries are ready, we'll start generating the response and show the usual "Generating..." indicator. Release Notes: - N/A --------- Co-authored-by: Danilo Leal Co-authored-by: Marshall Bowers --- crates/assistant2/src/active_thread.rs | 4 +- crates/assistant2/src/assistant.rs | 10 +- crates/assistant2/src/assistant_panel.rs | 86 +++++++-- crates/assistant2/src/context.rs | 2 - .../src/context_picker/completion_provider.rs | 2 +- crates/assistant2/src/context_store.rs | 49 ++++- crates/assistant2/src/inline_assistant.rs | 6 +- crates/assistant2/src/message_editor.rs | 177 +++++++++++++++--- .../src/terminal_inline_assistant.rs | 3 +- crates/assistant2/src/thread.rs | 138 +++++++++++++- crates/assistant2/src/thread_store.rs | 9 +- crates/assistant2/src/ui/context_pill.rs | 31 ++- 12 files changed, 453 insertions(+), 64 deletions(-) diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 37283f85901ca41607b219960418ac5cd905a507..dae556bd70ae0f0eb835a8296e7e93a1af72b778 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -945,7 +945,7 @@ impl ActiveThread { .map(|(_, state)| state.editor.clone()); let first_message = ix == 0; - let is_last_message = ix == self.messages.len() - 1; + let show_feedback = ix == self.messages.len() - 1 && message.role != Role::User; let colors = cx.theme().colors(); let active_color = colors.element_active; @@ -1311,7 +1311,7 @@ impl ActiveThread { }) .child(styled_message) .when( - is_last_message && !self.thread.read(cx).is_generating(), + show_feedback && !self.thread.read(cx).is_generating(), |parent| parent.child(feedback_items), ) .into_any() diff --git a/crates/assistant2/src/assistant.rs b/crates/assistant2/src/assistant.rs index 3d8dcafdc82610f6befd6ba742e94d194ecbc5b5..3a2fb29659e3475698748f0c50a4ea4311740359 100644 --- a/crates/assistant2/src/assistant.rs +++ b/crates/assistant2/src/assistant.rs @@ -33,6 +33,7 @@ use prompt_store::PromptBuilder; use schemars::JsonSchema; use serde::Deserialize; use settings::Settings as _; +use thread::ThreadId; pub use crate::active_thread::ActiveThread; use crate::assistant_configuration::{AddContextServerModal, ManageProfilesModal}; @@ -45,7 +46,6 @@ pub use assistant_diff::{AssistantDiff, AssistantDiffToolbar}; actions!( agent, [ - NewThread, NewPromptEditor, ToggleContextPicker, ToggleProfileSelector, @@ -73,6 +73,12 @@ actions!( ] ); +#[derive(Default, Clone, PartialEq, Deserialize, JsonSchema)] +pub struct NewThread { + #[serde(default)] + from_thread_id: Option, +} + #[derive(PartialEq, Clone, Default, Debug, Deserialize, JsonSchema)] pub struct ManageProfiles { #[serde(default)] @@ -87,7 +93,7 @@ impl ManageProfiles { } } -impl_actions!(agent, [ManageProfiles]); +impl_actions!(agent, [NewThread, ManageProfiles]); const NAMESPACE: &str = "agent"; diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index 0d4f1109292f322614bd65ff65ccc0d5a71e62ed..70d51769023306260633ba77bdf5ce0a1faabf11 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -56,9 +56,9 @@ pub fn init(cx: &mut App) { cx.observe_new( |workspace: &mut Workspace, _window, _cx: &mut Context| { workspace - .register_action(|workspace, _: &NewThread, window, cx| { + .register_action(|workspace, action: &NewThread, window, cx| { if let Some(panel) = workspace.panel::(cx) { - panel.update(cx, |panel, cx| panel.new_thread(window, cx)); + panel.update(cx, |panel, cx| panel.new_thread(action, window, cx)); workspace.focus_panel::(window, cx); } }) @@ -181,8 +181,12 @@ impl AssistantPanel { let workspace = workspace.weak_handle(); let weak_self = cx.entity().downgrade(); - let message_editor_context_store = - cx.new(|_cx| crate::context_store::ContextStore::new(workspace.clone())); + let message_editor_context_store = cx.new(|_cx| { + crate::context_store::ContextStore::new( + workspace.clone(), + Some(thread_store.downgrade()), + ) + }); let message_editor = cx.new(|cx| { MessageEditor::new( @@ -268,15 +272,39 @@ impl AssistantPanel { .update(cx, |thread, cx| thread.cancel_last_completion(cx)); } - fn new_thread(&mut self, window: &mut Window, cx: &mut Context) { + fn new_thread(&mut self, action: &NewThread, window: &mut Window, cx: &mut Context) { let thread = self .thread_store .update(cx, |this, cx| this.create_thread(cx)); self.active_view = ActiveView::Thread; - let message_editor_context_store = - cx.new(|_cx| crate::context_store::ContextStore::new(self.workspace.clone())); + let message_editor_context_store = cx.new(|_cx| { + crate::context_store::ContextStore::new( + self.workspace.clone(), + Some(self.thread_store.downgrade()), + ) + }); + + if let Some(other_thread_id) = action.from_thread_id.clone() { + let other_thread_task = self + .thread_store + .update(cx, |this, cx| this.open_thread(&other_thread_id, cx)); + + cx.spawn({ + let context_store = message_editor_context_store.clone(); + + async move |_panel, cx| { + let other_thread = other_thread_task.await?; + + context_store.update(cx, |this, cx| { + this.add_thread(other_thread, false, cx); + })?; + anyhow::Ok(()) + } + }) + .detach_and_log_err(cx); + } self.thread = cx.new(|cx| { ActiveThread::new( @@ -414,8 +442,12 @@ impl AssistantPanel { let thread = open_thread_task.await?; this.update_in(cx, |this, window, cx| { this.active_view = ActiveView::Thread; - let message_editor_context_store = - cx.new(|_cx| crate::context_store::ContextStore::new(this.workspace.clone())); + let message_editor_context_store = cx.new(|_cx| { + crate::context_store::ContextStore::new( + this.workspace.clone(), + Some(this.thread_store.downgrade()), + ) + }); this.thread = cx.new(|cx| { ActiveThread::new( thread.clone(), @@ -556,7 +588,7 @@ impl AssistantPanel { } } - self.new_thread(window, cx); + self.new_thread(&NewThread::default(), window, cx); } } } @@ -688,11 +720,14 @@ impl Panel for AssistantPanel { impl AssistantPanel { fn render_toolbar(&self, _window: &mut Window, cx: &mut Context) -> impl IntoElement { let thread = self.thread.read(cx); + let is_empty = thread.is_empty(); + + let thread_id = thread.thread().read(cx).id().clone(); let focus_handle = self.focus_handle(cx); let title = match self.active_view { ActiveView::Thread => { - if thread.is_empty() { + if is_empty { thread.summary_or_default(cx) } else { thread @@ -754,14 +789,17 @@ impl AssistantPanel { .tooltip(move |window, cx| { Tooltip::for_action_in( "New Thread", - &NewThread, + &NewThread::default(), &focus_handle, window, cx, ) }) .on_click(move |_event, window, cx| { - window.dispatch_action(NewThread.boxed_clone(), cx); + window.dispatch_action( + NewThread::default().boxed_clone(), + cx, + ); }), ) .child( @@ -780,9 +818,23 @@ impl AssistantPanel { cx, |menu, _window, _cx| { menu.action( + "New Thread", + Box::new(NewThread { + from_thread_id: None, + }), + ) + .action( "New Prompt Editor", NewPromptEditor.boxed_clone(), ) + .when(!is_empty, |menu| { + menu.action( + "Continue in New Thread", + Box::new(NewThread { + from_thread_id: Some(thread_id.clone()), + }), + ) + }) .separator() .action("History", OpenHistory.boxed_clone()) .action("Settings", OpenConfiguration.boxed_clone()) @@ -871,13 +923,13 @@ impl AssistantPanel { .icon_color(Color::Muted) .full_width() .key_binding(KeyBinding::for_action_in( - &NewThread, + &NewThread::default(), &focus_handle, window, cx, )) .on_click(|_event, window, cx| { - window.dispatch_action(NewThread.boxed_clone(), cx) + window.dispatch_action(NewThread::default().boxed_clone(), cx) }), ) .child( @@ -1267,8 +1319,8 @@ impl Render for AssistantPanel { .justify_between() .size_full() .on_action(cx.listener(Self::cancel)) - .on_action(cx.listener(|this, _: &NewThread, window, cx| { - this.new_thread(window, cx); + .on_action(cx.listener(|this, action: &NewThread, window, cx| { + this.new_thread(action, window, cx); })) .on_action(cx.listener(|this, _: &OpenHistory, window, cx| { this.open_history(window, cx); diff --git a/crates/assistant2/src/context.rs b/crates/assistant2/src/context.rs index 6d8583579d2d648751fa3ba0696019e90cd3d58f..3c933717b94e4e6035d014fd732726b442f80dfa 100644 --- a/crates/assistant2/src/context.rs +++ b/crates/assistant2/src/context.rs @@ -19,8 +19,6 @@ impl ContextId { Self(post_inc(&mut self.0)) } } - -#[derive(Debug, Clone, Copy, PartialEq, Eq)] pub enum ContextKind { File, Directory, diff --git a/crates/assistant2/src/context_picker/completion_provider.rs b/crates/assistant2/src/context_picker/completion_provider.rs index f1b11fff47a7455d97f71d5e9728ac78b5c618b0..016d3e8567adec7521c3987a624d7ca416468c29 100644 --- a/crates/assistant2/src/context_picker/completion_provider.rs +++ b/crates/assistant2/src/context_picker/completion_provider.rs @@ -862,7 +862,7 @@ mod tests { .expect("Opened test file wasn't an editor") }); - let context_store = cx.new(|_| ContextStore::new(workspace.downgrade())); + let context_store = cx.new(|_| ContextStore::new(workspace.downgrade(), None)); let editor_entity = editor.downgrade(); editor.update_in(&mut cx, |editor, window, cx| { diff --git a/crates/assistant2/src/context_store.rs b/crates/assistant2/src/context_store.rs index b9af3c1ce59e955f135edb612c8f83b754753077..009e6388168cdc8abb37aec9b9ba0a23ade3ee14 100644 --- a/crates/assistant2/src/context_store.rs +++ b/crates/assistant2/src/context_store.rs @@ -4,15 +4,17 @@ use std::sync::Arc; use anyhow::{Context as _, Result, anyhow}; use collections::{BTreeMap, HashMap, HashSet}; +use futures::future::join_all; use futures::{self, Future, FutureExt, future}; use gpui::{App, AppContext as _, AsyncApp, Context, Entity, SharedString, Task, WeakEntity}; use language::{Buffer, File}; use project::{ProjectItem, ProjectPath, Worktree}; use rope::Rope; use text::{Anchor, BufferId, OffsetRangeExt}; -use util::{ResultExt, maybe}; +use util::{ResultExt as _, maybe}; use workspace::Workspace; +use crate::ThreadStore; use crate::context::{ AssistantContext, ContextBuffer, ContextId, ContextSymbol, ContextSymbolId, DirectoryContext, FetchedUrlContext, FileContext, SymbolContext, ThreadContext, @@ -23,6 +25,7 @@ use crate::thread::{Thread, ThreadId}; pub struct ContextStore { workspace: WeakEntity, context: Vec, + thread_store: Option>, // TODO: If an EntityId is used for all context types (like BufferId), can remove ContextId. next_context_id: ContextId, files: BTreeMap, @@ -31,13 +34,18 @@ pub struct ContextStore { symbol_buffers: HashMap>, symbols_by_path: HashMap>, threads: HashMap, + thread_summary_tasks: Vec>, fetched_urls: HashMap, } impl ContextStore { - pub fn new(workspace: WeakEntity) -> Self { + pub fn new( + workspace: WeakEntity, + thread_store: Option>, + ) -> Self { Self { workspace, + thread_store, context: Vec::new(), next_context_id: ContextId(0), files: BTreeMap::default(), @@ -46,6 +54,7 @@ impl ContextStore { symbol_buffers: HashMap::default(), symbols_by_path: HashMap::default(), threads: HashMap::default(), + thread_summary_tasks: Vec::new(), fetched_urls: HashMap::default(), } } @@ -375,9 +384,39 @@ impl ContextStore { } } - fn insert_thread(&mut self, thread: Entity, cx: &App) { + pub fn wait_for_summaries(&mut self, cx: &App) -> Task<()> { + let tasks = std::mem::take(&mut self.thread_summary_tasks); + + cx.spawn(async move |_cx| { + join_all(tasks).await; + }) + } + + fn insert_thread(&mut self, thread: Entity, cx: &mut App) { + if let Some(summary_task) = + thread.update(cx, |thread, cx| thread.generate_detailed_summary(cx)) + { + let thread = thread.clone(); + let thread_store = self.thread_store.clone(); + + self.thread_summary_tasks.push(cx.spawn(async move |cx| { + summary_task.await; + + if let Some(thread_store) = thread_store { + // Save thread so its summary can be reused later + let save_task = thread_store + .update(cx, |thread_store, cx| thread_store.save_thread(&thread, cx)); + + if let Some(save_task) = save_task.ok() { + save_task.await.log_err(); + } + } + })); + } + let id = self.next_context_id.post_inc(); - let text = thread.read(cx).text().into(); + + let text = thread.read(cx).latest_detailed_summary_or_text(); self.threads.insert(thread.read(cx).id().clone(), id); self.context @@ -865,7 +904,7 @@ fn refresh_thread_text( cx.spawn(async move |cx| { context_store .update(cx, |context_store, cx| { - let text = thread.read(cx).text().into(); + let text = thread.read(cx).latest_detailed_summary_or_text(); context_store.replace_context(AssistantContext::Thread(ThreadContext { id, thread, diff --git a/crates/assistant2/src/inline_assistant.rs b/crates/assistant2/src/inline_assistant.rs index 548a3b15e21a60bc2ac403576fedc9d09f66293c..33a7d1f891204a83a1f4fef712a2534bdaf5ca1f 100644 --- a/crates/assistant2/src/inline_assistant.rs +++ b/crates/assistant2/src/inline_assistant.rs @@ -424,7 +424,8 @@ impl InlineAssistant { let mut assist_to_focus = None; for range in codegen_ranges { let assist_id = self.next_assist_id.post_inc(); - let context_store = cx.new(|_cx| ContextStore::new(workspace.clone())); + let context_store = + cx.new(|_cx| ContextStore::new(workspace.clone(), thread_store.clone())); let codegen = cx.new(|cx| { BufferCodegen::new( editor.read(cx).buffer().clone(), @@ -536,7 +537,8 @@ impl InlineAssistant { range.end = range.end.bias_right(&snapshot); } - let context_store = cx.new(|_cx| ContextStore::new(workspace.clone())); + let context_store = + cx.new(|_cx| ContextStore::new(workspace.clone(), thread_store.clone())); let codegen = cx.new(|cx| { BufferCodegen::new( diff --git a/crates/assistant2/src/message_editor.rs b/crates/assistant2/src/message_editor.rs index 8cbcb580ae76b25ceff352001816551546766212..7bef4a83dde7cabbaa991386a298dd7d82eaaad9 100644 --- a/crates/assistant2/src/message_editor.rs +++ b/crates/assistant2/src/message_editor.rs @@ -19,7 +19,7 @@ use ui::{ ButtonLike, Disclosure, KeyBinding, PlatformStyle, PopoverMenu, PopoverMenuHandle, Tooltip, prelude::*, }; -use util::ResultExt; +use util::ResultExt as _; use vim_mode_setting::VimModeSetting; use workspace::Workspace; @@ -31,7 +31,7 @@ use crate::profile_selector::ProfileSelector; use crate::thread::{RequestKind, Thread}; use crate::thread_store::ThreadStore; use crate::{ - AssistantDiff, Chat, ChatMode, OpenAssistantDiff, RemoveAllContext, ThreadEvent, + AssistantDiff, Chat, ChatMode, NewThread, OpenAssistantDiff, RemoveAllContext, ThreadEvent, ToggleContextPicker, ToggleProfileSelector, }; @@ -49,6 +49,7 @@ pub struct MessageEditor { model_selector: Entity, profile_selector: Entity, edits_expanded: bool, + waiting_for_summaries_to_send: bool, _subscriptions: Vec, } @@ -141,6 +142,7 @@ impl MessageEditor { ) }), edits_expanded: false, + waiting_for_summaries_to_send: false, profile_selector: cx .new(|cx| ProfileSelector::new(fs, thread_store, editor.focus_handle(cx), cx)), _subscriptions: subscriptions, @@ -225,10 +227,12 @@ impl MessageEditor { let thread = self.thread.clone(); let context_store = self.context_store.clone(); let checkpoint = self.project.read(cx).git_store().read(cx).checkpoint(cx); - cx.spawn(async move |_, cx| { + + cx.spawn(async move |this, cx| { let checkpoint = checkpoint.await.ok(); refresh_task.await; let (system_prompt_context, load_error) = system_prompt_context_task.await; + thread .update(cx, |thread, cx| { thread.set_system_prompt_context(system_prompt_context); @@ -237,6 +241,7 @@ impl MessageEditor { } }) .ok(); + thread .update(cx, |thread, cx| { let context = context_store.read(cx).context().clone(); @@ -244,6 +249,31 @@ impl MessageEditor { action_log.clear_reviewed_changes(cx); }); thread.insert_user_message(user_message, context, checkpoint, cx); + }) + .ok(); + + if let Some(wait_for_summaries) = context_store + .update(cx, |context_store, cx| context_store.wait_for_summaries(cx)) + .log_err() + { + this.update(cx, |this, cx| { + this.waiting_for_summaries_to_send = true; + cx.notify(); + }) + .ok(); + + wait_for_summaries.await; + + this.update(cx, |this, cx| { + this.waiting_for_summaries_to_send = false; + cx.notify(); + }) + .ok(); + } + + // Send to model after summaries are done + thread + .update(cx, |thread, cx| { thread.send_to_model(model, request_kind, cx); }) .ok(); @@ -309,7 +339,9 @@ impl Render for MessageEditor { let focus_handle = self.editor.focus_handle(cx); let inline_context_picker = self.inline_context_picker.clone(); - let is_generating = self.thread.read(cx).is_generating(); + let thread = self.thread.read(cx); + let is_generating = thread.is_generating(); + let is_too_long = thread.is_getting_too_long(cx); let is_model_selected = self.is_model_selected(cx); let is_editor_empty = self.is_editor_empty(cx); let submit_label_color = if is_editor_empty { @@ -339,6 +371,41 @@ impl Render for MessageEditor { v_flex() .size_full() + .when(self.waiting_for_summaries_to_send, |parent| { + parent.child( + h_flex().py_3().w_full().justify_center().child( + h_flex() + .flex_none() + .px_2() + .py_2() + .bg(editor_bg_color) + .border_1() + .border_color(cx.theme().colors().border_variant) + .rounded_lg() + .shadow_md() + .gap_1() + .child( + Icon::new(IconName::ArrowCircle) + .size(IconSize::XSmall) + .color(Color::Muted) + .with_animation( + "arrow-circle", + Animation::new(Duration::from_secs(2)).repeat(), + |icon, delta| { + icon.transform(gpui::Transformation::rotate( + gpui::percentage(delta), + )) + }, + ), + ) + .child( + Label::new("Summarizing context…") + .size(LabelSize::XSmall) + .color(Color::Muted), + ), + ), + ) + }) .when(is_generating, |parent| { let focus_handle = self.editor.focus_handle(cx).clone(); parent.child( @@ -622,28 +689,29 @@ impl Render for MessageEditor { v_flex() .gap_5() .child({ - let settings = ThemeSettings::get_global(cx); - let text_style = TextStyle { - color: cx.theme().colors().text, - font_family: settings.ui_font.family.clone(), - font_fallbacks: settings.ui_font.fallbacks.clone(), - font_features: settings.ui_font.features.clone(), - font_size: font_size.into(), - font_weight: settings.ui_font.weight, - line_height: line_height.into(), - ..Default::default() - }; - - EditorElement::new( - &self.editor, - EditorStyle { - background: editor_bg_color, - local_player: cx.theme().players().local(), - text: text_style, - syntax: cx.theme().syntax().clone(), + let settings = ThemeSettings::get_global(cx); + let text_style = TextStyle { + color: cx.theme().colors().text, + font_family: settings.ui_font.family.clone(), + font_fallbacks: settings.ui_font.fallbacks.clone(), + font_features: settings.ui_font.features.clone(), + font_size: font_size.into(), + font_weight: settings.ui_font.weight, + line_height: line_height.into(), ..Default::default() - }, - ) + }; + + EditorElement::new( + &self.editor, + EditorStyle { + background: editor_bg_color, + local_player: cx.theme().players().local(), + text: text_style, + syntax: cx.theme().syntax().clone(), + ..Default::default() + }, + ).into_any() + }) .child( PopoverMenu::new("inline-context-picker") @@ -675,7 +743,8 @@ impl Render for MessageEditor { .disabled( is_editor_empty || !is_model_selected - || is_generating, + || is_generating + || self.waiting_for_summaries_to_send ) .child( h_flex() @@ -723,7 +792,61 @@ impl Render for MessageEditor { ), ), ), - ), + ) ) + .when(is_too_long, |parent| { + parent.child( + h_flex() + .p_2() + .gap_2() + .flex_wrap() + .justify_between() + .bg(cx.theme().status().warning_background.opacity(0.1)) + .border_t_1() + .border_color(cx.theme().colors().border) + .child( + h_flex() + .gap_2() + .items_start() + .child( + h_flex() + .h(line_height) + .justify_center() + .child( + Icon::new(IconName::Warning) + .color(Color::Warning) + .size(IconSize::XSmall), + ), + ) + .child( + v_flex() + .mr_auto() + .child(Label::new("Thread reaching the token limit soon").size(LabelSize::Small)) + .child( + Label::new( + "Start a new thread from a summary to continue the conversation.", + ) + .size(LabelSize::Small) + .color(Color::Muted), + ), + ), + ) + .child( + Button::new("new-thread", "Start New Thread") + .on_click(cx.listener(|this, _, window, cx| { + let from_thread_id = Some(this.thread.read(cx).id().clone()); + + window.dispatch_action(Box::new(NewThread { + from_thread_id + }), cx); + })) + .icon(IconName::Plus) + .icon_position(IconPosition::Start) + .icon_size(IconSize::Small) + .style(ButtonStyle::Tinted(ui::TintColor::Accent)) + .label_size(LabelSize::Small), + ), + ) + }) } } diff --git a/crates/assistant2/src/terminal_inline_assistant.rs b/crates/assistant2/src/terminal_inline_assistant.rs index a67e1ed3477e4e9d47653b241771730460f895cb..15460609bb33df487dd7e0d905b18d6088921a4f 100644 --- a/crates/assistant2/src/terminal_inline_assistant.rs +++ b/crates/assistant2/src/terminal_inline_assistant.rs @@ -75,7 +75,8 @@ impl TerminalInlineAssistant { let assist_id = self.next_assist_id.post_inc(); let prompt_buffer = cx.new(|cx| MultiBuffer::singleton(cx.new(|cx| Buffer::local(String::new(), cx)), cx)); - let context_store = cx.new(|_cx| ContextStore::new(workspace.clone())); + let context_store = + cx.new(|_cx| ContextStore::new(workspace.clone(), thread_store.clone())); let codegen = cx.new(|_| TerminalCodegen::new(terminal, self.telemetry.clone())); let prompt_editor = cx.new(|cx| { diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index 8115ed4fdc251f3b3d656a2e3b90b3effed53e06..e84ffd07f48d23c3592bd3245a674575983f2ed9 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -24,6 +24,7 @@ use project::{Project, Worktree}; use prompt_store::{ AssistantSystemPromptContext, PromptBuilder, RulesFile, WorktreeInfoForSystemPrompt, }; +use schemars::JsonSchema; use serde::{Deserialize, Serialize}; use settings::Settings; use util::{ResultExt as _, TryFutureExt as _, maybe, post_inc}; @@ -43,7 +44,9 @@ pub enum RequestKind { Summarize, } -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize)] +#[derive( + Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Serialize, Deserialize, JsonSchema, +)] pub struct ThreadId(Arc); impl ThreadId { @@ -173,12 +176,26 @@ impl LastRestoreCheckpoint { } } +#[derive(Clone, Debug, Default, Serialize, Deserialize)] +pub enum DetailedSummaryState { + #[default] + NotGenerated, + Generating { + message_id: MessageId, + }, + Generated { + text: SharedString, + message_id: MessageId, + }, +} + /// A thread of conversation with the LLM. pub struct Thread { id: ThreadId, updated_at: DateTime, summary: Option, pending_summary: Task>, + detailed_summary_state: DetailedSummaryState, messages: Vec, next_message_id: MessageId, context: BTreeMap, @@ -211,6 +228,7 @@ impl Thread { updated_at: Utc::now(), summary: None, pending_summary: Task::ready(None), + detailed_summary_state: DetailedSummaryState::NotGenerated, messages: Vec::new(), next_message_id: MessageId(0), context: BTreeMap::default(), @@ -260,6 +278,7 @@ impl Thread { updated_at: serialized.updated_at, summary: Some(serialized.summary), pending_summary: Task::ready(None), + detailed_summary_state: serialized.detailed_summary_state, messages: serialized .messages .into_iter() @@ -328,6 +347,19 @@ impl Thread { cx.emit(ThreadEvent::SummaryChanged); } + pub fn latest_detailed_summary_or_text(&self) -> SharedString { + self.latest_detailed_summary() + .unwrap_or_else(|| self.text().into()) + } + + fn latest_detailed_summary(&self) -> Option { + if let DetailedSummaryState::Generated { text, .. } = &self.detailed_summary_state { + Some(text.clone()) + } else { + None + } + } + pub fn message(&self, id: MessageId) -> Option<&Message> { self.messages.iter().find(|message| message.id == id) } @@ -658,6 +690,7 @@ impl Thread { .collect(), initial_project_snapshot, cumulative_token_usage: this.cumulative_token_usage.clone(), + detailed_summary_state: this.detailed_summary_state.clone(), }) }) } @@ -1202,6 +1235,87 @@ impl Thread { }); } + pub fn generate_detailed_summary(&mut self, cx: &mut Context) -> Option> { + let last_message_id = self.messages.last().map(|message| message.id)?; + + match &self.detailed_summary_state { + DetailedSummaryState::Generating { message_id, .. } + | DetailedSummaryState::Generated { message_id, .. } + if *message_id == last_message_id => + { + // Already up-to-date + return None; + } + _ => {} + } + + let provider = LanguageModelRegistry::read_global(cx).active_provider()?; + let model = LanguageModelRegistry::read_global(cx).active_model()?; + + if !provider.is_authenticated(cx) { + return None; + } + + let mut request = self.to_completion_request(RequestKind::Summarize, cx); + + request.messages.push(LanguageModelRequestMessage { + role: Role::User, + content: vec![ + "Generate a detailed summary of this conversation. Include:\n\ + 1. A brief overview of what was discussed\n\ + 2. Key facts or information discovered\n\ + 3. Outcomes or conclusions reached\n\ + 4. Any action items or next steps if any\n\ + Format it in Markdown with headings and bullet points." + .into(), + ], + cache: false, + }); + + let task = cx.spawn(async move |thread, cx| { + let stream = model.stream_completion_text(request, &cx); + let Some(mut messages) = stream.await.log_err() else { + thread + .update(cx, |this, _cx| { + this.detailed_summary_state = DetailedSummaryState::NotGenerated; + }) + .log_err(); + + return; + }; + + let mut new_detailed_summary = String::new(); + + while let Some(chunk) = messages.stream.next().await { + if let Some(chunk) = chunk.log_err() { + new_detailed_summary.push_str(&chunk); + } + } + + thread + .update(cx, |this, _cx| { + this.detailed_summary_state = DetailedSummaryState::Generated { + text: new_detailed_summary.into(), + message_id: last_message_id, + }; + }) + .log_err(); + }); + + self.detailed_summary_state = DetailedSummaryState::Generating { + message_id: last_message_id, + }; + + Some(task) + } + + pub fn is_generating_detailed_summary(&self) -> bool { + matches!( + self.detailed_summary_state, + DetailedSummaryState::Generating { .. } + ) + } + pub fn use_pending_tools( &mut self, cx: &mut Context, @@ -1596,6 +1710,28 @@ impl Thread { self.cumulative_token_usage.clone() } + pub fn is_getting_too_long(&self, cx: &App) -> bool { + let model_registry = LanguageModelRegistry::read_global(cx); + let Some(model) = model_registry.active_model() else { + return false; + }; + + let max_tokens = model.max_token_count(); + + let current_usage = + self.cumulative_token_usage.input_tokens + self.cumulative_token_usage.output_tokens; + + #[cfg(debug_assertions)] + let warning_threshold: f32 = std::env::var("ZED_THREAD_WARNING_THRESHOLD") + .unwrap_or("0.9".to_string()) + .parse() + .unwrap(); + #[cfg(not(debug_assertions))] + let warning_threshold: f32 = 0.9; + + current_usage as f32 >= (max_tokens as f32 * warning_threshold) + } + pub fn deny_tool_use( &mut self, tool_use_id: LanguageModelToolUseId, diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index bd540d5d26c3dbd93adbf97df9b8b74968f3e962..3a3c497b16c08ff26aff6c2be4ae46645b78742d 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -24,7 +24,9 @@ use serde::{Deserialize, Serialize}; use settings::{Settings as _, SettingsStore}; use util::ResultExt as _; -use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId}; +use crate::thread::{ + DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId, +}; pub fn init(cx: &mut App) { ThreadsDatabase::init(cx); @@ -320,7 +322,7 @@ pub struct SerializedThreadMetadata { pub updated_at: DateTime, } -#[derive(Serialize, Deserialize)] +#[derive(Serialize, Deserialize, Debug)] pub struct SerializedThread { pub version: String, pub summary: SharedString, @@ -330,6 +332,8 @@ pub struct SerializedThread { pub initial_project_snapshot: Option>, #[serde(default)] pub cumulative_token_usage: TokenUsage, + #[serde(default)] + pub detailed_summary_state: DetailedSummaryState, } impl SerializedThread { @@ -413,6 +417,7 @@ impl LegacySerializedThread { messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(), initial_project_snapshot: self.initial_project_snapshot, cumulative_token_usage: TokenUsage::default(), + detailed_summary_state: DetailedSummaryState::default(), } } } diff --git a/crates/assistant2/src/ui/context_pill.rs b/crates/assistant2/src/ui/context_pill.rs index dd39a48e2ec7345521602c807a1483272ea498ca..6ca80cc27e8689a80d26e00b8c78884048695e0b 100644 --- a/crates/assistant2/src/ui/context_pill.rs +++ b/crates/assistant2/src/ui/context_pill.rs @@ -1,7 +1,8 @@ -use std::rc::Rc; +use std::{rc::Rc, time::Duration}; use file_icons::FileIcons; use gpui::ClickEvent; +use gpui::{Animation, AnimationExt as _, pulsating_between}; use ui::{IconButtonShape, Tooltip, prelude::*}; use crate::context::{AssistantContext, ContextId, ContextKind}; @@ -170,6 +171,22 @@ impl RenderOnce for ContextPill { element .cursor_pointer() .on_click(move |event, window, cx| on_click(event, window, cx)) + }) + .map(|element| { + if context.summarizing { + element + .tooltip(ui::Tooltip::text("Summarizing...")) + .with_animation( + "pulsating-ctx-pill", + Animation::new(Duration::from_secs(2)) + .repeat() + .with_easing(pulsating_between(0.4, 0.8)), + |label, delta| label.opacity(delta), + ) + .into_any_element() + } else { + element.into_any() + } }), ContextPill::Suggested { name, @@ -220,7 +237,8 @@ impl RenderOnce for ContextPill { .when_some(on_click.as_ref(), |element, on_click| { let on_click = on_click.clone(); element.on_click(move |event, window, cx| on_click(event, window, cx)) - }), + }) + .into_any(), } } } @@ -232,6 +250,7 @@ pub struct AddedContext { pub parent: Option, pub tooltip: Option, pub icon_path: Option, + pub summarizing: bool, } impl AddedContext { @@ -256,6 +275,7 @@ impl AddedContext { parent, tooltip: Some(full_path_string), icon_path: FileIcons::get_icon(&full_path, cx), + summarizing: false, } } @@ -280,6 +300,7 @@ impl AddedContext { parent, tooltip: Some(full_path_string), icon_path: None, + summarizing: false, } } @@ -290,6 +311,7 @@ impl AddedContext { parent: None, tooltip: None, icon_path: None, + summarizing: false, }, AssistantContext::FetchedUrl(fetched_url_context) => AddedContext { @@ -299,6 +321,7 @@ impl AddedContext { parent: None, tooltip: None, icon_path: None, + summarizing: false, }, AssistantContext::Thread(thread_context) => AddedContext { @@ -308,6 +331,10 @@ impl AddedContext { parent: None, tooltip: None, icon_path: None, + summarizing: thread_context + .thread + .read(cx) + .is_generating_detailed_summary(), }, } }