edit_prediction_context.rs

  1use crate::assemble_excerpts::assemble_excerpts;
  2use anyhow::Result;
  3use collections::HashMap;
  4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
  5use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
  6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
  7use project::{LocationLink, Project, ProjectPath};
  8use serde::{Serialize, Serializer};
  9use smallvec::SmallVec;
 10use std::{
 11    collections::hash_map,
 12    ops::Range,
 13    sync::Arc,
 14    time::{Duration, Instant},
 15};
 16use util::{RangeExt as _, ResultExt};
 17
 18mod assemble_excerpts;
 19#[cfg(test)]
 20mod edit_prediction_context_tests;
 21mod excerpt;
 22#[cfg(test)]
 23mod fake_definition_lsp;
 24
 25pub use cloud_llm_client::predict_edits_v3::Line;
 26pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
 27
 28const IDENTIFIER_LINE_COUNT: u32 = 3;
 29
 30pub struct RelatedExcerptStore {
 31    project: WeakEntity<Project>,
 32    related_files: Vec<RelatedFile>,
 33    cache: HashMap<Identifier, Arc<CacheEntry>>,
 34    update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
 35    identifier_line_count: u32,
 36}
 37
 38pub enum RelatedExcerptStoreEvent {
 39    StartedRefresh,
 40    FinishedRefresh {
 41        cache_hit_count: usize,
 42        cache_miss_count: usize,
 43        mean_definition_latency: Duration,
 44        max_definition_latency: Duration,
 45    },
 46}
 47
 48#[derive(Clone, Debug, PartialEq, Eq, Hash)]
 49struct Identifier {
 50    pub name: String,
 51    pub range: Range<Anchor>,
 52}
 53
 54enum DefinitionTask {
 55    CacheHit(Arc<CacheEntry>),
 56    CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
 57}
 58
 59#[derive(Debug)]
 60struct CacheEntry {
 61    definitions: SmallVec<[CachedDefinition; 1]>,
 62}
 63
 64#[derive(Clone, Debug)]
 65struct CachedDefinition {
 66    path: ProjectPath,
 67    buffer: Entity<Buffer>,
 68    anchor_range: Range<Anchor>,
 69}
 70
 71#[derive(Clone, Debug, Serialize)]
 72pub struct RelatedFile {
 73    #[serde(serialize_with = "serialize_project_path")]
 74    pub path: ProjectPath,
 75    #[serde(skip)]
 76    pub buffer: WeakEntity<Buffer>,
 77    pub excerpts: Vec<RelatedExcerpt>,
 78    pub max_row: u32,
 79}
 80
 81impl RelatedFile {
 82    pub fn merge_excerpts(&mut self) {
 83        self.excerpts.sort_unstable_by(|a, b| {
 84            a.point_range
 85                .start
 86                .cmp(&b.point_range.start)
 87                .then(b.point_range.end.cmp(&a.point_range.end))
 88        });
 89
 90        let mut index = 1;
 91        while index < self.excerpts.len() {
 92            if self.excerpts[index - 1]
 93                .point_range
 94                .end
 95                .cmp(&self.excerpts[index].point_range.start)
 96                .is_ge()
 97            {
 98                let removed = self.excerpts.remove(index);
 99                if removed
100                    .point_range
101                    .end
102                    .cmp(&self.excerpts[index - 1].point_range.end)
103                    .is_gt()
104                {
105                    self.excerpts[index - 1].point_range.end = removed.point_range.end;
106                    self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
107                }
108            } else {
109                index += 1;
110            }
111        }
112    }
113}
114
115#[derive(Clone, Debug, Serialize)]
116pub struct RelatedExcerpt {
117    #[serde(skip)]
118    pub anchor_range: Range<Anchor>,
119    #[serde(serialize_with = "serialize_point_range")]
120    pub point_range: Range<Point>,
121    #[serde(serialize_with = "serialize_rope")]
122    pub text: Rope,
123}
124
125fn serialize_project_path<S: Serializer>(
126    project_path: &ProjectPath,
127    serializer: S,
128) -> Result<S::Ok, S::Error> {
129    project_path.path.serialize(serializer)
130}
131
132fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
133    rope.to_string().serialize(serializer)
134}
135
136fn serialize_point_range<S: Serializer>(
137    range: &Range<Point>,
138    serializer: S,
139) -> Result<S::Ok, S::Error> {
140    [
141        [range.start.row, range.start.column],
142        [range.end.row, range.end.column],
143    ]
144    .serialize(serializer)
145}
146
147const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
148
149impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
150
151impl RelatedExcerptStore {
152    pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
153        let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
154        cx.spawn(async move |this, cx| {
155            let executor = cx.background_executor().clone();
156            while let Some((mut buffer, mut position)) = update_rx.next().await {
157                let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
158                loop {
159                    futures::select_biased! {
160                        next = update_rx.next() => {
161                            if let Some((new_buffer, new_position)) = next {
162                                buffer = new_buffer;
163                                position = new_position;
164                                timer = executor.timer(DEBOUNCE_DURATION).fuse();
165                            } else {
166                                return anyhow::Ok(());
167                            }
168                        }
169                        _ = timer => break,
170                    }
171                }
172
173                Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
174            }
175            anyhow::Ok(())
176        })
177        .detach_and_log_err(cx);
178
179        RelatedExcerptStore {
180            project: project.downgrade(),
181            update_tx,
182            related_files: Vec::new(),
183            cache: Default::default(),
184            identifier_line_count: IDENTIFIER_LINE_COUNT,
185        }
186    }
187
188    pub fn set_identifier_line_count(&mut self, count: u32) {
189        self.identifier_line_count = count;
190    }
191
192    pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
193        self.update_tx.unbounded_send((buffer, position)).ok();
194    }
195
196    pub fn related_files(&self) -> &[RelatedFile] {
197        &self.related_files
198    }
199
200    async fn fetch_excerpts(
201        this: WeakEntity<Self>,
202        buffer: Entity<Buffer>,
203        position: Anchor,
204        cx: &mut AsyncApp,
205    ) -> Result<()> {
206        let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
207            (
208                this.project.upgrade(),
209                buffer.read(cx).snapshot(),
210                this.identifier_line_count,
211            )
212        })?;
213        let Some(project) = project else {
214            return Ok(());
215        };
216
217        let file = snapshot.file().cloned();
218        if let Some(file) = &file {
219            log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
220        }
221
222        this.update(cx, |_, cx| {
223            cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
224        })?;
225
226        let identifiers = cx
227            .background_spawn(async move {
228                identifiers_for_position(&snapshot, position, identifier_line_count)
229            })
230            .await;
231
232        let async_cx = cx.clone();
233        let start_time = Instant::now();
234        let futures = this.update(cx, |this, cx| {
235            identifiers
236                .into_iter()
237                .filter_map(|identifier| {
238                    let task = if let Some(entry) = this.cache.get(&identifier) {
239                        DefinitionTask::CacheHit(entry.clone())
240                    } else {
241                        DefinitionTask::CacheMiss(
242                            this.project
243                                .update(cx, |project, cx| {
244                                    project.definitions(&buffer, identifier.range.start, cx)
245                                })
246                                .ok()?,
247                        )
248                    };
249
250                    let cx = async_cx.clone();
251                    let project = project.clone();
252                    Some(async move {
253                        match task {
254                            DefinitionTask::CacheHit(cache_entry) => {
255                                Some((identifier, cache_entry, None))
256                            }
257                            DefinitionTask::CacheMiss(task) => {
258                                let locations = task.await.log_err()??;
259                                let duration = start_time.elapsed();
260                                cx.update(|cx| {
261                                    (
262                                        identifier,
263                                        Arc::new(CacheEntry {
264                                            definitions: locations
265                                                .into_iter()
266                                                .filter_map(|location| {
267                                                    process_definition(location, &project, cx)
268                                                })
269                                                .collect(),
270                                        }),
271                                        Some(duration),
272                                    )
273                                })
274                                .ok()
275                            }
276                        }
277                    })
278                })
279                .collect::<Vec<_>>()
280        })?;
281
282        let mut cache_hit_count = 0;
283        let mut cache_miss_count = 0;
284        let mut mean_definition_latency = Duration::ZERO;
285        let mut max_definition_latency = Duration::ZERO;
286        let mut new_cache = HashMap::default();
287        new_cache.reserve(futures.len());
288        for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
289            new_cache.insert(identifier, entry);
290            if let Some(duration) = duration {
291                cache_miss_count += 1;
292                mean_definition_latency += duration;
293                max_definition_latency = max_definition_latency.max(duration);
294            } else {
295                cache_hit_count += 1;
296            }
297        }
298        mean_definition_latency /= cache_miss_count.max(1) as u32;
299
300        let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
301
302        if let Some(file) = &file {
303            log::debug!(
304                "finished retrieving context buffer:{}, latency:{:?}",
305                file.path().as_unix_str(),
306                start_time.elapsed()
307            );
308        }
309
310        this.update(cx, |this, cx| {
311            this.cache = new_cache;
312            this.related_files = related_files;
313            cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
314                cache_hit_count,
315                cache_miss_count,
316                mean_definition_latency,
317                max_definition_latency,
318            });
319        })?;
320
321        anyhow::Ok(())
322    }
323}
324
325async fn rebuild_related_files(
326    new_entries: HashMap<Identifier, Arc<CacheEntry>>,
327    cx: &mut AsyncApp,
328) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
329    let mut snapshots = HashMap::default();
330    for entry in new_entries.values() {
331        for definition in &entry.definitions {
332            if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
333                definition
334                    .buffer
335                    .read_with(cx, |buffer, _| buffer.parsing_idle())?
336                    .await;
337                e.insert(
338                    definition
339                        .buffer
340                        .read_with(cx, |buffer, _| buffer.snapshot())?,
341                );
342            }
343        }
344    }
345
346    Ok(cx
347        .background_spawn(async move {
348            let mut files = Vec::<RelatedFile>::new();
349            let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
350            let mut paths_by_buffer = HashMap::default();
351            for entry in new_entries.values() {
352                for definition in &entry.definitions {
353                    let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
354                        continue;
355                    };
356                    paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
357                    ranges_by_buffer
358                        .entry(definition.buffer.clone())
359                        .or_default()
360                        .push(definition.anchor_range.to_point(snapshot));
361                }
362            }
363
364            for (buffer, ranges) in ranges_by_buffer {
365                let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
366                    continue;
367                };
368                let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
369                    continue;
370                };
371                let excerpts = assemble_excerpts(snapshot, ranges);
372                files.push(RelatedFile {
373                    path: project_path.clone(),
374                    buffer: buffer.downgrade(),
375                    excerpts,
376                    max_row: snapshot.max_point().row,
377                });
378            }
379
380            files.sort_by_key(|file| file.path.clone());
381            (new_entries, files)
382        })
383        .await)
384}
385
386const MAX_TARGET_LEN: usize = 128;
387
388fn process_definition(
389    location: LocationLink,
390    project: &Entity<Project>,
391    cx: &mut App,
392) -> Option<CachedDefinition> {
393    let buffer = location.target.buffer.read(cx);
394    let anchor_range = location.target.range;
395    let file = buffer.file()?;
396    let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
397    if worktree.read(cx).is_single_file() {
398        return None;
399    }
400
401    // If the target range is large, it likely means we requested the definition of an entire module.
402    // For individual definitions, the target range should be small as it only covers the symbol.
403    let buffer = location.target.buffer.read(cx);
404    let target_len = anchor_range.to_offset(&buffer).len();
405    if target_len > MAX_TARGET_LEN {
406        return None;
407    }
408
409    Some(CachedDefinition {
410        path: ProjectPath {
411            worktree_id: file.worktree_id(cx),
412            path: file.path().clone(),
413        },
414        buffer: location.target.buffer,
415        anchor_range,
416    })
417}
418
419/// Gets all of the identifiers that are present in the given line, and its containing
420/// outline items.
421fn identifiers_for_position(
422    buffer: &BufferSnapshot,
423    position: Anchor,
424    identifier_line_count: u32,
425) -> Vec<Identifier> {
426    let offset = position.to_offset(buffer);
427    let point = buffer.offset_to_point(offset);
428
429    // Search for identifiers on lines adjacent to the cursor.
430    let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
431    let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
432    let line_range = start..end;
433    let mut ranges = vec![line_range.to_offset(&buffer)];
434
435    // Search for identifiers mentioned in headers/signatures of containing outline items.
436    let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
437    for item in outline_items {
438        if let Some(body_range) = item.body_range(&buffer) {
439            ranges.push(item.range.start..body_range.start.to_offset(&buffer));
440        } else {
441            ranges.push(item.range.clone());
442        }
443    }
444
445    ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
446    ranges.dedup_by(|a, b| {
447        if a.start <= b.end {
448            b.start = b.start.min(a.start);
449            b.end = b.end.max(a.end);
450            true
451        } else {
452            false
453        }
454    });
455
456    let mut identifiers = Vec::new();
457    let outer_range =
458        ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
459
460    let mut captures = buffer
461        .syntax
462        .captures(outer_range.clone(), &buffer.text, |grammar| {
463            grammar
464                .highlights_config
465                .as_ref()
466                .map(|config| &config.query)
467        });
468
469    for range in ranges {
470        captures.set_byte_range(range.start..outer_range.end);
471
472        let mut last_range = None;
473        while let Some(capture) = captures.peek() {
474            let node_range = capture.node.byte_range();
475            if node_range.start > range.end {
476                break;
477            }
478            let config = captures.grammars()[capture.grammar_index]
479                .highlights_config
480                .as_ref();
481
482            if let Some(config) = config
483                && config.identifier_capture_indices.contains(&capture.index)
484                && range.contains_inclusive(&node_range)
485                && Some(&node_range) != last_range.as_ref()
486            {
487                let name = buffer.text_for_range(node_range.clone()).collect();
488                identifiers.push(Identifier {
489                    range: buffer.anchor_after(node_range.start)
490                        ..buffer.anchor_before(node_range.end),
491                    name,
492                });
493                last_range = Some(node_range);
494            }
495
496            captures.advance();
497        }
498    }
499
500    identifiers
501}