1use anyhow::Context;
2use collections::HashMap;
3use fs::Fs;
4
5use gray_matter::{engine::YAML, Matter};
6use parking_lot::RwLock;
7use serde::{Deserialize, Serialize};
8use smol::stream::StreamExt;
9use std::sync::Arc;
10use util::paths::PROMPTS_DIR;
11use uuid::Uuid;
12
13use super::prompt::StaticPrompt;
14
15#[derive(Debug, Clone, Copy, Eq, PartialEq, Hash, Serialize, Deserialize)]
16pub struct PromptId(pub Uuid);
17
18#[derive(Debug, Clone, Copy, Eq, PartialEq)]
19pub enum SortOrder {
20 Alphabetical,
21}
22
23#[allow(unused)]
24impl PromptId {
25 pub fn new() -> Self {
26 Self(Uuid::new_v4())
27 }
28
29 pub fn from_str(id: &str) -> anyhow::Result<Self> {
30 Ok(Self(Uuid::parse_str(id)?))
31 }
32}
33
34impl Default for PromptId {
35 fn default() -> Self {
36 Self::new()
37 }
38}
39
40#[derive(Default, Serialize, Deserialize)]
41pub struct PromptLibraryState {
42 /// A set of prompts that all assistant contexts will start with
43 default_prompt: Vec<PromptId>,
44 /// All [Prompt]s loaded into the library
45 prompts: HashMap<PromptId, StaticPrompt>,
46 /// Prompts that have been changed but haven't been
47 /// saved back to the file system
48 dirty_prompts: Vec<PromptId>,
49 version: usize,
50}
51
52pub struct PromptLibrary {
53 state: RwLock<PromptLibraryState>,
54}
55
56impl Default for PromptLibrary {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62impl PromptLibrary {
63 fn new() -> Self {
64 Self {
65 state: RwLock::new(PromptLibraryState::default()),
66 }
67 }
68
69 pub fn new_prompt(&self) -> StaticPrompt {
70 StaticPrompt::default()
71 }
72
73 pub fn add_prompt(&self, prompt: StaticPrompt) {
74 let mut state = self.state.write();
75 let id = *prompt.id();
76 state.prompts.insert(id, prompt);
77 state.version += 1;
78 }
79
80 pub fn prompts(&self) -> HashMap<PromptId, StaticPrompt> {
81 let state = self.state.read();
82 state.prompts.clone()
83 }
84
85 pub fn sorted_prompts(&self, sort_order: SortOrder) -> Vec<(PromptId, StaticPrompt)> {
86 let state = self.state.read();
87
88 let mut prompts = state
89 .prompts
90 .iter()
91 .map(|(id, prompt)| (*id, prompt.clone()))
92 .collect::<Vec<_>>();
93
94 match sort_order {
95 SortOrder::Alphabetical => prompts.sort_by(|(_, a), (_, b)| a.title().cmp(&b.title())),
96 };
97
98 prompts
99 }
100
101 pub fn prompt_by_id(&self, id: PromptId) -> Option<StaticPrompt> {
102 let state = self.state.read();
103 state.prompts.get(&id).cloned()
104 }
105
106 pub fn first_prompt_id(&self) -> Option<PromptId> {
107 let state = self.state.read();
108 state.prompts.keys().next().cloned()
109 }
110
111 pub fn is_dirty(&self, id: &PromptId) -> bool {
112 let state = self.state.read();
113 state.dirty_prompts.contains(&id)
114 }
115
116 pub fn set_dirty(&self, id: PromptId, dirty: bool) {
117 let mut state = self.state.write();
118 if dirty {
119 if !state.dirty_prompts.contains(&id) {
120 state.dirty_prompts.push(id);
121 }
122 state.version += 1;
123 } else {
124 state.dirty_prompts.retain(|&i| i != id);
125 state.version += 1;
126 }
127 }
128
129 /// Load the state of the prompt library from the file system
130 /// or create a new one if it doesn't exist
131 pub async fn load_index(fs: Arc<dyn Fs>) -> anyhow::Result<Self> {
132 let path = PROMPTS_DIR.join("index.json");
133
134 let state = if fs.is_file(&path).await {
135 let json = fs.load(&path).await?;
136 serde_json::from_str(&json)?
137 } else {
138 PromptLibraryState::default()
139 };
140
141 let mut prompt_library = Self {
142 state: RwLock::new(state),
143 };
144
145 prompt_library.load_prompts(fs).await?;
146
147 Ok(prompt_library)
148 }
149
150 /// Load all prompts from the file system
151 /// adding them to the library if they don't already exist
152 pub async fn load_prompts(&mut self, fs: Arc<dyn Fs>) -> anyhow::Result<()> {
153 self.state.get_mut().prompts.clear();
154
155 let mut prompt_paths = fs.read_dir(&PROMPTS_DIR).await?;
156
157 while let Some(prompt_path) = prompt_paths.next().await {
158 let prompt_path = prompt_path.with_context(|| "Failed to read prompt path")?;
159 let file_name_lossy = if prompt_path.file_name().is_some() {
160 Some(
161 prompt_path
162 .file_name()
163 .unwrap()
164 .to_string_lossy()
165 .to_string(),
166 )
167 } else {
168 None
169 };
170
171 if !fs.is_file(&prompt_path).await
172 || prompt_path.extension().and_then(|ext| ext.to_str()) != Some("md")
173 {
174 continue;
175 }
176
177 let json = fs
178 .load(&prompt_path)
179 .await
180 .with_context(|| format!("Failed to load prompt {:?}", prompt_path))?;
181
182 // Check that the prompt is valid
183 let matter = Matter::<YAML>::new();
184 let result = matter.parse(&json);
185 if result.data.is_none() {
186 log::warn!("Invalid prompt: {:?}", prompt_path);
187 continue;
188 }
189
190 let static_prompt = StaticPrompt::new(json, file_name_lossy.clone());
191
192 let state = self.state.get_mut();
193
194 let id = Uuid::new_v4();
195 state.prompts.insert(PromptId(id), static_prompt);
196 state.version += 1;
197 }
198
199 // Write any changes back to the file system
200 self.save_index(fs.clone()).await?;
201
202 Ok(())
203 }
204
205 /// Save the current state of the prompt library to the
206 /// file system as a JSON file
207 pub async fn save_index(&self, fs: Arc<dyn Fs>) -> anyhow::Result<()> {
208 fs.create_dir(&PROMPTS_DIR).await?;
209
210 let path = PROMPTS_DIR.join("index.json");
211
212 let json = {
213 let state = self.state.read();
214 serde_json::to_string(&*state)?
215 };
216
217 fs.atomic_write(path, json).await?;
218
219 Ok(())
220 }
221
222 pub async fn save_prompt(
223 &self,
224 prompt_id: PromptId,
225 updated_content: Option<String>,
226 fs: Arc<dyn Fs>,
227 ) -> anyhow::Result<()> {
228 if let Some(updated_content) = updated_content {
229 let mut state = self.state.write();
230 if let Some(prompt) = state.prompts.get_mut(&prompt_id) {
231 prompt.update(prompt_id, updated_content);
232 state.version += 1;
233 }
234 }
235
236 if let Some(prompt) = self.prompt_by_id(prompt_id) {
237 prompt.save(fs).await?;
238 self.set_dirty(prompt_id, false);
239 } else {
240 log::warn!("Failed to save prompt: {:?}", prompt_id);
241 }
242
243 Ok(())
244 }
245}