From 6b58e38f4af1a63fef7f090f22222fbf374314b7 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 12 May 2025 04:45:48 -0300 Subject: [PATCH] agent: Handle thread title generation errors (#30273) The title of a (text) thread would get stuck in "Loading Summary..." when the request to generate it failed. We now handle this case by falling back to the default title, and letting the user manually edit the title or retry generating it. https://github.com/user-attachments/assets/898d26ad-d31f-4b62-9b05-519d923b1b22 Release Notes: - agent: Handle thread title generation errors --------- Co-authored-by: Richard Feldman --- crates/agent/src/active_thread.rs | 19 +- crates/agent/src/agent_diff.rs | 12 +- crates/agent/src/agent_panel.rs | 92 +++-- crates/agent/src/context.rs | 11 +- crates/agent/src/context_strip.rs | 4 +- crates/agent/src/history_store.rs | 4 +- crates/agent/src/thread.rs | 319 ++++++++++++++++-- .../assistant_context_editor/src/context.rs | 157 ++++++--- .../src/context/context_tests.rs | 189 ++++++++++- .../src/context_editor.rs | 7 +- .../src/context_store.rs | 5 +- 11 files changed, 678 insertions(+), 141 deletions(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 841482e4823f4ba308999ad4ad79258907f0ac46..c0d42bf5a2fafeed2e3b308118e9edca4e5d4684 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -6,7 +6,7 @@ use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::message_editor::insert_message_creases; use crate::thread::{ LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, Thread, ThreadError, - ThreadEvent, ThreadFeedback, + ThreadEvent, ThreadFeedback, ThreadSummary, }; use crate::thread_store::{RulesLoadingError, TextThreadStore, ThreadStore}; use crate::tool_use::{PendingToolUseStatus, ToolUse}; @@ -823,12 +823,12 @@ impl ActiveThread { self.messages.is_empty() } - pub fn summary(&self, cx: &App) -> Option { + pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadSummary { self.thread.read(cx).summary() } - pub fn summary_or_default(&self, cx: &App) -> SharedString { - self.thread.read(cx).summary_or_default() + pub fn regenerate_summary(&self, cx: &mut App) { + self.thread.update(cx, |thread, cx| thread.summarize(cx)) } pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool { @@ -1134,11 +1134,7 @@ impl ActiveThread { return; } - let title = self - .thread - .read(cx) - .summary() - .unwrap_or("Agent Panel".into()); + let title = self.thread.read(cx).summary().unwrap_or("Agent Panel"); match AssistantSettings::get_global(cx).notify_when_agent_waiting { NotifyWhenAgentWaiting::PrimaryScreen => { @@ -3442,10 +3438,7 @@ pub(crate) fn open_active_thread_as_markdown( workspace.update_in(cx, |workspace, window, cx| { let thread = thread.read(cx); let markdown = thread.to_markdown(cx)?; - let thread_summary = thread - .summary() - .map(|summary| summary.to_string()) - .unwrap_or_else(|| "Thread".to_string()); + let thread_summary = thread.summary().or_default().to_string(); let project = workspace.project().clone(); diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index 54e4f1f9aa90bac05370596415e8bebc1df267a5..d83b2cf80df1a57bc9ad3444eadf196d10112a7f 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -215,11 +215,7 @@ impl AgentDiffPane { } fn update_title(&mut self, cx: &mut Context) { - let new_title = self - .thread - .read(cx) - .summary() - .unwrap_or("Agent Changes".into()); + let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes"); if new_title != self.title { self.title = new_title; cx.emit(EditorEvent::TitleChanged); @@ -469,11 +465,7 @@ impl Item for AgentDiffPane { } fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { - let summary = self - .thread - .read(cx) - .summary() - .unwrap_or("Agent Changes".into()); + let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes"); Label::new(format!("Review: {}", summary)) .color(if params.selected { Color::Default diff --git a/crates/agent/src/agent_panel.rs b/crates/agent/src/agent_panel.rs index c48863ae41b98f3747cfe192673baa5ffb9a2f8f..eb1635bf4b4b741f62c2dd46eff3274de307aa1e 100644 --- a/crates/agent/src/agent_panel.rs +++ b/crates/agent/src/agent_panel.rs @@ -10,8 +10,8 @@ use serde::{Deserialize, Serialize}; use anyhow::{Result, anyhow}; use assistant_context_editor::{ AgentPanelDelegate, AssistantContext, ConfigurationError, ContextEditor, ContextEvent, - SlashCommandCompletionProvider, humanize_token_count, make_lsp_adapter_delegate, - render_remaining_tokens, + ContextSummary, SlashCommandCompletionProvider, humanize_token_count, + make_lsp_adapter_delegate, render_remaining_tokens, }; use assistant_settings::{AssistantDockPosition, AssistantSettings}; use assistant_slash_command::SlashCommandWorkingSet; @@ -59,7 +59,7 @@ use crate::agent_configuration::{AgentConfiguration, AssistantConfigurationEvent use crate::agent_diff::AgentDiff; use crate::history_store::{HistoryStore, RecentEntry}; use crate::message_editor::{MessageEditor, MessageEditorEvent}; -use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio}; +use crate::thread::{Thread, ThreadError, ThreadId, ThreadSummary, TokenUsageRatio}; use crate::thread_history::{HistoryEntryElement, ThreadHistory}; use crate::thread_store::ThreadStore; use crate::ui::AgentOnboardingModal; @@ -196,7 +196,7 @@ impl ActiveView { } pub fn thread(thread: Entity, window: &mut Window, cx: &mut App) -> Self { - let summary = thread.read(cx).summary_or_default(); + let summary = thread.read(cx).summary().or_default(); let editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); @@ -218,7 +218,7 @@ impl ActiveView { } EditorEvent::Blurred => { if editor.read(cx).text(cx).is_empty() { - let summary = thread.read(cx).summary_or_default(); + let summary = thread.read(cx).summary().or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -233,7 +233,7 @@ impl ActiveView { let editor = editor.clone(); move |thread, event, window, cx| match event { ThreadEvent::SummaryGenerated => { - let summary = thread.read(cx).summary_or_default(); + let summary = thread.read(cx).summary().or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -296,7 +296,8 @@ impl ActiveView { .read(cx) .context() .read(cx) - .summary_or_default(); + .summary() + .or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -311,7 +312,7 @@ impl ActiveView { let editor = editor.clone(); move |assistant_context, event, window, cx| match event { ContextEvent::SummaryGenerated => { - let summary = assistant_context.read(cx).summary_or_default(); + let summary = assistant_context.read(cx).summary().or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -1452,23 +1453,45 @@ impl AgentPanel { .. } => { let active_thread = self.thread.read(cx); - let is_empty = active_thread.is_empty(); - - let summary = active_thread.summary(cx); + let state = if active_thread.is_empty() { + &ThreadSummary::Pending + } else { + active_thread.summary(cx) + }; - if is_empty { - Label::new(Thread::DEFAULT_SUMMARY.clone()) + match state { + ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT.clone()) .truncate() - .into_any_element() - } else if summary.is_none() { - Label::new(LOADING_SUMMARY_PLACEHOLDER) + .into_any_element(), + ThreadSummary::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER) .truncate() - .into_any_element() - } else { - div() + .into_any_element(), + ThreadSummary::Ready(_) => div() .w_full() .child(change_title_editor.clone()) - .into_any_element() + .into_any_element(), + ThreadSummary::Error => h_flex() + .w_full() + .child(change_title_editor.clone()) + .child( + ui::IconButton::new("retry-summary-generation", IconName::RotateCcw) + .on_click({ + let active_thread = self.thread.clone(); + move |_, _window, cx| { + active_thread.update(cx, |thread, cx| { + thread.regenerate_summary(cx); + }); + } + }) + .tooltip(move |_window, cx| { + cx.new(|_| { + Tooltip::new("Failed to generate title") + .meta("Click to try again") + }) + .into() + }), + ) + .into_any_element(), } } ActiveView::PromptEditor { @@ -1476,14 +1499,13 @@ impl AgentPanel { context_editor, .. } => { - let context_editor = context_editor.read(cx); - let summary = context_editor.context().read(cx).summary(); + let summary = context_editor.read(cx).context().read(cx).summary(); match summary { - None => Label::new(AssistantContext::DEFAULT_SUMMARY.clone()) + ContextSummary::Pending => Label::new(ContextSummary::DEFAULT) .truncate() .into_any_element(), - Some(summary) => { + ContextSummary::Content(summary) => { if summary.done { div() .w_full() @@ -1495,6 +1517,28 @@ impl AgentPanel { .into_any_element() } } + ContextSummary::Error => h_flex() + .w_full() + .child(title_editor.clone()) + .child( + ui::IconButton::new("retry-summary-generation", IconName::RotateCcw) + .on_click({ + let context_editor = context_editor.clone(); + move |_, _window, cx| { + context_editor.update(cx, |context_editor, cx| { + context_editor.regenerate_summary(cx); + }); + } + }) + .tooltip(move |_window, cx| { + cx.new(|_| { + Tooltip::new("Failed to generate title") + .meta("Click to try again") + }) + .into() + }), + ) + .into_any_element(), } } ActiveView::History => Label::new("History").truncate().into_any_element(), diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index 4fa8183e714cbdd50c49d658f70168aa9d2fa475..5bcd1eb3c55f6f5f6c09dd55a96315e23ca1b3b0 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -586,10 +586,7 @@ impl ThreadContextHandle { } pub fn title(&self, cx: &App) -> SharedString { - self.thread - .read(cx) - .summary() - .unwrap_or_else(|| "New thread".into()) + self.thread.read(cx).summary().or_default() } fn load(self, cx: &App) -> Task>)>> { @@ -597,9 +594,7 @@ impl ThreadContextHandle { let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?; let title = self .thread - .read_with(cx, |thread, _cx| { - thread.summary().unwrap_or_else(|| "New thread".into()) - }) + .read_with(cx, |thread, _cx| thread.summary().or_default()) .ok()?; let context = AgentContext::Thread(ThreadContext { title, @@ -642,7 +637,7 @@ impl TextThreadContextHandle { } pub fn title(&self, cx: &App) -> SharedString { - self.context.read(cx).summary_or_default() + self.context.read(cx).summary().or_default() } fn load(self, cx: &App) -> Task>)>> { diff --git a/crates/agent/src/context_strip.rs b/crates/agent/src/context_strip.rs index f9d9ff278181b640eb029e8f7359ef9469487996..8fe1a21d7480d2703fa0545712baad7c1e2463e6 100644 --- a/crates/agent/src/context_strip.rs +++ b/crates/agent/src/context_strip.rs @@ -160,7 +160,7 @@ impl ContextStrip { } Some(SuggestedContext::Thread { - name: active_thread.summary_or_default(), + name: active_thread.summary().or_default(), thread: weak_active_thread, }) } else if let Some(active_context_editor) = panel.active_context_editor() { @@ -174,7 +174,7 @@ impl ContextStrip { } Some(SuggestedContext::TextThread { - name: context.summary_or_default(), + name: context.summary().or_default(), context: weak_context, }) } else { diff --git a/crates/agent/src/history_store.rs b/crates/agent/src/history_store.rs index 85fdac2bab0066d8c7025fb60139d9ba62db68bd..c8d9e9a26396bb57aae30664f581a7bff1b78984 100644 --- a/crates/agent/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -71,8 +71,8 @@ impl Eq for RecentEntry {} impl RecentEntry { pub(crate) fn summary(&self, cx: &App) -> SharedString { match self { - RecentEntry::Thread(_, thread) => thread.read(cx).summary_or_default(), - RecentEntry::Context(context) => context.read(cx).summary_or_default(), + RecentEntry::Thread(_, thread) => thread.read(cx).summary().or_default(), + RecentEntry::Context(context) => context.read(cx).summary().or_default(), } } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index 0f349514eb818bd427f19da03adae69e1e4025c7..eed797ca7b27a2fbb74c732aa63c290899b8aa72 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -36,7 +36,7 @@ use serde::{Deserialize, Serialize}; use settings::Settings; use thiserror::Error; use ui::Window; -use util::{ResultExt as _, TryFutureExt as _, post_inc}; +use util::{ResultExt as _, post_inc}; use uuid::Uuid; use zed_llm_client::CompletionRequestStatus; @@ -324,7 +324,7 @@ pub enum QueueState { pub struct Thread { id: ThreadId, updated_at: DateTime, - summary: Option, + summary: ThreadSummary, pending_summary: Task>, detailed_summary_task: Task>, detailed_summary_tx: postage::watch::Sender, @@ -361,6 +361,33 @@ pub struct Thread { configured_model: Option, } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ThreadSummary { + Pending, + Generating, + Ready(SharedString), + Error, +} + +impl ThreadSummary { + pub const DEFAULT: SharedString = SharedString::new_static("New Thread"); + + pub fn or_default(&self) -> SharedString { + self.unwrap_or(Self::DEFAULT) + } + + pub fn unwrap_or(&self, message: impl Into) -> SharedString { + self.ready().unwrap_or_else(|| message.into()) + } + + pub fn ready(&self) -> Option { + match self { + ThreadSummary::Ready(summary) => Some(summary.clone()), + ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExceededWindowError { /// Model used when last message exceeded context window @@ -383,7 +410,7 @@ impl Thread { Self { id: ThreadId::new(), updated_at: Utc::now(), - summary: None, + summary: ThreadSummary::Pending, pending_summary: Task::ready(None), detailed_summary_task: Task::ready(None), detailed_summary_tx, @@ -471,7 +498,7 @@ impl Thread { Self { id, updated_at: serialized.updated_at, - summary: Some(serialized.summary), + summary: ThreadSummary::Ready(serialized.summary), pending_summary: Task::ready(None), detailed_summary_task: Task::ready(None), detailed_summary_tx, @@ -572,10 +599,6 @@ impl Thread { self.last_prompt_id = PromptId::new(); } - pub fn summary(&self) -> Option { - self.summary.clone() - } - pub fn project_context(&self) -> SharedProjectContext { self.project_context.clone() } @@ -596,26 +619,25 @@ impl Thread { cx.notify(); } - pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread"); - - pub fn summary_or_default(&self) -> SharedString { - self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY) + pub fn summary(&self) -> &ThreadSummary { + &self.summary } pub fn set_summary(&mut self, new_summary: impl Into, cx: &mut Context) { - let Some(current_summary) = &self.summary else { - // Don't allow setting summary until generated - return; + let current_summary = match &self.summary { + ThreadSummary::Pending | ThreadSummary::Generating => return, + ThreadSummary::Ready(summary) => summary, + ThreadSummary::Error => &ThreadSummary::DEFAULT, }; let mut new_summary = new_summary.into(); if new_summary.is_empty() { - new_summary = Self::DEFAULT_SUMMARY; + new_summary = ThreadSummary::DEFAULT; } if current_summary != &new_summary { - self.summary = Some(new_summary); + self.summary = ThreadSummary::Ready(new_summary); cx.emit(ThreadEvent::SummaryChanged); } } @@ -1029,7 +1051,7 @@ impl Thread { let initial_project_snapshot = initial_project_snapshot.await; this.read_with(cx, |this, cx| SerializedThread { version: SerializedThread::VERSION.to_string(), - summary: this.summary_or_default(), + summary: this.summary().or_default(), updated_at: this.updated_at(), messages: this .messages() @@ -1625,7 +1647,7 @@ impl Thread { // If there is a response without tool use, summarize the message. Otherwise, // allow two tool uses before summarizing. - if thread.summary.is_none() + if matches!(thread.summary, ThreadSummary::Pending) && thread.messages.len() >= 2 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6) { @@ -1739,6 +1761,7 @@ impl Thread { pub fn summarize(&mut self, cx: &mut Context) { let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else { + println!("No thread summary model"); return; }; @@ -1753,13 +1776,17 @@ impl Thread { let request = self.to_summarize_request(&model.model, added_user_message.into(), cx); + self.summary = ThreadSummary::Generating; + self.pending_summary = cx.spawn(async move |this, cx| { - async move { + let result = async { let mut messages = model.model.stream_completion(request, &cx).await?; let mut new_summary = String::new(); while let Some(event) = messages.next().await { - let event = event?; + let Ok(event) = event else { + continue; + }; let text = match event { LanguageModelCompletionEvent::Text(text) => text, LanguageModelCompletionEvent::StatusUpdate( @@ -1785,18 +1812,29 @@ impl Thread { } } - this.update(cx, |this, cx| { - if !new_summary.is_empty() { - this.summary = Some(new_summary.into()); - } + anyhow::Ok(new_summary) + } + .await; - cx.emit(ThreadEvent::SummaryGenerated); - })?; + this.update(cx, |this, cx| { + match result { + Ok(new_summary) => { + if new_summary.is_empty() { + this.summary = ThreadSummary::Error; + } else { + this.summary = ThreadSummary::Ready(new_summary.into()); + } + } + Err(err) => { + this.summary = ThreadSummary::Error; + log::error!("Failed to generate thread summary: {}", err); + } + } + cx.emit(ThreadEvent::SummaryGenerated); + }) + .log_err()?; - anyhow::Ok(()) - } - .log_err() - .await + Some(()) }); } @@ -2406,9 +2444,8 @@ impl Thread { pub fn to_markdown(&self, cx: &App) -> Result { let mut markdown = Vec::new(); - if let Some(summary) = self.summary() { - writeln!(markdown, "# {summary}\n")?; - }; + let summary = self.summary().or_default(); + writeln!(markdown, "# {summary}\n")?; for message in self.messages() { writeln!( @@ -2725,7 +2762,7 @@ mod tests { use assistant_tool::ToolRegistry; use editor::EditorSettings; use gpui::TestAppContext; - use language_model::fake_provider::FakeLanguageModel; + use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; use project::{FakeFs, Project}; use prompt_store::PromptBuilder; use serde_json::json; @@ -3226,6 +3263,196 @@ fn main() {{ assert_eq!(request.temperature, None); } + #[gpui::test] + async fn test_thread_summary(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + + let (_, _thread_store, thread, _context_store, model) = + setup_test_environment(cx, project.clone()).await; + + // Initial state should be pending + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Pending)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + + // Manually setting the summary should not be allowed in this state + thread.update(cx, |thread, cx| { + thread.set_summary("This should not work", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Pending)); + }); + + // Send a message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); + thread.send_to_model(model.clone(), None, cx); + }); + + let fake_model = model.as_fake(); + simulate_successful_response(&fake_model, cx); + + // Should start generating summary when there are >= 2 messages + thread.read_with(cx, |thread, _| { + assert_eq!(*thread.summary(), ThreadSummary::Generating); + }); + + // Should not be able to set the summary while generating + thread.update(cx, |thread, cx| { + thread.set_summary("This should not work either", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Generating)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("Brief".into()); + fake_model.stream_last_completion_response(" Introduction".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Summary should be set + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), "Brief Introduction"); + }); + + // Now we should be able to set a summary + thread.update(cx, |thread, cx| { + thread.set_summary("Brief Intro", cx); + }); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.summary().or_default(), "Brief Intro"); + }); + + // Test setting an empty summary (should default to DEFAULT) + thread.update(cx, |thread, cx| { + thread.set_summary("", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + } + + #[gpui::test] + async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + + let (_, _thread_store, thread, _context_store, model) = + setup_test_environment(cx, project.clone()).await; + + test_summarize_error(&model, &thread, cx); + + // Now we should be able to set a summary + thread.update(cx, |thread, cx| { + thread.set_summary("Brief Intro", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), "Brief Intro"); + }); + } + + #[gpui::test] + async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + + let (_, _thread_store, thread, _context_store, model) = + setup_test_environment(cx, project.clone()).await; + + test_summarize_error(&model, &thread, cx); + + // Sending another message should not trigger another summarize request + thread.update(cx, |thread, cx| { + thread.insert_user_message( + "How are you?", + ContextLoadResult::default(), + None, + vec![], + cx, + ); + thread.send_to_model(model.clone(), None, cx); + }); + + let fake_model = model.as_fake(); + simulate_successful_response(&fake_model, cx); + + thread.read_with(cx, |thread, _| { + // State is still Error, not Generating + assert!(matches!(thread.summary(), ThreadSummary::Error)); + }); + + // But the summarize request can be invoked manually + thread.update(cx, |thread, cx| { + thread.summarize(cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Generating)); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("A successful summary".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), "A successful summary"); + }); + } + + fn test_summarize_error( + model: &Arc, + thread: &Entity, + cx: &mut TestAppContext, + ) { + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); + thread.send_to_model(model.clone(), None, cx); + }); + + let fake_model = model.as_fake(); + simulate_successful_response(&fake_model, cx); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Generating)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + + // Simulate summary request ending + cx.run_until_parked(); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // State is set to Error and default message + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Error)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + } + + fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { + cx.run_until_parked(); + fake_model.stream_last_completion_response("Assistant response".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + } + fn init_test_settings(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); @@ -3282,9 +3509,29 @@ fn main() {{ let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); - let model = FakeLanguageModel::default(); + let provider = Arc::new(FakeLanguageModelProvider); + let model = provider.test_model(); let model: Arc = Arc::new(model); + cx.update(|_, cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.set_default_model( + Some(ConfiguredModel { + provider: provider.clone(), + model: model.clone(), + }), + cx, + ); + registry.set_thread_summary_model( + Some(ConfiguredModel { + provider, + model: model.clone(), + }), + cx, + ); + }) + }); + (workspace, thread_store, thread, context_store, model) } diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index 793cf38b8a9fb365da27a313f5bb66a38fe76786..fcdd624f92f7a1c4609b48b54e0bd73338ef6fb5 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -2,7 +2,7 @@ mod context_tests; use crate::patch::{AssistantEdit, AssistantPatch, AssistantPatchStatus}; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result, anyhow, bail}; use assistant_settings::AssistantSettings; use assistant_slash_command::{ SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection, @@ -143,7 +143,7 @@ pub enum ContextOperation { version: clock::Global, }, UpdateSummary { - summary: ContextSummary, + summary: ContextSummaryContent, version: clock::Global, }, SlashCommandStarted { @@ -213,7 +213,7 @@ impl ContextOperation { version: language::proto::deserialize_version(&update.version), }), proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary { - summary: ContextSummary { + summary: ContextSummaryContent { text: update.summary, done: update.done, timestamp: language::proto::deserialize_timestamp( @@ -481,11 +481,73 @@ pub enum ContextEvent { Operation(ContextOperation), } -#[derive(Clone, Default, Debug)] -pub struct ContextSummary { +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum ContextSummary { + Pending, + Content(ContextSummaryContent), + Error, +} + +#[derive(Default, Clone, Debug, Eq, PartialEq)] +pub struct ContextSummaryContent { pub text: String, pub done: bool, - timestamp: clock::Lamport, + pub timestamp: clock::Lamport, +} + +impl ContextSummary { + pub const DEFAULT: &str = "New Text Thread"; + + pub fn or_default(&self) -> SharedString { + self.unwrap_or(Self::DEFAULT) + } + + pub fn unwrap_or(&self, message: impl Into) -> SharedString { + self.content() + .map_or_else(|| message.into(), |content| content.text.clone().into()) + } + + pub fn content(&self) -> Option<&ContextSummaryContent> { + match self { + ContextSummary::Content(content) => Some(content), + ContextSummary::Pending | ContextSummary::Error => None, + } + } + + fn content_as_mut(&mut self) -> Option<&mut ContextSummaryContent> { + match self { + ContextSummary::Content(content) => Some(content), + ContextSummary::Pending | ContextSummary::Error => None, + } + } + + fn content_or_set_empty(&mut self) -> &mut ContextSummaryContent { + match self { + ContextSummary::Content(content) => content, + ContextSummary::Pending | ContextSummary::Error => { + let content = ContextSummaryContent::default(); + *self = ContextSummary::Content(content); + self.content_as_mut().unwrap() + } + } + } + + pub fn is_pending(&self) -> bool { + matches!(self, ContextSummary::Pending) + } + + fn timestamp(&self) -> Option { + match self { + ContextSummary::Content(content) => Some(content.timestamp), + ContextSummary::Pending | ContextSummary::Error => None, + } + } +} + +impl PartialOrd for ContextSummary { + fn partial_cmp(&self, other: &Self) -> Option { + self.timestamp().partial_cmp(&other.timestamp()) + } } #[derive(Clone, Debug, Eq, PartialEq)] @@ -641,7 +703,7 @@ pub struct AssistantContext { message_anchors: Vec, contents: Vec, messages_metadata: HashMap, - summary: Option, + summary: ContextSummary, summary_task: Task>, completion_count: usize, pending_completions: Vec, @@ -742,7 +804,7 @@ impl AssistantContext { slash_command_output_sections: Vec::new(), thought_process_output_sections: Vec::new(), edits_since_last_parse: edits_since_last_slash_command_parse, - summary: None, + summary: ContextSummary::Pending, summary_task: Task::ready(None), completion_count: Default::default(), pending_completions: Default::default(), @@ -803,7 +865,7 @@ impl AssistantContext { .collect(), summary: self .summary - .as_ref() + .content() .map(|summary| summary.text.clone()) .unwrap_or_default(), slash_command_output_sections: self @@ -989,12 +1051,10 @@ impl AssistantContext { summary: new_summary, .. } => { - if self - .summary - .as_ref() - .map_or(true, |summary| new_summary.timestamp > summary.timestamp) - { - self.summary = Some(new_summary); + if self.summary.timestamp().map_or(true, |current_timestamp| { + new_summary.timestamp > current_timestamp + }) { + self.summary = ContextSummary::Content(new_summary); summary_generated = true; } } @@ -1152,8 +1212,8 @@ impl AssistantContext { self.path.as_ref() } - pub fn summary(&self) -> Option<&ContextSummary> { - self.summary.as_ref() + pub fn summary(&self) -> &ContextSummary { + &self.summary } pub fn patch_containing(&self, position: Point, cx: &App) -> Option<&AssistantPatch> { @@ -2980,7 +3040,7 @@ impl AssistantContext { return; }; - if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) { + if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_pending()) { if !model.provider.is_authenticated(cx) { return; } @@ -2997,17 +3057,20 @@ impl AssistantContext { // If there is no summary, it is set with `done: false` so that "Loading Summary…" can // be displayed. - if self.summary.is_none() { - self.summary = Some(ContextSummary { - text: "".to_string(), - done: false, - timestamp: clock::Lamport::default(), - }); - replace_old = true; + match self.summary { + ContextSummary::Pending | ContextSummary::Error => { + self.summary = ContextSummary::Content(ContextSummaryContent { + text: "".to_string(), + done: false, + timestamp: clock::Lamport::default(), + }); + replace_old = true; + } + ContextSummary::Content(_) => {} } self.summary_task = cx.spawn(async move |this, cx| { - async move { + let result = async { let stream = model.model.stream_completion_text(request, &cx); let mut messages = stream.await?; @@ -3018,7 +3081,7 @@ impl AssistantContext { this.update(cx, |this, cx| { let version = this.version.clone(); let timestamp = this.next_timestamp(); - let summary = this.summary.get_or_insert(ContextSummary::default()); + let summary = this.summary.content_or_set_empty(); if !replaced && replace_old { summary.text.clear(); replaced = true; @@ -3040,10 +3103,19 @@ impl AssistantContext { } } + this.read_with(cx, |this, _cx| { + if let Some(summary) = this.summary.content() { + if summary.text.is_empty() { + bail!("Model generated an empty summary"); + } + } + Ok(()) + })??; + this.update(cx, |this, cx| { let version = this.version.clone(); let timestamp = this.next_timestamp(); - if let Some(summary) = this.summary.as_mut() { + if let Some(summary) = this.summary.content_as_mut() { summary.done = true; summary.timestamp = timestamp; let operation = ContextOperation::UpdateSummary { @@ -3058,8 +3130,18 @@ impl AssistantContext { anyhow::Ok(()) } - .log_err() - .await + .await; + + if let Err(err) = result { + this.update(cx, |this, cx| { + this.summary = ContextSummary::Error; + cx.emit(ContextEvent::SummaryChanged); + }) + .log_err(); + log::error!("Error generating context summary: {}", err); + } + + Some(()) }); } } @@ -3173,7 +3255,7 @@ impl AssistantContext { let (old_path, summary) = this.read_with(cx, |this, _| { let path = this.path.clone(); - let summary = if let Some(summary) = this.summary.as_ref() { + let summary = if let Some(summary) = this.summary.content() { if summary.done { Some(summary.text.clone()) } else { @@ -3227,21 +3309,12 @@ impl AssistantContext { pub fn set_custom_summary(&mut self, custom_summary: String, cx: &mut Context) { let timestamp = self.next_timestamp(); - let summary = self.summary.get_or_insert(ContextSummary::default()); + let summary = self.summary.content_or_set_empty(); summary.timestamp = timestamp; summary.done = true; summary.text = custom_summary; cx.emit(ContextEvent::SummaryChanged); } - - pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Text Thread"); - - pub fn summary_or_default(&self) -> SharedString { - self.summary - .as_ref() - .map(|summary| summary.text.clone().into()) - .unwrap_or(Self::DEFAULT_SUMMARY) - } } fn trimmed_text_in_range(buffer: &BufferSnapshot, range: Range) -> String { @@ -3475,7 +3548,7 @@ impl SavedContext { let timestamp = next_timestamp.tick(); operations.push(ContextOperation::UpdateSummary { - summary: ContextSummary { + summary: ContextSummaryContent { text: self.summary, done: true, timestamp, diff --git a/crates/assistant_context_editor/src/context/context_tests.rs b/crates/assistant_context_editor/src/context/context_tests.rs index 66da203886a5252c59277db5e5fd5f73da6610eb..7f9d0d893f09e0ef89e4b268e9a6ca4c03404452 100644 --- a/crates/assistant_context_editor/src/context/context_tests.rs +++ b/crates/assistant_context_editor/src/context/context_tests.rs @@ -1,6 +1,7 @@ use crate::{ AssistantContext, AssistantEdit, AssistantEditKind, CacheStatus, ContextEvent, ContextId, - ContextOperation, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus, + ContextOperation, ContextSummary, InvokedSlashCommandId, MessageCacheMetadata, MessageId, + MessageStatus, }; use anyhow::Result; use assistant_slash_command::{ @@ -16,7 +17,10 @@ use futures::{ }; use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*}; use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate}; -use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role}; +use language_model::{ + ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role, + fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}, +}; use parking_lot::Mutex; use pretty_assertions::assert_eq; use project::Project; @@ -1575,6 +1579,187 @@ fn test_mark_cache_anchors(cx: &mut App) { ); } +#[gpui::test] +async fn test_summarization(cx: &mut TestAppContext) { + let (context, fake_model) = setup_context_editor_with_fake_model(cx); + + // Initial state should be pending + context.read_with(cx, |context, _| { + assert!(matches!(context.summary(), ContextSummary::Pending)); + assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT); + }); + + let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone()); + context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) + .unwrap(); + }); + + // Send a message + context.update(cx, |context, cx| { + context.assist(cx); + }); + + simulate_successful_response(&fake_model, cx); + + // Should start generating summary when there are >= 2 messages + context.read_with(cx, |context, _| { + assert!(!context.summary().content().unwrap().done); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("Brief".into()); + fake_model.stream_last_completion_response(" Introduction".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Summary should be set + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "Brief Introduction"); + }); + + // We should be able to manually set a summary + context.update(cx, |context, cx| { + context.set_custom_summary("Brief Intro".into(), cx); + }); + + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "Brief Intro"); + }); +} + +#[gpui::test] +async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) { + let (context, fake_model) = setup_context_editor_with_fake_model(cx); + + test_summarize_error(&fake_model, &context, cx); + + // Now we should be able to set a summary + context.update(cx, |context, cx| { + context.set_custom_summary("Brief Intro".into(), cx); + }); + + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "Brief Intro"); + }); +} + +#[gpui::test] +async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { + let (context, fake_model) = setup_context_editor_with_fake_model(cx); + + test_summarize_error(&fake_model, &context, cx); + + // Sending another message should not trigger another summarize request + context.update(cx, |context, cx| { + context.assist(cx); + }); + + simulate_successful_response(&fake_model, cx); + + context.read_with(cx, |context, _| { + // State is still Error, not Generating + assert!(matches!(context.summary(), ContextSummary::Error)); + }); + + // But the summarize request can be invoked manually + context.update(cx, |context, cx| { + context.summarize(true, cx); + }); + + context.read_with(cx, |context, _| { + assert!(!context.summary().content().unwrap().done); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("A successful summary".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "A successful summary"); + }); +} + +fn test_summarize_error( + model: &Arc, + context: &Entity, + cx: &mut TestAppContext, +) { + let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone()); + context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) + .unwrap(); + }); + + // Send a message + context.update(cx, |context, cx| { + context.assist(cx); + }); + + simulate_successful_response(&model, cx); + + context.read_with(cx, |context, _| { + assert!(!context.summary().content().unwrap().done); + }); + + // Simulate summary request ending + cx.run_until_parked(); + model.end_last_completion_stream(); + cx.run_until_parked(); + + // State is set to Error and default message + context.read_with(cx, |context, _| { + assert_eq!(*context.summary(), ContextSummary::Error); + assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT); + }); +} + +fn setup_context_editor_with_fake_model( + cx: &mut TestAppContext, +) -> (Entity, Arc) { + let registry = Arc::new(LanguageRegistry::test(cx.executor().clone())); + + let fake_provider = Arc::new(FakeLanguageModelProvider); + let fake_model = Arc::new(fake_provider.test_model()); + + cx.update(|cx| { + init_test(cx); + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.set_default_model( + Some(ConfiguredModel { + provider: fake_provider.clone(), + model: fake_model.clone(), + }), + cx, + ) + }) + }); + + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = cx.new(|cx| { + AssistantContext::local( + registry, + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + cx, + ) + }); + + (context, fake_model) +} + +fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { + cx.run_until_parked(); + fake_model.stream_last_completion_response("Assistant response".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); +} + fn messages(context: &Entity, cx: &App) -> Vec<(MessageId, Role, Range)> { context .read(cx) diff --git a/crates/assistant_context_editor/src/context_editor.rs b/crates/assistant_context_editor/src/context_editor.rs index 8ff672613aa9962c92487279e4a460d2961bbdc9..12d3b687cc1b34d4bd8d7477a0a963e7cd7ec3d0 100644 --- a/crates/assistant_context_editor/src/context_editor.rs +++ b/crates/assistant_context_editor/src/context_editor.rs @@ -2202,7 +2202,12 @@ impl ContextEditor { } pub fn title(&self, cx: &App) -> SharedString { - self.context.read(cx).summary_or_default() + self.context.read(cx).summary().or_default() + } + + pub fn regenerate_summary(&mut self, cx: &mut Context) { + self.context + .update(cx, |context, cx| context.summarize(true, cx)); } fn render_patch_block( diff --git a/crates/assistant_context_editor/src/context_store.rs b/crates/assistant_context_editor/src/context_store.rs index fe89d7610948e670a87963a8e5c8f24f8529311f..f1f3b501a65f59288f97663e2096df0dae84c6a0 100644 --- a/crates/assistant_context_editor/src/context_store.rs +++ b/crates/assistant_context_editor/src/context_store.rs @@ -648,7 +648,10 @@ impl ContextStore { if context.replica_id() == ReplicaId::default() { Some(proto::ContextMetadata { context_id: context.id().to_proto(), - summary: context.summary().map(|summary| summary.text.clone()), + summary: context + .summary() + .content() + .map(|summary| summary.text.clone()), }) } else { None