prompt_assets.rs

 1use std::borrow::Cow;
 2
 3#[cfg(feature = "dynamic_prompts")]
 4pub fn get_prompt(name: &'static str) -> Cow<'static, str> {
 5    use anyhow::Context;
 6    use std::collections::HashMap;
 7    use std::path::Path;
 8    use std::sync::{LazyLock, RwLock};
 9
10    const PROMPTS_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/src/prompts");
11
12    static PROMPT_CACHE: LazyLock<RwLock<HashMap<&'static str, &'static str>>> =
13        LazyLock::new(|| RwLock::new(HashMap::default()));
14
15    let filesystem_path = Path::new(PROMPTS_DIR).join(name);
16    if let Some(cached_contents) = PROMPT_CACHE.read().unwrap().get(name) {
17        return Cow::Borrowed(cached_contents);
18    }
19    let contents = std::fs::read_to_string(&filesystem_path)
20        .context(name)
21        .expect("Failed to read prompt");
22    let leaked = contents.leak();
23    PROMPT_CACHE.write().unwrap().insert(name, leaked);
24    return Cow::Borrowed(leaked);
25}
26
27#[cfg(not(feature = "dynamic_prompts"))]
28pub fn get_prompt(name: &'static str) -> Cow<'static, str> {
29    use rust_embed::RustEmbed;
30
31    #[derive(RustEmbed)]
32    #[folder = "src/prompts"]
33    struct EmbeddedPrompts;
34
35    match EmbeddedPrompts::get(name) {
36        Some(file) => match file.data {
37            Cow::Borrowed(bytes) => {
38                Cow::Borrowed(std::str::from_utf8(bytes).expect("prompt file is not valid UTF-8"))
39            }
40            Cow::Owned(bytes) => {
41                Cow::Owned(String::from_utf8(bytes).expect("prompt file is not valid UTF-8"))
42            }
43        },
44        None => panic!("prompt file not found: {name}"),
45    }
46}