prompts.rs

  1use assets::Assets;
  2use fs::Fs;
  3use futures::StreamExt;
  4use handlebars::{Handlebars, RenderError, TemplateError};
  5use language::BufferSnapshot;
  6use parking_lot::Mutex;
  7use serde::Serialize;
  8use std::{ops::Range, sync::Arc, time::Duration};
  9use util::ResultExt;
 10
 11#[derive(Serialize)]
 12pub struct ContentPromptContext {
 13    pub content_type: String,
 14    pub language_name: Option<String>,
 15    pub is_insert: bool,
 16    pub is_truncated: bool,
 17    pub document_content: String,
 18    pub user_prompt: String,
 19    pub rewrite_section: Option<String>,
 20}
 21
 22#[derive(Serialize)]
 23pub struct TerminalAssistantPromptContext {
 24    pub os: String,
 25    pub arch: String,
 26    pub shell: Option<String>,
 27    pub working_directory: Option<String>,
 28    pub latest_output: Vec<String>,
 29    pub user_prompt: String,
 30}
 31
 32pub struct PromptBuilder {
 33    handlebars: Arc<Mutex<Handlebars<'static>>>,
 34}
 35
 36impl PromptBuilder {
 37    pub fn new(
 38        fs_and_cx: Option<(Arc<dyn Fs>, &gpui::AppContext)>,
 39    ) -> Result<Self, Box<TemplateError>> {
 40        let mut handlebars = Handlebars::new();
 41        Self::register_templates(&mut handlebars)?;
 42
 43        let handlebars = Arc::new(Mutex::new(handlebars));
 44
 45        if let Some((fs, cx)) = fs_and_cx {
 46            Self::watch_fs_for_template_overrides(fs, cx, handlebars.clone());
 47        }
 48
 49        Ok(Self { handlebars })
 50    }
 51
 52    fn watch_fs_for_template_overrides(
 53        fs: Arc<dyn Fs>,
 54        cx: &gpui::AppContext,
 55        handlebars: Arc<Mutex<Handlebars<'static>>>,
 56    ) {
 57        let templates_dir = paths::prompt_templates_dir();
 58
 59        cx.background_executor()
 60            .spawn(async move {
 61                // Create the prompt templates directory if it doesn't exist
 62                if !fs.is_dir(templates_dir).await {
 63                    if let Err(e) = fs.create_dir(templates_dir).await {
 64                        log::error!("Failed to create prompt templates directory: {}", e);
 65                        return;
 66                    }
 67                }
 68
 69                // Initial scan of the prompts directory
 70                if let Ok(mut entries) = fs.read_dir(templates_dir).await {
 71                    while let Some(Ok(file_path)) = entries.next().await {
 72                        if file_path.to_string_lossy().ends_with(".hbs") {
 73                            if let Ok(content) = fs.load(&file_path).await {
 74                                let file_name = file_path.file_stem().unwrap().to_string_lossy();
 75
 76                                match handlebars.lock().register_template_string(&file_name, content) {
 77                                    Ok(_) => {
 78                                        log::info!(
 79                                            "Successfully registered template override: {} ({})",
 80                                            file_name,
 81                                            file_path.display()
 82                                        );
 83                                    },
 84                                    Err(e) => {
 85                                        log::error!(
 86                                            "Failed to register template during initial scan: {} ({})",
 87                                            e,
 88                                            file_path.display()
 89                                        );
 90                                    },
 91                                }
 92                            }
 93                        }
 94                    }
 95                }
 96
 97                // Watch for changes
 98                let (mut changes, watcher) = fs.watch(templates_dir, Duration::from_secs(1)).await;
 99                while let Some(changed_paths) = changes.next().await {
100                    for changed_path in changed_paths {
101                        if changed_path.extension().map_or(false, |ext| ext == "hbs") {
102                            log::info!("Reloading template: {}", changed_path.display());
103                            if let Some(content) = fs.load(&changed_path).await.log_err() {
104                                let file_name = changed_path.file_stem().unwrap().to_string_lossy();
105                                let file_path = changed_path.to_string_lossy();
106                                match handlebars.lock().register_template_string(&file_name, content) {
107                                    Ok(_) => log::info!(
108                                        "Successfully reloaded template: {} ({})",
109                                        file_name,
110                                        file_path
111                                    ),
112                                    Err(e) => log::error!(
113                                        "Failed to register template: {} ({})",
114                                        e,
115                                        file_path
116                                    ),
117                                }
118                            }
119                        }
120                    }
121                }
122                drop(watcher);
123            })
124            .detach();
125    }
126
127    fn register_templates(handlebars: &mut Handlebars) -> Result<(), Box<TemplateError>> {
128        let mut register_template = |id: &str| {
129            let prompt = Assets::get(&format!("prompts/{}.hbs", id))
130                .unwrap_or_else(|| panic!("{} prompt template not found", id))
131                .data;
132            handlebars
133                .register_template_string(id, String::from_utf8_lossy(&prompt))
134                .map_err(Box::new)
135        };
136
137        register_template("content_prompt")?;
138        register_template("terminal_assistant_prompt")?;
139        register_template("edit_workflow")?;
140        register_template("step_resolution")?;
141
142        Ok(())
143    }
144
145    pub fn generate_content_prompt(
146        &self,
147        user_prompt: String,
148        language_name: Option<&str>,
149        buffer: BufferSnapshot,
150        range: Range<usize>,
151    ) -> Result<String, RenderError> {
152        let content_type = match language_name {
153            None | Some("Markdown" | "Plain Text") => "text",
154            Some(_) => "code",
155        };
156
157        const MAX_CTX: usize = 50000;
158        let is_insert = range.is_empty();
159        let mut is_truncated = false;
160
161        let before_range = 0..range.start;
162        let truncated_before = if before_range.len() > MAX_CTX {
163            is_truncated = true;
164            range.start - MAX_CTX..range.start
165        } else {
166            before_range
167        };
168
169        let after_range = range.end..buffer.len();
170        let truncated_after = if after_range.len() > MAX_CTX {
171            is_truncated = true;
172            range.end..range.end + MAX_CTX
173        } else {
174            after_range
175        };
176
177        let mut document_content = String::new();
178        for chunk in buffer.text_for_range(truncated_before) {
179            document_content.push_str(chunk);
180        }
181        if is_insert {
182            document_content.push_str("<insert_here></insert_here>");
183        } else {
184            document_content.push_str("<rewrite_this>\n");
185            for chunk in buffer.text_for_range(range.clone()) {
186                document_content.push_str(chunk);
187            }
188            document_content.push_str("\n</rewrite_this>");
189        }
190        for chunk in buffer.text_for_range(truncated_after) {
191            document_content.push_str(chunk);
192        }
193
194        let rewrite_section = if !is_insert {
195            let mut section = String::new();
196            for chunk in buffer.text_for_range(range.clone()) {
197                section.push_str(chunk);
198            }
199            Some(section)
200        } else {
201            None
202        };
203
204        let context = ContentPromptContext {
205            content_type: content_type.to_string(),
206            language_name: language_name.map(|s| s.to_string()),
207            is_insert,
208            is_truncated,
209            document_content,
210            user_prompt,
211            rewrite_section,
212        };
213
214        self.handlebars.lock().render("content_prompt", &context)
215    }
216
217    pub fn generate_terminal_assistant_prompt(
218        &self,
219        user_prompt: &str,
220        shell: Option<&str>,
221        working_directory: Option<&str>,
222        latest_output: &[String],
223    ) -> Result<String, RenderError> {
224        let context = TerminalAssistantPromptContext {
225            os: std::env::consts::OS.to_string(),
226            arch: std::env::consts::ARCH.to_string(),
227            shell: shell.map(|s| s.to_string()),
228            working_directory: working_directory.map(|s| s.to_string()),
229            latest_output: latest_output.to_vec(),
230            user_prompt: user_prompt.to_string(),
231        };
232
233        self.handlebars
234            .lock()
235            .render("terminal_assistant_prompt", &context)
236    }
237
238    pub fn generate_workflow_prompt(&self) -> Result<String, RenderError> {
239        self.handlebars.lock().render("edit_workflow", &())
240    }
241
242    pub fn generate_step_resolution_prompt(&self) -> Result<String, RenderError> {
243        self.handlebars.lock().render("step_resolution", &())
244    }
245}