format_prompt.rs

  1use crate::{
  2    PromptFormat,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    load_project::run_load_project,
  6    progress::{Progress, Step},
  7    retrieve_context::run_context_retrieval,
  8};
  9use anyhow::{Context as _, Result, ensure};
 10use edit_prediction::{
 11    EditPredictionStore,
 12    zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
 13};
 14use gpui::AsyncApp;
 15use std::sync::Arc;
 16use zeta_prompt::format_zeta_prompt;
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    prompt_format: PromptFormat,
 21    app_state: Arc<EpAppState>,
 22    mut cx: AsyncApp,
 23) -> Result<()> {
 24    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 25
 26    let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 27
 28    match prompt_format {
 29        PromptFormat::Teacher => {
 30            step_progress.set_substatus("formatting teacher prompt");
 31            let prompt = TeacherPrompt::format_prompt(example);
 32            example.prompt = Some(ExamplePrompt {
 33                input: prompt,
 34                expected_output: example
 35                    .spec
 36                    .expected_patches
 37                    .first()
 38                    .cloned()
 39                    .unwrap_or_default(),
 40                format: prompt_format,
 41            });
 42        }
 43        PromptFormat::Zeta2 => {
 44            step_progress.set_substatus("loading project");
 45            run_load_project(example, app_state, cx.clone()).await?;
 46
 47            step_progress.set_substatus("formatting zeta2 prompt");
 48
 49            let ep_store = cx.update(|cx| {
 50                EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
 51            })??;
 52
 53            let state = example.state.as_ref().context("state must be set")?;
 54            let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
 55            let project = state.project.clone();
 56            let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
 57                let events = ep_store
 58                    .edit_history_for_project(&project, cx)
 59                    .into_iter()
 60                    .map(|e| e.event)
 61                    .collect();
 62                anyhow::Ok(zeta2_prompt_input(
 63                    &snapshot,
 64                    example
 65                        .context
 66                        .as_ref()
 67                        .context("context must be set")?
 68                        .files
 69                        .clone(),
 70                    events,
 71                    example.spec.cursor_path.clone(),
 72                    example
 73                        .buffer
 74                        .as_ref()
 75                        .context("buffer must be set")?
 76                        .cursor_offset,
 77                ))
 78            })??;
 79            let prompt = format_zeta_prompt(&input);
 80            let expected_output = zeta2_output_for_patch(
 81                &input,
 82                &example
 83                    .spec
 84                    .expected_patches
 85                    .first()
 86                    .context("expected patches is empty")?
 87                    .clone(),
 88            )?;
 89            example.prompt = Some(ExamplePrompt {
 90                input: prompt,
 91                expected_output,
 92                format: prompt_format,
 93            });
 94        }
 95    };
 96    Ok(())
 97}
 98
 99pub struct TeacherPrompt;
