format_prompt.rs

  1use crate::{
  2    PromptFormat,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    load_project::run_load_project,
  6    progress::{Progress, Step},
  7    retrieve_context::run_context_retrieval,
  8};
  9use anyhow::{Context as _, Result};
 10use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
 11use gpui::{AsyncApp, Entity};
 12use similar::DiffableStr;
 13use std::fmt::Write as _;
 14use std::sync::Arc;
 15use zeta_prompt::format_zeta_prompt;
 16
 17pub async fn run_format_prompt(
 18    example: &mut Example,
 19    prompt_format: PromptFormat,
 20    app_state: Arc<EpAppState>,
 21    mut cx: AsyncApp,
 22) -> Result<()> {
 23    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 24
 25    let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 26
 27    match prompt_format {
 28        PromptFormat::Teacher => {
 29            step_progress.set_substatus("formatting teacher prompt");
 30            let prompt = TeacherPrompt::format_prompt(example);
 31            example.prompt = Some(ExamplePrompt {
 32                input: prompt,
 33                expected_output: example
 34                    .spec
 35                    .expected_patches
 36                    .first()
 37                    .cloned()
 38                    .unwrap_or_default(),
 39                format: prompt_format,
 40            });
 41        }
 42        PromptFormat::Zeta2 => {
 43            step_progress.set_substatus("loading project");
 44            run_load_project(example, app_state, cx.clone()).await?;
 45
 46            step_progress.set_substatus("formatting zeta2 prompt");
 47
 48            let ep_store: Entity<EditPredictionStore> = cx.update(|cx| {
 49                EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
 50            })?;
 51
 52            let state = example.state.as_ref().context("state must be set")?;
 53            let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot());
 54            let project = state.project.clone();
 55            let (_, input) =
 56                ep_store.update(&mut cx, |ep_store: &mut EditPredictionStore, cx| {
 57                    let events = ep_store
 58                        .edit_history_for_project(&project, cx)
 59                        .into_iter()
 60                        .map(|e| e.event)
 61                        .collect();
 62                    anyhow::Ok(zeta2_prompt_input(
 63                        &snapshot,
 64                        example
 65                            .context
 66                            .as_ref()
 67                            .context("context must be set")?
 68                            .files
 69                            .clone(),
 70                        events,
 71                        example.spec.cursor_path.clone(),
 72                        example
 73                            .buffer
 74                            .as_ref()
 75                            .context("buffer must be set")?
 76                            .cursor_offset,
 77                    ))
 78                })?;
 79            let prompt = format_zeta_prompt(&input);
 80            let expected_output = zeta2_output_for_patch(
 81                &input,
 82                &example
 83                    .spec
 84                    .expected_patches
 85                    .first()
 86                    .context("expected patches is empty")?
 87                    .clone(),
 88            )?;
 89            example.prompt = Some(ExamplePrompt {
 90                input: prompt,
 91                expected_output,
 92                format: prompt_format,
 93            });
 94        }
 95    };
 96    Ok(())
 97}
 98
 99pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
