1use std::path::PathBuf;
2use std::sync::Arc;
3
4use anyhow::{anyhow, Result};
5use assistant_tool::{ToolId, ToolWorkingSet};
6use chrono::{DateTime, Utc};
7use collections::HashMap;
8use context_server::manager::ContextServerManager;
9use context_server::{ContextServerFactoryRegistry, ContextServerTool};
10use futures::future::{self, BoxFuture, Shared};
11use futures::FutureExt as _;
12use gpui::{
13 prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
14};
15use heed::types::{SerdeBincode, SerdeJson};
16use heed::Database;
17use language_model::{LanguageModelToolUseId, Role};
18use project::Project;
19use serde::{Deserialize, Serialize};
20use util::ResultExt as _;
21
22use crate::thread::{MessageId, Thread, ThreadId};
23
24pub fn init(cx: &mut App) {
25 ThreadsDatabase::init(cx);
26}
27
28pub struct ThreadStore {
29 #[allow(unused)]
30 project: Entity<Project>,
31 tools: Arc<ToolWorkingSet>,
32 context_server_manager: Entity<ContextServerManager>,
33 context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
34 threads: Vec<SavedThreadMetadata>,
35}
36
37impl ThreadStore {
38 pub fn new(
39 project: Entity<Project>,
40 tools: Arc<ToolWorkingSet>,
41 cx: &mut App,
42 ) -> Result<Entity<Self>> {
43 let this = cx.new(|cx| {
44 let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
45 let context_server_manager = cx.new(|cx| {
46 ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
47 });
48
49 let this = Self {
50 project,
51 tools,
52 context_server_manager,
53 context_server_tool_ids: HashMap::default(),
54 threads: Vec::new(),
55 };
56 this.register_context_server_handlers(cx);
57 this.reload(cx).detach_and_log_err(cx);
58
59 this
60 });
61
62 Ok(this)
63 }
64
65 /// Returns the number of threads.
66 pub fn thread_count(&self) -> usize {
67 self.threads.len()
68 }
69
70 pub fn threads(&self) -> Vec<SavedThreadMetadata> {
71 let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
72 threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
73 threads
74 }
75
76 pub fn recent_threads(&self, limit: usize) -> Vec<SavedThreadMetadata> {
77 self.threads().into_iter().take(limit).collect()
78 }
79
80 pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
81 cx.new(|cx| Thread::new(self.tools.clone(), cx))
82 }
83
84 pub fn open_thread(
85 &self,
86 id: &ThreadId,
87 cx: &mut Context<Self>,
88 ) -> Task<Result<Entity<Thread>>> {
89 let id = id.clone();
90 let database_future = ThreadsDatabase::global_future(cx);
91 cx.spawn(|this, mut cx| async move {
92 let database = database_future.await.map_err(|err| anyhow!(err))?;
93 let thread = database
94 .try_find_thread(id.clone())
95 .await?
96 .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
97
98 this.update(&mut cx, |this, cx| {
99 cx.new(|cx| Thread::from_saved(id.clone(), thread, this.tools.clone(), cx))
100 })
101 })
102 }
103
104 pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
105 let (metadata, thread) = thread.update(cx, |thread, _cx| {
106 let id = thread.id().clone();
107 let thread = SavedThread {
108 summary: thread.summary_or_default(),
109 updated_at: thread.updated_at(),
110 messages: thread
111 .messages()
112 .map(|message| SavedMessage {
113 id: message.id,
114 role: message.role,
115 text: message.text.clone(),
116 tool_uses: thread
117 .tool_uses_for_message(message.id)
118 .into_iter()
119 .map(|tool_use| SavedToolUse {
120 id: tool_use.id,
121 name: tool_use.name,
122 input: tool_use.input,
123 })
124 .collect(),
125 tool_results: thread
126 .tool_results_for_message(message.id)
127 .into_iter()
128 .map(|tool_result| SavedToolResult {
129 tool_use_id: tool_result.tool_use_id.clone(),
130 is_error: tool_result.is_error,
131 content: tool_result.content.clone(),
132 })
133 .collect(),
134 })
135 .collect(),
136 };
137
138 (id, thread)
139 });
140
141 let database_future = ThreadsDatabase::global_future(cx);
142 cx.spawn(|this, mut cx| async move {
143 let database = database_future.await.map_err(|err| anyhow!(err))?;
144 database.save_thread(metadata, thread).await?;
145
146 this.update(&mut cx, |this, cx| this.reload(cx))?.await
147 })
148 }
149
150 pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
151 let id = id.clone();
152 let database_future = ThreadsDatabase::global_future(cx);
153 cx.spawn(|this, mut cx| async move {
154 let database = database_future.await.map_err(|err| anyhow!(err))?;
155 database.delete_thread(id.clone()).await?;
156
157 this.update(&mut cx, |this, _cx| {
158 this.threads.retain(|thread| thread.id != id)
159 })
160 })
161 }
162
163 pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
164 let database_future = ThreadsDatabase::global_future(cx);
165 cx.spawn(|this, mut cx| async move {
166 let threads = database_future
167 .await
168 .map_err(|err| anyhow!(err))?
169 .list_threads()
170 .await?;
171
172 this.update(&mut cx, |this, cx| {
173 this.threads = threads;
174 cx.notify();
175 })
176 })
177 }
178
179 fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
180 cx.subscribe(
181 &self.context_server_manager.clone(),
182 Self::handle_context_server_event,
183 )
184 .detach();
185 }
186
187 fn handle_context_server_event(
188 &mut self,
189 context_server_manager: Entity<ContextServerManager>,
190 event: &context_server::manager::Event,
191 cx: &mut Context<Self>,
192 ) {
193 let tool_working_set = self.tools.clone();
194 match event {
195 context_server::manager::Event::ServerStarted { server_id } => {
196 if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
197 let context_server_manager = context_server_manager.clone();
198 cx.spawn({
199 let server = server.clone();
200 let server_id = server_id.clone();
201 |this, mut cx| async move {
202 let Some(protocol) = server.client() else {
203 return;
204 };
205
206 if protocol.capable(context_server::protocol::ServerCapability::Tools) {
207 if let Some(tools) = protocol.list_tools().await.log_err() {
208 let tool_ids = tools
209 .tools
210 .into_iter()
211 .map(|tool| {
212 log::info!(
213 "registering context server tool: {:?}",
214 tool.name
215 );
216 tool_working_set.insert(Arc::new(
217 ContextServerTool::new(
218 context_server_manager.clone(),
219 server.id(),
220 tool,
221 ),
222 ))
223 })
224 .collect::<Vec<_>>();
225
226 this.update(&mut cx, |this, _cx| {
227 this.context_server_tool_ids.insert(server_id, tool_ids);
228 })
229 .log_err();
230 }
231 }
232 }
233 })
234 .detach();
235 }
236 }
237 context_server::manager::Event::ServerStopped { server_id } => {
238 if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
239 tool_working_set.remove(&tool_ids);
240 }
241 }
242 }
243 }
244}
245
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct SavedThreadMetadata {
248 pub id: ThreadId,
249 pub summary: SharedString,
250 pub updated_at: DateTime<Utc>,
251}
252
253#[derive(Serialize, Deserialize)]
254pub struct SavedThread {
255 pub summary: SharedString,
256 pub updated_at: DateTime<Utc>,
257 pub messages: Vec<SavedMessage>,
258}
259
260#[derive(Debug, Serialize, Deserialize)]
261pub struct SavedMessage {
262 pub id: MessageId,
263 pub role: Role,
264 pub text: String,
265 #[serde(default)]
266 pub tool_uses: Vec<SavedToolUse>,
267 #[serde(default)]
268 pub tool_results: Vec<SavedToolResult>,
269}
270
271#[derive(Debug, Serialize, Deserialize)]
272pub struct SavedToolUse {
273 pub id: LanguageModelToolUseId,
274 pub name: SharedString,
275 pub input: serde_json::Value,
276}
277
278#[derive(Debug, Serialize, Deserialize)]
279pub struct SavedToolResult {
280 pub tool_use_id: LanguageModelToolUseId,
281 pub is_error: bool,
282 pub content: Arc<str>,
283}
284
285struct GlobalThreadsDatabase(
286 Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
287);
288
289impl Global for GlobalThreadsDatabase {}
290
291pub(crate) struct ThreadsDatabase {
292 executor: BackgroundExecutor,
293 env: heed::Env,
294 threads: Database<SerdeBincode<ThreadId>, SerdeJson<SavedThread>>,
295}
296
297impl ThreadsDatabase {
298 fn global_future(
299 cx: &mut App,
300 ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
301 GlobalThreadsDatabase::global(cx).0.clone()
302 }
303
304 fn init(cx: &mut App) {
305 let executor = cx.background_executor().clone();
306 let database_future = executor
307 .spawn({
308 let executor = executor.clone();
309 let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
310 async move { ThreadsDatabase::new(database_path, executor) }
311 })
312 .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
313 .boxed()
314 .shared();
315
316 cx.set_global(GlobalThreadsDatabase(database_future));
317 }
318
319 pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
320 std::fs::create_dir_all(&path)?;
321
322 const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
323 let env = unsafe {
324 heed::EnvOpenOptions::new()
325 .map_size(ONE_GB_IN_BYTES)
326 .max_dbs(1)
327 .open(path)?
328 };
329
330 let mut txn = env.write_txn()?;
331 let threads = env.create_database(&mut txn, Some("threads"))?;
332 txn.commit()?;
333
334 Ok(Self {
335 executor,
336 env,
337 threads,
338 })
339 }
340
341 pub fn list_threads(&self) -> Task<Result<Vec<SavedThreadMetadata>>> {
342 let env = self.env.clone();
343 let threads = self.threads;
344
345 self.executor.spawn(async move {
346 let txn = env.read_txn()?;
347 let mut iter = threads.iter(&txn)?;
348 let mut threads = Vec::new();
349 while let Some((key, value)) = iter.next().transpose()? {
350 threads.push(SavedThreadMetadata {
351 id: key,
352 summary: value.summary,
353 updated_at: value.updated_at,
354 });
355 }
356
357 Ok(threads)
358 })
359 }
360
361 pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SavedThread>>> {
362 let env = self.env.clone();
363 let threads = self.threads;
364
365 self.executor.spawn(async move {
366 let txn = env.read_txn()?;
367 let thread = threads.get(&txn, &id)?;
368 Ok(thread)
369 })
370 }
371
372 pub fn save_thread(&self, id: ThreadId, thread: SavedThread) -> Task<Result<()>> {
373 let env = self.env.clone();
374 let threads = self.threads;
375
376 self.executor.spawn(async move {
377 let mut txn = env.write_txn()?;
378 threads.put(&mut txn, &id, &thread)?;
379 txn.commit()?;
380 Ok(())
381 })
382 }
383
384 pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
385 let env = self.env.clone();
386 let threads = self.threads;
387
388 self.executor.spawn(async move {
389 let mut txn = env.write_txn()?;
390 threads.delete(&mut txn, &id)?;
391 txn.commit()?;
392 Ok(())
393 })
394 }
395}