1//! Zeta2 prompt planning and generation code shared with cloud.
2
3use anyhow::{Context as _, Result, anyhow};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration};
5use indoc::indoc;
6use ordered_float::OrderedFloat;
7use rustc_hash::{FxHashMap, FxHashSet};
8use std::fmt::Write;
9use std::sync::Arc;
10use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
11use strum::{EnumIter, IntoEnumIterator};
12
13pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
14
15pub const CURSOR_MARKER: &str = "<|cursor_position|>";
16/// NOTE: Differs from zed version of constant - includes a newline
17pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
18/// NOTE: Differs from zed version of constant - includes a newline
19pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
20
21// TODO: use constants for markers?
22const MARKED_EXCERPT_SYSTEM_PROMPT: &str = indoc! {"
23 You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
24
25 The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|cursor_position|>. Please respond with edited code for that region.
26
27 Other code is provided for context, and `…` indicates when code has been skipped.
28"};
29
30const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
31 You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
32
33 Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
34
35 The cursor position is marked with `<|cursor_position|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
36
37 Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
38
39 <|current_section|>
40 for i in 0..16 {
41 println!("{i}");
42 }
43"#};
44
45pub struct PlannedPrompt<'a> {
46 request: &'a predict_edits_v3::PredictEditsRequest,
47 /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
48 /// `to_prompt_string`.
49 snippets: Vec<PlannedSnippet<'a>>,
50 budget_used: usize,
51}
52
53pub fn system_prompt(format: PromptFormat) -> &'static str {
54 match format {
55 PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
56 PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
57 // only intended for use via zeta_cli
58 PromptFormat::OnlySnippets => "",
59 }
60}
61
62#[derive(Clone, Debug)]
63pub struct PlannedSnippet<'a> {
64 path: Arc<Path>,
65 range: Range<usize>,
66 text: &'a str,
67 // TODO: Indicate this in the output
68 #[allow(dead_code)]
69 text_is_truncated: bool,
70}
71
72#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
73pub enum SnippetStyle {
74 Signature,
75 Declaration,
76}
77
78#[derive(Clone, Debug)]
79pub struct SectionLabels {
80 pub excerpt_index: usize,
81 pub section_ranges: Vec<(Arc<Path>, Range<usize>)>,
82}
83
84impl<'a> PlannedPrompt<'a> {
85 /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
86 ///
87 /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
88 /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
89 /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
90 /// upgrade.
91 ///
92 /// TODO: Implement an early halting condition. One option might be to have another priority
93 /// queue where the score is the size, and update it accordingly. Another option might be to
94 /// have some simpler heuristic like bailing after N failed insertions, or based on how much
95 /// budget is left.
96 ///
97 /// TODO: Has the current known sources of imprecision:
98 ///
99 /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
100 /// plan even though the containing struct is already included.
101 ///
102 /// * Does not consider cost of signatures when ranking snippets - this is tricky since
103 /// signatures may be shared by multiple snippets.
104 ///
105 /// * Does not include file paths / other text when considering max_bytes.
106 pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
107 let mut this = PlannedPrompt {
108 request,
109 snippets: Vec::new(),
110 budget_used: request.excerpt.len(),
111 };
112 let mut included_parents = FxHashSet::default();
113 let additional_parents = this.additional_parent_signatures(
114 &request.excerpt_path,
115 request.excerpt_parent,
116 &included_parents,
117 )?;
118 this.add_parents(&mut included_parents, additional_parents);
119
120 let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
121
122 if this.budget_used > max_bytes {
123 return Err(anyhow!(
124 "Excerpt + signatures size of {} already exceeds budget of {}",
125 this.budget_used,
126 max_bytes
127 ));
128 }
129
130 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
131 struct QueueEntry {
132 score_density: OrderedFloat<f32>,
133 declaration_index: usize,
134 style: SnippetStyle,
135 }
136
137 // Initialize priority queue with the best score for each snippet.
138 let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
139 for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
140 let (style, score_density) = SnippetStyle::iter()
141 .map(|style| {
142 (
143 style,
144 OrderedFloat(declaration_score_density(&declaration, style)),
145 )
146 })
147 .max_by_key(|(_, score_density)| *score_density)
148 .unwrap();
149 queue.push(QueueEntry {
150 score_density,
151 declaration_index,
152 style,
153 });
154 }
155
156 // Knapsack selection loop
157 while let Some(queue_entry) = queue.pop() {
158 let Some(declaration) = request
159 .referenced_declarations
160 .get(queue_entry.declaration_index)
161 else {
162 return Err(anyhow!(
163 "Invalid declaration index {}",
164 queue_entry.declaration_index
165 ));
166 };
167
168 let mut additional_bytes = declaration_size(declaration, queue_entry.style);
169 if this.budget_used + additional_bytes > max_bytes {
170 continue;
171 }
172
173 let additional_parents = this.additional_parent_signatures(
174 &declaration.path,
175 declaration.parent_index,
176 &mut included_parents,
177 )?;
178 additional_bytes += additional_parents
179 .iter()
180 .map(|(_, snippet)| snippet.text.len())
181 .sum::<usize>();
182 if this.budget_used + additional_bytes > max_bytes {
183 continue;
184 }
185
186 this.budget_used += additional_bytes;
187 this.add_parents(&mut included_parents, additional_parents);
188 let planned_snippet = match queue_entry.style {
189 SnippetStyle::Signature => {
190 let Some(text) = declaration.text.get(declaration.signature_range.clone())
191 else {
192 return Err(anyhow!(
193 "Invalid declaration signature_range {:?} with text.len() = {}",
194 declaration.signature_range,
195 declaration.text.len()
196 ));
197 };
198 PlannedSnippet {
199 path: declaration.path.clone(),
200 range: (declaration.signature_range.start + declaration.range.start)
201 ..(declaration.signature_range.end + declaration.range.start),
202 text,
203 text_is_truncated: declaration.text_is_truncated,
204 }
205 }
206 SnippetStyle::Declaration => PlannedSnippet {
207 path: declaration.path.clone(),
208 range: declaration.range.clone(),
209 text: &declaration.text,
210 text_is_truncated: declaration.text_is_truncated,
211 },
212 };
213 this.snippets.push(planned_snippet);
214
215 // When a Signature is consumed, insert an entry for Definition style.
216 if queue_entry.style == SnippetStyle::Signature {
217 let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
218 let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
219 let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
220 let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
221
222 let score_diff = declaration_score - signature_score;
223 let size_diff = declaration_size.saturating_sub(signature_size);
224 if score_diff > 0.0001 && size_diff > 0 {
225 queue.push(QueueEntry {
226 declaration_index: queue_entry.declaration_index,
227 score_density: OrderedFloat(score_diff / (size_diff as f32)),
228 style: SnippetStyle::Declaration,
229 });
230 }
231 }
232 }
233
234 anyhow::Ok(this)
235 }
236
237 fn add_parents(
238 &mut self,
239 included_parents: &mut FxHashSet<usize>,
240 snippets: Vec<(usize, PlannedSnippet<'a>)>,
241 ) {
242 for (parent_index, snippet) in snippets {
243 included_parents.insert(parent_index);
244 self.budget_used += snippet.text.len();
245 self.snippets.push(snippet);
246 }
247 }
248
249 fn additional_parent_signatures(
250 &self,
251 path: &Arc<Path>,
252 parent_index: Option<usize>,
253 included_parents: &FxHashSet<usize>,
254 ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
255 let mut results = Vec::new();
256 self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
257 Ok(results)
258 }
259
260 fn additional_parent_signatures_impl(
261 &self,
262 path: &Arc<Path>,
263 parent_index: Option<usize>,
264 included_parents: &FxHashSet<usize>,
265 results: &mut Vec<(usize, PlannedSnippet<'a>)>,
266 ) -> Result<()> {
267 let Some(parent_index) = parent_index else {
268 return Ok(());
269 };
270 if included_parents.contains(&parent_index) {
271 return Ok(());
272 }
273 let Some(parent_signature) = self.request.signatures.get(parent_index) else {
274 return Err(anyhow!("Invalid parent index {}", parent_index));
275 };
276 results.push((
277 parent_index,
278 PlannedSnippet {
279 path: path.clone(),
280 range: parent_signature.range.clone(),
281 text: &parent_signature.text,
282 text_is_truncated: parent_signature.text_is_truncated,
283 },
284 ));
285 self.additional_parent_signatures_impl(
286 path,
287 parent_signature.parent_index,
288 included_parents,
289 results,
290 )
291 }
292
293 /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
294 /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
295 /// chunks.
296 pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
297 let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
298 FxHashMap::default();
299 for snippet in &self.snippets {
300 file_to_snippets
301 .entry(&snippet.path)
302 .or_default()
303 .push(snippet);
304 }
305
306 // Reorder so that file with cursor comes last
307 let mut file_snippets = Vec::new();
308 let mut excerpt_file_snippets = Vec::new();
309 for (file_path, snippets) in file_to_snippets {
310 if file_path == self.request.excerpt_path.as_ref() {
311 excerpt_file_snippets = snippets;
312 } else {
313 file_snippets.push((file_path, snippets, false));
314 }
315 }
316 let excerpt_snippet = PlannedSnippet {
317 path: self.request.excerpt_path.clone(),
318 range: self.request.excerpt_range.clone(),
319 text: &self.request.excerpt,
320 text_is_truncated: false,
321 };
322 excerpt_file_snippets.push(&excerpt_snippet);
323 file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
324
325 let mut excerpt_file_insertions = match self.request.prompt_format {
326 PromptFormat::MarkedExcerpt => vec![
327 (
328 self.request.excerpt_range.start,
329 EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
330 ),
331 (
332 self.request.excerpt_range.start + self.request.cursor_offset,
333 CURSOR_MARKER,
334 ),
335 (
336 self.request
337 .excerpt_range
338 .end
339 .saturating_sub(0)
340 .max(self.request.excerpt_range.start),
341 EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
342 ),
343 ],
344 PromptFormat::LabeledSections => vec![(
345 self.request.excerpt_range.start + self.request.cursor_offset,
346 CURSOR_MARKER,
347 )],
348 PromptFormat::OnlySnippets => vec![],
349 };
350
351 let mut prompt = String::new();
352 prompt.push_str("## User Edits\n\n");
353 Self::push_events(&mut prompt, &self.request.events);
354
355 prompt.push_str("\n## Code\n\n");
356 let section_labels =
357 self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
358 Ok((prompt, section_labels))
359 }
360
361 fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
362 for event in events {
363 match event {
364 Event::BufferChange {
365 path,
366 old_path,
367 diff,
368 predicted,
369 } => {
370 if let Some(old_path) = &old_path
371 && let Some(new_path) = &path
372 {
373 if old_path != new_path {
374 writeln!(
375 output,
376 "User renamed {} to {}\n\n",
377 old_path.display(),
378 new_path.display()
379 )
380 .unwrap();
381 }
382 }
383
384 let path = path
385 .as_ref()
386 .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
387
388 if *predicted {
389 writeln!(
390 output,
391 "User accepted prediction {:?}:\n```diff\n{}\n```\n",
392 path, diff
393 )
394 .unwrap();
395 } else {
396 writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
397 .unwrap();
398 }
399 }
400 }
401 }
402 }
403
404 fn push_file_snippets(
405 &self,
406 output: &mut String,
407 excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
408 file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
409 ) -> Result<SectionLabels> {
410 let mut section_ranges = Vec::new();
411 let mut excerpt_index = None;
412
413 for (file_path, mut snippets, is_excerpt_file) in file_snippets {
414 snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
415
416 // TODO: What if the snippets get expanded too large to be editable?
417 let mut current_snippet: Option<(&PlannedSnippet, Range<usize>)> = None;
418 let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<usize>)> = Vec::new();
419 for snippet in snippets {
420 if let Some((_, current_snippet_range)) = current_snippet.as_mut()
421 && snippet.range.start < current_snippet_range.end
422 {
423 if snippet.range.end > current_snippet_range.end {
424 current_snippet_range.end = snippet.range.end;
425 }
426 continue;
427 }
428 if let Some(current_snippet) = current_snippet.take() {
429 disjoint_snippets.push(current_snippet);
430 }
431 current_snippet = Some((snippet, snippet.range.clone()));
432 }
433 if let Some(current_snippet) = current_snippet.take() {
434 disjoint_snippets.push(current_snippet);
435 }
436
437 writeln!(output, "```{}", file_path.display()).ok();
438 let mut skipped_last_snippet = false;
439 for (snippet, range) in disjoint_snippets {
440 let section_index = section_ranges.len();
441
442 match self.request.prompt_format {
443 PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets => {
444 if range.start > 0 && !skipped_last_snippet {
445 output.push_str("…\n");
446 }
447 }
448 PromptFormat::LabeledSections => {
449 if is_excerpt_file
450 && range.start <= self.request.excerpt_range.start
451 && range.end >= self.request.excerpt_range.end
452 {
453 writeln!(output, "<|current_section|>").ok();
454 } else {
455 writeln!(output, "<|section_{}|>", section_index).ok();
456 }
457 }
458 }
459
460 if is_excerpt_file {
461 if self.request.prompt_format == PromptFormat::OnlySnippets {
462 if range.start >= self.request.excerpt_range.start
463 && range.end <= self.request.excerpt_range.end
464 {
465 skipped_last_snippet = true;
466 } else {
467 skipped_last_snippet = false;
468 output.push_str(snippet.text);
469 }
470 } else {
471 let mut last_offset = range.start;
472 let mut i = 0;
473 while i < excerpt_file_insertions.len() {
474 let (offset, insertion) = &excerpt_file_insertions[i];
475 let found = *offset >= range.start && *offset <= range.end;
476 if found {
477 excerpt_index = Some(section_index);
478 output.push_str(
479 &snippet.text[last_offset - range.start..offset - range.start],
480 );
481 output.push_str(insertion);
482 last_offset = *offset;
483 excerpt_file_insertions.remove(i);
484 continue;
485 }
486 i += 1;
487 }
488 skipped_last_snippet = false;
489 output.push_str(&snippet.text[last_offset - range.start..]);
490 }
491 } else {
492 skipped_last_snippet = false;
493 output.push_str(snippet.text);
494 }
495
496 section_ranges.push((snippet.path.clone(), range));
497 }
498
499 output.push_str("```\n\n");
500 }
501
502 Ok(SectionLabels {
503 // TODO: Clean this up
504 excerpt_index: match self.request.prompt_format {
505 PromptFormat::OnlySnippets => 0,
506 _ => excerpt_index.context("bug: no snippet found for excerpt")?,
507 },
508 section_ranges,
509 })
510 }
511}
512
513fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
514 declaration_score(declaration, style) / declaration_size(declaration, style) as f32
515}
516
517fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
518 match style {
519 SnippetStyle::Signature => declaration.signature_score,
520 SnippetStyle::Declaration => declaration.declaration_score,
521 }
522}
523
524fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
525 match style {
526 SnippetStyle::Signature => declaration.signature_range.len(),
527 SnippetStyle::Declaration => declaration.text.len(),
528 }
529}