1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 load_project::run_load_project,
6 progress::{Progress, Step},
7 retrieve_context::run_context_retrieval,
8};
9use anyhow::{Context as _, Result, ensure};
10use edit_prediction::{
11 EditPredictionStore,
12 zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
13};
14use gpui::AsyncApp;
15use std::sync::Arc;
16use zeta_prompt::format_zeta_prompt;
17
18pub async fn run_format_prompt(
19 example: &mut Example,
20 prompt_format: PromptFormat,
21 app_state: Arc<EpAppState>,
22 mut cx: AsyncApp,
23) -> Result<()> {
24 run_context_retrieval(example, app_state.clone(), cx.clone()).await?;
25
26 let _step_progress = Progress::global().start(Step::FormatPrompt, &example.spec.name);
27
28 match prompt_format {
29 PromptFormat::Teacher => {
30 let prompt = TeacherPrompt::format_prompt(example);
31 example.prompt = Some(ExamplePrompt {
32 input: prompt,
33 expected_output: example.spec.expected_patch.clone(), // TODO
34 format: prompt_format,
35 });
36 }
37 PromptFormat::Zeta2 => {
38 run_load_project(example, app_state, cx.clone()).await?;
39
40 let ep_store = cx.update(|cx| {
41 EditPredictionStore::try_global(cx).context("EditPredictionStore not initialized")
42 })??;
43
44 let state = example.state.as_ref().context("state must be set")?;
45 let snapshot = state.buffer.read_with(&cx, |buffer, _| buffer.snapshot())?;
46 let project = state.project.clone();
47 let (_, input) = ep_store.update(&mut cx, |ep_store, cx| {
48 anyhow::Ok(zeta2_prompt_input(
49 &snapshot,
50 example
51 .context
52 .as_ref()
53 .context("context must be set")?
54 .files
55 .clone(),
56 ep_store.edit_history_for_project(&project, cx),
57 example.spec.cursor_path.clone(),
58 example
59 .buffer
60 .as_ref()
61 .context("buffer must be set")?
62 .cursor_offset,
63 ))
64 })??;
65 let prompt = format_zeta_prompt(&input);
66 let expected_output =
67 zeta2_output_for_patch(&input, &example.spec.expected_patch.clone())?;
68 example.prompt = Some(ExamplePrompt {
69 input: prompt,
70 expected_output,
71 format: prompt_format,
72 });
73 }
74 };
75 Ok(())
76}
77
78pub struct TeacherPrompt;
79
80impl TeacherPrompt {
81 const PROMPT: &str = include_str!("teacher.prompt.md");
82 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
83 pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
84
85 /// Truncate edit history to this number of last lines
86 const MAX_HISTORY_LINES: usize = 128;
87
88 pub fn format_prompt(example: &Example) -> String {
89 let edit_history = Self::format_edit_history(&example.spec.edit_history);
90 let context = Self::format_context(example);
91 let editable_region = Self::format_editable_region(example);
92
93 let prompt = Self::PROMPT
94 .replace("{{context}}", &context)
95 .replace("{{edit_history}}", &edit_history)
96 .replace("{{editable_region}}", &editable_region);
97
98 prompt
99 }
100
101 pub fn parse(example: &Example, response: &str) -> Result<String> {
102 // Ideally, we should always be able to find cursor position in the retrieved context.
103 // In reality, sometimes we don't find it for these reasons:
104 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
105 // (can be fixed by getting cursor coordinates at the load_example stage)
106 // 2. Context retriever just didn't include cursor line.
107 //
108 // In that case, fallback to using `cursor_position` as excerpt.
109 let cursor_file = &example
110 .buffer
111 .as_ref()
112 .context("`buffer` should be filled in in the context collection step")?
113 .content;
114
115 // Extract updated (new) editable region from the model response
116 let new_editable_region = extract_last_codeblock(response);
117
118 // Reconstruct old editable region we sent to the model
119 let old_editable_region = Self::format_editable_region(example);
120 let old_editable_region = Self::extract_editable_region(&old_editable_region);
121 ensure!(
122 cursor_file.contains(&old_editable_region),
123 "Something's wrong: editable_region is not found in the cursor file"
124 );
125
126 // Apply editable region to a larger context and compute diff.
127 // This is needed to get a better context lines around the editable region
128 let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
129 let diff = language::unified_diff(&cursor_file, &edited_file);
130
131 let diff = indoc::formatdoc! {"
132 --- a/{path}
133 +++ b/{path}
134 {diff}",
135 path = example.spec.cursor_path.to_string_lossy(),
136 diff = diff,
137 };
138
139 Ok(diff)
140 }
141
142 fn format_edit_history(edit_history: &str) -> String {
143 // Strip comments ("garbage lines") from edit history
144 let lines = edit_history
145 .lines()
146 .filter(|&s| Self::is_udiff_content_line(s))
147 .collect::<Vec<_>>();
148
149 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
150 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
151 } else {
152 &lines
153 };
154
155 if history_lines.is_empty() {
156 return "(No edit history)".to_string();
157 }
158
159 history_lines.join("\n")
160 }
161
162 fn format_context(example: &Example) -> String {
163 assert!(example.context.is_some(), "Missing context retriever step");
164
165 let mut prompt = String::new();
166 zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
167
168 prompt
169 }
170
171 fn format_editable_region(example: &Example) -> String {
172 let mut result = String::new();
173
174 let path_str = example.spec.cursor_path.to_string_lossy();
175 result.push_str(&format!("`````path=\"{path_str}\"\n"));
176 result.push_str(Self::EDITABLE_REGION_START);
177
178 // TODO: control number of lines around cursor
179 result.push_str(&example.spec.cursor_position);
180 if !example.spec.cursor_position.ends_with('\n') {
181 result.push('\n');
182 }
183
184 result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
185 result.push_str("`````");
186
187 result
188 }
189
190 fn extract_editable_region(text: &str) -> String {
191 let start = text
192 .find(Self::EDITABLE_REGION_START)
193 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
194 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
195
196 let region = &text[start..end];
197
198 region.replace("<|user_cursor|>", "")
199 }
200
201 fn is_udiff_content_line(s: &str) -> bool {
202 s.starts_with("-")
203 || s.starts_with("+")
204 || s.starts_with(" ")
205 || s.starts_with("---")
206 || s.starts_with("+++")
207 || s.starts_with("@@")
208 }
209}
210
211fn extract_last_codeblock(text: &str) -> String {
212 let mut last_block = None;
213 let mut search_start = 0;
214
215 while let Some(start) = text[search_start..].find("```") {
216 let start = start + search_start;
217 let bytes = text.as_bytes();
218 let mut backtick_end = start;
219
220 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
221 backtick_end += 1;
222 }
223
224 let backtick_count = backtick_end - start;
225 let closing_backticks = "`".repeat(backtick_count);
226
227 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
228 backtick_end += 1;
229 }
230
231 if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
232 let code_block = &text[backtick_end + 1..backtick_end + end_pos];
233 last_block = Some(code_block.to_string());
234 search_start = backtick_end + end_pos + backtick_count;
235 } else {
236 break;
237 }
238 }
239
240 last_block.unwrap_or_else(|| text.to_string())
241}
242
243#[cfg(test)]
244mod tests {
245 use super::*;
246
247 #[test]
248 fn test_extract_last_code_block() {
249 let text = indoc::indoc! {"
250 Some thinking
251
252 ```
253 first block
254 ```
255
256 `````path='something' lines=1:2
257 last block
258 `````
259 "};
260 let last_block = extract_last_codeblock(text);
261 assert_eq!(last_block, "last block\n");
262 }
263
264 #[test]
265 fn test_extract_editable_region() {
266 let text = indoc::indoc! {"
267 some lines
268 are
269 here
270 <|editable_region_start|>
271 one
272 two three
273
274 <|editable_region_end|>
275 more
276 lines here
277 "};
278 let parsed = TeacherPrompt::extract_editable_region(text);
279 assert_eq!(
280 parsed,
281 indoc::indoc! {"
282 one
283 two three
284
285 "}
286 );
287 }
288}