1use assets::Assets;
2use fs::Fs;
3use futures::StreamExt;
4use handlebars::{Handlebars, RenderError, TemplateError};
5use language::BufferSnapshot;
6use parking_lot::Mutex;
7use serde::Serialize;
8use std::{ops::Range, sync::Arc, time::Duration};
9use util::ResultExt;
10
11#[derive(Serialize)]
12pub struct ContentPromptContext {
13 pub content_type: String,
14 pub language_name: Option<String>,
15 pub is_truncated: bool,
16 pub document_content: String,
17 pub user_prompt: String,
18 pub rewrite_section: String,
19 pub rewrite_section_prefix: String,
20 pub rewrite_section_suffix: String,
21 pub rewrite_section_with_edits: String,
22 pub has_insertion: bool,
23 pub has_replacement: bool,
24}
25
26#[derive(Serialize)]
27pub struct TerminalAssistantPromptContext {
28 pub os: String,
29 pub arch: String,
30 pub shell: Option<String>,
31 pub working_directory: Option<String>,
32 pub latest_output: Vec<String>,
33 pub user_prompt: String,
34}
35
36/// Context required to generate a workflow step resolution prompt.
37#[derive(Debug, Serialize)]
38pub struct StepResolutionContext {
39 /// The full context, including <step>...</step> tags
40 pub workflow_context: String,
41 /// The text of the specific step from the context to resolve
42 pub step_to_resolve: String,
43}
44
45pub struct PromptBuilder {
46 handlebars: Arc<Mutex<Handlebars<'static>>>,
47}
48
49pub struct PromptOverrideContext<'a> {
50 pub dev_mode: bool,
51 pub fs: Arc<dyn Fs>,
52 pub cx: &'a mut gpui::AppContext,
53}
54
55impl PromptBuilder {
56 pub fn new(override_cx: Option<PromptOverrideContext>) -> Result<Self, Box<TemplateError>> {
57 let mut handlebars = Handlebars::new();
58 Self::register_templates(&mut handlebars)?;
59
60 let handlebars = Arc::new(Mutex::new(handlebars));
61
62 if let Some(override_cx) = override_cx {
63 Self::watch_fs_for_template_overrides(override_cx, handlebars.clone());
64 }
65
66 Ok(Self { handlebars })
67 }
68
69 fn watch_fs_for_template_overrides(
70 PromptOverrideContext { dev_mode, fs, cx }: PromptOverrideContext,
71 handlebars: Arc<Mutex<Handlebars<'static>>>,
72 ) {
73 cx.background_executor()
74 .spawn(async move {
75 let templates_dir = if dev_mode {
76 std::env::current_dir()
77 .ok()
78 .and_then(|pwd| {
79 let pwd_assets_prompts = pwd.join("assets").join("prompts");
80 pwd_assets_prompts.exists().then_some(pwd_assets_prompts)
81 })
82 .unwrap_or_else(|| paths::prompt_overrides_dir().clone())
83 } else {
84 paths::prompt_overrides_dir().clone()
85 };
86
87 // Create the prompt templates directory if it doesn't exist
88 if !fs.is_dir(&templates_dir).await {
89 if let Err(e) = fs.create_dir(&templates_dir).await {
90 log::error!("Failed to create prompt templates directory: {}", e);
91 return;
92 }
93 }
94
95 // Initial scan of the prompts directory
96 if let Ok(mut entries) = fs.read_dir(&templates_dir).await {
97 while let Some(Ok(file_path)) = entries.next().await {
98 if file_path.to_string_lossy().ends_with(".hbs") {
99 if let Ok(content) = fs.load(&file_path).await {
100 let file_name = file_path.file_stem().unwrap().to_string_lossy();
101
102 match handlebars.lock().register_template_string(&file_name, content) {
103 Ok(_) => {
104 log::info!(
105 "Successfully registered template override: {} ({})",
106 file_name,
107 file_path.display()
108 );
109 },
110 Err(e) => {
111 log::error!(
112 "Failed to register template during initial scan: {} ({})",
113 e,
114 file_path.display()
115 );
116 },
117 }
118 }
119 }
120 }
121 }
122
123 // Watch for changes
124 let (mut changes, watcher) = fs.watch(&templates_dir, Duration::from_secs(1)).await;
125 while let Some(changed_paths) = changes.next().await {
126 for changed_path in changed_paths {
127 if changed_path.extension().map_or(false, |ext| ext == "hbs") {
128 log::info!("Reloading template: {}", changed_path.display());
129 if let Some(content) = fs.load(&changed_path).await.log_err() {
130 let file_name = changed_path.file_stem().unwrap().to_string_lossy();
131 let file_path = changed_path.to_string_lossy();
132 match handlebars.lock().register_template_string(&file_name, content) {
133 Ok(_) => log::info!(
134 "Successfully reloaded template: {} ({})",
135 file_name,
136 file_path
137 ),
138 Err(e) => log::error!(
139 "Failed to register template: {} ({})",
140 e,
141 file_path
142 ),
143 }
144 }
145 }
146 }
147 }
148 drop(watcher);
149 })
150 .detach();
151 }
152
153 fn register_templates(handlebars: &mut Handlebars) -> Result<(), Box<TemplateError>> {
154 let mut register_template = |id: &str| {
155 let prompt = Assets::get(&format!("prompts/{}.hbs", id))
156 .unwrap_or_else(|| panic!("{} prompt template not found", id))
157 .data;
158 handlebars
159 .register_template_string(id, String::from_utf8_lossy(&prompt))
160 .map_err(Box::new)
161 };
162
163 register_template("content_prompt")?;
164 register_template("terminal_assistant_prompt")?;
165 register_template("edit_workflow")?;
166 register_template("step_resolution")?;
167
168 Ok(())
169 }
170
171 pub fn generate_content_prompt(
172 &self,
173 user_prompt: String,
174 language_name: Option<&str>,
175 buffer: BufferSnapshot,
176 transform_range: Range<usize>,
177 selected_ranges: Vec<Range<usize>>,
178 transform_context_range: Range<usize>,
179 ) -> Result<String, RenderError> {
180 let content_type = match language_name {
181 None | Some("Markdown" | "Plain Text") => "text",
182 Some(_) => "code",
183 };
184
185 const MAX_CTX: usize = 50000;
186 let mut is_truncated = false;
187
188 let before_range = 0..transform_range.start;
189 let truncated_before = if before_range.len() > MAX_CTX {
190 is_truncated = true;
191 transform_range.start - MAX_CTX..transform_range.start
192 } else {
193 before_range
194 };
195
196 let after_range = transform_range.end..buffer.len();
197 let truncated_after = if after_range.len() > MAX_CTX {
198 is_truncated = true;
199 transform_range.end..transform_range.end + MAX_CTX
200 } else {
201 after_range
202 };
203
204 let mut document_content = String::new();
205 for chunk in buffer.text_for_range(truncated_before) {
206 document_content.push_str(chunk);
207 }
208
209 document_content.push_str("<rewrite_this>\n");
210 for chunk in buffer.text_for_range(transform_range.clone()) {
211 document_content.push_str(chunk);
212 }
213 document_content.push_str("\n</rewrite_this>");
214
215 for chunk in buffer.text_for_range(truncated_after) {
216 document_content.push_str(chunk);
217 }
218
219 let mut rewrite_section = String::new();
220 for chunk in buffer.text_for_range(transform_range.clone()) {
221 rewrite_section.push_str(chunk);
222 }
223
224 let mut rewrite_section_prefix = String::new();
225 for chunk in buffer.text_for_range(transform_context_range.start..transform_range.start) {
226 rewrite_section_prefix.push_str(chunk);
227 }
228
229 let mut rewrite_section_suffix = String::new();
230 for chunk in buffer.text_for_range(transform_range.end..transform_context_range.end) {
231 rewrite_section_suffix.push_str(chunk);
232 }
233
234 let rewrite_section_with_edits = {
235 let mut section_with_selections = String::new();
236 let mut last_end = 0;
237 for selected_range in &selected_ranges {
238 if selected_range.start > last_end {
239 section_with_selections.push_str(
240 &rewrite_section[last_end..selected_range.start - transform_range.start],
241 );
242 }
243 if selected_range.start == selected_range.end {
244 section_with_selections.push_str("<insert_here></insert_here>");
245 } else {
246 section_with_selections.push_str("<edit_here>");
247 section_with_selections.push_str(
248 &rewrite_section[selected_range.start - transform_range.start
249 ..selected_range.end - transform_range.start],
250 );
251 section_with_selections.push_str("</edit_here>");
252 }
253 last_end = selected_range.end - transform_range.start;
254 }
255 if last_end < rewrite_section.len() {
256 section_with_selections.push_str(&rewrite_section[last_end..]);
257 }
258 section_with_selections
259 };
260
261 let has_insertion = selected_ranges.iter().any(|range| range.start == range.end);
262 let has_replacement = selected_ranges.iter().any(|range| range.start != range.end);
263
264 let context = ContentPromptContext {
265 content_type: content_type.to_string(),
266 language_name: language_name.map(|s| s.to_string()),
267 is_truncated,
268 document_content,
269 user_prompt,
270 rewrite_section,
271 rewrite_section_prefix,
272 rewrite_section_suffix,
273 rewrite_section_with_edits,
274 has_insertion,
275 has_replacement,
276 };
277
278 self.handlebars.lock().render("content_prompt", &context)
279 }
280
281 pub fn generate_terminal_assistant_prompt(
282 &self,
283 user_prompt: &str,
284 shell: Option<&str>,
285 working_directory: Option<&str>,
286 latest_output: &[String],
287 ) -> Result<String, RenderError> {
288 let context = TerminalAssistantPromptContext {
289 os: std::env::consts::OS.to_string(),
290 arch: std::env::consts::ARCH.to_string(),
291 shell: shell.map(|s| s.to_string()),
292 working_directory: working_directory.map(|s| s.to_string()),
293 latest_output: latest_output.to_vec(),
294 user_prompt: user_prompt.to_string(),
295 };
296
297 self.handlebars
298 .lock()
299 .render("terminal_assistant_prompt", &context)
300 }
301
302 pub fn generate_workflow_prompt(&self) -> Result<String, RenderError> {
303 self.handlebars.lock().render("edit_workflow", &())
304 }
305
306 pub fn generate_step_resolution_prompt(
307 &self,
308 context: &StepResolutionContext,
309 ) -> Result<String, RenderError> {
310 self.handlebars.lock().render("step_resolution", context)
311 }
312}