1use crate::assemble_excerpts::assemble_excerpt_ranges;
2use anyhow::Result;
3use collections::HashMap;
4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
5use gpui::{App, AppContext, AsyncApp, Context, Entity, EntityId, EventEmitter, Task, WeakEntity};
6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
7use project::{LocationLink, Project, ProjectPath};
8use smallvec::SmallVec;
9use std::{
10 collections::hash_map,
11 ops::Range,
12 path::Path,
13 sync::Arc,
14 time::{Duration, Instant},
15};
16use util::paths::PathStyle;
17use util::rel_path::RelPath;
18use util::{RangeExt as _, ResultExt};
19
20mod assemble_excerpts;
21#[cfg(test)]
22mod edit_prediction_context_tests;
23mod excerpt;
24#[cfg(test)]
25mod fake_definition_lsp;
26
27pub use cloud_llm_client::predict_edits_v3::Line;
28pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
29pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
30
31const IDENTIFIER_LINE_COUNT: u32 = 3;
32
33pub struct RelatedExcerptStore {
34 project: WeakEntity<Project>,
35 related_buffers: Vec<RelatedBuffer>,
36 cache: HashMap<Identifier, Arc<CacheEntry>>,
37 update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
38 identifier_line_count: u32,
39}
40
41struct RelatedBuffer {
42 buffer: Entity<Buffer>,
43 path: Arc<Path>,
44 anchor_ranges: Vec<Range<Anchor>>,
45 cached_file: Option<CachedRelatedFile>,
46}
47
48struct CachedRelatedFile {
49 excerpts: Vec<RelatedExcerpt>,
50 buffer_version: clock::Global,
51}
52
53pub enum RelatedExcerptStoreEvent {
54 StartedRefresh,
55 FinishedRefresh {
56 cache_hit_count: usize,
57 cache_miss_count: usize,
58 mean_definition_latency: Duration,
59 max_definition_latency: Duration,
60 },
61}
62
63#[derive(Clone, Debug, PartialEq, Eq, Hash)]
64struct Identifier {
65 pub name: String,
66 pub range: Range<Anchor>,
67}
68
69enum DefinitionTask {
70 CacheHit(Arc<CacheEntry>),
71 CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
72}
73
74#[derive(Debug)]
75struct CacheEntry {
76 definitions: SmallVec<[CachedDefinition; 1]>,
77}
78
79#[derive(Clone, Debug)]
80struct CachedDefinition {
81 path: ProjectPath,
82 buffer: Entity<Buffer>,
83 anchor_range: Range<Anchor>,
84}
85
86const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
87
88impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
89
90impl RelatedExcerptStore {
91 pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
92 let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
93 cx.spawn(async move |this, cx| {
94 let executor = cx.background_executor().clone();
95 while let Some((mut buffer, mut position)) = update_rx.next().await {
96 let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
97 loop {
98 futures::select_biased! {
99 next = update_rx.next() => {
100 if let Some((new_buffer, new_position)) = next {
101 buffer = new_buffer;
102 position = new_position;
103 timer = executor.timer(DEBOUNCE_DURATION).fuse();
104 } else {
105 return anyhow::Ok(());
106 }
107 }
108 _ = timer => break,
109 }
110 }
111
112 Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
113 }
114 anyhow::Ok(())
115 })
116 .detach_and_log_err(cx);
117
118 RelatedExcerptStore {
119 project: project.downgrade(),
120 update_tx,
121 related_buffers: Vec::new(),
122 cache: Default::default(),
123 identifier_line_count: IDENTIFIER_LINE_COUNT,
124 }
125 }
126
127 pub fn set_identifier_line_count(&mut self, count: u32) {
128 self.identifier_line_count = count;
129 }
130
131 pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
132 self.update_tx.unbounded_send((buffer, position)).ok();
133 }
134
135 pub fn related_files(&mut self, cx: &App) -> Vec<RelatedFile> {
136 self.related_buffers
137 .iter_mut()
138 .map(|related| related.related_file(cx))
139 .collect()
140 }
141
142 pub fn related_files_with_buffers(&mut self, cx: &App) -> Vec<(RelatedFile, Entity<Buffer>)> {
143 self.related_buffers
144 .iter_mut()
145 .map(|related| (related.related_file(cx), related.buffer.clone()))
146 .collect::<Vec<_>>()
147 }
148
149 pub fn set_related_files(&mut self, files: Vec<RelatedFile>, cx: &App) {
150 self.related_buffers = files
151 .into_iter()
152 .filter_map(|file| {
153 let project = self.project.upgrade()?;
154 let project = project.read(cx);
155 let worktree = project.worktrees(cx).find(|wt| {
156 let root_name = wt.read(cx).root_name().as_unix_str();
157 file.path
158 .components()
159 .next()
160 .is_some_and(|c| c.as_os_str() == root_name)
161 })?;
162 let worktree = worktree.read(cx);
163 let relative_path = file
164 .path
165 .strip_prefix(worktree.root_name().as_unix_str())
166 .ok()?;
167 let relative_path = RelPath::new(relative_path, PathStyle::Posix).ok()?;
168 let project_path = ProjectPath {
169 worktree_id: worktree.id(),
170 path: relative_path.into_owned().into(),
171 };
172 let buffer = project.get_open_buffer(&project_path, cx)?;
173 let snapshot = buffer.read(cx).snapshot();
174 let anchor_ranges = file
175 .excerpts
176 .iter()
177 .map(|excerpt| {
178 let start = snapshot.anchor_before(Point::new(excerpt.row_range.start, 0));
179 let end_row = excerpt.row_range.end;
180 let end_col = snapshot.line_len(end_row);
181 let end = snapshot.anchor_after(Point::new(end_row, end_col));
182 start..end
183 })
184 .collect();
185 Some(RelatedBuffer {
186 buffer,
187 path: file.path.clone(),
188 anchor_ranges,
189 cached_file: None,
190 })
191 })
192 .collect();
193 }
194
195 async fn fetch_excerpts(
196 this: WeakEntity<Self>,
197 buffer: Entity<Buffer>,
198 position: Anchor,
199 cx: &mut AsyncApp,
200 ) -> Result<()> {
201 let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
202 (
203 this.project.upgrade(),
204 buffer.read(cx).snapshot(),
205 this.identifier_line_count,
206 )
207 })?;
208 let Some(project) = project else {
209 return Ok(());
210 };
211
212 let file = snapshot.file().cloned();
213 if let Some(file) = &file {
214 log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
215 }
216
217 this.update(cx, |_, cx| {
218 cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
219 })?;
220
221 let identifiers = cx
222 .background_spawn(async move {
223 identifiers_for_position(&snapshot, position, identifier_line_count)
224 })
225 .await;
226
227 let async_cx = cx.clone();
228 let start_time = Instant::now();
229 let futures = this.update(cx, |this, cx| {
230 identifiers
231 .into_iter()
232 .filter_map(|identifier| {
233 let task = if let Some(entry) = this.cache.get(&identifier) {
234 DefinitionTask::CacheHit(entry.clone())
235 } else {
236 DefinitionTask::CacheMiss(
237 this.project
238 .update(cx, |project, cx| {
239 project.definitions(&buffer, identifier.range.start, cx)
240 })
241 .ok()?,
242 )
243 };
244
245 let cx = async_cx.clone();
246 let project = project.clone();
247 Some(async move {
248 match task {
249 DefinitionTask::CacheHit(cache_entry) => {
250 Some((identifier, cache_entry, None))
251 }
252 DefinitionTask::CacheMiss(task) => {
253 let locations = task.await.log_err()??;
254 let duration = start_time.elapsed();
255 Some(cx.update(|cx| {
256 (
257 identifier,
258 Arc::new(CacheEntry {
259 definitions: locations
260 .into_iter()
261 .filter_map(|location| {
262 process_definition(location, &project, cx)
263 })
264 .collect(),
265 }),
266 Some(duration),
267 )
268 }))
269 }
270 }
271 })
272 })
273 .collect::<Vec<_>>()
274 })?;
275
276 let mut cache_hit_count = 0;
277 let mut cache_miss_count = 0;
278 let mut mean_definition_latency = Duration::ZERO;
279 let mut max_definition_latency = Duration::ZERO;
280 let mut new_cache = HashMap::default();
281 new_cache.reserve(futures.len());
282 for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
283 new_cache.insert(identifier, entry);
284 if let Some(duration) = duration {
285 cache_miss_count += 1;
286 mean_definition_latency += duration;
287 max_definition_latency = max_definition_latency.max(duration);
288 } else {
289 cache_hit_count += 1;
290 }
291 }
292 mean_definition_latency /= cache_miss_count.max(1) as u32;
293
294 let (new_cache, related_buffers) = rebuild_related_files(&project, new_cache, cx).await?;
295
296 if let Some(file) = &file {
297 log::debug!(
298 "finished retrieving context buffer:{}, latency:{:?}",
299 file.path().as_unix_str(),
300 start_time.elapsed()
301 );
302 }
303
304 this.update(cx, |this, cx| {
305 this.cache = new_cache;
306 this.related_buffers = related_buffers;
307 cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
308 cache_hit_count,
309 cache_miss_count,
310 mean_definition_latency,
311 max_definition_latency,
312 });
313 })?;
314
315 anyhow::Ok(())
316 }
317}
318
319async fn rebuild_related_files(
320 project: &Entity<Project>,
321 mut new_entries: HashMap<Identifier, Arc<CacheEntry>>,
322 cx: &mut AsyncApp,
323) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedBuffer>)> {
324 let mut snapshots = HashMap::default();
325 let mut worktree_root_names = HashMap::default();
326 for entry in new_entries.values() {
327 for definition in &entry.definitions {
328 if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
329 definition
330 .buffer
331 .read_with(cx, |buffer, _| buffer.parsing_idle())
332 .await;
333 e.insert(
334 definition
335 .buffer
336 .read_with(cx, |buffer, _| buffer.snapshot()),
337 );
338 }
339 let worktree_id = definition.path.worktree_id;
340 if let hash_map::Entry::Vacant(e) =
341 worktree_root_names.entry(definition.path.worktree_id)
342 {
343 project.read_with(cx, |project, cx| {
344 if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
345 e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
346 }
347 });
348 }
349 }
350 }
351
352 Ok(cx
353 .background_spawn(async move {
354 let mut ranges_by_buffer =
355 HashMap::<EntityId, (Entity<Buffer>, Vec<Range<Point>>)>::default();
356 let mut paths_by_buffer = HashMap::default();
357 for entry in new_entries.values_mut() {
358 for definition in &entry.definitions {
359 let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
360 continue;
361 };
362 paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
363
364 ranges_by_buffer
365 .entry(definition.buffer.entity_id())
366 .or_insert_with(|| (definition.buffer.clone(), Vec::new()))
367 .1
368 .push(definition.anchor_range.to_point(snapshot));
369 }
370 }
371
372 let mut related_buffers: Vec<RelatedBuffer> = ranges_by_buffer
373 .into_iter()
374 .filter_map(|(entity_id, (buffer, ranges))| {
375 let snapshot = snapshots.get(&entity_id)?;
376 let project_path = paths_by_buffer.get(&entity_id)?;
377 let row_ranges = assemble_excerpt_ranges(snapshot, ranges);
378 let root_name = worktree_root_names.get(&project_path.worktree_id)?;
379
380 let path: Arc<Path> = Path::new(&format!(
381 "{}/{}",
382 root_name,
383 project_path.path.as_unix_str()
384 ))
385 .into();
386
387 let anchor_ranges = row_ranges
388 .into_iter()
389 .map(|row_range| {
390 let start = snapshot.anchor_before(Point::new(row_range.start, 0));
391 let end_col = snapshot.line_len(row_range.end);
392 let end = snapshot.anchor_after(Point::new(row_range.end, end_col));
393 start..end
394 })
395 .collect();
396
397 let mut related_buffer = RelatedBuffer {
398 buffer,
399 path,
400 anchor_ranges,
401 cached_file: None,
402 };
403 related_buffer.fill_cache(snapshot);
404 Some(related_buffer)
405 })
406 .collect();
407
408 related_buffers.sort_by_key(|related| related.path.clone());
409
410 (new_entries, related_buffers)
411 })
412 .await)
413}
414
415impl RelatedBuffer {
416 fn related_file(&mut self, cx: &App) -> RelatedFile {
417 let buffer = self.buffer.read(cx);
418 let path = self.path.clone();
419 let cached = if let Some(cached) = &self.cached_file
420 && buffer.version() == cached.buffer_version
421 {
422 cached
423 } else {
424 self.fill_cache(buffer)
425 };
426 let related_file = RelatedFile {
427 path,
428 excerpts: cached.excerpts.clone(),
429 max_row: buffer.max_point().row,
430 };
431 return related_file;
432 }
433
434 fn fill_cache(&mut self, buffer: &text::BufferSnapshot) -> &CachedRelatedFile {
435 let excerpts = self
436 .anchor_ranges
437 .iter()
438 .map(|range| {
439 let start = range.start.to_point(buffer);
440 let end = range.end.to_point(buffer);
441 RelatedExcerpt {
442 row_range: start.row..end.row,
443 text: buffer.text_for_range(start..end).collect::<String>().into(),
444 }
445 })
446 .collect::<Vec<_>>();
447 self.cached_file = Some(CachedRelatedFile {
448 excerpts: excerpts,
449 buffer_version: buffer.version().clone(),
450 });
451 self.cached_file.as_ref().unwrap()
452 }
453}
454
455use language::ToPoint as _;
456
457const MAX_TARGET_LEN: usize = 128;
458
459fn process_definition(
460 location: LocationLink,
461 project: &Entity<Project>,
462 cx: &mut App,
463) -> Option<CachedDefinition> {
464 let buffer = location.target.buffer.read(cx);
465 let anchor_range = location.target.range;
466 let file = buffer.file()?;
467 let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
468 if worktree.read(cx).is_single_file() {
469 return None;
470 }
471
472 // If the target range is large, it likely means we requested the definition of an entire module.
473 // For individual definitions, the target range should be small as it only covers the symbol.
474 let buffer = location.target.buffer.read(cx);
475 let target_len = anchor_range.to_offset(&buffer).len();
476 if target_len > MAX_TARGET_LEN {
477 return None;
478 }
479
480 Some(CachedDefinition {
481 path: ProjectPath {
482 worktree_id: file.worktree_id(cx),
483 path: file.path().clone(),
484 },
485 buffer: location.target.buffer,
486 anchor_range,
487 })
488}
489
490/// Gets all of the identifiers that are present in the given line, and its containing
491/// outline items.
492fn identifiers_for_position(
493 buffer: &BufferSnapshot,
494 position: Anchor,
495 identifier_line_count: u32,
496) -> Vec<Identifier> {
497 let offset = position.to_offset(buffer);
498 let point = buffer.offset_to_point(offset);
499
500 // Search for identifiers on lines adjacent to the cursor.
501 let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
502 let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
503 let line_range = start..end;
504 let mut ranges = vec![line_range.to_offset(&buffer)];
505
506 // Search for identifiers mentioned in headers/signatures of containing outline items.
507 let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
508 for item in outline_items {
509 if let Some(body_range) = item.body_range(&buffer) {
510 ranges.push(item.range.start..body_range.start.to_offset(&buffer));
511 } else {
512 ranges.push(item.range.clone());
513 }
514 }
515
516 ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
517 ranges.dedup_by(|a, b| {
518 if a.start <= b.end {
519 b.start = b.start.min(a.start);
520 b.end = b.end.max(a.end);
521 true
522 } else {
523 false
524 }
525 });
526
527 let mut identifiers = Vec::new();
528 let outer_range =
529 ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
530
531 let mut captures = buffer
532 .syntax
533 .captures(outer_range.clone(), &buffer.text, |grammar| {
534 grammar
535 .highlights_config
536 .as_ref()
537 .map(|config| &config.query)
538 });
539
540 for range in ranges {
541 captures.set_byte_range(range.start..outer_range.end);
542
543 let mut last_range = None;
544 while let Some(capture) = captures.peek() {
545 let node_range = capture.node.byte_range();
546 if node_range.start > range.end {
547 break;
548 }
549 let config = captures.grammars()[capture.grammar_index]
550 .highlights_config
551 .as_ref();
552
553 if let Some(config) = config
554 && config.identifier_capture_indices.contains(&capture.index)
555 && range.contains_inclusive(&node_range)
556 && Some(&node_range) != last_range.as_ref()
557 {
558 let name = buffer.text_for_range(node_range.clone()).collect();
559 identifiers.push(Identifier {
560 range: buffer.anchor_after(node_range.start)
561 ..buffer.anchor_before(node_range.end),
562 name,
563 });
564 last_range = Some(node_range);
565 }
566
567 captures.advance();
568 }
569 }
570
571 identifiers
572}