1use cloud_llm_client::predict_edits_v3::DeclarationScoreComponents;
2use collections::HashMap;
3use language::BufferSnapshot;
4use ordered_float::OrderedFloat;
5use project::ProjectEntryId;
6use serde::Serialize;
7use std::{cmp::Reverse, ops::Range, path::Path, sync::Arc};
8use strum::EnumIter;
9use text::{Point, ToPoint};
10use util::RangeExt as _;
11
12use crate::{
13 CachedDeclarationPath, Declaration, EditPredictionExcerpt, Identifier,
14 imports::{Import, Imports, Module},
15 reference::{Reference, ReferenceRegion},
16 syntax_index::SyntaxIndexState,
17 text_similarity::{Occurrences, jaccard_similarity, weighted_overlap_coefficient},
18};
19
20const MAX_IDENTIFIER_DECLARATION_COUNT: usize = 16;
21
22#[derive(Clone, Debug, PartialEq)]
23pub struct EditPredictionScoreOptions {
24 pub omit_excerpt_overlaps: bool,
25 pub prefilter_score_ratio: f32,
26}
27
28#[derive(Clone, Debug)]
29pub struct ScoredDeclaration {
30 /// identifier used by the local reference
31 pub identifier: Identifier,
32 pub declaration: Declaration,
33 pub components: DeclarationScoreComponents,
34}
35
36#[derive(EnumIter, Clone, Copy, PartialEq, Eq, Hash, Debug)]
37pub enum DeclarationStyle {
38 Signature,
39 Declaration,
40}
41
42#[derive(Clone, Debug, Serialize, Default)]
43pub struct DeclarationScores {
44 pub signature: f32,
45 pub declaration: f32,
46 pub retrieval: f32,
47}
48
49impl ScoredDeclaration {
50 /// Returns the score for this declaration with the specified style.
51 pub fn score(&self, style: DeclarationStyle) -> f32 {
52 // TODO: handle truncation
53
54 // Score related to how likely this is the correct declaration, range 0 to 1
55 let retrieval = self.retrieval_score();
56
57 // Score related to the distance between the reference and cursor, range 0 to 1
58 let distance_score = if self.components.is_referenced_nearby {
59 1.0 / (1.0 + self.components.reference_line_distance as f32 / 10.0).powf(2.0)
60 } else {
61 // same score as ~14 lines away, rationale is to not overly penalize references from parent signatures
62 0.5
63 };
64
65 // For now instead of linear combination, the scores are just multiplied together.
66 let combined_score = 10.0 * retrieval * distance_score;
67
68 match style {
69 DeclarationStyle::Signature => {
70 combined_score * self.components.excerpt_vs_signature_weighted_overlap
71 }
72 DeclarationStyle::Declaration => {
73 2.0 * combined_score * self.components.excerpt_vs_item_weighted_overlap
74 }
75 }
76 }
77
78 pub fn retrieval_score(&self) -> f32 {
79 let mut score = if self.components.is_same_file {
80 10.0 / self.components.same_file_declaration_count as f32
81 } else if self.components.path_import_match_count > 0 {
82 3.0
83 } else if self.components.wildcard_path_import_match_count > 0 {
84 1.0
85 } else if self.components.normalized_import_similarity > 0.0 {
86 self.components.normalized_import_similarity
87 } else if self.components.normalized_wildcard_import_similarity > 0.0 {
88 0.5 * self.components.normalized_wildcard_import_similarity
89 } else {
90 1.0 / self.components.declaration_count as f32
91 };
92 score *= 1. + self.components.included_by_others as f32 / 2.;
93 score *= 1. + self.components.includes_others as f32 / 4.;
94 score
95 }
96
97 pub fn size(&self, style: DeclarationStyle) -> usize {
98 match &self.declaration {
99 Declaration::File { declaration, .. } => match style {
100 DeclarationStyle::Signature => declaration.signature_range.len(),
101 DeclarationStyle::Declaration => declaration.text.len(),
102 },
103 Declaration::Buffer { declaration, .. } => match style {
104 DeclarationStyle::Signature => declaration.signature_range.len(),
105 DeclarationStyle::Declaration => declaration.item_range.len(),
106 },
107 }
108 }
109
110 pub fn score_density(&self, style: DeclarationStyle) -> f32 {
111 self.score(style) / self.size(style) as f32
112 }
113}
114
115pub fn scored_declarations(
116 options: &EditPredictionScoreOptions,
117 index: &SyntaxIndexState,
118 excerpt: &EditPredictionExcerpt,
119 excerpt_occurrences: &Occurrences,
120 adjacent_occurrences: &Occurrences,
121 imports: &Imports,
122 identifier_to_references: HashMap<Identifier, Vec<Reference>>,
123 cursor_offset: usize,
124 current_buffer: &BufferSnapshot,
125) -> Vec<ScoredDeclaration> {
126 let cursor_point = cursor_offset.to_point(¤t_buffer);
127
128 let mut wildcard_import_occurrences = Vec::new();
129 let mut wildcard_import_paths = Vec::new();
130 for wildcard_import in imports.wildcard_modules.iter() {
131 match wildcard_import {
132 Module::Namespace(namespace) => {
133 wildcard_import_occurrences.push(namespace.occurrences())
134 }
135 Module::SourceExact(path) => wildcard_import_paths.push(path),
136 Module::SourceFuzzy(path) => {
137 wildcard_import_occurrences.push(Occurrences::from_path(&path))
138 }
139 }
140 }
141
142 let mut scored_declarations = Vec::new();
143 let mut project_entry_id_to_outline_ranges: HashMap<ProjectEntryId, Vec<Range<usize>>> =
144 HashMap::default();
145 for (identifier, references) in identifier_to_references {
146 let mut import_occurrences = Vec::new();
147 let mut import_paths = Vec::new();
148 let mut found_external_identifier: Option<&Identifier> = None;
149
150 if let Some(imports) = imports.identifier_to_imports.get(&identifier) {
151 // only use alias when it's the only import, could be generalized if some language
152 // has overlapping aliases
153 //
154 // TODO: when an aliased declaration is included in the prompt, should include the
155 // aliasing in the prompt.
156 //
157 // TODO: For SourceFuzzy consider having componentwise comparison that pays
158 // attention to ordering.
159 if let [
160 Import::Alias {
161 module,
162 external_identifier,
163 },
164 ] = imports.as_slice()
165 {
166 match module {
167 Module::Namespace(namespace) => {
168 import_occurrences.push(namespace.occurrences())
169 }
170 Module::SourceExact(path) => import_paths.push(path),
171 Module::SourceFuzzy(path) => {
172 import_occurrences.push(Occurrences::from_path(&path))
173 }
174 }
175 found_external_identifier = Some(&external_identifier);
176 } else {
177 for import in imports {
178 match import {
179 Import::Direct { module } => match module {
180 Module::Namespace(namespace) => {
181 import_occurrences.push(namespace.occurrences())
182 }
183 Module::SourceExact(path) => import_paths.push(path),
184 Module::SourceFuzzy(path) => {
185 import_occurrences.push(Occurrences::from_path(&path))
186 }
187 },
188 Import::Alias { .. } => {}
189 }
190 }
191 }
192 }
193
194 let identifier_to_lookup = found_external_identifier.unwrap_or(&identifier);
195 // TODO: update this to be able to return more declarations? Especially if there is the
196 // ability to quickly filter a large list (based on imports)
197 let identifier_declarations = index
198 .declarations_for_identifier::<MAX_IDENTIFIER_DECLARATION_COUNT>(&identifier_to_lookup);
199 let declaration_count = identifier_declarations.len();
200
201 if declaration_count == 0 {
202 continue;
203 }
204
205 // TODO: option to filter out other candidates when same file / import match
206 let mut checked_declarations = Vec::with_capacity(declaration_count);
207 for (declaration_id, declaration) in identifier_declarations {
208 match declaration {
209 Declaration::Buffer {
210 buffer_id,
211 declaration: buffer_declaration,
212 ..
213 } => {
214 if buffer_id == ¤t_buffer.remote_id() {
215 let already_included_in_prompt =
216 range_intersection(&buffer_declaration.item_range, &excerpt.range)
217 .is_some()
218 || excerpt
219 .parent_declarations
220 .iter()
221 .any(|(excerpt_parent, _)| excerpt_parent == &declaration_id);
222 if !options.omit_excerpt_overlaps || !already_included_in_prompt {
223 let declaration_line = buffer_declaration
224 .item_range
225 .start
226 .to_point(current_buffer)
227 .row;
228 let declaration_line_distance =
229 (cursor_point.row as i32 - declaration_line as i32).unsigned_abs();
230 checked_declarations.push(CheckedDeclaration {
231 declaration,
232 same_file_line_distance: Some(declaration_line_distance),
233 path_import_match_count: 0,
234 wildcard_path_import_match_count: 0,
235 });
236 }
237 continue;
238 } else {
239 }
240 }
241 Declaration::File { .. } => {}
242 }
243 let declaration_path = declaration.cached_path();
244 let path_import_match_count = import_paths
245 .iter()
246 .filter(|import_path| {
247 declaration_path_matches_import(&declaration_path, import_path)
248 })
249 .count();
250 let wildcard_path_import_match_count = wildcard_import_paths
251 .iter()
252 .filter(|import_path| {
253 declaration_path_matches_import(&declaration_path, import_path)
254 })
255 .count();
256 checked_declarations.push(CheckedDeclaration {
257 declaration,
258 same_file_line_distance: None,
259 path_import_match_count,
260 wildcard_path_import_match_count,
261 });
262 }
263
264 let mut max_import_similarity = 0.0;
265 let mut max_wildcard_import_similarity = 0.0;
266 // todo! consider max retrieval score instead?
267 let mut max_score = 0.0;
268
269 let mut scored_declarations_for_identifier = Vec::with_capacity(checked_declarations.len());
270 for checked_declaration in checked_declarations {
271 let same_file_declaration_count =
272 index.file_declaration_count(checked_declaration.declaration);
273
274 let declaration = score_declaration(
275 &identifier,
276 &references,
277 checked_declaration,
278 same_file_declaration_count,
279 declaration_count,
280 &excerpt_occurrences,
281 &adjacent_occurrences,
282 &import_occurrences,
283 &wildcard_import_occurrences,
284 cursor_point,
285 current_buffer,
286 );
287
288 if declaration.components.import_similarity > max_import_similarity {
289 max_import_similarity = declaration.components.import_similarity;
290 }
291
292 if declaration.components.wildcard_import_similarity > max_wildcard_import_similarity {
293 max_wildcard_import_similarity = declaration.components.wildcard_import_similarity;
294 }
295
296 project_entry_id_to_outline_ranges
297 .entry(declaration.declaration.project_entry_id())
298 .or_default()
299 .push(declaration.declaration.item_range());
300 let score = declaration.score(DeclarationStyle::Declaration);
301 scored_declarations_for_identifier.push(declaration);
302
303 if score > max_score {
304 max_score = score;
305 }
306 }
307
308 if max_import_similarity > 0.0
309 || max_wildcard_import_similarity > 0.0
310 || options.prefilter_score_ratio > 0.0
311 {
312 for mut declaration in scored_declarations_for_identifier.into_iter() {
313 if max_import_similarity > 0.0 {
314 declaration.components.max_import_similarity = max_import_similarity;
315 declaration.components.normalized_import_similarity =
316 declaration.components.import_similarity / max_import_similarity;
317 }
318 if max_wildcard_import_similarity > 0.0 {
319 declaration.components.normalized_wildcard_import_similarity =
320 declaration.components.wildcard_import_similarity
321 / max_wildcard_import_similarity;
322 }
323 if options.prefilter_score_ratio <= 0.0
324 || declaration.score(DeclarationStyle::Declaration)
325 > max_score * options.prefilter_score_ratio
326 {
327 scored_declarations.push(declaration);
328 }
329 }
330 } else {
331 scored_declarations.extend(scored_declarations_for_identifier);
332 }
333 }
334
335 // TODO: Inform this via import / retrieval scores of outline items
336 // TODO: Consider using a sweepline
337 for scored_declaration in scored_declarations.iter_mut() {
338 let project_entry_id = scored_declaration.declaration.project_entry_id();
339 let Some(ranges) = project_entry_id_to_outline_ranges.get(&project_entry_id) else {
340 continue;
341 };
342 for range in ranges {
343 if range.contains_inclusive(&scored_declaration.declaration.item_range()) {
344 scored_declaration.components.included_by_others += 1
345 } else if scored_declaration
346 .declaration
347 .item_range()
348 .contains_inclusive(range)
349 {
350 scored_declaration.components.includes_others += 1
351 }
352 }
353 }
354
355 scored_declarations.sort_unstable_by_key(|declaration| {
356 Reverse(OrderedFloat(
357 declaration.score(DeclarationStyle::Declaration),
358 ))
359 });
360
361 scored_declarations
362}
363
364struct CheckedDeclaration<'a> {
365 declaration: &'a Declaration,
366 same_file_line_distance: Option<u32>,
367 path_import_match_count: usize,
368 wildcard_path_import_match_count: usize,
369}
370
371fn declaration_path_matches_import(
372 declaration_path: &CachedDeclarationPath,
373 import_path: &Arc<Path>,
374) -> bool {
375 if import_path.is_absolute() {
376 declaration_path.equals_absolute_path(import_path)
377 } else {
378 declaration_path.ends_with_posix_path(import_path)
379 }
380}
381
382fn range_intersection<T: Ord + Clone>(a: &Range<T>, b: &Range<T>) -> Option<Range<T>> {
383 let start = a.start.clone().max(b.start.clone());
384 let end = a.end.clone().min(b.end.clone());
385 if start < end {
386 Some(Range { start, end })
387 } else {
388 None
389 }
390}
391
392fn score_declaration(
393 identifier: &Identifier,
394 references: &[Reference],
395 checked_declaration: CheckedDeclaration,
396 same_file_declaration_count: usize,
397 declaration_count: usize,
398 excerpt_occurrences: &Occurrences,
399 adjacent_occurrences: &Occurrences,
400 import_occurrences: &[Occurrences],
401 wildcard_import_occurrences: &[Occurrences],
402 cursor: Point,
403 current_buffer: &BufferSnapshot,
404) -> ScoredDeclaration {
405 let CheckedDeclaration {
406 declaration,
407 same_file_line_distance,
408 path_import_match_count,
409 wildcard_path_import_match_count,
410 } = checked_declaration;
411
412 let is_referenced_nearby = references
413 .iter()
414 .any(|r| r.region == ReferenceRegion::Nearby);
415 let is_referenced_in_breadcrumb = references
416 .iter()
417 .any(|r| r.region == ReferenceRegion::Breadcrumb);
418 let reference_count = references.len();
419 let reference_line_distance = references
420 .iter()
421 .map(|r| {
422 let reference_line = r.range.start.to_point(current_buffer).row as i32;
423 (cursor.row as i32 - reference_line).unsigned_abs()
424 })
425 .min()
426 .unwrap();
427
428 let is_same_file = same_file_line_distance.is_some();
429 let declaration_line_distance = same_file_line_distance.unwrap_or(u32::MAX);
430
431 let item_source_occurrences = Occurrences::within_string(&declaration.item_text().0);
432 let item_signature_occurrences = Occurrences::within_string(&declaration.signature_text().0);
433 let excerpt_vs_item_jaccard = jaccard_similarity(excerpt_occurrences, &item_source_occurrences);
434 let excerpt_vs_signature_jaccard =
435 jaccard_similarity(excerpt_occurrences, &item_signature_occurrences);
436 let adjacent_vs_item_jaccard =
437 jaccard_similarity(adjacent_occurrences, &item_source_occurrences);
438 let adjacent_vs_signature_jaccard =
439 jaccard_similarity(adjacent_occurrences, &item_signature_occurrences);
440
441 let excerpt_vs_item_weighted_overlap =
442 weighted_overlap_coefficient(excerpt_occurrences, &item_source_occurrences);
443 let excerpt_vs_signature_weighted_overlap =
444 weighted_overlap_coefficient(excerpt_occurrences, &item_signature_occurrences);
445 let adjacent_vs_item_weighted_overlap =
446 weighted_overlap_coefficient(adjacent_occurrences, &item_source_occurrences);
447 let adjacent_vs_signature_weighted_overlap =
448 weighted_overlap_coefficient(adjacent_occurrences, &item_signature_occurrences);
449
450 let mut import_similarity = 0f32;
451 let mut wildcard_import_similarity = 0f32;
452 if !import_occurrences.is_empty() || !wildcard_import_occurrences.is_empty() {
453 let cached_path = declaration.cached_path();
454 let path_occurrences = Occurrences::from_worktree_path(
455 cached_path
456 .worktree_abs_path
457 .file_name()
458 .map(|f| f.to_string_lossy()),
459 &cached_path.rel_path,
460 );
461 import_similarity = import_occurrences
462 .iter()
463 .map(|namespace_occurrences| {
464 OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
465 })
466 .max()
467 .map(|similarity| similarity.into_inner())
468 .unwrap_or_default();
469
470 // TODO: Consider something other than max
471 wildcard_import_similarity = wildcard_import_occurrences
472 .iter()
473 .map(|namespace_occurrences| {
474 OrderedFloat(jaccard_similarity(namespace_occurrences, &path_occurrences))
475 })
476 .max()
477 .map(|similarity| similarity.into_inner())
478 .unwrap_or_default();
479 }
480
481 // TODO: Consider adding declaration_file_count
482 let score_components = DeclarationScoreComponents {
483 is_same_file,
484 is_referenced_nearby,
485 is_referenced_in_breadcrumb,
486 reference_line_distance,
487 declaration_line_distance,
488 reference_count,
489 same_file_declaration_count,
490 declaration_count,
491 excerpt_vs_item_jaccard,
492 excerpt_vs_signature_jaccard,
493 adjacent_vs_item_jaccard,
494 adjacent_vs_signature_jaccard,
495 excerpt_vs_item_weighted_overlap,
496 excerpt_vs_signature_weighted_overlap,
497 adjacent_vs_item_weighted_overlap,
498 adjacent_vs_signature_weighted_overlap,
499 path_import_match_count,
500 wildcard_path_import_match_count,
501 import_similarity,
502 max_import_similarity: 0.0,
503 normalized_import_similarity: 0.0,
504 wildcard_import_similarity,
505 normalized_wildcard_import_similarity: 0.0,
506 included_by_others: 0,
507 includes_others: 0,
508 };
509
510 ScoredDeclaration {
511 identifier: identifier.clone(),
512 declaration: declaration.clone(),
513 components: score_components,
514 }
515}
516
517#[cfg(test)]
518mod test {
519 use super::*;
520
521 #[test]
522 fn test_declaration_path_matches() {
523 let declaration_path =
524 CachedDeclarationPath::new_for_test("/home/user/project", "src/maths.ts");
525
526 assert!(declaration_path_matches_import(
527 &declaration_path,
528 &Path::new("maths.ts").into()
529 ));
530
531 assert!(declaration_path_matches_import(
532 &declaration_path,
533 &Path::new("project/src/maths.ts").into()
534 ));
535
536 assert!(declaration_path_matches_import(
537 &declaration_path,
538 &Path::new("user/project/src/maths.ts").into()
539 ));
540
541 assert!(declaration_path_matches_import(
542 &declaration_path,
543 &Path::new("/home/user/project/src/maths.ts").into()
544 ));
545
546 assert!(!declaration_path_matches_import(
547 &declaration_path,
548 &Path::new("other.ts").into()
549 ));
550
551 assert!(!declaration_path_matches_import(
552 &declaration_path,
553 &Path::new("/home/user/project/src/other.ts").into()
554 ));
555 }
556}