1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use assistant_tool::{ToolId, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::HashMap;
8use context_server::manager::ContextServerManager;
9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
10use futures::future::{self, BoxFuture, Shared};
11use futures::FutureExt as _;
12use gpui::{prelude::*, AppContext, BackgroundExecutor, Model, ModelContext, SharedString, Task};
13use heed::types::SerdeBincode;
14use heed::Database;
15use language_model::Role;
16use project::Project;
17use serde::{Deserialize, Serialize};
18use util::ResultExt as _;
19
20use crate::thread::{MessageId, Thread, ThreadId};
21
22pub struct ThreadStore {
23 #[allow(unused)]
24 project: Model<Project>,
25 tools: Arc<ToolWorkingSet>,
26 context_server_manager: Model<ContextServerManager>,
27 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
28 threads: Vec<SavedThreadMetadata>,
29 database_future: Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
30}
31
32impl ThreadStore {
33 pub fn new(
34 project: Model<Project>,
35 tools: Arc<ToolWorkingSet>,
36 cx: &mut AppContext,
37 ) -> Task<Result<Model<Self>>> {
38 cx.spawn(|mut cx| async move {
39 let this = cx.new_model(|cx: &mut ModelContext<Self>| {
40 let context_server_factory_registry =
41 ContextServerFactoryRegistry::default_global(cx);
42 let context_server_manager = cx.new_model(|cx| {
43 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
44 });
45
46 let executor = cx.background_executor().clone();
47 let database_future = executor
48 .spawn({
49 let executor = executor.clone();
50 let database_path = paths::support_dir().join("threads/threads-db.0.mdb");
51 async move { ThreadsDatabase::new(database_path, executor) }
52 })
53 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
54 .boxed()
55 .shared();
56
57 let this = Self {
58 project,
59 tools,
60 context_server_manager,
61 context_server_tool_ids: HashMap::default(),
62 threads: Vec::new(),
63 database_future,
64 };
65 this.register_context_server_handlers(cx);
66
67 this
68 })?;
69
70 this.update(&mut cx, |this, cx| this.reload(cx))?.await?;
71
72 Ok(this)
73 })
74 }
75
76 /// Returns the number of threads.
77 pub fn thread_count(&self) -> usize {
78 self.threads.len()
79 }
80
81 pub fn threads(&self) -> Vec<SavedThreadMetadata> {
82 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
83 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
84 threads
85 }
86
87 pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
88 self.threads().into_iter().take(limit).collect()
89 }
90
91 pub fn create_thread(&mut self, cx: &mut ModelContext<Self>) -> Model<Thread> {
92 cx.new_model(|cx| Thread::new(self.tools.clone(), cx))
93 }
94
95 pub fn open_thread(
96 &self,
97 id: &ThreadId,
98 cx: &mut ModelContext<Self>,
99 ) -> Task<Result<Model<Thread>>> {
100 let id = id.clone();
101 let database_future = self.database_future.clone();
102 cx.spawn(|this, mut cx| async move {
103 let database = database_future.await.map_err(|err| anyhow!(err))?;
104 let thread = database
105 .try_find_thread(id.clone())
106 .await?
107 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
108
109 this.update(&mut cx, |this, cx| {
110 cx.new_model(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
111 })
112 })
113 }
114
115 pub fn save_thread(
116 &self,
117 thread: &Model<Thread>,
118 cx: &mut ModelContext<Self>,
119 ) -> Task<Result<()>> {
120 let (metadata, thread) = thread.update(cx, |thread, _cx| {
121 let id = thread.id().clone();
122 let thread = SavedThread {
123 summary: thread.summary_or_default(),
124 updated_at: thread.updated_at(),
125 messages: thread
126 .messages()
127 .map(|message| SavedMessage {
128 id: message.id,
129 role: message.role,
130 text: message.text.clone(),
131 })
132 .collect(),
133 };
134
135 (id, thread)
136 });
137
138 let database_future = self.database_future.clone();
139 cx.spawn(|this, mut cx| async move {
140 let database = database_future.await.map_err(|err| anyhow!(err))?;
141 database.save_thread(metadata, thread).await?;
142
143 this.update(&mut cx, |this, cx| this.reload(cx))?.await
144 })
145 }
146
147 pub fn delete_thread(
148 &mut self,
149 id: &ThreadId,
150 cx: &mut ModelContext<Self>,
151 ) -> Task<Result<()>> {
152 let id = id.clone();
153 let database_future = self.database_future.clone();
154 cx.spawn(|this, mut cx| async move {
155 let database = database_future.await.map_err(|err| anyhow!(err))?;
156 database.delete_thread(id.clone()).await?;
157
158 this.update(&mut cx, |this, _cx| {
159 this.threads.retain(|thread| thread.id != id)
160 })
161 })
162 }
163
164 fn reload(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
165 let database_future = self.database_future.clone();
166 cx.spawn(|this, mut cx| async move {
167 let threads = database_future
168 .await
169 .map_err(|err| anyhow!(err))?
170 .list_threads()
171 .await?;
172
173 this.update(&mut cx, |this, cx| {
174 this.threads = threads;
175 cx.notify();
176 })
177 })
178 }
179
180 fn register_context_server_handlers(&self, cx: &mut ModelContext<Self>) {
181 cx.subscribe(
182 &self.context_server_manager.clone(),
183 Self::handle_context_server_event,
184 )
185 .detach();
186 }
187
188 fn handle_context_server_event(
189 &mut self,
190 context_server_manager: Model<ContextServerManager>,
191 event: &context_server::manager::Event,
192 cx: &mut ModelContext<Self>,
193 ) {
194 let tool_working_set = self.tools.clone();
195 match event {
196 context_server::manager::Event::ServerStarted { server_id } => {
197 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
198 let context_server_manager = context_server_manager.clone();
199 cx.spawn({
200 let server = server.clone();
201 let server_id = server_id.clone();
202 |this, mut cx| async move {
203 let Some(protocol) = server.client() else {
204 return;
205 };
206
207 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
208 if let Some(tools) = protocol.list_tools().await.log_err() {
209 let tool_ids = tools
210 .tools
211 .into_iter()
212 .map(|tool| {
213 log::info!(
214 "registering context server tool: {:?}",
215 tool.name
216 );
217 tool_working_set.insert(Arc::new(
218 ContextServerTool::new(
219 context_server_manager.clone(),
220 server.id(),
221 tool,
222 ),
223 ))
224 })
225 .collect::<Vec<_>>();
226
227 this.update(&mut cx, |this, _cx| {
228 this.context_server_tool_ids.insert(server_id, tool_ids);
229 })
230 .log_err();
231 }
232 }
233 }
234 })
235 .detach();
236 }
237 }
238 context_server::manager::Event::ServerStopped { server_id } => {
239 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
240 tool_working_set.remove(&tool_ids);
241 }
242 }
243 }
244 }
245}
246
247#[derive(Debug, Clone, Serialize, Deserialize)]
248pub struct SavedThreadMetadata {
249 pub id: ThreadId,
250 pub summary: SharedString,
251 pub updated_at: DateTime<Utc>,
252}
253
254#[derive(Serialize, Deserialize)]
255pub struct SavedThread {
256 pub summary: SharedString,
257 pub updated_at: DateTime<Utc>,
258 pub messages: Vec<SavedMessage>,
259}
260
261#[derive(Serialize, Deserialize)]
262pub struct SavedMessage {
263 pub id: MessageId,
264 pub role: Role,
265 pub text: String,
266}
267
268struct ThreadsDatabase {
269 executor: BackgroundExecutor,
270 env: heed::Env,
271 threads: Database<SerdeBincode<ThreadId>, SerdeBincode<SavedThread>>,
272}
273
274impl ThreadsDatabase {
275 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
276 std::fs::create_dir_all(&path)?;
277
278 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
279 let env = unsafe {
280 heed::EnvOpenOptions::new()
281 .map_size(ONE_GB_IN_BYTES)
282 .max_dbs(1)
283 .open(path)?
284 };
285
286 let mut txn = env.write_txn()?;
287 let threads = env.create_database(&mut txn, Some("threads"))?;
288 txn.commit()?;
289
290 Ok(Self {
291 executor,
292 env,
293 threads,
294 })
295 }
296
297 pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
298 let env = self.env.clone();
299 let threads = self.threads;
300
301 self.executor.spawn(async move {
302 let txn = env.read_txn()?;
303 let mut iter = threads.iter(&txn)?;
304 let mut threads = Vec::new();
305 while let Some((key, value)) = iter.next().transpose()? {
306 threads.push(SavedThreadMetadata {
307 id: key,
308 summary: value.summary,
309 updated_at: value.updated_at,
310 });
311 }
312
313 Ok(threads)
314 })
315 }
316
317 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
318 let env = self.env.clone();
319 let threads = self.threads;
320
321 self.executor.spawn(async move {
322 let txn = env.read_txn()?;
323 let thread = threads.get(&txn, &id)?;
324 Ok(thread)
325 })
326 }
327
328 pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
329 let env = self.env.clone();
330 let threads = self.threads;
331
332 self.executor.spawn(async move {
333 let mut txn = env.write_txn()?;
334 threads.put(&mut txn, &id, &thread)?;
335 txn.commit()?;
336 Ok(())
337 })
338 }
339
340 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
341 let env = self.env.clone();
342 let threads = self.threads;
343
344 self.executor.spawn(async move {
345 let mut txn = env.write_txn()?;
346 threads.delete(&mut txn, &id)?;
347 txn.commit()?;
348 Ok(())
349 })
350 }
351}