1use crate::{
2 PromptFormat,
3 example::{Example, ExamplePrompt},
4 headless::EpAppState,
5 retrieve_context::run_context_retrieval,
6};
7use edit_prediction::{EditPredictionStore, zeta2::zeta2_prompt_input};
8use gpui::AsyncApp;
9use std::sync::Arc;
10use zeta_prompt::format_zeta_prompt;
11
12pub async fn run_format_prompt(
13 example: &mut Example,
14 prompt_format: PromptFormat,
15 app_state: Arc<EpAppState>,
16 mut cx: AsyncApp,
17) {
18 run_context_retrieval(example, app_state, cx.clone()).await;
19
20 let prompt = match prompt_format {
21 PromptFormat::Teacher => TeacherPrompt::format(example),
22 PromptFormat::Zeta2 => {
23 let ep_store = cx
24 .update(|cx| EditPredictionStore::try_global(cx).unwrap())
25 .unwrap();
26
27 let state = example.state.as_ref().unwrap();
28 let snapshot = state
29 .buffer
30 .read_with(&cx, |buffer, _| buffer.snapshot())
31 .unwrap();
32 let project = state.project.clone();
33 let (_, input) = ep_store
34 .update(&mut cx, |ep_store, _cx| {
35 zeta2_prompt_input(
36 &snapshot,
37 example.context.as_ref().unwrap().files.clone(),
38 ep_store.edit_history_for_project(&project),
39 example.cursor_path.clone(),
40 example.buffer.as_ref().unwrap().cursor_offset,
41 )
42 })
43 .unwrap();
44 format_zeta_prompt(&input)
45 }
46 };
47
48 example.prompt = Some(ExamplePrompt {
49 input: prompt,
50 expected_output: example.expected_patch.clone(), // TODO
51 format: prompt_format,
52 });
53}
54
55pub trait PromptFormatter {
56 fn format(example: &Example) -> String;
57}
58
59pub trait PromptParser {
60 /// Return unified diff patch of prediction given raw LLM response
61 fn parse(example: &Example, response: &str) -> String;
62}
63
64pub struct TeacherPrompt;
65
66impl PromptFormatter for TeacherPrompt {
67 fn format(example: &Example) -> String {
68 let edit_history = Self::format_edit_history(&example.edit_history);
69 let context = Self::format_context(example);
70 let editable_region = Self::format_editable_region(example);
71
72 let prompt = Self::PROMPT
73 .replace("{{context}}", &context)
74 .replace("{{edit_history}}", &edit_history)
75 .replace("{{editable_region}}", &editable_region);
76
77 prompt
78 }
79}
80
81impl TeacherPrompt {
82 const PROMPT: &str = include_str!("teacher.prompt.md");
83 pub(crate) const EDITABLE_REGION_START: &str = "<|editable_region_start|>\n";
84 pub(crate) const EDITABLE_REGION_END: &str = "<|editable_region_end|>";
85
86 /// Truncate edit history to this number of last lines
87 const MAX_HISTORY_LINES: usize = 128;
88
89 fn format_edit_history(edit_history: &str) -> String {
90 // Strip comments ("garbage lines") from edit history
91 let lines = edit_history
92 .lines()
93 .filter(|&s| Self::is_udiff_content_line(s))
94 .collect::<Vec<_>>();
95
96 let history_lines = if lines.len() > Self::MAX_HISTORY_LINES {
97 &lines[lines.len() - Self::MAX_HISTORY_LINES..]
98 } else {
99 &lines
100 };
101
102 if history_lines.is_empty() {
103 return "(No edit history)".to_string();
104 }
105
106 history_lines.join("\n")
107 }
108
109 fn format_context(example: &Example) -> String {
110 if example.context.is_none() {
111 panic!("Missing context retriever step");
112 }
113
114 let mut prompt = String::new();
115 zeta_prompt::write_related_files(&mut prompt, &example.context.as_ref().unwrap().files);
116
117 prompt
118 }
119
120 fn format_editable_region(example: &Example) -> String {
121 let mut result = String::new();
122
123 let path_str = example.cursor_path.to_string_lossy();
124 result.push_str(&format!("`````path=\"{path_str}\"\n"));
125 result.push_str(Self::EDITABLE_REGION_START);
126
127 // TODO: control number of lines around cursor
128 result.push_str(&example.cursor_position);
129 if !example.cursor_position.ends_with('\n') {
130 result.push('\n');
131 }
132
133 result.push_str(&format!("{}\n", Self::EDITABLE_REGION_END));
134 result.push_str("`````");
135
136 result
137 }
138
139 fn extract_editable_region(text: &str) -> String {
140 let start = text
141 .find(Self::EDITABLE_REGION_START)
142 .map_or(0, |pos| pos + Self::EDITABLE_REGION_START.len());
143 let end = text.find(Self::EDITABLE_REGION_END).unwrap_or(text.len());
144
145 let region = &text[start..end];
146
147 region.replace("<|user_cursor|>", "")
148 }
149
150 fn is_udiff_content_line(s: &str) -> bool {
151 s.starts_with("-")
152 || s.starts_with("+")
153 || s.starts_with(" ")
154 || s.starts_with("---")
155 || s.starts_with("+++")
156 || s.starts_with("@@")
157 }
158}
159
160impl PromptParser for TeacherPrompt {
161 fn parse(example: &Example, response: &str) -> String {
162 // Ideally, we should always be able to find cursor position in the retrieved context.
163 // In reality, sometimes we don't find it for these reasons:
164 // 1. `example.cursor_position` contains _more_ context than included in the retrieved context
165 // (can be fixed by getting cursor coordinates at the load_example stage)
166 // 2. Context retriever just didn't include cursor line.
167 //
168 // In that case, fallback to using `cursor_position` as excerpt.
169 let cursor_file = &example
170 .buffer
171 .as_ref()
172 .expect("`buffer` should be filled in in the context collection step")
173 .content;
174
175 // Extract updated (new) editable region from the model response
176 let new_editable_region = extract_last_codeblock(response);
177
178 // Reconstruct old editable region we sent to the model
179 let old_editable_region = Self::format_editable_region(example);
180 let old_editable_region = Self::extract_editable_region(&old_editable_region);
181 if !cursor_file.contains(&old_editable_region) {
182 panic!("Something's wrong: editable_region is not found in the cursor file")
183 }
184
185 // Apply editable region to a larger context and compute diff.
186 // This is needed to get a better context lines around the editable region
187 let edited_file = cursor_file.replace(&old_editable_region, &new_editable_region);
188 let diff = language::unified_diff(&cursor_file, &edited_file);
189
190 let diff = indoc::formatdoc! {"
191 --- a/{path}
192 +++ b/{path}
193 {diff}
194 ",
195 path = example.cursor_path.to_string_lossy(),
196 diff = diff,
197 };
198
199 diff
200 }
201}
202
203fn extract_last_codeblock(text: &str) -> String {
204 let mut last_block = None;
205 let mut search_start = 0;
206
207 while let Some(start) = text[search_start..].find("```") {
208 let start = start + search_start;
209 let bytes = text.as_bytes();
210 let mut backtick_end = start;
211
212 while backtick_end < bytes.len() && bytes[backtick_end] == b'`' {
213 backtick_end += 1;
214 }
215
216 let backtick_count = backtick_end - start;
217 let closing_backticks = "`".repeat(backtick_count);
218
219 while backtick_end < bytes.len() && bytes[backtick_end] != b'\n' {
220 backtick_end += 1;
221 }
222
223 if let Some(end_pos) = text[backtick_end..].find(&closing_backticks) {
224 let code_block = &text[backtick_end + 1..backtick_end + end_pos - 1];
225 last_block = Some(code_block.to_string());
226 search_start = backtick_end + end_pos + backtick_count;
227 } else {
228 break;
229 }
230 }
231
232 last_block.unwrap_or_else(|| text.to_string())
233}
234
235#[cfg(test)]
236mod tests {
237 use super::*;
238
239 #[test]
240 fn test_extract_last_code_block() {
241 let text = indoc::indoc! {"
242 Some thinking
243
244 ```
245 first block
246 ```
247
248 `````path='something' lines=1:2
249 last block
250 `````
251 "};
252 let last_block = extract_last_codeblock(text);
253 assert_eq!(last_block, "last block");
254 }
255
256 #[test]
257 fn test_extract_editable_region() {
258 let text = indoc::indoc! {"
259 some lines
260 are
261 here
262 <|editable_region_start|>
263 one
264 two three
265
266 <|editable_region_end|>
267 more
268 lines here
269 "};
270 let parsed = TeacherPrompt::extract_editable_region(text);
271 assert_eq!(
272 parsed,
273 indoc::indoc! {"
274 one
275 two three
276
277 "}
278 );
279 }
280}