1//! Zeta2 prompt planning and generation code shared with cloud.
2
3use anyhow::{Result, anyhow};
4use cloud_llm_client::predict_edits_v3::{self, Event, ReferencedDeclaration};
5use indoc::indoc;
6use ordered_float::OrderedFloat;
7use rustc_hash::{FxHashMap, FxHashSet};
8use std::fmt::Write;
9use std::{cmp::Reverse, collections::BinaryHeap, ops::Range, path::Path};
10use strum::{EnumIter, IntoEnumIterator};
11
12pub const DEFAULT_MAX_PROMPT_BYTES: usize = 10 * 1024;
13
14pub const CURSOR_MARKER: &str = "<|user_cursor_is_here|>";
15/// NOTE: Differs from zed version of constant - includes a newline
16pub const EDITABLE_REGION_START_MARKER_WITH_NEWLINE: &str = "<|editable_region_start|>\n";
17/// NOTE: Differs from zed version of constant - includes a newline
18pub const EDITABLE_REGION_END_MARKER_WITH_NEWLINE: &str = "<|editable_region_end|>\n";
19
20// TODO: use constants for markers?
21pub const SYSTEM_PROMPT: &str = indoc! {"
22 You are a code completion assistant and your task is to analyze user edits and then rewrite an excerpt that the user provides, suggesting the appropriate edits within the excerpt, taking into account the cursor location.
23
24 The excerpt to edit will be wrapped in markers <|editable_region_start|> and <|editable_region_end|>. The cursor position is marked with <|user_cursor_is_here|>. Please respond with edited code for that region.
25"};
26
27pub struct PlannedPrompt<'a> {
28 request: &'a predict_edits_v3::PredictEditsRequest,
29 /// Snippets to include in the prompt. These may overlap - they are merged / deduplicated in
30 /// `to_prompt_string`.
31 snippets: Vec<PlannedSnippet<'a>>,
32 budget_used: usize,
33}
34
35pub struct PlanOptions {
36 pub max_bytes: usize,
37}
38
39#[derive(Clone, Debug)]
40pub struct PlannedSnippet<'a> {
41 path: &'a Path,
42 range: Range<usize>,
43 text: &'a str,
44 // TODO: Indicate this in the output
45 #[allow(dead_code)]
46 text_is_truncated: bool,
47}
48
49#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug, PartialOrd, Ord)]
50pub enum SnippetStyle {
51 Signature,
52 Declaration,
53}
54
55impl<'a> PlannedPrompt<'a> {
56 /// Greedy one-pass knapsack algorithm to populate the prompt plan. Does the following:
57 ///
58 /// Initializes a priority queue by populating it with each snippet, finding the SnippetStyle
59 /// that minimizes `score_density = score / snippet.range(style).len()`. When a "signature"
60 /// snippet is popped, insert an entry for the "declaration" variant that reflects the cost of
61 /// upgrade.
62 ///
63 /// TODO: Implement an early halting condition. One option might be to have another priority
64 /// queue where the score is the size, and update it accordingly. Another option might be to
65 /// have some simpler heuristic like bailing after N failed insertions, or based on how much
66 /// budget is left.
67 ///
68 /// TODO: Has the current known sources of imprecision:
69 ///
70 /// * Does not consider snippet overlap when ranking. For example, it might add a field to the
71 /// plan even though the containing struct is already included.
72 ///
73 /// * Does not consider cost of signatures when ranking snippets - this is tricky since
74 /// signatures may be shared by multiple snippets.
75 ///
76 /// * Does not include file paths / other text when considering max_bytes.
77 pub fn populate(
78 request: &'a predict_edits_v3::PredictEditsRequest,
79 options: &PlanOptions,
80 ) -> Result<Self> {
81 let mut this = PlannedPrompt {
82 request,
83 snippets: Vec::new(),
84 budget_used: request.excerpt.len(),
85 };
86 let mut included_parents = FxHashSet::default();
87 let additional_parents = this.additional_parent_signatures(
88 &request.excerpt_path,
89 request.excerpt_parent,
90 &included_parents,
91 )?;
92 this.add_parents(&mut included_parents, additional_parents);
93
94 if this.budget_used > options.max_bytes {
95 return Err(anyhow!(
96 "Excerpt + signatures size of {} already exceeds budget of {}",
97 this.budget_used,
98 options.max_bytes
99 ));
100 }
101
102 #[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord)]
103 struct QueueEntry {
104 score_density: OrderedFloat<f32>,
105 declaration_index: usize,
106 style: SnippetStyle,
107 }
108
109 // Initialize priority queue with the best score for each snippet.
110 let mut queue: BinaryHeap<QueueEntry> = BinaryHeap::new();
111 for (declaration_index, declaration) in request.referenced_declarations.iter().enumerate() {
112 let (style, score_density) = SnippetStyle::iter()
113 .map(|style| {
114 (
115 style,
116 OrderedFloat(declaration_score_density(&declaration, style)),
117 )
118 })
119 .max_by_key(|(_, score_density)| *score_density)
120 .unwrap();
121 queue.push(QueueEntry {
122 score_density,
123 declaration_index,
124 style,
125 });
126 }
127
128 // Knapsack selection loop
129 while let Some(queue_entry) = queue.pop() {
130 let Some(declaration) = request
131 .referenced_declarations
132 .get(queue_entry.declaration_index)
133 else {
134 return Err(anyhow!(
135 "Invalid declaration index {}",
136 queue_entry.declaration_index
137 ));
138 };
139
140 let mut additional_bytes = declaration_size(declaration, queue_entry.style);
141 if this.budget_used + additional_bytes > options.max_bytes {
142 continue;
143 }
144
145 let additional_parents = this.additional_parent_signatures(
146 &declaration.path,
147 declaration.parent_index,
148 &mut included_parents,
149 )?;
150 additional_bytes += additional_parents
151 .iter()
152 .map(|(_, snippet)| snippet.text.len())
153 .sum::<usize>();
154 if this.budget_used + additional_bytes > options.max_bytes {
155 continue;
156 }
157
158 this.budget_used += additional_bytes;
159 this.add_parents(&mut included_parents, additional_parents);
160 let planned_snippet = match queue_entry.style {
161 SnippetStyle::Signature => {
162 let Some(text) = declaration.text.get(declaration.signature_range.clone())
163 else {
164 return Err(anyhow!(
165 "Invalid declaration signature_range {:?} with text.len() = {}",
166 declaration.signature_range,
167 declaration.text.len()
168 ));
169 };
170 PlannedSnippet {
171 path: &declaration.path,
172 range: (declaration.signature_range.start + declaration.range.start)
173 ..(declaration.signature_range.end + declaration.range.start),
174 text,
175 text_is_truncated: declaration.text_is_truncated,
176 }
177 }
178 SnippetStyle::Declaration => PlannedSnippet {
179 path: &declaration.path,
180 range: declaration.range.clone(),
181 text: &declaration.text,
182 text_is_truncated: declaration.text_is_truncated,
183 },
184 };
185 this.snippets.push(planned_snippet);
186
187 // When a Signature is consumed, insert an entry for Definition style.
188 if queue_entry.style == SnippetStyle::Signature {
189 let signature_size = declaration_size(&declaration, SnippetStyle::Signature);
190 let declaration_size = declaration_size(&declaration, SnippetStyle::Declaration);
191 let signature_score = declaration_score(&declaration, SnippetStyle::Signature);
192 let declaration_score = declaration_score(&declaration, SnippetStyle::Declaration);
193
194 let score_diff = declaration_score - signature_score;
195 let size_diff = declaration_size.saturating_sub(signature_size);
196 if score_diff > 0.0001 && size_diff > 0 {
197 queue.push(QueueEntry {
198 declaration_index: queue_entry.declaration_index,
199 score_density: OrderedFloat(score_diff / (size_diff as f32)),
200 style: SnippetStyle::Declaration,
201 });
202 }
203 }
204 }
205
206 anyhow::Ok(this)
207 }
208
209 fn add_parents(
210 &mut self,
211 included_parents: &mut FxHashSet<usize>,
212 snippets: Vec<(usize, PlannedSnippet<'a>)>,
213 ) {
214 for (parent_index, snippet) in snippets {
215 included_parents.insert(parent_index);
216 self.budget_used += snippet.text.len();
217 self.snippets.push(snippet);
218 }
219 }
220
221 fn additional_parent_signatures(
222 &self,
223 path: &'a Path,
224 parent_index: Option<usize>,
225 included_parents: &FxHashSet<usize>,
226 ) -> Result<Vec<(usize, PlannedSnippet<'a>)>> {
227 let mut results = Vec::new();
228 self.additional_parent_signatures_impl(path, parent_index, included_parents, &mut results)?;
229 Ok(results)
230 }
231
232 fn additional_parent_signatures_impl(
233 &self,
234 path: &'a Path,
235 parent_index: Option<usize>,
236 included_parents: &FxHashSet<usize>,
237 results: &mut Vec<(usize, PlannedSnippet<'a>)>,
238 ) -> Result<()> {
239 let Some(parent_index) = parent_index else {
240 return Ok(());
241 };
242 if included_parents.contains(&parent_index) {
243 return Ok(());
244 }
245 let Some(parent_signature) = self.request.signatures.get(parent_index) else {
246 return Err(anyhow!("Invalid parent index {}", parent_index));
247 };
248 results.push((
249 parent_index,
250 PlannedSnippet {
251 path,
252 range: parent_signature.range.clone(),
253 text: &parent_signature.text,
254 text_is_truncated: parent_signature.text_is_truncated,
255 },
256 ));
257 self.additional_parent_signatures_impl(
258 path,
259 parent_signature.parent_index,
260 included_parents,
261 results,
262 )
263 }
264
265 /// Renders the planned context. Each file starts with "```FILE_PATH\n` and ends with triple
266 /// backticks, with a newline after each file. Outputs a line with "..." between nonconsecutive
267 /// chunks.
268 pub fn to_prompt_string(&self) -> String {
269 let mut file_to_snippets: FxHashMap<&'a std::path::Path, Vec<&PlannedSnippet<'a>>> =
270 FxHashMap::default();
271 for snippet in &self.snippets {
272 file_to_snippets
273 .entry(&snippet.path)
274 .or_default()
275 .push(snippet);
276 }
277
278 // Reorder so that file with cursor comes last
279 let mut file_snippets = Vec::new();
280 let mut excerpt_file_snippets = Vec::new();
281 for (file_path, snippets) in file_to_snippets {
282 if file_path == &self.request.excerpt_path {
283 excerpt_file_snippets = snippets;
284 } else {
285 file_snippets.push((file_path, snippets, false));
286 }
287 }
288 let excerpt_snippet = PlannedSnippet {
289 path: &self.request.excerpt_path,
290 range: self.request.excerpt_range.clone(),
291 text: &self.request.excerpt,
292 text_is_truncated: false,
293 };
294 excerpt_file_snippets.push(&excerpt_snippet);
295 file_snippets.push((&self.request.excerpt_path, excerpt_file_snippets, true));
296
297 let mut excerpt_file_insertions = vec![
298 (
299 self.request.excerpt_range.start,
300 EDITABLE_REGION_START_MARKER_WITH_NEWLINE,
301 ),
302 (
303 self.request.excerpt_range.start + self.request.cursor_offset,
304 CURSOR_MARKER,
305 ),
306 (
307 self.request
308 .excerpt_range
309 .end
310 .saturating_sub(0)
311 .max(self.request.excerpt_range.start),
312 EDITABLE_REGION_END_MARKER_WITH_NEWLINE,
313 ),
314 ];
315
316 let mut output = String::new();
317 output.push_str("## User Edits\n\n");
318 Self::push_events(&mut output, &self.request.events);
319
320 output.push_str("\n## Code\n\n");
321 Self::push_file_snippets(&mut output, &mut excerpt_file_insertions, file_snippets);
322 output
323 }
324
325 fn push_events(output: &mut String, events: &[predict_edits_v3::Event]) {
326 for event in events {
327 match event {
328 Event::BufferChange {
329 path,
330 old_path,
331 diff,
332 predicted,
333 } => {
334 if let Some(old_path) = &old_path
335 && let Some(new_path) = &path
336 {
337 if old_path != new_path {
338 writeln!(
339 output,
340 "User renamed {} to {}\n\n",
341 old_path.display(),
342 new_path.display()
343 )
344 .unwrap();
345 }
346 }
347
348 let path = path
349 .as_ref()
350 .map_or_else(|| "untitled".to_string(), |path| path.display().to_string());
351
352 if *predicted {
353 writeln!(
354 output,
355 "User accepted prediction {:?}:\n```diff\n{}\n```\n",
356 path, diff
357 )
358 .unwrap();
359 } else {
360 writeln!(output, "User edited {:?}:\n```diff\n{}\n```\n", path, diff)
361 .unwrap();
362 }
363 }
364 }
365 }
366 }
367
368 fn push_file_snippets(
369 output: &mut String,
370 excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
371 file_snippets: Vec<(&Path, Vec<&PlannedSnippet>, bool)>,
372 ) {
373 fn push_excerpt_file_range(
374 range: Range<usize>,
375 text: &str,
376 excerpt_file_insertions: &mut Vec<(usize, &'static str)>,
377 output: &mut String,
378 ) {
379 let mut last_offset = range.start;
380 let mut i = 0;
381 while i < excerpt_file_insertions.len() {
382 let (offset, insertion) = &excerpt_file_insertions[i];
383 let found = *offset >= range.start && *offset <= range.end;
384 if found {
385 output.push_str(&text[last_offset - range.start..offset - range.start]);
386 output.push_str(insertion);
387 last_offset = *offset;
388 excerpt_file_insertions.remove(i);
389 continue;
390 }
391 i += 1;
392 }
393 output.push_str(&text[last_offset - range.start..]);
394 }
395
396 for (file_path, mut snippets, is_excerpt_file) in file_snippets {
397 output.push_str(&format!("```{}\n", file_path.display()));
398
399 let mut last_included_range: Option<Range<usize>> = None;
400 snippets.sort_by_key(|s| (s.range.start, Reverse(s.range.end)));
401 for snippet in snippets {
402 if let Some(last_range) = &last_included_range
403 && snippet.range.start < last_range.end
404 {
405 if snippet.range.end <= last_range.end {
406 continue;
407 }
408 // TODO: Should probably also handle case where there is just one char (newline)
409 // between snippets - assume it's a newline.
410 let text = &snippet.text[last_range.end - snippet.range.start..];
411 if is_excerpt_file {
412 push_excerpt_file_range(
413 last_range.end..snippet.range.end,
414 text,
415 excerpt_file_insertions,
416 output,
417 );
418 } else {
419 output.push_str(text);
420 }
421 last_included_range = Some(last_range.start..snippet.range.end);
422 continue;
423 }
424 if last_included_range.is_some() {
425 output.push_str("…\n");
426 }
427 if is_excerpt_file {
428 push_excerpt_file_range(
429 snippet.range.clone(),
430 snippet.text,
431 excerpt_file_insertions,
432 output,
433 );
434 } else {
435 output.push_str(snippet.text);
436 }
437 last_included_range = Some(snippet.range.clone());
438 }
439
440 output.push_str("```\n\n");
441 }
442 }
443}
444
445fn declaration_score_density(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
446 declaration_score(declaration, style) / declaration_size(declaration, style) as f32
447}
448
449fn declaration_score(declaration: &ReferencedDeclaration, style: SnippetStyle) -> f32 {
450 match style {
451 SnippetStyle::Signature => declaration.signature_score,
452 SnippetStyle::Declaration => declaration.declaration_score,
453 }
454}
455
456fn declaration_size(declaration: &ReferencedDeclaration, style: SnippetStyle) -> usize {
457 match style {
458 SnippetStyle::Signature => declaration.signature_range.len(),
459 SnippetStyle::Declaration => declaration.text.len(),
460 }
461}