From c55cdd0cb9d16468d175f9f41f8bdf22aa6a4b9f Mon Sep 17 00:00:00 2001 From: Marshall Bowers Date: Thu, 23 Jan 2025 19:09:26 -0500 Subject: [PATCH] assistant2: Add thread persistence (#23582) This PR adds persistence for threads in Assistant2. Threads are now persisted to an LMDB database. Release Notes: - N/A --- Cargo.lock | 3 +- crates/assistant2/Cargo.toml | 3 +- crates/assistant2/src/active_thread.rs | 19 +- crates/assistant2/src/assistant_panel.rs | 66 ++-- crates/assistant2/src/context_picker.rs | 48 ++- .../context_picker/thread_context_picker.rs | 39 +- crates/assistant2/src/thread.rs | 35 ++ crates/assistant2/src/thread_history.rs | 47 +-- crates/assistant2/src/thread_store.rs | 358 ++++++++++-------- 9 files changed, 370 insertions(+), 248 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b627b05afb218cbcb0d501669a51e4632a2ca5a2..c543cba54b6120ae40aeba80eaa837e706505e3c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -456,6 +456,7 @@ dependencies = [ "futures 0.3.31", "fuzzy", "gpui", + "heed", "html_to_markdown", "http_client", "indoc", @@ -470,6 +471,7 @@ dependencies = [ "menu", "multi_buffer", "parking_lot", + "paths", "picker", "project", "prompt_library", @@ -490,7 +492,6 @@ dependencies = [ "time", "time_format", "ui", - "unindent", "util", "uuid", "workspace", diff --git a/crates/assistant2/Cargo.toml b/crates/assistant2/Cargo.toml index 5d213f5b435474956639fd37494620ab7d0d6f04..f53595cd0c067f6060817e300fe7eff0dba13d0d 100644 --- a/crates/assistant2/Cargo.toml +++ b/crates/assistant2/Cargo.toml @@ -39,6 +39,7 @@ fs.workspace = true futures.workspace = true fuzzy.workspace = true gpui.workspace = true +heed.workspace = true html_to_markdown.workspace = true http_client.workspace = true itertools.workspace = true @@ -52,6 +53,7 @@ markdown.workspace = true menu.workspace = true multi_buffer.workspace = true parking_lot.workspace = true +paths.workspace = true picker.workspace = true project.workspace = true prompt_library.workspace = true @@ -71,7 +73,6 @@ theme.workspace = true time.workspace = true time_format.workspace = true ui.workspace = true -unindent.workspace = true util.workspace = true uuid.workspace = true workspace.workspace = true diff --git a/crates/assistant2/src/active_thread.rs b/crates/assistant2/src/active_thread.rs index 948ac37bacb44c288a8c1e91ef1bf720f340cfdc..3f643fdc00ca813928794d76b8e0e8da588fc0b7 100644 --- a/crates/assistant2/src/active_thread.rs +++ b/crates/assistant2/src/active_thread.rs @@ -16,12 +16,14 @@ use ui::prelude::*; use workspace::Workspace; use crate::thread::{MessageId, Thread, ThreadError, ThreadEvent}; +use crate::thread_store::ThreadStore; use crate::ui::ContextPill; pub struct ActiveThread { workspace: WeakView, language_registry: Arc, tools: Arc, + thread_store: Model, thread: Model, messages: Vec, list_state: ListState, @@ -33,6 +35,7 @@ pub struct ActiveThread { impl ActiveThread { pub fn new( thread: Model, + thread_store: Model, workspace: WeakView, language_registry: Arc, tools: Arc, @@ -47,6 +50,7 @@ impl ActiveThread { workspace, language_registry, tools, + thread_store, thread: thread.clone(), messages: Vec::new(), rendered_messages_by_id: HashMap::default(), @@ -192,8 +196,13 @@ impl ActiveThread { ThreadEvent::ShowError(error) => { self.last_error = Some(error.clone()); } - ThreadEvent::StreamedCompletion => {} - ThreadEvent::SummaryChanged => {} + ThreadEvent::StreamedCompletion | ThreadEvent::SummaryChanged => { + self.thread_store + .update(cx, |thread_store, cx| { + thread_store.save_thread(&self.thread, cx) + }) + .detach_and_log_err(cx); + } ThreadEvent::StreamedAssistantText(message_id, text) => { if let Some(markdown) = self.rendered_messages_by_id.get_mut(&message_id) { markdown.update(cx, |markdown, cx| { @@ -211,6 +220,12 @@ impl ActiveThread { self.push_message(message_id, message_text, cx); } + self.thread_store + .update(cx, |thread_store, cx| { + thread_store.save_thread(&self.thread, cx) + }) + .detach_and_log_err(cx); + cx.notify(); } ThreadEvent::UsePendingTools => { diff --git a/crates/assistant2/src/assistant_panel.rs b/crates/assistant2/src/assistant_panel.rs index d98e238613774c91348c49dcd416c93cb7652e0e..f9e08fa738774fd49ae2f3bb29281450f3baa6c7 100644 --- a/crates/assistant2/src/assistant_panel.rs +++ b/crates/assistant2/src/assistant_panel.rs @@ -179,6 +179,7 @@ impl AssistantPanel { thread: cx.new_view(|cx| { ActiveThread::new( thread.clone(), + thread_store.clone(), workspace, language_registry, tools.clone(), @@ -239,6 +240,7 @@ impl AssistantPanel { self.thread = cx.new_view(|cx| { ActiveThread::new( thread.clone(), + self.thread_store.clone(), self.workspace.clone(), self.language_registry.clone(), self.tools.clone(), @@ -361,34 +363,41 @@ impl AssistantPanel { }) } - pub(crate) fn open_thread(&mut self, thread_id: &ThreadId, cx: &mut ViewContext) { - let Some(thread) = self + pub(crate) fn open_thread( + &mut self, + thread_id: &ThreadId, + cx: &mut ViewContext, + ) -> Task> { + let open_thread_task = self .thread_store - .update(cx, |this, cx| this.open_thread(thread_id, cx)) - else { - return; - }; + .update(cx, |this, cx| this.open_thread(thread_id, cx)); - self.active_view = ActiveView::Thread; - self.thread = cx.new_view(|cx| { - ActiveThread::new( - thread.clone(), - self.workspace.clone(), - self.language_registry.clone(), - self.tools.clone(), - cx, - ) - }); - self.message_editor = cx.new_view(|cx| { - MessageEditor::new( - self.fs.clone(), - self.workspace.clone(), - self.thread_store.downgrade(), - thread, - cx, - ) - }); - self.message_editor.focus_handle(cx).focus(cx); + cx.spawn(|this, mut cx| async move { + let thread = open_thread_task.await?; + this.update(&mut cx, |this, cx| { + this.active_view = ActiveView::Thread; + this.thread = cx.new_view(|cx| { + ActiveThread::new( + thread.clone(), + this.thread_store.clone(), + this.workspace.clone(), + this.language_registry.clone(), + this.tools.clone(), + cx, + ) + }); + this.message_editor = cx.new_view(|cx| { + MessageEditor::new( + this.fs.clone(), + this.workspace.clone(), + this.thread_store.downgrade(), + thread, + cx, + ) + }); + this.message_editor.focus_handle(cx).focus(cx); + }) + }) } pub(crate) fn open_configuration(&mut self, cx: &mut ViewContext) { @@ -437,7 +446,8 @@ impl AssistantPanel { pub(crate) fn delete_thread(&mut self, thread_id: &ThreadId, cx: &mut ViewContext) { self.thread_store - .update(cx, |this, cx| this.delete_thread(thread_id, cx)); + .update(cx, |this, cx| this.delete_thread(thread_id, cx)) + .detach_and_log_err(cx); } } @@ -655,7 +665,7 @@ impl AssistantPanel { fn render_thread_empty_state(&self, cx: &mut ViewContext) -> impl IntoElement { let recent_threads = self .thread_store - .update(cx, |this, cx| this.recent_threads(3, cx)); + .update(cx, |this, _cx| this.recent_threads(3)); v_flex() .gap_2() diff --git a/crates/assistant2/src/context_picker.rs b/crates/assistant2/src/context_picker.rs index dbca3a115ae1eaca212dc6cd2c3dae0420c41800..b1b230bd05fd0a228fa7a2c6b2d1972bf8763630 100644 --- a/crates/assistant2/src/context_picker.rs +++ b/crates/assistant2/src/context_picker.rs @@ -6,10 +6,12 @@ mod thread_context_picker; use std::path::PathBuf; use std::sync::Arc; +use anyhow::{anyhow, Result}; use editor::Editor; use file_context_picker::render_file_context_entry; use gpui::{ - AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, View, WeakModel, WeakView, + AppContext, DismissEvent, EventEmitter, FocusHandle, FocusableView, Task, View, WeakModel, + WeakView, }; use project::ProjectPath; use thread_context_picker::{render_thread_context_entry, ThreadContextEntry}; @@ -237,7 +239,8 @@ impl ContextPicker { }, move |cx| { context_picker.update(cx, |this, cx| { - this.add_recent_thread(thread.clone(), cx); + this.add_recent_thread(thread.clone(), cx) + .detach_and_log_err(cx); }) }, ) @@ -260,25 +263,32 @@ impl ContextPicker { cx.notify(); } - fn add_recent_thread(&self, thread: ThreadContextEntry, cx: &mut ViewContext) { + fn add_recent_thread( + &self, + thread: ThreadContextEntry, + cx: &mut ViewContext, + ) -> Task> { let Some(context_store) = self.context_store.upgrade() else { - return; + return Task::ready(Err(anyhow!("context store not available"))); }; - let Some(thread) = self + let Some(thread_store) = self .thread_store - .clone() - .and_then(|this| this.upgrade()) - .and_then(|this| this.update(cx, |this, cx| this.open_thread(&thread.id, cx))) + .as_ref() + .and_then(|thread_store| thread_store.upgrade()) else { - return; + return Task::ready(Err(anyhow!("thread store not available"))); }; - context_store.update(cx, |context_store, cx| { - context_store.add_thread(thread, cx); - }); + let open_thread_task = thread_store.update(cx, |this, cx| this.open_thread(&thread.id, cx)); + cx.spawn(|this, mut cx| async move { + let thread = open_thread_task.await?; + context_store.update(&mut cx, |context_store, cx| { + context_store.add_thread(thread, cx); + })?; - cx.notify(); + this.update(&mut cx, |_this, cx| cx.notify()) + }) } fn recent_entries(&self, cx: &mut WindowContext) -> Vec { @@ -332,19 +342,17 @@ impl ContextPicker { return recent; }; - thread_store.update(cx, |thread_store, cx| { + thread_store.update(cx, |thread_store, _cx| { recent.extend( thread_store - .threads(cx) + .threads() .into_iter() - .filter(|thread| !current_threads.contains(thread.read(cx).id())) + .filter(|thread| !current_threads.contains(&thread.id)) .take(2) .map(|thread| { - let thread = thread.read(cx); - RecentEntry::Thread(ThreadContextEntry { - id: thread.id().clone(), - summary: thread.summary_or_default(), + id: thread.id, + summary: thread.summary, }) }), ) diff --git a/crates/assistant2/src/context_picker/thread_context_picker.rs b/crates/assistant2/src/context_picker/thread_context_picker.rs index bea32d3d043ee8d14815b04ffc81c55a9d276b11..16c05d5405577f6fe9e4b4664f1b1c4292fc7a53 100644 --- a/crates/assistant2/src/context_picker/thread_context_picker.rs +++ b/crates/assistant2/src/context_picker/thread_context_picker.rs @@ -99,13 +99,12 @@ impl PickerDelegate for ThreadContextPickerDelegate { } fn update_matches(&mut self, query: String, cx: &mut ViewContext>) -> Task<()> { - let Ok(threads) = self.thread_store.update(cx, |this, cx| { - this.threads(cx) + let Ok(threads) = self.thread_store.update(cx, |this, _cx| { + this.threads() .into_iter() - .map(|thread| { - let id = thread.read(cx).id().clone(); - let summary = thread.read(cx).summary_or_default(); - ThreadContextEntry { id, summary } + .map(|thread| ThreadContextEntry { + id: thread.id, + summary: thread.summary, }) .collect::>() }) else { @@ -159,19 +158,23 @@ impl PickerDelegate for ThreadContextPickerDelegate { return; }; - let Some(thread) = thread_store.update(cx, |this, cx| this.open_thread(&entry.id, cx)) - else { - return; - }; + let open_thread_task = thread_store.update(cx, |this, cx| this.open_thread(&entry.id, cx)); - self.context_store - .update(cx, |context_store, cx| context_store.add_thread(thread, cx)) - .ok(); - - match self.confirm_behavior { - ConfirmBehavior::KeepOpen => {} - ConfirmBehavior::Close => self.dismissed(cx), - } + cx.spawn(|this, mut cx| async move { + let thread = open_thread_task.await?; + this.update(&mut cx, |this, cx| { + this.delegate + .context_store + .update(cx, |context_store, cx| context_store.add_thread(thread, cx)) + .ok(); + + match this.delegate.confirm_behavior { + ConfirmBehavior::KeepOpen => {} + ConfirmBehavior::Close => this.delegate.dismissed(cx), + } + }) + }) + .detach_and_log_err(cx); } fn dismissed(&mut self, cx: &mut ViewContext>) { diff --git a/crates/assistant2/src/thread.rs b/crates/assistant2/src/thread.rs index fa0fec39be63809912d7953d56cc47a6624728b4..57cc0c3b87ca0f4c4e9803f127c244adc4c8f2d8 100644 --- a/crates/assistant2/src/thread.rs +++ b/crates/assistant2/src/thread.rs @@ -18,6 +18,7 @@ use util::{post_inc, TryFutureExt as _}; use uuid::Uuid; use crate::context::{attach_context_to_message, ContextId, ContextSnapshot}; +use crate::thread_store::SavedThread; #[derive(Debug, Clone, Copy)] pub enum RequestKind { @@ -94,6 +95,40 @@ impl Thread { } } + pub fn from_saved( + id: ThreadId, + saved: SavedThread, + tools: Arc, + _cx: &mut ModelContext, + ) -> Self { + let next_message_id = MessageId(saved.messages.len()); + + Self { + id, + updated_at: saved.updated_at, + summary: Some(saved.summary), + pending_summary: Task::ready(None), + messages: saved + .messages + .into_iter() + .map(|message| Message { + id: message.id, + role: message.role, + text: message.text, + }) + .collect(), + next_message_id, + context: BTreeMap::default(), + context_by_message: HashMap::default(), + completion_count: 0, + pending_completions: Vec::new(), + tools, + tool_uses_by_message: HashMap::default(), + tool_results_by_message: HashMap::default(), + pending_tool_uses_by_id: HashMap::default(), + } + } + pub fn id(&self) -> &ThreadId { &self.id } diff --git a/crates/assistant2/src/thread_history.rs b/crates/assistant2/src/thread_history.rs index 18619fd0514b79df7a1a04e42179766654487926..eeb1bae84c8f565805c4187d7909c7110e9b1ff0 100644 --- a/crates/assistant2/src/thread_history.rs +++ b/crates/assistant2/src/thread_history.rs @@ -5,8 +5,7 @@ use gpui::{ use time::{OffsetDateTime, UtcOffset}; use ui::{prelude::*, IconButtonShape, ListItem, ListItemSpacing, Tooltip}; -use crate::thread::Thread; -use crate::thread_store::ThreadStore; +use crate::thread_store::{SavedThreadMetadata, ThreadStore}; use crate::{AssistantPanel, RemoveSelectedThread}; pub struct ThreadHistory { @@ -33,8 +32,7 @@ impl ThreadHistory { } pub fn select_prev(&mut self, _: &menu::SelectPrev, cx: &mut ViewContext) { - let count = self.thread_store.read(cx).non_empty_len(cx); - + let count = self.thread_store.read(cx).thread_count(); if count > 0 { if self.selected_index == 0 { self.set_selected_index(count - 1, cx); @@ -45,8 +43,7 @@ impl ThreadHistory { } pub fn select_next(&mut self, _: &menu::SelectNext, cx: &mut ViewContext) { - let count = self.thread_store.read(cx).non_empty_len(cx); - + let count = self.thread_store.read(cx).thread_count(); if count > 0 { if self.selected_index == count - 1 { self.set_selected_index(0, cx); @@ -57,14 +54,14 @@ impl ThreadHistory { } fn select_first(&mut self, _: &menu::SelectFirst, cx: &mut ViewContext) { - let count = self.thread_store.read(cx).non_empty_len(cx); + let count = self.thread_store.read(cx).thread_count(); if count > 0 { self.set_selected_index(0, cx); } } fn select_last(&mut self, _: &menu::SelectLast, cx: &mut ViewContext) { - let count = self.thread_store.read(cx).non_empty_len(cx); + let count = self.thread_store.read(cx).thread_count(); if count > 0 { self.set_selected_index(count - 1, cx); } @@ -78,14 +75,11 @@ impl ThreadHistory { } fn confirm(&mut self, _: &menu::Confirm, cx: &mut ViewContext) { - let threads = self.thread_store.update(cx, |this, cx| this.threads(cx)); + let threads = self.thread_store.update(cx, |this, _cx| this.threads()); if let Some(thread) = threads.get(self.selected_index) { self.assistant_panel - .update(cx, move |this, cx| { - let thread_id = thread.read(cx).id().clone(); - this.open_thread(&thread_id, cx) - }) + .update(cx, move |this, cx| this.open_thread(&thread.id, cx)) .ok(); cx.notify(); @@ -93,13 +87,12 @@ impl ThreadHistory { } fn remove_selected_thread(&mut self, _: &RemoveSelectedThread, cx: &mut ViewContext) { - let threads = self.thread_store.update(cx, |this, cx| this.threads(cx)); + let threads = self.thread_store.update(cx, |this, _cx| this.threads()); if let Some(thread) = threads.get(self.selected_index) { self.assistant_panel .update(cx, |this, cx| { - let thread_id = thread.read(cx).id().clone(); - this.delete_thread(&thread_id, cx); + this.delete_thread(&thread.id, cx); }) .ok(); @@ -116,7 +109,7 @@ impl FocusableView for ThreadHistory { impl Render for ThreadHistory { fn render(&mut self, cx: &mut ViewContext) -> impl IntoElement { - let threads = self.thread_store.update(cx, |this, cx| this.threads(cx)); + let threads = self.thread_store.update(cx, |this, _cx| this.threads()); let selected_index = self.selected_index; v_flex() @@ -172,14 +165,14 @@ impl Render for ThreadHistory { #[derive(IntoElement)] pub struct PastThread { - thread: Model, + thread: SavedThreadMetadata, assistant_panel: WeakView, selected: bool, } impl PastThread { pub fn new( - thread: Model, + thread: SavedThreadMetadata, assistant_panel: WeakView, selected: bool, ) -> Self { @@ -193,14 +186,10 @@ impl PastThread { impl RenderOnce for PastThread { fn render(self, cx: &mut WindowContext) -> impl IntoElement { - let (id, summary) = { - let thread = self.thread.read(cx); - (thread.id().clone(), thread.summary_or_default()) - }; + let summary = self.thread.summary; let thread_timestamp = time_format::format_localized_timestamp( - OffsetDateTime::from_unix_timestamp(self.thread.read(cx).updated_at().timestamp()) - .unwrap(), + OffsetDateTime::from_unix_timestamp(self.thread.updated_at.timestamp()).unwrap(), OffsetDateTime::now_utc(), self.assistant_panel .update(cx, |this, _cx| this.local_timezone()) @@ -208,7 +197,7 @@ impl RenderOnce for PastThread { time_format::TimestampFormat::EnhancedAbsolute, ); - ListItem::new(("past-thread", self.thread.entity_id())) + ListItem::new(SharedString::from(self.thread.id.to_string())) .outlined() .toggle_state(self.selected) .start_slot( @@ -233,7 +222,7 @@ impl RenderOnce for PastThread { .tooltip(|cx| Tooltip::text("Delete Thread", cx)) .on_click({ let assistant_panel = self.assistant_panel.clone(); - let id = id.clone(); + let id = self.thread.id.clone(); move |_event, cx| { assistant_panel .update(cx, |this, cx| { @@ -246,11 +235,11 @@ impl RenderOnce for PastThread { ) .on_click({ let assistant_panel = self.assistant_panel.clone(); - let id = id.clone(); + let id = self.thread.id.clone(); move |_event, cx| { assistant_panel .update(cx, |this, cx| { - this.open_thread(&id, cx); + this.open_thread(&id, cx).detach_and_log_err(cx); }) .ok(); } diff --git a/crates/assistant2/src/thread_store.rs b/crates/assistant2/src/thread_store.rs index e07e447f79e2b0a601105c1aa836f8f5cc279bf5..ee377649a1719cbc3589eb006971e48a1626a802 100644 --- a/crates/assistant2/src/thread_store.rs +++ b/crates/assistant2/src/thread_store.rs @@ -1,16 +1,23 @@ +use std::path::PathBuf; use std::sync::Arc; -use anyhow::Result; +use anyhow::{anyhow, Result}; use assistant_tool::{ToolId, ToolWorkingSet}; +use chrono::{DateTime, Utc}; use collections::HashMap; use context_server::manager::ContextServerManager; use context_server::{ContextServerFactoryRegistry, ContextServerTool}; -use gpui::{prelude::*, AppContext, Model, ModelContext, Task}; +use futures::future::{self, BoxFuture, Shared}; +use futures::FutureExt as _; +use gpui::{prelude::*, AppContext, BackgroundExecutor, Model, ModelContext, SharedString, Task}; +use heed::types::SerdeBincode; +use heed::Database; +use language_model::Role; use project::Project; -use unindent::Unindent; +use serde::{Deserialize, Serialize}; use util::ResultExt as _; -use crate::thread::{Thread, ThreadId}; +use crate::thread::{MessageId, Thread, ThreadId}; pub struct ThreadStore { #[allow(unused)] @@ -18,7 +25,8 @@ pub struct ThreadStore { tools: Arc, context_server_manager: Model, context_server_tool_ids: HashMap, Vec>, - threads: Vec>, + threads: Vec, + database_future: Shared, Arc>>>, } impl ThreadStore { @@ -35,61 +43,138 @@ impl ThreadStore { ContextServerManager::new(context_server_factory_registry, project.clone(), cx) }); - let mut this = Self { + let executor = cx.background_executor().clone(); + let database_future = executor + .spawn({ + let executor = executor.clone(); + let database_path = paths::support_dir().join("threads/threads-db.0.mdb"); + async move { ThreadsDatabase::new(database_path, executor) } + }) + .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new))) + .boxed() + .shared(); + + let this = Self { project, tools, context_server_manager, context_server_tool_ids: HashMap::default(), threads: Vec::new(), + database_future, }; - this.mock_recent_threads(cx); this.register_context_server_handlers(cx); this })?; + this.update(&mut cx, |this, cx| this.reload(cx))?.await?; + Ok(this) }) } - /// Returns the number of non-empty threads. - pub fn non_empty_len(&self, cx: &AppContext) -> usize { - self.threads - .iter() - .filter(|thread| !thread.read(cx).is_empty()) - .count() + /// Returns the number of threads. + pub fn thread_count(&self) -> usize { + self.threads.len() } - pub fn threads(&self, cx: &ModelContext) -> Vec> { - let mut threads = self - .threads - .iter() - .filter(|thread| !thread.read(cx).is_empty()) - .cloned() - .collect::>(); - threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.read(cx).updated_at())); + pub fn threads(&self) -> Vec { + let mut threads = self.threads.iter().cloned().collect::>(); + threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at)); threads } - pub fn recent_threads(&self, limit: usize, cx: &ModelContext) -> Vec> { - self.threads(cx).into_iter().take(limit).collect() + pub fn recent_threads(&self, limit: usize) -> Vec { + self.threads().into_iter().take(limit).collect() } pub fn create_thread(&mut self, cx: &mut ModelContext) -> Model { - let thread = cx.new_model(|cx| Thread::new(self.tools.clone(), cx)); - self.threads.push(thread.clone()); - thread + cx.new_model(|cx| Thread::new(self.tools.clone(), cx)) } - pub fn open_thread(&self, id: &ThreadId, cx: &mut ModelContext) -> Option> { - self.threads - .iter() - .find(|thread| thread.read(cx).id() == id) - .cloned() + pub fn open_thread( + &self, + id: &ThreadId, + cx: &mut ModelContext, + ) -> Task>> { + let id = id.clone(); + let database_future = self.database_future.clone(); + cx.spawn(|this, mut cx| async move { + let database = database_future.await.map_err(|err| anyhow!(err))?; + let thread = database + .try_find_thread(id.clone()) + .await? + .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?; + + this.update(&mut cx, |this, cx| { + cx.new_model(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx)) + }) + }) } - pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut ModelContext) { - self.threads.retain(|thread| thread.read(cx).id() != id); + pub fn save_thread( + &self, + thread: &Model, + cx: &mut ModelContext, + ) -> Task> { + let (metadata, thread) = thread.update(cx, |thread, _cx| { + let id = thread.id().clone(); + let thread = SavedThread { + summary: thread.summary_or_default(), + updated_at: thread.updated_at(), + messages: thread + .messages() + .map(|message| SavedMessage { + id: message.id, + role: message.role, + text: message.text.clone(), + }) + .collect(), + }; + + (id, thread) + }); + + let database_future = self.database_future.clone(); + cx.spawn(|this, mut cx| async move { + let database = database_future.await.map_err(|err| anyhow!(err))?; + database.save_thread(metadata, thread).await?; + + this.update(&mut cx, |this, cx| this.reload(cx))?.await + }) + } + + pub fn delete_thread( + &mut self, + id: &ThreadId, + cx: &mut ModelContext, + ) -> Task> { + let id = id.clone(); + let database_future = self.database_future.clone(); + cx.spawn(|this, mut cx| async move { + let database = database_future.await.map_err(|err| anyhow!(err))?; + database.delete_thread(id.clone()).await?; + + this.update(&mut cx, |this, _cx| { + this.threads.retain(|thread| thread.id != id) + }) + }) + } + + fn reload(&self, cx: &mut ModelContext) -> Task> { + let database_future = self.database_future.clone(); + cx.spawn(|this, mut cx| async move { + let threads = database_future + .await + .map_err(|err| anyhow!(err))? + .list_threads() + .await?; + + this.update(&mut cx, |this, cx| { + this.threads = threads; + cx.notify(); + }) + }) } fn register_context_server_handlers(&self, cx: &mut ModelContext) { @@ -159,133 +244,108 @@ impl ThreadStore { } } -impl ThreadStore { - /// Creates some mocked recent threads for testing purposes. - fn mock_recent_threads(&mut self, cx: &mut ModelContext) { - use language_model::Role; - - self.threads.push(cx.new_model(|cx| { - let mut thread = Thread::new(self.tools.clone(), cx); - thread.set_summary("Introduction to quantum computing", cx); - thread.insert_user_message("Hello! Can you help me understand quantum computing?", Vec::new(), cx); - thread.insert_message(Role::Assistant, "Of course! I'd be happy to help you understand quantum computing. Quantum computing is a fascinating field that uses the principles of quantum mechanics to process information. Unlike classical computers that use bits (0s and 1s), quantum computers use quantum bits or 'qubits'. These qubits can exist in multiple states simultaneously, a property called superposition. This allows quantum computers to perform certain calculations much faster than classical computers. What specific aspect of quantum computing would you like to know more about?", cx); - thread.insert_user_message("That's interesting! Can you explain how quantum entanglement is used in quantum computing?", Vec::new(), cx); - thread.insert_message(Role::Assistant, "Certainly! Quantum entanglement is a key principle used in quantum computing. When two qubits become entangled, the state of one qubit is directly related to the state of the other, regardless of the distance between them. This property is used in quantum computing to create complex quantum states and to perform operations on multiple qubits simultaneously. Entanglement allows quantum computers to process information in ways that classical computers cannot, potentially solving certain problems much more efficiently. For example, it's crucial in quantum error correction and in algorithms like quantum teleportation, which is important for quantum communication.", cx); - thread - })); - - self.threads.push(cx.new_model(|cx| { - let mut thread = Thread::new(self.tools.clone(), cx); - thread.set_summary("Rust web development and async programming", cx); - thread.insert_user_message("Can you show me an example of Rust code for a simple web server?", Vec::new(), cx); - thread.insert_message(Role::Assistant, "Certainly! Here's an example of a simple web server in Rust using the `actix-web` framework: - - ```rust - use actix_web::{web, App, HttpResponse, HttpServer, Responder}; - - async fn hello() -> impl Responder { - HttpResponse::Ok().body(\"Hello, World!\") - } - - #[actix_web::main] - async fn main() -> std::io::Result<()> { - HttpServer::new(|| { - App::new() - .route(\"/\", web::get().to(hello)) - }) - .bind(\"127.0.0.1:8080\")? - .run() - .await - } - ``` - - This code creates a basic web server that responds with 'Hello, World!' when you access the root URL. Here's a breakdown of what's happening: - - 1. We import necessary items from the `actix-web` crate. - 2. We define an async `hello` function that returns a simple HTTP response. - 3. In the `main` function, we set up the server to listen on `127.0.0.1:8080`. - 4. We configure the app to respond to GET requests on the root path with our `hello` function. - - To run this, you'd need to add `actix-web` to your `Cargo.toml` dependencies: - - ```toml - [dependencies] - actix-web = \"4.0\" - ``` - - Then you can run the server with `cargo run` and access it at `http://localhost:8080`.".unindent(), cx); - thread.insert_user_message("That's great! Can you explain more about async functions in Rust?", Vec::new(), cx); - thread.insert_message(Role::Assistant, "Certainly! Async functions are a key feature in Rust for writing efficient, non-blocking code, especially for I/O-bound operations. Here's an overview: - - 1. **Syntax**: Async functions are declared using the `async` keyword: - - ```rust - async fn my_async_function() -> Result<(), Error> { - // Asynchronous code here - } - ``` - - 2. **Futures**: Async functions return a `Future`. A `Future` represents a value that may not be available yet but will be at some point. - - 3. **Await**: Inside an async function, you can use the `.await` syntax to wait for other async operations to complete: - - ```rust - async fn fetch_data() -> Result { - let response = make_http_request().await?; - let data = process_response(response).await?; - Ok(data) - } - ``` - - 4. **Non-blocking**: Async functions allow the runtime to work on other tasks while waiting for I/O or other operations to complete, making efficient use of system resources. - - 5. **Runtime**: To execute async code, you need a runtime like `tokio` or `async-std`. Actix-web, which we used in the previous example, includes its own runtime. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct SavedThreadMetadata { + pub id: ThreadId, + pub summary: SharedString, + pub updated_at: DateTime, +} - 6. **Error Handling**: Async functions work well with Rust's `?` operator for error handling. +#[derive(Serialize, Deserialize)] +pub struct SavedThread { + pub summary: SharedString, + pub updated_at: DateTime, + pub messages: Vec, +} - Async programming in Rust provides a powerful way to write concurrent code that's both safe and efficient. It's particularly useful for servers, network programming, and any application that deals with many concurrent operations.".unindent(), cx); - thread - })); +#[derive(Serialize, Deserialize)] +pub struct SavedMessage { + pub id: MessageId, + pub role: Role, + pub text: String, +} - self.threads.push(cx.new_model(|cx| { - let mut thread = Thread::new(self.tools.clone(), cx); - thread.set_summary("Rust code with long lines", cx); - thread.insert_user_message("Could you write me some Rust code with long lines?", Vec::new(), cx); - thread.insert_message(Role::Assistant, r#"Here's some Rust code with some intentionally long lines: - ```rust - use std::collections::{HashMap, HashSet}; - use std::sync::{Arc, Mutex}; - use std::thread; +struct ThreadsDatabase { + executor: BackgroundExecutor, + env: heed::Env, + threads: Database, SerdeBincode>, +} - fn main() { - let very_long_vector = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25]; +impl ThreadsDatabase { + pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result { + std::fs::create_dir_all(&path)?; + + const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024; + let env = unsafe { + heed::EnvOpenOptions::new() + .map_size(ONE_GB_IN_BYTES) + .max_dbs(1) + .open(path)? + }; + + let mut txn = env.write_txn()?; + let threads = env.create_database(&mut txn, Some("threads"))?; + txn.commit()?; + + Ok(Self { + executor, + env, + threads, + }) + } - let complicated_hashmap: HashMap> = [("key1".to_string(), vec![(1, 1.1, "value1".to_string()), (2, 2.2, "value2".to_string())]), ("key2".to_string(), vec![(3, 3.3, "value3".to_string()), (4, 4.4, "value4".to_string())])].iter().cloned().collect(); + pub fn list_threads(&self) -> Task>> { + let env = self.env.clone(); + let threads = self.threads; + + self.executor.spawn(async move { + let txn = env.read_txn()?; + let mut iter = threads.iter(&txn)?; + let mut threads = Vec::new(); + while let Some((key, value)) = iter.next().transpose()? { + threads.push(SavedThreadMetadata { + id: key, + summary: value.summary, + updated_at: value.updated_at, + }); + } - let nested_structure = Arc::new(Mutex::new(HashMap::new())); + Ok(threads) + }) + } - let long_closure = |x: i32, y: i32, z: i32| -> i32 { let result = x * y + z; println!("The result of the long closure calculation is: {}", result); result }; + pub fn try_find_thread(&self, id: ThreadId) -> Task>> { + let env = self.env.clone(); + let threads = self.threads; - let thread_handles: Vec<_> = (0..10).map(|i| { - let nested_structure_clone = Arc::clone(&nested_structure); - thread::spawn(move || { - let mut lock = nested_structure_clone.lock().unwrap(); - lock.entry(format!("thread_{}", i)).or_insert_with(|| HashSet::new()).insert(i * i); - }) - }).collect(); + self.executor.spawn(async move { + let txn = env.read_txn()?; + let thread = threads.get(&txn, &id)?; + Ok(thread) + }) + } - for handle in thread_handles { - handle.join().unwrap(); - } + pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task> { + let env = self.env.clone(); + let threads = self.threads; - println!("The final state of the nested structure is: {:?}", nested_structure.lock().unwrap()); + self.executor.spawn(async move { + let mut txn = env.write_txn()?; + threads.put(&mut txn, &id, &thread)?; + txn.commit()?; + Ok(()) + }) + } - let complex_expression = very_long_vector.iter().filter(|&&x| x % 2 == 0).map(|&x| x * x).fold(0, |acc, x| acc + x) + long_closure(5, 10, 15); + pub fn delete_thread(&self, id: ThreadId) -> Task> { + let env = self.env.clone(); + let threads = self.threads; - println!("The result of the complex expression is: {}", complex_expression); - } - ```"#.unindent(), cx); - thread - })); + self.executor.spawn(async move { + let mut txn = env.write_txn()?; + threads.delete(&mut txn, &id)?; + txn.commit()?; + Ok(()) + }) } }