thread_store.rs

  1use std::borrow::Cow;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use assistant_settings::{AgentProfile, AssistantSettings};
  7use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
  8use chrono::{DateTime, Utc};
  9use collections::HashMap;
 10use context_server::manager::ContextServerManager;
 11use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 12use futures::FutureExt as _;
 13use futures::future::{self, BoxFuture, Shared};
 14use gpui::{
 15    App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Task, prelude::*,
 16};
 17use heed::Database;
 18use heed::types::SerdeBincode;
 19use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 20use project::Project;
 21use prompt_store::PromptBuilder;
 22use serde::{Deserialize, Serialize};
 23use settings::Settings as _;
 24use util::ResultExt as _;
 25
 26use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
 27
 28pub fn init(cx: &mut App) {
 29    ThreadsDatabase::init(cx);
 30}
 31
 32pub struct ThreadStore {
 33    project: Entity<Project>,
 34    tools: Arc<ToolWorkingSet>,
 35    prompt_builder: Arc<PromptBuilder>,
 36    context_server_manager: Entity<ContextServerManager>,
 37    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 38    threads: Vec<SerializedThreadMetadata>,
 39}
 40
 41impl ThreadStore {
 42    pub fn new(
 43        project: Entity<Project>,
 44        tools: Arc<ToolWorkingSet>,
 45        prompt_builder: Arc<PromptBuilder>,
 46        cx: &mut App,
 47    ) -> Result<Entity<Self>> {
 48        let this = cx.new(|cx| {
 49            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 50            let context_server_manager = cx.new(|cx| {
 51                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 52            });
 53
 54            let this = Self {
 55                project,
 56                tools,
 57                prompt_builder,
 58                context_server_manager,
 59                context_server_tool_ids: HashMap::default(),
 60                threads: Vec::new(),
 61            };
 62            this.load_default_profile(cx);
 63            this.register_context_server_handlers(cx);
 64            this.reload(cx).detach_and_log_err(cx);
 65
 66            this
 67        });
 68
 69        Ok(this)
 70    }
 71
 72    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
 73        self.context_server_manager.clone()
 74    }
 75
 76    pub fn tools(&self) -> Arc<ToolWorkingSet> {
 77        self.tools.clone()
 78    }
 79
 80    /// Returns the number of threads.
 81    pub fn thread_count(&self) -> usize {
 82        self.threads.len()
 83    }
 84
 85    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
 86        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 87        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 88        threads
 89    }
 90
 91    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
 92        self.threads().into_iter().take(limit).collect()
 93    }
 94
 95    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
 96        cx.new(|cx| {
 97            Thread::new(
 98                self.project.clone(),
 99                self.tools.clone(),
100                self.prompt_builder.clone(),
101                cx,
102            )
103        })
104    }
105
106    pub fn open_thread(
107        &self,
108        id: &ThreadId,
109        cx: &mut Context<Self>,
110    ) -> Task<Result<Entity<Thread>>> {
111        let id = id.clone();
112        let database_future = ThreadsDatabase::global_future(cx);
113        cx.spawn(async move |this, cx| {
114            let database = database_future.await.map_err(|err| anyhow!(err))?;
115            let thread = database
116                .try_find_thread(id.clone())
117                .await?
118                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
119
120            let thread = this.update(cx, |this, cx| {
121                cx.new(|cx| {
122                    Thread::deserialize(
123                        id.clone(),
124                        thread,
125                        this.project.clone(),
126                        this.tools.clone(),
127                        this.prompt_builder.clone(),
128                        cx,
129                    )
130                })
131            })?;
132
133            let (system_prompt_context, load_error) = thread
134                .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
135                .await;
136            thread.update(cx, |thread, cx| {
137                thread.set_system_prompt_context(system_prompt_context);
138                if let Some(load_error) = load_error {
139                    cx.emit(ThreadEvent::ShowError(load_error));
140                }
141            })?;
142
143            Ok(thread)
144        })
145    }
146
147    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
148        let (metadata, serialized_thread) =
149            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
150
151        let database_future = ThreadsDatabase::global_future(cx);
152        cx.spawn(async move |this, cx| {
153            let serialized_thread = serialized_thread.await?;
154            let database = database_future.await.map_err(|err| anyhow!(err))?;
155            database.save_thread(metadata, serialized_thread).await?;
156
157            this.update(cx, |this, cx| this.reload(cx))?.await
158        })
159    }
160
161    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
162        let id = id.clone();
163        let database_future = ThreadsDatabase::global_future(cx);
164        cx.spawn(async move |this, cx| {
165            let database = database_future.await.map_err(|err| anyhow!(err))?;
166            database.delete_thread(id.clone()).await?;
167
168            this.update(cx, |this, _cx| {
169                this.threads.retain(|thread| thread.id != id)
170            })
171        })
172    }
173
174    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
175        let database_future = ThreadsDatabase::global_future(cx);
176        cx.spawn(async move |this, cx| {
177            let threads = database_future
178                .await
179                .map_err(|err| anyhow!(err))?
180                .list_threads()
181                .await?;
182
183            this.update(cx, |this, cx| {
184                this.threads = threads;
185                cx.notify();
186            })
187        })
188    }
189
190    fn load_default_profile(&self, cx: &Context<Self>) {
191        let assistant_settings = AssistantSettings::get_global(cx);
192
193        self.load_profile_by_id(&assistant_settings.default_profile, cx);
194    }
195
196    pub fn load_profile_by_id(&self, profile_id: &Arc<str>, cx: &Context<Self>) {
197        let assistant_settings = AssistantSettings::get_global(cx);
198
199        if let Some(profile) = assistant_settings.profiles.get(profile_id) {
200            self.load_profile(profile);
201        }
202    }
203
204    pub fn load_profile(&self, profile: &AgentProfile) {
205        self.tools.disable_all_tools();
206        self.tools.enable(
207            ToolSource::Native,
208            &profile
209                .tools
210                .iter()
211                .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
212                .collect::<Vec<_>>(),
213        );
214
215        for (context_server_id, preset) in &profile.context_servers {
216            self.tools.enable(
217                ToolSource::ContextServer {
218                    id: context_server_id.clone().into(),
219                },
220                &preset
221                    .tools
222                    .iter()
223                    .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
224                    .collect::<Vec<_>>(),
225            )
226        }
227    }
228
229    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
230        cx.subscribe(
231            &self.context_server_manager.clone(),
232            Self::handle_context_server_event,
233        )
234        .detach();
235    }
236
237    fn handle_context_server_event(
238        &mut self,
239        context_server_manager: Entity<ContextServerManager>,
240        event: &context_server::manager::Event,
241        cx: &mut Context<Self>,
242    ) {
243        let tool_working_set = self.tools.clone();
244        match event {
245            context_server::manager::Event::ServerStarted { server_id } => {
246                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
247                    let context_server_manager = context_server_manager.clone();
248                    cx.spawn({
249                        let server = server.clone();
250                        let server_id = server_id.clone();
251                        async move |this, cx| {
252                            let Some(protocol) = server.client() else {
253                                return;
254                            };
255
256                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
257                                if let Some(tools) = protocol.list_tools().await.log_err() {
258                                    let tool_ids = tools
259                                        .tools
260                                        .into_iter()
261                                        .map(|tool| {
262                                            log::info!(
263                                                "registering context server tool: {:?}",
264                                                tool.name
265                                            );
266                                            tool_working_set.insert(Arc::new(
267                                                ContextServerTool::new(
268                                                    context_server_manager.clone(),
269                                                    server.id(),
270                                                    tool,
271                                                ),
272                                            ))
273                                        })
274                                        .collect::<Vec<_>>();
275
276                                    this.update(cx, |this, _cx| {
277                                        this.context_server_tool_ids.insert(server_id, tool_ids);
278                                    })
279                                    .log_err();
280                                }
281                            }
282                        }
283                    })
284                    .detach();
285                }
286            }
287            context_server::manager::Event::ServerStopped { server_id } => {
288                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
289                    tool_working_set.remove(&tool_ids);
290                }
291            }
292        }
293    }
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct SerializedThreadMetadata {
298    pub id: ThreadId,
299    pub summary: SharedString,
300    pub updated_at: DateTime<Utc>,
301}
302
303#[derive(Serialize, Deserialize)]
304pub struct SerializedThread {
305    pub version: String,
306    pub summary: SharedString,
307    pub updated_at: DateTime<Utc>,
308    pub messages: Vec<SerializedMessage>,
309    #[serde(default)]
310    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
311    #[serde(default)]
312    pub cumulative_token_usage: TokenUsage,
313}
314
315impl SerializedThread {
316    pub const VERSION: &'static str = "0.1.0";
317
318    pub fn from_json(json: &[u8]) -> Result<Self> {
319        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
320        match saved_thread_json.get("version") {
321            Some(serde_json::Value::String(version)) => match version.as_str() {
322                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
323                    saved_thread_json,
324                )?),
325                _ => Err(anyhow!(
326                    "unrecognized serialized thread version: {}",
327                    version
328                )),
329            },
330            None => {
331                let saved_thread =
332                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
333                Ok(saved_thread.upgrade())
334            }
335            version => Err(anyhow!(
336                "unrecognized serialized thread version: {:?}",
337                version
338            )),
339        }
340    }
341}
342
343#[derive(Debug, Serialize, Deserialize)]
344pub struct SerializedMessage {
345    pub id: MessageId,
346    pub role: Role,
347    #[serde(default)]
348    pub segments: Vec<SerializedMessageSegment>,
349    #[serde(default)]
350    pub tool_uses: Vec<SerializedToolUse>,
351    #[serde(default)]
352    pub tool_results: Vec<SerializedToolResult>,
353}
354
355#[derive(Debug, Serialize, Deserialize)]
356#[serde(tag = "type")]
357pub enum SerializedMessageSegment {
358    #[serde(rename = "text")]
359    Text { text: String },
360    #[serde(rename = "thinking")]
361    Thinking { text: String },
362}
363
364#[derive(Debug, Serialize, Deserialize)]
365pub struct SerializedToolUse {
366    pub id: LanguageModelToolUseId,
367    pub name: SharedString,
368    pub input: serde_json::Value,
369}
370
371#[derive(Debug, Serialize, Deserialize)]
372pub struct SerializedToolResult {
373    pub tool_use_id: LanguageModelToolUseId,
374    pub is_error: bool,
375    pub content: Arc<str>,
376}
377
378#[derive(Serialize, Deserialize)]
379struct LegacySerializedThread {
380    pub summary: SharedString,
381    pub updated_at: DateTime<Utc>,
382    pub messages: Vec<LegacySerializedMessage>,
383    #[serde(default)]
384    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
385}
386
387impl LegacySerializedThread {
388    pub fn upgrade(self) -> SerializedThread {
389        SerializedThread {
390            version: SerializedThread::VERSION.to_string(),
391            summary: self.summary,
392            updated_at: self.updated_at,
393            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
394            initial_project_snapshot: self.initial_project_snapshot,
395            cumulative_token_usage: TokenUsage::default(),
396        }
397    }
398}
399
400#[derive(Debug, Serialize, Deserialize)]
401struct LegacySerializedMessage {
402    pub id: MessageId,
403    pub role: Role,
404    pub text: String,
405    #[serde(default)]
406    pub tool_uses: Vec<SerializedToolUse>,
407    #[serde(default)]
408    pub tool_results: Vec<SerializedToolResult>,
409}
410
411impl LegacySerializedMessage {
412    fn upgrade(self) -> SerializedMessage {
413        SerializedMessage {
414            id: self.id,
415            role: self.role,
416            segments: vec![SerializedMessageSegment::Text { text: self.text }],
417            tool_uses: self.tool_uses,
418            tool_results: self.tool_results,
419        }
420    }
421}
422
423struct GlobalThreadsDatabase(
424    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
425);
426
427impl Global for GlobalThreadsDatabase {}
428
429pub(crate) struct ThreadsDatabase {
430    executor: BackgroundExecutor,
431    env: heed::Env,
432    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
433}
434
435impl heed::BytesEncode<'_> for SerializedThread {
436    type EItem = SerializedThread;
437
438    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
439        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
440    }
441}
442
443impl<'a> heed::BytesDecode<'a> for SerializedThread {
444    type DItem = SerializedThread;
445
446    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
447        // We implement this type manually because we want to call `SerializedThread::from_json`,
448        // instead of the Deserialize trait implementation for `SerializedThread`.
449        SerializedThread::from_json(bytes).map_err(Into::into)
450    }
451}
452
453impl ThreadsDatabase {
454    fn global_future(
455        cx: &mut App,
456    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
457        GlobalThreadsDatabase::global(cx).0.clone()
458    }
459
460    fn init(cx: &mut App) {
461        let executor = cx.background_executor().clone();
462        let database_future = executor
463            .spawn({
464                let executor = executor.clone();
465                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
466                async move { ThreadsDatabase::new(database_path, executor) }
467            })
468            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
469            .boxed()
470            .shared();
471
472        cx.set_global(GlobalThreadsDatabase(database_future));
473    }
474
475    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
476        std::fs::create_dir_all(&path)?;
477
478        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
479        let env = unsafe {
480            heed::EnvOpenOptions::new()
481                .map_size(ONE_GB_IN_BYTES)
482                .max_dbs(1)
483                .open(path)?
484        };
485
486        let mut txn = env.write_txn()?;
487        let threads = env.create_database(&mut txn, Some("threads"))?;
488        txn.commit()?;
489
490        Ok(Self {
491            executor,
492            env,
493            threads,
494        })
495    }
496
497    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
498        let env = self.env.clone();
499        let threads = self.threads;
500
501        self.executor.spawn(async move {
502            let txn = env.read_txn()?;
503            let mut iter = threads.iter(&txn)?;
504            let mut threads = Vec::new();
505            while let Some((key, value)) = iter.next().transpose()? {
506                threads.push(SerializedThreadMetadata {
507                    id: key,
508                    summary: value.summary,
509                    updated_at: value.updated_at,
510                });
511            }
512
513            Ok(threads)
514        })
515    }
516
517    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
518        let env = self.env.clone();
519        let threads = self.threads;
520
521        self.executor.spawn(async move {
522            let txn = env.read_txn()?;
523            let thread = threads.get(&txn, &id)?;
524            Ok(thread)
525        })
526    }
527
528    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
529        let env = self.env.clone();
530        let threads = self.threads;
531
532        self.executor.spawn(async move {
533            let mut txn = env.write_txn()?;
534            threads.put(&mut txn, &id, &thread)?;
535            txn.commit()?;
536            Ok(())
537        })
538    }
539
540    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
541        let env = self.env.clone();
542        let threads = self.threads;
543
544        self.executor.spawn(async move {
545            let mut txn = env.write_txn()?;
546            threads.delete(&mut txn, &id)?;
547            txn.commit()?;
548            Ok(())
549        })
550    }
551}