prompt_library.rs

  1use anyhow::Context;
  2use collections::HashMap;
  3use fs::Fs;
  4
  5use gray_matter::{engine::YAML, Matter};
  6use parking_lot::RwLock;
  7use serde::{Deserialize, Serialize};
  8use smol::stream::StreamExt;
  9use std::sync::Arc;
 10use util::paths::PROMPTS_DIR;
 11use uuid::Uuid;
 12
 13use super::prompt::StaticPrompt;
 14
 15#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize)]
 16pub struct PromptId(pub Uuid);
 17
 18#[derive(Debug, Clone, Copy, Eq, PartialEq)]
 19pub enum SortOrder {
 20    Alphabetical,
 21}
 22
 23#[allow(unused)]
 24impl PromptId {
 25    pub fn new() -> Self {
 26        Self(Uuid::new_v4())
 27    }
 28
 29    pub fn from_str(id: &str) -> anyhow::Result<Self> {
 30        Ok(Self(Uuid::parse_str(id)?))
 31    }
 32}
 33
 34impl Default for PromptId {
 35    fn default() -> Self {
 36        Self::new()
 37    }
 38}
 39
 40#[derive(Default, Serialize, Deserialize)]
 41pub struct PromptLibraryState {
 42    /// A set of prompts that all assistant contexts will start with
 43    default_prompt: Vec<PromptId>,
 44    /// All [Prompt]s loaded into the library
 45    prompts: HashMap<PromptId, StaticPrompt>,
 46    /// Prompts that have been changed but haven't been
 47    /// saved back to the file system
 48    dirty_prompts: Vec<PromptId>,
 49    version: usize,
 50}
 51
 52pub struct PromptLibrary {
 53    state: RwLock<PromptLibraryState>,
 54}
 55
 56impl Default for PromptLibrary {
 57    fn default() -> Self {
 58        Self::new()
 59    }
 60}
 61
 62impl PromptLibrary {
 63    fn new() -> Self {
 64        Self {
 65            state: RwLock::new(PromptLibraryState::default()),
 66        }
 67    }
 68
 69    pub fn new_prompt(&self) -> StaticPrompt {
 70        StaticPrompt::default()
 71    }
 72
 73    pub fn add_prompt(&self, prompt: StaticPrompt) {
 74        let mut state = self.state.write();
 75        let id = *prompt.id();
 76        state.prompts.insert(id, prompt);
 77        state.version += 1;
 78    }
 79
 80    pub fn prompts(&self) -> HashMap<PromptId, StaticPrompt> {
 81        let state = self.state.read();
 82        state.prompts.clone()
 83    }
 84
 85    pub fn sorted_prompts(&self, sort_order: SortOrder) -> Vec<(PromptId, StaticPrompt)> {
 86        let state = self.state.read();
 87
 88        let mut prompts = state
 89            .prompts
 90            .iter()
 91            .map(|(id, prompt)| (*id, prompt.clone()))
 92            .collect::<Vec<_>>();
 93
 94        match sort_order {
 95            SortOrder::Alphabetical => prompts.sort_by(|(_, a), (_, b)| a.title().cmp(&b.title())),
 96        };
 97
 98        prompts
 99    }
100
101    pub fn prompt_by_id(&self, id: PromptId) -> Option<StaticPrompt> {
102        let state = self.state.read();
103        state.prompts.get(&id).cloned()
104    }
105
106    pub fn first_prompt_id(&self) -> Option<PromptId> {
107        let state = self.state.read();
108        state.prompts.keys().next().cloned()
109    }
110
111    pub fn is_dirty(&self, id: &PromptId) -> bool {
112        let state = self.state.read();
113        state.dirty_prompts.contains(&id)
114    }
115
116    pub fn set_dirty(&self, id: PromptId, dirty: bool) {
117        let mut state = self.state.write();
118        if dirty {
119            if !state.dirty_prompts.contains(&id) {
120                state.dirty_prompts.push(id);
121            }
122            state.version += 1;
123        } else {
124            state.dirty_prompts.retain(|&i| i != id);
125            state.version += 1;
126        }
127    }
128
129    /// Load the state of the prompt library from the file system
130    /// or create a new one if it doesn't exist
131    pub async fn load_index(fs: Arc<dyn Fs>) -> anyhow::Result<Self> {
132        let path = PROMPTS_DIR.join("index.json");
133
134        let state = if fs.is_file(&path).await {
135            let json = fs.load(&path).await?;
136            serde_json::from_str(&json)?
137        } else {
138            PromptLibraryState::default()
139        };
140
141        let mut prompt_library = Self {
142            state: RwLock::new(state),
143        };
144
145        prompt_library.load_prompts(fs).await?;
146
147        Ok(prompt_library)
148    }
149
150    /// Load all prompts from the file system
151    /// adding them to the library if they don't already exist
152    pub async fn load_prompts(&mut self, fs: Arc<dyn Fs>) -> anyhow::Result<()> {
153        self.state.get_mut().prompts.clear();
154
155        let mut prompt_paths = fs.read_dir(&PROMPTS_DIR).await?;
156
157        while let Some(prompt_path) = prompt_paths.next().await {
158            let prompt_path = prompt_path.with_context(|| "Failed to read prompt path")?;
159            let file_name_lossy = if prompt_path.file_name().is_some() {
160                Some(
161                    prompt_path
162                        .file_name()
163                        .unwrap()
164                        .to_string_lossy()
165                        .to_string(),
166                )
167            } else {
168                None
169            };
170
171            if !fs.is_file(&prompt_path).await
172                || prompt_path.extension().and_then(|ext| ext.to_str()) != Some("md")
173            {
174                continue;
175            }
176
177            let json = fs
178                .load(&prompt_path)
179                .await
180                .with_context(|| format!("Failed to load prompt {:?}", prompt_path))?;
181
182            // Check that the prompt is valid
183            let matter = Matter::<YAML>::new();
184            let result = matter.parse(&json);
185            if result.data.is_none() {
186                log::warn!("Invalid prompt: {:?}", prompt_path);
187                continue;
188            }
189
190            let static_prompt = StaticPrompt::new(json, file_name_lossy.clone());
191
192            let state = self.state.get_mut();
193
194            let id = Uuid::new_v4();
195            state.prompts.insert(PromptId(id), static_prompt);
196            state.version += 1;
197        }
198
199        // Write any changes back to the file system
200        self.save_index(fs.clone()).await?;
201
202        Ok(())
203    }
204
205    /// Save the current state of the prompt library to the
206    /// file system as a JSON file
207    pub async fn save_index(&self, fs: Arc<dyn Fs>) -> anyhow::Result<()> {
208        fs.create_dir(&PROMPTS_DIR).await?;
209
210        let path = PROMPTS_DIR.join("index.json");
211
212        let json = {
213            let state = self.state.read();
214            serde_json::to_string(&*state)?
215        };
216
217        fs.atomic_write(path, json).await?;
218
219        Ok(())
220    }
221
222    pub async fn save_prompt(
223        &self,
224        prompt_id: PromptId,
225        updated_content: Option<String>,
226        fs: Arc<dyn Fs>,
227    ) -> anyhow::Result<()> {
228        if let Some(updated_content) = updated_content {
229            let mut state = self.state.write();
230            if let Some(prompt) = state.prompts.get_mut(&prompt_id) {
231                prompt.update(prompt_id, updated_content);
232                state.version += 1;
233            }
234        }
235
236        if let Some(prompt) = self.prompt_by_id(prompt_id) {
237            prompt.save(fs).await?;
238            self.set_dirty(prompt_id, false);
239        } else {
240            log::warn!("Failed to save prompt: {:?}", prompt_id);
241        }
242
243        Ok(())
244    }
245}