1//! Zeta2 prompt planning and generation code shared with cloud.
2
3use anyhow::{Context as _, Result, anyhow};
4use cloud_llm_client::predict_edits_v3::{self, Event, PromptFormat, ReferencedDeclaration};
5use indoc::indoc;
6use ordered_float::OrderedFloat;
7use rustc_hash::{FxHashMap, FxHashSet};
8use serde::Serialize;
9use std::fmt::Write;
10use std::sync::Arc;
11use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
12use strum::{EnumIter, IntoEnumIterator};
13
14pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
15
16pub const CURSOR_MARKER: &str = "<|cursor_position|>";
17/// NOTE: Differs from zed version of constant - includes a newline
18pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
19/// NOTE: Differs from zed version of constant - includes a newline
20pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
21
22// TODO: use constants for markers?
23const MARKED_EXCERPT_SYSTEM_PROMPT: &str = indoc! {"
24 You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
25
26 The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|cursor_position|>. Please respond with edited code for that region.
27
28 Other code is provided for context, and `…` indicates when code has been skipped.
29"};
30
31const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
32 You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
33
34 Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
35
36 The cursor position is marked with `<|cursor_position|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
37
38 Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
39
40 <|current_section|>
41 for i in 0..16 {
42 println!("{i}");
43 }
44"#};
45
46pub struct PlannedPrompt<'a> {
47 request: &'a predict_edits_v3::PredictEditsRequest,
48 /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
49 /// `to_prompt_string`.
50 snippets: Vec<PlannedSnippet<'a>>,
51 budget_used: usize,
52}
53
54pub fn system_prompt(format: PromptFormat) -> &'static str {
55 match format {
56 PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
57 PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
58 // only intended for use via zeta_cli
59 PromptFormat::OnlySnippets => "",
60 }
61}
62
63#[derive(Clone, Debug)]
64pub struct PlannedSnippet<'a> {
65 path: Arc<Path>,
66 range: Range<usize>,
67 text: &'a str,
68 // TODO: Indicate this in the output
69 #[allow(dead_code)]
70 text_is_truncated: bool,
71}
72
73#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
74pub enum DeclarationStyle {
75 Signature,
76 Declaration,
77}
78
79#[derive(Clone, Debug, Serialize)]
80pub struct SectionLabels {
81 pub excerpt_index: usize,
82 pub section_ranges: Vec<(Arc<Path>, Range<usize>)>,
83}
84
85impl<'a> PlannedPrompt<'a> {
86 /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
87 ///
88 /// Initializes a priority queue by populating it with each snippet, finding the
89 /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
90 /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
91 /// the cost of upgrade.
92 ///
93 /// TODO: Implement an early halting condition. One option might be to have another priority
94 /// queue where the score is the size, and update it accordingly. Another option might be to
95 /// have some simpler heuristic like bailing after N failed insertions, or based on how much
96 /// budget is left.
97 ///
98 /// TODO: Has the current known sources of imprecision:
99 ///
100 /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
101 /// plan even though the containing struct is already included.
102 ///
103 /// * Does not consider cost of signatures when ranking snippets - this is tricky since
104 /// signatures may be shared by multiple snippets.
105 ///
106 /// * Does not include file paths / other text when considering max_bytes.
107 pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
108 let mut this = PlannedPrompt {
109 request,
110 snippets: Vec::new(),
111 budget_used: request.excerpt.len(),
112 };
113 let mut included_parents = FxHashSet::default();
114 let additional_parents = this.additional_parent_signatures(
115 &request.excerpt_path,
116 request.excerpt_parent,
117 &included_parents,
118 )?;
119 this.add_parents(&mut included_parents, additional_parents);
120
121 let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
122
123 if this.budget_used > max_bytes {
124 return Err(anyhow!(
125 "Excerpt + signatures size of {} already exceeds budget of {}",
126 this.budget_used,
127 max_bytes
128 ));
129 }
130
131 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
132 struct QueueEntry {
133 score_density: OrderedFloat<f32>,
134 declaration_index: usize,
135 style: DeclarationStyle,
136 }
137
138 // Initialize priority queue with the best score for each snippet.
139 let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
140 for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
141 let (style, score_density) = DeclarationStyle::iter()
142 .map(|style| {
143 (
144 style,
145 OrderedFloat(declaration_score_density(&declaration, style)),
146 )
147 })
148 .max_by_key(|(_, score_density)| *score_density)
149 .unwrap();
150 queue.push(QueueEntry {
151 score_density,
152 declaration_index,
153 style,
154 });
155 }
156
157 // Knapsack selection loop
158 while let Some(queue_entry) = queue.pop() {
159 let Some(declaration) = request
160 .referenced_declarations
161 .get(queue_entry.declaration_index)
162 else {
163 return Err(anyhow!(
164 "Invalid declaration index {}",
165 queue_entry.declaration_index
166 ));
167 };
168
169 let mut additional_bytes = declaration_size(declaration, queue_entry.style);
170 if this.budget_used + additional_bytes > max_bytes {
171 continue;
172 }
173
174 let additional_parents = this.additional_parent_signatures(
175 &declaration.path,
176 declaration.parent_index,
177 &mut included_parents,
178 )?;
179 additional_bytes += additional_parents
180 .iter()
181 .map(|(_, snippet)| snippet.text.len())
182 .sum::<usize>();
183 if this.budget_used + additional_bytes > max_bytes {
184 continue;
185 }
186
187 this.budget_used += additional_bytes;
188 this.add_parents(&mut included_parents, additional_parents);
189 let planned_snippet = match queue_entry.style {
190 DeclarationStyle::Signature => {
191 let Some(text) = declaration.text.get(declaration.signature_range.clone())
192 else {
193 return Err(anyhow!(
194 "Invalid declaration signature_range {:?} with text.len() = {}",
195 declaration.signature_range,
196 declaration.text.len()
197 ));
198 };
199 PlannedSnippet {
200 path: declaration.path.clone(),
201 range: (declaration.signature_range.start + declaration.range.start)
202 ..(declaration.signature_range.end + declaration.range.start),
203 text,
204 text_is_truncated: declaration.text_is_truncated,
205 }
206 }
207 DeclarationStyle::Declaration => PlannedSnippet {
208 path: declaration.path.clone(),
209 range: declaration.range.clone(),
210 text: &declaration.text,
211 text_is_truncated: declaration.text_is_truncated,
212 },
213 };
214 this.snippets.push(planned_snippet);
215
216 // When a Signature is consumed, insert an entry for Definition style.
217 if queue_entry.style == DeclarationStyle::Signature {
218 let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
219 let declaration_size =
220 declaration_size(&declaration, DeclarationStyle::Declaration);
221 let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
222 let declaration_score =
223 declaration_score(&declaration, DeclarationStyle::Declaration);
224
225 let score_diff = declaration_score - signature_score;
226 let size_diff = declaration_size.saturating_sub(signature_size);
227 if score_diff > 0.0001 && size_diff > 0 {
228 queue.push(QueueEntry {
229 declaration_index: queue_entry.declaration_index,
230 score_density: OrderedFloat(score_diff / (size_diff as f32)),
231 style: DeclarationStyle::Declaration,
232 });
233 }
234 }
235 }
236
237 anyhow::Ok(this)
238 }
239
240 fn add_parents(
241 &mut self,
242 included_parents: &mut FxHashSet<usize>,
243 snippets: Vec<(usize, PlannedSnippet<'a>)>,
244 ) {
245 for (parent_index, snippet) in snippets {
246 included_parents.insert(parent_index);
247 self.budget_used += snippet.text.len();
248 self.snippets.push(snippet);
249 }
250 }
251
252 fn additional_parent_signatures(
253 &self,
254 path: &Arc<Path>,
255 parent_index: Option<usize>,
256 included_parents: &FxHashSet<usize>,
257 ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
258 let mut results = Vec::new();
259 self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
260 Ok(results)
261 }
262
263 fn additional_parent_signatures_impl(
264 &self,
265 path: &Arc<Path>,
266 parent_index: Option<usize>,
267 included_parents: &FxHashSet<usize>,
268 results: &mut Vec<(usize, PlannedSnippet<'a>)>,
269 ) -> Result<()> {
270 let Some(parent_index) = parent_index else {
271 return Ok(());
272 };
273 if included_parents.contains(&parent_index) {
274 return Ok(());
275 }
276 let Some(parent_signature) = self.request.signatures.get(parent_index) else {
277 return Err(anyhow!("Invalid parent index {}", parent_index));
278 };
279 results.push((
280 parent_index,
281 PlannedSnippet {
282 path: path.clone(),
283 range: parent_signature.range.clone(),
284 text: &parent_signature.text,
285 text_is_truncated: parent_signature.text_is_truncated,
286 },
287 ));
288 self.additional_parent_signatures_impl(
289 path,
290 parent_signature.parent_index,
291 included_parents,
292 results,
293 )
294 }
295
296 /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
297 /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
298 /// chunks.
299 pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
300 let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
301 FxHashMap::default();
302 for snippet in &self.snippets {
303 file_to_snippets
304 .entry(&snippet.path)
305 .or_default()
306 .push(snippet);
307 }
308
309 // Reorder so that file with cursor comes last
310 let mut file_snippets = Vec::new();
311 let mut excerpt_file_snippets = Vec::new();
312 for (file_path, snippets) in file_to_snippets {
313 if file_path == self.request.excerpt_path.as_ref() {
314 excerpt_file_snippets = snippets;
315 } else {
316 file_snippets.push((file_path, snippets, false));
317 }
318 }
319 let excerpt_snippet = PlannedSnippet {
320 path: self.request.excerpt_path.clone(),
321 range: self.request.excerpt_range.clone(),
322 text: &self.request.excerpt,
323 text_is_truncated: false,
324 };
325 excerpt_file_snippets.push(&excerpt_snippet);
326 file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
327
328 let mut excerpt_file_insertions = match self.request.prompt_format {
329 PromptFormat::MarkedExcerpt => vec![
330 (
331 self.request.excerpt_range.start,
332 EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
333 ),
334 (
335 self.request.excerpt_range.start + self.request.cursor_offset,
336 CURSOR_MARKER,
337 ),
338 (
339 self.request
340 .excerpt_range
341 .end
342 .saturating_sub(0)
343 .max(self.request.excerpt_range.start),
344 EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
345 ),
346 ],
347 PromptFormat::LabeledSections => vec![(
348 self.request.excerpt_range.start + self.request.cursor_offset,
349 CURSOR_MARKER,
350 )],
351 PromptFormat::OnlySnippets => vec![],
352 };
353
354 let mut prompt = String::new();
355 prompt.push_str("## User Edits\n\n");
356 Self::push_events(&mut prompt, &self.request.events);
357
358 prompt.push_str("\n## Code\n\n");
359 let section_labels =
360 self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
361 Ok((prompt, section_labels))
362 }
363
364 fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
365 for event in events {
366 match event {
367 Event::BufferChange {
368 path,
369 old_path,
370 diff,
371 predicted,
372 } => {
373 if let Some(old_path) = &old_path
374 && let Some(new_path) = &path
375 {
376 if old_path != new_path {
377 writeln!(
378 output,
379 "User renamed {} to {}\n\n",
380 old_path.display(),
381 new_path.display()
382 )
383 .unwrap();
384 }
385 }
386
387 let path = path
388 .as_ref()
389 .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
390
391 if *predicted {
392 writeln!(
393 output,
394 "User accepted prediction {:?}:\n```diff\n{}\n```\n",
395 path, diff
396 )
397 .unwrap();
398 } else {
399 writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
400 .unwrap();
401 }
402 }
403 }
404 }
405 }
406
407 fn push_file_snippets(
408 &self,
409 output: &mut String,
410 excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
411 file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
412 ) -> Result<SectionLabels> {
413 let mut section_ranges = Vec::new();
414 let mut excerpt_index = None;
415
416 for (file_path, mut snippets, is_excerpt_file) in file_snippets {
417 snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
418
419 // TODO: What if the snippets get expanded too large to be editable?
420 let mut current_snippet: Option<(&PlannedSnippet, Range<usize>)> = None;
421 let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<usize>)> = Vec::new();
422 for snippet in snippets {
423 if let Some((_, current_snippet_range)) = current_snippet.as_mut()
424 && snippet.range.start < current_snippet_range.end
425 {
426 if snippet.range.end > current_snippet_range.end {
427 current_snippet_range.end = snippet.range.end;
428 }
429 continue;
430 }
431 if let Some(current_snippet) = current_snippet.take() {
432 disjoint_snippets.push(current_snippet);
433 }
434 current_snippet = Some((snippet, snippet.range.clone()));
435 }
436 if let Some(current_snippet) = current_snippet.take() {
437 disjoint_snippets.push(current_snippet);
438 }
439
440 writeln!(output, "```{}", file_path.display()).ok();
441 let mut skipped_last_snippet = false;
442 for (snippet, range) in disjoint_snippets {
443 let section_index = section_ranges.len();
444
445 match self.request.prompt_format {
446 PromptFormat::MarkedExcerpt | PromptFormat::OnlySnippets => {
447 if range.start > 0 && !skipped_last_snippet {
448 output.push_str("…\n");
449 }
450 }
451 PromptFormat::LabeledSections => {
452 if is_excerpt_file
453 && range.start <= self.request.excerpt_range.start
454 && range.end >= self.request.excerpt_range.end
455 {
456 writeln!(output, "<|current_section|>").ok();
457 } else {
458 writeln!(output, "<|section_{}|>", section_index).ok();
459 }
460 }
461 }
462
463 if is_excerpt_file {
464 if self.request.prompt_format == PromptFormat::OnlySnippets {
465 if range.start >= self.request.excerpt_range.start
466 && range.end <= self.request.excerpt_range.end
467 {
468 skipped_last_snippet = true;
469 } else {
470 skipped_last_snippet = false;
471 output.push_str(snippet.text);
472 }
473 } else {
474 let mut last_offset = range.start;
475 let mut i = 0;
476 while i < excerpt_file_insertions.len() {
477 let (offset, insertion) = &excerpt_file_insertions[i];
478 let found = *offset >= range.start && *offset <= range.end;
479 if found {
480 excerpt_index = Some(section_index);
481 output.push_str(
482 &snippet.text[last_offset - range.start..offset - range.start],
483 );
484 output.push_str(insertion);
485 last_offset = *offset;
486 excerpt_file_insertions.remove(i);
487 continue;
488 }
489 i += 1;
490 }
491 skipped_last_snippet = false;
492 output.push_str(&snippet.text[last_offset - range.start..]);
493 }
494 } else {
495 skipped_last_snippet = false;
496 output.push_str(snippet.text);
497 }
498
499 section_ranges.push((snippet.path.clone(), range));
500 }
501
502 output.push_str("```\n\n");
503 }
504
505 Ok(SectionLabels {
506 // TODO: Clean this up
507 excerpt_index: match self.request.prompt_format {
508 PromptFormat::OnlySnippets => 0,
509 _ => excerpt_index.context("bug: no snippet found for excerpt")?,
510 },
511 section_ranges,
512 })
513 }
514}
515
516fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
517 declaration_score(declaration, style) / declaration_size(declaration, style) as f32
518}
519
520fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
521 match style {
522 DeclarationStyle::Signature => declaration.signature_score,
523 DeclarationStyle::Declaration => declaration.declaration_score,
524 }
525}
526
527fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
528 match style {
529 DeclarationStyle::Signature => declaration.signature_range.len(),
530 DeclarationStyle::Declaration => declaration.text.len(),
531 }
532}