100
101impl TeacherPrompt {
102    const PROMPT: &str = include_str!("teacher.prompt.md");
103    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
104    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
105    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
106
107    /// Truncate edit history to this number of last lines
108    const MAX_HISTORY_LINES: usize = 128;
109
110    pub fn format_prompt(example: &Example) -> String {
111        let edit_history = Self::format_edit_history(&example.spec.edit_history);
112        let context = Self::format_context(example);
113        let editable_region = Self::format_editable_region(example);
114
115        let prompt = Self::PROMPT
116            .replace("{{context}}", &context)
117            .replace("{{edit_history}}", &edit_history)
118            .replace("{{editable_region}}", &editable_region);
119
120        prompt
121    }
122
123    pub fn parse(example: &Example, response: &str) -> Result<String> {
124        // Ideally, we should always be able to find cursor position in the retrieved context.
125        // In reality, sometimes we don't find it for these reasons:
126        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
127        //    (can be fixed by getting cursor coordinates at the load_example stage)
128        // 2. Context retriever just didn't include cursor line.
129        //
130        // In that case, fallback to using `cursor_position` as excerpt.
131        let example_buffer = example
132            .buffer
133            .as_ref()
134            .context("`buffer` should be filled in in the context collection step")?;
135        let cursor_file = &example_buffer.content;
136
137        // Extract updated (new) editable region from the model response.
138        // The model may include editable region markers in its output, so we need to strip them.
139        let new_editable_region = extract_last_codeblock(response);
140        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
141
142        let old_editable_region =
143            example_buffer.content[example_buffer.editable_range.clone()].to_string();
144
145        // Normalize leading newlines: if old starts with newline but new doesn't,
146        // prepend newline to new to preserve whitespace structure.
147        // This handles the case where the model drops the leading blank line.
148        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
149            new_editable_region.insert(0, '\n');
150        }
151
152        ensure!(
153            cursor_file.contains(&old_editable_region),
154            "Something's wrong: editable_region is not found in the cursor file"
155        );
156
157        // Apply editable region to a larger context and compute diff.
158        // This is needed to get a better context lines around the editable region
159        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
160        let diff = language::unified_diff(&cursor_file, &edited_file);
161
162        let diff = indoc::formatdoc! {"
163            --- a/{path}
164            +++ b/{path}
165            {diff}",
166            path = example.spec.cursor_path.to_string_lossy(),
167            diff = diff,
168        };
169
170        Ok(diff)
171    }
172
173    fn format_edit_history(edit_history: &str) -> String {
174        // Strip comments ("garbage lines") from edit history
175        let lines = edit_history
176            .lines()
177            .filter(|&s| Self::is_udiff_content_line(s))
178            .collect::<Vec<_>>();
179
180        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
181            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
182        } else {
183            &lines
184        };
185
186        if history_lines.is_empty() {
187            return "(No edit history)".to_string();
188        }
189
190        history_lines.join("\n")
191    }
192
193    fn format_context(example: &Example) -> String {
194        assert!(example.context.is_some(), "Missing context retriever step");
195
196        let mut prompt = String::new();
197        zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
198
199        prompt
200    }
201
202    fn format_editable_region(example: &Example) -> String {
203        let mut result = String::new();
204
205        let example_buffer = example.buffer.as_ref().unwrap();
206
207        let path_str = example.spec.cursor_path.to_string_lossy();
208        result.push_str(&format!("`````path=\"{path_str}\"\n"));
209        result.push_str(
210            &example_buffer.content
211                [example_buffer.context_range.start..example_buffer.editable_range.start],
212        );
213        result.push_str(Self::EDITABLE_REGION_START);
214        result.push_str(
215            &example_buffer.content
216                [example_buffer.editable_range.start..example_buffer.cursor_offset],
217        );
218        result.push_str(Self::USER_CURSOR_MARKER);
219        result.push_str(
220            &example_buffer.content
221                [example_buffer.cursor_offset..example_buffer.editable_range.end],
222        );
223        result.push_str(Self::EDITABLE_REGION_END);
224        result.push_str(
225            &example_buffer.content
226                [example_buffer.editable_range.end..example_buffer.context_range.end],
227        );
228        result.push_str("\n`````");
229
230        result
231    }
232
233    fn extract_editable_region(text: &str) -> String {
234        let start = text
235            .find(Self::EDITABLE_REGION_START)
236            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
237        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
238
239        let region = &text[start..end];
240        let region = region.strip_suffix('\n').unwrap_or(region);
241
242        region.replace("<|user_cursor|>", "")
243    }
244
245    fn is_udiff_content_line(s: &str) -> bool {
246        s.starts_with("-")
247            || s.starts_with("+")
248            || s.starts_with(" ")
249            || s.starts_with("---")
250            || s.starts_with("+++")
251            || s.starts_with("@@")
252    }
253}
254
255fn extract_last_codeblock(text: &str) -> String {
256    let mut last_block = None;
257    let mut search_start = 0;
258
259    while let Some(start) = text[search_start..].find("```") {
260        let start = start + search_start;
261        let bytes = text.as_bytes();
262        let mut backtick_end = start;
263
264        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
265            backtick_end += 1;
266        }
267
268        let backtick_count = backtick_end - start;
269        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
270
271        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
272            backtick_end += 1;
273        }
274
275        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
276            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
277            last_block = Some(code_block.to_string());
278            search_start = backtick_end + end_pos + closing_pattern.len();
279        } else {
280            break;
281        }
282    }
283
284    last_block.unwrap_or_else(|| text.to_string())
285}
286
287#[cfg(test)]
288mod tests {
289    use super::*;
290
291    #[test]
292    fn test_extract_last_code_block() {
293        let text = indoc::indoc! {"
294            Some thinking
295
296            ```
297            first block
298            ```
299
300            `````path='something' lines=1:2
301            last block
302            `````
303            "};
304        let last_block = extract_last_codeblock(text);
305        assert_eq!(last_block, "last block\n");
306    }
307
308    #[test]
309    fn test_extract_codeblock_with_nested_fences() {
310        let text = indoc::indoc! {"
311            `````
312            content with ``` inline
313            and ```python nested
314            more content
315            `````
316            "};
317        let last_block = extract_last_codeblock(text);
318        assert_eq!(
319            last_block,
320            "content with ``` inline\nand ```python nested\nmore content\n"
321        );
322    }
323
324    #[test]
325    fn test_extract_codeblock_ignores_inline_backticks() {
326        let text = indoc::indoc! {"
327            `````
328            here is some `code` with inline backticks
329            and here```more```stuff
330            `````
331            "};
332        let last_block = extract_last_codeblock(text);
333        assert_eq!(
334            last_block,
335            "here is some `code` with inline backticks\nand here```more```stuff\n"
336        );
337    }
338
339    #[test]
340    fn test_extract_editable_region() {
341        let text = indoc::indoc! {"
342            some lines
343            are
344            here
345            <|editable_region_start|>
346            one
347            two three
348
349            <|editable_region_end|>
350            more
351            lines here
352            "};
353        let parsed = TeacherPrompt::extract_editable_region(text);
354        assert_eq!(
355            parsed,
356            indoc::indoc! {"
357            one
358            two three
359            "}
360        );
361    }
362
363    #[test]
364    fn test_extract_last_codeblock_nested_bibtex() {
365        let text = indoc::indoc! {r#"
366            Looking at the edit history, I can see that a Citation section was just added.
367
368            `````
369            ## Collaborations
370            Our mission is to create a 4D generative model.
371
372            ## Citation
373
374            If you found Unique3D helpful, please cite our report:
375            ```bibtex
376            @misc{wu2024unique3d,
377                  title={Unique3D},
378            }
379            ```
380            `````
381            "#};
382        let last_block = extract_last_codeblock(text);
383        assert_eq!(
384            last_block,
385            indoc::indoc! {r#"
386            ## Collaborations
387            Our mission is to create a 4D generative model.
388
389            ## Citation
390
391            If you found Unique3D helpful, please cite our report:
392            ```bibtex
393            @misc{wu2024unique3d,
394                  title={Unique3D},
395            }
396            ```
397            "#}
398        );
399    }
400
401    #[test]
402    fn test_extract_editable_region_no_markers() {
403        let text = indoc::indoc! {"
404            one
405            two three
406            "};
407        let parsed = TeacherPrompt::extract_editable_region(text);
408        assert_eq!(
409            parsed,
410            indoc::indoc! {"
411            one
412            two three"}
413        );
414    }
415
416    #[test]
417    fn test_extract_editable_region_strips_cursor_marker() {
418        let text = indoc::indoc! {"
419            <|editable_region_start|>
420            one
421            <|user_cursor|>two three
422
423            <|editable_region_end|>
424            "};
425        let parsed = TeacherPrompt::extract_editable_region(text);
426        assert_eq!(
427            parsed,
428            indoc::indoc! {"
429            one
430            two three
431            "}
432        );
433    }
434}