1//! Zeta2 prompt planning and generation code shared with cloud.
2
3use anyhow::{Context as _, Result, anyhow};
4use cloud_llm_client::predict_edits_v3::{
5 self, Event, Line, Point, PromptFormat, ReferencedDeclaration,
6};
7use indoc::indoc;
8use ordered_float::OrderedFloat;
9use rustc_hash::{FxHashMap, FxHashSet};
10use serde::Serialize;
11use std::fmt::Write;
12use std::sync::Arc;
13use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
14use strum::{EnumIter, IntoEnumIterator};
15
16pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
17
18pub const CURSOR_MARKER: &str = "<|user_cursor|>";
19/// NOTE: Differs from zed version of constant - includes a newline
20pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
21/// NOTE: Differs from zed version of constant - includes a newline
22pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
23
24// TODO: use constants for markers?
25const MARKED_EXCERPT_INSTRUCTIONS: &str = indoc! {"
26 You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
27
28 The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor|>. Please respond with edited code for that region.
29
30 Other code is provided for context, and `…` indicates when code has been skipped.
31
32 # Edit History:
33
34"};
35
36const LABELED_SECTIONS_INSTRUCTIONS: &str = indoc! {r#"
37 You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
38
39 Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
40
41 The cursor position is marked with `<|user_cursor|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
42
43 Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
44
45 <|current_section|>
46 for i in 0..16 {
47 println!("{i}");
48 }
49
50 # Edit History:
51
52"#};
53
54const NUMBERED_LINES_INSTRUCTIONS: &str = indoc! {r#"
55 # Instructions
56
57 You are a code completion assistant helping a programmer finish their work. Your task is to:
58
59 1. Analyze the edit history to understand what the programmer is trying to achieve
60 2. Identify any incomplete refactoring or changes that need to be finished
61 3. Make the remaining edits that a human programmer would logically make next
62 4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
63
64 Focus on:
65 - Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
66 - Completing any partially-applied changes across the codebase
67 - Ensuring consistency with the programming style and patterns already established
68 - Making edits that maintain or improve code quality
69 - If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
70 - Don't write a lot of code if you're not sure what to do
71
72 Rules:
73 - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
74 - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
75 - Write the edits in the unified diff format as shown in the example.
76
77 # Example output:
78
79 ```
80 --- a/src/myapp/cli.py
81 +++ b/src/myapp/cli.py
82 @@ -1,3 +1,3 @@
83 -
84 -
85 -import sys
86 +import json
87 ```
88
89 # Edit History:
90
91"#};
92
93const UNIFIED_DIFF_REMINDER: &str = indoc! {"
94 ---
95
96 Please analyze the edit history and the files, then provide the unified diff for your predicted edits.
97 Do not include the cursor marker in your output.
98 If you're editing multiple files, be sure to reflect filename in the hunk's header.
99"};
100
101pub struct PlannedPrompt<'a> {
102 request: &'a predict_edits_v3::PredictEditsRequest,
103 /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
104 /// `to_prompt_string`.
105 snippets: Vec<PlannedSnippet<'a>>,
106 budget_used: usize,
107}
108
109#[derive(Clone, Debug)]
110pub struct PlannedSnippet<'a> {
111 path: Arc<Path>,
112 range: Range<Line>,
113 text: &'a str,
114 // TODO: Indicate this in the output
115 #[allow(dead_code)]
116 text_is_truncated: bool,
117}
118
119#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
120pub enum DeclarationStyle {
121 Signature,
122 Declaration,
123}
124
125#[derive(Clone, Debug, Serialize)]
126pub struct SectionLabels {
127 pub excerpt_index: usize,
128 pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
129}
130
131impl<'a> PlannedPrompt<'a> {
132 /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
133 ///
134 /// Initializes a priority queue by populating it with each snippet, finding the
135 /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
136 /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
137 /// the cost of upgrade.
138 ///
139 /// TODO: Implement an early halting condition. One option might be to have another priority
140 /// queue where the score is the size, and update it accordingly. Another option might be to
141 /// have some simpler heuristic like bailing after N failed insertions, or based on how much
142 /// budget is left.
143 ///
144 /// TODO: Has the current known sources of imprecision:
145 ///
146 /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
147 /// plan even though the containing struct is already included.
148 ///
149 /// * Does not consider cost of signatures when ranking snippets - this is tricky since
150 /// signatures may be shared by multiple snippets.
151 ///
152 /// * Does not include file paths / other text when considering max_bytes.
153 pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
154 let mut this = PlannedPrompt {
155 request,
156 snippets: Vec::new(),
157 budget_used: request.excerpt.len(),
158 };
159 let mut included_parents = FxHashSet::default();
160 let additional_parents = this.additional_parent_signatures(
161 &request.excerpt_path,
162 request.excerpt_parent,
163 &included_parents,
164 )?;
165 this.add_parents(&mut included_parents, additional_parents);
166
167 let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
168
169 if this.budget_used > max_bytes {
170 return Err(anyhow!(
171 "Excerpt + signatures size of {} already exceeds budget of {}",
172 this.budget_used,
173 max_bytes
174 ));
175 }
176
177 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
178 struct QueueEntry {
179 score_density: OrderedFloat<f32>,
180 declaration_index: usize,
181 style: DeclarationStyle,
182 }
183
184 // Initialize priority queue with the best score for each snippet.
185 let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
186 for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
187 let (style, score_density) = DeclarationStyle::iter()
188 .map(|style| {
189 (
190 style,
191 OrderedFloat(declaration_score_density(&declaration, style)),
192 )
193 })
194 .max_by_key(|(_, score_density)| *score_density)
195 .unwrap();
196 queue.push(QueueEntry {
197 score_density,
198 declaration_index,
199 style,
200 });
201 }
202
203 // Knapsack selection loop
204 while let Some(queue_entry) = queue.pop() {
205 let Some(declaration) = request
206 .referenced_declarations
207 .get(queue_entry.declaration_index)
208 else {
209 return Err(anyhow!(
210 "Invalid declaration index {}",
211 queue_entry.declaration_index
212 ));
213 };
214
215 let mut additional_bytes = declaration_size(declaration, queue_entry.style);
216 if this.budget_used + additional_bytes > max_bytes {
217 continue;
218 }
219
220 let additional_parents = this.additional_parent_signatures(
221 &declaration.path,
222 declaration.parent_index,
223 &mut included_parents,
224 )?;
225 additional_bytes += additional_parents
226 .iter()
227 .map(|(_, snippet)| snippet.text.len())
228 .sum::<usize>();
229 if this.budget_used + additional_bytes > max_bytes {
230 continue;
231 }
232
233 this.budget_used += additional_bytes;
234 this.add_parents(&mut included_parents, additional_parents);
235 let planned_snippet = match queue_entry.style {
236 DeclarationStyle::Signature => {
237 let Some(text) = declaration.text.get(declaration.signature_range.clone())
238 else {
239 return Err(anyhow!(
240 "Invalid declaration signature_range {:?} with text.len() = {}",
241 declaration.signature_range,
242 declaration.text.len()
243 ));
244 };
245 let signature_start_line = declaration.range.start
246 + Line(
247 declaration.text[..declaration.signature_range.start]
248 .lines()
249 .count() as u32,
250 );
251 let signature_end_line = signature_start_line
252 + Line(
253 declaration.text
254 [declaration.signature_range.start..declaration.signature_range.end]
255 .lines()
256 .count() as u32,
257 );
258 let range = signature_start_line..signature_end_line;
259
260 PlannedSnippet {
261 path: declaration.path.clone(),
262 range,
263 text,
264 text_is_truncated: declaration.text_is_truncated,
265 }
266 }
267 DeclarationStyle::Declaration => PlannedSnippet {
268 path: declaration.path.clone(),
269 range: declaration.range.clone(),
270 text: &declaration.text,
271 text_is_truncated: declaration.text_is_truncated,
272 },
273 };
274 this.snippets.push(planned_snippet);
275
276 // When a Signature is consumed, insert an entry for Definition style.
277 if queue_entry.style == DeclarationStyle::Signature {
278 let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
279 let declaration_size =
280 declaration_size(&declaration, DeclarationStyle::Declaration);
281 let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
282 let declaration_score =
283 declaration_score(&declaration, DeclarationStyle::Declaration);
284
285 let score_diff = declaration_score - signature_score;
286 let size_diff = declaration_size.saturating_sub(signature_size);
287 if score_diff > 0.0001 && size_diff > 0 {
288 queue.push(QueueEntry {
289 declaration_index: queue_entry.declaration_index,
290 score_density: OrderedFloat(score_diff / (size_diff as f32)),
291 style: DeclarationStyle::Declaration,
292 });
293 }
294 }
295 }
296
297 anyhow::Ok(this)
298 }
299
300 fn add_parents(
301 &mut self,
302 included_parents: &mut FxHashSet<usize>,
303 snippets: Vec<(usize, PlannedSnippet<'a>)>,
304 ) {
305 for (parent_index, snippet) in snippets {
306 included_parents.insert(parent_index);
307 self.budget_used += snippet.text.len();
308 self.snippets.push(snippet);
309 }
310 }
311
312 fn additional_parent_signatures(
313 &self,
314 path: &Arc<Path>,
315 parent_index: Option<usize>,
316 included_parents: &FxHashSet<usize>,
317 ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
318 let mut results = Vec::new();
319 self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
320 Ok(results)
321 }
322
323 fn additional_parent_signatures_impl(
324 &self,
325 path: &Arc<Path>,
326 parent_index: Option<usize>,
327 included_parents: &FxHashSet<usize>,
328 results: &mut Vec<(usize, PlannedSnippet<'a>)>,
329 ) -> Result<()> {
330 let Some(parent_index) = parent_index else {
331 return Ok(());
332 };
333 if included_parents.contains(&parent_index) {
334 return Ok(());
335 }
336 let Some(parent_signature) = self.request.signatures.get(parent_index) else {
337 return Err(anyhow!("Invalid parent index {}", parent_index));
338 };
339 results.push((
340 parent_index,
341 PlannedSnippet {
342 path: path.clone(),
343 range: parent_signature.range.clone(),
344 text: &parent_signature.text,
345 text_is_truncated: parent_signature.text_is_truncated,
346 },
347 ));
348 self.additional_parent_signatures_impl(
349 path,
350 parent_signature.parent_index,
351 included_parents,
352 results,
353 )
354 }
355
356 /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
357 /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
358 /// chunks.
359 pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
360 let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
361 FxHashMap::default();
362 for snippet in &self.snippets {
363 file_to_snippets
364 .entry(&snippet.path)
365 .or_default()
366 .push(snippet);
367 }
368
369 // Reorder so that file with cursor comes last
370 let mut file_snippets = Vec::new();
371 let mut excerpt_file_snippets = Vec::new();
372 for (file_path, snippets) in file_to_snippets {
373 if file_path == self.request.excerpt_path.as_ref() {
374 excerpt_file_snippets = snippets;
375 } else {
376 file_snippets.push((file_path, snippets, false));
377 }
378 }
379 let excerpt_snippet = PlannedSnippet {
380 path: self.request.excerpt_path.clone(),
381 range: self.request.excerpt_line_range.clone(),
382 text: &self.request.excerpt,
383 text_is_truncated: false,
384 };
385 excerpt_file_snippets.push(&excerpt_snippet);
386 file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
387
388 let mut excerpt_file_insertions = match self.request.prompt_format {
389 PromptFormat::MarkedExcerpt => vec![
390 (
391 Point {
392 line: self.request.excerpt_line_range.start,
393 column: 0,
394 },
395 EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
396 ),
397 (self.request.cursor_point, CURSOR_MARKER),
398 (
399 Point {
400 line: self.request.excerpt_line_range.end,
401 column: 0,
402 },
403 EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
404 ),
405 ],
406 PromptFormat::LabeledSections => vec![(self.request.cursor_point, CURSOR_MARKER)],
407 PromptFormat::NumLinesUniDiff => {
408 vec![(self.request.cursor_point, CURSOR_MARKER)]
409 }
410 PromptFormat::OnlySnippets => vec![],
411 };
412
413 let mut prompt = match self.request.prompt_format {
414 PromptFormat::MarkedExcerpt => MARKED_EXCERPT_INSTRUCTIONS.to_string(),
415 PromptFormat::LabeledSections => LABELED_SECTIONS_INSTRUCTIONS.to_string(),
416 PromptFormat::NumLinesUniDiff => NUMBERED_LINES_INSTRUCTIONS.to_string(),
417 // only intended for use via zeta_cli
418 PromptFormat::OnlySnippets => String::new(),
419 };
420
421 if self.request.events.is_empty() {
422 prompt.push_str("No edits yet.\n\n");
423 } else {
424 prompt.push_str(
425 "The following are the latest edits made by the user, from earlier to later.\n\n",
426 );
427 Self::push_events(&mut prompt, &self.request.events);
428 }
429
430 if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
431 if self.request.referenced_declarations.is_empty() {
432 prompt.push_str(indoc! {"
433 # File under the cursor:
434
435 The cursor marker <|user_cursor|> indicates the current user cursor position.
436 The file is in current state, edits from edit history have been applied.
437 We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
438
439 "});
440 } else {
441 // Note: This hasn't been trained on yet
442 prompt.push_str(indoc! {"
443 # Code Excerpts:
444
445 The cursor marker <|user_cursor|> indicates the current user cursor position.
446 Other excerpts of code from the project have been included as context based on their similarity to the code under the cursor.
447 Context excerpts are not guaranteed to be relevant, so use your own judgement.
448 Files are in their current state, edits from edit history have been applied.
449 We prepend line numbers (e.g., `123|<actual line>`); they are not part of the file.
450
451 "});
452 }
453 } else {
454 prompt.push_str("\n## Code\n\n");
455 }
456
457 let section_labels =
458 self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
459
460 if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
461 prompt.push_str(UNIFIED_DIFF_REMINDER);
462 }
463
464 Ok((prompt, section_labels))
465 }
466
467 fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
468 for event in events {
469 match event {
470 Event::BufferChange {
471 path,
472 old_path,
473 diff,
474 predicted,
475 } => {
476 if let Some(old_path) = &old_path
477 && let Some(new_path) = &path
478 {
479 if old_path != new_path {
480 writeln!(
481 output,
482 "User renamed {} to {}\n\n",
483 old_path.display(),
484 new_path.display()
485 )
486 .unwrap();
487 }
488 }
489
490 let path = path
491 .as_ref()
492 .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
493
494 if *predicted {
495 writeln!(
496 output,
497 "User accepted prediction {:?}:\n`````diff\n{}\n`````\n",
498 path, diff
499 )
500 .unwrap();
501 } else {
502 writeln!(
503 output,
504 "User edited {:?}:\n`````diff\n{}\n`````\n",
505 path, diff
506 )
507 .unwrap();
508 }
509 }
510 }
511 }
512 }
513
514 fn push_file_snippets(
515 &self,
516 output: &mut String,
517 excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
518 file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
519 ) -> Result<SectionLabels> {
520 let mut section_ranges = Vec::new();
521 let mut excerpt_index = None;
522
523 for (file_path, mut snippets, is_excerpt_file) in file_snippets {
524 snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
525
526 // TODO: What if the snippets get expanded too large to be editable?
527 let mut current_snippet: Option<(&PlannedSnippet, Range<Line>)> = None;
528 let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<Line>)> = Vec::new();
529 for snippet in snippets {
530 if let Some((_, current_snippet_range)) = current_snippet.as_mut()
531 && snippet.range.start <= current_snippet_range.end
532 {
533 current_snippet_range.end = current_snippet_range.end.max(snippet.range.end);
534 continue;
535 }
536 if let Some(current_snippet) = current_snippet.take() {
537 disjoint_snippets.push(current_snippet);
538 }
539 current_snippet = Some((snippet, snippet.range.clone()));
540 }
541 if let Some(current_snippet) = current_snippet.take() {
542 disjoint_snippets.push(current_snippet);
543 }
544
545 // TODO: remove filename=?
546 writeln!(output, "`````filename={}", file_path.display()).ok();
547 let mut skipped_last_snippet = false;
548 for (snippet, range) in disjoint_snippets {
549 let section_index = section_ranges.len();
550
551 match self.request.prompt_format {
552 PromptFormat::MarkedExcerpt
553 | PromptFormat::OnlySnippets
554 | PromptFormat::NumLinesUniDiff => {
555 if range.start.0 > 0 && !skipped_last_snippet {
556 output.push_str("…\n");
557 }
558 }
559 PromptFormat::LabeledSections => {
560 if is_excerpt_file
561 && range.start <= self.request.excerpt_line_range.start
562 && range.end >= self.request.excerpt_line_range.end
563 {
564 writeln!(output, "<|current_section|>").ok();
565 } else {
566 writeln!(output, "<|section_{}|>", section_index).ok();
567 }
568 }
569 }
570
571 let push_full_snippet = |output: &mut String| {
572 if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
573 for (i, line) in snippet.text.lines().enumerate() {
574 writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?;
575 }
576 } else {
577 output.push_str(&snippet.text);
578 }
579 anyhow::Ok(())
580 };
581
582 if is_excerpt_file {
583 if self.request.prompt_format == PromptFormat::OnlySnippets {
584 if range.start >= self.request.excerpt_line_range.start
585 && range.end <= self.request.excerpt_line_range.end
586 {
587 skipped_last_snippet = true;
588 } else {
589 skipped_last_snippet = false;
590 output.push_str(snippet.text);
591 }
592 } else if !excerpt_file_insertions.is_empty() {
593 let lines = snippet.text.lines().collect::<Vec<_>>();
594 let push_line = |output: &mut String, line_ix: usize| {
595 if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
596 write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?;
597 }
598 anyhow::Ok(writeln!(output, "{}", lines[line_ix])?)
599 };
600 let mut last_line_ix = 0;
601 let mut insertion_ix = 0;
602 while insertion_ix < excerpt_file_insertions.len() {
603 let (point, insertion) = &excerpt_file_insertions[insertion_ix];
604 let found = point.line >= range.start && point.line <= range.end;
605 if found {
606 excerpt_index = Some(section_index);
607 let insertion_line_ix = (point.line.0 - range.start.0) as usize;
608 for line_ix in last_line_ix..insertion_line_ix {
609 push_line(output, line_ix)?;
610 }
611 if let Some(next_line) = lines.get(insertion_line_ix) {
612 if self.request.prompt_format == PromptFormat::NumLinesUniDiff {
613 write!(
614 output,
615 "{}|",
616 insertion_line_ix as u32 + range.start.0 + 1
617 )?
618 }
619 output.push_str(&next_line[..point.column as usize]);
620 output.push_str(insertion);
621 writeln!(output, "{}", &next_line[point.column as usize..])?;
622 } else {
623 writeln!(output, "{}", insertion)?;
624 }
625 last_line_ix = insertion_line_ix + 1;
626 excerpt_file_insertions.remove(insertion_ix);
627 continue;
628 }
629 insertion_ix += 1;
630 }
631 skipped_last_snippet = false;
632 for line_ix in last_line_ix..lines.len() {
633 push_line(output, line_ix)?;
634 }
635 } else {
636 skipped_last_snippet = false;
637 push_full_snippet(output)?;
638 }
639 } else {
640 skipped_last_snippet = false;
641 push_full_snippet(output)?;
642 }
643
644 section_ranges.push((snippet.path.clone(), range));
645 }
646
647 output.push_str("`````\n\n");
648 }
649
650 Ok(SectionLabels {
651 // TODO: Clean this up
652 excerpt_index: match self.request.prompt_format {
653 PromptFormat::OnlySnippets => 0,
654 _ => excerpt_index.context("bug: no snippet found for excerpt")?,
655 },
656 section_ranges,
657 })
658 }
659}
660
661fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
662 declaration_score(declaration, style) / declaration_size(declaration, style) as f32
663}
664
665fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
666 match style {
667 DeclarationStyle::Signature => declaration.signature_score,
668 DeclarationStyle::Declaration => declaration.declaration_score,
669 }
670}
671
672fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
673 match style {
674 DeclarationStyle::Signature => declaration.signature_range.len(),
675 DeclarationStyle::Declaration => declaration.text.len(),
676 }
677}