thread_store.rs

  1use std::borrow::Cow;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{anyhow, Result};
  6use assistant_settings::AssistantSettings;
  7use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
  8use chrono::{DateTime, Utc};
  9use collections::HashMap;
 10use context_server::manager::ContextServerManager;
 11use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 12use futures::future::{self, BoxFuture, Shared};
 13use futures::FutureExt as _;
 14use gpui::{
 15    prelude::*, App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task,
 16};
 17use heed::types::SerdeBincode;
 18use heed::Database;
 19use language_model::{LanguageModelToolUseId, Role};
 20use project::Project;
 21use prompt_store::PromptBuilder;
 22use serde::{Deserialize, Serialize};
 23use settings::Settings as _;
 24use util::ResultExt as _;
 25
 26use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
 27
 28pub fn init(cx: &mut App) {
 29    ThreadsDatabase::init(cx);
 30}
 31
 32pub struct ThreadStore {
 33    project: Entity<Project>,
 34    tools: Arc<ToolWorkingSet>,
 35    prompt_builder: Arc<PromptBuilder>,
 36    context_server_manager: Entity<ContextServerManager>,
 37    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 38    threads: Vec<SerializedThreadMetadata>,
 39}
 40
 41impl ThreadStore {
 42    pub fn new(
 43        project: Entity<Project>,
 44        tools: Arc<ToolWorkingSet>,
 45        prompt_builder: Arc<PromptBuilder>,
 46        cx: &mut App,
 47    ) -> Result<Entity<Self>> {
 48        let this = cx.new(|cx| {
 49            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 50            let context_server_manager = cx.new(|cx| {
 51                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 52            });
 53
 54            let this = Self {
 55                project,
 56                tools,
 57                prompt_builder,
 58                context_server_manager,
 59                context_server_tool_ids: HashMap::default(),
 60                threads: Vec::new(),
 61            };
 62            this.load_default_profile(cx);
 63            this.register_context_server_handlers(cx);
 64            this.reload(cx).detach_and_log_err(cx);
 65
 66            this
 67        });
 68
 69        Ok(this)
 70    }
 71
 72    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
 73        self.context_server_manager.clone()
 74    }
 75
 76    pub fn tools(&self) -> Arc<ToolWorkingSet> {
 77        self.tools.clone()
 78    }
 79
 80    /// Returns the number of threads.
 81    pub fn thread_count(&self) -> usize {
 82        self.threads.len()
 83    }
 84
 85    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
 86        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 87        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 88        threads
 89    }
 90
 91    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
 92        self.threads().into_iter().take(limit).collect()
 93    }
 94
 95    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 96        cx.new(|cx| {
 97            Thread::new(
 98                self.project.clone(),
 99                self.tools.clone(),
100                self.prompt_builder.clone(),
101                cx,
102            )
103        })
104    }
105
106    pub fn open_thread(
107        &self,
108        id: &ThreadId,
109        cx: &mut Context<Self>,
110    ) -> Task<Result<Entity<Thread>>> {
111        let id = id.clone();
112        let database_future = ThreadsDatabase::global_future(cx);
113        cx.spawn(async move |this, cx| {
114            let database = database_future.await.map_err(|err| anyhow!(err))?;
115            let thread = database
116                .try_find_thread(id.clone())
117                .await?
118                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
119
120            let thread = this.update(cx, |this, cx| {
121                cx.new(|cx| {
122                    Thread::deserialize(
123                        id.clone(),
124                        thread,
125                        this.project.clone(),
126                        this.tools.clone(),
127                        this.prompt_builder.clone(),
128                        cx,
129                    )
130                })
131            })?;
132
133            let (system_prompt_context, load_error) = thread
134                .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
135                .await;
136            thread.update(cx, |thread, cx| {
137                thread.set_system_prompt_context(system_prompt_context);
138                if let Some(load_error) = load_error {
139                    cx.emit(ThreadEvent::ShowError(load_error));
140                }
141            })?;
142
143            Ok(thread)
144        })
145    }
146
147    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
148        let (metadata, serialized_thread) =
149            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
150
151        let database_future = ThreadsDatabase::global_future(cx);
152        cx.spawn(async move |this, cx| {
153            let serialized_thread = serialized_thread.await?;
154            let database = database_future.await.map_err(|err| anyhow!(err))?;
155            database.save_thread(metadata, serialized_thread).await?;
156
157            this.update(cx, |this, cx| this.reload(cx))?.await
158        })
159    }
160
161    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
162        let id = id.clone();
163        let database_future = ThreadsDatabase::global_future(cx);
164        cx.spawn(async move |this, cx| {
165            let database = database_future.await.map_err(|err| anyhow!(err))?;
166            database.delete_thread(id.clone()).await?;
167
168            this.update(cx, |this, _cx| {
169                this.threads.retain(|thread| thread.id != id)
170            })
171        })
172    }
173
174    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
175        let database_future = ThreadsDatabase::global_future(cx);
176        cx.spawn(async move |this, cx| {
177            let threads = database_future
178                .await
179                .map_err(|err| anyhow!(err))?
180                .list_threads()
181                .await?;
182
183            this.update(cx, |this, cx| {
184                this.threads = threads;
185                cx.notify();
186            })
187        })
188    }
189
190    pub fn load_default_profile(&self, cx: &mut Context<Self>) {
191        let assistant_settings = AssistantSettings::get_global(cx);
192
193        if let Some(profile) = assistant_settings
194            .profiles
195            .get(&assistant_settings.default_profile)
196        {
197            self.tools.disable_source(ToolSource::Native, cx);
198            self.tools.enable(
199                ToolSource::Native,
200                &profile
201                    .tools
202                    .iter()
203                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
204                    .collect::<Vec<_>>(),
205            );
206
207            for (context_server_id, preset) in &profile.context_servers {
208                self.tools.enable(
209                    ToolSource::ContextServer {
210                        id: context_server_id.clone().into(),
211                    },
212                    &preset
213                        .tools
214                        .iter()
215                        .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
216                        .collect::<Vec<_>>(),
217                )
218            }
219        }
220    }
221
222    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
223        cx.subscribe(
224            &self.context_server_manager.clone(),
225            Self::handle_context_server_event,
226        )
227        .detach();
228    }
229
230    fn handle_context_server_event(
231        &mut self,
232        context_server_manager: Entity<ContextServerManager>,
233        event: &context_server::manager::Event,
234        cx: &mut Context<Self>,
235    ) {
236        let tool_working_set = self.tools.clone();
237        match event {
238            context_server::manager::Event::ServerStarted { server_id } => {
239                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
240                    let context_server_manager = context_server_manager.clone();
241                    cx.spawn({
242                        let server = server.clone();
243                        let server_id = server_id.clone();
244                        async move |this, cx| {
245                            let Some(protocol) = server.client() else {
246                                return;
247                            };
248
249                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
250                                if let Some(tools) = protocol.list_tools().await.log_err() {
251                                    let tool_ids = tools
252                                        .tools
253                                        .into_iter()
254                                        .map(|tool| {
255                                            log::info!(
256                                                "registering context server tool: {:?}",
257                                                tool.name
258                                            );
259                                            tool_working_set.insert(Arc::new(
260                                                ContextServerTool::new(
261                                                    context_server_manager.clone(),
262                                                    server.id(),
263                                                    tool,
264                                                ),
265                                            ))
266                                        })
267                                        .collect::<Vec<_>>();
268
269                                    this.update(cx, |this, _cx| {
270                                        this.context_server_tool_ids.insert(server_id, tool_ids);
271                                    })
272                                    .log_err();
273                                }
274                            }
275                        }
276                    })
277                    .detach();
278                }
279            }
280            context_server::manager::Event::ServerStopped { server_id } => {
281                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
282                    tool_working_set.remove(&tool_ids);
283                }
284            }
285        }
286    }
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct SerializedThreadMetadata {
291    pub id: ThreadId,
292    pub summary: SharedString,
293    pub updated_at: DateTime<Utc>,
294}
295
296#[derive(Serialize, Deserialize)]
297pub struct SerializedThread {
298    pub version: String,
299    pub summary: SharedString,
300    pub updated_at: DateTime<Utc>,
301    pub messages: Vec<SerializedMessage>,
302    #[serde(default)]
303    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
304}
305
306impl SerializedThread {
307    pub const VERSION: &'static str = "0.1.0";
308
309    pub fn from_json(json: &[u8]) -> Result<Self> {
310        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
311        match saved_thread_json.get("version") {
312            Some(serde_json::Value::String(version)) => match version.as_str() {
313                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
314                    saved_thread_json,
315                )?),
316                _ => Err(anyhow!(
317                    "unrecognized serialized thread version: {}",
318                    version
319                )),
320            },
321            None => {
322                let saved_thread =
323                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
324                Ok(saved_thread.upgrade())
325            }
326            version => Err(anyhow!(
327                "unrecognized serialized thread version: {:?}",
328                version
329            )),
330        }
331    }
332}
333
334#[derive(Debug, Serialize, Deserialize)]
335pub struct SerializedMessage {
336    pub id: MessageId,
337    pub role: Role,
338    #[serde(default)]
339    pub segments: Vec<SerializedMessageSegment>,
340    #[serde(default)]
341    pub tool_uses: Vec<SerializedToolUse>,
342    #[serde(default)]
343    pub tool_results: Vec<SerializedToolResult>,
344}
345
346#[derive(Debug, Serialize, Deserialize)]
347#[serde(tag = "type")]
348pub enum SerializedMessageSegment {
349    #[serde(rename = "text")]
350    Text { text: String },
351    #[serde(rename = "thinking")]
352    Thinking { text: String },
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356pub struct SerializedToolUse {
357    pub id: LanguageModelToolUseId,
358    pub name: SharedString,
359    pub input: serde_json::Value,
360}
361
362#[derive(Debug, Serialize, Deserialize)]
363pub struct SerializedToolResult {
364    pub tool_use_id: LanguageModelToolUseId,
365    pub is_error: bool,
366    pub content: Arc<str>,
367}
368
369#[derive(Serialize, Deserialize)]
370struct LegacySerializedThread {
371    pub summary: SharedString,
372    pub updated_at: DateTime<Utc>,
373    pub messages: Vec<LegacySerializedMessage>,
374    #[serde(default)]
375    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
376}
377
378impl LegacySerializedThread {
379    pub fn upgrade(self) -> SerializedThread {
380        SerializedThread {
381            version: SerializedThread::VERSION.to_string(),
382            summary: self.summary,
383            updated_at: self.updated_at,
384            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
385            initial_project_snapshot: self.initial_project_snapshot,
386        }
387    }
388}
389
390#[derive(Debug, Serialize, Deserialize)]
391struct LegacySerializedMessage {
392    pub id: MessageId,
393    pub role: Role,
394    pub text: String,
395    #[serde(default)]
396    pub tool_uses: Vec<SerializedToolUse>,
397    #[serde(default)]
398    pub tool_results: Vec<SerializedToolResult>,
399}
400
401impl LegacySerializedMessage {
402    fn upgrade(self) -> SerializedMessage {
403        SerializedMessage {
404            id: self.id,
405            role: self.role,
406            segments: vec![SerializedMessageSegment::Text { text: self.text }],
407            tool_uses: self.tool_uses,
408            tool_results: self.tool_results,
409        }
410    }
411}
412
413struct GlobalThreadsDatabase(
414    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
415);
416
417impl Global for GlobalThreadsDatabase {}
418
419pub(crate) struct ThreadsDatabase {
420    executor: BackgroundExecutor,
421    env: heed::Env,
422    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
423}
424
425impl heed::BytesEncode<'_> for SerializedThread {
426    type EItem = SerializedThread;
427
428    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
429        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
430    }
431}
432
433impl<'a> heed::BytesDecode<'a> for SerializedThread {
434    type DItem = SerializedThread;
435
436    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
437        // We implement this type manually because we want to call `SerializedThread::from_json`,
438        // instead of the Deserialize trait implementation for `SerializedThread`.
439        SerializedThread::from_json(bytes).map_err(Into::into)
440    }
441}
442
443impl ThreadsDatabase {
444    fn global_future(
445        cx: &mut App,
446    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
447        GlobalThreadsDatabase::global(cx).0.clone()
448    }
449
450    fn init(cx: &mut App) {
451        let executor = cx.background_executor().clone();
452        let database_future = executor
453            .spawn({
454                let executor = executor.clone();
455                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
456                async move { ThreadsDatabase::new(database_path, executor) }
457            })
458            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
459            .boxed()
460            .shared();
461
462        cx.set_global(GlobalThreadsDatabase(database_future));
463    }
464
465    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
466        std::fs::create_dir_all(&path)?;
467
468        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
469        let env = unsafe {
470            heed::EnvOpenOptions::new()
471                .map_size(ONE_GB_IN_BYTES)
472                .max_dbs(1)
473                .open(path)?
474        };
475
476        let mut txn = env.write_txn()?;
477        let threads = env.create_database(&mut txn, Some("threads"))?;
478        txn.commit()?;
479
480        Ok(Self {
481            executor,
482            env,
483            threads,
484        })
485    }
486
487    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
488        let env = self.env.clone();
489        let threads = self.threads;
490
491        self.executor.spawn(async move {
492            let txn = env.read_txn()?;
493            let mut iter = threads.iter(&txn)?;
494            let mut threads = Vec::new();
495            while let Some((key, value)) = iter.next().transpose()? {
496                threads.push(SerializedThreadMetadata {
497                    id: key,
498                    summary: value.summary,
499                    updated_at: value.updated_at,
500                });
501            }
502
503            Ok(threads)
504        })
505    }
506
507    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
508        let env = self.env.clone();
509        let threads = self.threads;
510
511        self.executor.spawn(async move {
512            let txn = env.read_txn()?;
513            let thread = threads.get(&txn, &id)?;
514            Ok(thread)
515        })
516    }
517
518    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
519        let env = self.env.clone();
520        let threads = self.threads;
521
522        self.executor.spawn(async move {
523            let mut txn = env.write_txn()?;
524            threads.put(&mut txn, &id, &thread)?;
525            txn.commit()?;
526            Ok(())
527        })
528    }
529
530    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
531        let env = self.env.clone();
532        let threads = self.threads;
533
534        self.executor.spawn(async move {
535            let mut txn = env.write_txn()?;
536            threads.delete(&mut txn, &id)?;
537            txn.commit()?;
538            Ok(())
539        })
540    }
541}