format_prompt.rs

  1use crate::{
  2    PromptFormat,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    load_project::run_load_project,
  6    retrieve_context::run_context_retrieval,
  7};
  8use edit_prediction::{
  9    EditPredictionStore,
 10    zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
 11};
 12use gpui::AsyncApp;
 13use std::sync::Arc;
 14use zeta_prompt::format_zeta_prompt;
 15
 16pub async fn run_format_prompt(
 17    example: &mut Example,
 18    prompt_format: PromptFormat,
 19    app_state: Arc<EpAppState>,
 20    mut cx: AsyncApp,
 21) {
 22    run_context_retrieval(example, app_state.clone(), cx.clone()).await;
 23
 24    match prompt_format {
 25        PromptFormat::Teacher => {
 26            let prompt = TeacherPrompt::format_prompt(example);
 27            example.prompt = Some(ExamplePrompt {
 28                input: prompt,
 29                expected_output: example.expected_patch.clone(), // TODO
 30                format: prompt_format,
 31            });
 32        }
 33        PromptFormat::Zeta2 => {
 34            run_load_project(example, app_state, cx.clone()).await;
 35
 36            let ep_store = cx
 37                .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 38                .unwrap();
 39
 40            let state = example.state.as_ref().unwrap();
 41            let snapshot = state
 42                .buffer
 43                .read_with(&cx, |buffer, _| buffer.snapshot())
 44                .unwrap();
 45            let project = state.project.clone();
 46            let (_, input) = ep_store
 47                .update(&mut cx, |ep_store, _cx| {
 48                    zeta2_prompt_input(
 49                        &snapshot,
 50                        example.context.as_ref().unwrap().files.clone(),
 51                        ep_store.edit_history_for_project(&project),
 52                        example.cursor_path.clone(),
 53                        example.buffer.as_ref().unwrap().cursor_offset,
 54                    )
 55                })
 56                .unwrap();
 57            let prompt = format_zeta_prompt(&input);
 58            let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
 59            example.prompt = Some(ExamplePrompt {
 60                input: prompt,
 61                expected_output,
 62                format: prompt_format,
 63            });
 64        }
 65    };
 66}
 67
 68pub struct TeacherPrompt;
 69
 70impl TeacherPrompt {
 71    const PROMPT: &str = include_str!("teacher.prompt.md");
 72    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
 73    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
 74
 75    /// Truncate edit history to this number of last lines
 76    const MAX_HISTORY_LINES: usize = 128;
 77
 78    pub fn format_prompt(example: &Example) -> String {
 79        let edit_history = Self::format_edit_history(&example.edit_history);
 80        let context = Self::format_context(example);
 81        let editable_region = Self::format_editable_region(example);
 82
 83        let prompt = Self::PROMPT
 84            .replace("{{context}}", &context)
 85            .replace("{{edit_history}}", &edit_history)
 86            .replace("{{editable_region}}", &editable_region);
 87
 88        prompt
 89    }
 90
 91    pub fn parse(example: &Example, response: &str) -> String {
 92        // Ideally, we should always be able to find cursor position in the retrieved context.
 93        // In reality, sometimes we don't find it for these reasons:
 94        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
 95        //    (can be fixed by getting cursor coordinates at the load_example stage)
 96        // 2. Context retriever just didn't include cursor line.
 97        //
 98        // In that case, fallback to using `cursor_position` as excerpt.
 99        let cursor_file = &example
100            .buffer
101            .as_ref()
102            .expect("`buffer` should be filled in in the context collection step")
103            .content;
104
105        // Extract updated (new) editable region from the model response
106        let new_editable_region = extract_last_codeblock(response);
107
108        // Reconstruct old editable region we sent to the model
109        let old_editable_region = Self::format_editable_region(example);
110        let old_editable_region = Self::extract_editable_region(&old_editable_region);
111        if !cursor_file.contains(&old_editable_region) {
112            panic!("Something's wrong: editable_region is not found in the cursor file")
113        }
114
115        // Apply editable region to a larger context and compute diff.
116        // This is needed to get a better context lines around the editable region
117        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
118        let diff = language::unified_diff(&cursor_file, &edited_file);
119
120        let diff = indoc::formatdoc! {"
121            --- a/{path}
122            +++ b/{path}
123            {diff}",
124            path = example.cursor_path.to_string_lossy(),
125            diff = diff,
126        };
127
128        diff
129    }
130
131    fn format_edit_history(edit_history: &str) -> String {
132        // Strip comments ("garbage lines") from edit history
133        let lines = edit_history
134            .lines()
135            .filter(|&s| Self::is_udiff_content_line(s))
136            .collect::<Vec<_>>();
137
138        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
139            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
140        } else {
141            &lines
142        };
143
144        if history_lines.is_empty() {
145            return "(No edit history)".to_string();
146        }
147
148        history_lines.join("\n")
149    }
150
151    fn format_context(example: &Example) -> String {
152        if example.context.is_none() {
153            panic!("Missing context retriever step");
154        }
155
156        let mut prompt = String::new();
157        zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
158
159        prompt
160    }
161
162    fn format_editable_region(example: &Example) -> String {
163        let mut result = String::new();
164
165        let path_str = example.cursor_path.to_string_lossy();
166        result.push_str(&format!("`````path=\"{path_str}\"\n"));
167        result.push_str(Self::EDITABLE_REGION_START);
168
169        // TODO: control number of lines around cursor
170        result.push_str(&example.cursor_position);
171        if !example.cursor_position.ends_with('\n') {
172            result.push('\n');
173        }
174
175        result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
176        result.push_str("`````");
177
178        result
179    }
180
181    fn extract_editable_region(text: &str) -> String {
182        let start = text
183            .find(Self::EDITABLE_REGION_START)
184            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
185        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
186
187        let region = &text[start..end];
188
189        region.replace("<|user_cursor|>", "")
190    }
191
192    fn is_udiff_content_line(s: &str) -> bool {
193        s.starts_with("-")
194            || s.starts_with("+")
195            || s.starts_with(" ")
196            || s.starts_with("---")
197            || s.starts_with("+++")
198            || s.starts_with("@@")
199    }
200}
201
202fn extract_last_codeblock(text: &str) -> String {
203    let mut last_block = None;
204    let mut search_start = 0;
205
206    while let Some(start) = text[search_start..].find("```") {
207        let start = start + search_start;
208        let bytes = text.as_bytes();
209        let mut backtick_end = start;
210
211        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
212            backtick_end += 1;
213        }
214
215        let backtick_count = backtick_end - start;
216        let closing_backticks = "`".repeat(backtick_count);
217
218        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
219            backtick_end += 1;
220        }
221
222        if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
223            let code_block = &text[backtick_end + 1..backtick_end + end_pos];
224            last_block = Some(code_block.to_string());
225            search_start = backtick_end + end_pos + backtick_count;
226        } else {
227            break;
228        }
229    }
230
231    last_block.unwrap_or_else(|| text.to_string())
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_extract_last_code_block() {
240        let text = indoc::indoc! {"
241            Some thinking
242
243            ```
244            first block
245            ```
246
247            `````path='something' lines=1:2
248            last block
249            `````
250            "};
251        let last_block = extract_last_codeblock(text);
252        assert_eq!(last_block, "last block\n");
253    }
254
255    #[test]
256    fn test_extract_editable_region() {
257        let text = indoc::indoc! {"
258            some lines
259            are
260            here
261            <|editable_region_start|>
262            one
263            two three
264
265            <|editable_region_end|>
266            more
267            lines here
268            "};
269        let parsed = TeacherPrompt::extract_editable_region(text);
270        assert_eq!(
271            parsed,
272            indoc::indoc! {"
273            one
274            two three
275
276            "}
277        );
278    }
279}