format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{ExampleProgress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result};
  9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
 10use gpui::{AppContext, AsyncApp};
 11use language::{Buffer, OffsetRangeExt, Point};
 12use similar::DiffableStr;
 13use std::sync::Arc;
 14use std::{fmt::Write as _, ops::Range};
 15use zeta_prompt::format_zeta_prompt;
 16
 17pub async fn run_format_prompt(
 18    example: &mut Example,
 19    args: &FormatPromptArgs,
 20    app_state: Arc<EpAppState>,
 21    example_progress: &ExampleProgress,
 22    cx: AsyncApp,
 23) -> Result<()> {
 24    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 25
 26    let step_progress = example_progress.start(Step::FormatPrompt);
 27
 28    let prompt_inputs = example
 29        .prompt_inputs
 30        .as_ref()
 31        .context("prompt_inputs must be set after context retrieval")?;
 32
 33    let language = app_state
 34        .languages
 35        .load_language_for_file_path(&example.spec.cursor_path)
 36        .await
 37        .ok();
 38    let snapshot_fut = cx.update(|cx| {
 39        Buffer::build_snapshot(
 40            prompt_inputs.content.as_str().into(),
 41            language,
 42            Some(app_state.languages.clone()),
 43            cx,
 44        )
 45    });
 46    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
 47    let snapshot = cx.background_spawn(snapshot_fut).await;
 48
 49    match args.provider {
 50        PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
 51            step_progress.set_substatus("formatting teacher prompt");
 52
 53            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 54                cursor_point,
 55                &snapshot,
 56                edit_prediction::zeta2::max_editable_tokens(version),
 57                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 58            );
 59            let editable_range = editable_range.to_offset(&snapshot);
 60            let context_range = context_range.to_offset(&snapshot);
 61
 62            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 63            example.prompt = Some(ExamplePrompt {
 64                input: prompt,
 65                expected_output: example
 66                    .spec
 67                    .expected_patches
 68                    .first()
 69                    .cloned()
 70                    .unwrap_or_default(),
 71                provider: args.provider,
 72            });
 73        }
 74        PredictionProvider::Zeta2(version) => {
 75            step_progress.set_substatus("formatting zeta2 prompt");
 76
 77            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 78                cursor_point,
 79                &snapshot,
 80                edit_prediction::zeta2::max_editable_tokens(version),
 81                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 82            );
 83            let editable_range = editable_range.to_offset(&snapshot);
 84            let context_range = context_range.to_offset(&snapshot);
 85
 86            let context_start = context_range.start;
 87            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 88            let editable_range_in_excerpt =
 89                (editable_range.start - context_start)..(editable_range.end - context_start);
 90            let input = zeta_prompt::ZetaPromptInput {
 91                cursor_path: example.spec.cursor_path.clone(),
 92                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
 93                editable_range_in_excerpt,
 94                cursor_offset_in_excerpt,
 95                events: prompt_inputs.edit_history.clone(),
 96                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
 97            };
 98            let prompt = format_zeta_prompt(&input, version);
 99            let expected_output = zeta2_output_for_patch(
100                &input,
101                &example
102                    .spec
103                    .expected_patches
104                    .first()
105                    .context("expected patches is empty")?
106                    .clone(),
107            )?;
108            example.prompt = Some(ExamplePrompt {
109                input: prompt,
110                expected_output,
111                provider: args.provider,
112            });
113        }
114        _ => {
115            panic!("Cannot format prompt for {:?}", args.provider);
116        }
117    };
118    Ok(())
119}
120
121pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
122    let mut old_editable_region =
123        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
124
125    if !old_editable_region.ends_with_newline() {
126        old_editable_region.push('\n');
127    }
128
129    edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
130        format!(
131            "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
132            patch, old_editable_region
133        )
134    })
135}
136
137pub struct TeacherPrompt;
138
139impl TeacherPrompt {
140    const PROMPT: &str = include_str!("teacher.prompt.md");
141    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
142    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
143    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
144
145    /// Truncate edit history to this number of last lines
146    const MAX_HISTORY_LINES: usize = 128;
147
148    pub fn format_prompt(
149        example: &Example,
150        editable_range: Range<usize>,
151        context_range: Range<usize>,
152    ) -> String {
153        let edit_history = Self::format_edit_history(&example.spec.edit_history);
154        let context = Self::format_context(example);
155        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
156
157        let prompt = Self::PROMPT
158            .replace("{{context}}", &context)
159            .replace("{{edit_history}}", &edit_history)
160            .replace("{{cursor_excerpt}}", &cursor_excerpt);
161
162        prompt
163    }
164
165    pub fn parse(example: &Example, response: &str) -> Result<String> {
166        // Extract updated (new) editable region from the model response.
167        // The model may include editable region markers in its output, so we need to strip them.
168        let new_editable_region = extract_last_codeblock(response);
169        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
170        let old_editable_region = Self::extract_editable_region(
171            &example
172                .prompt
173                .as_ref()
174                .context("example prompt missing")?
175                .input,
176        );
177        let prompt_inputs = example
178            .prompt_inputs
179            .as_ref()
180            .context("example is missing prompt inputs")?;
181
182        // Normalize leading newlines: if old starts with newline but new doesn't,
183        // prepend newline to new to preserve whitespace structure.
184        // This handles the case where the model drops the leading blank line.
185        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
186            new_editable_region.insert(0, '\n');
187        }
188
189        let (editable_region_offset, _) = prompt_inputs
190            .content
191            .match_indices(&old_editable_region)
192            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
193            .unwrap();
194        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
195            .matches('\n')
196            .count();
197
198        let diff = language::unified_diff_with_offsets(
199            &old_editable_region,
200            &new_editable_region,
201            editable_region_start_line as u32,
202            editable_region_start_line as u32,
203        );
204
205        let diff = indoc::formatdoc! {"
206            --- a/{path}
207            +++ b/{path}
208            {diff}",
209            path = example.spec.cursor_path.to_string_lossy(),
210            diff = diff,
211        };
212
213        Ok(diff)
214    }
215
216    fn format_edit_history(edit_history: &str) -> String {
217        // Strip comments ("garbage lines") from edit history
218        let lines = edit_history
219            .lines()
220            .filter(|&s| Self::is_udiff_content_line(s))
221            .collect::<Vec<_>>();
222
223        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
224            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
225        } else {
226            &lines
227        };
228
229        if history_lines.is_empty() {
230            return "(No edit history)".to_string();
231        }
232
233        history_lines.join("\n")
234    }
235
236    fn format_context(example: &Example) -> String {
237        let related_files = example
238            .prompt_inputs
239            .as_ref()
240            .and_then(|pi| pi.related_files.as_ref());
241
242        let Some(related_files) = related_files else {
243            return "(No context)".to_string();
244        };
245
246        if related_files.is_empty() {
247            return "(No context)".to_string();
248        }
249
250        let mut prompt = String::new();
251        for file in related_files {
252            let path_str = file.path.to_string_lossy();
253            writeln!(&mut prompt, "`````{path_str}").ok();
254            let mut prev_row = 0;
255            for excerpt in &file.excerpts {
256                if excerpt.row_range.start > prev_row {
257                    prompt.push_str("\n");
258                }
259                prompt.push_str(&excerpt.text);
260                prompt.push('\n');
261                prev_row = excerpt.row_range.end;
262            }
263            if prev_row < file.max_row {
264                prompt.push_str("\n");
265            }
266            prompt.push_str("\n`````");
267        }
268
269        prompt
270    }
271
272    fn format_cursor_excerpt(
273        example: &Example,
274        editable_range: Range<usize>,
275        context_range: Range<usize>,
276    ) -> String {
277        let mut result = String::new();
278
279        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
280
281        let path_str = example.spec.cursor_path.to_string_lossy();
282        result.push_str(&format!("`````{path_str}\n"));
283        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
284        result.push_str(Self::EDITABLE_REGION_START);
285        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
286        result.push_str(Self::USER_CURSOR_MARKER);
287        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
288        result.push_str(Self::EDITABLE_REGION_END);
289        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
290        result.push_str("\n`````");
291
292        result
293    }
294
295    fn extract_editable_region(text: &str) -> String {
296        let start = text
297            .find(Self::EDITABLE_REGION_START)
298            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
299        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
300
301        let region = &text[start..end];
302        let region = region.strip_suffix('\n').unwrap_or(region);
303
304        region.replace(Self::USER_CURSOR_MARKER, "")
305    }
306
307    fn is_udiff_content_line(s: &str) -> bool {
308        s.starts_with("-")
309            || s.starts_with("+")
310            || s.starts_with(" ")
311            || s.starts_with("---")
312            || s.starts_with("+++")
313            || s.starts_with("@@")
314    }
315}
316
317fn extract_last_codeblock(text: &str) -> String {
318    let mut last_block = None;
319    let mut search_start = 0;
320
321    while let Some(start) = text[search_start..].find("```") {
322        let start = start + search_start;
323        let bytes = text.as_bytes();
324        let mut backtick_end = start;
325
326        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
327            backtick_end += 1;
328        }
329
330        let backtick_count = backtick_end - start;
331        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
332
333        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
334            backtick_end += 1;
335        }
336
337        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
338            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
339            last_block = Some(code_block.to_string());
340            search_start = backtick_end + end_pos + closing_pattern.len();
341        } else {
342            break;
343        }
344    }
345
346    last_block.unwrap_or_else(|| text.to_string())
347}
348
349#[cfg(test)]
350mod tests {
351    use super::*;
352
353    #[test]
354    fn test_extract_last_code_block() {
355        let text = indoc::indoc! {"
356            Some thinking
357
358            ```
359            first block
360            ```
361
362            `````path='something' lines=1:2
363            last block
364            `````
365            "};
366        let last_block = extract_last_codeblock(text);
367        assert_eq!(last_block, "last block\n");
368    }
369
370    #[test]
371    fn test_extract_codeblock_with_nested_fences() {
372        let text = indoc::indoc! {"
373            `````
374            content with ``` inline
375            and ```python nested
376            more content
377            `````
378            "};
379        let last_block = extract_last_codeblock(text);
380        assert_eq!(
381            last_block,
382            "content with ``` inline\nand ```python nested\nmore content\n"
383        );
384    }
385
386    #[test]
387    fn test_extract_codeblock_ignores_inline_backticks() {
388        let text = indoc::indoc! {"
389            `````
390            here is some `code` with inline backticks
391            and here```more```stuff
392            `````
393            "};
394        let last_block = extract_last_codeblock(text);
395        assert_eq!(
396            last_block,
397            "here is some `code` with inline backticks\nand here```more```stuff\n"
398        );
399    }
400
401    #[test]
402    fn test_extract_editable_region() {
403        let text = indoc::indoc! {"
404            some lines
405            are
406            here
407            <|editable_region_start|>
408            one
409            two three
410
411            <|editable_region_end|>
412            more
413            lines here
414            "};
415        let parsed = TeacherPrompt::extract_editable_region(text);
416        assert_eq!(
417            parsed,
418            indoc::indoc! {"
419            one
420            two three"}
421        );
422    }
423
424    #[test]
425    fn test_extract_last_codeblock_nested_bibtex() {
426        let text = indoc::indoc! {r#"
427            Looking at the edit history, I can see that a Citation section was just added.
428
429            `````
430            ## Collaborations
431            Our mission is to create a 4D generative model.
432
433            ## Citation
434
435            If you found Unique3D helpful, please cite our report:
436            ```bibtex
437            @misc{wu2024unique3d,
438                  title={Unique3D},
439            }
440            ```
441            `````
442            "#};
443        let last_block = extract_last_codeblock(text);
444        assert_eq!(
445            last_block,
446            indoc::indoc! {r#"
447            ## Collaborations
448            Our mission is to create a 4D generative model.
449
450            ## Citation
451
452            If you found Unique3D helpful, please cite our report:
453            ```bibtex
454            @misc{wu2024unique3d,
455                  title={Unique3D},
456            }
457            ```
458            "#}
459        );
460    }
461
462    #[test]
463    fn test_extract_editable_region_no_markers() {
464        let text = indoc::indoc! {"
465            one
466            two three"};
467        let parsed = TeacherPrompt::extract_editable_region(text);
468        assert_eq!(
469            parsed,
470            indoc::indoc! {"
471            one
472            two three"}
473        );
474    }
475
476    #[test]
477    fn test_extract_editable_region_strips_cursor_marker() {
478        let text = indoc::indoc! {"
479            <|editable_region_start|>
480            one
481            <|user_cursor|>two three
482
483            <|editable_region_end|>
484            "};
485        let parsed = TeacherPrompt::extract_editable_region(text);
486        assert_eq!(
487            parsed,
488            indoc::indoc! {"
489            one
490            two three"}
491        );
492    }
493}