1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use assistant_tool::{ToolId, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::HashMap;
8use context_server::manager::ContextServerManager;
9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
10use futures::future::{self, BoxFuture, Shared};
11use futures::FutureExt as _;
12use gpui::{prelude::*, App, BackgroundExecutor, Context, Entity, SharedString, Task};
13use heed::types::SerdeBincode;
14use heed::Database;
15use language_model::Role;
16use project::Project;
17use serde::{Deserialize, Serialize};
18use util::ResultExt as _;
19
20use crate::thread::{MessageId, Thread, ThreadId};
21
22pub struct ThreadStore {
23 #[allow(unused)]
24 project: Entity<Project>,
25 tools: Arc<ToolWorkingSet>,
26 context_server_manager: Entity<ContextServerManager>,
27 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
28 threads: Vec<SavedThreadMetadata>,
29 database_future: Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
30}
31
32impl ThreadStore {
33 pub fn new(
34 project: Entity<Project>,
35 tools: Arc<ToolWorkingSet>,
36 cx: &mut App,
37 ) -> Task<Result<Entity<Self>>> {
38 cx.spawn(|mut cx| async move {
39 let this = cx.new(|cx: &mut Context<Self>| {
40 let context_server_factory_registry =
41 ContextServerFactoryRegistry::default_global(cx);
42 let context_server_manager = cx.new(|cx| {
43 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
44 });
45
46 let executor = cx.background_executor().clone();
47 let database_future = executor
48 .spawn({
49 let executor = executor.clone();
50 let database_path = paths::support_dir().join("threads/threads-db.0.mdb");
51 async move { ThreadsDatabase::new(database_path, executor) }
52 })
53 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
54 .boxed()
55 .shared();
56
57 let this = Self {
58 project,
59 tools,
60 context_server_manager,
61 context_server_tool_ids: HashMap::default(),
62 threads: Vec::new(),
63 database_future,
64 };
65 this.register_context_server_handlers(cx);
66
67 this
68 })?;
69
70 this.update(&mut cx, |this, cx| this.reload(cx))?.await?;
71
72 Ok(this)
73 })
74 }
75
76 /// Returns the number of threads.
77 pub fn thread_count(&self) -> usize {
78 self.threads.len()
79 }
80
81 pub fn threads(&self) -> Vec<SavedThreadMetadata> {
82 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
83 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
84 threads
85 }
86
87 pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
88 self.threads().into_iter().take(limit).collect()
89 }
90
91 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
92 cx.new(|cx| Thread::new(self.tools.clone(), cx))
93 }
94
95 pub fn open_thread(
96 &self,
97 id: &ThreadId,
98 cx: &mut Context<Self>,
99 ) -> Task<Result<Entity<Thread>>> {
100 let id = id.clone();
101 let database_future = self.database_future.clone();
102 cx.spawn(|this, mut cx| async move {
103 let database = database_future.await.map_err(|err| anyhow!(err))?;
104 let thread = database
105 .try_find_thread(id.clone())
106 .await?
107 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
108
109 this.update(&mut cx, |this, cx| {
110 cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
111 })
112 })
113 }
114
115 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
116 let (metadata, thread) = thread.update(cx, |thread, _cx| {
117 let id = thread.id().clone();
118 let thread = SavedThread {
119 summary: thread.summary_or_default(),
120 updated_at: thread.updated_at(),
121 messages: thread
122 .messages()
123 .map(|message| SavedMessage {
124 id: message.id,
125 role: message.role,
126 text: message.text.clone(),
127 })
128 .collect(),
129 };
130
131 (id, thread)
132 });
133
134 let database_future = self.database_future.clone();
135 cx.spawn(|this, mut cx| async move {
136 let database = database_future.await.map_err(|err| anyhow!(err))?;
137 database.save_thread(metadata, thread).await?;
138
139 this.update(&mut cx, |this, cx| this.reload(cx))?.await
140 })
141 }
142
143 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
144 let id = id.clone();
145 let database_future = self.database_future.clone();
146 cx.spawn(|this, mut cx| async move {
147 let database = database_future.await.map_err(|err| anyhow!(err))?;
148 database.delete_thread(id.clone()).await?;
149
150 this.update(&mut cx, |this, _cx| {
151 this.threads.retain(|thread| thread.id != id)
152 })
153 })
154 }
155
156 fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
157 let database_future = self.database_future.clone();
158 cx.spawn(|this, mut cx| async move {
159 let threads = database_future
160 .await
161 .map_err(|err| anyhow!(err))?
162 .list_threads()
163 .await?;
164
165 this.update(&mut cx, |this, cx| {
166 this.threads = threads;
167 cx.notify();
168 })
169 })
170 }
171
172 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
173 cx.subscribe(
174 &self.context_server_manager.clone(),
175 Self::handle_context_server_event,
176 )
177 .detach();
178 }
179
180 fn handle_context_server_event(
181 &mut self,
182 context_server_manager: Entity<ContextServerManager>,
183 event: &context_server::manager::Event,
184 cx: &mut Context<Self>,
185 ) {
186 let tool_working_set = self.tools.clone();
187 match event {
188 context_server::manager::Event::ServerStarted { server_id } => {
189 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
190 let context_server_manager = context_server_manager.clone();
191 cx.spawn({
192 let server = server.clone();
193 let server_id = server_id.clone();
194 |this, mut cx| async move {
195 let Some(protocol) = server.client() else {
196 return;
197 };
198
199 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
200 if let Some(tools) = protocol.list_tools().await.log_err() {
201 let tool_ids = tools
202 .tools
203 .into_iter()
204 .map(|tool| {
205 log::info!(
206 "registering context server tool: {:?}",
207 tool.name
208 );
209 tool_working_set.insert(Arc::new(
210 ContextServerTool::new(
211 context_server_manager.clone(),
212 server.id(),
213 tool,
214 ),
215 ))
216 })
217 .collect::<Vec<_>>();
218
219 this.update(&mut cx, |this, _cx| {
220 this.context_server_tool_ids.insert(server_id, tool_ids);
221 })
222 .log_err();
223 }
224 }
225 }
226 })
227 .detach();
228 }
229 }
230 context_server::manager::Event::ServerStopped { server_id } => {
231 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
232 tool_working_set.remove(&tool_ids);
233 }
234 }
235 }
236 }
237}
238
239#[derive(Debug, Clone, Serialize, Deserialize)]
240pub struct SavedThreadMetadata {
241 pub id: ThreadId,
242 pub summary: SharedString,
243 pub updated_at: DateTime<Utc>,
244}
245
246#[derive(Serialize, Deserialize)]
247pub struct SavedThread {
248 pub summary: SharedString,
249 pub updated_at: DateTime<Utc>,
250 pub messages: Vec<SavedMessage>,
251}
252
253#[derive(Serialize, Deserialize)]
254pub struct SavedMessage {
255 pub id: MessageId,
256 pub role: Role,
257 pub text: String,
258}
259
260struct ThreadsDatabase {
261 executor: BackgroundExecutor,
262 env: heed::Env,
263 threads: Database<SerdeBincode<ThreadId>, SerdeBincode<SavedThread>>,
264}
265
266impl ThreadsDatabase {
267 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
268 std::fs::create_dir_all(&path)?;
269
270 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
271 let env = unsafe {
272 heed::EnvOpenOptions::new()
273 .map_size(ONE_GB_IN_BYTES)
274 .max_dbs(1)
275 .open(path)?
276 };
277
278 let mut txn = env.write_txn()?;
279 let threads = env.create_database(&mut txn, Some("threads"))?;
280 txn.commit()?;
281
282 Ok(Self {
283 executor,
284 env,
285 threads,
286 })
287 }
288
289 pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
290 let env = self.env.clone();
291 let threads = self.threads;
292
293 self.executor.spawn(async move {
294 let txn = env.read_txn()?;
295 let mut iter = threads.iter(&txn)?;
296 let mut threads = Vec::new();
297 while let Some((key, value)) = iter.next().transpose()? {
298 threads.push(SavedThreadMetadata {
299 id: key,
300 summary: value.summary,
301 updated_at: value.updated_at,
302 });
303 }
304
305 Ok(threads)
306 })
307 }
308
309 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
310 let env = self.env.clone();
311 let threads = self.threads;
312
313 self.executor.spawn(async move {
314 let txn = env.read_txn()?;
315 let thread = threads.get(&txn, &id)?;
316 Ok(thread)
317 })
318 }
319
320 pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
321 let env = self.env.clone();
322 let threads = self.threads;
323
324 self.executor.spawn(async move {
325 let mut txn = env.write_txn()?;
326 threads.put(&mut txn, &id, &thread)?;
327 txn.commit()?;
328 Ok(())
329 })
330 }
331
332 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
333 let env = self.env.clone();
334 let threads = self.threads;
335
336 self.executor.spawn(async move {
337 let mut txn = env.write_txn()?;
338 threads.delete(&mut txn, &id)?;
339 txn.commit()?;
340 Ok(())
341 })
342 }
343}