edit_prediction_context2.rs

  1use crate::assemble_excerpts::assemble_excerpts;
  2use anyhow::Result;
  3use collections::HashMap;
  4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
  5use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
  6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
  7use project::{LocationLink, Project, ProjectPath};
  8use serde::{Serialize, Serializer};
  9use smallvec::SmallVec;
 10use std::{
 11    collections::hash_map,
 12    ops::Range,
 13    sync::Arc,
 14    time::{Duration, Instant},
 15};
 16use util::{RangeExt as _, ResultExt};
 17
 18mod assemble_excerpts;
 19#[cfg(test)]
 20mod edit_prediction_context_tests;
 21#[cfg(test)]
 22mod fake_definition_lsp;
 23
 24pub struct RelatedExcerptStore {
 25    project: WeakEntity<Project>,
 26    related_files: Vec<RelatedFile>,
 27    cache: HashMap<Identifier, Arc<CacheEntry>>,
 28    update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
 29}
 30
 31pub enum RelatedExcerptStoreEvent {
 32    StartedRefresh,
 33    FinishedRefresh {
 34        cache_hit_count: usize,
 35        cache_miss_count: usize,
 36        mean_definition_latency: Duration,
 37        max_definition_latency: Duration,
 38    },
 39}
 40
 41#[derive(Clone, Debug, PartialEq, Eq, Hash)]
 42struct Identifier {
 43    pub name: String,
 44    pub range: Range<Anchor>,
 45}
 46
 47enum DefinitionTask {
 48    CacheHit(Arc<CacheEntry>),
 49    CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
 50}
 51
 52#[derive(Debug)]
 53struct CacheEntry {
 54    definitions: SmallVec<[CachedDefinition; 1]>,
 55}
 56
 57#[derive(Clone, Debug)]
 58struct CachedDefinition {
 59    path: ProjectPath,
 60    buffer: Entity<Buffer>,
 61    anchor_range: Range<Anchor>,
 62}
 63
 64#[derive(Clone, Debug, Serialize)]
 65pub struct RelatedFile {
 66    #[serde(serialize_with = "serialize_project_path")]
 67    pub path: ProjectPath,
 68    #[serde(skip)]
 69    pub buffer: WeakEntity<Buffer>,
 70    pub excerpts: Vec<RelatedExcerpt>,
 71    pub max_row: u32,
 72}
 73
 74impl RelatedFile {
 75    pub fn merge_excerpts(&mut self) {
 76        self.excerpts.sort_unstable_by(|a, b| {
 77            a.point_range
 78                .start
 79                .cmp(&b.point_range.start)
 80                .then(b.point_range.end.cmp(&a.point_range.end))
 81        });
 82
 83        let mut index = 1;
 84        while index < self.excerpts.len() {
 85            if self.excerpts[index - 1]
 86                .point_range
 87                .end
 88                .cmp(&self.excerpts[index].point_range.start)
 89                .is_ge()
 90            {
 91                let removed = self.excerpts.remove(index);
 92                if removed
 93                    .point_range
 94                    .end
 95                    .cmp(&self.excerpts[index - 1].point_range.end)
 96                    .is_gt()
 97                {
 98                    self.excerpts[index - 1].point_range.end = removed.point_range.end;
 99                    self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
100                }
101            } else {
102                index += 1;
103            }
104        }
105    }
106}
107
108#[derive(Clone, Debug, Serialize)]
109pub struct RelatedExcerpt {
110    #[serde(skip)]
111    pub anchor_range: Range<Anchor>,
112    #[serde(serialize_with = "serialize_point_range")]
113    pub point_range: Range<Point>,
114    #[serde(serialize_with = "serialize_rope")]
115    pub text: Rope,
116}
117
118fn serialize_project_path<S: Serializer>(
119    project_path: &ProjectPath,
120    serializer: S,
121) -> Result<S::Ok, S::Error> {
122    project_path.path.serialize(serializer)
123}
124
125fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
126    rope.to_string().serialize(serializer)
127}
128
129fn serialize_point_range<S: Serializer>(
130    range: &Range<Point>,
131    serializer: S,
132) -> Result<S::Ok, S::Error> {
133    [
134        [range.start.row, range.start.column],
135        [range.end.row, range.end.column],
136    ]
137    .serialize(serializer)
138}
139
140const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
141
142impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
143
144impl RelatedExcerptStore {
145    pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
146        let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
147        cx.spawn(async move |this, cx| {
148            let executor = cx.background_executor().clone();
149            while let Some((mut buffer, mut position)) = update_rx.next().await {
150                let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
151                loop {
152                    futures::select_biased! {
153                        next = update_rx.next() => {
154                            if let Some((new_buffer, new_position)) = next {
155                                buffer = new_buffer;
156                                position = new_position;
157                                timer = executor.timer(DEBOUNCE_DURATION).fuse();
158                            } else {
159                                return anyhow::Ok(());
160                            }
161                        }
162                        _ = timer => break,
163                    }
164                }
165
166                Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
167            }
168            anyhow::Ok(())
169        })
170        .detach_and_log_err(cx);
171
172        RelatedExcerptStore {
173            project: project.downgrade(),
174            update_tx,
175            related_files: Vec::new(),
176            cache: Default::default(),
177        }
178    }
179
180    pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
181        self.update_tx.unbounded_send((buffer, position)).ok();
182    }
183
184    pub fn related_files(&self) -> &[RelatedFile] {
185        &self.related_files
186    }
187
188    async fn fetch_excerpts(
189        this: WeakEntity<Self>,
190        buffer: Entity<Buffer>,
191        position: Anchor,
192        cx: &mut AsyncApp,
193    ) -> Result<()> {
194        let (project, snapshot) = this.read_with(cx, |this, cx| {
195            (this.project.upgrade(), buffer.read(cx).snapshot())
196        })?;
197        let Some(project) = project else {
198            return Ok(());
199        };
200
201        let file = snapshot.file().cloned();
202        if let Some(file) = &file {
203            log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
204        }
205
206        this.update(cx, |_, cx| {
207            cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
208        })?;
209
210        let identifiers = cx
211            .background_spawn(async move { identifiers_for_position(&snapshot, position) })
212            .await;
213
214        let async_cx = cx.clone();
215        let start_time = Instant::now();
216        let futures = this.update(cx, |this, cx| {
217            identifiers
218                .into_iter()
219                .filter_map(|identifier| {
220                    let task = if let Some(entry) = this.cache.get(&identifier) {
221                        DefinitionTask::CacheHit(entry.clone())
222                    } else {
223                        DefinitionTask::CacheMiss(
224                            this.project
225                                .update(cx, |project, cx| {
226                                    project.definitions(&buffer, identifier.range.start, cx)
227                                })
228                                .ok()?,
229                        )
230                    };
231
232                    let cx = async_cx.clone();
233                    let project = project.clone();
234                    Some(async move {
235                        match task {
236                            DefinitionTask::CacheHit(cache_entry) => {
237                                Some((identifier, cache_entry, None))
238                            }
239                            DefinitionTask::CacheMiss(task) => {
240                                let locations = task.await.log_err()??;
241                                let duration = start_time.elapsed();
242                                cx.update(|cx| {
243                                    (
244                                        identifier,
245                                        Arc::new(CacheEntry {
246                                            definitions: locations
247                                                .into_iter()
248                                                .filter_map(|location| {
249                                                    process_definition(location, &project, cx)
250                                                })
251                                                .collect(),
252                                        }),
253                                        Some(duration),
254                                    )
255                                })
256                                .ok()
257                            }
258                        }
259                    })
260                })
261                .collect::<Vec<_>>()
262        })?;
263
264        let mut cache_hit_count = 0;
265        let mut cache_miss_count = 0;
266        let mut mean_definition_latency = Duration::ZERO;
267        let mut max_definition_latency = Duration::ZERO;
268        let mut new_cache = HashMap::default();
269        new_cache.reserve(futures.len());
270        for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
271            new_cache.insert(identifier, entry);
272            if let Some(duration) = duration {
273                cache_miss_count += 1;
274                mean_definition_latency += duration;
275                max_definition_latency = max_definition_latency.max(duration);
276            } else {
277                cache_hit_count += 1;
278            }
279        }
280        mean_definition_latency /= cache_miss_count.max(1) as u32;
281
282        let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
283
284        if let Some(file) = &file {
285            log::debug!(
286                "finished retrieving context buffer:{}, latency:{:?}",
287                file.path().as_unix_str(),
288                start_time.elapsed()
289            );
290        }
291
292        this.update(cx, |this, cx| {
293            this.cache = new_cache;
294            this.related_files = related_files;
295            cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
296                cache_hit_count,
297                cache_miss_count,
298                mean_definition_latency,
299                max_definition_latency,
300            });
301        })?;
302
303        anyhow::Ok(())
304    }
305}
306
307async fn rebuild_related_files(
308    new_entries: HashMap<Identifier, Arc<CacheEntry>>,
309    cx: &mut AsyncApp,
310) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
311    let mut snapshots = HashMap::default();
312    for entry in new_entries.values() {
313        for definition in &entry.definitions {
314            if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
315                definition
316                    .buffer
317                    .read_with(cx, |buffer, _| buffer.parsing_idle())?
318                    .await;
319                e.insert(
320                    definition
321                        .buffer
322                        .read_with(cx, |buffer, _| buffer.snapshot())?,
323                );
324            }
325        }
326    }
327
328    Ok(cx
329        .background_spawn(async move {
330            let mut files = Vec::<RelatedFile>::new();
331            let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
332            let mut paths_by_buffer = HashMap::default();
333            for entry in new_entries.values() {
334                for definition in &entry.definitions {
335                    let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
336                        continue;
337                    };
338                    paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
339                    ranges_by_buffer
340                        .entry(definition.buffer.clone())
341                        .or_default()
342                        .push(definition.anchor_range.to_point(snapshot));
343                }
344            }
345
346            for (buffer, ranges) in ranges_by_buffer {
347                let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
348                    continue;
349                };
350                let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
351                    continue;
352                };
353                let excerpts = assemble_excerpts(snapshot, ranges);
354                files.push(RelatedFile {
355                    path: project_path.clone(),
356                    buffer: buffer.downgrade(),
357                    excerpts,
358                    max_row: snapshot.max_point().row,
359                });
360            }
361
362            files.sort_by_key(|file| file.path.clone());
363            (new_entries, files)
364        })
365        .await)
366}
367
368fn process_definition(
369    location: LocationLink,
370    project: &Entity<Project>,
371    cx: &mut App,
372) -> Option<CachedDefinition> {
373    let buffer = location.target.buffer.read(cx);
374    let anchor_range = location.target.range;
375    let file = buffer.file()?;
376    let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
377    if worktree.read(cx).is_single_file() {
378        return None;
379    }
380    Some(CachedDefinition {
381        path: ProjectPath {
382            worktree_id: file.worktree_id(cx),
383            path: file.path().clone(),
384        },
385        buffer: location.target.buffer,
386        anchor_range,
387    })
388}
389
390/// Gets all of the identifiers that are present in the given line, and its containing
391/// outline items.
392fn identifiers_for_position(buffer: &BufferSnapshot, position: Anchor) -> Vec<Identifier> {
393    let offset = position.to_offset(buffer);
394    let point = buffer.offset_to_point(offset);
395
396    let line_range = Point::new(point.row, 0)..Point::new(point.row + 1, 0).min(buffer.max_point());
397    let mut ranges = vec![line_range.to_offset(&buffer)];
398
399    // Include the range of the outline item itself, but not its body.
400    let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
401    for item in outline_items {
402        if let Some(body_range) = item.body_range(&buffer) {
403            ranges.push(item.range.start..body_range.start.to_offset(&buffer));
404        } else {
405            ranges.push(item.range.clone());
406        }
407    }
408
409    ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
410    ranges.dedup_by(|a, b| {
411        if a.start <= b.end {
412            b.start = b.start.min(a.start);
413            b.end = b.end.max(a.end);
414            true
415        } else {
416            false
417        }
418    });
419
420    let mut identifiers = Vec::new();
421    let outer_range =
422        ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
423
424    let mut captures = buffer
425        .syntax
426        .captures(outer_range.clone(), &buffer.text, |grammar| {
427            grammar
428                .highlights_config
429                .as_ref()
430                .map(|config| &config.query)
431        });
432
433    for range in ranges {
434        captures.set_byte_range(range.start..outer_range.end);
435
436        let mut last_range = None;
437        while let Some(capture) = captures.peek() {
438            let node_range = capture.node.byte_range();
439            if node_range.start > range.end {
440                break;
441            }
442            let config = captures.grammars()[capture.grammar_index]
443                .highlights_config
444                .as_ref();
445
446            if let Some(config) = config
447                && config.identifier_capture_indices.contains(&capture.index)
448                && range.contains_inclusive(&node_range)
449                && Some(&node_range) != last_range.as_ref()
450            {
451                let name = buffer.text_for_range(node_range.clone()).collect();
452                identifiers.push(Identifier {
453                    range: buffer.anchor_after(node_range.start)
454                        ..buffer.anchor_before(node_range.end),
455                    name,
456                });
457                last_range = Some(node_range);
458            }
459
460            captures.advance();
461        }
462    }
463
464    identifiers
465}