thread_store.rs

  1use std::borrow::Cow;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use assistant_settings::{AgentProfile, AssistantSettings};
  7use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
  8use chrono::{DateTime, Utc};
  9use collections::HashMap;
 10use context_server::manager::ContextServerManager;
 11use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 12use futures::FutureExt as _;
 13use futures::future::{self, BoxFuture, Shared};
 14use gpui::{
 15    App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
 16    prelude::*,
 17};
 18use heed::Database;
 19use heed::types::SerdeBincode;
 20use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 21use project::Project;
 22use prompt_store::PromptBuilder;
 23use serde::{Deserialize, Serialize};
 24use settings::{Settings as _, SettingsStore};
 25use util::ResultExt as _;
 26
 27use crate::thread::{MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId};
 28
 29pub fn init(cx: &mut App) {
 30    ThreadsDatabase::init(cx);
 31}
 32
 33pub struct ThreadStore {
 34    project: Entity<Project>,
 35    tools: Arc<ToolWorkingSet>,
 36    prompt_builder: Arc<PromptBuilder>,
 37    context_server_manager: Entity<ContextServerManager>,
 38    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 39    threads: Vec<SerializedThreadMetadata>,
 40    _subscriptions: Vec<Subscription>,
 41}
 42
 43impl ThreadStore {
 44    pub fn new(
 45        project: Entity<Project>,
 46        tools: Arc<ToolWorkingSet>,
 47        prompt_builder: Arc<PromptBuilder>,
 48        cx: &mut App,
 49    ) -> Result<Entity<Self>> {
 50        let this = cx.new(|cx| {
 51            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 52            let context_server_manager = cx.new(|cx| {
 53                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 54            });
 55            let settings_subscription =
 56                cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
 57                    this.load_default_profile(cx);
 58                });
 59
 60            let this = Self {
 61                project,
 62                tools,
 63                prompt_builder,
 64                context_server_manager,
 65                context_server_tool_ids: HashMap::default(),
 66                threads: Vec::new(),
 67                _subscriptions: vec![settings_subscription],
 68            };
 69            this.load_default_profile(cx);
 70            this.register_context_server_handlers(cx);
 71            this.reload(cx).detach_and_log_err(cx);
 72
 73            this
 74        });
 75
 76        Ok(this)
 77    }
 78
 79    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
 80        self.context_server_manager.clone()
 81    }
 82
 83    pub fn tools(&self) -> Arc<ToolWorkingSet> {
 84        self.tools.clone()
 85    }
 86
 87    /// Returns the number of threads.
 88    pub fn thread_count(&self) -> usize {
 89        self.threads.len()
 90    }
 91
 92    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
 93        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 94        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 95        threads
 96    }
 97
 98    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
 99        self.threads().into_iter().take(limit).collect()
