1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 load_project::run_load_project,
6 progress::{Progress, Step},
7 retrieve_context::run_context_retrieval,
8};
9use edit_prediction::{
10 EditPredictionStore,
11 zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
12};
13use gpui::AsyncApp;
14use std::sync::Arc;
15use zeta_prompt::format_zeta_prompt;
16
17pub async fn run_format_prompt(
18 example: &mut Example,
19 prompt_format: PromptFormat,
20 app_state: Arc<EpAppState>,
21 progress: Arc<Progress>,
22 mut cx: AsyncApp,
23) {
24 run_context_retrieval(example, app_state.clone(), progress.clone(), cx.clone()).await;
25
26 let _step_progress = progress.start(Step::FormatPrompt, &example.name);
27
28 match prompt_format {
29 PromptFormat::Teacher => {
30 let prompt = TeacherPrompt::format_prompt(example);
31 example.prompt = Some(ExamplePrompt {
32 input: prompt,
33 expected_output: example.expected_patch.clone(), // TODO
34 format: prompt_format,
35 });
36 }
37 PromptFormat::Zeta2 => {
38 run_load_project(example, app_state, progress.clone(), cx.clone()).await;
39
40 let ep_store = cx
41 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
42 .unwrap();
43
44 let state = example.state.as_ref().unwrap();
45 let snapshot = state
46 .buffer
47 .read_with(&cx, |buffer, _| buffer.snapshot())
48 .unwrap();
49 let project = state.project.clone();
50 let (_, input) = ep_store
51 .update(&mut cx, |ep_store, _cx| {
52 zeta2_prompt_input(
53 &snapshot,
54 example.context.as_ref().unwrap().files.clone(),
55 ep_store.edit_history_for_project(&project),
56 example.cursor_path.clone(),
57 example.buffer.as_ref().unwrap().cursor_offset,
58 )
59 })
60 .unwrap();
61 let prompt = format_zeta_prompt(&input);
62 let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
63 example.prompt = Some(ExamplePrompt {
64 input: prompt,
65 expected_output,
66 format: prompt_format,
67 });
68 }
69 };
70}
71
72pub struct TeacherPrompt;
73
74impl TeacherPrompt {
75 const PROMPT: &str = include_str!("teacher.prompt.md");
76 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
77 pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
78
79 /// Truncate edit history to this number of last lines
80 const MAX_HISTORY_LINES: usize = 128;
81
82 pub fn format_prompt(example: &Example) -> String {
83 let edit_history = Self::format_edit_history(&example.edit_history);
84 let context = Self::format_context(example);
85 let editable_region = Self::format_editable_region(example);
86
87 let prompt = Self::PROMPT
88 .replace("{{context}}", &context)
89 .replace("{{edit_history}}", &edit_history)
90 .replace("{{editable_region}}", &editable_region);
91
92 prompt
93 }
94
95 pub fn parse(example: &Example, response: &str) -> String {
96 // Ideally, we should always be able to find cursor position in the retrieved context.
97 // In reality, sometimes we don't find it for these reasons:
98 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
99 // (can be fixed by getting cursor coordinates at the load_example stage)
100 // 2. Context retriever just didn't include cursor line.
101 //
102 // In that case, fallback to using `cursor_position` as excerpt.
103 let cursor_file = &example
104 .buffer
105 .as_ref()
106 .expect("`buffer` should be filled in in the context collection step")
107 .content;
108
109 // Extract updated (new) editable region from the model response
110 let new_editable_region = extract_last_codeblock(response);
111
112 // Reconstruct old editable region we sent to the model
113 let old_editable_region = Self::format_editable_region(example);
114 let old_editable_region = Self::extract_editable_region(&old_editable_region);
115 if !cursor_file.contains(&old_editable_region) {
116 panic!("Something's wrong: editable_region is not found in the cursor file")
117 }
118
119 // Apply editable region to a larger context and compute diff.
120 // This is needed to get a better context lines around the editable region
121 let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
122 let diff = language::unified_diff(&cursor_file, &edited_file);
123
124 let diff = indoc::formatdoc! {"
125 --- a/{path}
126 +++ b/{path}
127 {diff}",
128 path = example.cursor_path.to_string_lossy(),
129 diff = diff,
130 };
131
132 diff
133 }
134
135 fn format_edit_history(edit_history: &str) -> String {
136 // Strip comments ("garbage lines") from edit history
137 let lines = edit_history
138 .lines()
139 .filter(|&s| Self::is_udiff_content_line(s))
140 .collect::<Vec<_>>();
141
142 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
143 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
144 } else {
145 &lines
146 };
147
148 if history_lines.is_empty() {
149 return "(No edit history)".to_string();
150 }
151
152 history_lines.join("\n")
153 }
154
155 fn format_context(example: &Example) -> String {
156 if example.context.is_none() {
157 panic!("Missing context retriever step");
158 }
159
160 let mut prompt = String::new();
161 zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
162
163 prompt
164 }
165
166 fn format_editable_region(example: &Example) -> String {
167 let mut result = String::new();
168
169 let path_str = example.cursor_path.to_string_lossy();
170 result.push_str(&format!("`````path=\"{path_str}\"\n"));
171 result.push_str(Self::EDITABLE_REGION_START);
172
173 // TODO: control number of lines around cursor
174 result.push_str(&example.cursor_position);
175 if !example.cursor_position.ends_with('\n') {
176 result.push('\n');
177 }
178
179 result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
180 result.push_str("`````");
181
182 result
183 }
184
185 fn extract_editable_region(text: &str) -> String {
186 let start = text
187 .find(Self::EDITABLE_REGION_START)
188 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
189 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
190
191 let region = &text[start..end];
192
193 region.replace("<|user_cursor|>", "")
194 }
195
196 fn is_udiff_content_line(s: &str) -> bool {
197 s.starts_with("-")
198 || s.starts_with("+")
199 || s.starts_with(" ")
200 || s.starts_with("---")
201 || s.starts_with("+++")
202 || s.starts_with("@@")
203 }
204}
205
206fn extract_last_codeblock(text: &str) -> String {
207 let mut last_block = None;
208 let mut search_start = 0;
209
210 while let Some(start) = text[search_start..].find("```") {
211 let start = start + search_start;
212 let bytes = text.as_bytes();
213 let mut backtick_end = start;
214
215 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
216 backtick_end += 1;
217 }
218
219 let backtick_count = backtick_end - start;
220 let closing_backticks = "`".repeat(backtick_count);
221
222 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
223 backtick_end += 1;
224 }
225
226 if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
227 let code_block = &text[backtick_end + 1..backtick_end + end_pos];
228 last_block = Some(code_block.to_string());
229 search_start = backtick_end + end_pos + backtick_count;
230 } else {
231 break;
232 }
233 }
234
235 last_block.unwrap_or_else(|| text.to_string())
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_extract_last_code_block() {
244 let text = indoc::indoc! {"
245 Some thinking
246
247 ```
248 first block
249 ```
250
251 `````path='something' lines=1:2
252 last block
253 `````
254 "};
255 let last_block = extract_last_codeblock(text);
256 assert_eq!(last_block, "last block\n");
257 }
258
259 #[test]
260 fn test_extract_editable_region() {
261 let text = indoc::indoc! {"
262 some lines
263 are
264 here
265 <|editable_region_start|>
266 one
267 two three
268
269 <|editable_region_end|>
270 more
271 lines here
272 "};
273 let parsed = TeacherPrompt::extract_editable_region(text);
274 assert_eq!(
275 parsed,
276 indoc::indoc! {"
277 one
278 two three
279
280 "}
281 );
282 }
283}