1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 load_project::run_load_project,
6 retrieve_context::run_context_retrieval,
7};
8use edit_prediction::{
9 EditPredictionStore,
10 zeta2::{zeta2_output_for_patch, zeta2_prompt_input},
11};
12use gpui::AsyncApp;
13use std::sync::Arc;
14use zeta_prompt::format_zeta_prompt;
15
16pub async fn run_format_prompt(
17 example: &mut Example,
18 prompt_format: PromptFormat,
19 app_state: Arc<EpAppState>,
20 mut cx: AsyncApp,
21) {
22 run_context_retrieval(example, app_state.clone(), cx.clone()).await;
23
24 match prompt_format {
25 PromptFormat::Teacher => {
26 let prompt = TeacherPrompt::format_prompt(example);
27 example.prompt = Some(ExamplePrompt {
28 input: prompt,
29 expected_output: example.expected_patch.clone(), // TODO
30 format: prompt_format,
31 });
32 }
33 PromptFormat::Zeta2 => {
34 run_load_project(example, app_state, cx.clone()).await;
35
36 let ep_store = cx
37 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
38 .unwrap();
39
40 let state = example.state.as_ref().unwrap();
41 let snapshot = state
42 .buffer
43 .read_with(&cx, |buffer, _| buffer.snapshot())
44 .unwrap();
45 let project = state.project.clone();
46 let (_, input) = ep_store
47 .update(&mut cx, |ep_store, _cx| {
48 zeta2_prompt_input(
49 &snapshot,
50 example.context.as_ref().unwrap().files.clone(),
51 ep_store.edit_history_for_project(&project),
52 example.cursor_path.clone(),
53 example.buffer.as_ref().unwrap().cursor_offset,
54 )
55 })
56 .unwrap();
57 let prompt = format_zeta_prompt(&input);
58 let expected_output = zeta2_output_for_patch(&input, &example.expected_patch.clone());
59 example.prompt = Some(ExamplePrompt {
60 input: prompt,
61 expected_output,
62 format: prompt_format,
63 });
64 }
65 };
66}
67
68pub struct TeacherPrompt;
69
70impl TeacherPrompt {
71 const PROMPT: &str = include_str!("teacher.prompt.md");
72 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
73 pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
74
75 /// Truncate edit history to this number of last lines
76 const MAX_HISTORY_LINES: usize = 128;
77
78 pub fn format_prompt(example: &Example) -> String {
79 let edit_history = Self::format_edit_history(&example.edit_history);
80 let context = Self::format_context(example);
81 let editable_region = Self::format_editable_region(example);
82
83 let prompt = Self::PROMPT
84 .replace("{{context}}", &context)
85 .replace("{{edit_history}}", &edit_history)
86 .replace("{{editable_region}}", &editable_region);
87
88 prompt
89 }
90
91 pub fn parse(example: &Example, response: &str) -> String {
92 // Ideally, we should always be able to find cursor position in the retrieved context.
93 // In reality, sometimes we don't find it for these reasons:
94 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
95 // (can be fixed by getting cursor coordinates at the load_example stage)
96 // 2. Context retriever just didn't include cursor line.
97 //
98 // In that case, fallback to using `cursor_position` as excerpt.
99 let cursor_file = &example
100 .buffer
101 .as_ref()
102 .expect("`buffer` should be filled in in the context collection step")
103 .content;
104
105 // Extract updated (new) editable region from the model response
106 let new_editable_region = extract_last_codeblock(response);
107
108 // Reconstruct old editable region we sent to the model
109 let old_editable_region = Self::format_editable_region(example);
110 let old_editable_region = Self::extract_editable_region(&old_editable_region);
111 if !cursor_file.contains(&old_editable_region) {
112 panic!("Something's wrong: editable_region is not found in the cursor file")
113 }
114
115 // Apply editable region to a larger context and compute diff.
116 // This is needed to get a better context lines around the editable region
117 let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
118 let diff = language::unified_diff(&cursor_file, &edited_file);
119
120 let diff = indoc::formatdoc! {"
121 --- a/{path}
122 +++ b/{path}
123 {diff}",
124 path = example.cursor_path.to_string_lossy(),
125 diff = diff,
126 };
127
128 diff
129 }
130
131 fn format_edit_history(edit_history: &str) -> String {
132 // Strip comments ("garbage lines") from edit history
133 let lines = edit_history
134 .lines()
135 .filter(|&s| Self::is_udiff_content_line(s))
136 .collect::<Vec<_>>();
137
138 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
139 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
140 } else {
141 &lines
142 };
143
144 if history_lines.is_empty() {
145 return "(No edit history)".to_string();
146 }
147
148 history_lines.join("\n")
149 }
150
151 fn format_context(example: &Example) -> String {
152 if example.context.is_none() {
153 panic!("Missing context retriever step");
154 }
155
156 let mut prompt = String::new();
157 zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
158
159 prompt
160 }
161
162 fn format_editable_region(example: &Example) -> String {
163 let mut result = String::new();
164
165 let path_str = example.cursor_path.to_string_lossy();
166 result.push_str(&format!("`````path=\"{path_str}\"\n"));
167 result.push_str(Self::EDITABLE_REGION_START);
168
169 // TODO: control number of lines around cursor
170 result.push_str(&example.cursor_position);
171 if !example.cursor_position.ends_with('\n') {
172 result.push('\n');
173 }
174
175 result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
176 result.push_str("`````");
177
178 result
179 }
180
181 fn extract_editable_region(text: &str) -> String {
182 let start = text
183 .find(Self::EDITABLE_REGION_START)
184 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
185 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
186
187 let region = &text[start..end];
188
189 region.replace("<|user_cursor|>", "")
190 }
191
192 fn is_udiff_content_line(s: &str) -> bool {
193 s.starts_with("-")
194 || s.starts_with("+")
195 || s.starts_with(" ")
196 || s.starts_with("---")
197 || s.starts_with("+++")
198 || s.starts_with("@@")
199 }
200}
201
202fn extract_last_codeblock(text: &str) -> String {
203 let mut last_block = None;
204 let mut search_start = 0;
205
206 while let Some(start) = text[search_start..].find("```") {
207 let start = start + search_start;
208 let bytes = text.as_bytes();
209 let mut backtick_end = start;
210
211 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
212 backtick_end += 1;
213 }
214
215 let backtick_count = backtick_end - start;
216 let closing_backticks = "`".repeat(backtick_count);
217
218 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
219 backtick_end += 1;
220 }
221
222 if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
223 let code_block = &text[backtick_end + 1..backtick_end + end_pos];
224 last_block = Some(code_block.to_string());
225 search_start = backtick_end + end_pos + backtick_count;
226 } else {
227 break;
228 }
229 }
230
231 last_block.unwrap_or_else(|| text.to_string())
232}
233
234#[cfg(test)]
235mod tests {
236 use super::*;
237
238 #[test]
239 fn test_extract_last_code_block() {
240 let text = indoc::indoc! {"
241 Some thinking
242
243 ```
244 first block
245 ```
246
247 `````path='something' lines=1:2
248 last block
249 `````
250 "};
251 let last_block = extract_last_codeblock(text);
252 assert_eq!(last_block, "last block\n");
253 }
254
255 #[test]
256 fn test_extract_editable_region() {
257 let text = indoc::indoc! {"
258 some lines
259 are
260 here
261 <|editable_region_start|>
262 one
263 two three
264
265 <|editable_region_end|>
266 more
267 lines here
268 "};
269 let parsed = TeacherPrompt::extract_editable_region(text);
270 assert_eq!(
271 parsed,
272 indoc::indoc! {"
273 one
274 two three
275
276 "}
277 );
278 }
279}