format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{ExampleProgress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result};
  9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
 10use gpui::{AppContext, AsyncApp};
 11use language::{Buffer, OffsetRangeExt, Point};
 12use similar::DiffableStr;
 13use std::sync::Arc;
 14use std::{fmt::Write as _, ops::Range};
 15use zeta_prompt::format_zeta_prompt;
 16
 17pub async fn run_format_prompt(
 18    example: &mut Example,
 19    args: &FormatPromptArgs,
 20    app_state: Arc<EpAppState>,
 21    example_progress: &ExampleProgress,
 22    cx: AsyncApp,
 23) -> Result<()> {
 24    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 25
 26    let step_progress = example_progress.start(Step::FormatPrompt);
 27
 28    let prompt_inputs = example
 29        .prompt_inputs
 30        .as_ref()
 31        .context("prompt_inputs must be set after context retrieval")?;
 32
 33    let language = app_state
 34        .languages
 35        .load_language_for_file_path(&example.spec.cursor_path)
 36        .await
 37        .ok();
 38    let snapshot_fut = cx.update(|cx| {
 39        Buffer::build_snapshot(
 40            prompt_inputs.content.as_str().into(),
 41            language,
 42            Some(app_state.languages.clone()),
 43            cx,
 44        )
 45    });
 46    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
 47    let snapshot = cx.background_spawn(snapshot_fut).await;
 48
 49    match args.provider {
 50        PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
 51            step_progress.set_substatus("formatting teacher prompt");
 52
 53            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 54                cursor_point,
 55                &snapshot,
 56                edit_prediction::zeta2::max_editable_tokens(version),
 57                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 58            );
 59            let editable_range = editable_range.to_offset(&snapshot);
 60            let context_range = context_range.to_offset(&snapshot);
 61
 62            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 63            example.prompt = Some(ExamplePrompt {
 64                input: prompt,
 65                expected_output: example
 66                    .spec
 67                    .expected_patches
 68                    .first()
 69                    .cloned()
 70                    .unwrap_or_default(),
 71                provider: args.provider,
 72            });
 73        }
 74        PredictionProvider::Zeta2(version) => {
 75            step_progress.set_substatus("formatting zeta2 prompt");
 76
 77            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 78                cursor_point,
 79                &snapshot,
 80                edit_prediction::zeta2::max_editable_tokens(version),
 81                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 82            );
 83            let editable_range = editable_range.to_offset(&snapshot);
 84            let context_range = context_range.to_offset(&snapshot);
 85
 86            let context_start = context_range.start;
 87            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 88            let editable_range_in_excerpt =
 89                (editable_range.start - context_start)..(editable_range.end - context_start);
 90            let input = zeta_prompt::ZetaPromptInput {
 91                cursor_path: example.spec.cursor_path.clone(),
 92                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
 93                editable_range_in_excerpt,
 94                cursor_offset_in_excerpt,
 95                events: prompt_inputs.edit_history.clone(),
 96                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
 97            };
 98            let prompt = format_zeta_prompt(&input, version);
 99            let expected_output = zeta2_output_for_patch(
100                &input,
101                &example
102                    .spec
103                    .expected_patches
104                    .first()
105                    .context("expected patches is empty")?
106                    .clone(),
107            )?;
108            example.prompt = Some(ExamplePrompt {
109                input: prompt,
110                expected_output,
111                provider: args.provider,
112            });
113        }
114        _ => {
115            panic!("Cannot format prompt for {:?}", args.provider);
116        }
117    };
118    Ok(())
119}
120
121pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
122    let mut old_editable_region =
123        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
124
125    if !old_editable_region.ends_with_newline() {
126        old_editable_region.push('\n');
127    }
128
129    edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
130        format!(
131            "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
132            patch, old_editable_region
133        )
134    })
135}
136
137pub struct TeacherPrompt;
138
139impl TeacherPrompt {
140    const PROMPT: &str = include_str!("teacher.prompt.md");
141    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
142    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
143    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
144
145    /// Truncate edit history to this number of last lines
146    const MAX_HISTORY_LINES: usize = 128;
147
148    pub fn format_prompt(
149        example: &Example,
150        editable_range: Range<usize>,
151        context_range: Range<usize>,
152    ) -> String {
153        let edit_history = Self::format_edit_history(&example.spec.edit_history);
154        let context = Self::format_context(example);
155        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
156
157        let prompt = Self::PROMPT
158            .replace("{{context}}", &context)
159            .replace("{{edit_history}}", &edit_history)
160            .replace("{{cursor_excerpt}}", &cursor_excerpt);
161
162        prompt
163    }
164
165    pub fn parse(example: &Example, response: &str) -> Result<String> {
166        // Extract updated (new) editable region from the model response.
167        // The model may include editable region markers in its output, so we need to strip them.
168        let new_editable_region = extract_last_codeblock(response);
169        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
170        let old_editable_region = Self::extract_editable_region(
171            &example
172                .prompt
173                .as_ref()
174                .context("example prompt missing")?
175                .input,
176        );
177        let prompt_inputs = example
178            .prompt_inputs
179            .as_ref()
180            .context("example is missing prompt inputs")?;
181
182        // Normalize leading newlines: if old starts with newline but new doesn't,
183        // prepend newline to new to preserve whitespace structure.
184        // This handles the case where the model drops the leading blank line.
185        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
186            new_editable_region.insert(0, '\n');
187        }
188
189        let (editable_region_offset, _) = prompt_inputs
190            .content
191            .match_indices(&old_editable_region)
192            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
193            .unwrap();
194        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
195            .matches('\n')
196            .count();
197
198        let diff = language::unified_diff_with_offsets(
199            &old_editable_region,
200            &new_editable_region,
201            editable_region_start_line as u32,
202            editable_region_start_line as u32,
203        );
204
205        let diff = indoc::formatdoc! {"
206            --- a/{path}
207            +++ b/{path}
208            {diff}",
209            path = example.spec.cursor_path.to_string_lossy(),
210            diff = diff,
211        };
212
213        Ok(diff)
214    }
215
216    fn format_edit_history(edit_history: &str) -> String {
217        // Strip comments ("garbage lines") from edit history
218        let lines = edit_history
219            .lines()
220            .filter(|&s| Self::is_udiff_content_line(s))
221            .collect::<Vec<_>>();
222
223        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
224            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
225        } else {
226            &lines
227        };
228
229        if history_lines.is_empty() {
230            return "(No edit history)".to_string();
231        }
232
233        history_lines.join("\n")
234    }
235
236    fn format_context(example: &Example) -> String {
237        let related_files = example
238            .prompt_inputs
239            .as_ref()
240            .and_then(|pi| pi.related_files.as_ref());
241
242        let Some(related_files) = related_files else {
243            return "(No context)".to_string();
244        };
245
246        if related_files.is_empty() {
247            return "(No context)".to_string();
248        }
249
250        let mut prompt = String::new();
251        for file in related_files {
252            let path_str = file.path.to_string_lossy();
253            writeln!(&mut prompt, "`````{path_str}").ok();
254
255            let mut prev_row = 0;
256            for excerpt in &file.excerpts {
257                if excerpt.row_range.start > prev_row {
258                    prompt.push_str("\n");
259                }
260                prompt.push_str(&excerpt.text);
261                prompt.push('\n');
262                prev_row = excerpt.row_range.end;
263            }
264            if prev_row < file.max_row {
265                prompt.push_str("\n");
266            }
267            prompt.push_str("\n`````\n");
268        }
269
270        prompt
271    }
272
273    fn format_cursor_excerpt(
274        example: &Example,
275        editable_range: Range<usize>,
276        context_range: Range<usize>,
277    ) -> String {
278        let mut result = String::new();
279
280        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
281
282        let path_str = example.spec.cursor_path.to_string_lossy();
283        result.push_str(&format!("`````{path_str}\n"));
284        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
285        result.push_str(Self::EDITABLE_REGION_START);
286        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
287        result.push_str(Self::USER_CURSOR_MARKER);
288        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
289        result.push_str(Self::EDITABLE_REGION_END);
290        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
291        result.push_str("\n`````");
292
293        result
294    }
295
296    fn extract_editable_region(text: &str) -> String {
297        let start = text
298            .rfind(Self::EDITABLE_REGION_START)
299            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
300        let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
301
302        let region = &text[start..end];
303        let region = region.strip_suffix('\n').unwrap_or(region);
304
305        region.replace(Self::USER_CURSOR_MARKER, "")
306    }
307
308    fn is_udiff_content_line(s: &str) -> bool {
309        s.starts_with("-")
310            || s.starts_with("+")
311            || s.starts_with(" ")
312            || s.starts_with("---")
313            || s.starts_with("+++")
314            || s.starts_with("@@")
315    }
316}
317
318fn extract_last_codeblock(text: &str) -> String {
319    let mut last_block = None;
320    let mut search_start = 0;
321
322    while let Some(start) = text[search_start..].find("```") {
323        let start = start + search_start;
324        let bytes = text.as_bytes();
325        let mut backtick_end = start;
326
327        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
328            backtick_end += 1;
329        }
330
331        let backtick_count = backtick_end - start;
332        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
333
334        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
335            backtick_end += 1;
336        }
337
338        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
339            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
340            last_block = Some(code_block.to_string());
341            search_start = backtick_end + end_pos + closing_pattern.len();
342        } else {
343            break;
344        }
345    }
346
347    last_block.unwrap_or_else(|| text.to_string())
348}
349
350#[cfg(test)]
351mod tests {
352    use super::*;
353
354    #[test]
355    fn test_extract_last_code_block() {
356        let text = indoc::indoc! {"
357            Some thinking
358
359            ```
360            first block
361            ```
362
363            `````path='something' lines=1:2
364            last block
365            `````
366            "};
367        let last_block = extract_last_codeblock(text);
368        assert_eq!(last_block, "last block\n");
369    }
370
371    #[test]
372    fn test_extract_codeblock_with_nested_fences() {
373        let text = indoc::indoc! {"
374            `````
375            content with ``` inline
376            and ```python nested
377            more content
378            `````
379            "};
380        let last_block = extract_last_codeblock(text);
381        assert_eq!(
382            last_block,
383            "content with ``` inline\nand ```python nested\nmore content\n"
384        );
385    }
386
387    #[test]
388    fn test_extract_codeblock_ignores_inline_backticks() {
389        let text = indoc::indoc! {"
390            `````
391            here is some `code` with inline backticks
392            and here```more```stuff
393            `````
394            "};
395        let last_block = extract_last_codeblock(text);
396        assert_eq!(
397            last_block,
398            "here is some `code` with inline backticks\nand here```more```stuff\n"
399        );
400    }
401
402    #[test]
403    fn test_extract_editable_region() {
404        let text = indoc::indoc! {"
405            some lines
406            are
407            here
408            <|editable_region_start|>
409            one
410            two three
411
412            <|editable_region_end|>
413            more
414            lines here
415            "};
416        let parsed = TeacherPrompt::extract_editable_region(text);
417        assert_eq!(
418            parsed,
419            indoc::indoc! {"
420            one
421            two three"}
422        );
423    }
424
425    #[test]
426    fn test_extract_last_codeblock_nested_bibtex() {
427        let text = indoc::indoc! {r#"
428            Looking at the edit history, I can see that a Citation section was just added.
429
430            `````
431            ## Collaborations
432            Our mission is to create a 4D generative model.
433
434            ## Citation
435
436            If you found Unique3D helpful, please cite our report:
437            ```bibtex
438            @misc{wu2024unique3d,
439                  title={Unique3D},
440            }
441            ```
442            `````
443            "#};
444        let last_block = extract_last_codeblock(text);
445        assert_eq!(
446            last_block,
447            indoc::indoc! {r#"
448            ## Collaborations
449            Our mission is to create a 4D generative model.
450
451            ## Citation
452
453            If you found Unique3D helpful, please cite our report:
454            ```bibtex
455            @misc{wu2024unique3d,
456                  title={Unique3D},
457            }
458            ```
459            "#}
460        );
461    }
462
463    #[test]
464    fn test_extract_editable_region_no_markers() {
465        let text = indoc::indoc! {"
466            one
467            two three"};
468        let parsed = TeacherPrompt::extract_editable_region(text);
469        assert_eq!(
470            parsed,
471            indoc::indoc! {"
472            one
473            two three"}
474        );
475    }
476
477    #[test]
478    fn test_extract_editable_region_strips_cursor_marker() {
479        let text = indoc::indoc! {"
480            <|editable_region_start|>
481            one
482            <|user_cursor|>two three
483
484            <|editable_region_end|>
485            "};
486        let parsed = TeacherPrompt::extract_editable_region(text);
487        assert_eq!(
488            parsed,
489            indoc::indoc! {"
490            one
491            two three"}
492        );
493    }
494}