format_prompt.rs

  1use crate::{
  2    PromptFormat,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    load_project::run_load_project,
  6    progress::{Progress, Step},
  7    retrieve_context::run_context_retrieval,
  8};
  9use edit_prediction::{
 10    EditPredictionStore,
 11    zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
 12};
 13use gpui::AsyncApp;
 14use std::sync::Arc;
 15use zeta_prompt::format_zeta_prompt;
 16
 17pub async fn run_format_prompt(
 18    example: &mut Example,
 19    prompt_format: PromptFormat,
 20    app_state: Arc<EpAppState>,
 21    mut cx: AsyncApp,
 22) {
 23    run_context_retrieval(example, app_state.clone(), cx.clone()).await;
 24
 25    let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
 26
 27    match prompt_format {
 28        PromptFormat::Teacher => {
 29            let prompt = TeacherPrompt::format_prompt(example);
 30            example.prompt = Some(ExamplePrompt {
 31                input: prompt,
 32                expected_output: example.expected_patch.clone(), // TODO
 33                format: prompt_format,
 34            });
 35        }
 36        PromptFormat::Zeta2 => {
 37            run_load_project(example, app_state, cx.clone()).await;
 38
 39            let ep_store = cx
 40                .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 41                .unwrap();
 42
 43            let state = example.state.as_ref().unwrap();
 44            let snapshot = state
 45                .buffer
 46                .read_with(&cx, |buffer, _| buffer.snapshot())
 47                .unwrap();
 48            let project = state.project.clone();
 49            let (_, input) = ep_store
 50                .update(&mut cx, |ep_store, _cx| {
 51                    zeta2_prompt_input(
 52                        &snapshot,
 53                        example.context.as_ref().unwrap().files.clone(),
 54                        ep_store.edit_history_for_project(&project),
 55                        example.cursor_path.clone(),
 56                        example.buffer.as_ref().unwrap().cursor_offset,
 57                    )
 58                })
 59                .unwrap();
 60            let prompt = format_zeta_prompt(&input);
 61            let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
 62            example.prompt = Some(ExamplePrompt {
 63                input: prompt,
 64                expected_output,
 65                format: prompt_format,
 66            });
 67        }
 68    };
 69}
 70
 71pub struct TeacherPrompt;
 72
 73impl TeacherPrompt {
 74    const PROMPT: &str = include_str!("teacher.prompt.md");
 75    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
 76    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
 77
 78    /// Truncate edit history to this number of last lines
 79    const MAX_HISTORY_LINES: usize = 128;
 80
 81    pub fn format_prompt(example: &Example) -> String {
 82        let edit_history = Self::format_edit_history(&example.edit_history);
 83        let context = Self::format_context(example);
 84        let editable_region = Self::format_editable_region(example);
 85
 86        let prompt = Self::PROMPT
 87            .replace("{{context}}", &context)
 88            .replace("{{edit_history}}", &edit_history)
 89            .replace("{{editable_region}}", &editable_region);
 90
 91        prompt
 92    }
 93
 94    pub fn parse(example: &Example, response: &str) -> String {
 95        // Ideally, we should always be able to find cursor position in the retrieved context.
 96        // In reality, sometimes we don't find it for these reasons:
 97        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
 98        //    (can be fixed by getting cursor coordinates at the load_example stage)
 99        // 2. Context retriever just didn't include cursor line.
100        //
101        // In that case, fallback to using `cursor_position` as excerpt.
102        let cursor_file = &example
103            .buffer
104            .as_ref()
105            .expect("`buffer` should be filled in in the context collection step")
106            .content;
107
108        // Extract updated (new) editable region from the model response
109        let new_editable_region = extract_last_codeblock(response);
110
111        // Reconstruct old editable region we sent to the model
112        let old_editable_region = Self::format_editable_region(example);
113        let old_editable_region = Self::extract_editable_region(&old_editable_region);
114        if !cursor_file.contains(&old_editable_region) {
115            panic!("Something's wrong: editable_region is not found in the cursor file")
116        }
117
118        // Apply editable region to a larger context and compute diff.
119        // This is needed to get a better context lines around the editable region
120        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
121        let diff = language::unified_diff(&cursor_file, &edited_file);
122
123        let diff = indoc::formatdoc! {"
124            --- a/{path}
125            +++ b/{path}
126            {diff}",
127            path = example.cursor_path.to_string_lossy(),
128            diff = diff,
129        };
130
131        diff
132    }
133
134    fn format_edit_history(edit_history: &str) -> String {
135        // Strip comments ("garbage lines") from edit history
136        let lines = edit_history
137            .lines()
138            .filter(|&s| Self::is_udiff_content_line(s))
139            .collect::<Vec<_>>();
140
141        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
142            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
143        } else {
144            &lines
145        };
146
147        if history_lines.is_empty() {
148            return "(No edit history)".to_string();
149        }
150
151        history_lines.join("\n")
152    }
153
154    fn format_context(example: &Example) -> String {
155        if example.context.is_none() {
156            panic!("Missing context retriever step");
157        }
158
159        let mut prompt = String::new();
160        zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
161
162        prompt
163    }
164
165    fn format_editable_region(example: &Example) -> String {
166        let mut result = String::new();
167
168        let path_str = example.cursor_path.to_string_lossy();
169        result.push_str(&format!("`````path=\"{path_str}\"\n"));
170        result.push_str(Self::EDITABLE_REGION_START);
171
172        // TODO: control number of lines around cursor
173        result.push_str(&example.cursor_position);
174        if !example.cursor_position.ends_with('\n') {
175            result.push('\n');
176        }
177
178        result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
179        result.push_str("`````");
180
181        result
182    }
183
184    fn extract_editable_region(text: &str) -> String {
185        let start = text
186            .find(Self::EDITABLE_REGION_START)
187            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
188        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
189
190        let region = &text[start..end];
191
192        region.replace("<|user_cursor|>", "")
193    }
194
195    fn is_udiff_content_line(s: &str) -> bool {
196        s.starts_with("-")
197            || s.starts_with("+")
198            || s.starts_with(" ")
199            || s.starts_with("---")
200            || s.starts_with("+++")
201            || s.starts_with("@@")
202    }
203}
204
205fn extract_last_codeblock(text: &str) -> String {
206    let mut last_block = None;
207    let mut search_start = 0;
208
209    while let Some(start) = text[search_start..].find("```") {
210        let start = start + search_start;
211        let bytes = text.as_bytes();
212        let mut backtick_end = start;
213
214        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
215            backtick_end += 1;
216        }
217
218        let backtick_count = backtick_end - start;
219        let closing_backticks = "`".repeat(backtick_count);
220
221        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
222            backtick_end += 1;
223        }
224
225        if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
226            let code_block = &text[backtick_end + 1..backtick_end + end_pos];
227            last_block = Some(code_block.to_string());
228            search_start = backtick_end + end_pos + backtick_count;
229        } else {
230            break;
231        }
232    }
233
234    last_block.unwrap_or_else(|| text.to_string())
235}
236
237#[cfg(test)]
238mod tests {
239    use super::*;
240
241    #[test]
242    fn test_extract_last_code_block() {
243        let text = indoc::indoc! {"
244            Some thinking
245
246            ```
247            first block
248            ```
249
250            `````path='something' lines=1:2
251            last block
252            `````
253            "};
254        let last_block = extract_last_codeblock(text);
255        assert_eq!(last_block, "last block\n");
256    }
257
258    #[test]
259    fn test_extract_editable_region() {
260        let text = indoc::indoc! {"
261            some lines
262            are
263            here
264            <|editable_region_start|>
265            one
266            two three
267
268            <|editable_region_end|>
269            more
270            lines here
271            "};
272        let parsed = TeacherPrompt::extract_editable_region(text);
273        assert_eq!(
274            parsed,
275            indoc::indoc! {"
276            one
277            two three
278
279            "}
280        );
281    }
282}