1use crate::assemble_excerpts::assemble_excerpt_ranges;
2use anyhow::Result;
3use collections::HashMap;
4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
5use gpui::{App, AppContext, AsyncApp, Context, Entity, EntityId, EventEmitter, Task, WeakEntity};
6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
7use project::{LocationLink, Project, ProjectPath};
8use smallvec::SmallVec;
9use std::{
10 collections::hash_map,
11 ops::Range,
12 path::Path,
13 sync::Arc,
14 time::{Duration, Instant},
15};
16use util::paths::PathStyle;
17use util::rel_path::RelPath;
18use util::{RangeExt as _, ResultExt};
19
20mod assemble_excerpts;
21#[cfg(test)]
22mod edit_prediction_context_tests;
23#[cfg(test)]
24mod fake_definition_lsp;
25
26pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
27
28const IDENTIFIER_LINE_COUNT: u32 = 3;
29
30pub struct RelatedExcerptStore {
31 project: WeakEntity<Project>,
32 related_buffers: Vec<RelatedBuffer>,
33 cache: HashMap<Identifier, Arc<CacheEntry>>,
34 update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
35 identifier_line_count: u32,
36}
37
38struct RelatedBuffer {
39 buffer: Entity<Buffer>,
40 path: Arc<Path>,
41 anchor_ranges: Vec<Range<Anchor>>,
42 excerpt_orders: Vec<usize>,
43 cached_file: Option<CachedRelatedFile>,
44}
45
46struct CachedRelatedFile {
47 excerpts: Vec<RelatedExcerpt>,
48 buffer_version: clock::Global,
49}
50
51pub enum RelatedExcerptStoreEvent {
52 StartedRefresh,
53 FinishedRefresh {
54 cache_hit_count: usize,
55 cache_miss_count: usize,
56 mean_definition_latency: Duration,
57 max_definition_latency: Duration,
58 },
59}
60
61#[derive(Clone, Debug, PartialEq, Eq, Hash)]
62struct Identifier {
63 pub name: String,
64 pub range: Range<Anchor>,
65}
66
67enum DefinitionTask {
68 CacheHit(Arc<CacheEntry>),
69 CacheMiss {
70 definitions: Task<Result<Option<Vec<LocationLink>>>>,
71 type_definitions: Task<Result<Option<Vec<LocationLink>>>>,
72 },
73}
74
75#[derive(Debug)]
76struct CacheEntry {
77 definitions: SmallVec<[CachedDefinition; 1]>,
78 type_definitions: SmallVec<[CachedDefinition; 1]>,
79}
80
81#[derive(Clone, Debug)]
82struct CachedDefinition {
83 path: ProjectPath,
84 buffer: Entity<Buffer>,
85 anchor_range: Range<Anchor>,
86}
87
88const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
89
90impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
91
92impl RelatedExcerptStore {
93 pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
94 let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
95 cx.spawn(async move |this, cx| {
96 let executor = cx.background_executor().clone();
97 while let Some((mut buffer, mut position)) = update_rx.next().await {
98 let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
99 loop {
100 futures::select_biased! {
101 next = update_rx.next() => {
102 if let Some((new_buffer, new_position)) = next {
103 buffer = new_buffer;
104 position = new_position;
105 timer = executor.timer(DEBOUNCE_DURATION).fuse();
106 } else {
107 return anyhow::Ok(());
108 }
109 }
110 _ = timer => break,
111 }
112 }
113
114 Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
115 }
116 anyhow::Ok(())
117 })
118 .detach_and_log_err(cx);
119
120 RelatedExcerptStore {
121 project: project.downgrade(),
122 update_tx,
123 related_buffers: Vec::new(),
124 cache: Default::default(),
125 identifier_line_count: IDENTIFIER_LINE_COUNT,
126 }
127 }
128
129 pub fn set_identifier_line_count(&mut self, count: u32) {
130 self.identifier_line_count = count;
131 }
132
133 pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
134 self.update_tx.unbounded_send((buffer, position)).ok();
135 }
136
137 pub fn related_files(&mut self, cx: &App) -> Vec<RelatedFile> {
138 self.related_buffers
139 .iter_mut()
140 .map(|related| related.related_file(cx))
141 .collect()
142 }
143
144 pub fn related_files_with_buffers(
145 &mut self,
146 cx: &App,
147 ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
148 self.related_buffers
149 .iter_mut()
150 .map(|related| (related.related_file(cx), related.buffer.clone()))
151 }
152
153 pub fn set_related_files(&mut self, files: Vec<RelatedFile>, cx: &App) {
154 self.related_buffers = files
155 .into_iter()
156 .filter_map(|file| {
157 let project = self.project.upgrade()?;
158 let project = project.read(cx);
159 let worktree = project.worktrees(cx).find(|wt| {
160 let root_name = wt.read(cx).root_name().as_unix_str();
161 file.path
162 .components()
163 .next()
164 .is_some_and(|c| c.as_os_str() == root_name)
165 })?;
166 let worktree = worktree.read(cx);
167 let relative_path = file
168 .path
169 .strip_prefix(worktree.root_name().as_unix_str())
170 .ok()?;
171 let relative_path = RelPath::new(relative_path, PathStyle::Posix).ok()?;
172 let project_path = ProjectPath {
173 worktree_id: worktree.id(),
174 path: relative_path.into_owned().into(),
175 };
176 let buffer = project.get_open_buffer(&project_path, cx)?;
177 let snapshot = buffer.read(cx).snapshot();
178 let mut anchor_ranges = Vec::with_capacity(file.excerpts.len());
179 let mut excerpt_orders = Vec::with_capacity(file.excerpts.len());
180 for excerpt in &file.excerpts {
181 let start = snapshot.anchor_before(Point::new(excerpt.row_range.start, 0));
182 let end_row = excerpt.row_range.end;
183 let end_col = snapshot.line_len(end_row);
184 let end = snapshot.anchor_after(Point::new(end_row, end_col));
185 anchor_ranges.push(start..end);
186 excerpt_orders.push(excerpt.order);
187 }
188 Some(RelatedBuffer {
189 buffer,
190 path: file.path.clone(),
191 anchor_ranges,
192 excerpt_orders,
193 cached_file: None,
194 })
195 })
196 .collect();
197 }
198
199 async fn fetch_excerpts(
200 this: WeakEntity<Self>,
201 buffer: Entity<Buffer>,
202 position: Anchor,
203 cx: &mut AsyncApp,
204 ) -> Result<()> {
205 let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
206 (
207 this.project.upgrade(),
208 buffer.read(cx).snapshot(),
209 this.identifier_line_count,
210 )
211 })?;
212 let Some(project) = project else {
213 return Ok(());
214 };
215
216 let file = snapshot.file().cloned();
217 if let Some(file) = &file {
218 log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
219 }
220
221 this.update(cx, |_, cx| {
222 cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
223 })?;
224
225 let identifiers_with_ranks = cx
226 .background_spawn(async move {
227 let cursor_offset = position.to_offset(&snapshot);
228 let identifiers =
229 identifiers_for_position(&snapshot, position, identifier_line_count);
230
231 // Compute byte distance from cursor to each identifier, then sort by
232 // distance so we can assign ordinal ranks. Identifiers at the same
233 // distance share the same rank.
234 let mut identifiers_with_distance: Vec<(Identifier, usize)> = identifiers
235 .into_iter()
236 .map(|id| {
237 let start = id.range.start.to_offset(&snapshot);
238 let end = id.range.end.to_offset(&snapshot);
239 let distance = if cursor_offset < start {
240 start - cursor_offset
241 } else if cursor_offset > end {
242 cursor_offset - end
243 } else {
244 0
245 };
246 (id, distance)
247 })
248 .collect();
249 identifiers_with_distance.sort_by_key(|(_, distance)| *distance);
250
251 let mut cursor_distances: HashMap<Identifier, usize> = HashMap::default();
252 let mut current_rank = 0;
253 let mut previous_distance = None;
254 for (identifier, distance) in &identifiers_with_distance {
255 if previous_distance != Some(*distance) {
256 current_rank = cursor_distances.len();
257 previous_distance = Some(*distance);
258 }
259 cursor_distances.insert(identifier.clone(), current_rank);
260 }
261
262 (identifiers_with_distance, cursor_distances)
263 })
264 .await;
265
266 let (identifiers_with_distance, cursor_distances) = identifiers_with_ranks;
267
268 let async_cx = cx.clone();
269 let start_time = Instant::now();
270 let futures = this.update(cx, |this, cx| {
271 identifiers_with_distance
272 .into_iter()
273 .filter_map(|(identifier, _)| {
274 let task = if let Some(entry) = this.cache.get(&identifier) {
275 DefinitionTask::CacheHit(entry.clone())
276 } else {
277 let definitions = this
278 .project
279 .update(cx, |project, cx| {
280 project.definitions(&buffer, identifier.range.start, cx)
281 })
282 .ok()?;
283 let type_definitions = this
284 .project
285 .update(cx, |project, cx| {
286 project.type_definitions(&buffer, identifier.range.start, cx)
287 })
288 .ok()?;
289 DefinitionTask::CacheMiss {
290 definitions,
291 type_definitions,
292 }
293 };
294
295 let cx = async_cx.clone();
296 let project = project.clone();
297 Some(async move {
298 match task {
299 DefinitionTask::CacheHit(cache_entry) => {
300 Some((identifier, cache_entry, None))
301 }
302 DefinitionTask::CacheMiss {
303 definitions,
304 type_definitions,
305 } => {
306 let (definition_locations, type_definition_locations) =
307 futures::join!(definitions, type_definitions);
308 let duration = start_time.elapsed();
309
310 let definition_locations =
311 definition_locations.log_err().flatten().unwrap_or_default();
312 let type_definition_locations = type_definition_locations
313 .log_err()
314 .flatten()
315 .unwrap_or_default();
316
317 Some(cx.update(|cx| {
318 let definitions: SmallVec<[CachedDefinition; 1]> =
319 definition_locations
320 .into_iter()
321 .filter_map(|location| {
322 process_definition(location, &project, cx)
323 })
324 .collect();
325
326 let type_definitions: SmallVec<[CachedDefinition; 1]> =
327 type_definition_locations
328 .into_iter()
329 .filter_map(|location| {
330 process_definition(location, &project, cx)
331 })
332 .filter(|type_def| {
333 !definitions.iter().any(|def| {
334 def.buffer.entity_id()
335 == type_def.buffer.entity_id()
336 && def.anchor_range == type_def.anchor_range
337 })
338 })
339 .collect();
340
341 (
342 identifier,
343 Arc::new(CacheEntry {
344 definitions,
345 type_definitions,
346 }),
347 Some(duration),
348 )
349 }))
350 }
351 }
352 })
353 })
354 .collect::<Vec<_>>()
355 })?;
356
357 let mut cache_hit_count = 0;
358 let mut cache_miss_count = 0;
359 let mut mean_definition_latency = Duration::ZERO;
360 let mut max_definition_latency = Duration::ZERO;
361 let mut new_cache = HashMap::default();
362 new_cache.reserve(futures.len());
363 for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
364 new_cache.insert(identifier, entry);
365 if let Some(duration) = duration {
366 cache_miss_count += 1;
367 mean_definition_latency += duration;
368 max_definition_latency = max_definition_latency.max(duration);
369 } else {
370 cache_hit_count += 1;
371 }
372 }
373 mean_definition_latency /= cache_miss_count.max(1) as u32;
374
375 let (new_cache, related_buffers) =
376 rebuild_related_files(&project, new_cache, &cursor_distances, cx).await?;
377
378 if let Some(file) = &file {
379 log::debug!(
380 "finished retrieving context buffer:{}, latency:{:?}",
381 file.path().as_unix_str(),
382 start_time.elapsed()
383 );
384 }
385
386 this.update(cx, |this, cx| {
387 this.cache = new_cache;
388 this.related_buffers = related_buffers;
389 cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
390 cache_hit_count,
391 cache_miss_count,
392 mean_definition_latency,
393 max_definition_latency,
394 });
395 })?;
396
397 anyhow::Ok(())
398 }
399}
400
401async fn rebuild_related_files(
402 project: &Entity<Project>,
403 mut new_entries: HashMap<Identifier, Arc<CacheEntry>>,
404 cursor_distances: &HashMap<Identifier, usize>,
405 cx: &mut AsyncApp,
406) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedBuffer>)> {
407 let mut snapshots = HashMap::default();
408 let mut worktree_root_names = HashMap::default();
409 for entry in new_entries.values() {
410 for definition in entry
411 .definitions
412 .iter()
413 .chain(entry.type_definitions.iter())
414 {
415 if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
416 definition
417 .buffer
418 .read_with(cx, |buffer, _| buffer.parsing_idle())
419 .await;
420 e.insert(
421 definition
422 .buffer
423 .read_with(cx, |buffer, _| buffer.snapshot()),
424 );
425 }
426 let worktree_id = definition.path.worktree_id;
427 if let hash_map::Entry::Vacant(e) =
428 worktree_root_names.entry(definition.path.worktree_id)
429 {
430 project.read_with(cx, |project, cx| {
431 if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
432 e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
433 }
434 });
435 }
436 }
437 }
438
439 let cursor_distances = cursor_distances.clone();
440 Ok(cx
441 .background_spawn(async move {
442 let mut ranges_by_buffer =
443 HashMap::<EntityId, (Entity<Buffer>, Vec<(Range<Point>, usize)>)>::default();
444 let mut paths_by_buffer = HashMap::default();
445 let mut min_rank_by_buffer = HashMap::<EntityId, usize>::default();
446 for (identifier, entry) in new_entries.iter_mut() {
447 let rank = cursor_distances
448 .get(identifier)
449 .copied()
450 .unwrap_or(usize::MAX);
451 for definition in entry
452 .definitions
453 .iter()
454 .chain(entry.type_definitions.iter())
455 {
456 let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
457 continue;
458 };
459 paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
460
461 let buffer_rank = min_rank_by_buffer
462 .entry(definition.buffer.entity_id())
463 .or_insert(usize::MAX);
464 *buffer_rank = (*buffer_rank).min(rank);
465
466 ranges_by_buffer
467 .entry(definition.buffer.entity_id())
468 .or_insert_with(|| (definition.buffer.clone(), Vec::new()))
469 .1
470 .push((definition.anchor_range.to_point(snapshot), rank));
471 }
472 }
473
474 let mut related_buffers: Vec<RelatedBuffer> = ranges_by_buffer
475 .into_iter()
476 .filter_map(|(entity_id, (buffer, ranges))| {
477 let snapshot = snapshots.get(&entity_id)?;
478 let project_path = paths_by_buffer.get(&entity_id)?;
479 let assembled = assemble_excerpt_ranges(snapshot, ranges);
480 let root_name = worktree_root_names.get(&project_path.worktree_id)?;
481
482 let path: Arc<Path> = Path::new(&format!(
483 "{}/{}",
484 root_name,
485 project_path.path.as_unix_str()
486 ))
487 .into();
488
489 let mut anchor_ranges = Vec::with_capacity(assembled.len());
490 let mut excerpt_orders = Vec::with_capacity(assembled.len());
491 for (row_range, order) in assembled {
492 let start = snapshot.anchor_before(Point::new(row_range.start, 0));
493 let end_col = snapshot.line_len(row_range.end);
494 let end = snapshot.anchor_after(Point::new(row_range.end, end_col));
495 anchor_ranges.push(start..end);
496 excerpt_orders.push(order);
497 }
498
499 let mut related_buffer = RelatedBuffer {
500 buffer,
501 path,
502 anchor_ranges,
503 excerpt_orders,
504 cached_file: None,
505 };
506 related_buffer.fill_cache(snapshot);
507 Some(related_buffer)
508 })
509 .collect();
510
511 related_buffers.sort_by(|a, b| {
512 let rank_a = min_rank_by_buffer
513 .get(&a.buffer.entity_id())
514 .copied()
515 .unwrap_or(usize::MAX);
516 let rank_b = min_rank_by_buffer
517 .get(&b.buffer.entity_id())
518 .copied()
519 .unwrap_or(usize::MAX);
520 rank_a.cmp(&rank_b).then_with(|| a.path.cmp(&b.path))
521 });
522
523 (new_entries, related_buffers)
524 })
525 .await)
526}
527
528impl RelatedBuffer {
529 fn related_file(&mut self, cx: &App) -> RelatedFile {
530 let buffer = self.buffer.read(cx);
531 let path = self.path.clone();
532 let cached = if let Some(cached) = &self.cached_file
533 && buffer.version() == cached.buffer_version
534 {
535 cached
536 } else {
537 self.fill_cache(buffer)
538 };
539 let related_file = RelatedFile {
540 path,
541 excerpts: cached.excerpts.clone(),
542 max_row: buffer.max_point().row,
543 in_open_source_repo: false,
544 };
545 return related_file;
546 }
547
548 fn fill_cache(&mut self, buffer: &text::BufferSnapshot) -> &CachedRelatedFile {
549 let excerpts = self
550 .anchor_ranges
551 .iter()
552 .zip(self.excerpt_orders.iter())
553 .map(|(range, &order)| {
554 let start = range.start.to_point(buffer);
555 let end = range.end.to_point(buffer);
556 RelatedExcerpt {
557 row_range: start.row..end.row,
558 text: buffer.text_for_range(start..end).collect::<String>().into(),
559 order,
560 }
561 })
562 .collect::<Vec<_>>();
563 self.cached_file = Some(CachedRelatedFile {
564 excerpts: excerpts,
565 buffer_version: buffer.version().clone(),
566 });
567 self.cached_file.as_ref().unwrap()
568 }
569}
570
571use language::ToPoint as _;
572
573const MAX_TARGET_LEN: usize = 128;
574
575fn process_definition(
576 location: LocationLink,
577 project: &Entity<Project>,
578 cx: &mut App,
579) -> Option<CachedDefinition> {
580 let buffer = location.target.buffer.read(cx);
581 let anchor_range = location.target.range;
582 let file = buffer.file()?;
583 let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
584 if worktree.read(cx).is_single_file() {
585 return None;
586 }
587
588 // If the target range is large, it likely means we requested the definition of an entire module.
589 // For individual definitions, the target range should be small as it only covers the symbol.
590 let buffer = location.target.buffer.read(cx);
591 let target_len = anchor_range.to_offset(&buffer).len();
592 if target_len > MAX_TARGET_LEN {
593 return None;
594 }
595
596 Some(CachedDefinition {
597 path: ProjectPath {
598 worktree_id: file.worktree_id(cx),
599 path: file.path().clone(),
600 },
601 buffer: location.target.buffer,
602 anchor_range,
603 })
604}
605
606/// Gets all of the identifiers that are present in the given line, and its containing
607/// outline items.
608fn identifiers_for_position(
609 buffer: &BufferSnapshot,
610 position: Anchor,
611 identifier_line_count: u32,
612) -> Vec<Identifier> {
613 let offset = position.to_offset(buffer);
614 let point = buffer.offset_to_point(offset);
615
616 // Search for identifiers on lines adjacent to the cursor.
617 let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
618 let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
619 let line_range = start..end;
620 let mut ranges = vec![line_range.to_offset(&buffer)];
621
622 // Search for identifiers mentioned in headers/signatures of containing outline items.
623 let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
624 for item in outline_items {
625 if let Some(body_range) = item.body_range(&buffer) {
626 ranges.push(item.range.start..body_range.start.to_offset(&buffer));
627 } else {
628 ranges.push(item.range.clone());
629 }
630 }
631
632 ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
633 ranges.dedup_by(|a, b| {
634 if a.start <= b.end {
635 b.start = b.start.min(a.start);
636 b.end = b.end.max(a.end);
637 true
638 } else {
639 false
640 }
641 });
642
643 let mut identifiers = Vec::new();
644 let outer_range =
645 ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
646
647 let mut captures = buffer.captures(outer_range.clone(), |grammar| {
648 grammar
649 .highlights_config
650 .as_ref()
651 .map(|config| &config.query)
652 });
653
654 for range in ranges {
655 captures.set_byte_range(range.start..outer_range.end);
656
657 let mut last_range = None;
658 while let Some(capture) = captures.peek() {
659 let node_range = capture.node.byte_range();
660 if node_range.start > range.end {
661 break;
662 }
663 let config = captures.grammars()[capture.grammar_index]
664 .highlights_config
665 .as_ref();
666
667 if let Some(config) = config
668 && config.identifier_capture_indices.contains(&capture.index)
669 && range.contains_inclusive(&node_range)
670 && Some(&node_range) != last_range.as_ref()
671 {
672 let name = buffer.text_for_range(node_range.clone()).collect();
673 identifiers.push(Identifier {
674 range: buffer.anchor_after(node_range.start)
675 ..buffer.anchor_before(node_range.end),
676 name,
677 });
678 last_range = Some(node_range);
679 }
680
681 captures.advance();
682 }
683 }
684
685 identifiers
686}