1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use assistant_tool::{ToolId, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::HashMap;
8use context_server::manager::ContextServerManager;
9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
10use futures::future::{self, BoxFuture, Shared};
11use futures::FutureExt as _;
12use gpui::{prelude::*, App, BackgroundExecutor, Context, Entity, SharedString, Task};
13use heed::types::SerdeBincode;
14use heed::Database;
15use language_model::Role;
16use project::Project;
17use serde::{Deserialize, Serialize};
18use util::ResultExt as _;
19
20use crate::thread::{MessageId, Thread, ThreadId};
21
22pub struct ThreadStore {
23 #[allow(unused)]
24 project: Entity<Project>,
25 tools: Arc<ToolWorkingSet>,
26 context_server_manager: Entity<ContextServerManager>,
27 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
28 threads: Vec<SavedThreadMetadata>,
29 database_future: Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
30}
31
32impl ThreadStore {
33 pub fn new(
34 project: Entity<Project>,
35 tools: Arc<ToolWorkingSet>,
36 cx: &mut App,
37 ) -> Result<Entity<Self>> {
38 let this = cx.new(|cx| {
39 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
40 let context_server_manager = cx.new(|cx| {
41 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
42 });
43
44 let executor = cx.background_executor().clone();
45 let database_future = executor
46 .spawn({
47 let executor = executor.clone();
48 let database_path = paths::support_dir().join("threads/threads-db.0.mdb");
49 async move { ThreadsDatabase::new(database_path, executor) }
50 })
51 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
52 .boxed()
53 .shared();
54
55 let this = Self {
56 project,
57 tools,
58 context_server_manager,
59 context_server_tool_ids: HashMap::default(),
60 threads: Vec::new(),
61 database_future,
62 };
63 this.register_context_server_handlers(cx);
64 this.reload(cx).detach_and_log_err(cx);
65
66 this
67 });
68
69 Ok(this)
70 }
71
72 /// Returns the number of threads.
73 pub fn thread_count(&self) -> usize {
74 self.threads.len()
75 }
76
77 pub fn threads(&self) -> Vec<SavedThreadMetadata> {
78 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
79 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
80 threads
81 }
82
83 pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
84 self.threads().into_iter().take(limit).collect()
85 }
86
87 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
88 cx.new(|cx| Thread::new(self.tools.clone(), cx))
89 }
90
91 pub fn open_thread(
92 &self,
93 id: &ThreadId,
94 cx: &mut Context<Self>,
95 ) -> Task<Result<Entity<Thread>>> {
96 let id = id.clone();
97 let database_future = self.database_future.clone();
98 cx.spawn(|this, mut cx| async move {
99 let database = database_future.await.map_err(|err| anyhow!(err))?;
100 let thread = database
101 .try_find_thread(id.clone())
102 .await?
103 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
104
105 this.update(&mut cx, |this, cx| {
106 cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
107 })
108 })
109 }
110
111 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
112 let (metadata, thread) = thread.update(cx, |thread, _cx| {
113 let id = thread.id().clone();
114 let thread = SavedThread {
115 summary: thread.summary_or_default(),
116 updated_at: thread.updated_at(),
117 messages: thread
118 .messages()
119 .map(|message| SavedMessage {
120 id: message.id,
121 role: message.role,
122 text: message.text.clone(),
123 })
124 .collect(),
125 };
126
127 (id, thread)
128 });
129
130 let database_future = self.database_future.clone();
131 cx.spawn(|this, mut cx| async move {
132 let database = database_future.await.map_err(|err| anyhow!(err))?;
133 database.save_thread(metadata, thread).await?;
134
135 this.update(&mut cx, |this, cx| this.reload(cx))?.await
136 })
137 }
138
139 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
140 let id = id.clone();
141 let database_future = self.database_future.clone();
142 cx.spawn(|this, mut cx| async move {
143 let database = database_future.await.map_err(|err| anyhow!(err))?;
144 database.delete_thread(id.clone()).await?;
145
146 this.update(&mut cx, |this, _cx| {
147 this.threads.retain(|thread| thread.id != id)
148 })
149 })
150 }
151
152 fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
153 let database_future = self.database_future.clone();
154 cx.spawn(|this, mut cx| async move {
155 let threads = database_future
156 .await
157 .map_err(|err| anyhow!(err))?
158 .list_threads()
159 .await?;
160
161 this.update(&mut cx, |this, cx| {
162 this.threads = threads;
163 cx.notify();
164 })
165 })
166 }
167
168 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
169 cx.subscribe(
170 &self.context_server_manager.clone(),
171 Self::handle_context_server_event,
172 )
173 .detach();
174 }
175
176 fn handle_context_server_event(
177 &mut self,
178 context_server_manager: Entity<ContextServerManager>,
179 event: &context_server::manager::Event,
180 cx: &mut Context<Self>,
181 ) {
182 let tool_working_set = self.tools.clone();
183 match event {
184 context_server::manager::Event::ServerStarted { server_id } => {
185 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
186 let context_server_manager = context_server_manager.clone();
187 cx.spawn({
188 let server = server.clone();
189 let server_id = server_id.clone();
190 |this, mut cx| async move {
191 let Some(protocol) = server.client() else {
192 return;
193 };
194
195 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
196 if let Some(tools) = protocol.list_tools().await.log_err() {
197 let tool_ids = tools
198 .tools
199 .into_iter()
200 .map(|tool| {
201 log::info!(
202 "registering context server tool: {:?}",
203 tool.name
204 );
205 tool_working_set.insert(Arc::new(
206 ContextServerTool::new(
207 context_server_manager.clone(),
208 server.id(),
209 tool,
210 ),
211 ))
212 })
213 .collect::<Vec<_>>();
214
215 this.update(&mut cx, |this, _cx| {
216 this.context_server_tool_ids.insert(server_id, tool_ids);
217 })
218 .log_err();
219 }
220 }
221 }
222 })
223 .detach();
224 }
225 }
226 context_server::manager::Event::ServerStopped { server_id } => {
227 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
228 tool_working_set.remove(&tool_ids);
229 }
230 }
231 }
232 }
233}
234
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct SavedThreadMetadata {
237 pub id: ThreadId,
238 pub summary: SharedString,
239 pub updated_at: DateTime<Utc>,
240}
241
242#[derive(Serialize, Deserialize)]
243pub struct SavedThread {
244 pub summary: SharedString,
245 pub updated_at: DateTime<Utc>,
246 pub messages: Vec<SavedMessage>,
247}
248
249#[derive(Serialize, Deserialize)]
250pub struct SavedMessage {
251 pub id: MessageId,
252 pub role: Role,
253 pub text: String,
254}
255
256struct ThreadsDatabase {
257 executor: BackgroundExecutor,
258 env: heed::Env,
259 threads: Database<SerdeBincode<ThreadId>, SerdeBincode<SavedThread>>,
260}
261
262impl ThreadsDatabase {
263 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
264 std::fs::create_dir_all(&path)?;
265
266 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
267 let env = unsafe {
268 heed::EnvOpenOptions::new()
269 .map_size(ONE_GB_IN_BYTES)
270 .max_dbs(1)
271 .open(path)?
272 };
273
274 let mut txn = env.write_txn()?;
275 let threads = env.create_database(&mut txn, Some("threads"))?;
276 txn.commit()?;
277
278 Ok(Self {
279 executor,
280 env,
281 threads,
282 })
283 }
284
285 pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
286 let env = self.env.clone();
287 let threads = self.threads;
288
289 self.executor.spawn(async move {
290 let txn = env.read_txn()?;
291 let mut iter = threads.iter(&txn)?;
292 let mut threads = Vec::new();
293 while let Some((key, value)) = iter.next().transpose()? {
294 threads.push(SavedThreadMetadata {
295 id: key,
296 summary: value.summary,
297 updated_at: value.updated_at,
298 });
299 }
300
301 Ok(threads)
302 })
303 }
304
305 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
306 let env = self.env.clone();
307 let threads = self.threads;
308
309 self.executor.spawn(async move {
310 let txn = env.read_txn()?;
311 let thread = threads.get(&txn, &id)?;
312 Ok(thread)
313 })
314 }
315
316 pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
317 let env = self.env.clone();
318 let threads = self.threads;
319
320 self.executor.spawn(async move {
321 let mut txn = env.write_txn()?;
322 threads.put(&mut txn, &id, &thread)?;
323 txn.commit()?;
324 Ok(())
325 })
326 }
327
328 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
329 let env = self.env.clone();
330 let threads = self.threads;
331
332 self.executor.spawn(async move {
333 let mut txn = env.write_txn()?;
334 threads.delete(&mut txn, &id)?;
335 txn.commit()?;
336 Ok(())
337 })
338 }
339}