1use std::borrow::Cow;
2
3#[cfg(feature = "dynamic_prompts")]
4pub fn get_prompt(name: &'static str) -> Cow<'static, str> {
5 use anyhow::Context;
6 use std::collections::HashMap;
7 use std::path::Path;
8 use std::sync::{LazyLock, RwLock};
9
10 const PROMPTS_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/src/prompts");
11
12 static PROMPT_CACHE: LazyLock<RwLock<HashMap<&'static str, &'static str>>> =
13 LazyLock::new(|| RwLock::new(HashMap::default()));
14
15 let filesystem_path = Path::new(PROMPTS_DIR).join(name);
16 if let Some(cached_contents) = PROMPT_CACHE.read().unwrap().get(name) {
17 return Cow::Borrowed(cached_contents);
18 }
19 let contents = std::fs::read_to_string(&filesystem_path)
20 .context(name)
21 .expect("Failed to read prompt");
22 let leaked = contents.leak();
23 PROMPT_CACHE.write().unwrap().insert(name, leaked);
24 return Cow::Borrowed(leaked);
25}
26
27#[cfg(not(feature = "dynamic_prompts"))]
28pub fn get_prompt(name: &'static str) -> Cow<'static, str> {
29 use rust_embed::RustEmbed;
30
31 #[derive(RustEmbed)]
32 #[folder = "src/prompts"]
33 struct EmbeddedPrompts;
34
35 match EmbeddedPrompts::get(name) {
36 Some(file) => match file.data {
37 Cow::Borrowed(bytes) => {
38 Cow::Borrowed(std::str::from_utf8(bytes).expect("prompt file is not valid UTF-8"))
39 }
40 Cow::Owned(bytes) => {
41 Cow::Owned(String::from_utf8(bytes).expect("prompt file is not valid UTF-8"))
42 }
43 },
44 None => panic!("prompt file not found: {name}"),
45 }
46}