edit_prediction_context.rs

  1use crate::assemble_excerpts::assemble_excerpt_ranges;
  2use anyhow::Result;
  3use collections::HashMap;
  4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
  5use gpui::{App, AppContext, AsyncApp, Context, Entity, EntityId, EventEmitter, Task, WeakEntity};
  6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
  7use project::{LocationLink, Project, ProjectPath};
  8use smallvec::SmallVec;
  9use std::{
 10    collections::hash_map,
 11    ops::Range,
 12    path::Path,
 13    sync::Arc,
 14    time::{Duration, Instant},
 15};
 16use util::paths::PathStyle;
 17use util::rel_path::RelPath;
 18use util::{RangeExt as _, ResultExt};
 19
 20mod assemble_excerpts;
 21#[cfg(test)]
 22mod edit_prediction_context_tests;
 23#[cfg(test)]
 24mod fake_definition_lsp;
 25
 26pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
 27
 28const IDENTIFIER_LINE_COUNT: u32 = 3;
 29
 30pub struct RelatedExcerptStore {
 31    project: WeakEntity<Project>,
 32    related_buffers: Vec<RelatedBuffer>,
 33    cache: HashMap<Identifier, Arc<CacheEntry>>,
 34    update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
 35    identifier_line_count: u32,
 36}
 37
 38struct RelatedBuffer {
 39    buffer: Entity<Buffer>,
 40    path: Arc<Path>,
 41    anchor_ranges: Vec<Range<Anchor>>,
 42    cached_file: Option<CachedRelatedFile>,
 43}
 44
 45struct CachedRelatedFile {
 46    excerpts: Vec<RelatedExcerpt>,
 47    buffer_version: clock::Global,
 48}
 49
 50pub enum RelatedExcerptStoreEvent {
 51    StartedRefresh,
 52    FinishedRefresh {
 53        cache_hit_count: usize,
 54        cache_miss_count: usize,
 55        mean_definition_latency: Duration,
 56        max_definition_latency: Duration,
 57    },
 58}
 59
 60#[derive(Clone, Debug, PartialEq, Eq, Hash)]
 61struct Identifier {
 62    pub name: String,
 63    pub range: Range<Anchor>,
 64}
 65
 66enum DefinitionTask {
 67    CacheHit(Arc<CacheEntry>),
 68    CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
 69}
 70
 71#[derive(Debug)]
 72struct CacheEntry {
 73    definitions: SmallVec<[CachedDefinition; 1]>,
 74}
 75
 76#[derive(Clone, Debug)]
 77struct CachedDefinition {
 78    path: ProjectPath,
 79    buffer: Entity<Buffer>,
 80    anchor_range: Range<Anchor>,
 81}
 82
 83const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
 84
 85impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
 86
 87impl RelatedExcerptStore {
 88    pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
 89        let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
 90        cx.spawn(async move |this, cx| {
 91            let executor = cx.background_executor().clone();
 92            while let Some((mut buffer, mut position)) = update_rx.next().await {
 93                let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
 94                loop {
 95                    futures::select_biased! {
 96                        next = update_rx.next() => {
 97                            if let Some((new_buffer, new_position)) = next {
 98                                buffer = new_buffer;
 99                                position = new_position;
100                                timer = executor.timer(DEBOUNCE_DURATION).fuse();
101                            } else {
102                                return anyhow::Ok(());
103                            }
104                        }
105                        _ = timer => break,
106                    }
107                }
108
109                Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
110            }
111            anyhow::Ok(())
112        })
113        .detach_and_log_err(cx);
114
115        RelatedExcerptStore {
116            project: project.downgrade(),
117            update_tx,
118            related_buffers: Vec::new(),
119            cache: Default::default(),
120            identifier_line_count: IDENTIFIER_LINE_COUNT,
121        }
122    }
123
124    pub fn set_identifier_line_count(&mut self, count: u32) {
125        self.identifier_line_count = count;
126    }
127
128    pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
129        self.update_tx.unbounded_send((buffer, position)).ok();
130    }
131
132    pub fn related_files(&mut self, cx: &App) -> Vec<RelatedFile> {
133        self.related_buffers
134            .iter_mut()
135            .map(|related| related.related_file(cx))
136            .collect()
137    }
138
139    pub fn related_files_with_buffers(&mut self, cx: &App) -> Vec<(RelatedFile, Entity<Buffer>)> {
140        self.related_buffers
141            .iter_mut()
142            .map(|related| (related.related_file(cx), related.buffer.clone()))
143            .collect::<Vec<_>>()
144    }
145
146    pub fn set_related_files(&mut self, files: Vec<RelatedFile>, cx: &App) {
147        self.related_buffers = files
148            .into_iter()
149            .filter_map(|file| {
150                let project = self.project.upgrade()?;
151                let project = project.read(cx);
152                let worktree = project.worktrees(cx).find(|wt| {
153                    let root_name = wt.read(cx).root_name().as_unix_str();
154                    file.path
155                        .components()
156                        .next()
157                        .is_some_and(|c| c.as_os_str() == root_name)
158                })?;
159                let worktree = worktree.read(cx);
160                let relative_path = file
161                    .path
162                    .strip_prefix(worktree.root_name().as_unix_str())
163                    .ok()?;
164                let relative_path = RelPath::new(relative_path, PathStyle::Posix).ok()?;
165                let project_path = ProjectPath {
166                    worktree_id: worktree.id(),
167                    path: relative_path.into_owned().into(),
168                };
169                let buffer = project.get_open_buffer(&project_path, cx)?;
170                let snapshot = buffer.read(cx).snapshot();
171                let anchor_ranges = file
172                    .excerpts
173                    .iter()
174                    .map(|excerpt| {
175                        let start = snapshot.anchor_before(Point::new(excerpt.row_range.start, 0));
176                        let end_row = excerpt.row_range.end;
177                        let end_col = snapshot.line_len(end_row);
178                        let end = snapshot.anchor_after(Point::new(end_row, end_col));
179                        start..end
180                    })
181                    .collect();
182                Some(RelatedBuffer {
183                    buffer,
184                    path: file.path.clone(),
185                    anchor_ranges,
186                    cached_file: None,
187                })
188            })
189            .collect();
190    }
191
192    async fn fetch_excerpts(
193        this: WeakEntity<Self>,
194        buffer: Entity<Buffer>,
195        position: Anchor,
196        cx: &mut AsyncApp,
197    ) -> Result<()> {
198        let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
199            (
200                this.project.upgrade(),
201                buffer.read(cx).snapshot(),
202                this.identifier_line_count,
203            )
204        })?;
205        let Some(project) = project else {
206            return Ok(());
207        };
208
209        let file = snapshot.file().cloned();
210        if let Some(file) = &file {
211            log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
212        }
213
214        this.update(cx, |_, cx| {
215            cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
216        })?;
217
218        let identifiers = cx
219            .background_spawn(async move {
220                identifiers_for_position(&snapshot, position, identifier_line_count)
221            })
222            .await;
223
224        let async_cx = cx.clone();
225        let start_time = Instant::now();
226        let futures = this.update(cx, |this, cx| {
227            identifiers
228                .into_iter()
229                .filter_map(|identifier| {
230                    let task = if let Some(entry) = this.cache.get(&identifier) {
231                        DefinitionTask::CacheHit(entry.clone())
232                    } else {
233                        DefinitionTask::CacheMiss(
234                            this.project
235                                .update(cx, |project, cx| {
236                                    project.definitions(&buffer, identifier.range.start, cx)
237                                })
238                                .ok()?,
239                        )
240                    };
241
242                    let cx = async_cx.clone();
243                    let project = project.clone();
244                    Some(async move {
245                        match task {
246                            DefinitionTask::CacheHit(cache_entry) => {
247                                Some((identifier, cache_entry, None))
248                            }
249                            DefinitionTask::CacheMiss(task) => {
250                                let locations = task.await.log_err()??;
251                                let duration = start_time.elapsed();
252                                Some(cx.update(|cx| {
253                                    (
254                                        identifier,
255                                        Arc::new(CacheEntry {
256                                            definitions: locations
257                                                .into_iter()
258                                                .filter_map(|location| {
259                                                    process_definition(location, &project, cx)
260                                                })
261                                                .collect(),
262                                        }),
263                                        Some(duration),
264                                    )
265                                }))
266                            }
267                        }
268                    })
269                })
270                .collect::<Vec<_>>()
271        })?;
272
273        let mut cache_hit_count = 0;
274        let mut cache_miss_count = 0;
275        let mut mean_definition_latency = Duration::ZERO;
276        let mut max_definition_latency = Duration::ZERO;
277        let mut new_cache = HashMap::default();
278        new_cache.reserve(futures.len());
279        for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
280            new_cache.insert(identifier, entry);
281            if let Some(duration) = duration {
282                cache_miss_count += 1;
283                mean_definition_latency += duration;
284                max_definition_latency = max_definition_latency.max(duration);
285            } else {
286                cache_hit_count += 1;
287            }
288        }
289        mean_definition_latency /= cache_miss_count.max(1) as u32;
290
291        let (new_cache, related_buffers) = rebuild_related_files(&project, new_cache, cx).await?;
292
293        if let Some(file) = &file {
294            log::debug!(
295                "finished retrieving context buffer:{}, latency:{:?}",
296                file.path().as_unix_str(),
297                start_time.elapsed()
298            );
299        }
300
301        this.update(cx, |this, cx| {
302            this.cache = new_cache;
303            this.related_buffers = related_buffers;
304            cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
305                cache_hit_count,
306                cache_miss_count,
307                mean_definition_latency,
308                max_definition_latency,
309            });
310        })?;
311
312        anyhow::Ok(())
313    }
314}
315
316async fn rebuild_related_files(
317    project: &Entity<Project>,
318    mut new_entries: HashMap<Identifier, Arc<CacheEntry>>,
319    cx: &mut AsyncApp,
320) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedBuffer>)> {
321    let mut snapshots = HashMap::default();
322    let mut worktree_root_names = HashMap::default();
323    for entry in new_entries.values() {
324        for definition in &entry.definitions {
325            if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
326                definition
327                    .buffer
328                    .read_with(cx, |buffer, _| buffer.parsing_idle())
329                    .await;
330                e.insert(
331                    definition
332                        .buffer
333                        .read_with(cx, |buffer, _| buffer.snapshot()),
334                );
335            }
336            let worktree_id = definition.path.worktree_id;
337            if let hash_map::Entry::Vacant(e) =
338                worktree_root_names.entry(definition.path.worktree_id)
339            {
340                project.read_with(cx, |project, cx| {
341                    if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
342                        e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
343                    }
344                });
345            }
346        }
347    }
348
349    Ok(cx
350        .background_spawn(async move {
351            let mut ranges_by_buffer =
352                HashMap::<EntityId, (Entity<Buffer>, Vec<Range<Point>>)>::default();
353            let mut paths_by_buffer = HashMap::default();
354            for entry in new_entries.values_mut() {
355                for definition in &entry.definitions {
356                    let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
357                        continue;
358                    };
359                    paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
360
361                    ranges_by_buffer
362                        .entry(definition.buffer.entity_id())
363                        .or_insert_with(|| (definition.buffer.clone(), Vec::new()))
364                        .1
365                        .push(definition.anchor_range.to_point(snapshot));
366                }
367            }
368
369            let mut related_buffers: Vec<RelatedBuffer> = ranges_by_buffer
370                .into_iter()
371                .filter_map(|(entity_id, (buffer, ranges))| {
372                    let snapshot = snapshots.get(&entity_id)?;
373                    let project_path = paths_by_buffer.get(&entity_id)?;
374                    let row_ranges = assemble_excerpt_ranges(snapshot, ranges);
375                    let root_name = worktree_root_names.get(&project_path.worktree_id)?;
376
377                    let path: Arc<Path> = Path::new(&format!(
378                        "{}/{}",
379                        root_name,
380                        project_path.path.as_unix_str()
381                    ))
382                    .into();
383
384                    let anchor_ranges = row_ranges
385                        .into_iter()
386                        .map(|row_range| {
387                            let start = snapshot.anchor_before(Point::new(row_range.start, 0));
388                            let end_col = snapshot.line_len(row_range.end);
389                            let end = snapshot.anchor_after(Point::new(row_range.end, end_col));
390                            start..end
391                        })
392                        .collect();
393
394                    let mut related_buffer = RelatedBuffer {
395                        buffer,
396                        path,
397                        anchor_ranges,
398                        cached_file: None,
399                    };
400                    related_buffer.fill_cache(snapshot);
401                    Some(related_buffer)
402                })
403                .collect();
404
405            related_buffers.sort_by_key(|related| related.path.clone());
406
407            (new_entries, related_buffers)
408        })
409        .await)
410}
411
412impl RelatedBuffer {
413    fn related_file(&mut self, cx: &App) -> RelatedFile {
414        let buffer = self.buffer.read(cx);
415        let path = self.path.clone();
416        let cached = if let Some(cached) = &self.cached_file
417            && buffer.version() == cached.buffer_version
418        {
419            cached
420        } else {
421            self.fill_cache(buffer)
422        };
423        let related_file = RelatedFile {
424            path,
425            excerpts: cached.excerpts.clone(),
426            max_row: buffer.max_point().row,
427        };
428        return related_file;
429    }
430
431    fn fill_cache(&mut self, buffer: &text::BufferSnapshot) -> &CachedRelatedFile {
432        let excerpts = self
433            .anchor_ranges
434            .iter()
435            .map(|range| {
436                let start = range.start.to_point(buffer);
437                let end = range.end.to_point(buffer);
438                RelatedExcerpt {
439                    row_range: start.row..end.row,
440                    text: buffer.text_for_range(start..end).collect::<String>().into(),
441                }
442            })
443            .collect::<Vec<_>>();
444        self.cached_file = Some(CachedRelatedFile {
445            excerpts: excerpts,
446            buffer_version: buffer.version().clone(),
447        });
448        self.cached_file.as_ref().unwrap()
449    }
450}
451
452use language::ToPoint as _;
453
454const MAX_TARGET_LEN: usize = 128;
455
456fn process_definition(
457    location: LocationLink,
458    project: &Entity<Project>,
459    cx: &mut App,
460) -> Option<CachedDefinition> {
461    let buffer = location.target.buffer.read(cx);
462    let anchor_range = location.target.range;
463    let file = buffer.file()?;
464    let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
465    if worktree.read(cx).is_single_file() {
466        return None;
467    }
468
469    // If the target range is large, it likely means we requested the definition of an entire module.
470    // For individual definitions, the target range should be small as it only covers the symbol.
471    let buffer = location.target.buffer.read(cx);
472    let target_len = anchor_range.to_offset(&buffer).len();
473    if target_len > MAX_TARGET_LEN {
474        return None;
475    }
476
477    Some(CachedDefinition {
478        path: ProjectPath {
479            worktree_id: file.worktree_id(cx),
480            path: file.path().clone(),
481        },
482        buffer: location.target.buffer,
483        anchor_range,
484    })
485}
486
487/// Gets all of the identifiers that are present in the given line, and its containing
488/// outline items.
489fn identifiers_for_position(
490    buffer: &BufferSnapshot,
491    position: Anchor,
492    identifier_line_count: u32,
493) -> Vec<Identifier> {
494    let offset = position.to_offset(buffer);
495    let point = buffer.offset_to_point(offset);
496
497    // Search for identifiers on lines adjacent to the cursor.
498    let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
499    let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
500    let line_range = start..end;
501    let mut ranges = vec![line_range.to_offset(&buffer)];
502
503    // Search for identifiers mentioned in headers/signatures of containing outline items.
504    let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
505    for item in outline_items {
506        if let Some(body_range) = item.body_range(&buffer) {
507            ranges.push(item.range.start..body_range.start.to_offset(&buffer));
508        } else {
509            ranges.push(item.range.clone());
510        }
511    }
512
513    ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
514    ranges.dedup_by(|a, b| {
515        if a.start <= b.end {
516            b.start = b.start.min(a.start);
517            b.end = b.end.max(a.end);
518            true
519        } else {
520            false
521        }
522    });
523
524    let mut identifiers = Vec::new();
525    let outer_range =
526        ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
527
528    let mut captures = buffer
529        .syntax
530        .captures(outer_range.clone(), &buffer.text, |grammar| {
531            grammar
532                .highlights_config
533                .as_ref()
534                .map(|config| &config.query)
535        });
536
537    for range in ranges {
538        captures.set_byte_range(range.start..outer_range.end);
539
540        let mut last_range = None;
541        while let Some(capture) = captures.peek() {
542            let node_range = capture.node.byte_range();
543            if node_range.start > range.end {
544                break;
545            }
546            let config = captures.grammars()[capture.grammar_index]
547                .highlights_config
548                .as_ref();
549
550            if let Some(config) = config
551                && config.identifier_capture_indices.contains(&capture.index)
552                && range.contains_inclusive(&node_range)
553                && Some(&node_range) != last_range.as_ref()
554            {
555                let name = buffer.text_for_range(node_range.clone()).collect();
556                identifiers.push(Identifier {
557                    range: buffer.anchor_after(node_range.start)
558                        ..buffer.anchor_before(node_range.end),
559                    name,
560                });
561                last_range = Some(node_range);
562            }
563
564            captures.advance();
565        }
566    }
567
568    identifiers
569}