1use std::borrow::Cow;
2use std::path::PathBuf;
3use std::sync::Arc;
4
5use anyhow::{anyhow, Result};
6use assistant_tool::{ToolId, ToolWorkingSet};
7use chrono::{DateTime, Utc};
8use collections::HashMap;
9use context_server::manager::ContextServerManager;
10use context_server::{ContextServerFactoryRegistry, ContextServerTool};
11use futures::future::{self, BoxFuture, Shared};
12use futures::FutureExt as _;
13use gpui::{
14 prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
15};
16use heed::types::SerdeBincode;
17use heed::Database;
18use language_model::{LanguageModelToolUseId, Role};
19use project::Project;
20use prompt_store::PromptBuilder;
21use serde::{Deserialize, Serialize};
22use util::ResultExt as _;
23
24use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
25
26pub fn init(cx: &mut App) {
27 ThreadsDatabase::init(cx);
28}
29
30pub struct ThreadStore {
31 project: Entity<Project>,
32 tools: Arc<ToolWorkingSet>,
33 prompt_builder: Arc<PromptBuilder>,
34 context_server_manager: Entity<ContextServerManager>,
35 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
36 threads: Vec<SerializedThreadMetadata>,
37}
38
39impl ThreadStore {
40 pub fn new(
41 project: Entity<Project>,
42 tools: Arc<ToolWorkingSet>,
43 prompt_builder: Arc<PromptBuilder>,
44 cx: &mut App,
45 ) -> Result<Entity<Self>> {
46 let this = cx.new(|cx| {
47 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
48 let context_server_manager = cx.new(|cx| {
49 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
50 });
51
52 let this = Self {
53 project,
54 tools,
55 prompt_builder,
56 context_server_manager,
57 context_server_tool_ids: HashMap::default(),
58 threads: Vec::new(),
59 };
60 this.register_context_server_handlers(cx);
61 this.reload(cx).detach_and_log_err(cx);
62
63 this
64 });
65
66 Ok(this)
67 }
68
69 pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
70 self.context_server_manager.clone()
71 }
72
73 pub fn tools(&self) -> Arc<ToolWorkingSet> {
74 self.tools.clone()
75 }
76
77 /// Returns the number of threads.
78 pub fn thread_count(&self) -> usize {
79 self.threads.len()
80 }
81
82 pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
83 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
84 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
85 threads
86 }
87
88 pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
89 self.threads().into_iter().take(limit).collect()
90 }
91
92 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
93 cx.new(|cx| {
94 Thread::new(
95 self.project.clone(),
96 self.tools.clone(),
97 self.prompt_builder.clone(),
98 cx,
99 )
100 })
101 }
102
103 pub fn open_thread(
104 &self,
105 id: &ThreadId,
106 cx: &mut Context<Self>,
107 ) -> Task<Result<Entity<Thread>>> {
108 let id = id.clone();
109 let database_future = ThreadsDatabase::global_future(cx);
110 cx.spawn(async move |this, cx| {
111 let database = database_future.await.map_err(|err| anyhow!(err))?;
112 let thread = database
113 .try_find_thread(id.clone())
114 .await?
115 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
116
117 let thread = this.update(cx, |this, cx| {
118 cx.new(|cx| {
119 Thread::deserialize(
120 id.clone(),
121 thread,
122 this.project.clone(),
123 this.tools.clone(),
124 this.prompt_builder.clone(),
125 cx,
126 )
127 })
128 })?;
129
130 let (system_prompt_context, load_error) = thread
131 .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
132 .await;
133 thread.update(cx, |thread, cx| {
134 thread.set_system_prompt_context(system_prompt_context);
135 if let Some(load_error) = load_error {
136 cx.emit(ThreadEvent::ShowError(load_error));
137 }
138 })?;
139
140 Ok(thread)
141 })
142 }
143
144 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
145 let (metadata, serialized_thread) =
146 thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
147
148 let database_future = ThreadsDatabase::global_future(cx);
149 cx.spawn(async move |this, cx| {
150 let serialized_thread = serialized_thread.await?;
151 let database = database_future.await.map_err(|err| anyhow!(err))?;
152 database.save_thread(metadata, serialized_thread).await?;
153
154 this.update(cx, |this, cx| this.reload(cx))?.await
155 })
156 }
157
158 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
159 let id = id.clone();
160 let database_future = ThreadsDatabase::global_future(cx);
161 cx.spawn(async move |this, cx| {
162 let database = database_future.await.map_err(|err| anyhow!(err))?;
163 database.delete_thread(id.clone()).await?;
164
165 this.update(cx, |this, _cx| {
166 this.threads.retain(|thread| thread.id != id)
167 })
168 })
169 }
170
171 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
172 let database_future = ThreadsDatabase::global_future(cx);
173 cx.spawn(async move |this, cx| {
174 let threads = database_future
175 .await
176 .map_err(|err| anyhow!(err))?
177 .list_threads()
178 .await?;
179
180 this.update(cx, |this, cx| {
181 this.threads = threads;
182 cx.notify();
183 })
184 })
185 }
186
187 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
188 cx.subscribe(
189 &self.context_server_manager.clone(),
190 Self::handle_context_server_event,
191 )
192 .detach();
193 }
194
195 fn handle_context_server_event(
196 &mut self,
197 context_server_manager: Entity<ContextServerManager>,
198 event: &context_server::manager::Event,
199 cx: &mut Context<Self>,
200 ) {
201 let tool_working_set = self.tools.clone();
202 match event {
203 context_server::manager::Event::ServerStarted { server_id } => {
204 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
205 let context_server_manager = context_server_manager.clone();
206 cx.spawn({
207 let server = server.clone();
208 let server_id = server_id.clone();
209 async move |this, cx| {
210 let Some(protocol) = server.client() else {
211 return;
212 };
213
214 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
215 if let Some(tools) = protocol.list_tools().await.log_err() {
216 let tool_ids = tools
217 .tools
218 .into_iter()
219 .map(|tool| {
220 log::info!(
221 "registering context server tool: {:?}",
222 tool.name
223 );
224 tool_working_set.insert(Arc::new(
225 ContextServerTool::new(
226 context_server_manager.clone(),
227 server.id(),
228 tool,
229 ),
230 ))
231 })
232 .collect::<Vec<_>>();
233
234 this.update(cx, |this, _cx| {
235 this.context_server_tool_ids.insert(server_id, tool_ids);
236 })
237 .log_err();
238 }
239 }
240 }
241 })
242 .detach();
243 }
244 }
245 context_server::manager::Event::ServerStopped { server_id } => {
246 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
247 tool_working_set.remove(&tool_ids);
248 }
249 }
250 }
251 }
252}
253
254#[derive(Debug, Clone, Serialize, Deserialize)]
255pub struct SerializedThreadMetadata {
256 pub id: ThreadId,
257 pub summary: SharedString,
258 pub updated_at: DateTime<Utc>,
259}
260
261#[derive(Serialize, Deserialize)]
262pub struct SerializedThread {
263 pub version: String,
264 pub summary: SharedString,
265 pub updated_at: DateTime<Utc>,
266 pub messages: Vec<SerializedMessage>,
267 #[serde(default)]
268 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
269}
270
271impl SerializedThread {
272 pub const VERSION: &'static str = "0.1.0";
273
274 pub fn from_json(json: &[u8]) -> Result<Self> {
275 let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
276 match saved_thread_json.get("version") {
277 Some(serde_json::Value::String(version)) => match version.as_str() {
278 SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
279 saved_thread_json,
280 )?),
281 _ => Err(anyhow!(
282 "unrecognized serialized thread version: {}",
283 version
284 )),
285 },
286 None => {
287 let saved_thread =
288 serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
289 Ok(saved_thread.upgrade())
290 }
291 version => Err(anyhow!(
292 "unrecognized serialized thread version: {:?}",
293 version
294 )),
295 }
296 }
297}
298
299#[derive(Debug, Serialize, Deserialize)]
300pub struct SerializedMessage {
301 pub id: MessageId,
302 pub role: Role,
303 #[serde(default)]
304 pub segments: Vec<SerializedMessageSegment>,
305 #[serde(default)]
306 pub tool_uses: Vec<SerializedToolUse>,
307 #[serde(default)]
308 pub tool_results: Vec<SerializedToolResult>,
309}
310
311#[derive(Debug, Serialize, Deserialize)]
312#[serde(tag = "type")]
313pub enum SerializedMessageSegment {
314 #[serde(rename = "text")]
315 Text { text: String },
316 #[serde(rename = "thinking")]
317 Thinking { text: String },
318}
319
320#[derive(Debug, Serialize, Deserialize)]
321pub struct SerializedToolUse {
322 pub id: LanguageModelToolUseId,
323 pub name: SharedString,
324 pub input: serde_json::Value,
325}
326
327#[derive(Debug, Serialize, Deserialize)]
328pub struct SerializedToolResult {
329 pub tool_use_id: LanguageModelToolUseId,
330 pub is_error: bool,
331 pub content: Arc<str>,
332}
333
334#[derive(Serialize, Deserialize)]
335struct LegacySerializedThread {
336 pub summary: SharedString,
337 pub updated_at: DateTime<Utc>,
338 pub messages: Vec<LegacySerializedMessage>,
339 #[serde(default)]
340 pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
341}
342
343impl LegacySerializedThread {
344 pub fn upgrade(self) -> SerializedThread {
345 SerializedThread {
346 version: SerializedThread::VERSION.to_string(),
347 summary: self.summary,
348 updated_at: self.updated_at,
349 messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
350 initial_project_snapshot: self.initial_project_snapshot,
351 }
352 }
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356struct LegacySerializedMessage {
357 pub id: MessageId,
358 pub role: Role,
359 pub text: String,
360 #[serde(default)]
361 pub tool_uses: Vec<SerializedToolUse>,
362 #[serde(default)]
363 pub tool_results: Vec<SerializedToolResult>,
364}
365
366impl LegacySerializedMessage {
367 fn upgrade(self) -> SerializedMessage {
368 SerializedMessage {
369 id: self.id,
370 role: self.role,
371 segments: vec![SerializedMessageSegment::Text { text: self.text }],
372 tool_uses: self.tool_uses,
373 tool_results: self.tool_results,
374 }
375 }
376}
377
378struct GlobalThreadsDatabase(
379 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
380);
381
382impl Global for GlobalThreadsDatabase {}
383
384pub(crate) struct ThreadsDatabase {
385 executor: BackgroundExecutor,
386 env: heed::Env,
387 threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
388}
389
390impl heed::BytesEncode<'_> for SerializedThread {
391 type EItem = SerializedThread;
392
393 fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
394 serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
395 }
396}
397
398impl<'a> heed::BytesDecode<'a> for SerializedThread {
399 type DItem = SerializedThread;
400
401 fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
402 // We implement this type manually because we want to call `SerializedThread::from_json`,
403 // instead of the Deserialize trait implementation for `SerializedThread`.
404 SerializedThread::from_json(bytes).map_err(Into::into)
405 }
406}
407
408impl ThreadsDatabase {
409 fn global_future(
410 cx: &mut App,
411 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
412 GlobalThreadsDatabase::global(cx).0.clone()
413 }
414
415 fn init(cx: &mut App) {
416 let executor = cx.background_executor().clone();
417 let database_future = executor
418 .spawn({
419 let executor = executor.clone();
420 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
421 async move { ThreadsDatabase::new(database_path, executor) }
422 })
423 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
424 .boxed()
425 .shared();
426
427 cx.set_global(GlobalThreadsDatabase(database_future));
428 }
429
430 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
431 std::fs::create_dir_all(&path)?;
432
433 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
434 let env = unsafe {
435 heed::EnvOpenOptions::new()
436 .map_size(ONE_GB_IN_BYTES)
437 .max_dbs(1)
438 .open(path)?
439 };
440
441 let mut txn = env.write_txn()?;
442 let threads = env.create_database(&mut txn, Some("threads"))?;
443 txn.commit()?;
444
445 Ok(Self {
446 executor,
447 env,
448 threads,
449 })
450 }
451
452 pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
453 let env = self.env.clone();
454 let threads = self.threads;
455
456 self.executor.spawn(async move {
457 let txn = env.read_txn()?;
458 let mut iter = threads.iter(&txn)?;
459 let mut threads = Vec::new();
460 while let Some((key, value)) = iter.next().transpose()? {
461 threads.push(SerializedThreadMetadata {
462 id: key,
463 summary: value.summary,
464 updated_at: value.updated_at,
465 });
466 }
467
468 Ok(threads)
469 })
470 }
471
472 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
473 let env = self.env.clone();
474 let threads = self.threads;
475
476 self.executor.spawn(async move {
477 let txn = env.read_txn()?;
478 let thread = threads.get(&txn, &id)?;
479 Ok(thread)
480 })
481 }
482
483 pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
484 let env = self.env.clone();
485 let threads = self.threads;
486
487 self.executor.spawn(async move {
488 let mut txn = env.write_txn()?;
489 threads.put(&mut txn, &id, &thread)?;
490 txn.commit()?;
491 Ok(())
492 })
493 }
494
495 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
496 let env = self.env.clone();
497 let threads = self.threads;
498
499 self.executor.spawn(async move {
500 let mut txn = env.write_txn()?;
501 threads.delete(&mut txn, &id)?;
502 txn.commit()?;
503 Ok(())
504 })
505 }
506}