100    }
101
102    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
103        cx.new(|cx| {
104            Thread::new(
105                self.project.clone(),
106                self.tools.clone(),
107                self.prompt_builder.clone(),
108                cx,
109            )
110        })
111    }
112
113    pub fn open_thread(
114        &self,
115        id: &ThreadId,
116        cx: &mut Context<Self>,
117    ) -> Task<Result<Entity<Thread>>> {
118        let id = id.clone();
119        let database_future = ThreadsDatabase::global_future(cx);
120        cx.spawn(async move |this, cx| {
121            let database = database_future.await.map_err(|err| anyhow!(err))?;
122            let thread = database
123                .try_find_thread(id.clone())
124                .await?
125                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
126
127            let thread = this.update(cx, |this, cx| {
128                cx.new(|cx| {
129                    Thread::deserialize(
130                        id.clone(),
131                        thread,
132                        this.project.clone(),
133                        this.tools.clone(),
134                        this.prompt_builder.clone(),
135                        cx,
136                    )
137                })
138            })?;
139
140            let (system_prompt_context, load_error) = thread
141                .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
142                .await;
143            thread.update(cx, |thread, cx| {
144                thread.set_system_prompt_context(system_prompt_context);
145                if let Some(load_error) = load_error {
146                    cx.emit(ThreadEvent::ShowError(load_error));
147                }
148            })?;
149
150            Ok(thread)
151        })
152    }
153
154    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
155        let (metadata, serialized_thread) =
156            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
157
158        let database_future = ThreadsDatabase::global_future(cx);
159        cx.spawn(async move |this, cx| {
160            let serialized_thread = serialized_thread.await?;
161            let database = database_future.await.map_err(|err| anyhow!(err))?;
162            database.save_thread(metadata, serialized_thread).await?;
163
164            this.update(cx, |this, cx| this.reload(cx))?.await
165        })
166    }
167
168    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
169        let id = id.clone();
170        let database_future = ThreadsDatabase::global_future(cx);
171        cx.spawn(async move |this, cx| {
172            let database = database_future.await.map_err(|err| anyhow!(err))?;
173            database.delete_thread(id.clone()).await?;
174
175            this.update(cx, |this, _cx| {
176                this.threads.retain(|thread| thread.id != id)
177            })
178        })
179    }
180
181    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
182        let database_future = ThreadsDatabase::global_future(cx);
183        cx.spawn(async move |this, cx| {
184            let threads = database_future
185                .await
186                .map_err(|err| anyhow!(err))?
187                .list_threads()
188                .await?;
189
190            this.update(cx, |this, cx| {
191                this.threads = threads;
192                cx.notify();
193            })
194        })
195    }
196
197    fn load_default_profile(&self, cx: &Context<Self>) {
198        let assistant_settings = AssistantSettings::get_global(cx);
199
200        self.load_profile_by_id(&assistant_settings.default_profile, cx);
201    }
202
203    pub fn load_profile_by_id(&self, profile_id: &Arc<str>, cx: &Context<Self>) {
204        let assistant_settings = AssistantSettings::get_global(cx);
205
206        if let Some(profile) = assistant_settings.profiles.get(profile_id) {
207            self.load_profile(profile, cx);
208        }
209    }
210
211    pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
212        self.tools.disable_all_tools();
213        self.tools.enable(
214            ToolSource::Native,
215            &profile
216                .tools
217                .iter()
218                .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
219                .collect::<Vec<_>>(),
220        );
221
222        if profile.enable_all_context_servers {
223            for context_server in self.context_server_manager.read(cx).all_servers() {
224                self.tools.enable_source(
225                    ToolSource::ContextServer {
226                        id: context_server.id().into(),
227                    },
228                    cx,
229                );
230            }
231        } else {
232            for (context_server_id, preset) in &profile.context_servers {
233                self.tools.enable(
234                    ToolSource::ContextServer {
235                        id: context_server_id.clone().into(),
236                    },
237                    &preset
238                        .tools
239                        .iter()
240                        .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
241                        .collect::<Vec<_>>(),
242                )
243            }
244        }
245    }
246
247    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
248        cx.subscribe(
249            &self.context_server_manager.clone(),
250            Self::handle_context_server_event,
251        )
252        .detach();
253    }
254
255    fn handle_context_server_event(
256        &mut self,
257        context_server_manager: Entity<ContextServerManager>,
258        event: &context_server::manager::Event,
259        cx: &mut Context<Self>,
260    ) {
261        let tool_working_set = self.tools.clone();
262        match event {
263            context_server::manager::Event::ServerStarted { server_id } => {
264                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
265                    let context_server_manager = context_server_manager.clone();
266                    cx.spawn({
267                        let server = server.clone();
268                        let server_id = server_id.clone();
269                        async move |this, cx| {
270                            let Some(protocol) = server.client() else {
271                                return;
272                            };
273
274                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
275                                if let Some(tools) = protocol.list_tools().await.log_err() {
276                                    let tool_ids = tools
277                                        .tools
278                                        .into_iter()
279                                        .map(|tool| {
280                                            log::info!(
281                                                "registering context server tool: {:?}",
282                                                tool.name
283                                            );
284                                            tool_working_set.insert(Arc::new(
285                                                ContextServerTool::new(
286                                                    context_server_manager.clone(),
287                                                    server.id(),
288                                                    tool,
289                                                ),
290                                            ))
291                                        })
292                                        .collect::<Vec<_>>();
293
294                                    this.update(cx, |this, cx| {
295                                        this.context_server_tool_ids.insert(server_id, tool_ids);
296                                        this.load_default_profile(cx);
297                                    })
298                                    .log_err();
299                                }
300                            }
301                        }
302                    })
303                    .detach();
304                }
305            }
306            context_server::manager::Event::ServerStopped { server_id } => {
307                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
308                    tool_working_set.remove(&tool_ids);
309                    self.load_default_profile(cx);
310                }
311            }
312        }
313    }
314}
315
316#[derive(Debug, Clone, Serialize, Deserialize)]
317pub struct SerializedThreadMetadata {
318    pub id: ThreadId,
319    pub summary: SharedString,
320    pub updated_at: DateTime<Utc>,
321}
322
323#[derive(Serialize, Deserialize)]
324pub struct SerializedThread {
325    pub version: String,
326    pub summary: SharedString,
327    pub updated_at: DateTime<Utc>,
328    pub messages: Vec<SerializedMessage>,
329    #[serde(default)]
330    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
331    #[serde(default)]
332    pub cumulative_token_usage: TokenUsage,
333}
334
335impl SerializedThread {
336    pub const VERSION: &'static str = "0.1.0";
337
338    pub fn from_json(json: &[u8]) -> Result<Self> {
339        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
340        match saved_thread_json.get("version") {
341            Some(serde_json::Value::String(version)) => match version.as_str() {
342                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
343                    saved_thread_json,
344                )?),
345                _ => Err(anyhow!(
346                    "unrecognized serialized thread version: {}",
347                    version
348                )),
349            },
350            None => {
351                let saved_thread =
352                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
353                Ok(saved_thread.upgrade())
354            }
355            version => Err(anyhow!(
356                "unrecognized serialized thread version: {:?}",
357                version
358            )),
359        }
360    }
361}
362
363#[derive(Debug, Serialize, Deserialize)]
364pub struct SerializedMessage {
365    pub id: MessageId,
366    pub role: Role,
367    #[serde(default)]
368    pub segments: Vec<SerializedMessageSegment>,
369    #[serde(default)]
370    pub tool_uses: Vec<SerializedToolUse>,
371    #[serde(default)]
372    pub tool_results: Vec<SerializedToolResult>,
373}
374
375#[derive(Debug, Serialize, Deserialize)]
376#[serde(tag = "type")]
377pub enum SerializedMessageSegment {
378    #[serde(rename = "text")]
379    Text { text: String },
380    #[serde(rename = "thinking")]
381    Thinking { text: String },
382}
383
384#[derive(Debug, Serialize, Deserialize)]
385pub struct SerializedToolUse {
386    pub id: LanguageModelToolUseId,
387    pub name: SharedString,
388    pub input: serde_json::Value,
389}
390
391#[derive(Debug, Serialize, Deserialize)]
392pub struct SerializedToolResult {
393    pub tool_use_id: LanguageModelToolUseId,
394    pub is_error: bool,
395    pub content: Arc<str>,
396}
397
398#[derive(Serialize, Deserialize)]
399struct LegacySerializedThread {
400    pub summary: SharedString,
401    pub updated_at: DateTime<Utc>,
402    pub messages: Vec<LegacySerializedMessage>,
403    #[serde(default)]
404    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
405}
406
407impl LegacySerializedThread {
408    pub fn upgrade(self) -> SerializedThread {
409        SerializedThread {
410            version: SerializedThread::VERSION.to_string(),
411            summary: self.summary,
412            updated_at: self.updated_at,
413            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
414            initial_project_snapshot: self.initial_project_snapshot,
415            cumulative_token_usage: TokenUsage::default(),
416        }
417    }
418}
419
420#[derive(Debug, Serialize, Deserialize)]
421struct LegacySerializedMessage {
422    pub id: MessageId,
423    pub role: Role,
424    pub text: String,
425    #[serde(default)]
426    pub tool_uses: Vec<SerializedToolUse>,
427    #[serde(default)]
428    pub tool_results: Vec<SerializedToolResult>,
429}
430
431impl LegacySerializedMessage {
432    fn upgrade(self) -> SerializedMessage {
433        SerializedMessage {
434            id: self.id,
435            role: self.role,
436            segments: vec![SerializedMessageSegment::Text { text: self.text }],
437            tool_uses: self.tool_uses,
438            tool_results: self.tool_results,
439        }
440    }
441}
442
443struct GlobalThreadsDatabase(
444    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
445);
446
447impl Global for GlobalThreadsDatabase {}
448
449pub(crate) struct ThreadsDatabase {
450    executor: BackgroundExecutor,
451    env: heed::Env,
452    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
453}
454
455impl heed::BytesEncode<'_> for SerializedThread {
456    type EItem = SerializedThread;
457
458    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
459        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
460    }
461}
462
463impl<'a> heed::BytesDecode<'a> for SerializedThread {
464    type DItem = SerializedThread;
465
466    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
467        // We implement this type manually because we want to call `SerializedThread::from_json`,
468        // instead of the Deserialize trait implementation for `SerializedThread`.
469        SerializedThread::from_json(bytes).map_err(Into::into)
470    }
471}
472
473impl ThreadsDatabase {
474    fn global_future(
475        cx: &mut App,
476    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
477        GlobalThreadsDatabase::global(cx).0.clone()
478    }
479
480    fn init(cx: &mut App) {
481        let executor = cx.background_executor().clone();
482        let database_future = executor
483            .spawn({
484                let executor = executor.clone();
485                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
486                async move { ThreadsDatabase::new(database_path, executor) }
487            })
488            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
489            .boxed()
490            .shared();
491
492        cx.set_global(GlobalThreadsDatabase(database_future));
493    }
494
495    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
496        std::fs::create_dir_all(&path)?;
497
498        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
499        let env = unsafe {
500            heed::EnvOpenOptions::new()
501                .map_size(ONE_GB_IN_BYTES)
502                .max_dbs(1)
503                .open(path)?
504        };
505
506        let mut txn = env.write_txn()?;
507        let threads = env.create_database(&mut txn, Some("threads"))?;
508        txn.commit()?;
509
510        Ok(Self {
511            executor,
512            env,
513            threads,
514        })
515    }
516
517    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
518        let env = self.env.clone();
519        let threads = self.threads;
520
521        self.executor.spawn(async move {
522            let txn = env.read_txn()?;
523            let mut iter = threads.iter(&txn)?;
524            let mut threads = Vec::new();
525            while let Some((key, value)) = iter.next().transpose()? {
526                threads.push(SerializedThreadMetadata {
527                    id: key,
528                    summary: value.summary,
529                    updated_at: value.updated_at,
530                });
531            }
532
533            Ok(threads)
534        })
535    }
536
537    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
538        let env = self.env.clone();
539        let threads = self.threads;
540
541        self.executor.spawn(async move {
542            let txn = env.read_txn()?;
543            let thread = threads.get(&txn, &id)?;
544            Ok(thread)
545        })
546    }
547
548    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
549        let env = self.env.clone();
550        let threads = self.threads;
551
552        self.executor.spawn(async move {
553            let mut txn = env.write_txn()?;
554            threads.put(&mut txn, &id, &thread)?;
555            txn.commit()?;
556            Ok(())
557        })
558    }
559
560    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
561        let env = self.env.clone();
562        let threads = self.threads;
563
564        self.executor.spawn(async move {
565            let mut txn = env.write_txn()?;
566            threads.delete(&mut txn, &id)?;
567            txn.commit()?;
568            Ok(())
569        })
570    }
571}