@@ -577,6 +577,10 @@ impl NativeAgent {
}
fn save_thread(&mut self, thread: Entity<Thread>, cx: &mut Context<Self>) {
+ if thread.read(cx).is_empty() {
+ return;
+ }
+
let database_future = ThreadsDatabase::connect(cx);
let (id, db_thread) =
thread.update(cx, |thread, cx| (thread.id().clone(), thread.to_db(cx)));
@@ -989,12 +993,19 @@ impl acp_thread::AgentSessionResume for NativeAgentSessionResume {
#[cfg(test)]
mod tests {
+ use crate::HistoryEntryId;
+
use super::*;
- use acp_thread::{AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo};
+ use acp_thread::{
+ AgentConnection, AgentModelGroupName, AgentModelId, AgentModelInfo, MentionUri,
+ };
use fs::FakeFs;
use gpui::TestAppContext;
+ use indoc::indoc;
+ use language_model::fake_provider::FakeLanguageModel;
use serde_json::json;
use settings::SettingsStore;
+ use util::path;
#[gpui::test]
async fn test_maintaining_project_context(cx: &mut TestAppContext) {
@@ -1179,6 +1190,163 @@ mod tests {
);
}
+ #[gpui::test]
+ #[cfg_attr(target_os = "windows", ignore)] // TODO: Fix this test on Windows
+ async fn test_save_load_thread(cx: &mut TestAppContext) {
+ init_test(cx);
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/",
+ json!({
+ "a": {
+ "b.md": "Lorem"
+ }
+ }),
+ )
+ .await;
+ let project = Project::test(fs.clone(), [path!("/a").as_ref()], cx).await;
+ let context_store = cx.new(|cx| assistant_context::ContextStore::fake(project.clone(), cx));
+ let history_store = cx.new(|cx| HistoryStore::new(context_store, cx));
+ let agent = NativeAgent::new(
+ project.clone(),
+ history_store.clone(),
+ Templates::new(),
+ None,
+ fs.clone(),
+ &mut cx.to_async(),
+ )
+ .await
+ .unwrap();
+ let connection = Rc::new(NativeAgentConnection(agent.clone()));
+
+ let acp_thread = cx
+ .update(|cx| {
+ connection
+ .clone()
+ .new_thread(project.clone(), Path::new(""), cx)
+ })
+ .await
+ .unwrap();
+ let session_id = acp_thread.read_with(cx, |thread, _| thread.session_id().clone());
+ let thread = agent.read_with(cx, |agent, _| {
+ agent.sessions.get(&session_id).unwrap().thread.clone()
+ });
+
+ // Ensure empty threads are not saved, even if they get mutated.
+ let model = Arc::new(FakeLanguageModel::default());
+ let summary_model = Arc::new(FakeLanguageModel::default());
+ thread.update(cx, |thread, cx| {
+ thread.set_model(model, cx);
+ thread.set_summarization_model(Some(summary_model), cx);
+ });
+ cx.run_until_parked();
+ assert_eq!(history_entries(&history_store, cx), vec![]);
+
+ let model = thread.read_with(cx, |thread, _| thread.model().unwrap().clone());
+ let model = model.as_fake();
+ let summary_model = thread.read_with(cx, |thread, _| {
+ thread.summarization_model().unwrap().clone()
+ });
+ let summary_model = summary_model.as_fake();
+ let send = acp_thread.update(cx, |thread, cx| {
+ thread.send(
+ vec![
+ "What does ".into(),
+ acp::ContentBlock::ResourceLink(acp::ResourceLink {
+ name: "b.md".into(),
+ uri: MentionUri::File {
+ abs_path: path!("/a/b.md").into(),
+ }
+ .to_uri()
+ .to_string(),
+ annotations: None,
+ description: None,
+ mime_type: None,
+ size: None,
+ title: None,
+ }),
+ " mean?".into(),
+ ],
+ cx,
+ )
+ });
+ let send = cx.foreground_executor().spawn(send);
+ cx.run_until_parked();
+
+ model.send_last_completion_stream_text_chunk("Lorem.");
+ model.end_last_completion_stream();
+ cx.run_until_parked();
+ summary_model.send_last_completion_stream_text_chunk("Explaining /a/b.md");
+ summary_model.end_last_completion_stream();
+
+ send.await.unwrap();
+ acp_thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User
+
+ What does [@b.md](file:///a/b.md) mean?
+
+ ## Assistant
+
+ Lorem.
+
+ "}
+ )
+ });
+
+ // Drop the ACP thread, which should cause the session to be dropped as well.
+ cx.update(|_| {
+ drop(thread);
+ drop(acp_thread);
+ });
+ agent.read_with(cx, |agent, _| {
+ assert_eq!(agent.sessions.keys().cloned().collect::<Vec<_>>(), []);
+ });
+
+ // Ensure the thread can be reloaded from disk.
+ assert_eq!(
+ history_entries(&history_store, cx),
+ vec![(
+ HistoryEntryId::AcpThread(session_id.clone()),
+ "Explaining /a/b.md".into()
+ )]
+ );
+ let acp_thread = agent
+ .update(cx, |agent, cx| agent.open_thread(session_id.clone(), cx))
+ .await
+ .unwrap();
+ acp_thread.read_with(cx, |thread, cx| {
+ assert_eq!(
+ thread.to_markdown(cx),
+ indoc! {"
+ ## User
+
+ What does [@b.md](file:///a/b.md) mean?
+
+ ## Assistant
+
+ Lorem.
+
+ "}
+ )
+ });
+ }
+
+ fn history_entries(
+ history: &Entity<HistoryStore>,
+ cx: &mut TestAppContext,
+ ) -> Vec<(HistoryEntryId, String)> {
+ history.read_with(cx, |history, cx| {
+ history
+ .entries(cx)
+ .iter()
+ .map(|e| (e.id(), e.title().to_string()))
+ .collect::<Vec<_>>()
+ })
+ }
+
fn init_test(cx: &mut TestAppContext) {
env_logger::try_init().ok();
cx.update(|cx| {
@@ -720,7 +720,7 @@ impl Thread {
pub fn to_db(&self, cx: &App) -> Task<DbThread> {
let initial_project_snapshot = self.initial_project_snapshot.clone();
let mut thread = DbThread {
- title: self.title.clone().unwrap_or_default(),
+ title: self.title(),
messages: self.messages.clone(),
updated_at: self.updated_at,
detailed_summary: self.summary.clone(),
@@ -870,6 +870,10 @@ impl Thread {
&self.action_log
}
+ pub fn is_empty(&self) -> bool {
+ self.messages.is_empty() && self.title.is_none()
+ }
+
pub fn model(&self) -> Option<&Arc<dyn LanguageModel>> {
self.model.as_ref()
}
@@ -884,6 +888,10 @@ impl Thread {
cx.notify()
}
+ pub fn summarization_model(&self) -> Option<&Arc<dyn LanguageModel>> {
+ self.summarization_model.as_ref()
+ }
+
pub fn set_summarization_model(
&mut self,
model: Option<Arc<dyn LanguageModel>>,
@@ -1068,6 +1076,7 @@ impl Thread {
event_stream: event_stream.clone(),
_task: cx.spawn(async move |this, cx| {
log::info!("Starting agent turn execution");
+ let mut update_title = None;
let turn_result: Result<StopReason> = async {
let mut completion_intent = CompletionIntent::UserPrompt;
loop {
@@ -1122,10 +1131,15 @@ impl Thread {
this.pending_message()
.tool_results
.insert(tool_result.tool_use_id.clone(), tool_result);
- })
- .ok();
+ })?;
}
+ this.update(cx, |this, cx| {
+ if this.title.is_none() && update_title.is_none() {
+ update_title = Some(this.update_title(&event_stream, cx));
+ }
+ })?;
+
if tool_use_limit_reached {
log::info!("Tool use limit reached, completing turn");
this.update(cx, |this, _cx| this.tool_use_limit_reached = true)?;
@@ -1146,10 +1160,6 @@ impl Thread {
Ok(reason) => {
log::info!("Turn execution completed: {:?}", reason);
- let update_title = this
- .update(cx, |this, cx| this.update_title(&event_stream, cx))
- .ok()
- .flatten();
if let Some(update_title) = update_title {
update_title.await.context("update title failed").log_err();
}
@@ -1593,17 +1603,14 @@ impl Thread {
&mut self,
event_stream: &ThreadEventStream,
cx: &mut Context<Self>,
- ) -> Option<Task<Result<()>>> {
- if self.title.is_some() {
- log::debug!("Skipping title generation because we already have one.");
- return None;
- }
-
+ ) -> Task<Result<()>> {
log::info!(
"Generating title with model: {:?}",
self.summarization_model.as_ref().map(|model| model.name())
);
- let model = self.summarization_model.clone()?;
+ let Some(model) = self.summarization_model.clone() else {
+ return Task::ready(Ok(()));
+ };
let event_stream = event_stream.clone();
let mut request = LanguageModelRequest {
intent: Some(CompletionIntent::ThreadSummarization),
@@ -1620,7 +1627,7 @@ impl Thread {
content: vec![SUMMARIZE_THREAD_PROMPT.into()],
cache: false,
});
- Some(cx.spawn(async move |this, cx| {
+ cx.spawn(async move |this, cx| {
let mut title = String::new();
let mut messages = model.stream_completion(request, cx).await?;
while let Some(event) = messages.next().await {
@@ -1655,7 +1662,7 @@ impl Thread {
this.title = Some(title);
cx.notify();
})
- }))
+ })
}
fn last_user_message(&self) -> Option<&UserMessage> {
@@ -2457,18 +2464,15 @@ impl From<UserMessageContent> for acp::ContentBlock {
uri: None,
}),
UserMessageContent::Mention { uri, content } => {
- acp::ContentBlock::ResourceLink(acp::ResourceLink {
- uri: uri.to_uri().to_string(),
- name: uri.name(),
+ acp::ContentBlock::Resource(acp::EmbeddedResource {
+ resource: acp::EmbeddedResourceResource::TextResourceContents(
+ acp::TextResourceContents {
+ mime_type: None,
+ text: content,
+ uri: uri.to_uri().to_string(),
+ },
+ ),
annotations: None,
- description: if content.is_empty() {
- None
- } else {
- Some(content)
- },
- mime_type: None,
- size: None,
- title: None,
})
}
}