edit_prediction_context.rs

  1use crate::assemble_excerpts::assemble_excerpts;
  2use anyhow::Result;
  3use collections::HashMap;
  4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
  5use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
  6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
  7use project::{LocationLink, Project, ProjectPath};
  8use smallvec::SmallVec;
  9use std::{
 10    collections::hash_map,
 11    ops::Range,
 12    path::Path,
 13    sync::Arc,
 14    time::{Duration, Instant},
 15};
 16use util::{RangeExt as _, ResultExt};
 17
 18mod assemble_excerpts;
 19#[cfg(test)]
 20mod edit_prediction_context_tests;
 21mod excerpt;
 22#[cfg(test)]
 23mod fake_definition_lsp;
 24
 25pub use cloud_llm_client::predict_edits_v3::Line;
 26pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
 27pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
 28
 29const IDENTIFIER_LINE_COUNT: u32 = 3;
 30
 31pub struct RelatedExcerptStore {
 32    project: WeakEntity<Project>,
 33    related_files: Arc<[RelatedFile]>,
 34    related_file_buffers: Vec<Entity<Buffer>>,
 35    cache: HashMap<Identifier, Arc<CacheEntry>>,
 36    update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
 37    identifier_line_count: u32,
 38}
 39
 40pub enum RelatedExcerptStoreEvent {
 41    StartedRefresh,
 42    FinishedRefresh {
 43        cache_hit_count: usize,
 44        cache_miss_count: usize,
 45        mean_definition_latency: Duration,
 46        max_definition_latency: Duration,
 47    },
 48}
 49
 50#[derive(Clone, Debug, PartialEq, Eq, Hash)]
 51struct Identifier {
 52    pub name: String,
 53    pub range: Range<Anchor>,
 54}
 55
 56enum DefinitionTask {
 57    CacheHit(Arc<CacheEntry>),
 58    CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
 59}
 60
 61#[derive(Debug)]
 62struct CacheEntry {
 63    definitions: SmallVec<[CachedDefinition; 1]>,
 64}
 65
 66#[derive(Clone, Debug)]
 67struct CachedDefinition {
 68    path: ProjectPath,
 69    buffer: Entity<Buffer>,
 70    anchor_range: Range<Anchor>,
 71}
 72
 73const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
 74
 75impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
 76
 77impl RelatedExcerptStore {
 78    pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
 79        let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
 80        cx.spawn(async move |this, cx| {
 81            let executor = cx.background_executor().clone();
 82            while let Some((mut buffer, mut position)) = update_rx.next().await {
 83                let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
 84                loop {
 85                    futures::select_biased! {
 86                        next = update_rx.next() => {
 87                            if let Some((new_buffer, new_position)) = next {
 88                                buffer = new_buffer;
 89                                position = new_position;
 90                                timer = executor.timer(DEBOUNCE_DURATION).fuse();
 91                            } else {
 92                                return anyhow::Ok(());
 93                            }
 94                        }
 95                        _ = timer => break,
 96                    }
 97                }
 98
 99                Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
100            }
101            anyhow::Ok(())
102        })
103        .detach_and_log_err(cx);
104
105        RelatedExcerptStore {
106            project: project.downgrade(),
107            update_tx,
108            related_files: Vec::new().into(),
109            related_file_buffers: Vec::new(),
110            cache: Default::default(),
111            identifier_line_count: IDENTIFIER_LINE_COUNT,
112        }
113    }
114
115    pub fn set_identifier_line_count(&mut self, count: u32) {
116        self.identifier_line_count = count;
117    }
118
119    pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
120        self.update_tx.unbounded_send((buffer, position)).ok();
121    }
122
123    pub fn related_files(&self) -> Arc<[RelatedFile]> {
124        self.related_files.clone()
125    }
126
127    pub fn related_files_with_buffers(
128        &self,
129    ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
130        self.related_files
131            .iter()
132            .cloned()
133            .zip(self.related_file_buffers.iter().cloned())
134    }
135
136    pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
137        self.related_files = files.into();
138    }
139
140    async fn fetch_excerpts(
141        this: WeakEntity<Self>,
142        buffer: Entity<Buffer>,
143        position: Anchor,
144        cx: &mut AsyncApp,
145    ) -> Result<()> {
146        let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
147            (
148                this.project.upgrade(),
149                buffer.read(cx).snapshot(),
150                this.identifier_line_count,
151            )
152        })?;
153        let Some(project) = project else {
154            return Ok(());
155        };
156
157        let file = snapshot.file().cloned();
158        if let Some(file) = &file {
159            log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
160        }
161
162        this.update(cx, |_, cx| {
163            cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
164        })?;
165
166        let identifiers = cx
167            .background_spawn(async move {
168                identifiers_for_position(&snapshot, position, identifier_line_count)
169            })
170            .await;
171
172        let async_cx = cx.clone();
173        let start_time = Instant::now();
174        let futures = this.update(cx, |this, cx| {
175            identifiers
176                .into_iter()
177                .filter_map(|identifier| {
178                    let task = if let Some(entry) = this.cache.get(&identifier) {
179                        DefinitionTask::CacheHit(entry.clone())
180                    } else {
181                        DefinitionTask::CacheMiss(
182                            this.project
183                                .update(cx, |project, cx| {
184                                    project.definitions(&buffer, identifier.range.start, cx)
185                                })
186                                .ok()?,
187                        )
188                    };
189
190                    let cx = async_cx.clone();
191                    let project = project.clone();
192                    Some(async move {
193                        match task {
194                            DefinitionTask::CacheHit(cache_entry) => {
195                                Some((identifier, cache_entry, None))
196                            }
197                            DefinitionTask::CacheMiss(task) => {
198                                let locations = task.await.log_err()??;
199                                let duration = start_time.elapsed();
200                                Some(cx.update(|cx| {
201                                    (
202                                        identifier,
203                                        Arc::new(CacheEntry {
204                                            definitions: locations
205                                                .into_iter()
206                                                .filter_map(|location| {
207                                                    process_definition(location, &project, cx)
208                                                })
209                                                .collect(),
210                                        }),
211                                        Some(duration),
212                                    )
213                                }))
214                            }
215                        }
216                    })
217                })
218                .collect::<Vec<_>>()
219        })?;
220
221        let mut cache_hit_count = 0;
222        let mut cache_miss_count = 0;
223        let mut mean_definition_latency = Duration::ZERO;
224        let mut max_definition_latency = Duration::ZERO;
225        let mut new_cache = HashMap::default();
226        new_cache.reserve(futures.len());
227        for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
228            new_cache.insert(identifier, entry);
229            if let Some(duration) = duration {
230                cache_miss_count += 1;
231                mean_definition_latency += duration;
232                max_definition_latency = max_definition_latency.max(duration);
233            } else {
234                cache_hit_count += 1;
235            }
236        }
237        mean_definition_latency /= cache_miss_count.max(1) as u32;
238
239        let (new_cache, related_files, related_file_buffers) =
240            rebuild_related_files(&project, new_cache, cx).await?;
241
242        if let Some(file) = &file {
243            log::debug!(
244                "finished retrieving context buffer:{}, latency:{:?}",
245                file.path().as_unix_str(),
246                start_time.elapsed()
247            );
248        }
249
250        this.update(cx, |this, cx| {
251            this.cache = new_cache;
252            this.related_files = related_files.into();
253            this.related_file_buffers = related_file_buffers;
254            cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
255                cache_hit_count,
256                cache_miss_count,
257                mean_definition_latency,
258                max_definition_latency,
259            });
260        })?;
261
262        anyhow::Ok(())
263    }
264}
265
266async fn rebuild_related_files(
267    project: &Entity<Project>,
268    new_entries: HashMap<Identifier, Arc<CacheEntry>>,
269    cx: &mut AsyncApp,
270) -> Result<(
271    HashMap<Identifier, Arc<CacheEntry>>,
272    Vec<RelatedFile>,
273    Vec<Entity<Buffer>>,
274)> {
275    let mut snapshots = HashMap::default();
276    let mut worktree_root_names = HashMap::default();
277    for entry in new_entries.values() {
278        for definition in &entry.definitions {
279            if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
280                definition
281                    .buffer
282                    .read_with(cx, |buffer, _| buffer.parsing_idle())
283                    .await;
284                e.insert(
285                    definition
286                        .buffer
287                        .read_with(cx, |buffer, _| buffer.snapshot()),
288                );
289            }
290            let worktree_id = definition.path.worktree_id;
291            if let hash_map::Entry::Vacant(e) =
292                worktree_root_names.entry(definition.path.worktree_id)
293            {
294                project.read_with(cx, |project, cx| {
295                    if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
296                        e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
297                    }
298                });
299            }
300        }
301    }
302
303    Ok(cx
304        .background_spawn(async move {
305            let mut files = Vec::new();
306            let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
307            let mut paths_by_buffer = HashMap::default();
308            for entry in new_entries.values() {
309                for definition in &entry.definitions {
310                    let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
311                        continue;
312                    };
313                    paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
314                    ranges_by_buffer
315                        .entry(definition.buffer.clone())
316                        .or_default()
317                        .push(definition.anchor_range.to_point(snapshot));
318                }
319            }
320
321            for (buffer, ranges) in ranges_by_buffer {
322                let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
323                    continue;
324                };
325                let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
326                    continue;
327                };
328                let excerpts = assemble_excerpts(snapshot, ranges);
329                let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
330                    continue;
331                };
332
333                let path = Path::new(&format!(
334                    "{}/{}",
335                    root_name,
336                    project_path.path.as_unix_str()
337                ))
338                .into();
339
340                files.push((
341                    buffer,
342                    RelatedFile {
343                        path,
344                        excerpts,
345                        max_row: snapshot.max_point().row,
346                    },
347                ));
348            }
349
350            files.sort_by_key(|(_, file)| file.path.clone());
351            let (related_buffers, related_files) = files.into_iter().unzip();
352
353            (new_entries, related_files, related_buffers)
354        })
355        .await)
356}
357
358const MAX_TARGET_LEN: usize = 128;
359
360fn process_definition(
361    location: LocationLink,
362    project: &Entity<Project>,
363    cx: &mut App,
364) -> Option<CachedDefinition> {
365    let buffer = location.target.buffer.read(cx);
366    let anchor_range = location.target.range;
367    let file = buffer.file()?;
368    let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
369    if worktree.read(cx).is_single_file() {
370        return None;
371    }
372
373    // If the target range is large, it likely means we requested the definition of an entire module.
374    // For individual definitions, the target range should be small as it only covers the symbol.
375    let buffer = location.target.buffer.read(cx);
376    let target_len = anchor_range.to_offset(&buffer).len();
377    if target_len > MAX_TARGET_LEN {
378        return None;
379    }
380
381    Some(CachedDefinition {
382        path: ProjectPath {
383            worktree_id: file.worktree_id(cx),
384            path: file.path().clone(),
385        },
386        buffer: location.target.buffer,
387        anchor_range,
388    })
389}
390
391/// Gets all of the identifiers that are present in the given line, and its containing
392/// outline items.
393fn identifiers_for_position(
394    buffer: &BufferSnapshot,
395    position: Anchor,
396    identifier_line_count: u32,
397) -> Vec<Identifier> {
398    let offset = position.to_offset(buffer);
399    let point = buffer.offset_to_point(offset);
400
401    // Search for identifiers on lines adjacent to the cursor.
402    let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
403    let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
404    let line_range = start..end;
405    let mut ranges = vec![line_range.to_offset(&buffer)];
406
407    // Search for identifiers mentioned in headers/signatures of containing outline items.
408    let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
409    for item in outline_items {
410        if let Some(body_range) = item.body_range(&buffer) {
411            ranges.push(item.range.start..body_range.start.to_offset(&buffer));
412        } else {
413            ranges.push(item.range.clone());
414        }
415    }
416
417    ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
418    ranges.dedup_by(|a, b| {
419        if a.start <= b.end {
420            b.start = b.start.min(a.start);
421            b.end = b.end.max(a.end);
422            true
423        } else {
424            false
425        }
426    });
427
428    let mut identifiers = Vec::new();
429    let outer_range =
430        ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
431
432    let mut captures = buffer
433        .syntax
434        .captures(outer_range.clone(), &buffer.text, |grammar| {
435            grammar
436                .highlights_config
437                .as_ref()
438                .map(|config| &config.query)
439        });
440
441    for range in ranges {
442        captures.set_byte_range(range.start..outer_range.end);
443
444        let mut last_range = None;
445        while let Some(capture) = captures.peek() {
446            let node_range = capture.node.byte_range();
447            if node_range.start > range.end {
448                break;
449            }
450            let config = captures.grammars()[capture.grammar_index]
451                .highlights_config
452                .as_ref();
453
454            if let Some(config) = config
455                && config.identifier_capture_indices.contains(&capture.index)
456                && range.contains_inclusive(&node_range)
457                && Some(&node_range) != last_range.as_ref()
458            {
459                let name = buffer.text_for_range(node_range.clone()).collect();
460                identifiers.push(Identifier {
461                    range: buffer.anchor_after(node_range.start)
462                        ..buffer.anchor_before(node_range.end),
463                    name,
464                });
465                last_range = Some(node_range);
466            }
467
468            captures.advance();
469        }
470    }
471
472    identifiers
473}