prompts.rs

  1use assets::Assets;
  2use fs::Fs;
  3use futures::StreamExt;
  4use handlebars::{Handlebars, RenderError, TemplateError};
  5use language::BufferSnapshot;
  6use parking_lot::Mutex;
  7use serde::Serialize;
  8use std::{ops::Range, sync::Arc, time::Duration};
  9use util::ResultExt;
 10
 11#[derive(Serialize)]
 12pub struct ContentPromptContext {
 13    pub content_type: String,
 14    pub language_name: Option<String>,
 15    pub is_truncated: bool,
 16    pub document_content: String,
 17    pub user_prompt: String,
 18    pub rewrite_section: String,
 19    pub rewrite_section_with_selections: String,
 20    pub has_insertion: bool,
 21    pub has_replacement: bool,
 22}
 23
 24#[derive(Serialize)]
 25pub struct TerminalAssistantPromptContext {
 26    pub os: String,
 27    pub arch: String,
 28    pub shell: Option<String>,
 29    pub working_directory: Option<String>,
 30    pub latest_output: Vec<String>,
 31    pub user_prompt: String,
 32}
 33
 34pub struct PromptBuilder {
 35    handlebars: Arc<Mutex<Handlebars<'static>>>,
 36}
 37
 38pub struct PromptOverrideContext<'a> {
 39    pub dev_mode: bool,
 40    pub fs: Arc<dyn Fs>,
 41    pub cx: &'a mut gpui::AppContext,
 42}
 43
 44impl PromptBuilder {
 45    pub fn new(override_cx: Option<PromptOverrideContext>) -> Result<Self, Box<TemplateError>> {
 46        let mut handlebars = Handlebars::new();
 47        Self::register_templates(&mut handlebars)?;
 48
 49        let handlebars = Arc::new(Mutex::new(handlebars));
 50
 51        if let Some(override_cx) = override_cx {
 52            Self::watch_fs_for_template_overrides(override_cx, handlebars.clone());
 53        }
 54
 55        Ok(Self { handlebars })
 56    }
 57
 58    fn watch_fs_for_template_overrides(
 59        PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext,
 60        handlebars: Arc<Mutex<Handlebars<'static>>>,
 61    ) {
 62        cx.background_executor()
 63            .spawn(async move {
 64                let templates_dir = if dev_mode {
 65                    std::env::current_dir()
 66                        .ok()
 67                        .and_then(|pwd| {
 68                            let pwd_assets_prompts = pwd.join("assets").join("prompts");
 69                            pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
 70                        })
 71                        .unwrap_or_else(|| paths::prompt_overrides_dir().clone())
 72                } else {
 73                    paths::prompt_overrides_dir().clone()
 74                };
 75
 76                // Create the prompt templates directory if it doesn't exist
 77                if !fs.is_dir(&templates_dir).await {
 78                    if let Err(e) = fs.create_dir(&templates_dir).await {
 79                        log::error!("Failed to create prompt templates directory: {}", e);
 80                        return;
 81                    }
 82                }
 83
 84                // Initial scan of the prompts directory
 85                if let Ok(mut entries) = fs.read_dir(&templates_dir).await {
 86                    while let Some(Ok(file_path)) = entries.next().await {
 87                        if file_path.to_string_lossy().ends_with(".hbs") {
 88                            if let Ok(content) = fs.load(&file_path).await {
 89                                let file_name = file_path.file_stem().unwrap().to_string_lossy();
 90
 91                                match handlebars.lock().register_template_string(&file_name, content) {
 92                                    Ok(_) => {
 93                                        log::info!(
 94                                            "Successfully registered template override: {} ({})",
 95                                            file_name,
 96                                            file_path.display()
 97                                        );
 98                                    },
 99                                    Err(e) => {
100                                        log::error!(
101                                            "Failed to register template during initial scan: {} ({})",
102                                            e,
103                                            file_path.display()
104                                        );
105                                    },
106                                }
107                            }
108                        }
109                    }
110                }
111
112                // Watch for changes
113                let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await;
114                while let Some(changed_paths) = changes.next().await {
115                    for changed_path in changed_paths {
116                        if changed_path.extension().map_or(false, |ext| ext == "hbs") {
117                            log::info!("Reloading template: {}", changed_path.display());
118                            if let Some(content) = fs.load(&changed_path).await.log_err() {
119                                let file_name = changed_path.file_stem().unwrap().to_string_lossy();
120                                let file_path = changed_path.to_string_lossy();
121                                match handlebars.lock().register_template_string(&file_name, content) {
122                                    Ok(_) => log::info!(
123                                        "Successfully reloaded template: {} ({})",
124                                        file_name,
125                                        file_path
126                                    ),
127                                    Err(e) => log::error!(
128                                        "Failed to register template: {} ({})",
129                                        e,
130                                        file_path
131                                    ),
132                                }
133                            }
134                        }
135                    }
136                }
137                drop(watcher);
138            })
139            .detach();
140    }
141
142    fn register_templates(handlebars: &mut Handlebars) -> Result<(), Box<TemplateError>> {
143        let mut register_template = |id: &str| {
144            let prompt = Assets::get(&format!("prompts/{}.hbs", id))
145                .unwrap_or_else(|| panic!("{} prompt template not found", id))
146                .data;
147            handlebars
148                .register_template_string(id, String::from_utf8_lossy(&prompt))
149                .map_err(Box::new)
150        };
151
152        register_template("content_prompt")?;
153        register_template("terminal_assistant_prompt")?;
154        register_template("edit_workflow")?;
155        register_template("step_resolution")?;
156
157        Ok(())
158    }
159
160    pub fn generate_content_prompt(
161        &self,
162        user_prompt: String,
163        language_name: Option<&str>,
164        buffer: BufferSnapshot,
165        transform_range: Range<usize>,
166        selected_ranges: Vec<Range<usize>>,
167    ) -> Result<String, RenderError> {
168        let content_type = match language_name {
169            None | Some("Markdown" | "Plain Text") => "text",
170            Some(_) => "code",
171        };
172
173        const MAX_CTX: usize = 50000;
174        let mut is_truncated = false;
175
176        let before_range = 0..transform_range.start;
177        let truncated_before = if before_range.len() > MAX_CTX {
178            is_truncated = true;
179            transform_range.start - MAX_CTX..transform_range.start
180        } else {
181            before_range
182        };
183
184        let after_range = transform_range.end..buffer.len();
185        let truncated_after = if after_range.len() > MAX_CTX {
186            is_truncated = true;
187            transform_range.end..transform_range.end + MAX_CTX
188        } else {
189            after_range
190        };
191
192        let mut document_content = String::new();
193        for chunk in buffer.text_for_range(truncated_before) {
194            document_content.push_str(chunk);
195        }
196        document_content.push_str("<rewrite_this>\n");
197        for chunk in buffer.text_for_range(transform_range.clone()) {
198            document_content.push_str(chunk);
199        }
200        document_content.push_str("\n</rewrite_this>");
201
202        for chunk in buffer.text_for_range(truncated_after) {
203            document_content.push_str(chunk);
204        }
205
206        let mut rewrite_section = String::new();
207        for chunk in buffer.text_for_range(transform_range.clone()) {
208            rewrite_section.push_str(chunk);
209        }
210
211        let rewrite_section_with_selections = {
212            let mut section_with_selections = String::new();
213            let mut last_end = 0;
214            for selected_range in &selected_ranges {
215                if selected_range.start > last_end {
216                    section_with_selections.push_str(
217                        &rewrite_section[last_end..selected_range.start - transform_range.start],
218                    );
219                }
220                if selected_range.start == selected_range.end {
221                    section_with_selections.push_str("<insert_here></insert_here>");
222                } else {
223                    section_with_selections.push_str("<edit_here>");
224                    section_with_selections.push_str(
225                        &rewrite_section[selected_range.start - transform_range.start
226                            ..selected_range.end - transform_range.start],
227                    );
228                    section_with_selections.push_str("</edit_here>");
229                }
230                last_end = selected_range.end - transform_range.start;
231            }
232            if last_end < rewrite_section.len() {
233                section_with_selections.push_str(&rewrite_section[last_end..]);
234            }
235            section_with_selections
236        };
237
238        let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
239        let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
240
241        let context = ContentPromptContext {
242            content_type: content_type.to_string(),
243            language_name: language_name.map(|s| s.to_string()),
244            is_truncated,
245            document_content,
246            user_prompt,
247            rewrite_section,
248            rewrite_section_with_selections,
249            has_insertion,
250            has_replacement,
251        };
252
253        self.handlebars.lock().render("content_prompt", &context)
254    }
255
256    pub fn generate_terminal_assistant_prompt(
257        &self,
258        user_prompt: &str,
259        shell: Option<&str>,
260        working_directory: Option<&str>,
261        latest_output: &[String],
262    ) -> Result<String, RenderError> {
263        let context = TerminalAssistantPromptContext {
264            os: std::env::consts::OS.to_string(),
265            arch: std::env::consts::ARCH.to_string(),
266            shell: shell.map(|s| s.to_string()),
267            working_directory: working_directory.map(|s| s.to_string()),
268            latest_output: latest_output.to_vec(),
269            user_prompt: user_prompt.to_string(),
270        };
271
272        self.handlebars
273            .lock()
274            .render("terminal_assistant_prompt", &context)
275    }
276
277    pub fn generate_workflow_prompt(&self) -> Result<String, RenderError> {
278        self.handlebars.lock().render("edit_workflow", &())
279    }
280
281    pub fn generate_step_resolution_prompt(&self) -> Result<String, RenderError> {
282        self.handlebars.lock().render("step_resolution", &())
283    }
284}