thread_store.rs

  1use std::borrow::Cow;
  2use std::path::PathBuf;
  3use std::sync::Arc;
  4
  5use anyhow::{Result, anyhow};
  6use assistant_settings::{AgentProfile, AgentProfileId, AssistantSettings};
  7use assistant_tool::{ToolId, ToolSource, ToolWorkingSet};
  8use chrono::{DateTime, Utc};
  9use collections::HashMap;
 10use context_server::manager::ContextServerManager;
 11use context_server::{ContextServerFactoryRegistry, ContextServerTool};
 12use futures::FutureExt as _;
 13use futures::future::{self, BoxFuture, Shared};
 14use gpui::{
 15    App, BackgroundExecutor, Context, Entity, Global, ReadGlobal, SharedString, Subscription, Task,
 16    prelude::*,
 17};
 18use heed::Database;
 19use heed::types::SerdeBincode;
 20use language_model::{LanguageModelToolUseId, Role, TokenUsage};
 21use project::Project;
 22use prompt_store::PromptBuilder;
 23use serde::{Deserialize, Serialize};
 24use settings::{Settings as _, SettingsStore};
 25use util::ResultExt as _;
 26
 27use crate::thread::{
 28    DetailedSummaryState, MessageId, ProjectSnapshot, Thread, ThreadEvent, ThreadId,
 29};
 30
 31pub fn init(cx: &mut App) {
 32    ThreadsDatabase::init(cx);
 33}
 34
 35pub struct ThreadStore {
 36    project: Entity<Project>,
 37    tools: Arc<ToolWorkingSet>,
 38    prompt_builder: Arc<PromptBuilder>,
 39    context_server_manager: Entity<ContextServerManager>,
 40    context_server_tool_ids: HashMap<Arc<str>, Vec<ToolId>>,
 41    threads: Vec<SerializedThreadMetadata>,
 42    _subscriptions: Vec<Subscription>,
 43}
 44
 45impl ThreadStore {
 46    pub fn new(
 47        project: Entity<Project>,
 48        tools: Arc<ToolWorkingSet>,
 49        prompt_builder: Arc<PromptBuilder>,
 50        cx: &mut App,
 51    ) -> Result<Entity<Self>> {
 52        let this = cx.new(|cx| {
 53            let context_server_factory_registry = ContextServerFactoryRegistry::default_global(cx);
 54            let context_server_manager = cx.new(|cx| {
 55                ContextServerManager::new(context_server_factory_registry, project.clone(), cx)
 56            });
 57            let settings_subscription =
 58                cx.observe_global::<SettingsStore>(move |this: &mut Self, cx| {
 59                    this.load_default_profile(cx);
 60                });
 61
 62            let this = Self {
 63                project,
 64                tools,
 65                prompt_builder,
 66                context_server_manager,
 67                context_server_tool_ids: HashMap::default(),
 68                threads: Vec::new(),
 69                _subscriptions: vec![settings_subscription],
 70            };
 71            this.load_default_profile(cx);
 72            this.register_context_server_handlers(cx);
 73            this.reload(cx).detach_and_log_err(cx);
 74
 75            this
 76        });
 77
 78        Ok(this)
 79    }
 80
 81    pub fn context_server_manager(&self) -> Entity<ContextServerManager> {
 82        self.context_server_manager.clone()
 83    }
 84
 85    pub fn tools(&self) -> Arc<ToolWorkingSet> {
 86        self.tools.clone()
 87    }
 88
 89    /// Returns the number of threads.
 90    pub fn thread_count(&self) -> usize {
 91        self.threads.len()
 92    }
 93
 94    pub fn threads(&self) -> Vec<SerializedThreadMetadata> {
 95        let mut threads = self.threads.iter().cloned().collect::<Vec<_>>();
 96        threads.sort_unstable_by_key(|thread| std::cmp::Reverse(thread.updated_at));
 97        threads
 98    }
 99
100    pub fn recent_threads(&self, limit: usize) -> Vec<SerializedThreadMetadata> {
101        self.threads().into_iter().take(limit).collect()
102    }
103
104    pub fn create_thread(&mut self, cx: &mut Context<Self>) -> Entity<Thread> {
105        cx.new(|cx| {
106            Thread::new(
107                self.project.clone(),
108                self.tools.clone(),
109                self.prompt_builder.clone(),
110                cx,
111            )
112        })
113    }
114
115    pub fn open_thread(
116        &self,
117        id: &ThreadId,
118        cx: &mut Context<Self>,
119    ) -> Task<Result<Entity<Thread>>> {
120        let id = id.clone();
121        let database_future = ThreadsDatabase::global_future(cx);
122        cx.spawn(async move |this, cx| {
123            let database = database_future.await.map_err(|err| anyhow!(err))?;
124            let thread = database
125                .try_find_thread(id.clone())
126                .await?
127                .ok_or_else(|| anyhow!("no thread found with ID: {id:?}"))?;
128
129            let thread = this.update(cx, |this, cx| {
130                cx.new(|cx| {
131                    Thread::deserialize(
132                        id.clone(),
133                        thread,
134                        this.project.clone(),
135                        this.tools.clone(),
136                        this.prompt_builder.clone(),
137                        cx,
138                    )
139                })
140            })?;
141
142            let (system_prompt_context, load_error) = thread
143                .update(cx, |thread, cx| thread.load_system_prompt_context(cx))?
144                .await;
145            thread.update(cx, |thread, cx| {
146                thread.set_system_prompt_context(system_prompt_context);
147                if let Some(load_error) = load_error {
148                    cx.emit(ThreadEvent::ShowError(load_error));
149                }
150            })?;
151
152            Ok(thread)
153        })
154    }
155
156    pub fn save_thread(&self, thread: &Entity<Thread>, cx: &mut Context<Self>) -> Task<Result<()>> {
157        let (metadata, serialized_thread) =
158            thread.update(cx, |thread, cx| (thread.id().clone(), thread.serialize(cx)));
159
160        let database_future = ThreadsDatabase::global_future(cx);
161        cx.spawn(async move |this, cx| {
162            let serialized_thread = serialized_thread.await?;
163            let database = database_future.await.map_err(|err| anyhow!(err))?;
164            database.save_thread(metadata, serialized_thread).await?;
165
166            this.update(cx, |this, cx| this.reload(cx))?.await
167        })
168    }
169
170    pub fn delete_thread(&mut self, id: &ThreadId, cx: &mut Context<Self>) -> Task<Result<()>> {
171        let id = id.clone();
172        let database_future = ThreadsDatabase::global_future(cx);
173        cx.spawn(async move |this, cx| {
174            let database = database_future.await.map_err(|err| anyhow!(err))?;
175            database.delete_thread(id.clone()).await?;
176
177            this.update(cx, |this, _cx| {
178                this.threads.retain(|thread| thread.id != id)
179            })
180        })
181    }
182
183    pub fn reload(&self, cx: &mut Context<Self>) -> Task<Result<()>> {
184        let database_future = ThreadsDatabase::global_future(cx);
185        cx.spawn(async move |this, cx| {
186            let threads = database_future
187                .await
188                .map_err(|err| anyhow!(err))?
189                .list_threads()
190                .await?;
191
192            this.update(cx, |this, cx| {
193                this.threads = threads;
194                cx.notify();
195            })
196        })
197    }
198
199    fn load_default_profile(&self, cx: &Context<Self>) {
200        let assistant_settings = AssistantSettings::get_global(cx);
201
202        self.load_profile_by_id(&assistant_settings.default_profile, cx);
203    }
204
205    pub fn load_profile_by_id(&self, profile_id: &AgentProfileId, cx: &Context<Self>) {
206        let assistant_settings = AssistantSettings::get_global(cx);
207
208        if let Some(profile) = assistant_settings.profiles.get(profile_id) {
209            self.load_profile(profile, cx);
210        }
211    }
212
213    pub fn load_profile(&self, profile: &AgentProfile, cx: &Context<Self>) {
214        self.tools.disable_all_tools();
215        self.tools.enable(
216            ToolSource::Native,
217            &profile
218                .tools
219                .iter()
220                .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
221                .collect::<Vec<_>>(),
222        );
223
224        if profile.enable_all_context_servers {
225            for context_server in self.context_server_manager.read(cx).all_servers() {
226                self.tools.enable_source(
227                    ToolSource::ContextServer {
228                        id: context_server.id().into(),
229                    },
230                    cx,
231                );
232            }
233        } else {
234            for (context_server_id, preset) in &profile.context_servers {
235                self.tools.enable(
236                    ToolSource::ContextServer {
237                        id: context_server_id.clone().into(),
238                    },
239                    &preset
240                        .tools
241                        .iter()
242                        .filter_map(|(tool, enabled)| enabled.then(|| tool.clone()))
243                        .collect::<Vec<_>>(),
244                )
245            }
246        }
247    }
248
249    fn register_context_server_handlers(&self, cx: &mut Context<Self>) {
250        cx.subscribe(
251            &self.context_server_manager.clone(),
252            Self::handle_context_server_event,
253        )
254        .detach();
255    }
256
257    fn handle_context_server_event(
258        &mut self,
259        context_server_manager: Entity<ContextServerManager>,
260        event: &context_server::manager::Event,
261        cx: &mut Context<Self>,
262    ) {
263        let tool_working_set = self.tools.clone();
264        match event {
265            context_server::manager::Event::ServerStarted { server_id } => {
266                if let Some(server) = context_server_manager.read(cx).get_server(server_id) {
267                    let context_server_manager = context_server_manager.clone();
268                    cx.spawn({
269                        let server = server.clone();
270                        let server_id = server_id.clone();
271                        async move |this, cx| {
272                            let Some(protocol) = server.client() else {
273                                return;
274                            };
275
276                            if protocol.capable(context_server::protocol::ServerCapability::Tools) {
277                                if let Some(tools) = protocol.list_tools().await.log_err() {
278                                    let tool_ids = tools
279                                        .tools
280                                        .into_iter()
281                                        .map(|tool| {
282                                            log::info!(
283                                                "registering context server tool: {:?}",
284                                                tool.name
285                                            );
286                                            tool_working_set.insert(Arc::new(
287                                                ContextServerTool::new(
288                                                    context_server_manager.clone(),
289                                                    server.id(),
290                                                    tool,
291                                                ),
292                                            ))
293                                        })
294                                        .collect::<Vec<_>>();
295
296                                    this.update(cx, |this, cx| {
297                                        this.context_server_tool_ids.insert(server_id, tool_ids);
298                                        this.load_default_profile(cx);
299                                    })
300                                    .log_err();
301                                }
302                            }
303                        }
304                    })
305                    .detach();
306                }
307            }
308            context_server::manager::Event::ServerStopped { server_id } => {
309                if let Some(tool_ids) = self.context_server_tool_ids.remove(server_id) {
310                    tool_working_set.remove(&tool_ids);
311                    self.load_default_profile(cx);
312                }
313            }
314        }
315    }
316}
317
318#[derive(Debug, Clone, Serialize, Deserialize)]
319pub struct SerializedThreadMetadata {
320    pub id: ThreadId,
321    pub summary: SharedString,
322    pub updated_at: DateTime<Utc>,
323}
324
325#[derive(Serialize, Deserialize, Debug)]
326pub struct SerializedThread {
327    pub version: String,
328    pub summary: SharedString,
329    pub updated_at: DateTime<Utc>,
330    pub messages: Vec<SerializedMessage>,
331    #[serde(default)]
332    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
333    #[serde(default)]
334    pub cumulative_token_usage: TokenUsage,
335    #[serde(default)]
336    pub detailed_summary_state: DetailedSummaryState,
337}
338
339impl SerializedThread {
340    pub const VERSION: &'static str = "0.1.0";
341
342    pub fn from_json(json: &[u8]) -> Result<Self> {
343        let saved_thread_json = serde_json::from_slice::<serde_json::Value>(json)?;
344        match saved_thread_json.get("version") {
345            Some(serde_json::Value::String(version)) => match version.as_str() {
346                SerializedThread::VERSION => Ok(serde_json::from_value::<SerializedThread>(
347                    saved_thread_json,
348                )?),
349                _ => Err(anyhow!(
350                    "unrecognized serialized thread version: {}",
351                    version
352                )),
353            },
354            None => {
355                let saved_thread =
356                    serde_json::from_value::<LegacySerializedThread>(saved_thread_json)?;
357                Ok(saved_thread.upgrade())
358            }
359            version => Err(anyhow!(
360                "unrecognized serialized thread version: {:?}",
361                version
362            )),
363        }
364    }
365}
366
367#[derive(Debug, Serialize, Deserialize)]
368pub struct SerializedMessage {
369    pub id: MessageId,
370    pub role: Role,
371    #[serde(default)]
372    pub segments: Vec<SerializedMessageSegment>,
373    #[serde(default)]
374    pub tool_uses: Vec<SerializedToolUse>,
375    #[serde(default)]
376    pub tool_results: Vec<SerializedToolResult>,
377    #[serde(default)]
378    pub context: String,
379}
380
381#[derive(Debug, Serialize, Deserialize)]
382#[serde(tag = "type")]
383pub enum SerializedMessageSegment {
384    #[serde(rename = "text")]
385    Text { text: String },
386    #[serde(rename = "thinking")]
387    Thinking { text: String },
388}
389
390#[derive(Debug, Serialize, Deserialize)]
391pub struct SerializedToolUse {
392    pub id: LanguageModelToolUseId,
393    pub name: SharedString,
394    pub input: serde_json::Value,
395}
396
397#[derive(Debug, Serialize, Deserialize)]
398pub struct SerializedToolResult {
399    pub tool_use_id: LanguageModelToolUseId,
400    pub is_error: bool,
401    pub content: Arc<str>,
402}
403
404#[derive(Serialize, Deserialize)]
405struct LegacySerializedThread {
406    pub summary: SharedString,
407    pub updated_at: DateTime<Utc>,
408    pub messages: Vec<LegacySerializedMessage>,
409    #[serde(default)]
410    pub initial_project_snapshot: Option<Arc<ProjectSnapshot>>,
411}
412
413impl LegacySerializedThread {
414    pub fn upgrade(self) -> SerializedThread {
415        SerializedThread {
416            version: SerializedThread::VERSION.to_string(),
417            summary: self.summary,
418            updated_at: self.updated_at,
419            messages: self.messages.into_iter().map(|msg| msg.upgrade()).collect(),
420            initial_project_snapshot: self.initial_project_snapshot,
421            cumulative_token_usage: TokenUsage::default(),
422            detailed_summary_state: DetailedSummaryState::default(),
423        }
424    }
425}
426
427#[derive(Debug, Serialize, Deserialize)]
428struct LegacySerializedMessage {
429    pub id: MessageId,
430    pub role: Role,
431    pub text: String,
432    #[serde(default)]
433    pub tool_uses: Vec<SerializedToolUse>,
434    #[serde(default)]
435    pub tool_results: Vec<SerializedToolResult>,
436}
437
438impl LegacySerializedMessage {
439    fn upgrade(self) -> SerializedMessage {
440        SerializedMessage {
441            id: self.id,
442            role: self.role,
443            segments: vec![SerializedMessageSegment::Text { text: self.text }],
444            tool_uses: self.tool_uses,
445            tool_results: self.tool_results,
446            context: String::new(),
447        }
448    }
449}
450
451struct GlobalThreadsDatabase(
452    Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>>,
453);
454
455impl Global for GlobalThreadsDatabase {}
456
457pub(crate) struct ThreadsDatabase {
458    executor: BackgroundExecutor,
459    env: heed::Env,
460    threads: Database<SerdeBincode<ThreadId>, SerializedThread>,
461}
462
463impl heed::BytesEncode<'_> for SerializedThread {
464    type EItem = SerializedThread;
465
466    fn bytes_encode(item: &Self::EItem) -> Result<Cow<[u8]>, heed::BoxedError> {
467        serde_json::to_vec(item).map(Cow::Owned).map_err(Into::into)
468    }
469}
470
471impl<'a> heed::BytesDecode<'a> for SerializedThread {
472    type DItem = SerializedThread;
473
474    fn bytes_decode(bytes: &'a [u8]) -> Result<Self::DItem, heed::BoxedError> {
475        // We implement this type manually because we want to call `SerializedThread::from_json`,
476        // instead of the Deserialize trait implementation for `SerializedThread`.
477        SerializedThread::from_json(bytes).map_err(Into::into)
478    }
479}
480
481impl ThreadsDatabase {
482    fn global_future(
483        cx: &mut App,
484    ) -> Shared<BoxFuture<'static, Result<Arc<ThreadsDatabase>, Arc<anyhow::Error>>>> {
485        GlobalThreadsDatabase::global(cx).0.clone()
486    }
487
488    fn init(cx: &mut App) {
489        let executor = cx.background_executor().clone();
490        let database_future = executor
491            .spawn({
492                let executor = executor.clone();
493                let database_path = paths::support_dir().join("threads/threads-db.1.mdb");
494                async move { ThreadsDatabase::new(database_path, executor) }
495            })
496            .then(|result| future::ready(result.map(Arc::new).map_err(Arc::new)))
497            .boxed()
498            .shared();
499
500        cx.set_global(GlobalThreadsDatabase(database_future));
501    }
502
503    pub fn new(path: PathBuf, executor: BackgroundExecutor) -> Result<Self> {
504        std::fs::create_dir_all(&path)?;
505
506        const ONE_GB_IN_BYTES: usize = 1024 * 1024 * 1024;
507        let env = unsafe {
508            heed::EnvOpenOptions::new()
509                .map_size(ONE_GB_IN_BYTES)
510                .max_dbs(1)
511                .open(path)?
512        };
513
514        let mut txn = env.write_txn()?;
515        let threads = env.create_database(&mut txn, Some("threads"))?;
516        txn.commit()?;
517
518        Ok(Self {
519            executor,
520            env,
521            threads,
522        })
523    }
524
525    pub fn list_threads(&self) -> Task<Result<Vec<SerializedThreadMetadata>>> {
526        let env = self.env.clone();
527        let threads = self.threads;
528
529        self.executor.spawn(async move {
530            let txn = env.read_txn()?;
531            let mut iter = threads.iter(&txn)?;
532            let mut threads = Vec::new();
533            while let Some((key, value)) = iter.next().transpose()? {
534                threads.push(SerializedThreadMetadata {
535                    id: key,
536                    summary: value.summary,
537                    updated_at: value.updated_at,
538                });
539            }
540
541            Ok(threads)
542        })
543    }
544
545    pub fn try_find_thread(&self, id: ThreadId) -> Task<Result<Option<SerializedThread>>> {
546        let env = self.env.clone();
547        let threads = self.threads;
548
549        self.executor.spawn(async move {
550            let txn = env.read_txn()?;
551            let thread = threads.get(&txn, &id)?;
552            Ok(thread)
553        })
554    }
555
556    pub fn save_thread(&self, id: ThreadId, thread: SerializedThread) -> Task<Result<()>> {
557        let env = self.env.clone();
558        let threads = self.threads;
559
560        self.executor.spawn(async move {
561            let mut txn = env.write_txn()?;
562            threads.put(&mut txn, &id, &thread)?;
563            txn.commit()?;
564            Ok(())
565        })
566    }
567
568    pub fn delete_thread(&self, id: ThreadId) -> Task<Result<()>> {
569        let env = self.env.clone();
570        let threads = self.threads;
571
572        self.executor.spawn(async move {
573            let mut txn = env.write_txn()?;
574            threads.delete(&mut txn, &id)?;
575            txn.commit()?;
576            Ok(())
577        })
578    }
579}