1use crate::assemble_excerpts::assemble_excerpt_ranges;
2use anyhow::Result;
3use collections::HashMap;
4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
5use gpui::{App, AppContext, AsyncApp, Context, Entity, EntityId, EventEmitter, Task, WeakEntity};
6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
7use project::{LocationLink, Project, ProjectPath};
8use smallvec::SmallVec;
9use std::{
10 collections::hash_map,
11 ops::Range,
12 path::Path,
13 sync::Arc,
14 time::{Duration, Instant},
15};
16use util::paths::PathStyle;
17use util::rel_path::RelPath;
18use util::{RangeExt as _, ResultExt};
19
20mod assemble_excerpts;
21#[cfg(test)]
22mod edit_prediction_context_tests;
23#[cfg(test)]
24mod fake_definition_lsp;
25
26pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
27
28const IDENTIFIER_LINE_COUNT: u32 = 3;
29
30pub struct RelatedExcerptStore {
31 project: WeakEntity<Project>,
32 related_buffers: Vec<RelatedBuffer>,
33 cache: HashMap<Identifier, Arc<CacheEntry>>,
34 update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
35 identifier_line_count: u32,
36}
37
38struct RelatedBuffer {
39 buffer: Entity<Buffer>,
40 path: Arc<Path>,
41 anchor_ranges: Vec<Range<Anchor>>,
42 excerpt_orders: Vec<usize>,
43 cached_file: Option<CachedRelatedFile>,
44}
45
46struct CachedRelatedFile {
47 excerpts: Vec<RelatedExcerpt>,
48 buffer_version: clock::Global,
49}
50
51pub enum RelatedExcerptStoreEvent {
52 StartedRefresh,
53 FinishedRefresh {
54 cache_hit_count: usize,
55 cache_miss_count: usize,
56 mean_definition_latency: Duration,
57 max_definition_latency: Duration,
58 },
59}
60
61#[derive(Clone, Debug, PartialEq, Eq, Hash)]
62struct Identifier {
63 pub name: String,
64 pub range: Range<Anchor>,
65}
66
67enum DefinitionTask {
68 CacheHit(Arc<CacheEntry>),
69 CacheMiss(
70 Task<
71 Option<(
72 Task<Result<Option<Vec<LocationLink>>>>,
73 Task<Result<Option<Vec<LocationLink>>>>,
74 )>,
75 >,
76 ),
77}
78
79#[derive(Debug)]
80struct CacheEntry {
81 definitions: SmallVec<[CachedDefinition; 1]>,
82 type_definitions: SmallVec<[CachedDefinition; 1]>,
83}
84
85#[derive(Clone, Debug)]
86struct CachedDefinition {
87 path: ProjectPath,
88 buffer: Entity<Buffer>,
89 anchor_range: Range<Anchor>,
90}
91
92const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
93
94impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
95
96impl RelatedExcerptStore {
97 pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
98 let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
99 cx.spawn(async move |this, cx| {
100 let executor = cx.background_executor().clone();
101 while let Some((mut buffer, mut position)) = update_rx.next().await {
102 let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
103 loop {
104 futures::select_biased! {
105 next = update_rx.next() => {
106 if let Some((new_buffer, new_position)) = next {
107 buffer = new_buffer;
108 position = new_position;
109 timer = executor.timer(DEBOUNCE_DURATION).fuse();
110 } else {
111 return anyhow::Ok(());
112 }
113 }
114 _ = timer => break,
115 }
116 }
117
118 Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
119 }
120 anyhow::Ok(())
121 })
122 .detach_and_log_err(cx);
123
124 RelatedExcerptStore {
125 project: project.downgrade(),
126 update_tx,
127 related_buffers: Vec::new(),
128 cache: Default::default(),
129 identifier_line_count: IDENTIFIER_LINE_COUNT,
130 }
131 }
132
133 pub fn set_identifier_line_count(&mut self, count: u32) {
134 self.identifier_line_count = count;
135 }
136
137 pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
138 self.update_tx.unbounded_send((buffer, position)).ok();
139 }
140
141 pub fn related_files(&mut self, cx: &App) -> Vec<RelatedFile> {
142 self.related_buffers
143 .iter_mut()
144 .map(|related| related.related_file(cx))
145 .collect()
146 }
147
148 pub fn related_files_with_buffers(
149 &mut self,
150 cx: &App,
151 ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
152 self.related_buffers
153 .iter_mut()
154 .map(|related| (related.related_file(cx), related.buffer.clone()))
155 }
156
157 pub fn set_related_files(&mut self, files: Vec<RelatedFile>, cx: &App) {
158 self.related_buffers = files
159 .into_iter()
160 .filter_map(|file| {
161 let project = self.project.upgrade()?;
162 let project = project.read(cx);
163 let worktree = project.worktrees(cx).find(|wt| {
164 let root_name = wt.read(cx).root_name().as_unix_str();
165 file.path
166 .components()
167 .next()
168 .is_some_and(|c| c.as_os_str() == root_name)
169 })?;
170 let worktree = worktree.read(cx);
171 let relative_path = file
172 .path
173 .strip_prefix(worktree.root_name().as_unix_str())
174 .ok()?;
175 let relative_path = RelPath::new(relative_path, PathStyle::Posix).ok()?;
176 let project_path = ProjectPath {
177 worktree_id: worktree.id(),
178 path: relative_path.into_owned().into(),
179 };
180 let buffer = project.get_open_buffer(&project_path, cx)?;
181 let snapshot = buffer.read(cx).snapshot();
182 let mut anchor_ranges = Vec::with_capacity(file.excerpts.len());
183 let mut excerpt_orders = Vec::with_capacity(file.excerpts.len());
184 for excerpt in &file.excerpts {
185 let start = snapshot.anchor_before(Point::new(excerpt.row_range.start, 0));
186 let end_row = excerpt.row_range.end;
187 let end_col = snapshot.line_len(end_row);
188 let end = snapshot.anchor_after(Point::new(end_row, end_col));
189 anchor_ranges.push(start..end);
190 excerpt_orders.push(excerpt.order);
191 }
192 Some(RelatedBuffer {
193 buffer,
194 path: file.path.clone(),
195 anchor_ranges,
196 excerpt_orders,
197 cached_file: None,
198 })
199 })
200 .collect();
201 }
202
203 async fn fetch_excerpts(
204 this: WeakEntity<Self>,
205 buffer: Entity<Buffer>,
206 position: Anchor,
207 cx: &mut AsyncApp,
208 ) -> Result<()> {
209 let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
210 (
211 this.project.upgrade(),
212 buffer.read(cx).snapshot(),
213 this.identifier_line_count,
214 )
215 })?;
216 let Some(project) = project else {
217 return Ok(());
218 };
219
220 let file = snapshot.file().cloned();
221 if let Some(file) = &file {
222 log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
223 }
224
225 this.update(cx, |_, cx| {
226 cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
227 })?;
228
229 let identifiers_with_ranks = cx
230 .background_spawn(async move {
231 let cursor_offset = position.to_offset(&snapshot);
232 let identifiers =
233 identifiers_for_position(&snapshot, position, identifier_line_count);
234
235 // Compute byte distance from cursor to each identifier, then sort by
236 // distance so we can assign ordinal ranks. Identifiers at the same
237 // distance share the same rank.
238 let mut identifiers_with_distance: Vec<(Identifier, usize)> = identifiers
239 .into_iter()
240 .map(|id| {
241 let start = id.range.start.to_offset(&snapshot);
242 let end = id.range.end.to_offset(&snapshot);
243 let distance = if cursor_offset < start {
244 start - cursor_offset
245 } else if cursor_offset > end {
246 cursor_offset - end
247 } else {
248 0
249 };
250 (id, distance)
251 })
252 .collect();
253 identifiers_with_distance.sort_by_key(|(_, distance)| *distance);
254
255 let mut cursor_distances: HashMap<Identifier, usize> = HashMap::default();
256 let mut current_rank = 0;
257 let mut previous_distance = None;
258 for (identifier, distance) in &identifiers_with_distance {
259 if previous_distance != Some(*distance) {
260 current_rank = cursor_distances.len();
261 previous_distance = Some(*distance);
262 }
263 cursor_distances.insert(identifier.clone(), current_rank);
264 }
265
266 (identifiers_with_distance, cursor_distances)
267 })
268 .await;
269
270 let (identifiers_with_distance, cursor_distances) = identifiers_with_ranks;
271
272 let async_cx = cx.clone();
273 let start_time = Instant::now();
274 let futures = this.update(cx, |this, cx| {
275 identifiers_with_distance
276 .into_iter()
277 .map(|(identifier, _)| {
278 let task = if let Some(entry) = this.cache.get(&identifier) {
279 DefinitionTask::CacheHit(entry.clone())
280 } else {
281 let project = this.project.clone();
282 let buffer = buffer.downgrade();
283 DefinitionTask::CacheMiss(cx.spawn(async move |_, cx| {
284 let buffer = buffer.upgrade()?;
285 let definitions = project
286 .update(cx, |project, cx| {
287 project.definitions(&buffer, identifier.range.start, cx)
288 })
289 .ok()?;
290 let type_definitions = project
291 .update(cx, |project, cx| {
292 project.type_definitions(&buffer, identifier.range.start, cx)
293 })
294 .ok()?;
295 Some((definitions, type_definitions))
296 }))
297 };
298
299 let cx = async_cx.clone();
300 let project = project.clone();
301 async move {
302 match task {
303 DefinitionTask::CacheHit(cache_entry) => {
304 Some((identifier, cache_entry, None))
305 }
306 DefinitionTask::CacheMiss(task) => {
307 let (definitions, type_definitions) = task.await?;
308 let (definition_locations, type_definition_locations) =
309 futures::join!(definitions, type_definitions);
310 let duration = start_time.elapsed();
311
312 let definition_locations =
313 definition_locations.log_err().flatten().unwrap_or_default();
314 let type_definition_locations = type_definition_locations
315 .log_err()
316 .flatten()
317 .unwrap_or_default();
318
319 Some(cx.update(|cx| {
320 let definitions: SmallVec<[CachedDefinition; 1]> =
321 definition_locations
322 .into_iter()
323 .filter_map(|location| {
324 process_definition(location, &project, cx)
325 })
326 .collect();
327
328 let type_definitions: SmallVec<[CachedDefinition; 1]> =
329 type_definition_locations
330 .into_iter()
331 .filter_map(|location| {
332 process_definition(location, &project, cx)
333 })
334 .filter(|type_def| {
335 !definitions.iter().any(|def| {
336 def.buffer.entity_id()
337 == type_def.buffer.entity_id()
338 && def.anchor_range == type_def.anchor_range
339 })
340 })
341 .collect();
342
343 (
344 identifier,
345 Arc::new(CacheEntry {
346 definitions,
347 type_definitions,
348 }),
349 Some(duration),
350 )
351 }))
352 }
353 }
354 }
355 })
356 .collect::<Vec<_>>()
357 })?;
358
359 let mut cache_hit_count = 0;
360 let mut cache_miss_count = 0;
361 let mut mean_definition_latency = Duration::ZERO;
362 let mut max_definition_latency = Duration::ZERO;
363 let mut new_cache = HashMap::default();
364 new_cache.reserve(futures.len());
365 for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
366 new_cache.insert(identifier, entry);
367 if let Some(duration) = duration {
368 cache_miss_count += 1;
369 mean_definition_latency += duration;
370 max_definition_latency = max_definition_latency.max(duration);
371 } else {
372 cache_hit_count += 1;
373 }
374 }
375 mean_definition_latency /= cache_miss_count.max(1) as u32;
376
377 let (new_cache, related_buffers) =
378 rebuild_related_files(&project, new_cache, &cursor_distances, cx).await?;
379
380 if let Some(file) = &file {
381 log::debug!(
382 "finished retrieving context buffer:{}, latency:{:?}",
383 file.path().as_unix_str(),
384 start_time.elapsed()
385 );
386 }
387
388 this.update(cx, |this, cx| {
389 this.cache = new_cache;
390 this.related_buffers = related_buffers;
391 cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
392 cache_hit_count,
393 cache_miss_count,
394 mean_definition_latency,
395 max_definition_latency,
396 });
397 })?;
398
399 anyhow::Ok(())
400 }
401}
402
403async fn rebuild_related_files(
404 project: &Entity<Project>,
405 mut new_entries: HashMap<Identifier, Arc<CacheEntry>>,
406 cursor_distances: &HashMap<Identifier, usize>,
407 cx: &mut AsyncApp,
408) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedBuffer>)> {
409 let mut snapshots = HashMap::default();
410 let mut worktree_root_names = HashMap::default();
411 for entry in new_entries.values() {
412 for definition in entry
413 .definitions
414 .iter()
415 .chain(entry.type_definitions.iter())
416 {
417 if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
418 definition
419 .buffer
420 .read_with(cx, |buffer, _| buffer.parsing_idle())
421 .await;
422 e.insert(
423 definition
424 .buffer
425 .read_with(cx, |buffer, _| buffer.snapshot()),
426 );
427 }
428 let worktree_id = definition.path.worktree_id;
429 if let hash_map::Entry::Vacant(e) =
430 worktree_root_names.entry(definition.path.worktree_id)
431 {
432 project.read_with(cx, |project, cx| {
433 if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
434 e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
435 }
436 });
437 }
438 }
439 }
440
441 let cursor_distances = cursor_distances.clone();
442 Ok(cx
443 .background_spawn(async move {
444 let mut ranges_by_buffer =
445 HashMap::<EntityId, (Entity<Buffer>, Vec<(Range<Point>, usize)>)>::default();
446 let mut paths_by_buffer = HashMap::default();
447 let mut min_rank_by_buffer = HashMap::<EntityId, usize>::default();
448 for (identifier, entry) in new_entries.iter_mut() {
449 let rank = cursor_distances
450 .get(identifier)
451 .copied()
452 .unwrap_or(usize::MAX);
453 for definition in entry
454 .definitions
455 .iter()
456 .chain(entry.type_definitions.iter())
457 {
458 let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
459 continue;
460 };
461 paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
462
463 let buffer_rank = min_rank_by_buffer
464 .entry(definition.buffer.entity_id())
465 .or_insert(usize::MAX);
466 *buffer_rank = (*buffer_rank).min(rank);
467
468 ranges_by_buffer
469 .entry(definition.buffer.entity_id())
470 .or_insert_with(|| (definition.buffer.clone(), Vec::new()))
471 .1
472 .push((definition.anchor_range.to_point(snapshot), rank));
473 }
474 }
475
476 let mut related_buffers: Vec<RelatedBuffer> = ranges_by_buffer
477 .into_iter()
478 .filter_map(|(entity_id, (buffer, ranges))| {
479 let snapshot = snapshots.get(&entity_id)?;
480 let project_path = paths_by_buffer.get(&entity_id)?;
481 let assembled = assemble_excerpt_ranges(snapshot, ranges);
482 let root_name = worktree_root_names.get(&project_path.worktree_id)?;
483
484 let path: Arc<Path> = Path::new(&format!(
485 "{}/{}",
486 root_name,
487 project_path.path.as_unix_str()
488 ))
489 .into();
490
491 let mut anchor_ranges = Vec::with_capacity(assembled.len());
492 let mut excerpt_orders = Vec::with_capacity(assembled.len());
493 for (row_range, order) in assembled {
494 let start = snapshot.anchor_before(Point::new(row_range.start, 0));
495 let end_col = snapshot.line_len(row_range.end);
496 let end = snapshot.anchor_after(Point::new(row_range.end, end_col));
497 anchor_ranges.push(start..end);
498 excerpt_orders.push(order);
499 }
500
501 let mut related_buffer = RelatedBuffer {
502 buffer,
503 path,
504 anchor_ranges,
505 excerpt_orders,
506 cached_file: None,
507 };
508 related_buffer.fill_cache(snapshot);
509 Some(related_buffer)
510 })
511 .collect();
512
513 related_buffers.sort_by(|a, b| {
514 let rank_a = min_rank_by_buffer
515 .get(&a.buffer.entity_id())
516 .copied()
517 .unwrap_or(usize::MAX);
518 let rank_b = min_rank_by_buffer
519 .get(&b.buffer.entity_id())
520 .copied()
521 .unwrap_or(usize::MAX);
522 rank_a.cmp(&rank_b).then_with(|| a.path.cmp(&b.path))
523 });
524
525 (new_entries, related_buffers)
526 })
527 .await)
528}
529
530impl RelatedBuffer {
531 fn related_file(&mut self, cx: &App) -> RelatedFile {
532 let buffer = self.buffer.read(cx);
533 let path = self.path.clone();
534 let cached = if let Some(cached) = &self.cached_file
535 && buffer.version() == cached.buffer_version
536 {
537 cached
538 } else {
539 self.fill_cache(buffer)
540 };
541 let related_file = RelatedFile {
542 path,
543 excerpts: cached.excerpts.clone(),
544 max_row: buffer.max_point().row,
545 in_open_source_repo: false,
546 };
547 return related_file;
548 }
549
550 fn fill_cache(&mut self, buffer: &text::BufferSnapshot) -> &CachedRelatedFile {
551 let excerpts = self
552 .anchor_ranges
553 .iter()
554 .zip(self.excerpt_orders.iter())
555 .map(|(range, &order)| {
556 let start = range.start.to_point(buffer);
557 let end = range.end.to_point(buffer);
558 RelatedExcerpt {
559 row_range: start.row..end.row,
560 text: buffer.text_for_range(start..end).collect::<String>().into(),
561 order,
562 }
563 })
564 .collect::<Vec<_>>();
565 self.cached_file = Some(CachedRelatedFile {
566 excerpts,
567 buffer_version: buffer.version().clone(),
568 });
569 self.cached_file.as_ref().unwrap()
570 }
571}
572
573use language::ToPoint as _;
574
575const MAX_TARGET_LEN: usize = 128;
576
577fn process_definition(
578 location: LocationLink,
579 project: &Entity<Project>,
580 cx: &mut App,
581) -> Option<CachedDefinition> {
582 let buffer = location.target.buffer.read(cx);
583 let anchor_range = location.target.range;
584 let file = buffer.file()?;
585 let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
586 if worktree.read(cx).is_single_file() {
587 return None;
588 }
589
590 // If the target range is large, it likely means we requested the definition of an entire module.
591 // For individual definitions, the target range should be small as it only covers the symbol.
592 let buffer = location.target.buffer.read(cx);
593 let target_len = anchor_range.to_offset(&buffer).len();
594 if target_len > MAX_TARGET_LEN {
595 return None;
596 }
597
598 Some(CachedDefinition {
599 path: ProjectPath {
600 worktree_id: file.worktree_id(cx),
601 path: file.path().clone(),
602 },
603 buffer: location.target.buffer,
604 anchor_range,
605 })
606}
607
608/// Gets all of the identifiers that are present in the given line, and its containing
609/// outline items.
610fn identifiers_for_position(
611 buffer: &BufferSnapshot,
612 position: Anchor,
613 identifier_line_count: u32,
614) -> Vec<Identifier> {
615 let offset = position.to_offset(buffer);
616 let point = buffer.offset_to_point(offset);
617
618 // Search for identifiers on lines adjacent to the cursor.
619 let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
620 let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
621 let line_range = start..end;
622 let mut ranges = vec![line_range.to_offset(&buffer)];
623
624 // Search for identifiers mentioned in headers/signatures of containing outline items.
625 let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
626 for item in outline_items {
627 if let Some(body_range) = item.body_range(&buffer) {
628 ranges.push(item.range.start..body_range.start.to_offset(&buffer));
629 } else {
630 ranges.push(item.range.clone());
631 }
632 }
633
634 ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
635 ranges.dedup_by(|a, b| {
636 if a.start <= b.end {
637 b.start = b.start.min(a.start);
638 b.end = b.end.max(a.end);
639 true
640 } else {
641 false
642 }
643 });
644
645 let mut identifiers = Vec::new();
646 let outer_range =
647 ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
648
649 let mut captures = buffer.captures(outer_range.clone(), |grammar| {
650 grammar
651 .highlights_config
652 .as_ref()
653 .map(|config| &config.query)
654 });
655
656 for range in ranges {
657 captures.set_byte_range(range.start..outer_range.end);
658
659 let mut last_range = None;
660 while let Some(capture) = captures.peek() {
661 let node_range = capture.node.byte_range();
662 if node_range.start > range.end {
663 break;
664 }
665 let config = captures.grammars()[capture.grammar_index]
666 .highlights_config
667 .as_ref();
668
669 if let Some(config) = config
670 && config.identifier_capture_indices.contains(&capture.index)
671 && range.contains_inclusive(&node_range)
672 && !is_tsx_tag(&buffer, &capture.node)
673 && Some(&node_range) != last_range.as_ref()
674 {
675 let name = buffer.text_for_range(node_range.clone()).collect();
676 identifiers.push(Identifier {
677 range: buffer.anchor_after(node_range.start)
678 ..buffer.anchor_before(node_range.end),
679 name,
680 });
681 last_range = Some(node_range);
682 }
683
684 captures.advance();
685 }
686 }
687
688 identifiers
689}
690
691fn is_tsx_tag(buffer: &BufferSnapshot, node: &tree_sitter::Node) -> bool {
692 let Some(language_config) = buffer
693 .language()
694 .and_then(|l| l.config().jsx_tag_auto_close.as_ref())
695 else {
696 return false;
697 };
698 let Some(parent_kind) = node.parent().map(|n| n.kind()) else {
699 return false;
700 };
701
702 if parent_kind != &language_config.open_tag_node_name
703 && parent_kind != &language_config.close_tag_node_name
704 && parent_kind != &language_config.tag_name_node_name
705 && language_config
706 .erroneous_close_tag_name_node_name
707 .as_ref()
708 .is_some_and(|kind| parent_kind != kind)
709 && language_config
710 .erroneous_close_tag_node_name
711 .as_ref()
712 .is_some_and(|kind| parent_kind == kind)
713 && parent_kind != &language_config.jsx_element_node_name
714 {
715 return false;
716 }
717 // do fetch `<Component />`, model probably understands `<div>`, but needs info for user defined components
718 if !buffer
719 .text_for_range(node.byte_range())
720 .all(|str| str.chars().all(|c| c.is_lowercase()))
721 {
722 return false;
723 }
724 true
725}