format_prompt.rs

  1use crate::{
  2    PromptFormat,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    load_project::run_load_project,
  6    progress::{Progress, Step},
  7    retrieve_context::run_context_retrieval,
  8};
  9use anyhow::{Context as _, Result, ensure};
 10use edit_prediction::{
 11    EditPredictionStore,
 12    zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
 13};
 14use gpui::AsyncApp;
 15use std::sync::Arc;
 16use zeta_prompt::format_zeta_prompt;
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    prompt_format: PromptFormat,
 21    app_state: Arc<EpAppState>,
 22    mut cx: AsyncApp,
 23) -> Result<()> {
 24    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 25
 26    let _step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 27
 28    match prompt_format {
 29        PromptFormat::Teacher => {
 30            let prompt = TeacherPrompt::format_prompt(example);
 31            example.prompt = Some(ExamplePrompt {
 32                input: prompt,
 33                expected_output: example.spec.expected_patch.clone(), // TODO
 34                format: prompt_format,
 35            });
 36        }
 37        PromptFormat::Zeta2 => {
 38            run_load_project(example, app_state, cx.clone()).await?;
 39
 40            let ep_store = cx.update(|cx| {
 41                EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
 42            })??;
 43
 44            let state = example.state.as_ref().context("state must be set")?;
 45            let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
 46            let project = state.project.clone();
 47            let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
 48                anyhow::Ok(zeta2_prompt_input(
 49                    &snapshot,
 50                    example
 51                        .context
 52                        .as_ref()
 53                        .context("context must be set")?
 54                        .files
 55                        .clone(),
 56                    ep_store.edit_history_for_project(&project, cx),
 57                    example.spec.cursor_path.clone(),
 58                    example
 59                        .buffer
 60                        .as_ref()
 61                        .context("buffer must be set")?
 62                        .cursor_offset,
 63                ))
 64            })??;
 65            let prompt = format_zeta_prompt(&input);
 66            let expected_output =
 67                zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?;
 68            example.prompt = Some(ExamplePrompt {
 69                input: prompt,
 70                expected_output,
 71                format: prompt_format,
 72            });
 73        }
 74    };
 75    Ok(())
 76}
 77
 78pub struct TeacherPrompt;
 79
 80impl TeacherPrompt {
 81    const PROMPT: &str = include_str!("teacher.prompt.md");
 82    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
 83    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
 84
 85    /// Truncate edit history to this number of last lines
 86    const MAX_HISTORY_LINES: usize = 128;
 87
 88    pub fn format_prompt(example: &Example) -> String {
 89        let edit_history = Self::format_edit_history(&example.spec.edit_history);
 90        let context = Self::format_context(example);
 91        let editable_region = Self::format_editable_region(example);
 92
 93        let prompt = Self::PROMPT
 94            .replace("{{context}}", &context)
 95            .replace("{{edit_history}}", &edit_history)
 96            .replace("{{editable_region}}", &editable_region);
 97
 98        prompt
 99    }
100
101    pub fn parse(example: &Example, response: &str) -> Result<String> {
102        // Ideally, we should always be able to find cursor position in the retrieved context.
103        // In reality, sometimes we don't find it for these reasons:
104        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
105        //    (can be fixed by getting cursor coordinates at the load_example stage)
106        // 2. Context retriever just didn't include cursor line.
107        //
108        // In that case, fallback to using `cursor_position` as excerpt.
109        let cursor_file = &example
110            .buffer
111            .as_ref()
112            .context("`buffer` should be filled in in the context collection step")?
113            .content;
114
115        // Extract updated (new) editable region from the model response
116        let new_editable_region = extract_last_codeblock(response);
117
118        // Reconstruct old editable region we sent to the model
119        let old_editable_region = Self::format_editable_region(example);
120        let old_editable_region = Self::extract_editable_region(&old_editable_region);
121        ensure!(
122            cursor_file.contains(&old_editable_region),
123            "Something's wrong: editable_region is not found in the cursor file"
124        );
125
126        // Apply editable region to a larger context and compute diff.
127        // This is needed to get a better context lines around the editable region
128        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
129        let diff = language::unified_diff(&cursor_file, &edited_file);
130
131        let diff = indoc::formatdoc! {"
132            --- a/{path}
133            +++ b/{path}
134            {diff}",
135            path = example.spec.cursor_path.to_string_lossy(),
136            diff = diff,
137        };
138
139        Ok(diff)
140    }
141
142    fn format_edit_history(edit_history: &str) -> String {
143        // Strip comments ("garbage lines") from edit history
144        let lines = edit_history
145            .lines()
146            .filter(|&s| Self::is_udiff_content_line(s))
147            .collect::<Vec<_>>();
148
149        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
150            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
151        } else {
152            &lines
153        };
154
155        if history_lines.is_empty() {
156            return "(No edit history)".to_string();
157        }
158
159        history_lines.join("\n")
160    }
161
162    fn format_context(example: &Example) -> String {
163        assert!(example.context.is_some(), "Missing context retriever step");
164
165        let mut prompt = String::new();
166        zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
167
168        prompt
169    }
170
171    fn format_editable_region(example: &Example) -> String {
172        let mut result = String::new();
173
174        let path_str = example.spec.cursor_path.to_string_lossy();
175        result.push_str(&format!("`````path=\"{path_str}\"\n"));
176        result.push_str(Self::EDITABLE_REGION_START);
177
178        // TODO: control number of lines around cursor
179        result.push_str(&example.spec.cursor_position);
180        if !example.spec.cursor_position.ends_with('\n') {
181            result.push('\n');
182        }
183
184        result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
185        result.push_str("`````");
186
187        result
188    }
189
190    fn extract_editable_region(text: &str) -> String {
191        let start = text
192            .find(Self::EDITABLE_REGION_START)
193            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
194        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
195
196        let region = &text[start..end];
197
198        region.replace("<|user_cursor|>", "")
199    }
200
201    fn is_udiff_content_line(s: &str) -> bool {
202        s.starts_with("-")
203            || s.starts_with("+")
204            || s.starts_with(" ")
205            || s.starts_with("---")
206            || s.starts_with("+++")
207            || s.starts_with("@@")
208    }
209}
210
211fn extract_last_codeblock(text: &str) -> String {
212    let mut last_block = None;
213    let mut search_start = 0;
214
215    while let Some(start) = text[search_start..].find("```") {
216        let start = start + search_start;
217        let bytes = text.as_bytes();
218        let mut backtick_end = start;
219
220        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
221            backtick_end += 1;
222        }
223
224        let backtick_count = backtick_end - start;
225        let closing_backticks = "`".repeat(backtick_count);
226
227        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
228            backtick_end += 1;
229        }
230
231        if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
232            let code_block = &text[backtick_end + 1..backtick_end + end_pos];
233            last_block = Some(code_block.to_string());
234            search_start = backtick_end + end_pos + backtick_count;
235        } else {
236            break;
237        }
238    }
239
240    last_block.unwrap_or_else(|| text.to_string())
241}
242
243#[cfg(test)]
244mod tests {
245    use super::*;
246
247    #[test]
248    fn test_extract_last_code_block() {
249        let text = indoc::indoc! {"
250            Some thinking
251
252            ```
253            first block
254            ```
255
256            `````path='something' lines=1:2
257            last block
258            `````
259            "};
260        let last_block = extract_last_codeblock(text);
261        assert_eq!(last_block, "last block\n");
262    }
263
264    #[test]
265    fn test_extract_editable_region() {
266        let text = indoc::indoc! {"
267            some lines
268            are
269            here
270            <|editable_region_start|>
271            one
272            two three
273
274            <|editable_region_end|>
275            more
276            lines here
277            "};
278        let parsed = TeacherPrompt::extract_editable_region(text);
279        assert_eq!(
280            parsed,
281            indoc::indoc! {"
282            one
283            two three
284
285            "}
286        );
287    }
288}