prompts.rs

  1use assets::Assets;
  2use fs::Fs;
  3use futures::StreamExt;
  4use handlebars::{Handlebars, RenderError, TemplateError};
  5use language::BufferSnapshot;
  6use parking_lot::Mutex;
  7use serde::Serialize;
  8use std::{ops::Range, sync::Arc, time::Duration};
  9use util::ResultExt;
 10
 11#[derive(Serialize)]
 12pub struct ContentPromptContext {
 13    pub content_type: String,
 14    pub language_name: Option<String>,
 15    pub is_truncated: bool,
 16    pub document_content: String,
 17    pub user_prompt: String,
 18    pub rewrite_section: String,
 19    pub rewrite_section_prefix: String,
 20    pub rewrite_section_suffix: String,
 21    pub rewrite_section_with_edits: String,
 22    pub has_insertion: bool,
 23    pub has_replacement: bool,
 24}
 25
 26#[derive(Serialize)]
 27pub struct TerminalAssistantPromptContext {
 28    pub os: String,
 29    pub arch: String,
 30    pub shell: Option<String>,
 31    pub working_directory: Option<String>,
 32    pub latest_output: Vec<String>,
 33    pub user_prompt: String,
 34}
 35
 36/// Context required to generate a workflow step resolution prompt.
 37#[derive(Debug, Serialize)]
 38pub struct StepResolutionContext {
 39    /// The full context, including <step>...</step> tags
 40    pub workflow_context: String,
 41    /// The text of the specific step from the context to resolve
 42    pub step_to_resolve: String,
 43}
 44
 45pub struct PromptBuilder {
 46    handlebars: Arc<Mutex<Handlebars<'static>>>,
 47}
 48
 49pub struct PromptOverrideContext<'a> {
 50    pub dev_mode: bool,
 51    pub fs: Arc<dyn Fs>,
 52    pub cx: &'a mut gpui::AppContext,
 53}
 54
 55impl PromptBuilder {
 56    pub fn new(override_cx: Option<PromptOverrideContext>) -> Result<Self, Box<TemplateError>> {
 57        let mut handlebars = Handlebars::new();
 58        Self::register_templates(&mut handlebars)?;
 59
 60        let handlebars = Arc::new(Mutex::new(handlebars));
 61
 62        if let Some(override_cx) = override_cx {
 63            Self::watch_fs_for_template_overrides(override_cx, handlebars.clone());
 64        }
 65
 66        Ok(Self { handlebars })
 67    }
 68
 69    fn watch_fs_for_template_overrides(
 70        PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext,
 71        handlebars: Arc<Mutex<Handlebars<'static>>>,
 72    ) {
 73        cx.background_executor()
 74            .spawn(async move {
 75                let templates_dir = if dev_mode {
 76                    std::env::current_dir()
 77                        .ok()
 78                        .and_then(|pwd| {
 79                            let pwd_assets_prompts = pwd.join("assets").join("prompts");
 80                            pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
 81                        })
 82                        .unwrap_or_else(|| paths::prompt_overrides_dir().clone())
 83                } else {
 84                    paths::prompt_overrides_dir().clone()
 85                };
 86
 87                // Create the prompt templates directory if it doesn't exist
 88                if !fs.is_dir(&templates_dir).await {
 89                    if let Err(e) = fs.create_dir(&templates_dir).await {
 90                        log::error!("Failed to create prompt templates directory: {}", e);
 91                        return;
 92                    }
 93                }
 94
 95                // Initial scan of the prompts directory
 96                if let Ok(mut entries) = fs.read_dir(&templates_dir).await {
 97                    while let Some(Ok(file_path)) = entries.next().await {
 98                        if file_path.to_string_lossy().ends_with(".hbs") {
 99                            if let Ok(content) = fs.load(&file_path).await {
100                                let file_name = file_path.file_stem().unwrap().to_string_lossy();
101
102                                match handlebars.lock().register_template_string(&file_name, content) {
103                                    Ok(_) => {
104                                        log::info!(
105                                            "Successfully registered template override: {} ({})",
106                                            file_name,
107                                            file_path.display()
108                                        );
109                                    },
110                                    Err(e) => {
111                                        log::error!(
112                                            "Failed to register template during initial scan: {} ({})",
113                                            e,
114                                            file_path.display()
115                                        );
116                                    },
117                                }
118                            }
119                        }
120                    }
121                }
122
123                // Watch for changes
124                let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await;
125                while let Some(changed_paths) = changes.next().await {
126                    for changed_path in changed_paths {
127                        if changed_path.extension().map_or(false, |ext| ext == "hbs") {
128                            log::info!("Reloading template: {}", changed_path.display());
129                            if let Some(content) = fs.load(&changed_path).await.log_err() {
130                                let file_name = changed_path.file_stem().unwrap().to_string_lossy();
131                                let file_path = changed_path.to_string_lossy();
132                                match handlebars.lock().register_template_string(&file_name, content) {
133                                    Ok(_) => log::info!(
134                                        "Successfully reloaded template: {} ({})",
135                                        file_name,
136                                        file_path
137                                    ),
138                                    Err(e) => log::error!(
139                                        "Failed to register template: {} ({})",
140                                        e,
141                                        file_path
142                                    ),
143                                }
144                            }
145                        }
146                    }
147                }
148                drop(watcher);
149            })
150            .detach();
151    }
152
153    fn register_templates(handlebars: &mut Handlebars) -> Result<(), Box<TemplateError>> {
154        let mut register_template = |id: &str| {
155            let prompt = Assets::get(&format!("prompts/{}.hbs", id))
156                .unwrap_or_else(|| panic!("{} prompt template not found", id))
157                .data;
158            handlebars
159                .register_template_string(id, String::from_utf8_lossy(&prompt))
160                .map_err(Box::new)
161        };
162
163        register_template("content_prompt")?;
164        register_template("terminal_assistant_prompt")?;
165        register_template("edit_workflow")?;
166        register_template("step_resolution")?;
167
168        Ok(())
169    }
170
171    pub fn generate_content_prompt(
172        &self,
173        user_prompt: String,
174        language_name: Option<&str>,
175        buffer: BufferSnapshot,
176        transform_range: Range<usize>,
177        selected_ranges: Vec<Range<usize>>,
178        transform_context_range: Range<usize>,
179    ) -> Result<String, RenderError> {
180        let content_type = match language_name {
181            None | Some("Markdown" | "Plain Text") => "text",
182            Some(_) => "code",
183        };
184
185        const MAX_CTX: usize = 50000;
186        let mut is_truncated = false;
187
188        let before_range = 0..transform_range.start;
189        let truncated_before = if before_range.len() > MAX_CTX {
190            is_truncated = true;
191            transform_range.start - MAX_CTX..transform_range.start
192        } else {
193            before_range
194        };
195
196        let after_range = transform_range.end..buffer.len();
197        let truncated_after = if after_range.len() > MAX_CTX {
198            is_truncated = true;
199            transform_range.end..transform_range.end + MAX_CTX
200        } else {
201            after_range
202        };
203
204        let mut document_content = String::new();
205        for chunk in buffer.text_for_range(truncated_before) {
206            document_content.push_str(chunk);
207        }
208
209        document_content.push_str("<rewrite_this>\n");
210        for chunk in buffer.text_for_range(transform_range.clone()) {
211            document_content.push_str(chunk);
212        }
213        document_content.push_str("\n</rewrite_this>");
214
215        for chunk in buffer.text_for_range(truncated_after) {
216            document_content.push_str(chunk);
217        }
218
219        let mut rewrite_section = String::new();
220        for chunk in buffer.text_for_range(transform_range.clone()) {
221            rewrite_section.push_str(chunk);
222        }
223
224        let mut rewrite_section_prefix = String::new();
225        for chunk in buffer.text_for_range(transform_context_range.start..transform_range.start) {
226            rewrite_section_prefix.push_str(chunk);
227        }
228
229        let mut rewrite_section_suffix = String::new();
230        for chunk in buffer.text_for_range(transform_range.end..transform_context_range.end) {
231            rewrite_section_suffix.push_str(chunk);
232        }
233
234        let rewrite_section_with_edits = {
235            let mut section_with_selections = String::new();
236            let mut last_end = 0;
237            for selected_range in &selected_ranges {
238                if selected_range.start > last_end {
239                    section_with_selections.push_str(
240                        &rewrite_section[last_end..selected_range.start - transform_range.start],
241                    );
242                }
243                if selected_range.start == selected_range.end {
244                    section_with_selections.push_str("<insert_here></insert_here>");
245                } else {
246                    section_with_selections.push_str("<edit_here>");
247                    section_with_selections.push_str(
248                        &rewrite_section[selected_range.start - transform_range.start
249                            ..selected_range.end - transform_range.start],
250                    );
251                    section_with_selections.push_str("</edit_here>");
252                }
253                last_end = selected_range.end - transform_range.start;
254            }
255            if last_end < rewrite_section.len() {
256                section_with_selections.push_str(&rewrite_section[last_end..]);
257            }
258            section_with_selections
259        };
260
261        let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
262        let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
263
264        let context = ContentPromptContext {
265            content_type: content_type.to_string(),
266            language_name: language_name.map(|s| s.to_string()),
267            is_truncated,
268            document_content,
269            user_prompt,
270            rewrite_section,
271            rewrite_section_prefix,
272            rewrite_section_suffix,
273            rewrite_section_with_edits,
274            has_insertion,
275            has_replacement,
276        };
277
278        self.handlebars.lock().render("content_prompt", &context)
279    }
280
281    pub fn generate_terminal_assistant_prompt(
282        &self,
283        user_prompt: &str,
284        shell: Option<&str>,
285        working_directory: Option<&str>,
286        latest_output: &[String],
287    ) -> Result<String, RenderError> {
288        let context = TerminalAssistantPromptContext {
289            os: std::env::consts::OS.to_string(),
290            arch: std::env::consts::ARCH.to_string(),
291            shell: shell.map(|s| s.to_string()),
292            working_directory: working_directory.map(|s| s.to_string()),
293            latest_output: latest_output.to_vec(),
294            user_prompt: user_prompt.to_string(),
295        };
296
297        self.handlebars
298            .lock()
299            .render("terminal_assistant_prompt", &context)
300    }
301
302    pub fn generate_workflow_prompt(&self) -> Result<String, RenderError> {
303        self.handlebars.lock().render("edit_workflow", &())
304    }
305
306    pub fn generate_step_resolution_prompt(
307        &self,
308        context: &StepResolutionContext,
309    ) -> Result<String, RenderError> {
310        self.handlebars.lock().render("step_resolution", context)
311    }
312}