1use crate::assemble_excerpts::assemble_excerpts;
2use anyhow::Result;
3use collections::HashMap;
4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
5use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
7use project::{LocationLink, Project, ProjectPath};
8use smallvec::SmallVec;
9use std::{
10 collections::hash_map,
11 ops::Range,
12 path::Path,
13 sync::Arc,
14 time::{Duration, Instant},
15};
16use util::{RangeExt as _, ResultExt};
17
18mod assemble_excerpts;
19#[cfg(test)]
20mod edit_prediction_context_tests;
21mod excerpt;
22#[cfg(test)]
23mod fake_definition_lsp;
24
25pub use cloud_llm_client::predict_edits_v3::Line;
26pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
27pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
28
29const IDENTIFIER_LINE_COUNT: u32 = 3;
30
31pub struct RelatedExcerptStore {
32 project: WeakEntity<Project>,
33 related_files: Arc<[RelatedFile]>,
34 related_file_buffers: Vec<Entity<Buffer>>,
35 cache: HashMap<Identifier, Arc<CacheEntry>>,
36 update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
37 identifier_line_count: u32,
38}
39
40pub enum RelatedExcerptStoreEvent {
41 StartedRefresh,
42 FinishedRefresh {
43 cache_hit_count: usize,
44 cache_miss_count: usize,
45 mean_definition_latency: Duration,
46 max_definition_latency: Duration,
47 },
48}
49
50#[derive(Clone, Debug, PartialEq, Eq, Hash)]
51struct Identifier {
52 pub name: String,
53 pub range: Range<Anchor>,
54}
55
56enum DefinitionTask {
57 CacheHit(Arc<CacheEntry>),
58 CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
59}
60
61#[derive(Debug)]
62struct CacheEntry {
63 definitions: SmallVec<[CachedDefinition; 1]>,
64}
65
66#[derive(Clone, Debug)]
67struct CachedDefinition {
68 path: ProjectPath,
69 buffer: Entity<Buffer>,
70 anchor_range: Range<Anchor>,
71}
72
73const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
74
75impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
76
77impl RelatedExcerptStore {
78 pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
79 let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
80 cx.spawn(async move |this, cx| {
81 let executor = cx.background_executor().clone();
82 while let Some((mut buffer, mut position)) = update_rx.next().await {
83 let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
84 loop {
85 futures::select_biased! {
86 next = update_rx.next() => {
87 if let Some((new_buffer, new_position)) = next {
88 buffer = new_buffer;
89 position = new_position;
90 timer = executor.timer(DEBOUNCE_DURATION).fuse();
91 } else {
92 return anyhow::Ok(());
93 }
94 }
95 _ = timer => break,
96 }
97 }
98
99 Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
100 }
101 anyhow::Ok(())
102 })
103 .detach_and_log_err(cx);
104
105 RelatedExcerptStore {
106 project: project.downgrade(),
107 update_tx,
108 related_files: Vec::new().into(),
109 related_file_buffers: Vec::new(),
110 cache: Default::default(),
111 identifier_line_count: IDENTIFIER_LINE_COUNT,
112 }
113 }
114
115 pub fn set_identifier_line_count(&mut self, count: u32) {
116 self.identifier_line_count = count;
117 }
118
119 pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
120 self.update_tx.unbounded_send((buffer, position)).ok();
121 }
122
123 pub fn related_files(&self) -> Arc<[RelatedFile]> {
124 self.related_files.clone()
125 }
126
127 pub fn related_files_with_buffers(
128 &self,
129 ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
130 self.related_files
131 .iter()
132 .cloned()
133 .zip(self.related_file_buffers.iter().cloned())
134 }
135
136 pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
137 self.related_files = files.into();
138 }
139
140 async fn fetch_excerpts(
141 this: WeakEntity<Self>,
142 buffer: Entity<Buffer>,
143 position: Anchor,
144 cx: &mut AsyncApp,
145 ) -> Result<()> {
146 let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
147 (
148 this.project.upgrade(),
149 buffer.read(cx).snapshot(),
150 this.identifier_line_count,
151 )
152 })?;
153 let Some(project) = project else {
154 return Ok(());
155 };
156
157 let file = snapshot.file().cloned();
158 if let Some(file) = &file {
159 log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
160 }
161
162 this.update(cx, |_, cx| {
163 cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
164 })?;
165
166 let identifiers = cx
167 .background_spawn(async move {
168 identifiers_for_position(&snapshot, position, identifier_line_count)
169 })
170 .await;
171
172 let async_cx = cx.clone();
173 let start_time = Instant::now();
174 let futures = this.update(cx, |this, cx| {
175 identifiers
176 .into_iter()
177 .filter_map(|identifier| {
178 let task = if let Some(entry) = this.cache.get(&identifier) {
179 DefinitionTask::CacheHit(entry.clone())
180 } else {
181 DefinitionTask::CacheMiss(
182 this.project
183 .update(cx, |project, cx| {
184 project.definitions(&buffer, identifier.range.start, cx)
185 })
186 .ok()?,
187 )
188 };
189
190 let cx = async_cx.clone();
191 let project = project.clone();
192 Some(async move {
193 match task {
194 DefinitionTask::CacheHit(cache_entry) => {
195 Some((identifier, cache_entry, None))
196 }
197 DefinitionTask::CacheMiss(task) => {
198 let locations = task.await.log_err()??;
199 let duration = start_time.elapsed();
200 Some(cx.update(|cx| {
201 (
202 identifier,
203 Arc::new(CacheEntry {
204 definitions: locations
205 .into_iter()
206 .filter_map(|location| {
207 process_definition(location, &project, cx)
208 })
209 .collect(),
210 }),
211 Some(duration),
212 )
213 }))
214 }
215 }
216 })
217 })
218 .collect::<Vec<_>>()
219 })?;
220
221 let mut cache_hit_count = 0;
222 let mut cache_miss_count = 0;
223 let mut mean_definition_latency = Duration::ZERO;
224 let mut max_definition_latency = Duration::ZERO;
225 let mut new_cache = HashMap::default();
226 new_cache.reserve(futures.len());
227 for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
228 new_cache.insert(identifier, entry);
229 if let Some(duration) = duration {
230 cache_miss_count += 1;
231 mean_definition_latency += duration;
232 max_definition_latency = max_definition_latency.max(duration);
233 } else {
234 cache_hit_count += 1;
235 }
236 }
237 mean_definition_latency /= cache_miss_count.max(1) as u32;
238
239 let (new_cache, related_files, related_file_buffers) =
240 rebuild_related_files(&project, new_cache, cx).await?;
241
242 if let Some(file) = &file {
243 log::debug!(
244 "finished retrieving context buffer:{}, latency:{:?}",
245 file.path().as_unix_str(),
246 start_time.elapsed()
247 );
248 }
249
250 this.update(cx, |this, cx| {
251 this.cache = new_cache;
252 this.related_files = related_files.into();
253 this.related_file_buffers = related_file_buffers;
254 cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
255 cache_hit_count,
256 cache_miss_count,
257 mean_definition_latency,
258 max_definition_latency,
259 });
260 })?;
261
262 anyhow::Ok(())
263 }
264}
265
266async fn rebuild_related_files(
267 project: &Entity<Project>,
268 new_entries: HashMap<Identifier, Arc<CacheEntry>>,
269 cx: &mut AsyncApp,
270) -> Result<(
271 HashMap<Identifier, Arc<CacheEntry>>,
272 Vec<RelatedFile>,
273 Vec<Entity<Buffer>>,
274)> {
275 let mut snapshots = HashMap::default();
276 let mut worktree_root_names = HashMap::default();
277 for entry in new_entries.values() {
278 for definition in &entry.definitions {
279 if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
280 definition
281 .buffer
282 .read_with(cx, |buffer, _| buffer.parsing_idle())
283 .await;
284 e.insert(
285 definition
286 .buffer
287 .read_with(cx, |buffer, _| buffer.snapshot()),
288 );
289 }
290 let worktree_id = definition.path.worktree_id;
291 if let hash_map::Entry::Vacant(e) =
292 worktree_root_names.entry(definition.path.worktree_id)
293 {
294 project.read_with(cx, |project, cx| {
295 if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
296 e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
297 }
298 });
299 }
300 }
301 }
302
303 Ok(cx
304 .background_spawn(async move {
305 let mut files = Vec::new();
306 let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
307 let mut paths_by_buffer = HashMap::default();
308 for entry in new_entries.values() {
309 for definition in &entry.definitions {
310 let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
311 continue;
312 };
313 paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
314 ranges_by_buffer
315 .entry(definition.buffer.clone())
316 .or_default()
317 .push(definition.anchor_range.to_point(snapshot));
318 }
319 }
320
321 for (buffer, ranges) in ranges_by_buffer {
322 let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
323 continue;
324 };
325 let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
326 continue;
327 };
328 let excerpts = assemble_excerpts(snapshot, ranges);
329 let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
330 continue;
331 };
332
333 let path = Path::new(&format!(
334 "{}/{}",
335 root_name,
336 project_path.path.as_unix_str()
337 ))
338 .into();
339
340 files.push((
341 buffer,
342 RelatedFile {
343 path,
344 excerpts,
345 max_row: snapshot.max_point().row,
346 },
347 ));
348 }
349
350 files.sort_by_key(|(_, file)| file.path.clone());
351 let (related_buffers, related_files) = files.into_iter().unzip();
352
353 (new_entries, related_files, related_buffers)
354 })
355 .await)
356}
357
358const MAX_TARGET_LEN: usize = 128;
359
360fn process_definition(
361 location: LocationLink,
362 project: &Entity<Project>,
363 cx: &mut App,
364) -> Option<CachedDefinition> {
365 let buffer = location.target.buffer.read(cx);
366 let anchor_range = location.target.range;
367 let file = buffer.file()?;
368 let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
369 if worktree.read(cx).is_single_file() {
370 return None;
371 }
372
373 // If the target range is large, it likely means we requested the definition of an entire module.
374 // For individual definitions, the target range should be small as it only covers the symbol.
375 let buffer = location.target.buffer.read(cx);
376 let target_len = anchor_range.to_offset(&buffer).len();
377 if target_len > MAX_TARGET_LEN {
378 return None;
379 }
380
381 Some(CachedDefinition {
382 path: ProjectPath {
383 worktree_id: file.worktree_id(cx),
384 path: file.path().clone(),
385 },
386 buffer: location.target.buffer,
387 anchor_range,
388 })
389}
390
391/// Gets all of the identifiers that are present in the given line, and its containing
392/// outline items.
393fn identifiers_for_position(
394 buffer: &BufferSnapshot,
395 position: Anchor,
396 identifier_line_count: u32,
397) -> Vec<Identifier> {
398 let offset = position.to_offset(buffer);
399 let point = buffer.offset_to_point(offset);
400
401 // Search for identifiers on lines adjacent to the cursor.
402 let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
403 let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
404 let line_range = start..end;
405 let mut ranges = vec![line_range.to_offset(&buffer)];
406
407 // Search for identifiers mentioned in headers/signatures of containing outline items.
408 let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
409 for item in outline_items {
410 if let Some(body_range) = item.body_range(&buffer) {
411 ranges.push(item.range.start..body_range.start.to_offset(&buffer));
412 } else {
413 ranges.push(item.range.clone());
414 }
415 }
416
417 ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
418 ranges.dedup_by(|a, b| {
419 if a.start <= b.end {
420 b.start = b.start.min(a.start);
421 b.end = b.end.max(a.end);
422 true
423 } else {
424 false
425 }
426 });
427
428 let mut identifiers = Vec::new();
429 let outer_range =
430 ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
431
432 let mut captures = buffer
433 .syntax
434 .captures(outer_range.clone(), &buffer.text, |grammar| {
435 grammar
436 .highlights_config
437 .as_ref()
438 .map(|config| &config.query)
439 });
440
441 for range in ranges {
442 captures.set_byte_range(range.start..outer_range.end);
443
444 let mut last_range = None;
445 while let Some(capture) = captures.peek() {
446 let node_range = capture.node.byte_range();
447 if node_range.start > range.end {
448 break;
449 }
450 let config = captures.grammars()[capture.grammar_index]
451 .highlights_config
452 .as_ref();
453
454 if let Some(config) = config
455 && config.identifier_capture_indices.contains(&capture.index)
456 && range.contains_inclusive(&node_range)
457 && Some(&node_range) != last_range.as_ref()
458 {
459 let name = buffer.text_for_range(node_range.clone()).collect();
460 identifiers.push(Identifier {
461 range: buffer.anchor_after(node_range.start)
462 ..buffer.anchor_before(node_range.end),
463 name,
464 });
465 last_range = Some(node_range);
466 }
467
468 captures.advance();
469 }
470 }
471
472 identifiers
473}