format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{ExampleProgress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result};
  9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
 10use gpui::{AppContext, AsyncApp};
 11use language::{Buffer, OffsetRangeExt, Point};
 12use similar::DiffableStr;
 13use std::sync::Arc;
 14use std::{fmt::Write as _, ops::Range};
 15use zeta_prompt::ZetaVersion;
 16use zeta_prompt::format_zeta_prompt;
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    args: &FormatPromptArgs,
 21    app_state: Arc<EpAppState>,
 22    example_progress: &ExampleProgress,
 23    cx: AsyncApp,
 24) -> Result<()> {
 25    run_context_retrieval(example, app_state.clone(), example_progress, cx.clone()).await?;
 26
 27    let step_progress = example_progress.start(Step::FormatPrompt);
 28
 29    let prompt_inputs = example
 30        .prompt_inputs
 31        .as_ref()
 32        .context("prompt_inputs must be set after context retrieval")?;
 33
 34    let language = app_state
 35        .languages
 36        .load_language_for_file_path(&example.spec.cursor_path)
 37        .await
 38        .ok();
 39    let snapshot_fut = cx.update(|cx| {
 40        Buffer::build_snapshot(
 41            prompt_inputs.content.as_str().into(),
 42            language,
 43            Some(app_state.languages.clone()),
 44            cx,
 45        )
 46    });
 47    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
 48    let snapshot = cx.background_spawn(snapshot_fut).await;
 49
 50    match args.provider {
 51        PredictionProvider::Teacher(version) | PredictionProvider::TeacherNonBatching(version) => {
 52            step_progress.set_substatus("formatting teacher prompt");
 53
 54            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 55                cursor_point,
 56                &snapshot,
 57                edit_prediction::zeta2::max_editable_tokens(version),
 58                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 59            );
 60            let editable_range = editable_range.to_offset(&snapshot);
 61            let context_range = context_range.to_offset(&snapshot);
 62
 63            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 64            example.prompt = Some(ExamplePrompt {
 65                input: prompt,
 66                expected_output: example
 67                    .spec
 68                    .expected_patches
 69                    .first()
 70                    .cloned()
 71                    .unwrap_or_default(),
 72                provider: args.provider,
 73            });
 74        }
 75        PredictionProvider::Zeta2(version) => {
 76            step_progress.set_substatus("formatting zeta2 prompt");
 77
 78            let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 79                cursor_point,
 80                &snapshot,
 81                edit_prediction::zeta2::max_editable_tokens(version),
 82                edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 83            );
 84            let editable_range = editable_range.to_offset(&snapshot);
 85            let context_range = context_range.to_offset(&snapshot);
 86
 87            let context_start = context_range.start;
 88            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 89            let editable_range_in_excerpt =
 90                (editable_range.start - context_start)..(editable_range.end - context_start);
 91            let input = zeta_prompt::ZetaPromptInput {
 92                cursor_path: example.spec.cursor_path.clone(),
 93                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
 94                editable_range_in_excerpt,
 95                cursor_offset_in_excerpt,
 96                events: prompt_inputs.edit_history.clone(),
 97                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
 98            };
 99            let prompt = format_zeta_prompt(&input, version);
100            let expected_output = zeta2_output_for_patch(
101                &input,
102                &example
103                    .spec
104                    .expected_patches
105                    .first()
106                    .context("expected patches is empty")?
107                    .clone(),
108                version,
109            )?;
110            example.prompt = Some(ExamplePrompt {
111                input: prompt,
112                expected_output,
113                provider: args.provider,
114            });
115        }
116        _ => {
117            panic!("Cannot format prompt for {:?}", args.provider);
118        }
119    };
120    Ok(())
121}
122
123pub fn zeta2_output_for_patch(
124    input: &zeta_prompt::ZetaPromptInput,
125    patch: &str,
126    version: ZetaVersion,
127) -> Result<String> {
128    let mut old_editable_region =
129        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
130
131    if !old_editable_region.ends_with_newline() {
132        old_editable_region.push('\n');
133    }
134
135    let mut result = edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region)
136        .with_context(|| {
137            format!(
138                "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
139                patch, old_editable_region
140            )
141        })?;
142
143    if version == ZetaVersion::V0120GitMergeMarkers {
144        if !result.ends_with('\n') {
145            result.push('\n');
146        }
147        result.push_str(zeta_prompt::v0120_git_merge_markers::END_MARKER);
148    }
149
150    Ok(result)
151}
152
153pub struct TeacherPrompt;
154
155impl TeacherPrompt {
156    const PROMPT: &str = include_str!("teacher.prompt.md");
157    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
158    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
159    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
160
161    /// Truncate edit history to this number of last lines
162    const MAX_HISTORY_LINES: usize = 128;
163
164    pub fn format_prompt(
165        example: &Example,
166        editable_range: Range<usize>,
167        context_range: Range<usize>,
168    ) -> String {
169        let edit_history = Self::format_edit_history(&example.spec.edit_history);
170        let context = Self::format_context(example);
171        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
172
173        let prompt = Self::PROMPT
174            .replace("{{context}}", &context)
175            .replace("{{edit_history}}", &edit_history)
176            .replace("{{cursor_excerpt}}", &cursor_excerpt);
177
178        prompt
179    }
180
181    pub fn parse(example: &Example, response: &str) -> Result<String> {
182        // Extract updated (new) editable region from the model response.
183        // The model may include editable region markers in its output, so we need to strip them.
184        let new_editable_region = extract_last_codeblock(response);
185        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
186        let old_editable_region = Self::extract_editable_region(
187            &example
188                .prompt
189                .as_ref()
190                .context("example prompt missing")?
191                .input,
192        );
193        let prompt_inputs = example
194            .prompt_inputs
195            .as_ref()
196            .context("example is missing prompt inputs")?;
197
198        // Normalize leading newlines: if old starts with newline but new doesn't,
199        // prepend newline to new to preserve whitespace structure.
200        // This handles the case where the model drops the leading blank line.
201        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
202            new_editable_region.insert(0, '\n');
203        }
204
205        let (editable_region_offset, _) = prompt_inputs
206            .content
207            .match_indices(&old_editable_region)
208            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
209            .unwrap();
210        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
211            .matches('\n')
212            .count();
213
214        let diff = language::unified_diff_with_offsets(
215            &old_editable_region,
216            &new_editable_region,
217            editable_region_start_line as u32,
218            editable_region_start_line as u32,
219        );
220
221        let diff = indoc::formatdoc! {"
222            --- a/{path}
223            +++ b/{path}
224            {diff}",
225            path = example.spec.cursor_path.to_string_lossy(),
226            diff = diff,
227        };
228
229        Ok(diff)
230    }
231
232    fn format_edit_history(edit_history: &str) -> String {
233        // Strip comments ("garbage lines") from edit history
234        let lines = edit_history
235            .lines()
236            .filter(|&s| Self::is_udiff_content_line(s))
237            .collect::<Vec<_>>();
238
239        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
240            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
241        } else {
242            &lines
243        };
244
245        if history_lines.is_empty() {
246            return "(No edit history)".to_string();
247        }
248
249        history_lines.join("\n")
250    }
251
252    fn format_context(example: &Example) -> String {
253        let related_files = example
254            .prompt_inputs
255            .as_ref()
256            .and_then(|pi| pi.related_files.as_ref());
257
258        let Some(related_files) = related_files else {
259            return "(No context)".to_string();
260        };
261
262        if related_files.is_empty() {
263            return "(No context)".to_string();
264        }
265
266        let mut prompt = String::new();
267        for file in related_files {
268            let path_str = file.path.to_string_lossy();
269            writeln!(&mut prompt, "`````{path_str}").ok();
270
271            let mut prev_row = 0;
272            for excerpt in &file.excerpts {
273                if excerpt.row_range.start > prev_row {
274                    prompt.push_str("\n");
275                }
276                prompt.push_str(&excerpt.text);
277                prompt.push('\n');
278                prev_row = excerpt.row_range.end;
279            }
280            if prev_row < file.max_row {
281                prompt.push_str("\n");
282            }
283            prompt.push_str("\n`````\n");
284        }
285
286        prompt
287    }
288
289    fn format_cursor_excerpt(
290        example: &Example,
291        editable_range: Range<usize>,
292        context_range: Range<usize>,
293    ) -> String {
294        let mut result = String::new();
295
296        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
297
298        let path_str = example.spec.cursor_path.to_string_lossy();
299        result.push_str(&format!("`````{path_str}\n"));
300        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
301        result.push_str(Self::EDITABLE_REGION_START);
302        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
303        result.push_str(Self::USER_CURSOR_MARKER);
304        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
305        result.push_str(Self::EDITABLE_REGION_END);
306        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
307        result.push_str("\n`````");
308
309        result
310    }
311
312    fn extract_editable_region(text: &str) -> String {
313        let start = text
314            .rfind(Self::EDITABLE_REGION_START)
315            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
316        let end = text.rfind(Self::EDITABLE_REGION_END).unwrap_or(text.len());
317
318        let region = &text[start..end];
319        let region = region.strip_suffix('\n').unwrap_or(region);
320
321        region.replace(Self::USER_CURSOR_MARKER, "")
322    }
323
324    fn is_udiff_content_line(s: &str) -> bool {
325        s.starts_with("-")
326            || s.starts_with("+")
327            || s.starts_with(" ")
328            || s.starts_with("---")
329            || s.starts_with("+++")
330            || s.starts_with("@@")
331    }
332}
333
334fn extract_last_codeblock(text: &str) -> String {
335    let mut last_block = None;
336    let mut search_start = 0;
337
338    while let Some(start) = text[search_start..].find("```") {
339        let start = start + search_start;
340        let bytes = text.as_bytes();
341        let mut backtick_end = start;
342
343        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
344            backtick_end += 1;
345        }
346
347        let backtick_count = backtick_end - start;
348        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
349
350        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
351            backtick_end += 1;
352        }
353
354        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
355            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
356            last_block = Some(code_block.to_string());
357            search_start = backtick_end + end_pos + closing_pattern.len();
358        } else {
359            break;
360        }
361    }
362
363    last_block.unwrap_or_else(|| text.to_string())
364}
365
366#[cfg(test)]
367mod tests {
368    use super::*;
369
370    #[test]
371    fn test_extract_last_code_block() {
372        let text = indoc::indoc! {"
373            Some thinking
374
375            ```
376            first block
377            ```
378
379            `````path='something' lines=1:2
380            last block
381            `````
382            "};
383        let last_block = extract_last_codeblock(text);
384        assert_eq!(last_block, "last block\n");
385    }
386
387    #[test]
388    fn test_extract_codeblock_with_nested_fences() {
389        let text = indoc::indoc! {"
390            `````
391            content with ``` inline
392            and ```python nested
393            more content
394            `````
395            "};
396        let last_block = extract_last_codeblock(text);
397        assert_eq!(
398            last_block,
399            "content with ``` inline\nand ```python nested\nmore content\n"
400        );
401    }
402
403    #[test]
404    fn test_extract_codeblock_ignores_inline_backticks() {
405        let text = indoc::indoc! {"
406            `````
407            here is some `code` with inline backticks
408            and here```more```stuff
409            `````
410            "};
411        let last_block = extract_last_codeblock(text);
412        assert_eq!(
413            last_block,
414            "here is some `code` with inline backticks\nand here```more```stuff\n"
415        );
416    }
417
418    #[test]
419    fn test_extract_editable_region() {
420        let text = indoc::indoc! {"
421            some lines
422            are
423            here
424            <|editable_region_start|>
425            one
426            two three
427
428            <|editable_region_end|>
429            more
430            lines here
431            "};
432        let parsed = TeacherPrompt::extract_editable_region(text);
433        assert_eq!(
434            parsed,
435            indoc::indoc! {"
436            one
437            two three"}
438        );
439    }
440
441    #[test]
442    fn test_extract_last_codeblock_nested_bibtex() {
443        let text = indoc::indoc! {r#"
444            Looking at the edit history, I can see that a Citation section was just added.
445
446            `````
447            ## Collaborations
448            Our mission is to create a 4D generative model.
449
450            ## Citation
451
452            If you found Unique3D helpful, please cite our report:
453            ```bibtex
454            @misc{wu2024unique3d,
455                  title={Unique3D},
456            }
457            ```
458            `````
459            "#};
460        let last_block = extract_last_codeblock(text);
461        assert_eq!(
462            last_block,
463            indoc::indoc! {r#"
464            ## Collaborations
465            Our mission is to create a 4D generative model.
466
467            ## Citation
468
469            If you found Unique3D helpful, please cite our report:
470            ```bibtex
471            @misc{wu2024unique3d,
472                  title={Unique3D},
473            }
474            ```
475            "#}
476        );
477    }
478
479    #[test]
480    fn test_extract_editable_region_no_markers() {
481        let text = indoc::indoc! {"
482            one
483            two three"};
484        let parsed = TeacherPrompt::extract_editable_region(text);
485        assert_eq!(
486            parsed,
487            indoc::indoc! {"
488            one
489            two three"}
490        );
491    }
492
493    #[test]
494    fn test_extract_editable_region_strips_cursor_marker() {
495        let text = indoc::indoc! {"
496            <|editable_region_start|>
497            one
498            <|user_cursor|>two three
499
500            <|editable_region_end|>
501            "};
502        let parsed = TeacherPrompt::extract_editable_region(text);
503        assert_eq!(
504            parsed,
505            indoc::indoc! {"
506            one
507            two three"}
508        );
509    }
510}