format_prompt.rs

  1use crate::{
  2    PromptFormat,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    retrieve_context::run_context_retrieval,
  6};
  7use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
  8use gpui::AsyncApp;
  9use std::sync::Arc;
 10use zeta_prompt::format_zeta_prompt;
 11
 12pub async fn run_format_prompt(
 13    example: &mut Example,
 14    prompt_format: PromptFormat,
 15    app_state: Arc<EpAppState>,
 16    mut cx: AsyncApp,
 17) {
 18    run_context_retrieval(example, app_state, cx.clone()).await;
 19
 20    let prompt = match prompt_format {
 21        PromptFormat::Teacher => TeacherPrompt::format(example),
 22        PromptFormat::Zeta2 => {
 23            let ep_store = cx
 24                .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 25                .unwrap();
 26
 27            let state = example.state.as_ref().unwrap();
 28            let snapshot = state
 29                .buffer
 30                .read_with(&cx, |buffer, _| buffer.snapshot())
 31                .unwrap();
 32            let project = state.project.clone();
 33            let (_, input) = ep_store
 34                .update(&mut cx, |ep_store, _cx| {
 35                    zeta2_prompt_input(
 36                        &snapshot,
 37                        example.context.as_ref().unwrap().files.clone(),
 38                        ep_store.edit_history_for_project(&project),
 39                        example.cursor_path.clone(),
 40                        example.buffer.as_ref().unwrap().cursor_offset,
 41                    )
 42                })
 43                .unwrap();
 44            format_zeta_prompt(&input)
 45        }
 46    };
 47
 48    example.prompt = Some(ExamplePrompt {
 49        input: prompt,
 50        expected_output: example.expected_patch.clone(), // TODO
 51        format: prompt_format,
 52    });
 53}
 54
 55pub trait PromptFormatter {
 56    fn format(example: &Example) -> String;
 57}
 58
 59pub trait PromptParser {
 60    /// Return unified diff patch of prediction given raw LLM response
 61    fn parse(example: &Example, response: &str) -> String;
 62}
 63
 64pub struct TeacherPrompt;
 65
 66impl PromptFormatter for TeacherPrompt {
 67    fn format(example: &Example) -> String {
 68        let edit_history = Self::format_edit_history(&example.edit_history);
 69        let context = Self::format_context(example);
 70        let editable_region = Self::format_editable_region(example);
 71
 72        let prompt = Self::PROMPT
 73            .replace("{{context}}", &context)
 74            .replace("{{edit_history}}", &edit_history)
 75            .replace("{{editable_region}}", &editable_region);
 76
 77        prompt
 78    }
 79}
 80
 81impl TeacherPrompt {
 82    const PROMPT: &str = include_str!("teacher.prompt.md");
 83    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
 84    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
 85
 86    /// Truncate edit history to this number of last lines
 87    const MAX_HISTORY_LINES: usize = 128;
 88
 89    fn format_edit_history(edit_history: &str) -> String {
 90        // Strip comments ("garbage lines") from edit history
 91        let lines = edit_history
 92            .lines()
 93            .filter(|&s| Self::is_udiff_content_line(s))
 94            .collect::<Vec<_>>();
 95
 96        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
 97            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
 98        } else {
 99            &lines
100        };
101
102        if history_lines.is_empty() {
103            return "(No edit history)".to_string();
104        }
105
106        history_lines.join("\n")
107    }
108
109    fn format_context(example: &Example) -> String {
110        if example.context.is_none() {
111            panic!("Missing context retriever step");
112        }
113
114        let mut prompt = String::new();
115        zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
116
117        prompt
118    }
119
120    fn format_editable_region(example: &Example) -> String {
121        let mut result = String::new();
122
123        let path_str = example.cursor_path.to_string_lossy();
124        result.push_str(&format!("`````path=\"{path_str}\"\n"));
125        result.push_str(Self::EDITABLE_REGION_START);
126
127        // TODO: control number of lines around cursor
128        result.push_str(&example.cursor_position);
129        if !example.cursor_position.ends_with('\n') {
130            result.push('\n');
131        }
132
133        result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
134        result.push_str("`````");
135
136        result
137    }
138
139    fn extract_editable_region(text: &str) -> String {
140        let start = text
141            .find(Self::EDITABLE_REGION_START)
142            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
143        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
144
145        let region = &text[start..end];
146
147        region.replace("<|user_cursor|>", "")
148    }
149
150    fn is_udiff_content_line(s: &str) -> bool {
151        s.starts_with("-")
152            || s.starts_with("+")
153            || s.starts_with(" ")
154            || s.starts_with("---")
155            || s.starts_with("+++")
156            || s.starts_with("@@")
157    }
158}
159
160impl PromptParser for TeacherPrompt {
161    fn parse(example: &Example, response: &str) -> String {
162        // Ideally, we should always be able to find cursor position in the retrieved context.
163        // In reality, sometimes we don't find it for these reasons:
164        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
165        //    (can be fixed by getting cursor coordinates at the load_example stage)
166        // 2. Context retriever just didn't include cursor line.
167        //
168        // In that case, fallback to using `cursor_position` as excerpt.
169        let cursor_file = &example
170            .buffer
171            .as_ref()
172            .expect("`buffer` should be filled in in the context collection step")
173            .content;
174
175        // Extract updated (new) editable region from the model response
176        let new_editable_region = extract_last_codeblock(response);
177
178        // Reconstruct old editable region we sent to the model
179        let old_editable_region = Self::format_editable_region(example);
180        let old_editable_region = Self::extract_editable_region(&old_editable_region);
181        if !cursor_file.contains(&old_editable_region) {
182            panic!("Something's wrong: editable_region is not found in the cursor file")
183        }
184
185        // Apply editable region to a larger context and compute diff.
186        // This is needed to get a better context lines around the editable region
187        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
188        let diff = language::unified_diff(&cursor_file, &edited_file);
189
190        let diff = indoc::formatdoc! {"
191            --- a/{path}
192            +++ b/{path}
193            {diff}
194            ",
195            path = example.cursor_path.to_string_lossy(),
196            diff = diff,
197        };
198
199        diff
200    }
201}
202
203fn extract_last_codeblock(text: &str) -> String {
204    let mut last_block = None;
205    let mut search_start = 0;
206
207    while let Some(start) = text[search_start..].find("```") {
208        let start = start + search_start;
209        let bytes = text.as_bytes();
210        let mut backtick_end = start;
211
212        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
213            backtick_end += 1;
214        }
215
216        let backtick_count = backtick_end - start;
217        let closing_backticks = "`".repeat(backtick_count);
218
219        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
220            backtick_end += 1;
221        }
222
223        if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
224            let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
225            last_block = Some(code_block.to_string());
226            search_start = backtick_end + end_pos + backtick_count;
227        } else {
228            break;
229        }
230    }
231
232    last_block.unwrap_or_else(|| text.to_string())
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_extract_last_code_block() {
241        let text = indoc::indoc! {"
242            Some thinking
243
244            ```
245            first block
246            ```
247
248            `````path='something' lines=1:2
249            last block
250            `````
251            "};
252        let last_block = extract_last_codeblock(text);
253        assert_eq!(last_block, "last block");
254    }
255
256    #[test]
257    fn test_extract_editable_region() {
258        let text = indoc::indoc! {"
259            some lines
260            are
261            here
262            <|editable_region_start|>
263            one
264            two three
265
266            <|editable_region_end|>
267            more
268            lines here
269            "};
270        let parsed = TeacherPrompt::extract_editable_region(text);
271        assert_eq!(
272            parsed,
273            indoc::indoc! {"
274            one
275            two three
276
277            "}
278        );
279    }
280}