Add support for backwards compatibility in `PromptStore` (#15602)

Antonio Scandurra and Nathan created

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>

Change summary

crates/assistant/src/prompt_library.rs | 119 ++++++++++++++++++++++++---
1 file changed, 106 insertions(+), 13 deletions(-)

Detailed changes

crates/assistant/src/prompt_library.rs 🔗

@@ -16,7 +16,10 @@ use gpui::{
     EventEmitter, Global, HighlightStyle, PromptLevel, ReadGlobal, Subscription, Task, TextStyle,
     TitlebarOptions, UpdateGlobal, View, WindowBounds, WindowHandle, WindowOptions,
 };
-use heed::{types::SerdeBincode, Database, RoTxn};
+use heed::{
+    types::{SerdeBincode, SerdeJson, Str},
+    Database, RoTxn,
+};
 use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
 use language_model::{
     LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage, Role,
@@ -1059,23 +1062,30 @@ pub struct PromptMetadata {
     pub title: Option<SharedString>,
     pub default: bool,
     pub saved_at: DateTime<Utc>,
+    pub built_in: bool,
 }
 
 #[derive(Clone, Copy, Debug, PartialEq, Eq, Hash, Serialize, Deserialize)]
-pub struct PromptId(Uuid);
+#[serde(tag = "kind")]
+pub enum PromptId {
+    User { uuid: Uuid },
+    EditWorkflow,
+}
 
 impl PromptId {
     pub fn new() -> PromptId {
-        PromptId(Uuid::new_v4())
+        PromptId::User {
+            uuid: Uuid::new_v4(),
+        }
     }
 }
 
 pub struct PromptStore {
     executor: BackgroundExecutor,
     env: heed::Env,
-    bodies: Database<SerdeBincode<PromptId>, SerdeBincode<String>>,
-    metadata: Database<SerdeBincode<PromptId>, SerdeBincode<PromptMetadata>>,
     metadata_cache: RwLock<MetadataCache>,
+    metadata: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
+    bodies: Database<SerdeJson<PromptId>, Str>,
 }
 
 #[derive(Default)]
@@ -1086,7 +1096,7 @@ struct MetadataCache {
 
 impl MetadataCache {
     fn from_db(
-        db: Database<SerdeBincode<PromptId>, SerdeBincode<PromptMetadata>>,
+        db: Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
         txn: &RoTxn,
     ) -> Result<Self> {
         let mut cache = MetadataCache::default();
@@ -1138,35 +1148,116 @@ impl PromptStore {
                 let db_env = unsafe {
                     heed::EnvOpenOptions::new()
                         .map_size(1024 * 1024 * 1024) // 1GB
-                        .max_dbs(2) // bodies and metadata
+                        .max_dbs(4) // Metadata and bodies (possibly v1 of both as well)
                         .open(db_path)?
                 };
 
                 let mut txn = db_env.write_txn()?;
-                let bodies = db_env.create_database(&mut txn, Some("bodies"))?;
-                let metadata = db_env.create_database(&mut txn, Some("metadata"))?;
+                let metadata = db_env.create_database(&mut txn, Some("metadata.v2"))?;
+                let bodies = db_env.create_database(&mut txn, Some("bodies.v2"))?;
+                txn.commit()?;
+
+                Self::upgrade_dbs(&db_env, metadata, bodies).log_err();
+
+                let txn = db_env.read_txn()?;
                 let metadata_cache = MetadataCache::from_db(metadata, &txn)?;
                 txn.commit()?;
 
                 Ok(PromptStore {
                     executor,
                     env: db_env,
-                    bodies,
-                    metadata,
                     metadata_cache: RwLock::new(metadata_cache),
+                    metadata,
+                    bodies,
                 })
             }
         })
     }
 
+    fn upgrade_dbs(
+        env: &heed::Env,
+        metadata_db: heed::Database<SerdeJson<PromptId>, SerdeJson<PromptMetadata>>,
+        bodies_db: heed::Database<SerdeJson<PromptId>, Str>,
+    ) -> Result<()> {
+        #[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize, Hash)]
+        pub struct PromptIdV1(Uuid);
+
+        #[derive(Clone, Debug, Serialize, Deserialize)]
+        pub struct PromptMetadataV1 {
+            pub id: PromptIdV1,
+            pub title: Option<SharedString>,
+            pub default: bool,
+            pub saved_at: DateTime<Utc>,
+        }
+
+        let mut txn = env.write_txn()?;
+        let Some(bodies_v1_db) = env
+            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<String>>(
+                &txn,
+                Some("bodies"),
+            )?
+        else {
+            return Ok(());
+        };
+        let mut bodies_v1 = bodies_v1_db
+            .iter(&txn)?
+            .collect::<heed::Result<HashMap<_, _>>>()?;
+
+        let Some(metadata_v1_db) = env
+            .open_database::<SerdeBincode<PromptIdV1>, SerdeBincode<PromptMetadataV1>>(
+                &txn,
+                Some("metadata"),
+            )?
+        else {
+            return Ok(());
+        };
+        let metadata_v1 = metadata_v1_db
+            .iter(&txn)?
+            .collect::<heed::Result<HashMap<_, _>>>()?;
+
+        for (prompt_id_v1, metadata_v1) in metadata_v1 {
+            let prompt_id_v2 = PromptId::User {
+                uuid: prompt_id_v1.0,
+            };
+            let Some(body_v1) = bodies_v1.remove(&prompt_id_v1) else {
+                continue;
+            };
+
+            if metadata_db
+                .get(&txn, &prompt_id_v2)?
+                .map_or(true, |metadata_v2| {
+                    metadata_v1.saved_at > metadata_v2.saved_at
+                })
+            {
+                metadata_db.put(
+                    &mut txn,
+                    &prompt_id_v2,
+                    &PromptMetadata {
+                        id: prompt_id_v2,
+                        title: metadata_v1.title.clone(),
+                        default: metadata_v1.default,
+                        saved_at: metadata_v1.saved_at,
+                        built_in: false,
+                    },
+                )?;
+                bodies_db.put(&mut txn, &prompt_id_v2, &body_v1)?;
+            }
+        }
+
+        txn.commit()?;
+
+        Ok(())
+    }
+
     pub fn load(&self, id: PromptId) -> Task<Result<String>> {
         let env = self.env.clone();
         let bodies = self.bodies;
         self.executor.spawn(async move {
             let txn = env.read_txn()?;
-            bodies
+            Ok(bodies
                 .get(&txn, &id)?
-                .ok_or_else(|| anyhow!("prompt not found"))
+                .ok_or_else(|| anyhow!("prompt not found"))?
+                .into())
         })
     }
 
@@ -1260,6 +1351,7 @@ impl PromptStore {
             title,
             default,
             saved_at: Utc::now(),
+            built_in: false,
         };
         self.metadata_cache.write().insert(prompt_metadata.clone());
 
@@ -1290,6 +1382,7 @@ impl PromptStore {
             title,
             default,
             saved_at: Utc::now(),
+            built_in: false,
         };
         self.metadata_cache.write().insert(prompt_metadata.clone());