format_prompt.rs

  1use crate::{
  2    PromptFormat,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    load_project::run_load_project,
  6    progress::{Progress, Step},
  7    retrieve_context::run_context_retrieval,
  8};
  9use anyhow::{Context as _, Result, ensure};
 10use edit_prediction::{
 11    EditPredictionStore,
 12    zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
 13};
 14use gpui::AsyncApp;
 15use std::sync::Arc;
 16use zeta_prompt::format_zeta_prompt;
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    prompt_format: PromptFormat,
 21    app_state: Arc<EpAppState>,
 22    mut cx: AsyncApp,
 23) -> Result<()> {
 24    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 25
 26    let _step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 27
 28    match prompt_format {
 29        PromptFormat::Teacher => {
 30            let prompt = TeacherPrompt::format_prompt(example);
 31            example.prompt = Some(ExamplePrompt {
 32                input: prompt,
 33                expected_output: example.spec.expected_patch.clone(), // TODO
 34                format: prompt_format,
 35            });
 36        }
 37        PromptFormat::Zeta2 => {
 38            run_load_project(example, app_state, cx.clone()).await?;
 39
 40            let ep_store = cx.update(|cx| {
 41                EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
 42            })??;
 43
 44            let state = example.state.as_ref().context("state must be set")?;
 45            let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
 46            let project = state.project.clone();
 47            let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
 48                let events = ep_store
 49                    .edit_history_for_project(&project, cx)
 50                    .into_iter()
 51                    .map(|e| e.event)
 52                    .collect();
 53                anyhow::Ok(zeta2_prompt_input(
 54                    &snapshot,
 55                    example
 56                        .context
 57                        .as_ref()
 58                        .context("context must be set")?
 59                        .files
 60                        .clone(),
 61                    events,
 62                    example.spec.cursor_path.clone(),
 63                    example
 64                        .buffer
 65                        .as_ref()
 66                        .context("buffer must be set")?
 67                        .cursor_offset,
 68                ))
 69            })??;
 70            let prompt = format_zeta_prompt(&input);
 71            let expected_output =
 72                zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?;
 73            example.prompt = Some(ExamplePrompt {
 74                input: prompt,
 75                expected_output,
 76                format: prompt_format,
 77            });
 78        }
 79    };
 80    Ok(())
 81}
 82
 83pub struct TeacherPrompt;
 84
 85impl TeacherPrompt {
 86    const PROMPT: &str = include_str!("teacher.prompt.md");
 87    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
 88    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
 89    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
 90
 91    /// Truncate edit history to this number of last lines
 92    const MAX_HISTORY_LINES: usize = 128;
 93
 94    pub fn format_prompt(example: &Example) -> String {
 95        let edit_history = Self::format_edit_history(&example.spec.edit_history);
 96        let context = Self::format_context(example);
 97        let editable_region = Self::format_editable_region(example);
 98
 99        let prompt = Self::PROMPT
100            .replace("{{context}}", &context)
101            .replace("{{edit_history}}", &edit_history)
102            .replace("{{editable_region}}", &editable_region);
103
104        prompt
105    }
106
107    pub fn parse(example: &Example, response: &str) -> Result<String> {
108        // Ideally, we should always be able to find cursor position in the retrieved context.
109        // In reality, sometimes we don't find it for these reasons:
110        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
111        //    (can be fixed by getting cursor coordinates at the load_example stage)
112        // 2. Context retriever just didn't include cursor line.
113        //
114        // In that case, fallback to using `cursor_position` as excerpt.
115        let cursor_file = &example
116            .buffer
117            .as_ref()
118            .context("`buffer` should be filled in in the context collection step")?
119            .content;
120
121        // Extract updated (new) editable region from the model response
122        let new_editable_region = extract_last_codeblock(response);
123
124        // Reconstruct old editable region we sent to the model
125        let old_editable_region = Self::format_editable_region(example);
126        let old_editable_region = Self::extract_editable_region(&old_editable_region);
127        ensure!(
128            cursor_file.contains(&old_editable_region),
129            "Something's wrong: editable_region is not found in the cursor file"
130        );
131
132        // Apply editable region to a larger context and compute diff.
133        // This is needed to get a better context lines around the editable region
134        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
135        let diff = language::unified_diff(&cursor_file, &edited_file);
136
137        let diff = indoc::formatdoc! {"
138            --- a/{path}
139            +++ b/{path}
140            {diff}",
141            path = example.spec.cursor_path.to_string_lossy(),
142            diff = diff,
143        };
144
145        Ok(diff)
146    }
147
148    fn format_edit_history(edit_history: &str) -> String {
149        // Strip comments ("garbage lines") from edit history
150        let lines = edit_history
151            .lines()
152            .filter(|&s| Self::is_udiff_content_line(s))
153            .collect::<Vec<_>>();
154
155        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
156            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
157        } else {
158            &lines
159        };
160
161        if history_lines.is_empty() {
162            return "(No edit history)".to_string();
163        }
164
165        history_lines.join("\n")
166    }
167
168    fn format_context(example: &Example) -> String {
169        assert!(example.context.is_some(), "Missing context retriever step");
170
171        let mut prompt = String::new();
172        zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
173
174        prompt
175    }
176
177    fn format_editable_region(example: &Example) -> String {
178        let mut result = String::new();
179
180        let path_str = example.spec.cursor_path.to_string_lossy();
181        result.push_str(&format!("`````path=\"{path_str}\"\n"));
182        result.push_str(Self::EDITABLE_REGION_START);
183
184        // TODO: control number of lines around cursor
185        let (mut excerpt, offset) = example.spec.cursor_excerpt().unwrap();
186        excerpt.insert_str(offset, Self::USER_CURSOR_MARKER);
187        result.push_str(&excerpt);
188        if !result.ends_with('\n') {
189            result.push('\n');
190        }
191
192        result.push_str(Self::EDITABLE_REGION_END);
193        result.push_str("\n`````");
194
195        result
196    }
197
198    fn extract_editable_region(text: &str) -> String {
199        let start = text
200            .find(Self::EDITABLE_REGION_START)
201            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
202        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
203
204        let region = &text[start..end];
205
206        region.replace("<|user_cursor|>", "")
207    }
208
209    fn is_udiff_content_line(s: &str) -> bool {
210        s.starts_with("-")
211            || s.starts_with("+")
212            || s.starts_with(" ")
213            || s.starts_with("---")
214            || s.starts_with("+++")
215            || s.starts_with("@@")
216    }
217}
218
219fn extract_last_codeblock(text: &str) -> String {
220    let mut last_block = None;
221    let mut search_start = 0;
222
223    while let Some(start) = text[search_start..].find("```") {
224        let start = start + search_start;
225        let bytes = text.as_bytes();
226        let mut backtick_end = start;
227
228        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
229            backtick_end += 1;
230        }
231
232        let backtick_count = backtick_end - start;
233        let closing_backticks = "`".repeat(backtick_count);
234
235        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
236            backtick_end += 1;
237        }
238
239        if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
240            let code_block = &text[backtick_end + 1..backtick_end + end_pos];
241            last_block = Some(code_block.to_string());
242            search_start = backtick_end + end_pos + backtick_count;
243        } else {
244            break;
245        }
246    }
247
248    last_block.unwrap_or_else(|| text.to_string())
249}
250
251#[cfg(test)]
252mod tests {
253    use super::*;
254
255    #[test]
256    fn test_extract_last_code_block() {
257        let text = indoc::indoc! {"
258            Some thinking
259
260            ```
261            first block
262            ```
263
264            `````path='something' lines=1:2
265            last block
266            `````
267            "};
268        let last_block = extract_last_codeblock(text);
269        assert_eq!(last_block, "last block\n");
270    }
271
272    #[test]
273    fn test_extract_editable_region() {
274        let text = indoc::indoc! {"
275            some lines
276            are
277            here
278            <|editable_region_start|>
279            one
280            two three
281
282            <|editable_region_end|>
283            more
284            lines here
285            "};
286        let parsed = TeacherPrompt::extract_editable_region(text);
287        assert_eq!(
288            parsed,
289            indoc::indoc! {"
290            one
291            two three
292
293            "}
294        );
295    }
296}