1//! Zeta2 prompt planning and generation code shared with cloud.
2
3use anyhow::{Context as _, Result, anyhow};
4use cloud_llm_client::predict_edits_v3::{
5 self, Event, Line, Point, PromptFormat, ReferencedDeclaration,
6};
7use indoc::indoc;
8use ordered_float::OrderedFloat;
9use rustc_hash::{FxHashMap, FxHashSet};
10use serde::Serialize;
11use std::fmt::Write;
12use std::sync::Arc;
13use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
14use strum::{EnumIter, IntoEnumIterator};
15
16pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
17
18pub const CURSOR_MARKER: &str = "<|cursor_position|>";
19/// NOTE: Differs from zed version of constant - includes a newline
20pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
21/// NOTE: Differs from zed version of constant - includes a newline
22pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
23
24// TODO: use constants for markers?
25const MARKED_EXCERPT_SYSTEM_PROMPT: &str = indoc! {"
26 You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
27
28 The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|cursor_position|>. Please respond with edited code for that region.
29
30 Other code is provided for context, and `…` indicates when code has been skipped.
31"};
32
33const LABELED_SECTIONS_SYSTEM_PROMPT: &str = indoc! {r#"
34 You are a code completion assistant and your task is to analyze user edits, and suggest an edit to one of the provided sections of code.
35
36 Sections of code are grouped by file and then labeled by `<|section_N|>` (e.g `<|section_8|>`).
37
38 The cursor position is marked with `<|cursor_position|>` and it will appear within a special section labeled `<|current_section|>`. Prefer editing the current section until no more changes are needed within it.
39
40 Respond ONLY with the name of the section to edit on a single line, followed by all of the code that should replace that section. For example:
41
42 <|current_section|>
43 for i in 0..16 {
44 println!("{i}");
45 }
46"#};
47
48const NUMBERED_LINES_SYSTEM_PROMPT: &str = indoc! {r#"
49 # Instructions
50
51 You are a code completion assistant helping a programmer finish their work. Your task is to:
52
53 1. Analyze the edit history to understand what the programmer is trying to achieve
54 2. Identify any incomplete refactoring or changes that need to be finished
55 3. Make the remaining edits that a human programmer would logically make next
56 4. Apply systematic changes consistently across the entire codebase - if you see a pattern starting, complete it everywhere.
57
58 Focus on:
59 - Understanding the intent behind the changes (e.g., improving error handling, refactoring APIs, fixing bugs)
60 - Completing any partially-applied changes across the codebase
61 - Ensuring consistency with the programming style and patterns already established
62 - Making edits that maintain or improve code quality
63 - If the programmer started refactoring one instance of a pattern, find and update ALL similar instances
64 - Don't write a lot of code if you're not sure what to do
65
66 Rules:
67 - Do not just mechanically apply patterns - reason about what changes make sense given the context and the programmer's apparent goals.
68 - Do not just fix syntax errors - look for the broader refactoring pattern and apply it systematically throughout the code.
69 - Write the edits in the unified diff format as shown in the example.
70
71 # Example output:
72
73 ```
74 --- a/distill-claude/tmp-outs/edits_history.txt
75 +++ b/distill-claude/tmp-outs/edits_history.txt
76 @@ -1,3 +1,3 @@
77 -
78 -
79 -import sys
80 +import json
81 ```
82"#};
83
84pub struct PlannedPrompt<'a> {
85 request: &'a predict_edits_v3::PredictEditsRequest,
86 /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
87 /// `to_prompt_string`.
88 snippets: Vec<PlannedSnippet<'a>>,
89 budget_used: usize,
90}
91
92pub fn system_prompt(format: PromptFormat) -> &'static str {
93 match format {
94 PromptFormat::MarkedExcerpt => MARKED_EXCERPT_SYSTEM_PROMPT,
95 PromptFormat::LabeledSections => LABELED_SECTIONS_SYSTEM_PROMPT,
96 PromptFormat::NumberedLines => NUMBERED_LINES_SYSTEM_PROMPT,
97 // only intended for use via zeta_cli
98 PromptFormat::OnlySnippets => "",
99 }
100}
101
102#[derive(Clone, Debug)]
103pub struct PlannedSnippet<'a> {
104 path: Arc<Path>,
105 range: Range<Line>,
106 text: &'a str,
107 // TODO: Indicate this in the output
108 #[allow(dead_code)]
109 text_is_truncated: bool,
110}
111
112#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
113pub enum DeclarationStyle {
114 Signature,
115 Declaration,
116}
117
118#[derive(Clone, Debug, Serialize)]
119pub struct SectionLabels {
120 pub excerpt_index: usize,
121 pub section_ranges: Vec<(Arc<Path>, Range<Line>)>,
122}
123
124impl<'a> PlannedPrompt<'a> {
125 /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
126 ///
127 /// Initializes a priority queue by populating it with each snippet, finding the
128 /// DeclarationStyle that minimizes `score_density = score / snippet.range(style).len()`. When a
129 /// "signature" snippet is popped, insert an entry for the "declaration" variant that reflects
130 /// the cost of upgrade.
131 ///
132 /// TODO: Implement an early halting condition. One option might be to have another priority
133 /// queue where the score is the size, and update it accordingly. Another option might be to
134 /// have some simpler heuristic like bailing after N failed insertions, or based on how much
135 /// budget is left.
136 ///
137 /// TODO: Has the current known sources of imprecision:
138 ///
139 /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
140 /// plan even though the containing struct is already included.
141 ///
142 /// * Does not consider cost of signatures when ranking snippets - this is tricky since
143 /// signatures may be shared by multiple snippets.
144 ///
145 /// * Does not include file paths / other text when considering max_bytes.
146 pub fn populate(request: &'a predict_edits_v3::PredictEditsRequest) -> Result<Self> {
147 let mut this = PlannedPrompt {
148 request,
149 snippets: Vec::new(),
150 budget_used: request.excerpt.len(),
151 };
152 let mut included_parents = FxHashSet::default();
153 let additional_parents = this.additional_parent_signatures(
154 &request.excerpt_path,
155 request.excerpt_parent,
156 &included_parents,
157 )?;
158 this.add_parents(&mut included_parents, additional_parents);
159
160 let max_bytes = request.prompt_max_bytes.unwrap_or(DEFAULT_MAX_PROMPT_BYTES);
161
162 if this.budget_used > max_bytes {
163 return Err(anyhow!(
164 "Excerpt + signatures size of {} already exceeds budget of {}",
165 this.budget_used,
166 max_bytes
167 ));
168 }
169
170 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
171 struct QueueEntry {
172 score_density: OrderedFloat<f32>,
173 declaration_index: usize,
174 style: DeclarationStyle,
175 }
176
177 // Initialize priority queue with the best score for each snippet.
178 let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
179 for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
180 let (style, score_density) = DeclarationStyle::iter()
181 .map(|style| {
182 (
183 style,
184 OrderedFloat(declaration_score_density(&declaration, style)),
185 )
186 })
187 .max_by_key(|(_, score_density)| *score_density)
188 .unwrap();
189 queue.push(QueueEntry {
190 score_density,
191 declaration_index,
192 style,
193 });
194 }
195
196 // Knapsack selection loop
197 while let Some(queue_entry) = queue.pop() {
198 let Some(declaration) = request
199 .referenced_declarations
200 .get(queue_entry.declaration_index)
201 else {
202 return Err(anyhow!(
203 "Invalid declaration index {}",
204 queue_entry.declaration_index
205 ));
206 };
207
208 let mut additional_bytes = declaration_size(declaration, queue_entry.style);
209 if this.budget_used + additional_bytes > max_bytes {
210 continue;
211 }
212
213 let additional_parents = this.additional_parent_signatures(
214 &declaration.path,
215 declaration.parent_index,
216 &mut included_parents,
217 )?;
218 additional_bytes += additional_parents
219 .iter()
220 .map(|(_, snippet)| snippet.text.len())
221 .sum::<usize>();
222 if this.budget_used + additional_bytes > max_bytes {
223 continue;
224 }
225
226 this.budget_used += additional_bytes;
227 this.add_parents(&mut included_parents, additional_parents);
228 let planned_snippet = match queue_entry.style {
229 DeclarationStyle::Signature => {
230 let Some(text) = declaration.text.get(declaration.signature_range.clone())
231 else {
232 return Err(anyhow!(
233 "Invalid declaration signature_range {:?} with text.len() = {}",
234 declaration.signature_range,
235 declaration.text.len()
236 ));
237 };
238 let signature_start_line = declaration.range.start
239 + Line(
240 declaration.text[..declaration.signature_range.start]
241 .lines()
242 .count() as u32,
243 );
244 let signature_end_line = signature_start_line
245 + Line(
246 declaration.text
247 [declaration.signature_range.start..declaration.signature_range.end]
248 .lines()
249 .count() as u32,
250 );
251 let range = signature_start_line..signature_end_line;
252
253 PlannedSnippet {
254 path: declaration.path.clone(),
255 range,
256 text,
257 text_is_truncated: declaration.text_is_truncated,
258 }
259 }
260 DeclarationStyle::Declaration => PlannedSnippet {
261 path: declaration.path.clone(),
262 range: declaration.range.clone(),
263 text: &declaration.text,
264 text_is_truncated: declaration.text_is_truncated,
265 },
266 };
267 this.snippets.push(planned_snippet);
268
269 // When a Signature is consumed, insert an entry for Definition style.
270 if queue_entry.style == DeclarationStyle::Signature {
271 let signature_size = declaration_size(&declaration, DeclarationStyle::Signature);
272 let declaration_size =
273 declaration_size(&declaration, DeclarationStyle::Declaration);
274 let signature_score = declaration_score(&declaration, DeclarationStyle::Signature);
275 let declaration_score =
276 declaration_score(&declaration, DeclarationStyle::Declaration);
277
278 let score_diff = declaration_score - signature_score;
279 let size_diff = declaration_size.saturating_sub(signature_size);
280 if score_diff > 0.0001 && size_diff > 0 {
281 queue.push(QueueEntry {
282 declaration_index: queue_entry.declaration_index,
283 score_density: OrderedFloat(score_diff / (size_diff as f32)),
284 style: DeclarationStyle::Declaration,
285 });
286 }
287 }
288 }
289
290 anyhow::Ok(this)
291 }
292
293 fn add_parents(
294 &mut self,
295 included_parents: &mut FxHashSet<usize>,
296 snippets: Vec<(usize, PlannedSnippet<'a>)>,
297 ) {
298 for (parent_index, snippet) in snippets {
299 included_parents.insert(parent_index);
300 self.budget_used += snippet.text.len();
301 self.snippets.push(snippet);
302 }
303 }
304
305 fn additional_parent_signatures(
306 &self,
307 path: &Arc<Path>,
308 parent_index: Option<usize>,
309 included_parents: &FxHashSet<usize>,
310 ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
311 let mut results = Vec::new();
312 self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
313 Ok(results)
314 }
315
316 fn additional_parent_signatures_impl(
317 &self,
318 path: &Arc<Path>,
319 parent_index: Option<usize>,
320 included_parents: &FxHashSet<usize>,
321 results: &mut Vec<(usize, PlannedSnippet<'a>)>,
322 ) -> Result<()> {
323 let Some(parent_index) = parent_index else {
324 return Ok(());
325 };
326 if included_parents.contains(&parent_index) {
327 return Ok(());
328 }
329 let Some(parent_signature) = self.request.signatures.get(parent_index) else {
330 return Err(anyhow!("Invalid parent index {}", parent_index));
331 };
332 results.push((
333 parent_index,
334 PlannedSnippet {
335 path: path.clone(),
336 range: parent_signature.range.clone(),
337 text: &parent_signature.text,
338 text_is_truncated: parent_signature.text_is_truncated,
339 },
340 ));
341 self.additional_parent_signatures_impl(
342 path,
343 parent_signature.parent_index,
344 included_parents,
345 results,
346 )
347 }
348
349 /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
350 /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
351 /// chunks.
352 pub fn to_prompt_string(&'a self) -> Result<(String, SectionLabels)> {
353 let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
354 FxHashMap::default();
355 for snippet in &self.snippets {
356 file_to_snippets
357 .entry(&snippet.path)
358 .or_default()
359 .push(snippet);
360 }
361
362 // Reorder so that file with cursor comes last
363 let mut file_snippets = Vec::new();
364 let mut excerpt_file_snippets = Vec::new();
365 for (file_path, snippets) in file_to_snippets {
366 if file_path == self.request.excerpt_path.as_ref() {
367 excerpt_file_snippets = snippets;
368 } else {
369 file_snippets.push((file_path, snippets, false));
370 }
371 }
372 let excerpt_snippet = PlannedSnippet {
373 path: self.request.excerpt_path.clone(),
374 range: self.request.excerpt_line_range.clone(),
375 text: &self.request.excerpt,
376 text_is_truncated: false,
377 };
378 excerpt_file_snippets.push(&excerpt_snippet);
379 file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
380
381 let mut excerpt_file_insertions = match self.request.prompt_format {
382 PromptFormat::MarkedExcerpt => vec![
383 (
384 Point {
385 line: self.request.excerpt_line_range.start,
386 column: 0,
387 },
388 EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
389 ),
390 (self.request.cursor_point, CURSOR_MARKER),
391 (
392 Point {
393 line: self.request.excerpt_line_range.end,
394 column: 0,
395 },
396 EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
397 ),
398 ],
399 PromptFormat::LabeledSections => vec![(self.request.cursor_point, CURSOR_MARKER)],
400 PromptFormat::NumberedLines => vec![(self.request.cursor_point, CURSOR_MARKER)],
401 PromptFormat::OnlySnippets => vec![],
402 };
403
404 let mut prompt = String::new();
405 prompt.push_str("## User Edits\n\n");
406 if self.request.events.is_empty() {
407 prompt.push_str("No edits yet.\n");
408 } else {
409 Self::push_events(&mut prompt, &self.request.events);
410 }
411
412 prompt.push_str("\n## Code\n\n");
413 let section_labels =
414 self.push_file_snippets(&mut prompt, &mut excerpt_file_insertions, file_snippets)?;
415 Ok((prompt, section_labels))
416 }
417
418 fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
419 for event in events {
420 match event {
421 Event::BufferChange {
422 path,
423 old_path,
424 diff,
425 predicted,
426 } => {
427 if let Some(old_path) = &old_path
428 && let Some(new_path) = &path
429 {
430 if old_path != new_path {
431 writeln!(
432 output,
433 "User renamed {} to {}\n\n",
434 old_path.display(),
435 new_path.display()
436 )
437 .unwrap();
438 }
439 }
440
441 let path = path
442 .as_ref()
443 .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
444
445 if *predicted {
446 writeln!(
447 output,
448 "User accepted prediction {:?}:\n`````diff\n{}\n`````\n",
449 path, diff
450 )
451 .unwrap();
452 } else {
453 writeln!(
454 output,
455 "User edited {:?}:\n`````diff\n{}\n`````\n",
456 path, diff
457 )
458 .unwrap();
459 }
460 }
461 }
462 }
463 }
464
465 fn push_file_snippets(
466 &self,
467 output: &mut String,
468 excerpt_file_insertions: &mut Vec<(Point, &'static str)>,
469 file_snippets: Vec<(&'a Path, Vec<&'a PlannedSnippet>, bool)>,
470 ) -> Result<SectionLabels> {
471 let mut section_ranges = Vec::new();
472 let mut excerpt_index = None;
473
474 for (file_path, mut snippets, is_excerpt_file) in file_snippets {
475 snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
476
477 // TODO: What if the snippets get expanded too large to be editable?
478 let mut current_snippet: Option<(&PlannedSnippet, Range<Line>)> = None;
479 let mut disjoint_snippets: Vec<(&PlannedSnippet, Range<Line>)> = Vec::new();
480 for snippet in snippets {
481 if let Some((_, current_snippet_range)) = current_snippet.as_mut()
482 && snippet.range.start <= current_snippet_range.end
483 {
484 current_snippet_range.end = current_snippet_range.end.max(snippet.range.end);
485 continue;
486 }
487 if let Some(current_snippet) = current_snippet.take() {
488 disjoint_snippets.push(current_snippet);
489 }
490 current_snippet = Some((snippet, snippet.range.clone()));
491 }
492 if let Some(current_snippet) = current_snippet.take() {
493 disjoint_snippets.push(current_snippet);
494 }
495
496 // TODO: remove filename=?
497 writeln!(output, "`````filename={}", file_path.display()).ok();
498 let mut skipped_last_snippet = false;
499 for (snippet, range) in disjoint_snippets {
500 let section_index = section_ranges.len();
501
502 match self.request.prompt_format {
503 PromptFormat::MarkedExcerpt
504 | PromptFormat::OnlySnippets
505 | PromptFormat::NumberedLines => {
506 if range.start.0 > 0 && !skipped_last_snippet {
507 output.push_str("…\n");
508 }
509 }
510 PromptFormat::LabeledSections => {
511 if is_excerpt_file
512 && range.start <= self.request.excerpt_line_range.start
513 && range.end >= self.request.excerpt_line_range.end
514 {
515 writeln!(output, "<|current_section|>").ok();
516 } else {
517 writeln!(output, "<|section_{}|>", section_index).ok();
518 }
519 }
520 }
521
522 let push_full_snippet = |output: &mut String| {
523 if self.request.prompt_format == PromptFormat::NumberedLines {
524 for (i, line) in snippet.text.lines().enumerate() {
525 writeln!(output, "{}|{}", i as u32 + range.start.0 + 1, line)?;
526 }
527 } else {
528 output.push_str(&snippet.text);
529 }
530 anyhow::Ok(())
531 };
532
533 if is_excerpt_file {
534 if self.request.prompt_format == PromptFormat::OnlySnippets {
535 if range.start >= self.request.excerpt_line_range.start
536 && range.end <= self.request.excerpt_line_range.end
537 {
538 skipped_last_snippet = true;
539 } else {
540 skipped_last_snippet = false;
541 output.push_str(snippet.text);
542 }
543 } else if !excerpt_file_insertions.is_empty() {
544 let lines = snippet.text.lines().collect::<Vec<_>>();
545 let push_line = |output: &mut String, line_ix: usize| {
546 if self.request.prompt_format == PromptFormat::NumberedLines {
547 write!(output, "{}|", line_ix as u32 + range.start.0 + 1)?;
548 }
549 anyhow::Ok(writeln!(output, "{}", lines[line_ix])?)
550 };
551 let mut last_line_ix = 0;
552 let mut insertion_ix = 0;
553 while insertion_ix < excerpt_file_insertions.len() {
554 let (point, insertion) = &excerpt_file_insertions[insertion_ix];
555 let found = point.line >= range.start && point.line <= range.end;
556 if found {
557 excerpt_index = Some(section_index);
558 let insertion_line_ix = (point.line.0 - range.start.0) as usize;
559 for line_ix in last_line_ix..insertion_line_ix {
560 push_line(output, line_ix)?;
561 }
562 if let Some(next_line) = lines.get(insertion_line_ix) {
563 if self.request.prompt_format == PromptFormat::NumberedLines {
564 write!(
565 output,
566 "{}|",
567 insertion_line_ix as u32 + range.start.0 + 1
568 )?
569 }
570 output.push_str(&next_line[..point.column as usize]);
571 output.push_str(insertion);
572 writeln!(output, "{}", &next_line[point.column as usize..])?;
573 } else {
574 writeln!(output, "{}", insertion)?;
575 }
576 last_line_ix = insertion_line_ix + 1;
577 excerpt_file_insertions.remove(insertion_ix);
578 continue;
579 }
580 insertion_ix += 1;
581 }
582 skipped_last_snippet = false;
583 for line_ix in last_line_ix..lines.len() {
584 push_line(output, line_ix)?;
585 }
586 } else {
587 skipped_last_snippet = false;
588 push_full_snippet(output)?;
589 }
590 } else {
591 skipped_last_snippet = false;
592 push_full_snippet(output)?;
593 }
594
595 section_ranges.push((snippet.path.clone(), range));
596 }
597
598 output.push_str("`````\n\n");
599 }
600
601 Ok(SectionLabels {
602 // TODO: Clean this up
603 excerpt_index: match self.request.prompt_format {
604 PromptFormat::OnlySnippets => 0,
605 _ => excerpt_index.context("bug: no snippet found for excerpt")?,
606 },
607 section_ranges,
608 })
609 }
610}
611
612fn declaration_score_density(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
613 declaration_score(declaration, style) / declaration_size(declaration, style) as f32
614}
615
616fn declaration_score(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> f32 {
617 match style {
618 DeclarationStyle::Signature => declaration.signature_score,
619 DeclarationStyle::Declaration => declaration.declaration_score,
620 }
621}
622
623fn declaration_size(declaration: &ReferencedDeclaration, style: DeclarationStyle) -> usize {
624 match style {
625 DeclarationStyle::Signature => declaration.signature_range.len(),
626 DeclarationStyle::Declaration => declaration.text.len(),
627 }
628}