edit_prediction_context.rs

  1use crate::assemble_excerpts::assemble_excerpt_ranges;
  2use anyhow::Result;
  3use collections::HashMap;
  4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
  5use gpui::{App, AppContext, AsyncApp, Context, Entity, EntityId, EventEmitter, Task, WeakEntity};
  6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
  7use project::{LocationLink, Project, ProjectPath};
  8use smallvec::SmallVec;
  9use std::{
 10    collections::hash_map,
 11    ops::Range,
 12    path::Path,
 13    sync::Arc,
 14    time::{Duration, Instant},
 15};
 16use util::paths::PathStyle;
 17use util::rel_path::RelPath;
 18use util::{RangeExt as _, ResultExt};
 19
 20mod assemble_excerpts;
 21#[cfg(test)]
 22mod edit_prediction_context_tests;
 23#[cfg(test)]
 24mod fake_definition_lsp;
 25
 26pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
 27
 28const IDENTIFIER_LINE_COUNT: u32 = 3;
 29
 30pub struct RelatedExcerptStore {
 31    project: WeakEntity<Project>,
 32    related_buffers: Vec<RelatedBuffer>,
 33    cache: HashMap<Identifier, Arc<CacheEntry>>,
 34    update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
 35    identifier_line_count: u32,
 36}
 37
 38struct RelatedBuffer {
 39    buffer: Entity<Buffer>,
 40    path: Arc<Path>,
 41    anchor_ranges: Vec<Range<Anchor>>,
 42    cached_file: Option<CachedRelatedFile>,
 43}
 44
 45struct CachedRelatedFile {
 46    excerpts: Vec<RelatedExcerpt>,
 47    buffer_version: clock::Global,
 48}
 49
 50pub enum RelatedExcerptStoreEvent {
 51    StartedRefresh,
 52    FinishedRefresh {
 53        cache_hit_count: usize,
 54        cache_miss_count: usize,
 55        mean_definition_latency: Duration,
 56        max_definition_latency: Duration,
 57    },
 58}
 59
 60#[derive(Clone, Debug, PartialEq, Eq, Hash)]
 61struct Identifier {
 62    pub name: String,
 63    pub range: Range<Anchor>,
 64}
 65
 66enum DefinitionTask {
 67    CacheHit(Arc<CacheEntry>),
 68    CacheMiss {
 69        definitions: Task<Result<Option<Vec<LocationLink>>>>,
 70        type_definitions: Task<Result<Option<Vec<LocationLink>>>>,
 71    },
 72}
 73
 74#[derive(Debug)]
 75struct CacheEntry {
 76    definitions: SmallVec<[CachedDefinition; 1]>,
 77    type_definitions: SmallVec<[CachedDefinition; 1]>,
 78}
 79
 80#[derive(Clone, Debug)]
 81struct CachedDefinition {
 82    path: ProjectPath,
 83    buffer: Entity<Buffer>,
 84    anchor_range: Range<Anchor>,
 85}
 86
 87const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
 88
 89impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
 90
 91impl RelatedExcerptStore {
 92    pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
 93        let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
 94        cx.spawn(async move |this, cx| {
 95            let executor = cx.background_executor().clone();
 96            while let Some((mut buffer, mut position)) = update_rx.next().await {
 97                let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
 98                loop {
 99                    futures::select_biased! {
100                        next = update_rx.next() => {
101                            if let Some((new_buffer, new_position)) = next {
102                                buffer = new_buffer;
103                                position = new_position;
104                                timer = executor.timer(DEBOUNCE_DURATION).fuse();
105                            } else {
106                                return anyhow::Ok(());
107                            }
108                        }
109                        _ = timer => break,
110                    }
111                }
112
113                Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
114            }
115            anyhow::Ok(())
116        })
117        .detach_and_log_err(cx);
118
119        RelatedExcerptStore {
120            project: project.downgrade(),
121            update_tx,
122            related_buffers: Vec::new(),
123            cache: Default::default(),
124            identifier_line_count: IDENTIFIER_LINE_COUNT,
125        }
126    }
127
128    pub fn set_identifier_line_count(&mut self, count: u32) {
129        self.identifier_line_count = count;
130    }
131
132    pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
133        self.update_tx.unbounded_send((buffer, position)).ok();
134    }
135
136    pub fn related_files(&mut self, cx: &App) -> Vec<RelatedFile> {
137        self.related_buffers
138            .iter_mut()
139            .map(|related| related.related_file(cx))
140            .collect()
141    }
142
143    pub fn related_files_with_buffers(
144        &mut self,
145        cx: &App,
146    ) -> impl Iterator<Item = (RelatedFile, Entity<Buffer>)> {
147        self.related_buffers
148            .iter_mut()
149            .map(|related| (related.related_file(cx), related.buffer.clone()))
150    }
151
152    pub fn set_related_files(&mut self, files: Vec<RelatedFile>, cx: &App) {
153        self.related_buffers = files
154            .into_iter()
155            .filter_map(|file| {
156                let project = self.project.upgrade()?;
157                let project = project.read(cx);
158                let worktree = project.worktrees(cx).find(|wt| {
159                    let root_name = wt.read(cx).root_name().as_unix_str();
160                    file.path
161                        .components()
162                        .next()
163                        .is_some_and(|c| c.as_os_str() == root_name)
164                })?;
165                let worktree = worktree.read(cx);
166                let relative_path = file
167                    .path
168                    .strip_prefix(worktree.root_name().as_unix_str())
169                    .ok()?;
170                let relative_path = RelPath::new(relative_path, PathStyle::Posix).ok()?;
171                let project_path = ProjectPath {
172                    worktree_id: worktree.id(),
173                    path: relative_path.into_owned().into(),
174                };
175                let buffer = project.get_open_buffer(&project_path, cx)?;
176                let snapshot = buffer.read(cx).snapshot();
177                let anchor_ranges = file
178                    .excerpts
179                    .iter()
180                    .map(|excerpt| {
181                        let start = snapshot.anchor_before(Point::new(excerpt.row_range.start, 0));
182                        let end_row = excerpt.row_range.end;
183                        let end_col = snapshot.line_len(end_row);
184                        let end = snapshot.anchor_after(Point::new(end_row, end_col));
185                        start..end
186                    })
187                    .collect();
188                Some(RelatedBuffer {
189                    buffer,
190                    path: file.path.clone(),
191                    anchor_ranges,
192                    cached_file: None,
193                })
194            })
195            .collect();
196    }
197
198    async fn fetch_excerpts(
199        this: WeakEntity<Self>,
200        buffer: Entity<Buffer>,
201        position: Anchor,
202        cx: &mut AsyncApp,
203    ) -> Result<()> {
204        let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
205            (
206                this.project.upgrade(),
207                buffer.read(cx).snapshot(),
208                this.identifier_line_count,
209            )
210        })?;
211        let Some(project) = project else {
212            return Ok(());
213        };
214
215        let file = snapshot.file().cloned();
216        if let Some(file) = &file {
217            log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
218        }
219
220        this.update(cx, |_, cx| {
221            cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
222        })?;
223
224        let identifiers = cx
225            .background_spawn(async move {
226                identifiers_for_position(&snapshot, position, identifier_line_count)
227            })
228            .await;
229
230        let async_cx = cx.clone();
231        let start_time = Instant::now();
232        let futures = this.update(cx, |this, cx| {
233            identifiers
234                .into_iter()
235                .filter_map(|identifier| {
236                    let task = if let Some(entry) = this.cache.get(&identifier) {
237                        DefinitionTask::CacheHit(entry.clone())
238                    } else {
239                        let definitions = this
240                            .project
241                            .update(cx, |project, cx| {
242                                project.definitions(&buffer, identifier.range.start, cx)
243                            })
244                            .ok()?;
245                        let type_definitions = this
246                            .project
247                            .update(cx, |project, cx| {
248                                project.type_definitions(&buffer, identifier.range.start, cx)
249                            })
250                            .ok()?;
251                        DefinitionTask::CacheMiss {
252                            definitions,
253                            type_definitions,
254                        }
255                    };
256
257                    let cx = async_cx.clone();
258                    let project = project.clone();
259                    Some(async move {
260                        match task {
261                            DefinitionTask::CacheHit(cache_entry) => {
262                                Some((identifier, cache_entry, None))
263                            }
264                            DefinitionTask::CacheMiss {
265                                definitions,
266                                type_definitions,
267                            } => {
268                                let (definition_locations, type_definition_locations) =
269                                    futures::join!(definitions, type_definitions);
270                                let duration = start_time.elapsed();
271
272                                let definition_locations =
273                                    definition_locations.log_err().flatten().unwrap_or_default();
274                                let type_definition_locations = type_definition_locations
275                                    .log_err()
276                                    .flatten()
277                                    .unwrap_or_default();
278
279                                Some(cx.update(|cx| {
280                                    let definitions: SmallVec<[CachedDefinition; 1]> =
281                                        definition_locations
282                                            .into_iter()
283                                            .filter_map(|location| {
284                                                process_definition(location, &project, cx)
285                                            })
286                                            .collect();
287
288                                    let type_definitions: SmallVec<[CachedDefinition; 1]> =
289                                        type_definition_locations
290                                            .into_iter()
291                                            .filter_map(|location| {
292                                                process_definition(location, &project, cx)
293                                            })
294                                            .filter(|type_def| {
295                                                !definitions.iter().any(|def| {
296                                                    def.buffer.entity_id()
297                                                        == type_def.buffer.entity_id()
298                                                        && def.anchor_range == type_def.anchor_range
299                                                })
300                                            })
301                                            .collect();
302
303                                    (
304                                        identifier,
305                                        Arc::new(CacheEntry {
306                                            definitions,
307                                            type_definitions,
308                                        }),
309                                        Some(duration),
310                                    )
311                                }))
312                            }
313                        }
314                    })
315                })
316                .collect::<Vec<_>>()
317        })?;
318
319        let mut cache_hit_count = 0;
320        let mut cache_miss_count = 0;
321        let mut mean_definition_latency = Duration::ZERO;
322        let mut max_definition_latency = Duration::ZERO;
323        let mut new_cache = HashMap::default();
324        new_cache.reserve(futures.len());
325        for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
326            new_cache.insert(identifier, entry);
327            if let Some(duration) = duration {
328                cache_miss_count += 1;
329                mean_definition_latency += duration;
330                max_definition_latency = max_definition_latency.max(duration);
331            } else {
332                cache_hit_count += 1;
333            }
334        }
335        mean_definition_latency /= cache_miss_count.max(1) as u32;
336
337        let (new_cache, related_buffers) = rebuild_related_files(&project, new_cache, cx).await?;
338
339        if let Some(file) = &file {
340            log::debug!(
341                "finished retrieving context buffer:{}, latency:{:?}",
342                file.path().as_unix_str(),
343                start_time.elapsed()
344            );
345        }
346
347        this.update(cx, |this, cx| {
348            this.cache = new_cache;
349            this.related_buffers = related_buffers;
350            cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
351                cache_hit_count,
352                cache_miss_count,
353                mean_definition_latency,
354                max_definition_latency,
355            });
356        })?;
357
358        anyhow::Ok(())
359    }
360}
361
362async fn rebuild_related_files(
363    project: &Entity<Project>,
364    mut new_entries: HashMap<Identifier, Arc<CacheEntry>>,
365    cx: &mut AsyncApp,
366) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedBuffer>)> {
367    let mut snapshots = HashMap::default();
368    let mut worktree_root_names = HashMap::default();
369    for entry in new_entries.values() {
370        for definition in entry
371            .definitions
372            .iter()
373            .chain(entry.type_definitions.iter())
374        {
375            if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
376                definition
377                    .buffer
378                    .read_with(cx, |buffer, _| buffer.parsing_idle())
379                    .await;
380                e.insert(
381                    definition
382                        .buffer
383                        .read_with(cx, |buffer, _| buffer.snapshot()),
384                );
385            }
386            let worktree_id = definition.path.worktree_id;
387            if let hash_map::Entry::Vacant(e) =
388                worktree_root_names.entry(definition.path.worktree_id)
389            {
390                project.read_with(cx, |project, cx| {
391                    if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
392                        e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
393                    }
394                });
395            }
396        }
397    }
398
399    Ok(cx
400        .background_spawn(async move {
401            let mut ranges_by_buffer =
402                HashMap::<EntityId, (Entity<Buffer>, Vec<Range<Point>>)>::default();
403            let mut paths_by_buffer = HashMap::default();
404            for entry in new_entries.values_mut() {
405                for definition in entry
406                    .definitions
407                    .iter()
408                    .chain(entry.type_definitions.iter())
409                {
410                    let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
411                        continue;
412                    };
413                    paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
414
415                    ranges_by_buffer
416                        .entry(definition.buffer.entity_id())
417                        .or_insert_with(|| (definition.buffer.clone(), Vec::new()))
418                        .1
419                        .push(definition.anchor_range.to_point(snapshot));
420                }
421            }
422
423            let mut related_buffers: Vec<RelatedBuffer> = ranges_by_buffer
424                .into_iter()
425                .filter_map(|(entity_id, (buffer, ranges))| {
426                    let snapshot = snapshots.get(&entity_id)?;
427                    let project_path = paths_by_buffer.get(&entity_id)?;
428                    let row_ranges = assemble_excerpt_ranges(snapshot, ranges);
429                    let root_name = worktree_root_names.get(&project_path.worktree_id)?;
430
431                    let path: Arc<Path> = Path::new(&format!(
432                        "{}/{}",
433                        root_name,
434                        project_path.path.as_unix_str()
435                    ))
436                    .into();
437
438                    let anchor_ranges = row_ranges
439                        .into_iter()
440                        .map(|row_range| {
441                            let start = snapshot.anchor_before(Point::new(row_range.start, 0));
442                            let end_col = snapshot.line_len(row_range.end);
443                            let end = snapshot.anchor_after(Point::new(row_range.end, end_col));
444                            start..end
445                        })
446                        .collect();
447
448                    let mut related_buffer = RelatedBuffer {
449                        buffer,
450                        path,
451                        anchor_ranges,
452                        cached_file: None,
453                    };
454                    related_buffer.fill_cache(snapshot);
455                    Some(related_buffer)
456                })
457                .collect();
458
459            related_buffers.sort_by_key(|related| related.path.clone());
460
461            (new_entries, related_buffers)
462        })
463        .await)
464}
465
466impl RelatedBuffer {
467    fn related_file(&mut self, cx: &App) -> RelatedFile {
468        let buffer = self.buffer.read(cx);
469        let path = self.path.clone();
470        let cached = if let Some(cached) = &self.cached_file
471            && buffer.version() == cached.buffer_version
472        {
473            cached
474        } else {
475            self.fill_cache(buffer)
476        };
477        let related_file = RelatedFile {
478            path,
479            excerpts: cached.excerpts.clone(),
480            max_row: buffer.max_point().row,
481            in_open_source_repo: false,
482        };
483        return related_file;
484    }
485
486    fn fill_cache(&mut self, buffer: &text::BufferSnapshot) -> &CachedRelatedFile {
487        let excerpts = self
488            .anchor_ranges
489            .iter()
490            .map(|range| {
491                let start = range.start.to_point(buffer);
492                let end = range.end.to_point(buffer);
493                RelatedExcerpt {
494                    row_range: start.row..end.row,
495                    text: buffer.text_for_range(start..end).collect::<String>().into(),
496                }
497            })
498            .collect::<Vec<_>>();
499        self.cached_file = Some(CachedRelatedFile {
500            excerpts: excerpts,
501            buffer_version: buffer.version().clone(),
502        });
503        self.cached_file.as_ref().unwrap()
504    }
505}
506
507use language::ToPoint as _;
508
509const MAX_TARGET_LEN: usize = 128;
510
511fn process_definition(
512    location: LocationLink,
513    project: &Entity<Project>,
514    cx: &mut App,
515) -> Option<CachedDefinition> {
516    let buffer = location.target.buffer.read(cx);
517    let anchor_range = location.target.range;
518    let file = buffer.file()?;
519    let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
520    if worktree.read(cx).is_single_file() {
521        return None;
522    }
523
524    // If the target range is large, it likely means we requested the definition of an entire module.
525    // For individual definitions, the target range should be small as it only covers the symbol.
526    let buffer = location.target.buffer.read(cx);
527    let target_len = anchor_range.to_offset(&buffer).len();
528    if target_len > MAX_TARGET_LEN {
529        return None;
530    }
531
532    Some(CachedDefinition {
533        path: ProjectPath {
534            worktree_id: file.worktree_id(cx),
535            path: file.path().clone(),
536        },
537        buffer: location.target.buffer,
538        anchor_range,
539    })
540}
541
542/// Gets all of the identifiers that are present in the given line, and its containing
543/// outline items.
544fn identifiers_for_position(
545    buffer: &BufferSnapshot,
546    position: Anchor,
547    identifier_line_count: u32,
548) -> Vec<Identifier> {
549    let offset = position.to_offset(buffer);
550    let point = buffer.offset_to_point(offset);
551
552    // Search for identifiers on lines adjacent to the cursor.
553    let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
554    let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
555    let line_range = start..end;
556    let mut ranges = vec![line_range.to_offset(&buffer)];
557
558    // Search for identifiers mentioned in headers/signatures of containing outline items.
559    let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
560    for item in outline_items {
561        if let Some(body_range) = item.body_range(&buffer) {
562            ranges.push(item.range.start..body_range.start.to_offset(&buffer));
563        } else {
564            ranges.push(item.range.clone());
565        }
566    }
567
568    ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
569    ranges.dedup_by(|a, b| {
570        if a.start <= b.end {
571            b.start = b.start.min(a.start);
572            b.end = b.end.max(a.end);
573            true
574        } else {
575            false
576        }
577    });
578
579    let mut identifiers = Vec::new();
580    let outer_range =
581        ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
582
583    let mut captures = buffer
584        .syntax
585        .captures(outer_range.clone(), &buffer.text, |grammar| {
586            grammar
587                .highlights_config
588                .as_ref()
589                .map(|config| &config.query)
590        });
591
592    for range in ranges {
593        captures.set_byte_range(range.start..outer_range.end);
594
595        let mut last_range = None;
596        while let Some(capture) = captures.peek() {
597            let node_range = capture.node.byte_range();
598            if node_range.start > range.end {
599                break;
600            }
601            let config = captures.grammars()[capture.grammar_index]
602                .highlights_config
603                .as_ref();
604
605            if let Some(config) = config
606                && config.identifier_capture_indices.contains(&capture.index)
607                && range.contains_inclusive(&node_range)
608                && Some(&node_range) != last_range.as_ref()
609            {
610                let name = buffer.text_for_range(node_range.clone()).collect();
611                identifiers.push(Identifier {
612                    range: buffer.anchor_after(node_range.start)
613                        ..buffer.anchor_before(node_range.end),
614                    name,
615                });
616                last_range = Some(node_range);
617            }
618
619            captures.advance();
620        }
621    }
622
623    identifiers
624}