format_prompt.rs

  1use crate::{
  2    PromptFormat,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    load_project::run_load_project,
  6    progress::{Progress, Step},
  7    retrieve_context::run_context_retrieval,
  8};
  9use edit_prediction::{
 10    EditPredictionStore,
 11    zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
 12};
 13use gpui::AsyncApp;
 14use std::sync::Arc;
 15use zeta_prompt::format_zeta_prompt;
 16
 17pub async fn run_format_prompt(
 18    example: &mut Example,
 19    prompt_format: PromptFormat,
 20    app_state: Arc<EpAppState>,
 21    progress: Arc<Progress>,
 22    mut cx: AsyncApp,
 23) {
 24    run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
 25
 26    let _step_progress = progress.start(Step::FormatPrompt, &example.name);
 27
 28    match prompt_format {
 29        PromptFormat::Teacher => {
 30            let prompt = TeacherPrompt::format_prompt(example);
 31            example.prompt = Some(ExamplePrompt {
 32                input: prompt,
 33                expected_output: example.expected_patch.clone(), // TODO
 34                format: prompt_format,
 35            });
 36        }
 37        PromptFormat::Zeta2 => {
 38            run_load_project(example, app_state, progress.clone(), cx.clone()).await;
 39
 40            let ep_store = cx
 41                .update(|cx| EditPredictionStore::try_global(cx).unwrap())
 42                .unwrap();
 43
 44            let state = example.state.as_ref().unwrap();
 45            let snapshot = state
 46                .buffer
 47                .read_with(&cx, |buffer, _| buffer.snapshot())
 48                .unwrap();
 49            let project = state.project.clone();
 50            let (_, input) = ep_store
 51                .update(&mut cx, |ep_store, _cx| {
 52                    zeta2_prompt_input(
 53                        &snapshot,
 54                        example.context.as_ref().unwrap().files.clone(),
 55                        ep_store.edit_history_for_project(&project),
 56                        example.cursor_path.clone(),
 57                        example.buffer.as_ref().unwrap().cursor_offset,
 58                    )
 59                })
 60                .unwrap();
 61            let prompt = format_zeta_prompt(&input);
 62            let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
 63            example.prompt = Some(ExamplePrompt {
 64                input: prompt,
 65                expected_output,
 66                format: prompt_format,
 67            });
 68        }
 69    };
 70}
 71
 72pub struct TeacherPrompt;
 73
 74impl TeacherPrompt {
 75    const PROMPT: &str = include_str!("teacher.prompt.md");
 76    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
 77    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
 78
 79    /// Truncate edit history to this number of last lines
 80    const MAX_HISTORY_LINES: usize = 128;
 81
 82    pub fn format_prompt(example: &Example) -> String {
 83        let edit_history = Self::format_edit_history(&example.edit_history);
 84        let context = Self::format_context(example);
 85        let editable_region = Self::format_editable_region(example);
 86
 87        let prompt = Self::PROMPT
 88            .replace("{{context}}", &context)
 89            .replace("{{edit_history}}", &edit_history)
 90            .replace("{{editable_region}}", &editable_region);
 91
 92        prompt
 93    }
 94
 95    pub fn parse(example: &Example, response: &str) -> String {
 96        // Ideally, we should always be able to find cursor position in the retrieved context.
 97        // In reality, sometimes we don't find it for these reasons:
 98        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
 99        //    (can be fixed by getting cursor coordinates at the load_example stage)
100        // 2. Context retriever just didn't include cursor line.
101        //
102        // In that case, fallback to using `cursor_position` as excerpt.
103        let cursor_file = &example
104            .buffer
105            .as_ref()
106            .expect("`buffer` should be filled in in the context collection step")
107            .content;
108
109        // Extract updated (new) editable region from the model response
110        let new_editable_region = extract_last_codeblock(response);
111
112        // Reconstruct old editable region we sent to the model
113        let old_editable_region = Self::format_editable_region(example);
114        let old_editable_region = Self::extract_editable_region(&old_editable_region);
115        if !cursor_file.contains(&old_editable_region) {
116            panic!("Something's wrong: editable_region is not found in the cursor file")
117        }
118
119        // Apply editable region to a larger context and compute diff.
120        // This is needed to get a better context lines around the editable region
121        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
122        let diff = language::unified_diff(&cursor_file, &edited_file);
123
124        let diff = indoc::formatdoc! {"
125            --- a/{path}
126            +++ b/{path}
127            {diff}",
128            path = example.cursor_path.to_string_lossy(),
129            diff = diff,
130        };
131
132        diff
133    }
134
135    fn format_edit_history(edit_history: &str) -> String {
136        // Strip comments ("garbage lines") from edit history
137        let lines = edit_history
138            .lines()
139            .filter(|&s| Self::is_udiff_content_line(s))
140            .collect::<Vec<_>>();
141
142        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
143            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
144        } else {
145            &lines
146        };
147
148        if history_lines.is_empty() {
149            return "(No edit history)".to_string();
150        }
151
152        history_lines.join("\n")
153    }
154
155    fn format_context(example: &Example) -> String {
156        if example.context.is_none() {
157            panic!("Missing context retriever step");
158        }
159
160        let mut prompt = String::new();
161        zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
162
163        prompt
164    }
165
166    fn format_editable_region(example: &Example) -> String {
167        let mut result = String::new();
168
169        let path_str = example.cursor_path.to_string_lossy();
170        result.push_str(&format!("`````path=\"{path_str}\"\n"));
171        result.push_str(Self::EDITABLE_REGION_START);
172
173        // TODO: control number of lines around cursor
174        result.push_str(&example.cursor_position);
175        if !example.cursor_position.ends_with('\n') {
176            result.push('\n');
177        }
178
179        result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
180        result.push_str("`````");
181
182        result
183    }
184
185    fn extract_editable_region(text: &str) -> String {
186        let start = text
187            .find(Self::EDITABLE_REGION_START)
188            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
189        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
190
191        let region = &text[start..end];
192
193        region.replace("<|user_cursor|>", "")
194    }
195
196    fn is_udiff_content_line(s: &str) -> bool {
197        s.starts_with("-")
198            || s.starts_with("+")
199            || s.starts_with(" ")
200            || s.starts_with("---")
201            || s.starts_with("+++")
202            || s.starts_with("@@")
203    }
204}
205
206fn extract_last_codeblock(text: &str) -> String {
207    let mut last_block = None;
208    let mut search_start = 0;
209
210    while let Some(start) = text[search_start..].find("```") {
211        let start = start + search_start;
212        let bytes = text.as_bytes();
213        let mut backtick_end = start;
214
215        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
216            backtick_end += 1;
217        }
218
219        let backtick_count = backtick_end - start;
220        let closing_backticks = "`".repeat(backtick_count);
221
222        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
223            backtick_end += 1;
224        }
225
226        if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
227            let code_block = &text[backtick_end + 1..backtick_end + end_pos];
228            last_block = Some(code_block.to_string());
229            search_start = backtick_end + end_pos + backtick_count;
230        } else {
231            break;
232        }
233    }
234
235    last_block.unwrap_or_else(|| text.to_string())
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn test_extract_last_code_block() {
244        let text = indoc::indoc! {"
245            Some thinking
246
247            ```
248            first block
249            ```
250
251            `````path='something' lines=1:2
252            last block
253            `````
254            "};
255        let last_block = extract_last_codeblock(text);
256        assert_eq!(last_block, "last block\n");
257    }
258
259    #[test]
260    fn test_extract_editable_region() {
261        let text = indoc::indoc! {"
262            some lines
263            are
264            here
265            <|editable_region_start|>
266            one
267            two three
268
269            <|editable_region_end|>
270            more
271            lines here
272            "};
273        let parsed = TeacherPrompt::extract_editable_region(text);
274        assert_eq!(
275            parsed,
276            indoc::indoc! {"
277            one
278            two three
279
280            "}
281        );
282    }
283}