format_prompt.rs

  1use crate::{
  2    PromptFormat,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    load_project::run_load_project,
  6    progress::{Progress, Step},
  7    retrieve_context::run_context_retrieval,
  8};
  9use anyhow::{Context as _, Result};
 10use edit_prediction::{
 11    EditPredictionStore,
 12    zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
 13};
 14use gpui::{AsyncApp, Entity};
 15use std::fmt::Write as _;
 16use std::sync::Arc;
 17use zeta_prompt::format_zeta_prompt;
 18
 19pub async fn run_format_prompt(
 20    example: &mut Example,
 21    prompt_format: PromptFormat,
 22    app_state: Arc<EpAppState>,
 23    mut cx: AsyncApp,
 24) -> Result<()> {
 25    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 26
 27    let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 28
 29    match prompt_format {
 30        PromptFormat::Teacher => {
 31            step_progress.set_substatus("formatting teacher prompt");
 32            let prompt = TeacherPrompt::format_prompt(example);
 33            example.prompt = Some(ExamplePrompt {
 34                input: prompt,
 35                expected_output: example
 36                    .spec
 37                    .expected_patches
 38                    .first()
 39                    .cloned()
 40                    .unwrap_or_default(),
 41                format: prompt_format,
 42            });
 43        }
 44        PromptFormat::Zeta2 => {
 45            step_progress.set_substatus("loading project");
 46            run_load_project(example, app_state, cx.clone()).await?;
 47
 48            step_progress.set_substatus("formatting zeta2 prompt");
 49
 50            let ep_store: Entity<EditPredictionStore> = cx.update(|cx| {
 51                EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
 52            })?;
 53
 54            let state = example.state.as_ref().context("state must be set")?;
 55            let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot());
 56            let project = state.project.clone();
 57            let (_, input) =
 58                ep_store.update(&mut cx, |ep_store: &mut EditPredictionStore, cx| {
 59                    let events = ep_store
 60                        .edit_history_for_project(&project, cx)
 61                        .into_iter()
 62                        .map(|e| e.event)
 63                        .collect();
 64                    anyhow::Ok(zeta2_prompt_input(
 65                        &snapshot,
 66                        example
 67                            .context
 68                            .as_ref()
 69                            .context("context must be set")?
 70                            .files
 71                            .clone(),
 72                        events,
 73                        example.spec.cursor_path.clone(),
 74                        example
 75                            .buffer
 76                            .as_ref()
 77                            .context("buffer must be set")?
 78                            .cursor_offset,
 79                    ))
 80                })?;
 81            let prompt = format_zeta_prompt(&input);
 82            let expected_output = zeta2_output_for_patch(
 83                &input,
 84                &example
 85                    .spec
 86                    .expected_patches
 87                    .first()
 88                    .context("expected patches is empty")?
 89                    .clone(),
 90            )?;
 91            example.prompt = Some(ExamplePrompt {
 92                input: prompt,
 93                expected_output,
 94                format: prompt_format,
 95            });
 96        }
 97    };
 98    Ok(())
 99}
