1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use assistant_tool::{ToolId, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::HashMap;
8use context_server::manager::ContextServerManager;
9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
10use futures::future::{self, BoxFuture, Shared};
11use futures::FutureExt as _;
12use gpui::{
13 prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
14};
15use heed::types::{SerdeBincode, SerdeJson};
16use heed::Database;
17use language_model::{LanguageModelToolUseId, Role};
18use project::Project;
19use prompt_store::PromptBuilder;
20use serde::{Deserialize, Serialize};
21use util::ResultExt as _;
22
23use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadId};
24
25pub fn init(cx: &mut App) {
26 ThreadsDatabase::init(cx);
27}
28
29pub struct ThreadStore {
30 project: Entity<Project>,
31 tools: Arc<ToolWorkingSet>,
32 prompt_builder: Arc<PromptBuilder>,
33 context_server_manager: Entity<ContextServerManager>,
34 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
35 threads: Vec<SerializedThreadMetadata>,
36}
37
38impl ThreadStore {
39 pub fn new(
40 project: Entity<Project>,
41 tools: Arc<ToolWorkingSet>,
42 prompt_builder: Arc<PromptBuilder>,
43 cx: &mut App,
44 ) -> Result<Entity<Self>> {
45 let this = cx.new(|cx| {
46 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
47 let context_server_manager = cx.new(|cx| {
48 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
49 });
50
51 let this = Self {
52 project,
53 tools,
54 prompt_builder,
55 context_server_manager,
56 context_server_tool_ids: HashMap::default(),
57 threads: Vec::new(),
58 };
59 this.register_context_server_handlers(cx);
60 this.reload(cx).detach_and_log_err(cx);
61
62 this
63 });
64
65 Ok(this)
66 }
67
68 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
69 self.context_server_manager.clone()
70 }
71
72 /// Returns the number of threads.
73 pub fn thread_count(&self) -> usize {
74 self.threads.len()
75 }
76
77 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
78 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
79 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
80 threads
81 }
82
83 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
84 self.threads().into_iter().take(limit).collect()
85 }
86
87 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
88 cx.new(|cx| {
89 Thread::new(
90 self.project.clone(),
91 self.tools.clone(),
92 self.prompt_builder.clone(),
93 cx,
94 )
95 })
96 }
97
98 pub fn open_thread(
99 &self,
100 id: &ThreadId,
101 cx: &mut Context<Self>,
102 ) -> Task<Result<Entity<Thread>>> {
103 let id = id.clone();
104 let database_future = ThreadsDatabase::global_future(cx);
105 cx.spawn(|this, mut cx| async move {
106 let database = database_future.await.map_err(|err| anyhow!(err))?;
107 let thread = database
108 .try_find_thread(id.clone())
109 .await?
110 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
111
112 this.update(&mut cx, |this, cx| {
113 cx.new(|cx| {
114 Thread::deserialize(
115 id.clone(),
116 thread,
117 this.project.clone(),
118 this.tools.clone(),
119 this.prompt_builder.clone(),
120 cx,
121 )
122 })
123 })
124 })
125 }
126
127 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
128 let (metadata, serialized_thread) =
129 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
130
131 let database_future = ThreadsDatabase::global_future(cx);
132 cx.spawn(|this, mut cx| async move {
133 let serialized_thread = serialized_thread.await?;
134 let database = database_future.await.map_err(|err| anyhow!(err))?;
135 database.save_thread(metadata, serialized_thread).await?;
136
137 this.update(&mut cx, |this, cx| this.reload(cx))?.await
138 })
139 }
140
141 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
142 let id = id.clone();
143 let database_future = ThreadsDatabase::global_future(cx);
144 cx.spawn(|this, mut cx| async move {
145 let database = database_future.await.map_err(|err| anyhow!(err))?;
146 database.delete_thread(id.clone()).await?;
147
148 this.update(&mut cx, |this, _cx| {
149 this.threads.retain(|thread| thread.id != id)
150 })
151 })
152 }
153
154 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
155 let database_future = ThreadsDatabase::global_future(cx);
156 cx.spawn(|this, mut cx| async move {
157 let threads = database_future
158 .await
159 .map_err(|err| anyhow!(err))?
160 .list_threads()
161 .await?;
162
163 this.update(&mut cx, |this, cx| {
164 this.threads = threads;
165 cx.notify();
166 })
167 })
168 }
169
170 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
171 cx.subscribe(
172 &self.context_server_manager.clone(),
173 Self::handle_context_server_event,
174 )
175 .detach();
176 }
177
178 fn handle_context_server_event(
179 &mut self,
180 context_server_manager: Entity<ContextServerManager>,
181 event: &context_server::manager::Event,
182 cx: &mut Context<Self>,
183 ) {
184 let tool_working_set = self.tools.clone();
185 match event {
186 context_server::manager::Event::ServerStarted { server_id } => {
187 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
188 let context_server_manager = context_server_manager.clone();
189 cx.spawn({
190 let server = server.clone();
191 let server_id = server_id.clone();
192 |this, mut cx| async move {
193 let Some(protocol) = server.client() else {
194 return;
195 };
196
197 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
198 if let Some(tools) = protocol.list_tools().await.log_err() {
199 let tool_ids = tools
200 .tools
201 .into_iter()
202 .map(|tool| {
203 log::info!(
204 "registering context server tool: {:?}",
205 tool.name
206 );
207 tool_working_set.insert(Arc::new(
208 ContextServerTool::new(
209 context_server_manager.clone(),
210 server.id(),
211 tool,
212 ),
213 ))
214 })
215 .collect::<Vec<_>>();
216
217 this.update(&mut cx, |this, _cx| {
218 this.context_server_tool_ids.insert(server_id, tool_ids);
219 })
220 .log_err();
221 }
222 }
223 }
224 })
225 .detach();
226 }
227 }
228 context_server::manager::Event::ServerStopped { server_id } => {
229 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
230 tool_working_set.remove(&tool_ids);
231 }
232 }
233 }
234 }
235}
236
237#[derive(Debug, Clone, Serialize, Deserialize)]
238pub struct SerializedThreadMetadata {
239 pub id: ThreadId,
240 pub summary: SharedString,
241 pub updated_at: DateTime<Utc>,
242}
243
244#[derive(Serialize, Deserialize)]
245pub struct SerializedThread {
246 pub summary: SharedString,
247 pub updated_at: DateTime<Utc>,
248 pub messages: Vec<SerializedMessage>,
249 #[serde(default)]
250 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
251}
252
253#[derive(Debug, Serialize, Deserialize)]
254pub struct SerializedMessage {
255 pub id: MessageId,
256 pub role: Role,
257 pub text: String,
258 #[serde(default)]
259 pub tool_uses: Vec<SerializedToolUse>,
260 #[serde(default)]
261 pub tool_results: Vec<SerializedToolResult>,
262}
263
264#[derive(Debug, Serialize, Deserialize)]
265pub struct SerializedToolUse {
266 pub id: LanguageModelToolUseId,
267 pub name: SharedString,
268 pub input: serde_json::Value,
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272pub struct SerializedToolResult {
273 pub tool_use_id: LanguageModelToolUseId,
274 pub is_error: bool,
275 pub content: Arc<str>,
276}
277
278struct GlobalThreadsDatabase(
279 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
280);
281
282impl Global for GlobalThreadsDatabase {}
283
284pub(crate) struct ThreadsDatabase {
285 executor: BackgroundExecutor,
286 env: heed::Env,
287 threads: Database<SerdeBincode<ThreadId>, SerdeJson<SerializedThread>>,
288}
289
290impl ThreadsDatabase {
291 fn global_future(
292 cx: &mut App,
293 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
294 GlobalThreadsDatabase::global(cx).0.clone()
295 }
296
297 fn init(cx: &mut App) {
298 let executor = cx.background_executor().clone();
299 let database_future = executor
300 .spawn({
301 let executor = executor.clone();
302 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
303 async move { ThreadsDatabase::new(database_path, executor) }
304 })
305 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
306 .boxed()
307 .shared();
308
309 cx.set_global(GlobalThreadsDatabase(database_future));
310 }
311
312 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
313 std::fs::create_dir_all(&path)?;
314
315 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
316 let env = unsafe {
317 heed::EnvOpenOptions::new()
318 .map_size(ONE_GB_IN_BYTES)
319 .max_dbs(1)
320 .open(path)?
321 };
322
323 let mut txn = env.write_txn()?;
324 let threads = env.create_database(&mut txn, Some("threads"))?;
325 txn.commit()?;
326
327 Ok(Self {
328 executor,
329 env,
330 threads,
331 })
332 }
333
334 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
335 let env = self.env.clone();
336 let threads = self.threads;
337
338 self.executor.spawn(async move {
339 let txn = env.read_txn()?;
340 let mut iter = threads.iter(&txn)?;
341 let mut threads = Vec::new();
342 while let Some((key, value)) = iter.next().transpose()? {
343 threads.push(SerializedThreadMetadata {
344 id: key,
345 summary: value.summary,
346 updated_at: value.updated_at,
347 });
348 }
349
350 Ok(threads)
351 })
352 }
353
354 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
355 let env = self.env.clone();
356 let threads = self.threads;
357
358 self.executor.spawn(async move {
359 let txn = env.read_txn()?;
360 let thread = threads.get(&txn, &id)?;
361 Ok(thread)
362 })
363 }
364
365 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
366 let env = self.env.clone();
367 let threads = self.threads;
368
369 self.executor.spawn(async move {
370 let mut txn = env.write_txn()?;
371 threads.put(&mut txn, &id, &thread)?;
372 txn.commit()?;
373 Ok(())
374 })
375 }
376
377 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
378 let env = self.env.clone();
379 let threads = self.threads;
380
381 self.executor.spawn(async move {
382 let mut txn = env.write_txn()?;
383 threads.delete(&mut txn, &id)?;
384 txn.commit()?;
385 Ok(())
386 })
387 }
388}