1//! Zeta2 prompt planning and generation code shared with cloud.
2
3use anyhow::{Context as _, Result, anyhow};
4use cloud_llm_client::predict_edits_v3::{self, Line, Point, PromptFormat, ReferencedDeclaration};
5use indoc::indoc;
6use ordered_float::OrderedFloat;
7use rustc_hash::{FxHashMap, FxHashSet};
8use serde::Serialize;
9use std::fmt::Write;
10use std::sync::Arc;
11use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
12use strum::{EnumIter, IntoEnumIterator};
13
14pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
15
16pub const CURSOR_MARKER: &str = "<|user_cursor|>";
17/// NOTE: Differs from zed version of constant - includes a newline
18pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
19/// NOTE: Differs from zed version of constant - includes a newline
20pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
21
22// TODO: use constants for markers?
23const MARKED_EXCERPT_INSTRUCTIONS: &str = indoc! {"
24 You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
25
26 The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor|>. Please respond with edited code for that region.
27
28 Other code is provided for context, and `…` indicates when code has been skipped.
29
30 # Edit History:
31
32"};
33
34const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#"
35 You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
36
37 Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
38
39 The cursor position is marked with `<|user_cursor|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
40
41 Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
42
43 <|current_section|>
44 for i in 0..16 {
45 println!("{i}");
46 }
47
48 # Edit History:
49
50"#};
51
52const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#"
53 # Instructions
54
55 You are a code completion assistant helping a programmer finish their work. Your task is to:
56
57 1. Analyze the edit history to understand what the programmer is trying to achieve
58 2. Identify any incomplete refactoring or changes that need to be finished
59 3. Make the remaining edits that a human programmer would logically make next
60 4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
61
62 Focus on:
63 - Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
64 - Completing any partially-applied changes across the codebase
65 - Ensuring consistency with the programming style and patterns already established
66 - Making edits that maintain or improve code quality
67 - If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
68 - Don't write a lot of code if you're not sure what to do
69
70 Rules:
71 - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
72 - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
73 - Write the edits in the unified diff format as shown in the example.
74
75 # Example output:
76
77 ```
78 --- a/src/myapp/cli.py
79 +++ b/src/myapp/cli.py
80 @@ -1,3 +1,3 @@
81 -
82 -
83 -import sys
84 +import json
85 ```
86
87 # Edit History:
88
89"#};
90
91const UNIFIED_DIFF_REMINDER: &str = indoc! {"
92 ---
93
94 Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
95 Do not include the cursor marker in your output.
96 If you're editing multiple files, be sure to reflect filename in the hunk's header.
97"};
98
99pub struct PlannedPrompt<'a> {
100 request: &'a predict_edits_v3::PredictEditsRequest,
101 /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
102 /// `to_prompt_string`.
103 snippets: Vec<PlannedSnippet<'a>>,
104 budget_used: usize,
105}
106
107#[derive(Clone, Debug)]
108pub struct PlannedSnippet<'a> {
109 path: Arc<Path>,
110 range: Range<Line>,
111 text: &'a str,
112 // TODO: Indicate this in the output
113 #[allow(dead_code)]
114 text_is_truncated: bool,
115}
116
117#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
118pub enum DeclarationStyle {
119 Signature,
120 Declaration,
121}
122
123#[derive(Clone, Debug, Serialize)]
124pub struct SectionLabels {
125 pub excerpt_index: usize,
126 pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
127}
128
129impl<'a> PlannedPrompt<'a> {
130 /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
131 ///
132 /// Initializes a priority queue by populating it with each snippet, finding the
133 /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
134 /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
135 /// the cost of upgrade.
136 ///
137 /// TODO: Implement an early halting condition. One option might be to have another priority
138 /// queue where the score is the size, and update it accordingly. Another option might be to
139 /// have some simpler heuristic like bailing after N failed insertions, or based on how much
140 /// budget is left.
141 ///
142 /// TODO: Has the current known sources of imprecision:
143 ///
144 /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
145 /// plan even though the containing struct is already included.
146 ///
147 /// * Does not consider cost of signatures when ranking snippets - this is tricky since
148 /// signatures may be shared by multiple snippets.
149 ///
150 /// * Does not include file paths / other text when considering max_bytes.
151 pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
152 let mut this = PlannedPrompt {
153 request,
154 snippets: Vec::new(),
155 budget_used: request.excerpt.len(),
156 };
157 let mut included_parents = FxHashSet::default();
158 let additional_parents = this.additional_parent_signatures(
159 &request.excerpt_path,
160 request.excerpt_parent,
161 &included_parents,
162 )?;
163 this.add_parents(&mut included_parents, additional_parents);
164
165 let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
166
167 if this.budget_used > max_bytes {
168 return Err(anyhow!(
169 "Excerpt + signatures size of {} already exceeds budget of {}",
170 this.budget_used,
171 max_bytes
172 ));
173 }
174
175 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
176 struct QueueEntry {
177 score_density: OrderedFloat<f32>,
178 declaration_index: usize,
179 style: DeclarationStyle,
180 }
181
182 // Initialize priority queue with the best score for each snippet.
183 let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
184 for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
185 let (style, score_density) = DeclarationStyle::iter()
186 .map(|style| {
187 (
188 style,
189 OrderedFloat(declaration_score_density(&declaration, style)),
190 )
191 })
192 .max_by_key(|(_, score_density)| *score_density)
193 .unwrap();
194 queue.push(QueueEntry {
195 score_density,
196 declaration_index,
197 style,
198 });
199 }
200
201 // Knapsack selection loop
202 while let Some(queue_entry) = queue.pop() {
203 let Some(declaration) = request
204 .referenced_declarations
205 .get(queue_entry.declaration_index)
206 else {
207 return Err(anyhow!(
208 "Invalid declaration index {}",
209 queue_entry.declaration_index
210 ));
211 };
212
213 let mut additional_bytes = declaration_size(declaration, queue_entry.style);
214 if this.budget_used + additional_bytes > max_bytes {
215 continue;
216 }
217
218 let additional_parents = this.additional_parent_signatures(
219 &declaration.path,
220 declaration.parent_index,
221 &mut included_parents,
222 )?;
223 additional_bytes += additional_parents
224 .iter()
225 .map(|(_, snippet)| snippet.text.len())
226 .sum::<usize>();
227 if this.budget_used + additional_bytes > max_bytes {
228 continue;
229 }
230
231 this.budget_used += additional_bytes;
232 this.add_parents(&mut included_parents, additional_parents);
233 let planned_snippet = match queue_entry.style {
234 DeclarationStyle::Signature => {
235 let Some(text) = declaration.text.get(declaration.signature_range.clone())
236 else {
237 return Err(anyhow!(
238 "Invalid declaration signature_range {:?} with text.len() = {}",
239 declaration.signature_range,
240 declaration.text.len()
241 ));
242 };
243 let signature_start_line = declaration.range.start
244 + Line(
245 declaration.text[..declaration.signature_range.start]
246 .lines()
247 .count() as u32,
248 );
249 let signature_end_line = signature_start_line
250 + Line(
251 declaration.text
252 [declaration.signature_range.start..declaration.signature_range.end]
253 .lines()
254 .count() as u32,
255 );
256 let range = signature_start_line..signature_end_line;
257
258 PlannedSnippet {
259 path: declaration.path.clone(),
260 range,
261 text,
262 text_is_truncated: declaration.text_is_truncated,
263 }
264 }
265 DeclarationStyle::Declaration => PlannedSnippet {
266 path: declaration.path.clone(),
267 range: declaration.range.clone(),
268 text: &declaration.text,
269 text_is_truncated: declaration.text_is_truncated,
270 },
271 };
272 this.snippets.push(planned_snippet);
273
274 // When a Signature is consumed, insert an entry for Definition style.
275 if queue_entry.style == DeclarationStyle::Signature {
276 let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
277 let declaration_size =
278 declaration_size(&declaration, DeclarationStyle::Declaration);
279 let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
280 let declaration_score =
281 declaration_score(&declaration, DeclarationStyle::Declaration);
282
283 let score_diff = declaration_score - signature_score;
284 let size_diff = declaration_size.saturating_sub(signature_size);
285 if score_diff > 0.0001 && size_diff > 0 {
286 queue.push(QueueEntry {
287 declaration_index: queue_entry.declaration_index,
288 score_density: OrderedFloat(score_diff / (size_diff as f32)),
289 style: DeclarationStyle::Declaration,
290 });
291 }
292 }
293 }
294
295 anyhow::Ok(this)
296 }
297
298 fn add_parents(
299 &mut self,
300 included_parents: &mut FxHashSet<usize>,
301 snippets: Vec<(usize, PlannedSnippet<'a>)>,
302 ) {
303 for (parent_index, snippet) in snippets {
304 included_parents.insert(parent_index);
305 self.budget_used += snippet.text.len();
306 self.snippets.push(snippet);
307 }
308 }
309
310 fn additional_parent_signatures(
311 &self,
312 path: &Arc<Path>,
313 parent_index: Option<usize>,
314 included_parents: &FxHashSet<usize>,
315 ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
316 let mut results = Vec::new();
317 self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
318 Ok(results)
319 }
320
321 fn additional_parent_signatures_impl(
322 &self,
323 path: &Arc<Path>,
324 parent_index: Option<usize>,
325 included_parents: &FxHashSet<usize>,
326 results: &mut Vec<(usize, PlannedSnippet<'a>)>,
327 ) -> Result<()> {
328 let Some(parent_index) = parent_index else {
329 return Ok(());
330 };
331 if included_parents.contains(&parent_index) {
332 return Ok(());
333 }
334 let Some(parent_signature) = self.request.signatures.get(parent_index) else {
335 return Err(anyhow!("Invalid parent index {}", parent_index));
336 };
337 results.push((
338 parent_index,
339 PlannedSnippet {
340 path: path.clone(),
341 range: parent_signature.range.clone(),
342 text: &parent_signature.text,
343 text_is_truncated: parent_signature.text_is_truncated,
344 },
345 ));
346 self.additional_parent_signatures_impl(
347 path,
348 parent_signature.parent_index,
349 included_parents,
350 results,
351 )
352 }
353
354 /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
355 /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
356 /// chunks.
357 pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
358 let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
359 FxHashMap::default();
360 for snippet in &self.snippets {
361 file_to_snippets
362 .entry(&snippet.path)
363 .or_default()
364 .push(snippet);
365 }
366
367 // Reorder so that file with cursor comes last
368 let mut file_snippets = Vec::new();
369 let mut excerpt_file_snippets = Vec::new();
370 for (file_path, snippets) in file_to_snippets {
371 if file_path == self.request.excerpt_path.as_ref() {
372 excerpt_file_snippets = snippets;
373 } else {
374 file_snippets.push((file_path, snippets, false));
375 }
376 }
377 let excerpt_snippet = PlannedSnippet {
378 path: self.request.excerpt_path.clone(),
379 range: self.request.excerpt_line_range.clone(),
380 text: &self.request.excerpt,
381 text_is_truncated: false,
382 };
383 excerpt_file_snippets.push(&excerpt_snippet);
384 file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
385
386 let mut excerpt_file_insertions = match self.request.prompt_format {
387 PromptFormat::MarkedExcerpt => vec![
388 (
389 Point {
390 line: self.request.excerpt_line_range.start,
391 column: 0,
392 },
393 EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
394 ),
395 (self.request.cursor_point, CURSOR_MARKER),
396 (
397 Point {
398 line: self.request.excerpt_line_range.end,
399 column: 0,
400 },
401 EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
402 ),
403 ],
404 PromptFormat::LabeledSections => vec![(self.request.cursor_point, CURSOR_MARKER)],
405 PromptFormat::NumLinesUniDiff => {
406 vec![(self.request.cursor_point, CURSOR_MARKER)]
407 }
408 PromptFormat::OnlySnippets => vec![],
409 };
410
411 let mut prompt = match self.request.prompt_format {
412 PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(),
413 PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(),
414 PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
415 // only intended for use via zeta_cli
416 PromptFormat::OnlySnippets => String::new(),
417 };
418
419 if self.request.events.is_empty() {
420 prompt.push_str("(No edit history)\n\n");
421 } else {
422 prompt.push_str(
423 "The following are the latest edits made by the user, from earlier to later.\n\n",
424 );
425 Self::push_events(&mut prompt, &self.request.events);
426 }
427
428 if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
429 if self.request.referenced_declarations.is_empty() {
430 prompt.push_str(indoc! {"
431 # File under the cursor:
432
433 The cursor marker <|user_cursor|> indicates the current user cursor position.
434 The file is in current state, edits from edit history have been applied.
435 We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
436
437 "});
438 } else {
439 // Note: This hasn't been trained on yet
440 prompt.push_str(indoc! {"
441 # Code Excerpts:
442
443 The cursor marker <|user_cursor|> indicates the current user cursor position.
444 Other excerpts of code from the project have been included as context based on their similarity to the code under the cursor.
445 Context excerpts are not guaranteed to be relevant, so use your own judgement.
446 Files are in their current state, edits from edit history have been applied.
447 We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
448
449 "});
450 }
451 } else {
452 prompt.push_str("\n## Code\n\n");
453 }
454
455 let section_labels =
456 self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
457
458 if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
459 prompt.push_str(UNIFIED_DIFF_REMINDER);
460 }
461
462 Ok((prompt, section_labels))
463 }
464
465 fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
466 if events.is_empty() {
467 return;
468 };
469
470 writeln!(output, "`````diff").unwrap();
471 for event in events {
472 writeln!(output, "{}", event).unwrap();
473 }
474 writeln!(output, "`````\n").unwrap();
475 }
476
477 fn push_file_snippets(
478 &self,
479 output: &mut String,
480 excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
481 file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
482 ) -> Result<SectionLabels> {
483 let mut section_ranges = Vec::new();
484 let mut excerpt_index = None;
485
486 for (file_path, mut snippets, is_excerpt_file) in file_snippets {
487 snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
488
489 // TODO: What if the snippets get expanded too large to be editable?
490 let mut current_snippet: Option<(&PlannedSnippet, Range<Line>)> = None;
491 let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<Line>)> = Vec::new();
492 for snippet in snippets {
493 if let Some((_, current_snippet_range)) = current_snippet.as_mut()
494 && snippet.range.start <= current_snippet_range.end
495 {
496 current_snippet_range.end = current_snippet_range.end.max(snippet.range.end);
497 continue;
498 }
499 if let Some(current_snippet) = current_snippet.take() {
500 disjoint_snippets.push(current_snippet);
501 }
502 current_snippet = Some((snippet, snippet.range.clone()));
503 }
504 if let Some(current_snippet) = current_snippet.take() {
505 disjoint_snippets.push(current_snippet);
506 }
507
508 // TODO: remove filename=?
509 writeln!(output, "`````filename={}", file_path.display()).ok();
510 let mut skipped_last_snippet = false;
511 for (snippet, range) in disjoint_snippets {
512 let section_index = section_ranges.len();
513
514 match self.request.prompt_format {
515 PromptFormat::MarkedExcerpt
516 | PromptFormat::OnlySnippets
517 | PromptFormat::NumLinesUniDiff => {
518 if range.start.0 > 0 && !skipped_last_snippet {
519 output.push_str("…\n");
520 }
521 }
522 PromptFormat::LabeledSections => {
523 if is_excerpt_file
524 && range.start <= self.request.excerpt_line_range.start
525 && range.end >= self.request.excerpt_line_range.end
526 {
527 writeln!(output, "<|current_section|>").ok();
528 } else {
529 writeln!(output, "<|section_{}|>", section_index).ok();
530 }
531 }
532 }
533
534 let push_full_snippet = |output: &mut String| {
535 if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
536 for (i, line) in snippet.text.lines().enumerate() {
537 writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?;
538 }
539 } else {
540 output.push_str(&snippet.text);
541 }
542 anyhow::Ok(())
543 };
544
545 if is_excerpt_file {
546 if self.request.prompt_format == PromptFormat::OnlySnippets {
547 if range.start >= self.request.excerpt_line_range.start
548 && range.end <= self.request.excerpt_line_range.end
549 {
550 skipped_last_snippet = true;
551 } else {
552 skipped_last_snippet = false;
553 output.push_str(snippet.text);
554 }
555 } else if !excerpt_file_insertions.is_empty() {
556 let lines = snippet.text.lines().collect::<Vec<_>>();
557 let push_line = |output: &mut String, line_ix: usize| {
558 if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
559 write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?;
560 }
561 anyhow::Ok(writeln!(output, "{}", lines[line_ix])?)
562 };
563 let mut last_line_ix = 0;
564 let mut insertion_ix = 0;
565 while insertion_ix < excerpt_file_insertions.len() {
566 let (point, insertion) = &excerpt_file_insertions[insertion_ix];
567 let found = point.line >= range.start && point.line <= range.end;
568 if found {
569 excerpt_index = Some(section_index);
570 let insertion_line_ix = (point.line.0 - range.start.0) as usize;
571 for line_ix in last_line_ix..insertion_line_ix {
572 push_line(output, line_ix)?;
573 }
574 if let Some(next_line) = lines.get(insertion_line_ix) {
575 if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
576 write!(
577 output,
578 "{}|",
579 insertion_line_ix as u32 + range.start.0 + 1
580 )?
581 }
582 output.push_str(&next_line[..point.column as usize]);
583 output.push_str(insertion);
584 writeln!(output, "{}", &next_line[point.column as usize..])?;
585 } else {
586 writeln!(output, "{}", insertion)?;
587 }
588 last_line_ix = insertion_line_ix + 1;
589 excerpt_file_insertions.remove(insertion_ix);
590 continue;
591 }
592 insertion_ix += 1;
593 }
594 skipped_last_snippet = false;
595 for line_ix in last_line_ix..lines.len() {
596 push_line(output, line_ix)?;
597 }
598 } else {
599 skipped_last_snippet = false;
600 push_full_snippet(output)?;
601 }
602 } else {
603 skipped_last_snippet = false;
604 push_full_snippet(output)?;
605 }
606
607 section_ranges.push((snippet.path.clone(), range));
608 }
609
610 output.push_str("`````\n\n");
611 }
612
613 Ok(SectionLabels {
614 // TODO: Clean this up
615 excerpt_index: match self.request.prompt_format {
616 PromptFormat::OnlySnippets => 0,
617 _ => excerpt_index.context("bug: no snippet found for excerpt")?,
618 },
619 section_ranges,
620 })
621 }
622}
623
624fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
625 declaration_score(declaration, style) / declaration_size(declaration, style) as f32
626}
627
628fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
629 match style {
630 DeclarationStyle::Signature => declaration.signature_score,
631 DeclarationStyle::Declaration => declaration.declaration_score,
632 }
633}
634
635fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
636 match style {
637 DeclarationStyle::Signature => declaration.signature_range.len(),
638 DeclarationStyle::Declaration => declaration.text.len(),
639 }
640}