templates.rs

 1use anyhow::Result;
 2use gpui::SharedString;
 3use handlebars::Handlebars;
 4use rust_embed::RustEmbed;
 5use serde::Serialize;
 6use std::sync::Arc;
 7
 8#[derive(RustEmbed)]
 9#[folder = "src/templates"]
10#[include = "*.hbs"]
11struct Assets;
12
13pub struct Templates(Handlebars<'static>);
14
15impl Templates {
16    pub fn new() -> Arc<Self> {
17        let mut handlebars = Handlebars::new();
18        handlebars.set_strict_mode(true);
19        handlebars.register_helper("contains", Box::new(contains));
20        handlebars.register_embed_templates::<Assets>().unwrap();
21        Arc::new(Self(handlebars))
22    }
23}
24
25pub trait Template: Sized {
26    const TEMPLATE_NAME: &'static str;
27
28    fn render(&self, templates: &Templates) -> Result<String>
29    where
30        Self: Serialize + Sized,
31    {
32        Ok(templates.0.render(Self::TEMPLATE_NAME, self)?)
33    }
34}
35
36#[derive(Serialize)]
37pub struct GlobTemplate {
38    pub project_roots: String,
39}
40
41impl Template for GlobTemplate {
42    const TEMPLATE_NAME: &'static str = "glob.hbs";
43}
44
45#[derive(Serialize)]
46pub struct SystemPromptTemplate<'a> {
47    #[serde(flatten)]
48    pub project: &'a prompt_store::ProjectContext,
49    pub available_tools: Vec<SharedString>,
50}
51
52impl Template for SystemPromptTemplate<'_> {
53    const TEMPLATE_NAME: &'static str = "system_prompt.hbs";
54}
55
56/// Handlebars helper for checking if an item is in a list
57fn contains(
58    h: &handlebars::Helper,
59    _: &handlebars::Handlebars,
60    _: &handlebars::Context,
61    _: &mut handlebars::RenderContext,
62    out: &mut dyn handlebars::Output,
63) -> handlebars::HelperResult {
64    let list = h
65        .param(0)
66        .and_then(|v| v.value().as_array())
67        .ok_or_else(|| {
68            handlebars::RenderError::new("contains: missing or invalid list parameter")
69        })?;
70    let query = h.param(1).map(|v| v.value()).ok_or_else(|| {
71        handlebars::RenderError::new("contains: missing or invalid query parameter")
72    })?;
73
74    if list.contains(&query) {
75        out.write("true")?;
76    }
77
78    Ok(())
79}
80
81#[cfg(test)]
82mod tests {
83    use super::*;
84
85    #[test]
86    fn test_system_prompt_template() {
87        let project = prompt_store::ProjectContext::default();
88        let template = SystemPromptTemplate {
89            project: &project,
90            available_tools: vec!["echo".into()],
91        };
92        let templates = Templates::new();
93        let rendered = template.render(&templates).unwrap();
94        assert!(rendered.contains("## Fixing Diagnostics"));
95    }
96}