1//! Zeta2 prompt planning and generation code shared with cloud.
2
3use anyhow::{Context as _, Result, anyhow};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration};
5use indoc::indoc;
6use ordered_float::OrderedFloat;
7use rustc_hash::{FxHashMap, FxHashSet};
8use std::fmt::Write;
9use std::sync::Arc;
10use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
11use strum::{EnumIter, IntoEnumIterator};
12
13pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
14
15pub const CURSOR_MARKER: &str = "<|cursor_position|>";
16/// NOTE: Differs from zed version of constant - includes a newline
17pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
18/// NOTE: Differs from zed version of constant - includes a newline
19pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
20
21// TODO: use constants for markers?
22const MARKED_EXCERPT_SYSTEM_PROMPT: &str = indoc! {"
23 You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
24
25 The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|cursor_position|>. Please respond with edited code for that region.
26
27 Other code is provided for context, and `…` indicates when code has been skipped.
28"};
29
30const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
31 You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
32
33 Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
34
35 The cursor position is marked with `<|cursor_position|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
36
37 Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
38
39 <|current_section|>
40 for i in 0..16 {
41 println!("{i}");
42 }
43"#};
44
45pub struct PlannedPrompt<'a> {
46 request: &'a predict_edits_v3::PredictEditsRequest,
47 /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
48 /// `to_prompt_string`.
49 snippets: Vec<PlannedSnippet<'a>>,
50 budget_used: usize,
51}
52
53pub fn system_prompt(format: PromptFormat) -> &'static str {
54 match format {
55 PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
56 PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
57 // only intended for use via zeta_cli
58 PromptFormat::OnlySnippets => "",
59 }
60}
61
62#[derive(Clone, Debug)]
63pub struct PlannedSnippet<'a> {
64 path: Arc<Path>,
65 range: Range<usize>,
66 text: &'a str,
67 // TODO: Indicate this in the output
68 #[allow(dead_code)]
69 text_is_truncated: bool,
70}
71
72#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
73pub enum DeclarationStyle {
74 Signature,
75 Declaration,
76}
77
78#[derive(Clone, Debug)]
79pub struct SectionLabels {
80 pub excerpt_index: usize,
81 pub section_ranges: Vec<(Arc<Path>, Range<usize>)>,
82}
83
84impl<'a> PlannedPrompt<'a> {
85 /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
86 ///
87 /// Initializes a priority queue by populating it with each snippet, finding the
88 /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
89 /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
90 /// the cost of upgrade.
91 ///
92 /// TODO: Implement an early halting condition. One option might be to have another priority
93 /// queue where the score is the size, and update it accordingly. Another option might be to
94 /// have some simpler heuristic like bailing after N failed insertions, or based on how much
95 /// budget is left.
96 ///
97 /// TODO: Has the current known sources of imprecision:
98 ///
99 /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
100 /// plan even though the containing struct is already included.
101 ///
102 /// * Does not consider cost of signatures when ranking snippets - this is tricky since
103 /// signatures may be shared by multiple snippets.
104 ///
105 /// * Does not include file paths / other text when considering max_bytes.
106 pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
107 let mut this = PlannedPrompt {
108 request,
109 snippets: Vec::new(),
110 budget_used: request.excerpt.len(),
111 };
112 let mut included_parents = FxHashSet::default();
113 let additional_parents = this.additional_parent_signatures(
114 &request.excerpt_path,
115 request.excerpt_parent,
116 &included_parents,
117 )?;
118 this.add_parents(&mut included_parents, additional_parents);
119
120 let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
121
122 if this.budget_used > max_bytes {
123 return Err(anyhow!(
124 "Excerpt + signatures size of {} already exceeds budget of {}",
125 this.budget_used,
126 max_bytes
127 ));
128 }
129
130 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
131 struct QueueEntry {
132 score_density: OrderedFloat<f32>,
133 declaration_index: usize,
134 style: DeclarationStyle,
135 }
136
137 // Initialize priority queue with the best score for each snippet.
138 let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
139 for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
140 let (style, score_density) = DeclarationStyle::iter()
141 .map(|style| {
142 (
143 style,
144 OrderedFloat(declaration_score_density(&declaration, style)),
145 )
146 })
147 .max_by_key(|(_, score_density)| *score_density)
148 .unwrap();
149 queue.push(QueueEntry {
150 score_density,
151 declaration_index,
152 style,
153 });
154 }
155
156 // Knapsack selection loop
157 while let Some(queue_entry) = queue.pop() {
158 let Some(declaration) = request
159 .referenced_declarations
160 .get(queue_entry.declaration_index)
161 else {
162 return Err(anyhow!(
163 "Invalid declaration index {}",
164 queue_entry.declaration_index
165 ));
166 };
167
168 let mut additional_bytes = declaration_size(declaration, queue_entry.style);
169 if this.budget_used + additional_bytes > max_bytes {
170 continue;
171 }
172
173 let additional_parents = this.additional_parent_signatures(
174 &declaration.path,
175 declaration.parent_index,
176 &mut included_parents,
177 )?;
178 additional_bytes += additional_parents
179 .iter()
180 .map(|(_, snippet)| snippet.text.len())
181 .sum::<usize>();
182 if this.budget_used + additional_bytes > max_bytes {
183 continue;
184 }
185
186 this.budget_used += additional_bytes;
187 this.add_parents(&mut included_parents, additional_parents);
188 let planned_snippet = match queue_entry.style {
189 DeclarationStyle::Signature => {
190 let Some(text) = declaration.text.get(declaration.signature_range.clone())
191 else {
192 return Err(anyhow!(
193 "Invalid declaration signature_range {:?} with text.len() = {}",
194 declaration.signature_range,
195 declaration.text.len()
196 ));
197 };
198 PlannedSnippet {
199 path: declaration.path.clone(),
200 range: (declaration.signature_range.start + declaration.range.start)
201 ..(declaration.signature_range.end + declaration.range.start),
202 text,
203 text_is_truncated: declaration.text_is_truncated,
204 }
205 }
206 DeclarationStyle::Declaration => PlannedSnippet {
207 path: declaration.path.clone(),
208 range: declaration.range.clone(),
209 text: &declaration.text,
210 text_is_truncated: declaration.text_is_truncated,
211 },
212 };
213 this.snippets.push(planned_snippet);
214
215 // When a Signature is consumed, insert an entry for Definition style.
216 if queue_entry.style == DeclarationStyle::Signature {
217 let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
218 let declaration_size =
219 declaration_size(&declaration, DeclarationStyle::Declaration);
220 let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
221 let declaration_score =
222 declaration_score(&declaration, DeclarationStyle::Declaration);
223
224 let score_diff = declaration_score - signature_score;
225 let size_diff = declaration_size.saturating_sub(signature_size);
226 if score_diff > 0.0001 && size_diff > 0 {
227 queue.push(QueueEntry {
228 declaration_index: queue_entry.declaration_index,
229 score_density: OrderedFloat(score_diff / (size_diff as f32)),
230 style: DeclarationStyle::Declaration,
231 });
232 }
233 }
234 }
235
236 anyhow::Ok(this)
237 }
238
239 fn add_parents(
240 &mut self,
241 included_parents: &mut FxHashSet<usize>,
242 snippets: Vec<(usize, PlannedSnippet<'a>)>,
243 ) {
244 for (parent_index, snippet) in snippets {
245 included_parents.insert(parent_index);
246 self.budget_used += snippet.text.len();
247 self.snippets.push(snippet);
248 }
249 }
250
251 fn additional_parent_signatures(
252 &self,
253 path: &Arc<Path>,
254 parent_index: Option<usize>,
255 included_parents: &FxHashSet<usize>,
256 ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
257 let mut results = Vec::new();
258 self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
259 Ok(results)
260 }
261
262 fn additional_parent_signatures_impl(
263 &self,
264 path: &Arc<Path>,
265 parent_index: Option<usize>,
266 included_parents: &FxHashSet<usize>,
267 results: &mut Vec<(usize, PlannedSnippet<'a>)>,
268 ) -> Result<()> {
269 let Some(parent_index) = parent_index else {
270 return Ok(());
271 };
272 if included_parents.contains(&parent_index) {
273 return Ok(());
274 }
275 let Some(parent_signature) = self.request.signatures.get(parent_index) else {
276 return Err(anyhow!("Invalid parent index {}", parent_index));
277 };
278 results.push((
279 parent_index,
280 PlannedSnippet {
281 path: path.clone(),
282 range: parent_signature.range.clone(),
283 text: &parent_signature.text,
284 text_is_truncated: parent_signature.text_is_truncated,
285 },
286 ));
287 self.additional_parent_signatures_impl(
288 path,
289 parent_signature.parent_index,
290 included_parents,
291 results,
292 )
293 }
294
295 /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
296 /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
297 /// chunks.
298 pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
299 let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
300 FxHashMap::default();
301 for snippet in &self.snippets {
302 file_to_snippets
303 .entry(&snippet.path)
304 .or_default()
305 .push(snippet);
306 }
307
308 // Reorder so that file with cursor comes last
309 let mut file_snippets = Vec::new();
310 let mut excerpt_file_snippets = Vec::new();
311 for (file_path, snippets) in file_to_snippets {
312 if file_path == self.request.excerpt_path.as_ref() {
313 excerpt_file_snippets = snippets;
314 } else {
315 file_snippets.push((file_path, snippets, false));
316 }
317 }
318 let excerpt_snippet = PlannedSnippet {
319 path: self.request.excerpt_path.clone(),
320 range: self.request.excerpt_range.clone(),
321 text: &self.request.excerpt,
322 text_is_truncated: false,
323 };
324 excerpt_file_snippets.push(&excerpt_snippet);
325 file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
326
327 let mut excerpt_file_insertions = match self.request.prompt_format {
328 PromptFormat::MarkedExcerpt => vec![
329 (
330 self.request.excerpt_range.start,
331 EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
332 ),
333 (
334 self.request.excerpt_range.start + self.request.cursor_offset,
335 CURSOR_MARKER,
336 ),
337 (
338 self.request
339 .excerpt_range
340 .end
341 .saturating_sub(0)
342 .max(self.request.excerpt_range.start),
343 EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
344 ),
345 ],
346 PromptFormat::LabeledSections => vec![(
347 self.request.excerpt_range.start + self.request.cursor_offset,
348 CURSOR_MARKER,
349 )],
350 PromptFormat::OnlySnippets => vec![],
351 };
352
353 let mut prompt = String::new();
354 prompt.push_str("## User Edits\n\n");
355 Self::push_events(&mut prompt, &self.request.events);
356
357 prompt.push_str("\n## Code\n\n");
358 let section_labels =
359 self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
360 Ok((prompt, section_labels))
361 }
362
363 fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
364 for event in events {
365 match event {
366 Event::BufferChange {
367 path,
368 old_path,
369 diff,
370 predicted,
371 } => {
372 if let Some(old_path) = &old_path
373 && let Some(new_path) = &path
374 {
375 if old_path != new_path {
376 writeln!(
377 output,
378 "User renamed {} to {}\n\n",
379 old_path.display(),
380 new_path.display()
381 )
382 .unwrap();
383 }
384 }
385
386 let path = path
387 .as_ref()
388 .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
389
390 if *predicted {
391 writeln!(
392 output,
393 "User accepted prediction {:?}:\n```diff\n{}\n```\n",
394 path, diff
395 )
396 .unwrap();
397 } else {
398 writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
399 .unwrap();
400 }
401 }
402 }
403 }
404 }
405
406 fn push_file_snippets(
407 &self,
408 output: &mut String,
409 excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
410 file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
411 ) -> Result<SectionLabels> {
412 let mut section_ranges = Vec::new();
413 let mut excerpt_index = None;
414
415 for (file_path, mut snippets, is_excerpt_file) in file_snippets {
416 snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
417
418 // TODO: What if the snippets get expanded too large to be editable?
419 let mut current_snippet: Option<(&PlannedSnippet, Range<usize>)> = None;
420 let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<usize>)> = Vec::new();
421 for snippet in snippets {
422 if let Some((_, current_snippet_range)) = current_snippet.as_mut()
423 && snippet.range.start < current_snippet_range.end
424 {
425 if snippet.range.end > current_snippet_range.end {
426 current_snippet_range.end = snippet.range.end;
427 }
428 continue;
429 }
430 if let Some(current_snippet) = current_snippet.take() {
431 disjoint_snippets.push(current_snippet);
432 }
433 current_snippet = Some((snippet, snippet.range.clone()));
434 }
435 if let Some(current_snippet) = current_snippet.take() {
436 disjoint_snippets.push(current_snippet);
437 }
438
439 writeln!(output, "```{}", file_path.display()).ok();
440 let mut skipped_last_snippet = false;
441 for (snippet, range) in disjoint_snippets {
442 let section_index = section_ranges.len();
443
444 match self.request.prompt_format {
445 PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets => {
446 if range.start > 0 && !skipped_last_snippet {
447 output.push_str("…\n");
448 }
449 }
450 PromptFormat::LabeledSections => {
451 if is_excerpt_file
452 && range.start <= self.request.excerpt_range.start
453 && range.end >= self.request.excerpt_range.end
454 {
455 writeln!(output, "<|current_section|>").ok();
456 } else {
457 writeln!(output, "<|section_{}|>", section_index).ok();
458 }
459 }
460 }
461
462 if is_excerpt_file {
463 if self.request.prompt_format == PromptFormat::OnlySnippets {
464 if range.start >= self.request.excerpt_range.start
465 && range.end <= self.request.excerpt_range.end
466 {
467 skipped_last_snippet = true;
468 } else {
469 skipped_last_snippet = false;
470 output.push_str(snippet.text);
471 }
472 } else {
473 let mut last_offset = range.start;
474 let mut i = 0;
475 while i < excerpt_file_insertions.len() {
476 let (offset, insertion) = &excerpt_file_insertions[i];
477 let found = *offset >= range.start && *offset <= range.end;
478 if found {
479 excerpt_index = Some(section_index);
480 output.push_str(
481 &snippet.text[last_offset - range.start..offset - range.start],
482 );
483 output.push_str(insertion);
484 last_offset = *offset;
485 excerpt_file_insertions.remove(i);
486 continue;
487 }
488 i += 1;
489 }
490 skipped_last_snippet = false;
491 output.push_str(&snippet.text[last_offset - range.start..]);
492 }
493 } else {
494 skipped_last_snippet = false;
495 output.push_str(snippet.text);
496 }
497
498 section_ranges.push((snippet.path.clone(), range));
499 }
500
501 output.push_str("```\n\n");
502 }
503
504 Ok(SectionLabels {
505 // TODO: Clean this up
506 excerpt_index: match self.request.prompt_format {
507 PromptFormat::OnlySnippets => 0,
508 _ => excerpt_index.context("bug: no snippet found for excerpt")?,
509 },
510 section_ranges,
511 })
512 }
513}
514
515fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
516 declaration_score(declaration, style) / declaration_size(declaration, style) as f32
517}
518
519fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
520 match style {
521 DeclarationStyle::Signature => declaration.signature_score,
522 DeclarationStyle::Declaration => declaration.declaration_score,
523 }
524}
525
526fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
527 match style {
528 DeclarationStyle::Signature => declaration.signature_range.len(),
529 DeclarationStyle::Declaration => declaration.text.len(),
530 }
531}