From f0da3b74f818226a3630fad021be7fde08163176 Mon Sep 17 00:00:00 2001 From: Agus Zubiaga Date: Mon, 12 May 2025 04:45:48 -0300 Subject: [PATCH] agent: Handle thread title generation errors (#30273) The title of a (text) thread would get stuck in "Loading Summary..." when the request to generate it failed. We now handle this case by falling back to the default title, and letting the user manually edit the title or retry generating it. https://github.com/user-attachments/assets/898d26ad-d31f-4b62-9b05-519d923b1b22 Release Notes: - agent: Handle thread title generation errors --------- Co-authored-by: Richard Feldman --- crates/agent/src/active_thread.rs | 19 +- crates/agent/src/agent_diff.rs | 12 +- crates/agent/src/agent_panel.rs | 92 +++-- crates/agent/src/context.rs | 11 +- crates/agent/src/context_strip.rs | 4 +- crates/agent/src/history_store.rs | 4 +- crates/agent/src/thread.rs | 319 ++++++++++++++++-- .../assistant_context_editor/src/context.rs | 157 ++++++--- .../src/context/context_tests.rs | 188 ++++++++++- .../src/context_editor.rs | 7 +- .../src/context_store.rs | 5 +- 11 files changed, 677 insertions(+), 141 deletions(-) diff --git a/crates/agent/src/active_thread.rs b/crates/agent/src/active_thread.rs index 841482e4823f4ba308999ad4ad79258907f0ac46..c0d42bf5a2fafeed2e3b308118e9edca4e5d4684 100644 --- a/crates/agent/src/active_thread.rs +++ b/crates/agent/src/active_thread.rs @@ -6,7 +6,7 @@ use crate::context_strip::{ContextStrip, ContextStripEvent, SuggestContextKind}; use crate::message_editor::insert_message_creases; use crate::thread::{ LastRestoreCheckpoint, MessageCrease, MessageId, MessageSegment, Thread, ThreadError, - ThreadEvent, ThreadFeedback, + ThreadEvent, ThreadFeedback, ThreadSummary, }; use crate::thread_store::{RulesLoadingError, TextThreadStore, ThreadStore}; use crate::tool_use::{PendingToolUseStatus, ToolUse}; @@ -823,12 +823,12 @@ impl ActiveThread { self.messages.is_empty() } - pub fn summary(&self, cx: &App) -> Option { + pub fn summary<'a>(&'a self, cx: &'a App) -> &'a ThreadSummary { self.thread.read(cx).summary() } - pub fn summary_or_default(&self, cx: &App) -> SharedString { - self.thread.read(cx).summary_or_default() + pub fn regenerate_summary(&self, cx: &mut App) { + self.thread.update(cx, |thread, cx| thread.summarize(cx)) } pub fn cancel_last_completion(&mut self, window: &mut Window, cx: &mut App) -> bool { @@ -1134,11 +1134,7 @@ impl ActiveThread { return; } - let title = self - .thread - .read(cx) - .summary() - .unwrap_or("Agent Panel".into()); + let title = self.thread.read(cx).summary().unwrap_or("Agent Panel"); match AssistantSettings::get_global(cx).notify_when_agent_waiting { NotifyWhenAgentWaiting::PrimaryScreen => { @@ -3442,10 +3438,7 @@ pub(crate) fn open_active_thread_as_markdown( workspace.update_in(cx, |workspace, window, cx| { let thread = thread.read(cx); let markdown = thread.to_markdown(cx)?; - let thread_summary = thread - .summary() - .map(|summary| summary.to_string()) - .unwrap_or_else(|| "Thread".to_string()); + let thread_summary = thread.summary().or_default().to_string(); let project = workspace.project().clone(); diff --git a/crates/agent/src/agent_diff.rs b/crates/agent/src/agent_diff.rs index 54e4f1f9aa90bac05370596415e8bebc1df267a5..d83b2cf80df1a57bc9ad3444eadf196d10112a7f 100644 --- a/crates/agent/src/agent_diff.rs +++ b/crates/agent/src/agent_diff.rs @@ -215,11 +215,7 @@ impl AgentDiffPane { } fn update_title(&mut self, cx: &mut Context) { - let new_title = self - .thread - .read(cx) - .summary() - .unwrap_or("Agent Changes".into()); + let new_title = self.thread.read(cx).summary().unwrap_or("Agent Changes"); if new_title != self.title { self.title = new_title; cx.emit(EditorEvent::TitleChanged); @@ -469,11 +465,7 @@ impl Item for AgentDiffPane { } fn tab_content(&self, params: TabContentParams, _window: &Window, cx: &App) -> AnyElement { - let summary = self - .thread - .read(cx) - .summary() - .unwrap_or("Agent Changes".into()); + let summary = self.thread.read(cx).summary().unwrap_or("Agent Changes"); Label::new(format!("Review: {}", summary)) .color(if params.selected { Color::Default diff --git a/crates/agent/src/agent_panel.rs b/crates/agent/src/agent_panel.rs index c48863ae41b98f3747cfe192673baa5ffb9a2f8f..eb1635bf4b4b741f62c2dd46eff3274de307aa1e 100644 --- a/crates/agent/src/agent_panel.rs +++ b/crates/agent/src/agent_panel.rs @@ -10,8 +10,8 @@ use serde::{Deserialize, Serialize}; use anyhow::{Result, anyhow}; use assistant_context_editor::{ AgentPanelDelegate, AssistantContext, ConfigurationError, ContextEditor, ContextEvent, - SlashCommandCompletionProvider, humanize_token_count, make_lsp_adapter_delegate, - render_remaining_tokens, + ContextSummary, SlashCommandCompletionProvider, humanize_token_count, + make_lsp_adapter_delegate, render_remaining_tokens, }; use assistant_settings::{AssistantDockPosition, AssistantSettings}; use assistant_slash_command::SlashCommandWorkingSet; @@ -59,7 +59,7 @@ use crate::agent_configuration::{AgentConfiguration, AssistantConfigurationEvent use crate::agent_diff::AgentDiff; use crate::history_store::{HistoryStore, RecentEntry}; use crate::message_editor::{MessageEditor, MessageEditorEvent}; -use crate::thread::{Thread, ThreadError, ThreadId, TokenUsageRatio}; +use crate::thread::{Thread, ThreadError, ThreadId, ThreadSummary, TokenUsageRatio}; use crate::thread_history::{HistoryEntryElement, ThreadHistory}; use crate::thread_store::ThreadStore; use crate::ui::AgentOnboardingModal; @@ -196,7 +196,7 @@ impl ActiveView { } pub fn thread(thread: Entity, window: &mut Window, cx: &mut App) -> Self { - let summary = thread.read(cx).summary_or_default(); + let summary = thread.read(cx).summary().or_default(); let editor = cx.new(|cx| { let mut editor = Editor::single_line(window, cx); @@ -218,7 +218,7 @@ impl ActiveView { } EditorEvent::Blurred => { if editor.read(cx).text(cx).is_empty() { - let summary = thread.read(cx).summary_or_default(); + let summary = thread.read(cx).summary().or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -233,7 +233,7 @@ impl ActiveView { let editor = editor.clone(); move |thread, event, window, cx| match event { ThreadEvent::SummaryGenerated => { - let summary = thread.read(cx).summary_or_default(); + let summary = thread.read(cx).summary().or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -296,7 +296,8 @@ impl ActiveView { .read(cx) .context() .read(cx) - .summary_or_default(); + .summary() + .or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -311,7 +312,7 @@ impl ActiveView { let editor = editor.clone(); move |assistant_context, event, window, cx| match event { ContextEvent::SummaryGenerated => { - let summary = assistant_context.read(cx).summary_or_default(); + let summary = assistant_context.read(cx).summary().or_default(); editor.update(cx, |editor, cx| { editor.set_text(summary, window, cx); @@ -1452,23 +1453,45 @@ impl AgentPanel { .. } => { let active_thread = self.thread.read(cx); - let is_empty = active_thread.is_empty(); - - let summary = active_thread.summary(cx); + let state = if active_thread.is_empty() { + &ThreadSummary::Pending + } else { + active_thread.summary(cx) + }; - if is_empty { - Label::new(Thread::DEFAULT_SUMMARY.clone()) + match state { + ThreadSummary::Pending => Label::new(ThreadSummary::DEFAULT.clone()) .truncate() - .into_any_element() - } else if summary.is_none() { - Label::new(LOADING_SUMMARY_PLACEHOLDER) + .into_any_element(), + ThreadSummary::Generating => Label::new(LOADING_SUMMARY_PLACEHOLDER) .truncate() - .into_any_element() - } else { - div() + .into_any_element(), + ThreadSummary::Ready(_) => div() .w_full() .child(change_title_editor.clone()) - .into_any_element() + .into_any_element(), + ThreadSummary::Error => h_flex() + .w_full() + .child(change_title_editor.clone()) + .child( + ui::IconButton::new("retry-summary-generation", IconName::RotateCcw) + .on_click({ + let active_thread = self.thread.clone(); + move |_, _window, cx| { + active_thread.update(cx, |thread, cx| { + thread.regenerate_summary(cx); + }); + } + }) + .tooltip(move |_window, cx| { + cx.new(|_| { + Tooltip::new("Failed to generate title") + .meta("Click to try again") + }) + .into() + }), + ) + .into_any_element(), } } ActiveView::PromptEditor { @@ -1476,14 +1499,13 @@ impl AgentPanel { context_editor, .. } => { - let context_editor = context_editor.read(cx); - let summary = context_editor.context().read(cx).summary(); + let summary = context_editor.read(cx).context().read(cx).summary(); match summary { - None => Label::new(AssistantContext::DEFAULT_SUMMARY.clone()) + ContextSummary::Pending => Label::new(ContextSummary::DEFAULT) .truncate() .into_any_element(), - Some(summary) => { + ContextSummary::Content(summary) => { if summary.done { div() .w_full() @@ -1495,6 +1517,28 @@ impl AgentPanel { .into_any_element() } } + ContextSummary::Error => h_flex() + .w_full() + .child(title_editor.clone()) + .child( + ui::IconButton::new("retry-summary-generation", IconName::RotateCcw) + .on_click({ + let context_editor = context_editor.clone(); + move |_, _window, cx| { + context_editor.update(cx, |context_editor, cx| { + context_editor.regenerate_summary(cx); + }); + } + }) + .tooltip(move |_window, cx| { + cx.new(|_| { + Tooltip::new("Failed to generate title") + .meta("Click to try again") + }) + .into() + }), + ) + .into_any_element(), } } ActiveView::History => Label::new("History").truncate().into_any_element(), diff --git a/crates/agent/src/context.rs b/crates/agent/src/context.rs index f458896b1817a949d1f2a8f03ef97699cad3f27d..98437778aa0dbbcd161e576af7cb1be1c5860a04 100644 --- a/crates/agent/src/context.rs +++ b/crates/agent/src/context.rs @@ -586,10 +586,7 @@ impl ThreadContextHandle { } pub fn title(&self, cx: &App) -> SharedString { - self.thread - .read(cx) - .summary() - .unwrap_or_else(|| "New thread".into()) + self.thread.read(cx).summary().or_default() } fn load(self, cx: &App) -> Task>)>> { @@ -597,9 +594,7 @@ impl ThreadContextHandle { let text = Thread::wait_for_detailed_summary_or_text(&self.thread, cx).await?; let title = self .thread - .read_with(cx, |thread, _cx| { - thread.summary().unwrap_or_else(|| "New thread".into()) - }) + .read_with(cx, |thread, _cx| thread.summary().or_default()) .ok()?; let context = AgentContext::Thread(ThreadContext { title, @@ -642,7 +637,7 @@ impl TextThreadContextHandle { } pub fn title(&self, cx: &App) -> SharedString { - self.context.read(cx).summary_or_default() + self.context.read(cx).summary().or_default() } fn load(self, cx: &App) -> Task>)>> { diff --git a/crates/agent/src/context_strip.rs b/crates/agent/src/context_strip.rs index f9d9ff278181b640eb029e8f7359ef9469487996..8fe1a21d7480d2703fa0545712baad7c1e2463e6 100644 --- a/crates/agent/src/context_strip.rs +++ b/crates/agent/src/context_strip.rs @@ -160,7 +160,7 @@ impl ContextStrip { } Some(SuggestedContext::Thread { - name: active_thread.summary_or_default(), + name: active_thread.summary().or_default(), thread: weak_active_thread, }) } else if let Some(active_context_editor) = panel.active_context_editor() { @@ -174,7 +174,7 @@ impl ContextStrip { } Some(SuggestedContext::TextThread { - name: context.summary_or_default(), + name: context.summary().or_default(), context: weak_context, }) } else { diff --git a/crates/agent/src/history_store.rs b/crates/agent/src/history_store.rs index 85fdac2bab0066d8c7025fb60139d9ba62db68bd..c8d9e9a26396bb57aae30664f581a7bff1b78984 100644 --- a/crates/agent/src/history_store.rs +++ b/crates/agent/src/history_store.rs @@ -71,8 +71,8 @@ impl Eq for RecentEntry {} impl RecentEntry { pub(crate) fn summary(&self, cx: &App) -> SharedString { match self { - RecentEntry::Thread(_, thread) => thread.read(cx).summary_or_default(), - RecentEntry::Context(context) => context.read(cx).summary_or_default(), + RecentEntry::Thread(_, thread) => thread.read(cx).summary().or_default(), + RecentEntry::Context(context) => context.read(cx).summary().or_default(), } } } diff --git a/crates/agent/src/thread.rs b/crates/agent/src/thread.rs index c6f2d74ff9587211f9e5969505a166eefe8eeddb..a65eda5b40c123b75197ec3fd99a57753e7c9ed7 100644 --- a/crates/agent/src/thread.rs +++ b/crates/agent/src/thread.rs @@ -36,7 +36,7 @@ use serde::{Deserialize, Serialize}; use settings::Settings; use thiserror::Error; use ui::Window; -use util::{ResultExt as _, TryFutureExt as _, post_inc}; +use util::{ResultExt as _, post_inc}; use uuid::Uuid; use zed_llm_client::CompletionRequestStatus; @@ -324,7 +324,7 @@ pub enum QueueState { pub struct Thread { id: ThreadId, updated_at: DateTime, - summary: Option, + summary: ThreadSummary, pending_summary: Task>, detailed_summary_task: Task>, detailed_summary_tx: postage::watch::Sender, @@ -361,6 +361,33 @@ pub struct Thread { configured_model: Option, } +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum ThreadSummary { + Pending, + Generating, + Ready(SharedString), + Error, +} + +impl ThreadSummary { + pub const DEFAULT: SharedString = SharedString::new_static("New Thread"); + + pub fn or_default(&self) -> SharedString { + self.unwrap_or(Self::DEFAULT) + } + + pub fn unwrap_or(&self, message: impl Into) -> SharedString { + self.ready().unwrap_or_else(|| message.into()) + } + + pub fn ready(&self) -> Option { + match self { + ThreadSummary::Ready(summary) => Some(summary.clone()), + ThreadSummary::Pending | ThreadSummary::Generating | ThreadSummary::Error => None, + } + } +} + #[derive(Debug, Clone, Serialize, Deserialize)] pub struct ExceededWindowError { /// Model used when last message exceeded context window @@ -383,7 +410,7 @@ impl Thread { Self { id: ThreadId::new(), updated_at: Utc::now(), - summary: None, + summary: ThreadSummary::Pending, pending_summary: Task::ready(None), detailed_summary_task: Task::ready(None), detailed_summary_tx, @@ -471,7 +498,7 @@ impl Thread { Self { id, updated_at: serialized.updated_at, - summary: Some(serialized.summary), + summary: ThreadSummary::Ready(serialized.summary), pending_summary: Task::ready(None), detailed_summary_task: Task::ready(None), detailed_summary_tx, @@ -572,10 +599,6 @@ impl Thread { self.last_prompt_id = PromptId::new(); } - pub fn summary(&self) -> Option { - self.summary.clone() - } - pub fn project_context(&self) -> SharedProjectContext { self.project_context.clone() } @@ -596,26 +619,25 @@ impl Thread { cx.notify(); } - pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Thread"); - - pub fn summary_or_default(&self) -> SharedString { - self.summary.clone().unwrap_or(Self::DEFAULT_SUMMARY) + pub fn summary(&self) -> &ThreadSummary { + &self.summary } pub fn set_summary(&mut self, new_summary: impl Into, cx: &mut Context) { - let Some(current_summary) = &self.summary else { - // Don't allow setting summary until generated - return; + let current_summary = match &self.summary { + ThreadSummary::Pending | ThreadSummary::Generating => return, + ThreadSummary::Ready(summary) => summary, + ThreadSummary::Error => &ThreadSummary::DEFAULT, }; let mut new_summary = new_summary.into(); if new_summary.is_empty() { - new_summary = Self::DEFAULT_SUMMARY; + new_summary = ThreadSummary::DEFAULT; } if current_summary != &new_summary { - self.summary = Some(new_summary); + self.summary = ThreadSummary::Ready(new_summary); cx.emit(ThreadEvent::SummaryChanged); } } @@ -1029,7 +1051,7 @@ impl Thread { let initial_project_snapshot = initial_project_snapshot.await; this.read_with(cx, |this, cx| SerializedThread { version: SerializedThread::VERSION.to_string(), - summary: this.summary_or_default(), + summary: this.summary().or_default(), updated_at: this.updated_at(), messages: this .messages() @@ -1625,7 +1647,7 @@ impl Thread { // If there is a response without tool use, summarize the message. Otherwise, // allow two tool uses before summarizing. - if thread.summary.is_none() + if matches!(thread.summary, ThreadSummary::Pending) && thread.messages.len() >= 2 && (!thread.has_pending_tool_uses() || thread.messages.len() >= 6) { @@ -1739,6 +1761,7 @@ impl Thread { pub fn summarize(&mut self, cx: &mut Context) { let Some(model) = LanguageModelRegistry::read_global(cx).thread_summary_model() else { + println!("No thread summary model"); return; }; @@ -1753,13 +1776,17 @@ impl Thread { let request = self.to_summarize_request(&model.model, added_user_message.into(), cx); + self.summary = ThreadSummary::Generating; + self.pending_summary = cx.spawn(async move |this, cx| { - async move { + let result = async { let mut messages = model.model.stream_completion(request, &cx).await?; let mut new_summary = String::new(); while let Some(event) = messages.next().await { - let event = event?; + let Ok(event) = event else { + continue; + }; let text = match event { LanguageModelCompletionEvent::Text(text) => text, LanguageModelCompletionEvent::StatusUpdate( @@ -1785,18 +1812,29 @@ impl Thread { } } - this.update(cx, |this, cx| { - if !new_summary.is_empty() { - this.summary = Some(new_summary.into()); - } + anyhow::Ok(new_summary) + } + .await; - cx.emit(ThreadEvent::SummaryGenerated); - })?; + this.update(cx, |this, cx| { + match result { + Ok(new_summary) => { + if new_summary.is_empty() { + this.summary = ThreadSummary::Error; + } else { + this.summary = ThreadSummary::Ready(new_summary.into()); + } + } + Err(err) => { + this.summary = ThreadSummary::Error; + log::error!("Failed to generate thread summary: {}", err); + } + } + cx.emit(ThreadEvent::SummaryGenerated); + }) + .log_err()?; - anyhow::Ok(()) - } - .log_err() - .await + Some(()) }); } @@ -2406,9 +2444,8 @@ impl Thread { pub fn to_markdown(&self, cx: &App) -> Result { let mut markdown = Vec::new(); - if let Some(summary) = self.summary() { - writeln!(markdown, "# {summary}\n")?; - }; + let summary = self.summary().or_default(); + writeln!(markdown, "# {summary}\n")?; for message in self.messages() { writeln!( @@ -2725,7 +2762,7 @@ mod tests { use assistant_tool::ToolRegistry; use editor::EditorSettings; use gpui::TestAppContext; - use language_model::fake_provider::FakeLanguageModel; + use language_model::fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}; use project::{FakeFs, Project}; use prompt_store::PromptBuilder; use serde_json::json; @@ -3226,6 +3263,196 @@ fn main() {{ assert_eq!(request.temperature, None); } + #[gpui::test] + async fn test_thread_summary(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + + let (_, _thread_store, thread, _context_store, model) = + setup_test_environment(cx, project.clone()).await; + + // Initial state should be pending + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Pending)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + + // Manually setting the summary should not be allowed in this state + thread.update(cx, |thread, cx| { + thread.set_summary("This should not work", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Pending)); + }); + + // Send a message + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); + thread.send_to_model(model.clone(), None, cx); + }); + + let fake_model = model.as_fake(); + simulate_successful_response(&fake_model, cx); + + // Should start generating summary when there are >= 2 messages + thread.read_with(cx, |thread, _| { + assert_eq!(*thread.summary(), ThreadSummary::Generating); + }); + + // Should not be able to set the summary while generating + thread.update(cx, |thread, cx| { + thread.set_summary("This should not work either", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Generating)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("Brief".into()); + fake_model.stream_last_completion_response(" Introduction".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Summary should be set + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), "Brief Introduction"); + }); + + // Now we should be able to set a summary + thread.update(cx, |thread, cx| { + thread.set_summary("Brief Intro", cx); + }); + + thread.read_with(cx, |thread, _| { + assert_eq!(thread.summary().or_default(), "Brief Intro"); + }); + + // Test setting an empty summary (should default to DEFAULT) + thread.update(cx, |thread, cx| { + thread.set_summary("", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + } + + #[gpui::test] + async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + + let (_, _thread_store, thread, _context_store, model) = + setup_test_environment(cx, project.clone()).await; + + test_summarize_error(&model, &thread, cx); + + // Now we should be able to set a summary + thread.update(cx, |thread, cx| { + thread.set_summary("Brief Intro", cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), "Brief Intro"); + }); + } + + #[gpui::test] + async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { + init_test_settings(cx); + + let project = create_test_project(cx, json!({})).await; + + let (_, _thread_store, thread, _context_store, model) = + setup_test_environment(cx, project.clone()).await; + + test_summarize_error(&model, &thread, cx); + + // Sending another message should not trigger another summarize request + thread.update(cx, |thread, cx| { + thread.insert_user_message( + "How are you?", + ContextLoadResult::default(), + None, + vec![], + cx, + ); + thread.send_to_model(model.clone(), None, cx); + }); + + let fake_model = model.as_fake(); + simulate_successful_response(&fake_model, cx); + + thread.read_with(cx, |thread, _| { + // State is still Error, not Generating + assert!(matches!(thread.summary(), ThreadSummary::Error)); + }); + + // But the summarize request can be invoked manually + thread.update(cx, |thread, cx| { + thread.summarize(cx); + }); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Generating)); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("A successful summary".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Ready(_))); + assert_eq!(thread.summary().or_default(), "A successful summary"); + }); + } + + fn test_summarize_error( + model: &Arc, + thread: &Entity, + cx: &mut TestAppContext, + ) { + thread.update(cx, |thread, cx| { + thread.insert_user_message("Hi!", ContextLoadResult::default(), None, vec![], cx); + thread.send_to_model(model.clone(), None, cx); + }); + + let fake_model = model.as_fake(); + simulate_successful_response(&fake_model, cx); + + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Generating)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + + // Simulate summary request ending + cx.run_until_parked(); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // State is set to Error and default message + thread.read_with(cx, |thread, _| { + assert!(matches!(thread.summary(), ThreadSummary::Error)); + assert_eq!(thread.summary().or_default(), ThreadSummary::DEFAULT); + }); + } + + fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { + cx.run_until_parked(); + fake_model.stream_last_completion_response("Assistant response".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + } + fn init_test_settings(cx: &mut TestAppContext) { cx.update(|cx| { let settings_store = SettingsStore::test(cx); @@ -3282,9 +3509,29 @@ fn main() {{ let thread = thread_store.update(cx, |store, cx| store.create_thread(cx)); let context_store = cx.new(|_cx| ContextStore::new(project.downgrade(), None)); - let model = FakeLanguageModel::default(); + let provider = Arc::new(FakeLanguageModelProvider); + let model = provider.test_model(); let model: Arc = Arc::new(model); + cx.update(|_, cx| { + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.set_default_model( + Some(ConfiguredModel { + provider: provider.clone(), + model: model.clone(), + }), + cx, + ); + registry.set_thread_summary_model( + Some(ConfiguredModel { + provider, + model: model.clone(), + }), + cx, + ); + }) + }); + (workspace, thread_store, thread, context_store, model) } diff --git a/crates/assistant_context_editor/src/context.rs b/crates/assistant_context_editor/src/context.rs index 28f03de1a61d044d613136b28bb75ac3c9b944cd..047ca89db03cd510f63c92bfeac02ef86931766e 100644 --- a/crates/assistant_context_editor/src/context.rs +++ b/crates/assistant_context_editor/src/context.rs @@ -1,7 +1,7 @@ #[cfg(test)] mod context_tests; -use anyhow::{Context as _, Result, anyhow}; +use anyhow::{Context as _, Result, anyhow, bail}; use assistant_settings::AssistantSettings; use assistant_slash_command::{ SlashCommandContent, SlashCommandEvent, SlashCommandLine, SlashCommandOutputSection, @@ -133,7 +133,7 @@ pub enum ContextOperation { version: clock::Global, }, UpdateSummary { - summary: ContextSummary, + summary: ContextSummaryContent, version: clock::Global, }, SlashCommandStarted { @@ -203,7 +203,7 @@ impl ContextOperation { version: language::proto::deserialize_version(&update.version), }), proto::context_operation::Variant::UpdateSummary(update) => Ok(Self::UpdateSummary { - summary: ContextSummary { + summary: ContextSummaryContent { text: update.summary, done: update.done, timestamp: language::proto::deserialize_timestamp( @@ -467,11 +467,73 @@ pub enum ContextEvent { Operation(ContextOperation), } -#[derive(Clone, Default, Debug)] -pub struct ContextSummary { +#[derive(Clone, Debug, Eq, PartialEq)] +pub enum ContextSummary { + Pending, + Content(ContextSummaryContent), + Error, +} + +#[derive(Default, Clone, Debug, Eq, PartialEq)] +pub struct ContextSummaryContent { pub text: String, pub done: bool, - timestamp: clock::Lamport, + pub timestamp: clock::Lamport, +} + +impl ContextSummary { + pub const DEFAULT: &str = "New Text Thread"; + + pub fn or_default(&self) -> SharedString { + self.unwrap_or(Self::DEFAULT) + } + + pub fn unwrap_or(&self, message: impl Into) -> SharedString { + self.content() + .map_or_else(|| message.into(), |content| content.text.clone().into()) + } + + pub fn content(&self) -> Option<&ContextSummaryContent> { + match self { + ContextSummary::Content(content) => Some(content), + ContextSummary::Pending | ContextSummary::Error => None, + } + } + + fn content_as_mut(&mut self) -> Option<&mut ContextSummaryContent> { + match self { + ContextSummary::Content(content) => Some(content), + ContextSummary::Pending | ContextSummary::Error => None, + } + } + + fn content_or_set_empty(&mut self) -> &mut ContextSummaryContent { + match self { + ContextSummary::Content(content) => content, + ContextSummary::Pending | ContextSummary::Error => { + let content = ContextSummaryContent::default(); + *self = ContextSummary::Content(content); + self.content_as_mut().unwrap() + } + } + } + + pub fn is_pending(&self) -> bool { + matches!(self, ContextSummary::Pending) + } + + fn timestamp(&self) -> Option { + match self { + ContextSummary::Content(content) => Some(content.timestamp), + ContextSummary::Pending | ContextSummary::Error => None, + } + } +} + +impl PartialOrd for ContextSummary { + fn partial_cmp(&self, other: &Self) -> Option { + self.timestamp().partial_cmp(&other.timestamp()) + } } #[derive(Clone, Debug, Eq, PartialEq)] @@ -607,7 +669,7 @@ pub struct AssistantContext { message_anchors: Vec, contents: Vec, messages_metadata: HashMap, - summary: Option, + summary: ContextSummary, summary_task: Task>, completion_count: usize, pending_completions: Vec, @@ -694,7 +756,7 @@ impl AssistantContext { slash_command_output_sections: Vec::new(), thought_process_output_sections: Vec::new(), edits_since_last_parse: edits_since_last_slash_command_parse, - summary: None, + summary: ContextSummary::Pending, summary_task: Task::ready(None), completion_count: Default::default(), pending_completions: Default::default(), @@ -753,7 +815,7 @@ impl AssistantContext { .collect(), summary: self .summary - .as_ref() + .content() .map(|summary| summary.text.clone()) .unwrap_or_default(), slash_command_output_sections: self @@ -939,12 +1001,10 @@ impl AssistantContext { summary: new_summary, .. } => { - if self - .summary - .as_ref() - .map_or(true, |summary| new_summary.timestamp > summary.timestamp) - { - self.summary = Some(new_summary); + if self.summary.timestamp().map_or(true, |current_timestamp| { + new_summary.timestamp > current_timestamp + }) { + self.summary = ContextSummary::Content(new_summary); summary_generated = true; } } @@ -1102,8 +1162,8 @@ impl AssistantContext { self.path.as_ref() } - pub fn summary(&self) -> Option<&ContextSummary> { - self.summary.as_ref() + pub fn summary(&self) -> &ContextSummary { + &self.summary } pub fn parsed_slash_commands(&self) -> &[ParsedSlashCommand] { @@ -2576,7 +2636,7 @@ impl AssistantContext { return; }; - if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_none()) { + if replace_old || (self.message_anchors.len() >= 2 && self.summary.is_pending()) { if !model.provider.is_authenticated(cx) { return; } @@ -2593,17 +2653,20 @@ impl AssistantContext { // If there is no summary, it is set with `done: false` so that "Loading Summary…" can // be displayed. - if self.summary.is_none() { - self.summary = Some(ContextSummary { - text: "".to_string(), - done: false, - timestamp: clock::Lamport::default(), - }); - replace_old = true; + match self.summary { + ContextSummary::Pending | ContextSummary::Error => { + self.summary = ContextSummary::Content(ContextSummaryContent { + text: "".to_string(), + done: false, + timestamp: clock::Lamport::default(), + }); + replace_old = true; + } + ContextSummary::Content(_) => {} } self.summary_task = cx.spawn(async move |this, cx| { - async move { + let result = async { let stream = model.model.stream_completion_text(request, &cx); let mut messages = stream.await?; @@ -2614,7 +2677,7 @@ impl AssistantContext { this.update(cx, |this, cx| { let version = this.version.clone(); let timestamp = this.next_timestamp(); - let summary = this.summary.get_or_insert(ContextSummary::default()); + let summary = this.summary.content_or_set_empty(); if !replaced && replace_old { summary.text.clear(); replaced = true; @@ -2636,10 +2699,19 @@ impl AssistantContext { } } + this.read_with(cx, |this, _cx| { + if let Some(summary) = this.summary.content() { + if summary.text.is_empty() { + bail!("Model generated an empty summary"); + } + } + Ok(()) + })??; + this.update(cx, |this, cx| { let version = this.version.clone(); let timestamp = this.next_timestamp(); - if let Some(summary) = this.summary.as_mut() { + if let Some(summary) = this.summary.content_as_mut() { summary.done = true; summary.timestamp = timestamp; let operation = ContextOperation::UpdateSummary { @@ -2654,8 +2726,18 @@ impl AssistantContext { anyhow::Ok(()) } - .log_err() - .await + .await; + + if let Err(err) = result { + this.update(cx, |this, cx| { + this.summary = ContextSummary::Error; + cx.emit(ContextEvent::SummaryChanged); + }) + .log_err(); + log::error!("Error generating context summary: {}", err); + } + + Some(()) }); } } @@ -2769,7 +2851,7 @@ impl AssistantContext { let (old_path, summary) = this.read_with(cx, |this, _| { let path = this.path.clone(); - let summary = if let Some(summary) = this.summary.as_ref() { + let summary = if let Some(summary) = this.summary.content() { if summary.done { Some(summary.text.clone()) } else { @@ -2823,21 +2905,12 @@ impl AssistantContext { pub fn set_custom_summary(&mut self, custom_summary: String, cx: &mut Context) { let timestamp = self.next_timestamp(); - let summary = self.summary.get_or_insert(ContextSummary::default()); + let summary = self.summary.content_or_set_empty(); summary.timestamp = timestamp; summary.done = true; summary.text = custom_summary; cx.emit(ContextEvent::SummaryChanged); } - - pub const DEFAULT_SUMMARY: SharedString = SharedString::new_static("New Text Thread"); - - pub fn summary_or_default(&self) -> SharedString { - self.summary - .as_ref() - .map(|summary| summary.text.clone().into()) - .unwrap_or(Self::DEFAULT_SUMMARY) - } } #[derive(Debug, Default)] @@ -3053,7 +3126,7 @@ impl SavedContext { let timestamp = next_timestamp.tick(); operations.push(ContextOperation::UpdateSummary { - summary: ContextSummary { + summary: ContextSummaryContent { text: self.summary, done: true, timestamp, diff --git a/crates/assistant_context_editor/src/context/context_tests.rs b/crates/assistant_context_editor/src/context/context_tests.rs index a8a5f3835d3a1597e717d401e030ae84b9a2b99a..3983a901587e563b3c9d490237a182d9ef3c01b3 100644 --- a/crates/assistant_context_editor/src/context/context_tests.rs +++ b/crates/assistant_context_editor/src/context/context_tests.rs @@ -1,5 +1,5 @@ use crate::{ - AssistantContext, CacheStatus, ContextEvent, ContextId, ContextOperation, + AssistantContext, CacheStatus, ContextEvent, ContextId, ContextOperation, ContextSummary, InvokedSlashCommandId, MessageCacheMetadata, MessageId, MessageStatus, }; use anyhow::Result; @@ -16,7 +16,10 @@ use futures::{ }; use gpui::{App, Entity, SharedString, Task, TestAppContext, WeakEntity, prelude::*}; use language::{Buffer, BufferSnapshot, LanguageRegistry, LspAdapterDelegate}; -use language_model::{LanguageModelCacheConfiguration, LanguageModelRegistry, Role}; +use language_model::{ + ConfiguredModel, LanguageModelCacheConfiguration, LanguageModelRegistry, Role, + fake_provider::{FakeLanguageModel, FakeLanguageModelProvider}, +}; use parking_lot::Mutex; use pretty_assertions::assert_eq; use project::Project; @@ -1177,6 +1180,187 @@ fn test_mark_cache_anchors(cx: &mut App) { ); } +#[gpui::test] +async fn test_summarization(cx: &mut TestAppContext) { + let (context, fake_model) = setup_context_editor_with_fake_model(cx); + + // Initial state should be pending + context.read_with(cx, |context, _| { + assert!(matches!(context.summary(), ContextSummary::Pending)); + assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT); + }); + + let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone()); + context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) + .unwrap(); + }); + + // Send a message + context.update(cx, |context, cx| { + context.assist(cx); + }); + + simulate_successful_response(&fake_model, cx); + + // Should start generating summary when there are >= 2 messages + context.read_with(cx, |context, _| { + assert!(!context.summary().content().unwrap().done); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("Brief".into()); + fake_model.stream_last_completion_response(" Introduction".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + // Summary should be set + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "Brief Introduction"); + }); + + // We should be able to manually set a summary + context.update(cx, |context, cx| { + context.set_custom_summary("Brief Intro".into(), cx); + }); + + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "Brief Intro"); + }); +} + +#[gpui::test] +async fn test_thread_summary_error_set_manually(cx: &mut TestAppContext) { + let (context, fake_model) = setup_context_editor_with_fake_model(cx); + + test_summarize_error(&fake_model, &context, cx); + + // Now we should be able to set a summary + context.update(cx, |context, cx| { + context.set_custom_summary("Brief Intro".into(), cx); + }); + + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "Brief Intro"); + }); +} + +#[gpui::test] +async fn test_thread_summary_error_retry(cx: &mut TestAppContext) { + let (context, fake_model) = setup_context_editor_with_fake_model(cx); + + test_summarize_error(&fake_model, &context, cx); + + // Sending another message should not trigger another summarize request + context.update(cx, |context, cx| { + context.assist(cx); + }); + + simulate_successful_response(&fake_model, cx); + + context.read_with(cx, |context, _| { + // State is still Error, not Generating + assert!(matches!(context.summary(), ContextSummary::Error)); + }); + + // But the summarize request can be invoked manually + context.update(cx, |context, cx| { + context.summarize(true, cx); + }); + + context.read_with(cx, |context, _| { + assert!(!context.summary().content().unwrap().done); + }); + + cx.run_until_parked(); + fake_model.stream_last_completion_response("A successful summary".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); + + context.read_with(cx, |context, _| { + assert_eq!(context.summary().or_default(), "A successful summary"); + }); +} + +fn test_summarize_error( + model: &Arc, + context: &Entity, + cx: &mut TestAppContext, +) { + let message_1 = context.read_with(cx, |context, _cx| context.message_anchors[0].clone()); + context.update(cx, |context, cx| { + context + .insert_message_after(message_1.id, Role::Assistant, MessageStatus::Done, cx) + .unwrap(); + }); + + // Send a message + context.update(cx, |context, cx| { + context.assist(cx); + }); + + simulate_successful_response(&model, cx); + + context.read_with(cx, |context, _| { + assert!(!context.summary().content().unwrap().done); + }); + + // Simulate summary request ending + cx.run_until_parked(); + model.end_last_completion_stream(); + cx.run_until_parked(); + + // State is set to Error and default message + context.read_with(cx, |context, _| { + assert_eq!(*context.summary(), ContextSummary::Error); + assert_eq!(context.summary().or_default(), ContextSummary::DEFAULT); + }); +} + +fn setup_context_editor_with_fake_model( + cx: &mut TestAppContext, +) -> (Entity, Arc) { + let registry = Arc::new(LanguageRegistry::test(cx.executor().clone())); + + let fake_provider = Arc::new(FakeLanguageModelProvider); + let fake_model = Arc::new(fake_provider.test_model()); + + cx.update(|cx| { + init_test(cx); + LanguageModelRegistry::global(cx).update(cx, |registry, cx| { + registry.set_default_model( + Some(ConfiguredModel { + provider: fake_provider.clone(), + model: fake_model.clone(), + }), + cx, + ) + }) + }); + + let prompt_builder = Arc::new(PromptBuilder::new(None).unwrap()); + let context = cx.new(|cx| { + AssistantContext::local( + registry, + None, + None, + prompt_builder.clone(), + Arc::new(SlashCommandWorkingSet::default()), + cx, + ) + }); + + (context, fake_model) +} + +fn simulate_successful_response(fake_model: &FakeLanguageModel, cx: &mut TestAppContext) { + cx.run_until_parked(); + fake_model.stream_last_completion_response("Assistant response".into()); + fake_model.end_last_completion_stream(); + cx.run_until_parked(); +} + fn messages(context: &Entity, cx: &App) -> Vec<(MessageId, Role, Range)> { context .read(cx) diff --git a/crates/assistant_context_editor/src/context_editor.rs b/crates/assistant_context_editor/src/context_editor.rs index 37cb766986f41dd2cc315fd2f3142511211f9b30..21ec018dc8c4fe9b5d027e362e095f2d783ab1cd 100644 --- a/crates/assistant_context_editor/src/context_editor.rs +++ b/crates/assistant_context_editor/src/context_editor.rs @@ -1860,7 +1860,12 @@ impl ContextEditor { } pub fn title(&self, cx: &App) -> SharedString { - self.context.read(cx).summary_or_default() + self.context.read(cx).summary().or_default() + } + + pub fn regenerate_summary(&mut self, cx: &mut Context) { + self.context + .update(cx, |context, cx| context.summarize(true, cx)); } fn render_notice(&self, cx: &mut Context) -> Option { diff --git a/crates/assistant_context_editor/src/context_store.rs b/crates/assistant_context_editor/src/context_store.rs index fe89d7610948e670a87963a8e5c8f24f8529311f..f1f3b501a65f59288f97663e2096df0dae84c6a0 100644 --- a/crates/assistant_context_editor/src/context_store.rs +++ b/crates/assistant_context_editor/src/context_store.rs @@ -648,7 +648,10 @@ impl ContextStore { if context.replica_id() == ReplicaId::default() { Some(proto::ContextMetadata { context_id: context.id().to_proto(), - summary: context.summary().map(|summary| summary.text.clone()), + summary: context + .summary() + .content() + .map(|summary| summary.text.clone()), }) } else { None