100    let mut old_editable_region =
101        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
102
103    if !old_editable_region.ends_with_newline() {
104        old_editable_region.push('\n');
105    }
106
107    edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
108        format!(
109            "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
110            patch, old_editable_region
111        )
112    })
113}
114
115pub struct TeacherPrompt;
116
117impl TeacherPrompt {
118    const PROMPT: &str = include_str!("teacher.prompt.md");
119    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
120    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
121    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
122
123    /// Truncate edit history to this number of last lines
124    const MAX_HISTORY_LINES: usize = 128;
125
126    pub fn format_prompt(example: &Example) -> String {
127        let edit_history = Self::format_edit_history(&example.spec.edit_history);
128        let context = Self::format_context(example);
129        let cursor_excerpt = Self::format_cursor_excerpt(example);
130
131        let prompt = Self::PROMPT
132            .replace("{{context}}", &context)
133            .replace("{{edit_history}}", &edit_history)
134            .replace("{{cursor_excerpt}}", &cursor_excerpt);
135
136        prompt
137    }
138
139    pub fn parse(example: &Example, response: &str) -> Result<String> {
140        // Ideally, we should always be able to find cursor position in the retrieved context.
141        // In reality, sometimes we don't find it for these reasons:
142        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
143        //    (can be fixed by getting cursor coordinates at the load_example stage)
144        // 2. Context retriever just didn't include cursor line.
145        //
146        // In that case, fallback to using `cursor_position` as excerpt.
147        let example_buffer = example
148            .buffer
149            .as_ref()
150            .context("`buffer` should be filled in in the context collection step")?;
151
152        // Extract updated (new) editable region from the model response.
153        // The model may include editable region markers in its output, so we need to strip them.
154        let new_editable_region = extract_last_codeblock(response);
155        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
156
157        let old_editable_region =
158            example_buffer.content[example_buffer.editable_range.clone()].to_string();
159
160        // Normalize leading newlines: if old starts with newline but new doesn't,
161        // prepend newline to new to preserve whitespace structure.
162        // This handles the case where the model drops the leading blank line.
163        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
164            new_editable_region.insert(0, '\n');
165        }
166
167        let editable_region_start_line = example_buffer.content
168            [..example_buffer.editable_range.start]
169            .matches('\n')
170            .count();
171
172        let diff = language::unified_diff_with_offsets(
173            &old_editable_region,
174            &new_editable_region,
175            editable_region_start_line as u32,
176            editable_region_start_line as u32,
177        );
178
179        let diff = indoc::formatdoc! {"
180            --- a/{path}
181            +++ b/{path}
182            {diff}",
183            path = example.spec.cursor_path.to_string_lossy(),
184            diff = diff,
185        };
186
187        Ok(diff)
188    }
189
190    fn format_edit_history(edit_history: &str) -> String {
191        // Strip comments ("garbage lines") from edit history
192        let lines = edit_history
193            .lines()
194            .filter(|&s| Self::is_udiff_content_line(s))
195            .collect::<Vec<_>>();
196
197        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
198            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
199        } else {
200            &lines
201        };
202
203        if history_lines.is_empty() {
204            return "(No edit history)".to_string();
205        }
206
207        history_lines.join("\n")
208    }
209
210    fn format_context(example: &Example) -> String {
211        let context = example
212            .context
213            .as_ref()
214            .expect("Missing context retriever step");
215
216        if context.files.is_empty() {
217            return "(No context)".to_string();
218        }
219
220        let mut prompt = String::new();
221        for file in context.files.as_ref() {
222            let path_str = file.path.to_string_lossy();
223            writeln!(&mut prompt, "`````{path_str}").ok();
224            let mut prev_row = 0;
225            for excerpt in &file.excerpts {
226                if excerpt.row_range.start > prev_row {
227                    prompt.push_str("\n");
228                }
229                prompt.push_str(&excerpt.text);
230                prompt.push('\n');
231                prev_row = excerpt.row_range.end;
232            }
233            if prev_row < file.max_row {
234                prompt.push_str("\n");
235            }
236            prompt.push_str("\n`````");
237        }
238
239        prompt
240    }
241
242    fn format_cursor_excerpt(example: &Example) -> String {
243        let mut result = String::new();
244
245        let example_buffer = example.buffer.as_ref().unwrap();
246
247        let path_str = example.spec.cursor_path.to_string_lossy();
248        result.push_str(&format!("`````{path_str}\n"));
249        result.push_str(
250            &example_buffer.content
251                [example_buffer.context_range.start..example_buffer.editable_range.start],
252        );
253        result.push_str(Self::EDITABLE_REGION_START);
254        result.push_str(
255            &example_buffer.content
256                [example_buffer.editable_range.start..example_buffer.cursor_offset],
257        );
258        result.push_str(Self::USER_CURSOR_MARKER);
259        result.push_str(
260            &example_buffer.content
261                [example_buffer.cursor_offset..example_buffer.editable_range.end],
262        );
263        result.push_str(Self::EDITABLE_REGION_END);
264        result.push_str(
265            &example_buffer.content
266                [example_buffer.editable_range.end..example_buffer.context_range.end],
267        );
268        result.push_str("\n`````");
269
270        result
271    }
272
273    fn extract_editable_region(text: &str) -> String {
274        let start = text
275            .find(Self::EDITABLE_REGION_START)
276            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
277        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
278
279        let region = &text[start..end];
280        let region = region.strip_suffix('\n').unwrap_or(region);
281
282        region.replace(Self::USER_CURSOR_MARKER, "")
283    }
284
285    fn is_udiff_content_line(s: &str) -> bool {
286        s.starts_with("-")
287            || s.starts_with("+")
288            || s.starts_with(" ")
289            || s.starts_with("---")
290            || s.starts_with("+++")
291            || s.starts_with("@@")
292    }
293}
294
295fn extract_last_codeblock(text: &str) -> String {
296    let mut last_block = None;
297    let mut search_start = 0;
298
299    while let Some(start) = text[search_start..].find("```") {
300        let start = start + search_start;
301        let bytes = text.as_bytes();
302        let mut backtick_end = start;
303
304        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
305            backtick_end += 1;
306        }
307
308        let backtick_count = backtick_end - start;
309        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
310
311        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
312            backtick_end += 1;
313        }
314
315        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
316            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
317            last_block = Some(code_block.to_string());
318            search_start = backtick_end + end_pos + closing_pattern.len();
319        } else {
320            break;
321        }
322    }
323
324    last_block.unwrap_or_else(|| text.to_string())
325}
326
327#[cfg(test)]
328mod tests {
329    use super::*;
330
331    #[test]
332    fn test_extract_last_code_block() {
333        let text = indoc::indoc! {"
334            Some thinking
335
336            ```
337            first block
338            ```
339
340            `````path='something' lines=1:2
341            last block
342            `````
343            "};
344        let last_block = extract_last_codeblock(text);
345        assert_eq!(last_block, "last block\n");
346    }
347
348    #[test]
349    fn test_extract_codeblock_with_nested_fences() {
350        let text = indoc::indoc! {"
351            `````
352            content with ``` inline
353            and ```python nested
354            more content
355            `````
356            "};
357        let last_block = extract_last_codeblock(text);
358        assert_eq!(
359            last_block,
360            "content with ``` inline\nand ```python nested\nmore content\n"
361        );
362    }
363
364    #[test]
365    fn test_extract_codeblock_ignores_inline_backticks() {
366        let text = indoc::indoc! {"
367            `````
368            here is some `code` with inline backticks
369            and here```more```stuff
370            `````
371            "};
372        let last_block = extract_last_codeblock(text);
373        assert_eq!(
374            last_block,
375            "here is some `code` with inline backticks\nand here```more```stuff\n"
376        );
377    }
378
379    #[test]
380    fn test_extract_editable_region() {
381        let text = indoc::indoc! {"
382            some lines
383            are
384            here
385            <|editable_region_start|>
386            one
387            two three
388
389            <|editable_region_end|>
390            more
391            lines here
392            "};
393        let parsed = TeacherPrompt::extract_editable_region(text);
394        assert_eq!(
395            parsed,
396            indoc::indoc! {"
397            one
398            two three"}
399        );
400    }
401
402    #[test]
403    fn test_extract_last_codeblock_nested_bibtex() {
404        let text = indoc::indoc! {r#"
405            Looking at the edit history, I can see that a Citation section was just added.
406
407            `````
408            ## Collaborations
409            Our mission is to create a 4D generative model.
410
411            ## Citation
412
413            If you found Unique3D helpful, please cite our report:
414            ```bibtex
415            @misc{wu2024unique3d,
416                  title={Unique3D},
417            }
418            ```
419            `````
420            "#};
421        let last_block = extract_last_codeblock(text);
422        assert_eq!(
423            last_block,
424            indoc::indoc! {r#"
425            ## Collaborations
426            Our mission is to create a 4D generative model.
427
428            ## Citation
429
430            If you found Unique3D helpful, please cite our report:
431            ```bibtex
432            @misc{wu2024unique3d,
433                  title={Unique3D},
434            }
435            ```
436            "#}
437        );
438    }
439
440    #[test]
441    fn test_extract_editable_region_no_markers() {
442        let text = indoc::indoc! {"
443            one
444            two three"};
445        let parsed = TeacherPrompt::extract_editable_region(text);
446        assert_eq!(
447            parsed,
448            indoc::indoc! {"
449            one
450            two three"}
451        );
452    }
453
454    #[test]
455    fn test_extract_editable_region_strips_cursor_marker() {
456        let text = indoc::indoc! {"
457            <|editable_region_start|>
458            one
459            <|user_cursor|>two three
460
461            <|editable_region_end|>
462            "};
463        let parsed = TeacherPrompt::extract_editable_region(text);
464        assert_eq!(
465            parsed,
466            indoc::indoc! {"
467            one
468            two three"}
469        );
470    }
471}