1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use assistant_tool::{ToolId, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::HashMap;
8use context_server::manager::ContextServerManager;
9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
10use futures::future::{self, BoxFuture, Shared};
11use futures::FutureExt as _;
12use gpui::{prelude::*, App, BackgroundExecutor, Context, Entity, SharedString, Task};
13use heed::types::SerdeBincode;
14use heed::Database;
15use language_model::Role;
16use project::Project;
17use serde::{Deserialize, Serialize};
18use util::ResultExt as _;
19
20use crate::thread::{MessageId, Thread, ThreadId};
21
22pub struct ThreadStore {
23 #[allow(unused)]
24 project: Entity<Project>,
25 tools: Arc<ToolWorkingSet>,
26 context_server_manager: Entity<ContextServerManager>,
27 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
28 threads: Vec<SavedThreadMetadata>,
29 database_future: Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
30}
31
32impl ThreadStore {
33 pub fn new(
34 project: Entity<Project>,
35 tools: Arc<ToolWorkingSet>,
36 cx: &mut App,
37 ) -> Task<Result<Entity<Self>>> {
38 cx.spawn(|mut cx| async move {
39 let this = cx.new(|cx: &mut Context<Self>| {
40 let context_server_factory_registry =
41 ContextServerFactoryRegistry::default_global(cx);
42 let context_server_manager = cx.new(|cx| {
43 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
44 });
45
46 let executor = cx.background_executor().clone();
47 let database_future = executor
48 .spawn({
49 let executor = executor.clone();
50 let database_path = paths::support_dir().join("threads/threads-db.0.mdb");
51 async move { ThreadsDatabase::new(database_path, executor) }
52 })
53 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
54 .boxed()
55 .shared();
56
57 let this = Self {
58 project,
59 tools,
60 context_server_manager,
61 context_server_tool_ids: HashMap::default(),
62 threads: Vec::new(),
63 database_future,
64 };
65 this.register_context_server_handlers(cx);
66
67 this
68 })?;
69
70 log::info!("[assistant2-debug] reloading threads");
71 this.update(&mut cx, |this, cx| this.reload(cx))?.await?;
72 log::info!("[assistant2-debug] finished reloading threads");
73
74 Ok(this)
75 })
76 }
77
78 /// Returns the number of threads.
79 pub fn thread_count(&self) -> usize {
80 self.threads.len()
81 }
82
83 pub fn threads(&self) -> Vec<SavedThreadMetadata> {
84 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
85 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
86 threads
87 }
88
89 pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
90 self.threads().into_iter().take(limit).collect()
91 }
92
93 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
94 cx.new(|cx| Thread::new(self.tools.clone(), cx))
95 }
96
97 pub fn open_thread(
98 &self,
99 id: &ThreadId,
100 cx: &mut Context<Self>,
101 ) -> Task<Result<Entity<Thread>>> {
102 let id = id.clone();
103 let database_future = self.database_future.clone();
104 cx.spawn(|this, mut cx| async move {
105 let database = database_future.await.map_err(|err| anyhow!(err))?;
106 let thread = database
107 .try_find_thread(id.clone())
108 .await?
109 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
110
111 this.update(&mut cx, |this, cx| {
112 cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
113 })
114 })
115 }
116
117 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
118 let (metadata, thread) = thread.update(cx, |thread, _cx| {
119 let id = thread.id().clone();
120 let thread = SavedThread {
121 summary: thread.summary_or_default(),
122 updated_at: thread.updated_at(),
123 messages: thread
124 .messages()
125 .map(|message| SavedMessage {
126 id: message.id,
127 role: message.role,
128 text: message.text.clone(),
129 })
130 .collect(),
131 };
132
133 (id, thread)
134 });
135
136 let database_future = self.database_future.clone();
137 cx.spawn(|this, mut cx| async move {
138 let database = database_future.await.map_err(|err| anyhow!(err))?;
139 database.save_thread(metadata, thread).await?;
140
141 this.update(&mut cx, |this, cx| this.reload(cx))?.await
142 })
143 }
144
145 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
146 let id = id.clone();
147 let database_future = self.database_future.clone();
148 cx.spawn(|this, mut cx| async move {
149 let database = database_future.await.map_err(|err| anyhow!(err))?;
150 database.delete_thread(id.clone()).await?;
151
152 this.update(&mut cx, |this, _cx| {
153 this.threads.retain(|thread| thread.id != id)
154 })
155 })
156 }
157
158 fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
159 let database_future = self.database_future.clone();
160 cx.spawn(|this, mut cx| async move {
161 let threads = database_future
162 .await
163 .map_err(|err| anyhow!(err))?
164 .list_threads()
165 .await?;
166
167 this.update(&mut cx, |this, cx| {
168 this.threads = threads;
169 cx.notify();
170 })
171 })
172 }
173
174 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
175 cx.subscribe(
176 &self.context_server_manager.clone(),
177 Self::handle_context_server_event,
178 )
179 .detach();
180 }
181
182 fn handle_context_server_event(
183 &mut self,
184 context_server_manager: Entity<ContextServerManager>,
185 event: &context_server::manager::Event,
186 cx: &mut Context<Self>,
187 ) {
188 let tool_working_set = self.tools.clone();
189 match event {
190 context_server::manager::Event::ServerStarted { server_id } => {
191 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
192 let context_server_manager = context_server_manager.clone();
193 cx.spawn({
194 let server = server.clone();
195 let server_id = server_id.clone();
196 |this, mut cx| async move {
197 let Some(protocol) = server.client() else {
198 return;
199 };
200
201 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
202 if let Some(tools) = protocol.list_tools().await.log_err() {
203 let tool_ids = tools
204 .tools
205 .into_iter()
206 .map(|tool| {
207 log::info!(
208 "registering context server tool: {:?}",
209 tool.name
210 );
211 tool_working_set.insert(Arc::new(
212 ContextServerTool::new(
213 context_server_manager.clone(),
214 server.id(),
215 tool,
216 ),
217 ))
218 })
219 .collect::<Vec<_>>();
220
221 this.update(&mut cx, |this, _cx| {
222 this.context_server_tool_ids.insert(server_id, tool_ids);
223 })
224 .log_err();
225 }
226 }
227 }
228 })
229 .detach();
230 }
231 }
232 context_server::manager::Event::ServerStopped { server_id } => {
233 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
234 tool_working_set.remove(&tool_ids);
235 }
236 }
237 }
238 }
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct SavedThreadMetadata {
243 pub id: ThreadId,
244 pub summary: SharedString,
245 pub updated_at: DateTime<Utc>,
246}
247
248#[derive(Serialize, Deserialize)]
249pub struct SavedThread {
250 pub summary: SharedString,
251 pub updated_at: DateTime<Utc>,
252 pub messages: Vec<SavedMessage>,
253}
254
255#[derive(Serialize, Deserialize)]
256pub struct SavedMessage {
257 pub id: MessageId,
258 pub role: Role,
259 pub text: String,
260}
261
262struct ThreadsDatabase {
263 executor: BackgroundExecutor,
264 env: heed::Env,
265 threads: Database<SerdeBincode<ThreadId>, SerdeBincode<SavedThread>>,
266}
267
268impl ThreadsDatabase {
269 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
270 std::fs::create_dir_all(&path)?;
271
272 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
273 let env = unsafe {
274 heed::EnvOpenOptions::new()
275 .map_size(ONE_GB_IN_BYTES)
276 .max_dbs(1)
277 .open(path)?
278 };
279
280 let mut txn = env.write_txn()?;
281 let threads = env.create_database(&mut txn, Some("threads"))?;
282 txn.commit()?;
283
284 Ok(Self {
285 executor,
286 env,
287 threads,
288 })
289 }
290
291 pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
292 let env = self.env.clone();
293 let threads = self.threads;
294
295 self.executor.spawn(async move {
296 let txn = env.read_txn()?;
297 let mut iter = threads.iter(&txn)?;
298 let mut threads = Vec::new();
299 while let Some((key, value)) = iter.next().transpose()? {
300 threads.push(SavedThreadMetadata {
301 id: key,
302 summary: value.summary,
303 updated_at: value.updated_at,
304 });
305 }
306
307 Ok(threads)
308 })
309 }
310
311 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
312 let env = self.env.clone();
313 let threads = self.threads;
314
315 self.executor.spawn(async move {
316 let txn = env.read_txn()?;
317 let thread = threads.get(&txn, &id)?;
318 Ok(thread)
319 })
320 }
321
322 pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
323 let env = self.env.clone();
324 let threads = self.threads;
325
326 self.executor.spawn(async move {
327 let mut txn = env.write_txn()?;
328 threads.put(&mut txn, &id, &thread)?;
329 txn.commit()?;
330 Ok(())
331 })
332 }
333
334 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
335 let env = self.env.clone();
336 let threads = self.threads;
337
338 self.executor.spawn(async move {
339 let mut txn = env.write_txn()?;
340 threads.delete(&mut txn, &id)?;
341 txn.commit()?;
342 Ok(())
343 })
344 }
345}