format_prompt.rs

  1use crate::{
  2    PromptFormat,
  3    example::{Example, ExamplePrompt},
  4    headless::EpAppState,
  5    load_project::run_load_project,
  6    progress::{Progress, Step},
  7    retrieve_context::run_context_retrieval,
  8};
  9use anyhow::{Context as _, Result, ensure};
 10use edit_prediction::{
 11    EditPredictionStore,
 12    zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
 13};
 14use gpui::AsyncApp;
 15use std::sync::Arc;
 16use zeta_prompt::format_zeta_prompt;
 17
 18pub async fn run_format_prompt(
 19    example: &mut Example,
 20    prompt_format: PromptFormat,
 21    app_state: Arc<EpAppState>,
 22    mut cx: AsyncApp,
 23) -> Result<()> {
 24    run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
 25
 26    let _step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
 27
 28    match prompt_format {
 29        PromptFormat::Teacher => {
 30            let prompt = TeacherPrompt::format_prompt(example);
 31            example.prompt = Some(ExamplePrompt {
 32                input: prompt,
 33                // TODO
 34                expected_output: example
 35                    .spec
 36                    .expected_patches
 37                    .first()
 38                    .context("no expected patches")?
 39                    .clone(),
 40                format: prompt_format,
 41            });
 42        }
 43        PromptFormat::Zeta2 => {
 44            run_load_project(example, app_state, cx.clone()).await?;
 45
 46            let ep_store = cx.update(|cx| {
 47                EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
 48            })??;
 49
 50            let state = example.state.as_ref().context("state must be set")?;
 51            let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
 52            let project = state.project.clone();
 53            let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
 54                let events = ep_store
 55                    .edit_history_for_project(&project, cx)
 56                    .into_iter()
 57                    .map(|e| e.event)
 58                    .collect();
 59                anyhow::Ok(zeta2_prompt_input(
 60                    &snapshot,
 61                    example
 62                        .context
 63                        .as_ref()
 64                        .context("context must be set")?
 65                        .files
 66                        .clone(),
 67                    events,
 68                    example.spec.cursor_path.clone(),
 69                    example
 70                        .buffer
 71                        .as_ref()
 72                        .context("buffer must be set")?
 73                        .cursor_offset,
 74                ))
 75            })??;
 76            let prompt = format_zeta_prompt(&input);
 77            let expected_output = zeta2_output_for_patch(
 78                &input,
 79                &example
 80                    .spec
 81                    .expected_patches
 82                    .first()
 83                    .context("expected patches is empty")?
 84                    .clone(),
 85            )?;
 86            example.prompt = Some(ExamplePrompt {
 87                input: prompt,
 88                expected_output,
 89                format: prompt_format,
 90            });
 91        }
 92    };
 93    Ok(())
 94}
 95
 96pub struct TeacherPrompt;
 97
 98impl TeacherPrompt {
 99    const PROMPT: &str = include_str!("teacher.prompt.md");
100    pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
101    pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
102    pub(crate) const USER_CURSOR_MARKER: &str = "<|user_cursor|>";
103
104    /// Truncate edit history to this number of last lines
105    const MAX_HISTORY_LINES: usize = 128;
106
107    pub fn format_prompt(example: &Example) -> String {
108        let edit_history = Self::format_edit_history(&example.spec.edit_history);
109        let context = Self::format_context(example);
110        let editable_region = Self::format_editable_region(example);
111
112        let prompt = Self::PROMPT
113            .replace("{{context}}", &context)
114            .replace("{{edit_history}}", &edit_history)
115            .replace("{{editable_region}}", &editable_region);
116
117        prompt
118    }
119
120    pub fn parse(example: &Example, response: &str) -> Result<String> {
121        // Ideally, we should always be able to find cursor position in the retrieved context.
122        // In reality, sometimes we don't find it for these reasons:
123        // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
124        //    (can be fixed by getting cursor coordinates at the load_example stage)
125        // 2. Context retriever just didn't include cursor line.
126        //
127        // In that case, fallback to using `cursor_position` as excerpt.
128        let cursor_file = &example
129            .buffer
130            .as_ref()
131            .context("`buffer` should be filled in in the context collection step")?
132            .content;
133
134        // Extract updated (new) editable region from the model response
135        let new_editable_region = extract_last_codeblock(response);
136
137        // Reconstruct old editable region we sent to the model
138        let old_editable_region = Self::format_editable_region(example);
139        let old_editable_region = Self::extract_editable_region(&old_editable_region);
140        ensure!(
141            cursor_file.contains(&old_editable_region),
142            "Something's wrong: editable_region is not found in the cursor file"
143        );
144
145        // Apply editable region to a larger context and compute diff.
146        // This is needed to get a better context lines around the editable region
147        let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
148        let diff = language::unified_diff(&cursor_file, &edited_file);
149
150        let diff = indoc::formatdoc! {"
151            --- a/{path}
152            +++ b/{path}
153            {diff}",
154            path = example.spec.cursor_path.to_string_lossy(),
155            diff = diff,
156        };
157
158        Ok(diff)
159    }
160
161    fn format_edit_history(edit_history: &str) -> String {
162        // Strip comments ("garbage lines") from edit history
163        let lines = edit_history
164            .lines()
165            .filter(|&s| Self::is_udiff_content_line(s))
166            .collect::<Vec<_>>();
167
168        let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
169            &lines[lines.len() - Self::MAX_HISTORY_LINES..]
170        } else {
171            &lines
172        };
173
174        if history_lines.is_empty() {
175            return "(No edit history)".to_string();
176        }
177
178        history_lines.join("\n")
179    }
180
181    fn format_context(example: &Example) -> String {
182        assert!(example.context.is_some(), "Missing context retriever step");
183
184        let mut prompt = String::new();
185        zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
186
187        prompt
188    }
189
190    fn format_editable_region(example: &Example) -> String {
191        let mut result = String::new();
192
193        let path_str = example.spec.cursor_path.to_string_lossy();
194        result.push_str(&format!("`````path=\"{path_str}\"\n"));
195        result.push_str(Self::EDITABLE_REGION_START);
196
197        // TODO: control number of lines around cursor
198        let (mut excerpt, offset) = example.spec.cursor_excerpt().unwrap();
199        excerpt.insert_str(offset, Self::USER_CURSOR_MARKER);
200        result.push_str(&excerpt);
201        if !result.ends_with('\n') {
202            result.push('\n');
203        }
204
205        result.push_str(Self::EDITABLE_REGION_END);
206        result.push_str("\n`````");
207
208        result
209    }
210
211    fn extract_editable_region(text: &str) -> String {
212        let start = text
213            .find(Self::EDITABLE_REGION_START)
214            .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
215        let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
216
217        let region = &text[start..end];
218
219        region.replace("<|user_cursor|>", "")
220    }
221
222    fn is_udiff_content_line(s: &str) -> bool {
223        s.starts_with("-")
224            || s.starts_with("+")
225            || s.starts_with(" ")
226            || s.starts_with("---")
227            || s.starts_with("+++")
228            || s.starts_with("@@")
229    }
230}
231
232fn extract_last_codeblock(text: &str) -> String {
233    let mut last_block = None;
234    let mut search_start = 0;
235
236    while let Some(start) = text[search_start..].find("```") {
237        let start = start + search_start;
238        let bytes = text.as_bytes();
239        let mut backtick_end = start;
240
241        while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
242            backtick_end += 1;
243        }
244
245        let backtick_count = backtick_end - start;
246        let closing_backticks = "`".repeat(backtick_count);
247
248        while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
249            backtick_end += 1;
250        }
251
252        if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
253            let code_block = &text[backtick_end + 1..backtick_end + end_pos];
254            last_block = Some(code_block.to_string());
255            search_start = backtick_end + end_pos + backtick_count;
256        } else {
257            break;
258        }
259    }
260
261    last_block.unwrap_or_else(|| text.to_string())
262}
263
264#[cfg(test)]
265mod tests {
266    use super::*;
267
268    #[test]
269    fn test_extract_last_code_block() {
270        let text = indoc::indoc! {"
271            Some thinking
272
273            ```
274            first block
275            ```
276
277            `````path='something' lines=1:2
278            last block
279            `````
280            "};
281        let last_block = extract_last_codeblock(text);
282        assert_eq!(last_block, "last block\n");
283    }
284
285    #[test]
286    fn test_extract_editable_region() {
287        let text = indoc::indoc! {"
288            some lines
289            are
290            here
291            <|editable_region_start|>
292            one
293            two three
294
295            <|editable_region_end|>
296            more
297            lines here
298            "};
299        let parsed = TeacherPrompt::extract_editable_region(text);
300        assert_eq!(
301            parsed,
302            indoc::indoc! {"
303            one
304            two three
305
306            "}
307        );
308    }
309}