format_prompt.rs

  1use crate::{
  2    FormatPromptArgs, PredictionProvider,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    progress::{Progress, Step},
  6    retrieve_context::run_context_retrieval,
  7};
  8use anyhow::{Context as _, Result};
  9use edit_prediction::cursor_excerpt::editable_and_context_ranges_for_cursor_position;
 10use gpui::{AppContext, AsyncApp};
 11use language::{Buffer, OffsetRangeExt, Point};
 12use similar::DiffableStr;
 13use std::sync::Arc;
 14use std::{fmt::Write as _, ops::Range};
 15use zeta_prompt::format_zeta_prompt;
 16
 17pub async fn run_format_prompt(
 18    example: &mut Example,
 19    args: &FormatPromptArgs,
 20    app_state: Arc<EpAppState>,
 21    cx: AsyncApp,
 22) -> Result<()> {
 23    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 24
 25    let step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 26
 27    let prompt_inputs = example
 28        .prompt_inputs
 29        .as_ref()
 30        .context("prompt_inputs must be set after context retrieval")?;
 31
 32    let language = app_state
 33        .languages
 34        .load_language_for_file_path(&example.spec.cursor_path)
 35        .await
 36        .ok();
 37    let snapshot_fut = cx.update(|cx| {
 38        Buffer::build_snapshot(
 39            prompt_inputs.content.as_str().into(),
 40            language,
 41            Some(app_state.languages.clone()),
 42            cx,
 43        )
 44    });
 45    let cursor_point = Point::new(prompt_inputs.cursor_row, prompt_inputs.cursor_column);
 46    let snapshot = cx.background_spawn(snapshot_fut).await;
 47
 48    let (editable_range, context_range) = editable_and_context_ranges_for_cursor_position(
 49        cursor_point,
 50        &snapshot,
 51        edit_prediction::zeta2::MAX_EDITABLE_TOKENS,
 52        edit_prediction::zeta2::MAX_CONTEXT_TOKENS,
 53    );
 54    let editable_range = editable_range.to_offset(&snapshot);
 55    let context_range = context_range.to_offset(&snapshot);
 56
 57    match args.provider {
 58        PredictionProvider::Teacher | PredictionProvider::TeacherNonBatching => {
 59            step_progress.set_substatus("formatting teacher prompt");
 60            let prompt = TeacherPrompt::format_prompt(example, editable_range, context_range);
 61            example.prompt = Some(ExamplePrompt {
 62                input: prompt,
 63                expected_output: example
 64                    .spec
 65                    .expected_patches
 66                    .first()
 67                    .cloned()
 68                    .unwrap_or_default(),
 69                provider: args.provider,
 70            });
 71        }
 72        PredictionProvider::Zeta2(version) => {
 73            step_progress.set_substatus("formatting zeta2 prompt");
 74
 75            let context_start = context_range.start;
 76            let cursor_offset_in_excerpt = prompt_inputs.cursor_offset - context_start;
 77            let editable_range_in_excerpt =
 78                (editable_range.start - context_start)..(editable_range.end - context_start);
 79            let input = zeta_prompt::ZetaPromptInput {
 80                cursor_path: example.spec.cursor_path.clone(),
 81                cursor_excerpt: prompt_inputs.content[context_range].to_string().into(),
 82                editable_range_in_excerpt,
 83                cursor_offset_in_excerpt,
 84                events: prompt_inputs.edit_history.clone(),
 85                related_files: prompt_inputs.related_files.clone().unwrap_or_default(),
 86            };
 87            let prompt = format_zeta_prompt(&input, version);
 88            let expected_output = zeta2_output_for_patch(
 89                &input,
 90                &example
 91                    .spec
 92                    .expected_patches
 93                    .first()
 94                    .context("expected patches is empty")?
 95                    .clone(),
 96            )?;
 97            example.prompt = Some(ExamplePrompt {
 98                input: prompt,
 99                expected_output,
100                provider: args.provider,
101            });
102        }
103        _ => {
104            panic!("Cannot format prompt for {:?}", args.provider);
105        }
106    };
107    Ok(())
108}
109
110pub fn zeta2_output_for_patch(input: &zeta_prompt::ZetaPromptInput, patch: &str) -> Result<String> {
111    let mut old_editable_region =
112        input.cursor_excerpt[input.editable_range_in_excerpt.clone()].to_string();
113
114    if !old_editable_region.ends_with_newline() {
115        old_editable_region.push('\n');
116    }
117
118    edit_prediction::udiff::apply_diff_to_string(patch, &old_editable_region).with_context(|| {
119        format!(
120            "Patch:\n```\n{}```\n\nEditable region:\n```\n{}```",
121            patch, old_editable_region
122        )
123    })
124}
125
126pub struct TeacherPrompt;
127
128impl TeacherPrompt {
129    const PROMPT: &str = include_str!("teacher.prompt.md");
130    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
131    pub(crate) const EDITABLE_REGION_END: &str = "\n<|editable_region_end|>";
132    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
133
134    /// Truncate edit history to this number of last lines
135    const MAX_HISTORY_LINES: usize = 128;
136
137    pub fn format_prompt(
138        example: &Example,
139        editable_range: Range<usize>,
140        context_range: Range<usize>,
141    ) -> String {
142        let edit_history = Self::format_edit_history(&example.spec.edit_history);
143        let context = Self::format_context(example);
144        let cursor_excerpt = Self::format_cursor_excerpt(example, editable_range, context_range);
145
146        let prompt = Self::PROMPT
147            .replace("{{context}}", &context)
148            .replace("{{edit_history}}", &edit_history)
149            .replace("{{cursor_excerpt}}", &cursor_excerpt);
150
151        prompt
152    }
153
154    pub fn parse(example: &Example, response: &str) -> Result<String> {
155        // Extract updated (new) editable region from the model response.
156        // The model may include editable region markers in its output, so we need to strip them.
157        let new_editable_region = extract_last_codeblock(response);
158        let mut new_editable_region = Self::extract_editable_region(&new_editable_region);
159        let old_editable_region = Self::extract_editable_region(
160            &example
161                .prompt
162                .as_ref()
163                .context("example prompt missing")?
164                .input,
165        );
166        let prompt_inputs = example
167            .prompt_inputs
168            .as_ref()
169            .context("example is missing prompt inputs")?;
170
171        // Normalize leading newlines: if old starts with newline but new doesn't,
172        // prepend newline to new to preserve whitespace structure.
173        // This handles the case where the model drops the leading blank line.
174        if old_editable_region.starts_with('\n') && !new_editable_region.starts_with('\n') {
175            new_editable_region.insert(0, '\n');
176        }
177
178        let (editable_region_offset, _) = prompt_inputs
179            .content
180            .match_indices(&old_editable_region)
181            .min_by_key(|(index, _)| index.abs_diff(prompt_inputs.cursor_offset))
182            .unwrap();
183        let editable_region_start_line = prompt_inputs.content[..editable_region_offset]
184            .matches('\n')
185            .count();
186
187        let diff = language::unified_diff_with_offsets(
188            &old_editable_region,
189            &new_editable_region,
190            editable_region_start_line as u32,
191            editable_region_start_line as u32,
192        );
193
194        let diff = indoc::formatdoc! {"
195            --- a/{path}
196            +++ b/{path}
197            {diff}",
198            path = example.spec.cursor_path.to_string_lossy(),
199            diff = diff,
200        };
201
202        Ok(diff)
203    }
204
205    fn format_edit_history(edit_history: &str) -> String {
206        // Strip comments ("garbage lines") from edit history
207        let lines = edit_history
208            .lines()
209            .filter(|&s| Self::is_udiff_content_line(s))
210            .collect::<Vec<_>>();
211
212        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
213            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
214        } else {
215            &lines
216        };
217
218        if history_lines.is_empty() {
219            return "(No edit history)".to_string();
220        }
221
222        history_lines.join("\n")
223    }
224
225    fn format_context(example: &Example) -> String {
226        let related_files = example
227            .prompt_inputs
228            .as_ref()
229            .and_then(|pi| pi.related_files.as_ref());
230
231        let Some(related_files) = related_files else {
232            return "(No context)".to_string();
233        };
234
235        if related_files.is_empty() {
236            return "(No context)".to_string();
237        }
238
239        let mut prompt = String::new();
240        for file in related_files {
241            let path_str = file.path.to_string_lossy();
242            writeln!(&mut prompt, "`````{path_str}").ok();
243            let mut prev_row = 0;
244            for excerpt in &file.excerpts {
245                if excerpt.row_range.start > prev_row {
246                    prompt.push_str("\n");
247                }
248                prompt.push_str(&excerpt.text);
249                prompt.push('\n');
250                prev_row = excerpt.row_range.end;
251            }
252            if prev_row < file.max_row {
253                prompt.push_str("\n");
254            }
255            prompt.push_str("\n`````");
256        }
257
258        prompt
259    }
260
261    fn format_cursor_excerpt(
262        example: &Example,
263        editable_range: Range<usize>,
264        context_range: Range<usize>,
265    ) -> String {
266        let mut result = String::new();
267
268        let prompt_inputs = example.prompt_inputs.as_ref().unwrap();
269
270        let path_str = example.spec.cursor_path.to_string_lossy();
271        result.push_str(&format!("`````{path_str}\n"));
272        result.push_str(&prompt_inputs.content[context_range.start..editable_range.start]);
273        result.push_str(Self::EDITABLE_REGION_START);
274        result.push_str(&prompt_inputs.content[editable_range.start..prompt_inputs.cursor_offset]);
275        result.push_str(Self::USER_CURSOR_MARKER);
276        result.push_str(&prompt_inputs.content[prompt_inputs.cursor_offset..editable_range.end]);
277        result.push_str(Self::EDITABLE_REGION_END);
278        result.push_str(&prompt_inputs.content[editable_range.end..context_range.end]);
279        result.push_str("\n`````");
280
281        result
282    }
283
284    fn extract_editable_region(text: &str) -> String {
285        let start = text
286            .find(Self::EDITABLE_REGION_START)
287            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
288        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
289
290        let region = &text[start..end];
291        let region = region.strip_suffix('\n').unwrap_or(region);
292
293        region.replace(Self::USER_CURSOR_MARKER, "")
294    }
295
296    fn is_udiff_content_line(s: &str) -> bool {
297        s.starts_with("-")
298            || s.starts_with("+")
299            || s.starts_with(" ")
300            || s.starts_with("---")
301            || s.starts_with("+++")
302            || s.starts_with("@@")
303    }
304}
305
306fn extract_last_codeblock(text: &str) -> String {
307    let mut last_block = None;
308    let mut search_start = 0;
309
310    while let Some(start) = text[search_start..].find("```") {
311        let start = start + search_start;
312        let bytes = text.as_bytes();
313        let mut backtick_end = start;
314
315        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
316            backtick_end += 1;
317        }
318
319        let backtick_count = backtick_end - start;
320        let closing_pattern = format!("\n{}", "`".repeat(backtick_count));
321
322        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
323            backtick_end += 1;
324        }
325
326        if let Some(end_pos) = text[backtick_end..].find(&closing_pattern) {
327            let code_block = &text[backtick_end + 1..backtick_end + end_pos + 1];
328            last_block = Some(code_block.to_string());
329            search_start = backtick_end + end_pos + closing_pattern.len();
330        } else {
331            break;
332        }
333    }
334
335    last_block.unwrap_or_else(|| text.to_string())
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341
342    #[test]
343    fn test_extract_last_code_block() {
344        let text = indoc::indoc! {"
345            Some thinking
346
347            ```
348            first block
349            ```
350
351            `````path='something' lines=1:2
352            last block
353            `````
354            "};
355        let last_block = extract_last_codeblock(text);
356        assert_eq!(last_block, "last block\n");
357    }
358
359    #[test]
360    fn test_extract_codeblock_with_nested_fences() {
361        let text = indoc::indoc! {"
362            `````
363            content with ``` inline
364            and ```python nested
365            more content
366            `````
367            "};
368        let last_block = extract_last_codeblock(text);
369        assert_eq!(
370            last_block,
371            "content with ``` inline\nand ```python nested\nmore content\n"
372        );
373    }
374
375    #[test]
376    fn test_extract_codeblock_ignores_inline_backticks() {
377        let text = indoc::indoc! {"
378            `````
379            here is some `code` with inline backticks
380            and here```more```stuff
381            `````
382            "};
383        let last_block = extract_last_codeblock(text);
384        assert_eq!(
385            last_block,
386            "here is some `code` with inline backticks\nand here```more```stuff\n"
387        );
388    }
389
390    #[test]
391    fn test_extract_editable_region() {
392        let text = indoc::indoc! {"
393            some lines
394            are
395            here
396            <|editable_region_start|>
397            one
398            two three
399
400            <|editable_region_end|>
401            more
402            lines here
403            "};
404        let parsed = TeacherPrompt::extract_editable_region(text);
405        assert_eq!(
406            parsed,
407            indoc::indoc! {"
408            one
409            two three"}
410        );
411    }
412
413    #[test]
414    fn test_extract_last_codeblock_nested_bibtex() {
415        let text = indoc::indoc! {r#"
416            Looking at the edit history, I can see that a Citation section was just added.
417
418            `````
419            ## Collaborations
420            Our mission is to create a 4D generative model.
421
422            ## Citation
423
424            If you found Unique3D helpful, please cite our report:
425            ```bibtex
426            @misc{wu2024unique3d,
427                  title={Unique3D},
428            }
429            ```
430            `````
431            "#};
432        let last_block = extract_last_codeblock(text);
433        assert_eq!(
434            last_block,
435            indoc::indoc! {r#"
436            ## Collaborations
437            Our mission is to create a 4D generative model.
438
439            ## Citation
440
441            If you found Unique3D helpful, please cite our report:
442            ```bibtex
443            @misc{wu2024unique3d,
444                  title={Unique3D},
445            }
446            ```
447            "#}
448        );
449    }
450
451    #[test]
452    fn test_extract_editable_region_no_markers() {
453        let text = indoc::indoc! {"
454            one
455            two three"};
456        let parsed = TeacherPrompt::extract_editable_region(text);
457        assert_eq!(
458            parsed,
459            indoc::indoc! {"
460            one
461            two three"}
462        );
463    }
464
465    #[test]
466    fn test_extract_editable_region_strips_cursor_marker() {
467        let text = indoc::indoc! {"
468            <|editable_region_start|>
469            one
470            <|user_cursor|>two three
471
472            <|editable_region_end|>
473            "};
474        let parsed = TeacherPrompt::extract_editable_region(text);
475        assert_eq!(
476            parsed,
477            indoc::indoc! {"
478            one
479            two three"}
480        );
481    }
482}