1//! Zeta2 prompt planning and generation code shared with cloud.
2
3use anyhow::{Context as _, Result, anyhow};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration};
5use indoc::indoc;
6use ordered_float::OrderedFloat;
7use rustc_hash::{FxHashMap, FxHashSet};
8use std::fmt::Write;
9use std::sync::Arc;
10use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
11use strum::{EnumIter, IntoEnumIterator};
12
13pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
14
15pub const CURSOR_MARKER: &str = "<|cursor_position|>";
16/// NOTE: Differs from zed version of constant - includes a newline
17pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
18/// NOTE: Differs from zed version of constant - includes a newline
19pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
20
21// TODO: use constants for markers?
22const MARKED_EXCERPT_SYSTEM_PROMPT: &str = indoc! {"
23 You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
24
25 The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|cursor_position|>. Please respond with edited code for that region.
26
27 Other code is provided for context, and `…` indicates when code has been skipped.
28"};
29
30const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
31 You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
32
33 Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
34
35 The cursor position is marked with `<|cursor_position|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
36
37 Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
38
39 <|current_section|>
40 for i in 0..16 {
41 println!("{i}");
42 }
43"#};
44
45pub struct PlannedPrompt<'a> {
46 request: &'a predict_edits_v3::PredictEditsRequest,
47 /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
48 /// `to_prompt_string`.
49 snippets: Vec<PlannedSnippet<'a>>,
50 budget_used: usize,
51}
52
53pub fn system_prompt(format: PromptFormat) -> &'static str {
54 match format {
55 PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
56 PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
57 }
58}
59
60#[derive(Clone, Debug)]
61pub struct PlannedSnippet<'a> {
62 path: Arc<Path>,
63 range: Range<usize>,
64 text: &'a str,
65 // TODO: Indicate this in the output
66 #[allow(dead_code)]
67 text_is_truncated: bool,
68}
69
70#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
71pub enum SnippetStyle {
72 Signature,
73 Declaration,
74}
75
76#[derive(Clone, Debug)]
77pub struct SectionLabels {
78 pub excerpt_index: usize,
79 pub section_ranges: Vec<(Arc<Path>, Range<usize>)>,
80}
81
82impl<'a> PlannedPrompt<'a> {
83 /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
84 ///
85 /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
86 /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
87 /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
88 /// upgrade.
89 ///
90 /// TODO: Implement an early halting condition. One option might be to have another priority
91 /// queue where the score is the size, and update it accordingly. Another option might be to
92 /// have some simpler heuristic like bailing after N failed insertions, or based on how much
93 /// budget is left.
94 ///
95 /// TODO: Has the current known sources of imprecision:
96 ///
97 /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
98 /// plan even though the containing struct is already included.
99 ///
100 /// * Does not consider cost of signatures when ranking snippets - this is tricky since
101 /// signatures may be shared by multiple snippets.
102 ///
103 /// * Does not include file paths / other text when considering max_bytes.
104 pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
105 let mut this = PlannedPrompt {
106 request,
107 snippets: Vec::new(),
108 budget_used: request.excerpt.len(),
109 };
110 let mut included_parents = FxHashSet::default();
111 let additional_parents = this.additional_parent_signatures(
112 &request.excerpt_path,
113 request.excerpt_parent,
114 &included_parents,
115 )?;
116 this.add_parents(&mut included_parents, additional_parents);
117
118 let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
119
120 if this.budget_used > max_bytes {
121 return Err(anyhow!(
122 "Excerpt + signatures size of {} already exceeds budget of {}",
123 this.budget_used,
124 max_bytes
125 ));
126 }
127
128 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
129 struct QueueEntry {
130 score_density: OrderedFloat<f32>,
131 declaration_index: usize,
132 style: SnippetStyle,
133 }
134
135 // Initialize priority queue with the best score for each snippet.
136 let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
137 for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
138 let (style, score_density) = SnippetStyle::iter()
139 .map(|style| {
140 (
141 style,
142 OrderedFloat(declaration_score_density(&declaration, style)),
143 )
144 })
145 .max_by_key(|(_, score_density)| *score_density)
146 .unwrap();
147 queue.push(QueueEntry {
148 score_density,
149 declaration_index,
150 style,
151 });
152 }
153
154 // Knapsack selection loop
155 while let Some(queue_entry) = queue.pop() {
156 let Some(declaration) = request
157 .referenced_declarations
158 .get(queue_entry.declaration_index)
159 else {
160 return Err(anyhow!(
161 "Invalid declaration index {}",
162 queue_entry.declaration_index
163 ));
164 };
165
166 let mut additional_bytes = declaration_size(declaration, queue_entry.style);
167 if this.budget_used + additional_bytes > max_bytes {
168 continue;
169 }
170
171 let additional_parents = this.additional_parent_signatures(
172 &declaration.path,
173 declaration.parent_index,
174 &mut included_parents,
175 )?;
176 additional_bytes += additional_parents
177 .iter()
178 .map(|(_, snippet)| snippet.text.len())
179 .sum::<usize>();
180 if this.budget_used + additional_bytes > max_bytes {
181 continue;
182 }
183
184 this.budget_used += additional_bytes;
185 this.add_parents(&mut included_parents, additional_parents);
186 let planned_snippet = match queue_entry.style {
187 SnippetStyle::Signature => {
188 let Some(text) = declaration.text.get(declaration.signature_range.clone())
189 else {
190 return Err(anyhow!(
191 "Invalid declaration signature_range {:?} with text.len() = {}",
192 declaration.signature_range,
193 declaration.text.len()
194 ));
195 };
196 PlannedSnippet {
197 path: declaration.path.clone(),
198 range: (declaration.signature_range.start + declaration.range.start)
199 ..(declaration.signature_range.end + declaration.range.start),
200 text,
201 text_is_truncated: declaration.text_is_truncated,
202 }
203 }
204 SnippetStyle::Declaration => PlannedSnippet {
205 path: declaration.path.clone(),
206 range: declaration.range.clone(),
207 text: &declaration.text,
208 text_is_truncated: declaration.text_is_truncated,
209 },
210 };
211 this.snippets.push(planned_snippet);
212
213 // When a Signature is consumed, insert an entry for Definition style.
214 if queue_entry.style == SnippetStyle::Signature {
215 let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
216 let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
217 let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
218 let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
219
220 let score_diff = declaration_score - signature_score;
221 let size_diff = declaration_size.saturating_sub(signature_size);
222 if score_diff > 0.0001 && size_diff > 0 {
223 queue.push(QueueEntry {
224 declaration_index: queue_entry.declaration_index,
225 score_density: OrderedFloat(score_diff / (size_diff as f32)),
226 style: SnippetStyle::Declaration,
227 });
228 }
229 }
230 }
231
232 anyhow::Ok(this)
233 }
234
235 fn add_parents(
236 &mut self,
237 included_parents: &mut FxHashSet<usize>,
238 snippets: Vec<(usize, PlannedSnippet<'a>)>,
239 ) {
240 for (parent_index, snippet) in snippets {
241 included_parents.insert(parent_index);
242 self.budget_used += snippet.text.len();
243 self.snippets.push(snippet);
244 }
245 }
246
247 fn additional_parent_signatures(
248 &self,
249 path: &Arc<Path>,
250 parent_index: Option<usize>,
251 included_parents: &FxHashSet<usize>,
252 ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
253 let mut results = Vec::new();
254 self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
255 Ok(results)
256 }
257
258 fn additional_parent_signatures_impl(
259 &self,
260 path: &Arc<Path>,
261 parent_index: Option<usize>,
262 included_parents: &FxHashSet<usize>,
263 results: &mut Vec<(usize, PlannedSnippet<'a>)>,
264 ) -> Result<()> {
265 let Some(parent_index) = parent_index else {
266 return Ok(());
267 };
268 if included_parents.contains(&parent_index) {
269 return Ok(());
270 }
271 let Some(parent_signature) = self.request.signatures.get(parent_index) else {
272 return Err(anyhow!("Invalid parent index {}", parent_index));
273 };
274 results.push((
275 parent_index,
276 PlannedSnippet {
277 path: path.clone(),
278 range: parent_signature.range.clone(),
279 text: &parent_signature.text,
280 text_is_truncated: parent_signature.text_is_truncated,
281 },
282 ));
283 self.additional_parent_signatures_impl(
284 path,
285 parent_signature.parent_index,
286 included_parents,
287 results,
288 )
289 }
290
291 /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
292 /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
293 /// chunks.
294 pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
295 let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
296 FxHashMap::default();
297 for snippet in &self.snippets {
298 file_to_snippets
299 .entry(&snippet.path)
300 .or_default()
301 .push(snippet);
302 }
303
304 // Reorder so that file with cursor comes last
305 let mut file_snippets = Vec::new();
306 let mut excerpt_file_snippets = Vec::new();
307 for (file_path, snippets) in file_to_snippets {
308 if file_path == self.request.excerpt_path.as_ref() {
309 excerpt_file_snippets = snippets;
310 } else {
311 file_snippets.push((file_path, snippets, false));
312 }
313 }
314 let excerpt_snippet = PlannedSnippet {
315 path: self.request.excerpt_path.clone(),
316 range: self.request.excerpt_range.clone(),
317 text: &self.request.excerpt,
318 text_is_truncated: false,
319 };
320 excerpt_file_snippets.push(&excerpt_snippet);
321 file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
322
323 let mut excerpt_file_insertions = match self.request.prompt_format {
324 PromptFormat::MarkedExcerpt => vec![
325 (
326 self.request.excerpt_range.start,
327 EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
328 ),
329 (
330 self.request.excerpt_range.start + self.request.cursor_offset,
331 CURSOR_MARKER,
332 ),
333 (
334 self.request
335 .excerpt_range
336 .end
337 .saturating_sub(0)
338 .max(self.request.excerpt_range.start),
339 EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
340 ),
341 ],
342 PromptFormat::LabeledSections => vec![(
343 self.request.excerpt_range.start + self.request.cursor_offset,
344 CURSOR_MARKER,
345 )],
346 };
347
348 let mut prompt = String::new();
349 prompt.push_str("## User Edits\n\n");
350 Self::push_events(&mut prompt, &self.request.events);
351
352 prompt.push_str("\n## Code\n\n");
353 let section_labels =
354 self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
355 Ok((prompt, section_labels))
356 }
357
358 fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
359 for event in events {
360 match event {
361 Event::BufferChange {
362 path,
363 old_path,
364 diff,
365 predicted,
366 } => {
367 if let Some(old_path) = &old_path
368 && let Some(new_path) = &path
369 {
370 if old_path != new_path {
371 writeln!(
372 output,
373 "User renamed {} to {}\n\n",
374 old_path.display(),
375 new_path.display()
376 )
377 .unwrap();
378 }
379 }
380
381 let path = path
382 .as_ref()
383 .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
384
385 if *predicted {
386 writeln!(
387 output,
388 "User accepted prediction {:?}:\n```diff\n{}\n```\n",
389 path, diff
390 )
391 .unwrap();
392 } else {
393 writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
394 .unwrap();
395 }
396 }
397 }
398 }
399 }
400
401 fn push_file_snippets(
402 &self,
403 output: &mut String,
404 excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
405 file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
406 ) -> Result<SectionLabels> {
407 let mut section_ranges = Vec::new();
408 let mut excerpt_index = None;
409
410 for (file_path, mut snippets, is_excerpt_file) in file_snippets {
411 snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
412
413 // TODO: What if the snippets get expanded too large to be editable?
414 let mut current_snippet: Option<(&PlannedSnippet, Range<usize>)> = None;
415 let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<usize>)> = Vec::new();
416 for snippet in snippets {
417 if let Some((_, current_snippet_range)) = current_snippet.as_mut()
418 && snippet.range.start < current_snippet_range.end
419 {
420 if snippet.range.end > current_snippet_range.end {
421 current_snippet_range.end = snippet.range.end;
422 }
423 continue;
424 }
425 if let Some(current_snippet) = current_snippet.take() {
426 disjoint_snippets.push(current_snippet);
427 }
428 current_snippet = Some((snippet, snippet.range.clone()));
429 }
430 if let Some(current_snippet) = current_snippet.take() {
431 disjoint_snippets.push(current_snippet);
432 }
433
434 writeln!(output, "```{}", file_path.display()).ok();
435 for (snippet, range) in disjoint_snippets {
436 let section_index = section_ranges.len();
437
438 match self.request.prompt_format {
439 PromptFormat::MarkedExcerpt => {
440 if range.start > 0 {
441 output.push_str("…\n");
442 }
443 }
444 PromptFormat::LabeledSections => {
445 if is_excerpt_file
446 && range.start <= self.request.excerpt_range.start
447 && range.end >= self.request.excerpt_range.end
448 {
449 writeln!(output, "<|current_section|>").ok();
450 } else {
451 writeln!(output, "<|section_{}|>", section_index).ok();
452 }
453 }
454 }
455
456 if is_excerpt_file {
457 excerpt_index = Some(section_index);
458 let mut last_offset = range.start;
459 let mut i = 0;
460 while i < excerpt_file_insertions.len() {
461 let (offset, insertion) = &excerpt_file_insertions[i];
462 let found = *offset >= range.start && *offset <= range.end;
463 if found {
464 output.push_str(
465 &snippet.text[last_offset - range.start..offset - range.start],
466 );
467 output.push_str(insertion);
468 last_offset = *offset;
469 excerpt_file_insertions.remove(i);
470 continue;
471 }
472 i += 1;
473 }
474 output.push_str(&snippet.text[last_offset - range.start..]);
475 } else {
476 output.push_str(snippet.text);
477 }
478
479 section_ranges.push((snippet.path.clone(), range));
480 }
481
482 output.push_str("```\n\n");
483 }
484
485 Ok(SectionLabels {
486 excerpt_index: excerpt_index.context("bug: no snippet found for excerpt")?,
487 section_ranges,
488 })
489 }
490}
491
492fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
493 declaration_score(declaration, style) / declaration_size(declaration, style) as f32
494}
495
496fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
497 match style {
498 SnippetStyle::Signature => declaration.signature_score,
499 SnippetStyle::Declaration => declaration.declaration_score,
500 }
501}
502
503fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
504 match style {
505 SnippetStyle::Signature => declaration.signature_range.len(),
506 SnippetStyle::Declaration => declaration.text.len(),
507 }
508}