edit_prediction_context.rs

  1use crate::assemble_excerpts::assemble_excerpt_ranges;
  2use anyhow::Result;
  3use collections::HashMap;
  4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
  5use gpui::{App, AppContext, AsyncApp, Context, Entity, EntityId, EventEmitter, Task, WeakEntity};
  6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
  7use project::{LocationLink, Project, ProjectPath};
  8use smallvec::SmallVec;
  9use std::{
 10    collections::hash_map,
 11    ops::Range,
 12    path::Path,
 13    sync::Arc,
 14    time::{Duration, Instant},
 15};
 16use util::paths::PathStyle;
 17use util::rel_path::RelPath;
 18use util::{RangeExt as _, ResultExt};
 19
 20mod assemble_excerpts;
 21#[cfg(test)]
 22mod edit_prediction_context_tests;
 23#[cfg(test)]
 24mod fake_definition_lsp;
 25
 26pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
 27
 28const IDENTIFIER_LINE_COUNT: u32 = 3;
 29
 30pub struct RelatedExcerptStore {
 31    project: WeakEntity<Project>,
 32    related_buffers: Vec<RelatedBuffer>,
 33    cache: HashMap<Identifier, Arc<CacheEntry>>,
 34    update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
 35    identifier_line_count: u32,
 36}
 37
 38struct RelatedBuffer {
 39    buffer: Entity<Buffer>,
 40    path: Arc<Path>,
 41    anchor_ranges: Vec<Range<Anchor>>,
 42    excerpt_orders: Vec<usize>,
 43    cached_file: Option<CachedRelatedFile>,
 44}
 45
 46struct CachedRelatedFile {
 47    excerpts: Vec<RelatedExcerpt>,
 48    buffer_version: clock::Global,
 49}
 50
 51pub enum RelatedExcerptStoreEvent {
 52    StartedRefresh,
 53    FinishedRefresh {
 54        cache_hit_count: usize,
 55        cache_miss_count: usize,
 56        mean_definition_latency: Duration,
 57        max_definition_latency: Duration,
 58    },
 59}
 60
 61#[derive(Clone, Debug, PartialEq, Eq, Hash)]
 62struct Identifier {
 63    pub name: String,
 64    pub range: Range<Anchor>,
 65}
 66
 67enum DefinitionTask {
 68    CacheHit(Arc<CacheEntry>),
 69    CacheMiss {
 70        definitions: Task<Result<Option<Vec<LocationLink>>>>,
 71        type_definitions: Task<Result<Option<Vec<LocationLink>>>>,
 72    },
 73}
 74
 75#[derive(Debug)]
 76struct CacheEntry {
 77    definitions: SmallVec<[CachedDefinition; 1]>,
 78    type_definitions: SmallVec<[CachedDefinition; 1]>,
 79}
 80
 81#[derive(Clone, Debug)]
 82struct CachedDefinition {
 83    path: ProjectPath,
 84    buffer: Entity<Buffer>,
 85    anchor_range: Range<Anchor>,
 86}
 87
 88const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
 89
 90impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
 91
 92impl RelatedExcerptStore {
 93    pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
 94        let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
 95        cx.spawn(async move |this, cx| {
 96            let executor = cx.background_executor().clone();
 97            while let Some((mut buffer, mut position)) = update_rx.next().await {
 98                let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
 99                loop {
100                    futures::select_biased! {
101                        next = update_rx.next() => {
102                            if let Some((new_buffer, new_position)) = next {
103                                buffer = new_buffer;
104                                position = new_position;
105                                timer = executor.timer(DEBOUNCE_DURATION).fuse();
106                            } else {
107                                return anyhow::Ok(());
108                            }
109                        }
110                        _ = timer => break,
111                    }
112                }
113
114                Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
115            }
116            anyhow::Ok(())
117        })
118        .detach_and_log_err(cx);
119
120        RelatedExcerptStore {
121            project: project.downgrade(),
122            update_tx,
123            related_buffers: Vec::new(),
124            cache: Default::default(),
125            identifier_line_count: IDENTIFIER_LINE_COUNT,
126        }
127    }
128
129    pub fn set_identifier_line_count(&mut self, count: u32) {
130        self.identifier_line_count = count;
131    }
132
133    pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
134        self.update_tx.unbounded_send((buffer, position)).ok();
135    }
136
137    pub fn related_files(&mut self, cx: &App) -> Vec<RelatedFile> {
138        self.related_buffers
139            .iter_mut()
140            .map(|related| related.related_file(cx))
141            .collect()
142    }
143
144    pub fn related_files_with_buffers(
145        &mut self,
146        cx: &App,
147    ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
148        self.related_buffers
149            .iter_mut()
150            .map(|related| (related.related_file(cx), related.buffer.clone()))
151    }
152
153    pub fn set_related_files(&mut self, files: Vec<RelatedFile>, cx: &App) {
154        self.related_buffers = files
155            .into_iter()
156            .filter_map(|file| {
157                let project = self.project.upgrade()?;
158                let project = project.read(cx);
159                let worktree = project.worktrees(cx).find(|wt| {
160                    let root_name = wt.read(cx).root_name().as_unix_str();
161                    file.path
162                        .components()
163                        .next()
164                        .is_some_and(|c| c.as_os_str() == root_name)
165                })?;
166                let worktree = worktree.read(cx);
167                let relative_path = file
168                    .path
169                    .strip_prefix(worktree.root_name().as_unix_str())
170                    .ok()?;
171                let relative_path = RelPath::new(relative_path, PathStyle::Posix).ok()?;
172                let project_path = ProjectPath {
173                    worktree_id: worktree.id(),
174                    path: relative_path.into_owned().into(),
175                };
176                let buffer = project.get_open_buffer(&project_path, cx)?;
177                let snapshot = buffer.read(cx).snapshot();
178                let mut anchor_ranges = Vec::with_capacity(file.excerpts.len());
179                let mut excerpt_orders = Vec::with_capacity(file.excerpts.len());
180                for excerpt in &file.excerpts {
181                    let start = snapshot.anchor_before(Point::new(excerpt.row_range.start, 0));
182                    let end_row = excerpt.row_range.end;
183                    let end_col = snapshot.line_len(end_row);
184                    let end = snapshot.anchor_after(Point::new(end_row, end_col));
185                    anchor_ranges.push(start..end);
186                    excerpt_orders.push(excerpt.order);
187                }
188                Some(RelatedBuffer {
189                    buffer,
190                    path: file.path.clone(),
191                    anchor_ranges,
192                    excerpt_orders,
193                    cached_file: None,
194                })
195            })
196            .collect();
197    }
198
199    async fn fetch_excerpts(
200        this: WeakEntity<Self>,
201        buffer: Entity<Buffer>,
202        position: Anchor,
203        cx: &mut AsyncApp,
204    ) -> Result<()> {
205        let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
206            (
207                this.project.upgrade(),
208                buffer.read(cx).snapshot(),
209                this.identifier_line_count,
210            )
211        })?;
212        let Some(project) = project else {
213            return Ok(());
214        };
215
216        let file = snapshot.file().cloned();
217        if let Some(file) = &file {
218            log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
219        }
220
221        this.update(cx, |_, cx| {
222            cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
223        })?;
224
225        let identifiers_with_ranks = cx
226            .background_spawn(async move {
227                let cursor_offset = position.to_offset(&snapshot);
228                let identifiers =
229                    identifiers_for_position(&snapshot, position, identifier_line_count);
230
231                // Compute byte distance from cursor to each identifier, then sort by
232                // distance so we can assign ordinal ranks. Identifiers at the same
233                // distance share the same rank.
234                let mut identifiers_with_distance: Vec<(Identifier, usize)> = identifiers
235                    .into_iter()
236                    .map(|id| {
237                        let start = id.range.start.to_offset(&snapshot);
238                        let end = id.range.end.to_offset(&snapshot);
239                        let distance = if cursor_offset < start {
240                            start - cursor_offset
241                        } else if cursor_offset > end {
242                            cursor_offset - end
243                        } else {
244                            0
245                        };
246                        (id, distance)
247                    })
248                    .collect();
249                identifiers_with_distance.sort_by_key(|(_, distance)| *distance);
250
251                let mut cursor_distances: HashMap<Identifier, usize> = HashMap::default();
252                let mut current_rank = 0;
253                let mut previous_distance = None;
254                for (identifier, distance) in &identifiers_with_distance {
255                    if previous_distance != Some(*distance) {
256                        current_rank = cursor_distances.len();
257                        previous_distance = Some(*distance);
258                    }
259                    cursor_distances.insert(identifier.clone(), current_rank);
260                }
261
262                (identifiers_with_distance, cursor_distances)
263            })
264            .await;
265
266        let (identifiers_with_distance, cursor_distances) = identifiers_with_ranks;
267
268        let async_cx = cx.clone();
269        let start_time = Instant::now();
270        let futures = this.update(cx, |this, cx| {
271            identifiers_with_distance
272                .into_iter()
273                .filter_map(|(identifier, _)| {
274                    let task = if let Some(entry) = this.cache.get(&identifier) {
275                        DefinitionTask::CacheHit(entry.clone())
276                    } else {
277                        let definitions = this
278                            .project
279                            .update(cx, |project, cx| {
280                                project.definitions(&buffer, identifier.range.start, cx)
281                            })
282                            .ok()?;
283                        let type_definitions = this
284                            .project
285                            .update(cx, |project, cx| {
286                                project.type_definitions(&buffer, identifier.range.start, cx)
287                            })
288                            .ok()?;
289                        DefinitionTask::CacheMiss {
290                            definitions,
291                            type_definitions,
292                        }
293                    };
294
295                    let cx = async_cx.clone();
296                    let project = project.clone();
297                    Some(async move {
298                        match task {
299                            DefinitionTask::CacheHit(cache_entry) => {
300                                Some((identifier, cache_entry, None))
301                            }
302                            DefinitionTask::CacheMiss {
303                                definitions,
304                                type_definitions,
305                            } => {
306                                let (definition_locations, type_definition_locations) =
307                                    futures::join!(definitions, type_definitions);
308                                let duration = start_time.elapsed();
309
310                                let definition_locations =
311                                    definition_locations.log_err().flatten().unwrap_or_default();
312                                let type_definition_locations = type_definition_locations
313                                    .log_err()
314                                    .flatten()
315                                    .unwrap_or_default();
316
317                                Some(cx.update(|cx| {
318                                    let definitions: SmallVec<[CachedDefinition; 1]> =
319                                        definition_locations
320                                            .into_iter()
321                                            .filter_map(|location| {
322                                                process_definition(location, &project, cx)
323                                            })
324                                            .collect();
325
326                                    let type_definitions: SmallVec<[CachedDefinition; 1]> =
327                                        type_definition_locations
328                                            .into_iter()
329                                            .filter_map(|location| {
330                                                process_definition(location, &project, cx)
331                                            })
332                                            .filter(|type_def| {
333                                                !definitions.iter().any(|def| {
334                                                    def.buffer.entity_id()
335                                                        == type_def.buffer.entity_id()
336                                                        && def.anchor_range == type_def.anchor_range
337                                                })
338                                            })
339                                            .collect();
340
341                                    (
342                                        identifier,
343                                        Arc::new(CacheEntry {
344                                            definitions,
345                                            type_definitions,
346                                        }),
347                                        Some(duration),
348                                    )
349                                }))
350                            }
351                        }
352                    })
353                })
354                .collect::<Vec<_>>()
355        })?;
356
357        let mut cache_hit_count = 0;
358        let mut cache_miss_count = 0;
359        let mut mean_definition_latency = Duration::ZERO;
360        let mut max_definition_latency = Duration::ZERO;
361        let mut new_cache = HashMap::default();
362        new_cache.reserve(futures.len());
363        for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
364            new_cache.insert(identifier, entry);
365            if let Some(duration) = duration {
366                cache_miss_count += 1;
367                mean_definition_latency += duration;
368                max_definition_latency = max_definition_latency.max(duration);
369            } else {
370                cache_hit_count += 1;
371            }
372        }
373        mean_definition_latency /= cache_miss_count.max(1) as u32;
374
375        let (new_cache, related_buffers) =
376            rebuild_related_files(&project, new_cache, &cursor_distances, cx).await?;
377
378        if let Some(file) = &file {
379            log::debug!(
380                "finished retrieving context buffer:{}, latency:{:?}",
381                file.path().as_unix_str(),
382                start_time.elapsed()
383            );
384        }
385
386        this.update(cx, |this, cx| {
387            this.cache = new_cache;
388            this.related_buffers = related_buffers;
389            cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
390                cache_hit_count,
391                cache_miss_count,
392                mean_definition_latency,
393                max_definition_latency,
394            });
395        })?;
396
397        anyhow::Ok(())
398    }
399}
400
401async fn rebuild_related_files(
402    project: &Entity<Project>,
403    mut new_entries: HashMap<Identifier, Arc<CacheEntry>>,
404    cursor_distances: &HashMap<Identifier, usize>,
405    cx: &mut AsyncApp,
406) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedBuffer>)> {
407    let mut snapshots = HashMap::default();
408    let mut worktree_root_names = HashMap::default();
409    for entry in new_entries.values() {
410        for definition in entry
411            .definitions
412            .iter()
413            .chain(entry.type_definitions.iter())
414        {
415            if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
416                definition
417                    .buffer
418                    .read_with(cx, |buffer, _| buffer.parsing_idle())
419                    .await;
420                e.insert(
421                    definition
422                        .buffer
423                        .read_with(cx, |buffer, _| buffer.snapshot()),
424                );
425            }
426            let worktree_id = definition.path.worktree_id;
427            if let hash_map::Entry::Vacant(e) =
428                worktree_root_names.entry(definition.path.worktree_id)
429            {
430                project.read_with(cx, |project, cx| {
431                    if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
432                        e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
433                    }
434                });
435            }
436        }
437    }
438
439    let cursor_distances = cursor_distances.clone();
440    Ok(cx
441        .background_spawn(async move {
442            let mut ranges_by_buffer =
443                HashMap::<EntityId, (Entity<Buffer>, Vec<(Range<Point>, usize)>)>::default();
444            let mut paths_by_buffer = HashMap::default();
445            let mut min_rank_by_buffer = HashMap::<EntityId, usize>::default();
446            for (identifier, entry) in new_entries.iter_mut() {
447                let rank = cursor_distances
448                    .get(identifier)
449                    .copied()
450                    .unwrap_or(usize::MAX);
451                for definition in entry
452                    .definitions
453                    .iter()
454                    .chain(entry.type_definitions.iter())
455                {
456                    let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
457                        continue;
458                    };
459                    paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
460
461                    let buffer_rank = min_rank_by_buffer
462                        .entry(definition.buffer.entity_id())
463                        .or_insert(usize::MAX);
464                    *buffer_rank = (*buffer_rank).min(rank);
465
466                    ranges_by_buffer
467                        .entry(definition.buffer.entity_id())
468                        .or_insert_with(|| (definition.buffer.clone(), Vec::new()))
469                        .1
470                        .push((definition.anchor_range.to_point(snapshot), rank));
471                }
472            }
473
474            let mut related_buffers: Vec<RelatedBuffer> = ranges_by_buffer
475                .into_iter()
476                .filter_map(|(entity_id, (buffer, ranges))| {
477                    let snapshot = snapshots.get(&entity_id)?;
478                    let project_path = paths_by_buffer.get(&entity_id)?;
479                    let assembled = assemble_excerpt_ranges(snapshot, ranges);
480                    let root_name = worktree_root_names.get(&project_path.worktree_id)?;
481
482                    let path: Arc<Path> = Path::new(&format!(
483                        "{}/{}",
484                        root_name,
485                        project_path.path.as_unix_str()
486                    ))
487                    .into();
488
489                    let mut anchor_ranges = Vec::with_capacity(assembled.len());
490                    let mut excerpt_orders = Vec::with_capacity(assembled.len());
491                    for (row_range, order) in assembled {
492                        let start = snapshot.anchor_before(Point::new(row_range.start, 0));
493                        let end_col = snapshot.line_len(row_range.end);
494                        let end = snapshot.anchor_after(Point::new(row_range.end, end_col));
495                        anchor_ranges.push(start..end);
496                        excerpt_orders.push(order);
497                    }
498
499                    let mut related_buffer = RelatedBuffer {
500                        buffer,
501                        path,
502                        anchor_ranges,
503                        excerpt_orders,
504                        cached_file: None,
505                    };
506                    related_buffer.fill_cache(snapshot);
507                    Some(related_buffer)
508                })
509                .collect();
510
511            related_buffers.sort_by(|a, b| {
512                let rank_a = min_rank_by_buffer
513                    .get(&a.buffer.entity_id())
514                    .copied()
515                    .unwrap_or(usize::MAX);
516                let rank_b = min_rank_by_buffer
517                    .get(&b.buffer.entity_id())
518                    .copied()
519                    .unwrap_or(usize::MAX);
520                rank_a.cmp(&rank_b).then_with(|| a.path.cmp(&b.path))
521            });
522
523            (new_entries, related_buffers)
524        })
525        .await)
526}
527
528impl RelatedBuffer {
529    fn related_file(&mut self, cx: &App) -> RelatedFile {
530        let buffer = self.buffer.read(cx);
531        let path = self.path.clone();
532        let cached = if let Some(cached) = &self.cached_file
533            && buffer.version() == cached.buffer_version
534        {
535            cached
536        } else {
537            self.fill_cache(buffer)
538        };
539        let related_file = RelatedFile {
540            path,
541            excerpts: cached.excerpts.clone(),
542            max_row: buffer.max_point().row,
543            in_open_source_repo: false,
544        };
545        return related_file;
546    }
547
548    fn fill_cache(&mut self, buffer: &text::BufferSnapshot) -> &CachedRelatedFile {
549        let excerpts = self
550            .anchor_ranges
551            .iter()
552            .zip(self.excerpt_orders.iter())
553            .map(|(range, &order)| {
554                let start = range.start.to_point(buffer);
555                let end = range.end.to_point(buffer);
556                RelatedExcerpt {
557                    row_range: start.row..end.row,
558                    text: buffer.text_for_range(start..end).collect::<String>().into(),
559                    order,
560                }
561            })
562            .collect::<Vec<_>>();
563        self.cached_file = Some(CachedRelatedFile {
564            excerpts: excerpts,
565            buffer_version: buffer.version().clone(),
566        });
567        self.cached_file.as_ref().unwrap()
568    }
569}
570
571use language::ToPoint as _;
572
573const MAX_TARGET_LEN: usize = 128;
574
575fn process_definition(
576    location: LocationLink,
577    project: &Entity<Project>,
578    cx: &mut App,
579) -> Option<CachedDefinition> {
580    let buffer = location.target.buffer.read(cx);
581    let anchor_range = location.target.range;
582    let file = buffer.file()?;
583    let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
584    if worktree.read(cx).is_single_file() {
585        return None;
586    }
587
588    // If the target range is large, it likely means we requested the definition of an entire module.
589    // For individual definitions, the target range should be small as it only covers the symbol.
590    let buffer = location.target.buffer.read(cx);
591    let target_len = anchor_range.to_offset(&buffer).len();
592    if target_len > MAX_TARGET_LEN {
593        return None;
594    }
595
596    Some(CachedDefinition {
597        path: ProjectPath {
598            worktree_id: file.worktree_id(cx),
599            path: file.path().clone(),
600        },
601        buffer: location.target.buffer,
602        anchor_range,
603    })
604}
605
606/// Gets all of the identifiers that are present in the given line, and its containing
607/// outline items.
608fn identifiers_for_position(
609    buffer: &BufferSnapshot,
610    position: Anchor,
611    identifier_line_count: u32,
612) -> Vec<Identifier> {
613    let offset = position.to_offset(buffer);
614    let point = buffer.offset_to_point(offset);
615
616    // Search for identifiers on lines adjacent to the cursor.
617    let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
618    let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
619    let line_range = start..end;
620    let mut ranges = vec![line_range.to_offset(&buffer)];
621
622    // Search for identifiers mentioned in headers/signatures of containing outline items.
623    let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
624    for item in outline_items {
625        if let Some(body_range) = item.body_range(&buffer) {
626            ranges.push(item.range.start..body_range.start.to_offset(&buffer));
627        } else {
628            ranges.push(item.range.clone());
629        }
630    }
631
632    ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
633    ranges.dedup_by(|a, b| {
634        if a.start <= b.end {
635            b.start = b.start.min(a.start);
636            b.end = b.end.max(a.end);
637            true
638        } else {
639            false
640        }
641    });
642
643    let mut identifiers = Vec::new();
644    let outer_range =
645        ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
646
647    let mut captures = buffer.captures(outer_range.clone(), |grammar| {
648        grammar
649            .highlights_config
650            .as_ref()
651            .map(|config| &config.query)
652    });
653
654    for range in ranges {
655        captures.set_byte_range(range.start..outer_range.end);
656
657        let mut last_range = None;
658        while let Some(capture) = captures.peek() {
659            let node_range = capture.node.byte_range();
660            if node_range.start > range.end {
661                break;
662            }
663            let config = captures.grammars()[capture.grammar_index]
664                .highlights_config
665                .as_ref();
666
667            if let Some(config) = config
668                && config.identifier_capture_indices.contains(&capture.index)
669                && range.contains_inclusive(&node_range)
670                && Some(&node_range) != last_range.as_ref()
671            {
672                let name = buffer.text_for_range(node_range.clone()).collect();
673                identifiers.push(Identifier {
674                    range: buffer.anchor_after(node_range.start)
675                        ..buffer.anchor_before(node_range.end),
676                    name,
677                });
678                last_range = Some(node_range);
679            }
680
681            captures.advance();
682        }
683    }
684
685    identifiers
686}