1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 load_project::run_load_project,
6 progress::{Progress, Step},
7 retrieve_context::run_context_retrieval,
8};
9use edit_prediction::{
10 EditPredictionStore,
11 zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
12};
13use gpui::AsyncApp;
14use std::sync::Arc;
15use zeta_prompt::format_zeta_prompt;
16
17pub async fn run_format_prompt(
18 example: &mut Example,
19 prompt_format: PromptFormat,
20 app_state: Arc<EpAppState>,
21 mut cx: AsyncApp,
22) {
23 run_context_retrieval(example, app_state.clone(), cx.clone()).await;
24
25 let _step_progress = Progress::global().start(Step::FormatPrompt, &example.name);
26
27 match prompt_format {
28 PromptFormat::Teacher => {
29 let prompt = TeacherPrompt::format_prompt(example);
30 example.prompt = Some(ExamplePrompt {
31 input: prompt,
32 expected_output: example.expected_patch.clone(), // TODO
33 format: prompt_format,
34 });
35 }
36 PromptFormat::Zeta2 => {
37 run_load_project(example, app_state, cx.clone()).await;
38
39 let ep_store = cx
40 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
41 .unwrap();
42
43 let state = example.state.as_ref().unwrap();
44 let snapshot = state
45 .buffer
46 .read_with(&cx, |buffer, _| buffer.snapshot())
47 .unwrap();
48 let project = state.project.clone();
49 let (_, input) = ep_store
50 .update(&mut cx, |ep_store, _cx| {
51 zeta2_prompt_input(
52 &snapshot,
53 example.context.as_ref().unwrap().files.clone(),
54 ep_store.edit_history_for_project(&project),
55 example.cursor_path.clone(),
56 example.buffer.as_ref().unwrap().cursor_offset,
57 )
58 })
59 .unwrap();
60 let prompt = format_zeta_prompt(&input);
61 let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
62 example.prompt = Some(ExamplePrompt {
63 input: prompt,
64 expected_output,
65 format: prompt_format,
66 });
67 }
68 };
69}
70
71pub struct TeacherPrompt;
72
73impl TeacherPrompt {
74 const PROMPT: &str = include_str!("teacher.prompt.md");
75 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
76 pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
77
78 /// Truncate edit history to this number of last lines
79 const MAX_HISTORY_LINES: usize = 128;
80
81 pub fn format_prompt(example: &Example) -> String {
82 let edit_history = Self::format_edit_history(&example.edit_history);
83 let context = Self::format_context(example);
84 let editable_region = Self::format_editable_region(example);
85
86 let prompt = Self::PROMPT
87 .replace("{{context}}", &context)
88 .replace("{{edit_history}}", &edit_history)
89 .replace("{{editable_region}}", &editable_region);
90
91 prompt
92 }
93
94 pub fn parse(example: &Example, response: &str) -> String {
95 // Ideally, we should always be able to find cursor position in the retrieved context.
96 // In reality, sometimes we don't find it for these reasons:
97 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
98 // (can be fixed by getting cursor coordinates at the load_example stage)
99 // 2. Context retriever just didn't include cursor line.
100 //
101 // In that case, fallback to using `cursor_position` as excerpt.
102 let cursor_file = &example
103 .buffer
104 .as_ref()
105 .expect("`buffer` should be filled in in the context collection step")
106 .content;
107
108 // Extract updated (new) editable region from the model response
109 let new_editable_region = extract_last_codeblock(response);
110
111 // Reconstruct old editable region we sent to the model
112 let old_editable_region = Self::format_editable_region(example);
113 let old_editable_region = Self::extract_editable_region(&old_editable_region);
114 if !cursor_file.contains(&old_editable_region) {
115 panic!("Something's wrong: editable_region is not found in the cursor file")
116 }
117
118 // Apply editable region to a larger context and compute diff.
119 // This is needed to get a better context lines around the editable region
120 let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
121 let diff = language::unified_diff(&cursor_file, &edited_file);
122
123 let diff = indoc::formatdoc! {"
124 --- a/{path}
125 +++ b/{path}
126 {diff}",
127 path = example.cursor_path.to_string_lossy(),
128 diff = diff,
129 };
130
131 diff
132 }
133
134 fn format_edit_history(edit_history: &str) -> String {
135 // Strip comments ("garbage lines") from edit history
136 let lines = edit_history
137 .lines()
138 .filter(|&s| Self::is_udiff_content_line(s))
139 .collect::<Vec<_>>();
140
141 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
142 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
143 } else {
144 &lines
145 };
146
147 if history_lines.is_empty() {
148 return "(No edit history)".to_string();
149 }
150
151 history_lines.join("\n")
152 }
153
154 fn format_context(example: &Example) -> String {
155 if example.context.is_none() {
156 panic!("Missing context retriever step");
157 }
158
159 let mut prompt = String::new();
160 zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
161
162 prompt
163 }
164
165 fn format_editable_region(example: &Example) -> String {
166 let mut result = String::new();
167
168 let path_str = example.cursor_path.to_string_lossy();
169 result.push_str(&format!("`````path=\"{path_str}\"\n"));
170 result.push_str(Self::EDITABLE_REGION_START);
171
172 // TODO: control number of lines around cursor
173 result.push_str(&example.cursor_position);
174 if !example.cursor_position.ends_with('\n') {
175 result.push('\n');
176 }
177
178 result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
179 result.push_str("`````");
180
181 result
182 }
183
184 fn extract_editable_region(text: &str) -> String {
185 let start = text
186 .find(Self::EDITABLE_REGION_START)
187 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
188 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
189
190 let region = &text[start..end];
191
192 region.replace("<|user_cursor|>", "")
193 }
194
195 fn is_udiff_content_line(s: &str) -> bool {
196 s.starts_with("-")
197 || s.starts_with("+")
198 || s.starts_with(" ")
199 || s.starts_with("---")
200 || s.starts_with("+++")
201 || s.starts_with("@@")
202 }
203}
204
205fn extract_last_codeblock(text: &str) -> String {
206 let mut last_block = None;
207 let mut search_start = 0;
208
209 while let Some(start) = text[search_start..].find("```") {
210 let start = start + search_start;
211 let bytes = text.as_bytes();
212 let mut backtick_end = start;
213
214 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
215 backtick_end += 1;
216 }
217
218 let backtick_count = backtick_end - start;
219 let closing_backticks = "`".repeat(backtick_count);
220
221 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
222 backtick_end += 1;
223 }
224
225 if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
226 let code_block = &text[backtick_end + 1..backtick_end + end_pos];
227 last_block = Some(code_block.to_string());
228 search_start = backtick_end + end_pos + backtick_count;
229 } else {
230 break;
231 }
232 }
233
234 last_block.unwrap_or_else(|| text.to_string())
235}
236
237#[cfg(test)]
238mod tests {
239 use super::*;
240
241 #[test]
242 fn test_extract_last_code_block() {
243 let text = indoc::indoc! {"
244 Some thinking
245
246 ```
247 first block
248 ```
249
250 `````path='something' lines=1:2
251 last block
252 `````
253 "};
254 let last_block = extract_last_codeblock(text);
255 assert_eq!(last_block, "last block\n");
256 }
257
258 #[test]
259 fn test_extract_editable_region() {
260 let text = indoc::indoc! {"
261 some lines
262 are
263 here
264 <|editable_region_start|>
265 one
266 two three
267
268 <|editable_region_end|>
269 more
270 lines here
271 "};
272 let parsed = TeacherPrompt::extract_editable_region(text);
273 assert_eq!(
274 parsed,
275 indoc::indoc! {"
276 one
277 two three
278
279 "}
280 );
281 }
282}