edit_prediction_context.rs

  1use crate::assemble_excerpts::assemble_excerpts;
  2use anyhow::Result;
  3use collections::HashMap;
  4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
  5use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
  6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
  7use project::{LocationLink, Project, ProjectPath};
  8use smallvec::SmallVec;
  9use std::{
 10    collections::hash_map,
 11    ops::Range,
 12    path::Path,
 13    sync::Arc,
 14    time::{Duration, Instant},
 15};
 16use util::{RangeExt as _, ResultExt};
 17
 18mod assemble_excerpts;
 19#[cfg(test)]
 20mod edit_prediction_context_tests;
 21mod excerpt;
 22#[cfg(test)]
 23mod fake_definition_lsp;
 24
 25pub use cloud_llm_client::predict_edits_v3::Line;
 26pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
 27pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
 28
 29const IDENTIFIER_LINE_COUNT: u32 = 3;
 30
 31pub struct RelatedExcerptStore {
 32    project: WeakEntity<Project>,
 33    related_files: Arc<[RelatedFile]>,
 34    related_file_buffers: Vec<Entity<Buffer>>,
 35    cache: HashMap<Identifier, Arc<CacheEntry>>,
 36    update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
 37    identifier_line_count: u32,
 38}
 39
 40pub enum RelatedExcerptStoreEvent {
 41    StartedRefresh,
 42    FinishedRefresh {
 43        cache_hit_count: usize,
 44        cache_miss_count: usize,
 45        mean_definition_latency: Duration,
 46        max_definition_latency: Duration,
 47    },
 48}
 49
 50#[derive(Clone, Debug, PartialEq, Eq, Hash)]
 51struct Identifier {
 52    pub name: String,
 53    pub range: Range<Anchor>,
 54}
 55
 56enum DefinitionTask {
 57    CacheHit(Arc<CacheEntry>),
 58    CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
 59}
 60
 61#[derive(Debug)]
 62struct CacheEntry {
 63    definitions: SmallVec<[CachedDefinition; 1]>,
 64}
 65
 66#[derive(Clone, Debug)]
 67struct CachedDefinition {
 68    path: ProjectPath,
 69    buffer: Entity<Buffer>,
 70    anchor_range: Range<Anchor>,
 71}
 72
 73const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
 74
 75impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
 76
 77impl RelatedExcerptStore {
 78    pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
 79        let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
 80        cx.spawn(async move |this, cx| {
 81            let executor = cx.background_executor().clone();
 82            while let Some((mut buffer, mut position)) = update_rx.next().await {
 83                let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
 84                loop {
 85                    futures::select_biased! {
 86                        next = update_rx.next() => {
 87                            if let Some((new_buffer, new_position)) = next {
 88                                buffer = new_buffer;
 89                                position = new_position;
 90                                timer = executor.timer(DEBOUNCE_DURATION).fuse();
 91                            } else {
 92                                return anyhow::Ok(());
 93                            }
 94                        }
 95                        _ = timer => break,
 96                    }
 97                }
 98
 99                Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
100            }
101            anyhow::Ok(())
102        })
103        .detach_and_log_err(cx);
104
105        RelatedExcerptStore {
106            project: project.downgrade(),
107            update_tx,
108            related_files: Vec::new().into(),
109            related_file_buffers: Vec::new(),
110            cache: Default::default(),
111            identifier_line_count: IDENTIFIER_LINE_COUNT,
112        }
113    }
114
115    pub fn set_identifier_line_count(&mut self, count: u32) {
116        self.identifier_line_count = count;
117    }
118
119    pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
120        self.update_tx.unbounded_send((buffer, position)).ok();
121    }
122
123    pub fn related_files(&self) -> Arc<[RelatedFile]> {
124        self.related_files.clone()
125    }
126
127    pub fn related_files_with_buffers(
128        &self,
129    ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
130        self.related_files
131            .iter()
132            .cloned()
133            .zip(self.related_file_buffers.iter().cloned())
134    }
135
136    pub fn set_related_files(&mut self, files: Vec<RelatedFile>) {
137        self.related_files = files.into();
138    }
139
140    async fn fetch_excerpts(
141        this: WeakEntity<Self>,
142        buffer: Entity<Buffer>,
143        position: Anchor,
144        cx: &mut AsyncApp,
145    ) -> Result<()> {
146        let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
147            (
148                this.project.upgrade(),
149                buffer.read(cx).snapshot(),
150                this.identifier_line_count,
151            )
152        })?;
153        let Some(project) = project else {
154            return Ok(());
155        };
156
157        let file = snapshot.file().cloned();
158        if let Some(file) = &file {
159            log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
160        }
161
162        this.update(cx, |_, cx| {
163            cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
164        })?;
165
166        let identifiers = cx
167            .background_spawn(async move {
168                identifiers_for_position(&snapshot, position, identifier_line_count)
169            })
170            .await;
171
172        let async_cx = cx.clone();
173        let start_time = Instant::now();
174        let futures = this.update(cx, |this, cx| {
175            identifiers
176                .into_iter()
177                .filter_map(|identifier| {
178                    let task = if let Some(entry) = this.cache.get(&identifier) {
179                        DefinitionTask::CacheHit(entry.clone())
180                    } else {
181                        DefinitionTask::CacheMiss(
182                            this.project
183                                .update(cx, |project, cx| {
184                                    project.definitions(&buffer, identifier.range.start, cx)
185                                })
186                                .ok()?,
187                        )
188                    };
189
190                    let cx = async_cx.clone();
191                    let project = project.clone();
192                    Some(async move {
193                        match task {
194                            DefinitionTask::CacheHit(cache_entry) => {
195                                Some((identifier, cache_entry, None))
196                            }
197                            DefinitionTask::CacheMiss(task) => {
198                                let locations = task.await.log_err()??;
199                                let duration = start_time.elapsed();
200                                cx.update(|cx| {
201                                    (
202                                        identifier,
203                                        Arc::new(CacheEntry {
204                                            definitions: locations
205                                                .into_iter()
206                                                .filter_map(|location| {
207                                                    process_definition(location, &project, cx)
208                                                })
209                                                .collect(),
210                                        }),
211                                        Some(duration),
212                                    )
213                                })
214                                .ok()
215                            }
216                        }
217                    })
218                })
219                .collect::<Vec<_>>()
220        })?;
221
222        let mut cache_hit_count = 0;
223        let mut cache_miss_count = 0;
224        let mut mean_definition_latency = Duration::ZERO;
225        let mut max_definition_latency = Duration::ZERO;
226        let mut new_cache = HashMap::default();
227        new_cache.reserve(futures.len());
228        for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
229            new_cache.insert(identifier, entry);
230            if let Some(duration) = duration {
231                cache_miss_count += 1;
232                mean_definition_latency += duration;
233                max_definition_latency = max_definition_latency.max(duration);
234            } else {
235                cache_hit_count += 1;
236            }
237        }
238        mean_definition_latency /= cache_miss_count.max(1) as u32;
239
240        let (new_cache, related_files, related_file_buffers) =
241            rebuild_related_files(&project, new_cache, cx).await?;
242
243        if let Some(file) = &file {
244            log::debug!(
245                "finished retrieving context buffer:{}, latency:{:?}",
246                file.path().as_unix_str(),
247                start_time.elapsed()
248            );
249        }
250
251        this.update(cx, |this, cx| {
252            this.cache = new_cache;
253            this.related_files = related_files.into();
254            this.related_file_buffers = related_file_buffers;
255            cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
256                cache_hit_count,
257                cache_miss_count,
258                mean_definition_latency,
259                max_definition_latency,
260            });
261        })?;
262
263        anyhow::Ok(())
264    }
265}
266
267async fn rebuild_related_files(
268    project: &Entity<Project>,
269    new_entries: HashMap<Identifier, Arc<CacheEntry>>,
270    cx: &mut AsyncApp,
271) -> Result<(
272    HashMap<Identifier, Arc<CacheEntry>>,
273    Vec<RelatedFile>,
274    Vec<Entity<Buffer>>,
275)> {
276    let mut snapshots = HashMap::default();
277    let mut worktree_root_names = HashMap::default();
278    for entry in new_entries.values() {
279        for definition in &entry.definitions {
280            if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
281                definition
282                    .buffer
283                    .read_with(cx, |buffer, _| buffer.parsing_idle())?
284                    .await;
285                e.insert(
286                    definition
287                        .buffer
288                        .read_with(cx, |buffer, _| buffer.snapshot())?,
289                );
290            }
291            let worktree_id = definition.path.worktree_id;
292            if let hash_map::Entry::Vacant(e) =
293                worktree_root_names.entry(definition.path.worktree_id)
294            {
295                project.read_with(cx, |project, cx| {
296                    if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
297                        e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
298                    }
299                })?;
300            }
301        }
302    }
303
304    Ok(cx
305        .background_spawn(async move {
306            let mut files = Vec::new();
307            let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
308            let mut paths_by_buffer = HashMap::default();
309            for entry in new_entries.values() {
310                for definition in &entry.definitions {
311                    let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
312                        continue;
313                    };
314                    paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
315                    ranges_by_buffer
316                        .entry(definition.buffer.clone())
317                        .or_default()
318                        .push(definition.anchor_range.to_point(snapshot));
319                }
320            }
321
322            for (buffer, ranges) in ranges_by_buffer {
323                let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
324                    continue;
325                };
326                let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
327                    continue;
328                };
329                let excerpts = assemble_excerpts(snapshot, ranges);
330                let Some(root_name) = worktree_root_names.get(&project_path.worktree_id) else {
331                    continue;
332                };
333
334                let path = Path::new(&format!(
335                    "{}/{}",
336                    root_name,
337                    project_path.path.as_unix_str()
338                ))
339                .into();
340
341                files.push((
342                    buffer,
343                    RelatedFile {
344                        path,
345                        excerpts,
346                        max_row: snapshot.max_point().row,
347                    },
348                ));
349            }
350
351            files.sort_by_key(|(_, file)| file.path.clone());
352            let (related_buffers, related_files) = files.into_iter().unzip();
353
354            (new_entries, related_files, related_buffers)
355        })
356        .await)
357}
358
359const MAX_TARGET_LEN: usize = 128;
360
361fn process_definition(
362    location: LocationLink,
363    project: &Entity<Project>,
364    cx: &mut App,
365) -> Option<CachedDefinition> {
366    let buffer = location.target.buffer.read(cx);
367    let anchor_range = location.target.range;
368    let file = buffer.file()?;
369    let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
370    if worktree.read(cx).is_single_file() {
371        return None;
372    }
373
374    // If the target range is large, it likely means we requested the definition of an entire module.
375    // For individual definitions, the target range should be small as it only covers the symbol.
376    let buffer = location.target.buffer.read(cx);
377    let target_len = anchor_range.to_offset(&buffer).len();
378    if target_len > MAX_TARGET_LEN {
379        return None;
380    }
381
382    Some(CachedDefinition {
383        path: ProjectPath {
384            worktree_id: file.worktree_id(cx),
385            path: file.path().clone(),
386        },
387        buffer: location.target.buffer,
388        anchor_range,
389    })
390}
391
392/// Gets all of the identifiers that are present in the given line, and its containing
393/// outline items.
394fn identifiers_for_position(
395    buffer: &BufferSnapshot,
396    position: Anchor,
397    identifier_line_count: u32,
398) -> Vec<Identifier> {
399    let offset = position.to_offset(buffer);
400    let point = buffer.offset_to_point(offset);
401
402    // Search for identifiers on lines adjacent to the cursor.
403    let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
404    let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
405    let line_range = start..end;
406    let mut ranges = vec![line_range.to_offset(&buffer)];
407
408    // Search for identifiers mentioned in headers/signatures of containing outline items.
409    let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
410    for item in outline_items {
411        if let Some(body_range) = item.body_range(&buffer) {
412            ranges.push(item.range.start..body_range.start.to_offset(&buffer));
413        } else {
414            ranges.push(item.range.clone());
415        }
416    }
417
418    ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
419    ranges.dedup_by(|a, b| {
420        if a.start <= b.end {
421            b.start = b.start.min(a.start);
422            b.end = b.end.max(a.end);
423            true
424        } else {
425            false
426        }
427    });
428
429    let mut identifiers = Vec::new();
430    let outer_range =
431        ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
432
433    let mut captures = buffer
434        .syntax
435        .captures(outer_range.clone(), &buffer.text, |grammar| {
436            grammar
437                .highlights_config
438                .as_ref()
439                .map(|config| &config.query)
440        });
441
442    for range in ranges {
443        captures.set_byte_range(range.start..outer_range.end);
444
445        let mut last_range = None;
446        while let Some(capture) = captures.peek() {
447            let node_range = capture.node.byte_range();
448            if node_range.start > range.end {
449                break;
450            }
451            let config = captures.grammars()[capture.grammar_index]
452                .highlights_config
453                .as_ref();
454
455            if let Some(config) = config
456                && config.identifier_capture_indices.contains(&capture.index)
457                && range.contains_inclusive(&node_range)
458                && Some(&node_range) != last_range.as_ref()
459            {
460                let name = buffer.text_for_range(node_range.clone()).collect();
461                identifiers.push(Identifier {
462                    range: buffer.anchor_after(node_range.start)
463                        ..buffer.anchor_before(node_range.end),
464                    name,
465                });
466                last_range = Some(node_range);
467            }
468
469            captures.advance();
470        }
471    }
472
473    identifiers
474}