100
101pub struct TeacherPrompt;
102
103impl TeacherPrompt {
104    const PROMPT: &str = include_str!("teacher.prompt.md");
105    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
106    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
107    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
108
109    /// Truncate edit history to this number of last lines
110    const MAX_HISTORY_LINES: usize = 128;
111
112    pub fn format_prompt(example: &Example) -> String {
113        let edit_history = Self::format_edit_history(&example.spec.edit_history);
114        let context = Self::format_context(example);
115        let cursor_excerpt = Self::format_cursor_excerpt(example);
116
117        let prompt = Self::PROMPT
118            .replace("{{context}}", &context)
119            .replace("{{edit_history}}", &edit_history)
120            .replace("{{cursor_excerpt}}", &cursor_excerpt);
121
122        prompt
123    }
124
125    pub fn parse(example: &Example, response: &str) -> Result<String> {
126        // Ideally, we should always be able to find cursor position in the retrieved context.
127        // In reality, sometimes we don't find it for these reasons:
128        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
129        //    (can be fixed by getting cursor coordinates at the load_example stage)
130        // 2. Context retriever just didn't include cursor line.
131        //
132        // In that case, fallback to using `cursor_position` as excerpt.
133        let example_buffer = example
134            .buffer
135            .as_ref()
136            .context("`buffer` should be filled in in the context collection step")?;
137
138        // Extract updated (new) editable region from the model response.
139        // The model may include editable region markers in its output, so we need to strip them.
140        let new_editable_region = extract_last_codeblock(response);
141        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
142
143        let old_editable_region =
144            example_buffer.content[example_buffer.editable_range.clone()].to_string();
145
146        // Normalize leading newlines: if old starts with newline but new doesn't,
147        // prepend newline to new to preserve whitespace structure.
148        // This handles the case where the model drops the leading blank line.
149        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
150            new_editable_region.insert(0, '\n');
151        }
152
153        let editable_region_start_line = example_buffer.content
154            [..example_buffer.editable_range.start]
155            .matches('\n')
156            .count();
157
158        let diff = language::unified_diff_with_offsets(
159            &old_editable_region,
160            &new_editable_region,
161            editable_region_start_line as u32,
162            editable_region_start_line as u32,
163        );
164
165        let diff = indoc::formatdoc! {"
166            --- a/{path}
167            +++ b/{path}
168            {diff}",
169            path = example.spec.cursor_path.to_string_lossy(),
170            diff = diff,
171        };
172
173        Ok(diff)
174    }
175
176    fn format_edit_history(edit_history: &str) -> String {
177        // Strip comments ("garbage lines") from edit history
178        let lines = edit_history
179            .lines()
180            .filter(|&s| Self::is_udiff_content_line(s))
181            .collect::<Vec<_>>();
182
183        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
184            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
185        } else {
186            &lines
187        };
188
189        if history_lines.is_empty() {
190            return "(No edit history)".to_string();
191        }
192
193        history_lines.join("\n")
194    }
195
196    fn format_context(example: &Example) -> String {
197        let context = example
198            .context
199            .as_ref()
200            .expect("Missing context retriever step");
201
202        if context.files.is_empty() {
203            return "(No context)".to_string();
204        }
205
206        let mut prompt = String::new();
207        for file in context.files.as_ref() {
208            let path_str = file.path.to_string_lossy();
209            writeln!(&mut prompt, "`````{path_str}").ok();
210            let mut prev_row = 0;
211            for excerpt in &file.excerpts {
212                if excerpt.row_range.start > prev_row {
213                    prompt.push_str("\n");
214                }
215                prompt.push_str(&excerpt.text);
216                prompt.push('\n');
217                prev_row = excerpt.row_range.end;
218            }
219            if prev_row < file.max_row {
220                prompt.push_str("\n");
221            }
222            prompt.push_str("\n`````");
223        }
224
225        prompt
226    }
227
228    fn format_cursor_excerpt(example: &Example) -> String {
229        let mut result = String::new();
230
231        let example_buffer = example.buffer.as_ref().unwrap();
232
233        let path_str = example.spec.cursor_path.to_string_lossy();
234        result.push_str(&format!("`````{path_str}\n"));
235        result.push_str(
236            &example_buffer.content
237                [example_buffer.context_range.start..example_buffer.editable_range.start],
238        );
239        result.push_str(Self::EDITABLE_REGION_START);
240        result.push_str(
241            &example_buffer.content
242                [example_buffer.editable_range.start..example_buffer.cursor_offset],
243        );
244        result.push_str(Self::USER_CURSOR_MARKER);
245        result.push_str(
246            &example_buffer.content
247                [example_buffer.cursor_offset..example_buffer.editable_range.end],
248        );
249        result.push_str(Self::EDITABLE_REGION_END);
250        result.push_str(
251            &example_buffer.content
252                [example_buffer.editable_range.end..example_buffer.context_range.end],
253        );
254        result.push_str("\n`````");
255
256        result
257    }
258
259    fn extract_editable_region(text: &str) -> String {
260        let start = text
261            .find(Self::EDITABLE_REGION_START)
262            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
263        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
264
265        let region = &text[start..end];
266        let region = region.strip_suffix('\n').unwrap_or(region);
267
268        region.replace(Self::USER_CURSOR_MARKER, "")
269    }
270
271    fn is_udiff_content_line(s: &str) -> bool {
272        s.starts_with("-")
273            || s.starts_with("+")
274            || s.starts_with(" ")
275            || s.starts_with("---")
276            || s.starts_with("+++")
277            || s.starts_with("@@")
278    }
279}
280
281fn extract_last_codeblock(text: &str) -> String {
282    let mut last_block = None;
283    let mut search_start = 0;
284
285    while let Some(start) = text[search_start..].find("```") {
286        let start = start + search_start;
287        let bytes = text.as_bytes();
288        let mut backtick_end = start;
289
290        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
291            backtick_end += 1;
292        }
293
294        let backtick_count = backtick_end - start;
295        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
296
297        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
298            backtick_end += 1;
299        }
300
301        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
302            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
303            last_block = Some(code_block.to_string());
304            search_start = backtick_end + end_pos + closing_pattern.len();
305        } else {
306            break;
307        }
308    }
309
310    last_block.unwrap_or_else(|| text.to_string())
311}
312
313#[cfg(test)]
314mod tests {
315    use super::*;
316
317    #[test]
318    fn test_extract_last_code_block() {
319        let text = indoc::indoc! {"
320            Some thinking
321
322            ```
323            first block
324            ```
325
326            `````path='something' lines=1:2
327            last block
328            `````
329            "};
330        let last_block = extract_last_codeblock(text);
331        assert_eq!(last_block, "last block\n");
332    }
333
334    #[test]
335    fn test_extract_codeblock_with_nested_fences() {
336        let text = indoc::indoc! {"
337            `````
338            content with ``` inline
339            and ```python nested
340            more content
341            `````
342            "};
343        let last_block = extract_last_codeblock(text);
344        assert_eq!(
345            last_block,
346            "content with ``` inline\nand ```python nested\nmore content\n"
347        );
348    }
349
350    #[test]
351    fn test_extract_codeblock_ignores_inline_backticks() {
352        let text = indoc::indoc! {"
353            `````
354            here is some `code` with inline backticks
355            and here```more```stuff
356            `````
357            "};
358        let last_block = extract_last_codeblock(text);
359        assert_eq!(
360            last_block,
361            "here is some `code` with inline backticks\nand here```more```stuff\n"
362        );
363    }
364
365    #[test]
366    fn test_extract_editable_region() {
367        let text = indoc::indoc! {"
368            some lines
369            are
370            here
371            <|editable_region_start|>
372            one
373            two three
374
375            <|editable_region_end|>
376            more
377            lines here
378            "};
379        let parsed = TeacherPrompt::extract_editable_region(text);
380        assert_eq!(
381            parsed,
382            indoc::indoc! {"
383            one
384            two three"}
385        );
386    }
387
388    #[test]
389    fn test_extract_last_codeblock_nested_bibtex() {
390        let text = indoc::indoc! {r#"
391            Looking at the edit history, I can see that a Citation section was just added.
392
393            `````
394            ## Collaborations
395            Our mission is to create a 4D generative model.
396
397            ## Citation
398
399            If you found Unique3D helpful, please cite our report:
400            ```bibtex
401            @misc{wu2024unique3d,
402                  title={Unique3D},
403            }
404            ```
405            `````
406            "#};
407        let last_block = extract_last_codeblock(text);
408        assert_eq!(
409            last_block,
410            indoc::indoc! {r#"
411            ## Collaborations
412            Our mission is to create a 4D generative model.
413
414            ## Citation
415
416            If you found Unique3D helpful, please cite our report:
417            ```bibtex
418            @misc{wu2024unique3d,
419                  title={Unique3D},
420            }
421            ```
422            "#}
423        );
424    }
425
426    #[test]
427    fn test_extract_editable_region_no_markers() {
428        let text = indoc::indoc! {"
429            one
430            two three"};
431        let parsed = TeacherPrompt::extract_editable_region(text);
432        assert_eq!(
433            parsed,
434            indoc::indoc! {"
435            one
436            two three"}
437        );
438    }
439
440    #[test]
441    fn test_extract_editable_region_strips_cursor_marker() {
442        let text = indoc::indoc! {"
443            <|editable_region_start|>
444            one
445            <|user_cursor|>two three
446
447            <|editable_region_end|>
448            "};
449        let parsed = TeacherPrompt::extract_editable_region(text);
450        assert_eq!(
451            parsed,
452            indoc::indoc! {"
453            one
454            two three"}
455        );
456    }
457}