1use crate::assemble_excerpts::assemble_excerpts;
2use anyhow::Result;
3use collections::HashMap;
4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
5use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
7use project::{LocationLink, Project, ProjectPath};
8use smallvec::SmallVec;
9use std::{
10 collections::hash_map,
11 ops::Range,
12 path::Path,
13 sync::Arc,
14 time::{Duration, Instant},
15};
16use util::{RangeExt as _, ResultExt};
17
18mod assemble_excerpts;
19#[cfg(test)]
20mod edit_prediction_context_tests;
21mod excerpt;
22#[cfg(test)]
23mod fake_definition_lsp;
24
25pub use cloud_llm_client::predict_edits_v3::Line;
26pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
27pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
28
29const IDENTIFIER_LINE_COUNT: u32 = 3;
30
31pub struct RelatedExcerptStore {
32 project: WeakEntity<Project>,
33 related_files: Arc<[RelatedFile]>,
34 related_file_buffers: Vec<Entity<Buffer>>,
35 cache: HashMap<Identifier, Arc<CacheEntry>>,
36 update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
37 identifier_line_count: u32,
38}
39
40pub enum RelatedExcerptStoreEvent {
41 StartedRefresh,
42 FinishedRefresh {
43 cache_hit_count: usize,
44 cache_miss_count: usize,
45 mean_definition_latency: Duration,
46 max_definition_latency: Duration,
47 },
48}
49
50#[derive(Clone, Debug, PartialEq, Eq, Hash)]
51struct Identifier {
52 pub name: String,
53 pub range: Range<Anchor>,
54}
55
56enum DefinitionTask {
57 CacheHit(Arc<CacheEntry>),
58 CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
59}
60
61#[derive(Debug)]
62struct CacheEntry {
63 definitions: SmallVec<[CachedDefinition; 1]>,
64}
65
66#[derive(Clone, Debug)]
67struct CachedDefinition {
68 path: ProjectPath,
69 buffer: Entity<Buffer>,
70 anchor_range: Range<Anchor>,
71}
72
73const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
74
75impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
76
77impl RelatedExcerptStore {
78 pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
79 let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
80 cx.spawn(async move |this, cx| {
81 let executor = cx.background_executor().clone();
82 while let Some((mut buffer, mut position)) = update_rx.next().await {
83 let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
84 loop {
85 futures::select_biased! {
86 next = update_rx.next() => {
87 if let Some((new_buffer, new_position)) = next {
88 buffer = new_buffer;
89 position = new_position;
90 timer = executor.timer(DEBOUNCE_DURATION).fuse();
91 } else {
92 return anyhow::Ok(());
93 }
94 }
95 _ = timer => break,
96 }
97 }
98
99 Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
100 }
101 anyhow::Ok(())
102 })
103 .detach_and_log_err(cx);
104
105 RelatedExcerptStore {
106 project: project.downgrade(),
107 update_tx,
108 related_files: Vec::new().into(),
109 related_file_buffers: Vec::new(),
110 cache: Default::default(),
111 identifier_line_count: IDENTIFIER_LINE_COUNT,
112 }
113 }
114
115 pub fn set_identifier_line_count(&mut self, count: u32) {
116 self.identifier_line_count = count;
117 }
118
119 pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
120 self.update_tx.unbounded_send((buffer, position)).ok();
121 }
122
123 pub fn related_files(&self) -> Arc<[RelatedFile]> {
124 self.related_files.clone()
125 }
126
127 pub fn related_files_with_buffers(
128 &self,
129 ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
130 self.related_files
131 .iter()
132 .cloned()
133 .zip(self.related_file_buffers.iter().cloned())
134 }
135
136 pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
137 self.related_files = files.into();
138 }
139
140 async fn fetch_excerpts(
141 this: WeakEntity<Self>,
142 buffer: Entity<Buffer>,
143 position: Anchor,
144 cx: &mut AsyncApp,
145 ) -> Result<()> {
146 let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
147 (
148 this.project.upgrade(),
149 buffer.read(cx).snapshot(),
150 this.identifier_line_count,
151 )
152 })?;
153 let Some(project) = project else {
154 return Ok(());
155 };
156
157 let file = snapshot.file().cloned();
158 if let Some(file) = &file {
159 log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
160 }
161
162 this.update(cx, |_, cx| {
163 cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
164 })?;
165
166 let identifiers = cx
167 .background_spawn(async move {
168 identifiers_for_position(&snapshot, position, identifier_line_count)
169 })
170 .await;
171
172 let async_cx = cx.clone();
173 let start_time = Instant::now();
174 let futures = this.update(cx, |this, cx| {
175 identifiers
176 .into_iter()
177 .filter_map(|identifier| {
178 let task = if let Some(entry) = this.cache.get(&identifier) {
179 DefinitionTask::CacheHit(entry.clone())
180 } else {
181 DefinitionTask::CacheMiss(
182 this.project
183 .update(cx, |project, cx| {
184 project.definitions(&buffer, identifier.range.start, cx)
185 })
186 .ok()?,
187 )
188 };
189
190 let cx = async_cx.clone();
191 let project = project.clone();
192 Some(async move {
193 match task {
194 DefinitionTask::CacheHit(cache_entry) => {
195 Some((identifier, cache_entry, None))
196 }
197 DefinitionTask::CacheMiss(task) => {
198 let locations = task.await.log_err()??;
199 let duration = start_time.elapsed();
200 cx.update(|cx| {
201 (
202 identifier,
203 Arc::new(CacheEntry {
204 definitions: locations
205 .into_iter()
206 .filter_map(|location| {
207 process_definition(location, &project, cx)
208 })
209 .collect(),
210 }),
211 Some(duration),
212 )
213 })
214 .ok()
215 }
216 }
217 })
218 })
219 .collect::<Vec<_>>()
220 })?;
221
222 let mut cache_hit_count = 0;
223 let mut cache_miss_count = 0;
224 let mut mean_definition_latency = Duration::ZERO;
225 let mut max_definition_latency = Duration::ZERO;
226 let mut new_cache = HashMap::default();
227 new_cache.reserve(futures.len());
228 for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
229 new_cache.insert(identifier, entry);
230 if let Some(duration) = duration {
231 cache_miss_count += 1;
232 mean_definition_latency += duration;
233 max_definition_latency = max_definition_latency.max(duration);
234 } else {
235 cache_hit_count += 1;
236 }
237 }
238 mean_definition_latency /= cache_miss_count.max(1) as u32;
239
240 let (new_cache, related_files, related_file_buffers) =
241 rebuild_related_files(&project, new_cache, cx).await?;
242
243 if let Some(file) = &file {
244 log::debug!(
245 "finished retrieving context buffer:{}, latency:{:?}",
246 file.path().as_unix_str(),
247 start_time.elapsed()
248 );
249 }
250
251 this.update(cx, |this, cx| {
252 this.cache = new_cache;
253 this.related_files = related_files.into();
254 this.related_file_buffers = related_file_buffers;
255 cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
256 cache_hit_count,
257 cache_miss_count,
258 mean_definition_latency,
259 max_definition_latency,
260 });
261 })?;
262
263 anyhow::Ok(())
264 }
265}
266
267async fn rebuild_related_files(
268 project: &Entity<Project>,
269 new_entries: HashMap<Identifier, Arc<CacheEntry>>,
270 cx: &mut AsyncApp,
271) -> Result<(
272 HashMap<Identifier, Arc<CacheEntry>>,
273 Vec<RelatedFile>,
274 Vec<Entity<Buffer>>,
275)> {
276 let mut snapshots = HashMap::default();
277 let mut worktree_root_names = HashMap::default();
278 for entry in new_entries.values() {
279 for definition in &entry.definitions {
280 if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
281 definition
282 .buffer
283 .read_with(cx, |buffer, _| buffer.parsing_idle())?
284 .await;
285 e.insert(
286 definition
287 .buffer
288 .read_with(cx, |buffer, _| buffer.snapshot())?,
289 );
290 }
291 let worktree_id = definition.path.worktree_id;
292 if let hash_map::Entry::Vacant(e) =
293 worktree_root_names.entry(definition.path.worktree_id)
294 {
295 project.read_with(cx, |project, cx| {
296 if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
297 e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
298 }
299 })?;
300 }
301 }
302 }
303
304 Ok(cx
305 .background_spawn(async move {
306 let mut files = Vec::new();
307 let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
308 let mut paths_by_buffer = HashMap::default();
309 for entry in new_entries.values() {
310 for definition in &entry.definitions {
311 let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
312 continue;
313 };
314 paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
315 ranges_by_buffer
316 .entry(definition.buffer.clone())
317 .or_default()
318 .push(definition.anchor_range.to_point(snapshot));
319 }
320 }
321
322 for (buffer, ranges) in ranges_by_buffer {
323 let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
324 continue;
325 };
326 let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
327 continue;
328 };
329 let excerpts = assemble_excerpts(snapshot, ranges);
330 let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
331 continue;
332 };
333
334 let path = Path::new(&format!(
335 "{}/{}",
336 root_name,
337 project_path.path.as_unix_str()
338 ))
339 .into();
340
341 files.push((
342 buffer,
343 RelatedFile {
344 path,
345 excerpts,
346 max_row: snapshot.max_point().row,
347 },
348 ));
349 }
350
351 files.sort_by_key(|(_, file)| file.path.clone());
352 let (related_buffers, related_files) = files.into_iter().unzip();
353
354 (new_entries, related_files, related_buffers)
355 })
356 .await)
357}
358
359const MAX_TARGET_LEN: usize = 128;
360
361fn process_definition(
362 location: LocationLink,
363 project: &Entity<Project>,
364 cx: &mut App,
365) -> Option<CachedDefinition> {
366 let buffer = location.target.buffer.read(cx);
367 let anchor_range = location.target.range;
368 let file = buffer.file()?;
369 let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
370 if worktree.read(cx).is_single_file() {
371 return None;
372 }
373
374 // If the target range is large, it likely means we requested the definition of an entire module.
375 // For individual definitions, the target range should be small as it only covers the symbol.
376 let buffer = location.target.buffer.read(cx);
377 let target_len = anchor_range.to_offset(&buffer).len();
378 if target_len > MAX_TARGET_LEN {
379 return None;
380 }
381
382 Some(CachedDefinition {
383 path: ProjectPath {
384 worktree_id: file.worktree_id(cx),
385 path: file.path().clone(),
386 },
387 buffer: location.target.buffer,
388 anchor_range,
389 })
390}
391
392/// Gets all of the identifiers that are present in the given line, and its containing
393/// outline items.
394fn identifiers_for_position(
395 buffer: &BufferSnapshot,
396 position: Anchor,
397 identifier_line_count: u32,
398) -> Vec<Identifier> {
399 let offset = position.to_offset(buffer);
400 let point = buffer.offset_to_point(offset);
401
402 // Search for identifiers on lines adjacent to the cursor.
403 let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
404 let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
405 let line_range = start..end;
406 let mut ranges = vec![line_range.to_offset(&buffer)];
407
408 // Search for identifiers mentioned in headers/signatures of containing outline items.
409 let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
410 for item in outline_items {
411 if let Some(body_range) = item.body_range(&buffer) {
412 ranges.push(item.range.start..body_range.start.to_offset(&buffer));
413 } else {
414 ranges.push(item.range.clone());
415 }
416 }
417
418 ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
419 ranges.dedup_by(|a, b| {
420 if a.start <= b.end {
421 b.start = b.start.min(a.start);
422 b.end = b.end.max(a.end);
423 true
424 } else {
425 false
426 }
427 });
428
429 let mut identifiers = Vec::new();
430 let outer_range =
431 ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
432
433 let mut captures = buffer
434 .syntax
435 .captures(outer_range.clone(), &buffer.text, |grammar| {
436 grammar
437 .highlights_config
438 .as_ref()
439 .map(|config| &config.query)
440 });
441
442 for range in ranges {
443 captures.set_byte_range(range.start..outer_range.end);
444
445 let mut last_range = None;
446 while let Some(capture) = captures.peek() {
447 let node_range = capture.node.byte_range();
448 if node_range.start > range.end {
449 break;
450 }
451 let config = captures.grammars()[capture.grammar_index]
452 .highlights_config
453 .as_ref();
454
455 if let Some(config) = config
456 && config.identifier_capture_indices.contains(&capture.index)
457 && range.contains_inclusive(&node_range)
458 && Some(&node_range) != last_range.as_ref()
459 {
460 let name = buffer.text_for_range(node_range.clone()).collect();
461 identifiers.push(Identifier {
462 range: buffer.anchor_after(node_range.start)
463 ..buffer.anchor_before(node_range.end),
464 name,
465 });
466 last_range = Some(node_range);
467 }
468
469 captures.advance();
470 }
471 }
472
473 identifiers
474}