1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use assistant_tool::{ToolId, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::HashMap;
8use context_server::manager::ContextServerManager;
9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
10use futures::future::{self, BoxFuture, Shared};
11use futures::FutureExt as _;
12use gpui::{
13 prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
14};
15use heed::types::{SerdeBincode, SerdeJson};
16use heed::Database;
17use language_model::{LanguageModelToolUseId, Role};
18use project::Project;
19use serde::{Deserialize, Serialize};
20use util::ResultExt as _;
21
22use crate::thread::{MessageId, Thread, ThreadId};
23
24pub fn init(cx: &mut App) {
25 ThreadsDatabase::init(cx);
26}
27
28pub struct ThreadStore {
29 project: Entity<Project>,
30 tools: Arc<ToolWorkingSet>,
31 context_server_manager: Entity<ContextServerManager>,
32 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
33 threads: Vec<SavedThreadMetadata>,
34}
35
36impl ThreadStore {
37 pub fn new(
38 project: Entity<Project>,
39 tools: Arc<ToolWorkingSet>,
40 cx: &mut App,
41 ) -> Result<Entity<Self>> {
42 let this = cx.new(|cx| {
43 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
44 let context_server_manager = cx.new(|cx| {
45 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
46 });
47
48 let this = Self {
49 project,
50 tools,
51 context_server_manager,
52 context_server_tool_ids: HashMap::default(),
53 threads: Vec::new(),
54 };
55 this.register_context_server_handlers(cx);
56 this.reload(cx).detach_and_log_err(cx);
57
58 this
59 });
60
61 Ok(this)
62 }
63
64 /// Returns the number of threads.
65 pub fn thread_count(&self) -> usize {
66 self.threads.len()
67 }
68
69 pub fn threads(&self) -> Vec<SavedThreadMetadata> {
70 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
71 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
72 threads
73 }
74
75 pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
76 self.threads().into_iter().take(limit).collect()
77 }
78
79 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
80 cx.new(|cx| Thread::new(self.project.clone(), self.tools.clone(), cx))
81 }
82
83 pub fn open_thread(
84 &self,
85 id: &ThreadId,
86 cx: &mut Context<Self>,
87 ) -> Task<Result<Entity<Thread>>> {
88 let id = id.clone();
89 let database_future = ThreadsDatabase::global_future(cx);
90 cx.spawn(|this, mut cx| async move {
91 let database = database_future.await.map_err(|err| anyhow!(err))?;
92 let thread = database
93 .try_find_thread(id.clone())
94 .await?
95 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
96
97 this.update(&mut cx, |this, cx| {
98 cx.new(|cx| {
99 Thread::from_saved(
100 id.clone(),
101 thread,
102 this.project.clone(),
103 this.tools.clone(),
104 cx,
105 )
106 })
107 })
108 })
109 }
110
111 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
112 let (metadata, thread) = thread.update(cx, |thread, _cx| {
113 let id = thread.id().clone();
114 let thread = SavedThread {
115 summary: thread.summary_or_default(),
116 updated_at: thread.updated_at(),
117 messages: thread
118 .messages()
119 .map(|message| SavedMessage {
120 id: message.id,
121 role: message.role,
122 text: message.text.clone(),
123 tool_uses: thread
124 .tool_uses_for_message(message.id)
125 .into_iter()
126 .map(|tool_use| SavedToolUse {
127 id: tool_use.id,
128 name: tool_use.name,
129 input: tool_use.input,
130 })
131 .collect(),
132 tool_results: thread
133 .tool_results_for_message(message.id)
134 .into_iter()
135 .map(|tool_result| SavedToolResult {
136 tool_use_id: tool_result.tool_use_id.clone(),
137 is_error: tool_result.is_error,
138 content: tool_result.content.clone(),
139 })
140 .collect(),
141 })
142 .collect(),
143 };
144
145 (id, thread)
146 });
147
148 let database_future = ThreadsDatabase::global_future(cx);
149 cx.spawn(|this, mut cx| async move {
150 let database = database_future.await.map_err(|err| anyhow!(err))?;
151 database.save_thread(metadata, thread).await?;
152
153 this.update(&mut cx, |this, cx| this.reload(cx))?.await
154 })
155 }
156
157 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
158 let id = id.clone();
159 let database_future = ThreadsDatabase::global_future(cx);
160 cx.spawn(|this, mut cx| async move {
161 let database = database_future.await.map_err(|err| anyhow!(err))?;
162 database.delete_thread(id.clone()).await?;
163
164 this.update(&mut cx, |this, _cx| {
165 this.threads.retain(|thread| thread.id != id)
166 })
167 })
168 }
169
170 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
171 let database_future = ThreadsDatabase::global_future(cx);
172 cx.spawn(|this, mut cx| async move {
173 let threads = database_future
174 .await
175 .map_err(|err| anyhow!(err))?
176 .list_threads()
177 .await?;
178
179 this.update(&mut cx, |this, cx| {
180 this.threads = threads;
181 cx.notify();
182 })
183 })
184 }
185
186 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
187 cx.subscribe(
188 &self.context_server_manager.clone(),
189 Self::handle_context_server_event,
190 )
191 .detach();
192 }
193
194 fn handle_context_server_event(
195 &mut self,
196 context_server_manager: Entity<ContextServerManager>,
197 event: &context_server::manager::Event,
198 cx: &mut Context<Self>,
199 ) {
200 let tool_working_set = self.tools.clone();
201 match event {
202 context_server::manager::Event::ServerStarted { server_id } => {
203 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
204 let context_server_manager = context_server_manager.clone();
205 cx.spawn({
206 let server = server.clone();
207 let server_id = server_id.clone();
208 |this, mut cx| async move {
209 let Some(protocol) = server.client() else {
210 return;
211 };
212
213 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
214 if let Some(tools) = protocol.list_tools().await.log_err() {
215 let tool_ids = tools
216 .tools
217 .into_iter()
218 .map(|tool| {
219 log::info!(
220 "registering context server tool: {:?}",
221 tool.name
222 );
223 tool_working_set.insert(Arc::new(
224 ContextServerTool::new(
225 context_server_manager.clone(),
226 server.id(),
227 tool,
228 ),
229 ))
230 })
231 .collect::<Vec<_>>();
232
233 this.update(&mut cx, |this, _cx| {
234 this.context_server_tool_ids.insert(server_id, tool_ids);
235 })
236 .log_err();
237 }
238 }
239 }
240 })
241 .detach();
242 }
243 }
244 context_server::manager::Event::ServerStopped { server_id } => {
245 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
246 tool_working_set.remove(&tool_ids);
247 }
248 }
249 }
250 }
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize)]
254pub struct SavedThreadMetadata {
255 pub id: ThreadId,
256 pub summary: SharedString,
257 pub updated_at: DateTime<Utc>,
258}
259
260#[derive(Serialize, Deserialize)]
261pub struct SavedThread {
262 pub summary: SharedString,
263 pub updated_at: DateTime<Utc>,
264 pub messages: Vec<SavedMessage>,
265}
266
267#[derive(Debug, Serialize, Deserialize)]
268pub struct SavedMessage {
269 pub id: MessageId,
270 pub role: Role,
271 pub text: String,
272 #[serde(default)]
273 pub tool_uses: Vec<SavedToolUse>,
274 #[serde(default)]
275 pub tool_results: Vec<SavedToolResult>,
276}
277
278#[derive(Debug, Serialize, Deserialize)]
279pub struct SavedToolUse {
280 pub id: LanguageModelToolUseId,
281 pub name: SharedString,
282 pub input: serde_json::Value,
283}
284
285#[derive(Debug, Serialize, Deserialize)]
286pub struct SavedToolResult {
287 pub tool_use_id: LanguageModelToolUseId,
288 pub is_error: bool,
289 pub content: Arc<str>,
290}
291
292struct GlobalThreadsDatabase(
293 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
294);
295
296impl Global for GlobalThreadsDatabase {}
297
298pub(crate) struct ThreadsDatabase {
299 executor: BackgroundExecutor,
300 env: heed::Env,
301 threads: Database<SerdeBincode<ThreadId>, SerdeJson<SavedThread>>,
302}
303
304impl ThreadsDatabase {
305 fn global_future(
306 cx: &mut App,
307 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
308 GlobalThreadsDatabase::global(cx).0.clone()
309 }
310
311 fn init(cx: &mut App) {
312 let executor = cx.background_executor().clone();
313 let database_future = executor
314 .spawn({
315 let executor = executor.clone();
316 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
317 async move { ThreadsDatabase::new(database_path, executor) }
318 })
319 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
320 .boxed()
321 .shared();
322
323 cx.set_global(GlobalThreadsDatabase(database_future));
324 }
325
326 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
327 std::fs::create_dir_all(&path)?;
328
329 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
330 let env = unsafe {
331 heed::EnvOpenOptions::new()
332 .map_size(ONE_GB_IN_BYTES)
333 .max_dbs(1)
334 .open(path)?
335 };
336
337 let mut txn = env.write_txn()?;
338 let threads = env.create_database(&mut txn, Some("threads"))?;
339 txn.commit()?;
340
341 Ok(Self {
342 executor,
343 env,
344 threads,
345 })
346 }
347
348 pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
349 let env = self.env.clone();
350 let threads = self.threads;
351
352 self.executor.spawn(async move {
353 let txn = env.read_txn()?;
354 let mut iter = threads.iter(&txn)?;
355 let mut threads = Vec::new();
356 while let Some((key, value)) = iter.next().transpose()? {
357 threads.push(SavedThreadMetadata {
358 id: key,
359 summary: value.summary,
360 updated_at: value.updated_at,
361 });
362 }
363
364 Ok(threads)
365 })
366 }
367
368 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
369 let env = self.env.clone();
370 let threads = self.threads;
371
372 self.executor.spawn(async move {
373 let txn = env.read_txn()?;
374 let thread = threads.get(&txn, &id)?;
375 Ok(thread)
376 })
377 }
378
379 pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
380 let env = self.env.clone();
381 let threads = self.threads;
382
383 self.executor.spawn(async move {
384 let mut txn = env.write_txn()?;
385 threads.put(&mut txn, &id, &thread)?;
386 txn.commit()?;
387 Ok(())
388 })
389 }
390
391 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
392 let env = self.env.clone();
393 let threads = self.threads;
394
395 self.executor.spawn(async move {
396 let mut txn = env.write_txn()?;
397 threads.delete(&mut txn, &id)?;
398 txn.commit()?;
399 Ok(())
400 })
401 }
402}