edit_prediction_context.rs

  1use crate::assemble_excerpts::assemble_excerpt_ranges;
  2use anyhow::Result;
  3use collections::HashMap;
  4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
  5use gpui::{App, AppContext, AsyncApp, Context, Entity, EntityId, EventEmitter, Task, WeakEntity};
  6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, ToOffset as _};
  7use project::{LocationLink, Project, ProjectPath};
  8use smallvec::SmallVec;
  9use std::{
 10    collections::hash_map,
 11    ops::Range,
 12    path::Path,
 13    sync::Arc,
 14    time::{Duration, Instant},
 15};
 16use util::paths::PathStyle;
 17use util::rel_path::RelPath;
 18use util::{RangeExt as _, ResultExt};
 19
 20mod assemble_excerpts;
 21#[cfg(test)]
 22mod edit_prediction_context_tests;
 23mod excerpt;
 24#[cfg(test)]
 25mod fake_definition_lsp;
 26
 27pub use cloud_llm_client::predict_edits_v3::Line;
 28pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
 29pub use zeta_prompt::{RelatedExcerpt, RelatedFile};
 30
 31const IDENTIFIER_LINE_COUNT: u32 = 3;
 32
 33pub struct RelatedExcerptStore {
 34    project: WeakEntity<Project>,
 35    related_buffers: Vec<RelatedBuffer>,
 36    cache: HashMap<Identifier, Arc<CacheEntry>>,
 37    update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
 38    identifier_line_count: u32,
 39}
 40
 41struct RelatedBuffer {
 42    buffer: Entity<Buffer>,
 43    path: Arc<Path>,
 44    anchor_ranges: Vec<Range<Anchor>>,
 45    cached_file: Option<CachedRelatedFile>,
 46}
 47
 48struct CachedRelatedFile {
 49    excerpts: Vec<RelatedExcerpt>,
 50    buffer_version: clock::Global,
 51}
 52
 53pub enum RelatedExcerptStoreEvent {
 54    StartedRefresh,
 55    FinishedRefresh {
 56        cache_hit_count: usize,
 57        cache_miss_count: usize,
 58        mean_definition_latency: Duration,
 59        max_definition_latency: Duration,
 60    },
 61}
 62
 63#[derive(Clone, Debug, PartialEq, Eq, Hash)]
 64struct Identifier {
 65    pub name: String,
 66    pub range: Range<Anchor>,
 67}
 68
 69enum DefinitionTask {
 70    CacheHit(Arc<CacheEntry>),
 71    CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
 72}
 73
 74#[derive(Debug)]
 75struct CacheEntry {
 76    definitions: SmallVec<[CachedDefinition; 1]>,
 77}
 78
 79#[derive(Clone, Debug)]
 80struct CachedDefinition {
 81    path: ProjectPath,
 82    buffer: Entity<Buffer>,
 83    anchor_range: Range<Anchor>,
 84}
 85
 86const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
 87
 88impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
 89
 90impl RelatedExcerptStore {
 91    pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
 92        let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
 93        cx.spawn(async move |this, cx| {
 94            let executor = cx.background_executor().clone();
 95            while let Some((mut buffer, mut position)) = update_rx.next().await {
 96                let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
 97                loop {
 98                    futures::select_biased! {
 99                        next = update_rx.next() => {
100                            if let Some((new_buffer, new_position)) = next {
101                                buffer = new_buffer;
102                                position = new_position;
103                                timer = executor.timer(DEBOUNCE_DURATION).fuse();
104                            } else {
105                                return anyhow::Ok(());
106                            }
107                        }
108                        _ = timer => break,
109                    }
110                }
111
112                Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
113            }
114            anyhow::Ok(())
115        })
116        .detach_and_log_err(cx);
117
118        RelatedExcerptStore {
119            project: project.downgrade(),
120            update_tx,
121            related_buffers: Vec::new(),
122            cache: Default::default(),
123            identifier_line_count: IDENTIFIER_LINE_COUNT,
124        }
125    }
126
127    pub fn set_identifier_line_count(&mut self, count: u32) {
128        self.identifier_line_count = count;
129    }
130
131    pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
132        self.update_tx.unbounded_send((buffer, position)).ok();
133    }
134
135    pub fn related_files(&mut self, cx: &App) -> Vec<RelatedFile> {
136        self.related_buffers
137            .iter_mut()
138            .map(|related| related.related_file(cx))
139            .collect()
140    }
141
142    pub fn related_files_with_buffers(&mut self, cx: &App) -> Vec<(RelatedFile, Entity<Buffer>)> {
143        self.related_buffers
144            .iter_mut()
145            .map(|related| (related.related_file(cx), related.buffer.clone()))
146            .collect::<Vec<_>>()
147    }
148
149    pub fn set_related_files(&mut self, files: Vec<RelatedFile>, cx: &App) {
150        self.related_buffers = files
151            .into_iter()
152            .filter_map(|file| {
153                let project = self.project.upgrade()?;
154                let project = project.read(cx);
155                let worktree = project.worktrees(cx).find(|wt| {
156                    let root_name = wt.read(cx).root_name().as_unix_str();
157                    file.path
158                        .components()
159                        .next()
160                        .is_some_and(|c| c.as_os_str() == root_name)
161                })?;
162                let worktree = worktree.read(cx);
163                let relative_path = file
164                    .path
165                    .strip_prefix(worktree.root_name().as_unix_str())
166                    .ok()?;
167                let relative_path = RelPath::new(relative_path, PathStyle::Posix).ok()?;
168                let project_path = ProjectPath {
169                    worktree_id: worktree.id(),
170                    path: relative_path.into_owned().into(),
171                };
172                let buffer = project.get_open_buffer(&project_path, cx)?;
173                let snapshot = buffer.read(cx).snapshot();
174                let anchor_ranges = file
175                    .excerpts
176                    .iter()
177                    .map(|excerpt| {
178                        let start = snapshot.anchor_before(Point::new(excerpt.row_range.start, 0));
179                        let end_row = excerpt.row_range.end;
180                        let end_col = snapshot.line_len(end_row);
181                        let end = snapshot.anchor_after(Point::new(end_row, end_col));
182                        start..end
183                    })
184                    .collect();
185                Some(RelatedBuffer {
186                    buffer,
187                    path: file.path.clone(),
188                    anchor_ranges,
189                    cached_file: None,
190                })
191            })
192            .collect();
193    }
194
195    async fn fetch_excerpts(
196        this: WeakEntity<Self>,
197        buffer: Entity<Buffer>,
198        position: Anchor,
199        cx: &mut AsyncApp,
200    ) -> Result<()> {
201        let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
202            (
203                this.project.upgrade(),
204                buffer.read(cx).snapshot(),
205                this.identifier_line_count,
206            )
207        })?;
208        let Some(project) = project else {
209            return Ok(());
210        };
211
212        let file = snapshot.file().cloned();
213        if let Some(file) = &file {
214            log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
215        }
216
217        this.update(cx, |_, cx| {
218            cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
219        })?;
220
221        let identifiers = cx
222            .background_spawn(async move {
223                identifiers_for_position(&snapshot, position, identifier_line_count)
224            })
225            .await;
226
227        let async_cx = cx.clone();
228        let start_time = Instant::now();
229        let futures = this.update(cx, |this, cx| {
230            identifiers
231                .into_iter()
232                .filter_map(|identifier| {
233                    let task = if let Some(entry) = this.cache.get(&identifier) {
234                        DefinitionTask::CacheHit(entry.clone())
235                    } else {
236                        DefinitionTask::CacheMiss(
237                            this.project
238                                .update(cx, |project, cx| {
239                                    project.definitions(&buffer, identifier.range.start, cx)
240                                })
241                                .ok()?,
242                        )
243                    };
244
245                    let cx = async_cx.clone();
246                    let project = project.clone();
247                    Some(async move {
248                        match task {
249                            DefinitionTask::CacheHit(cache_entry) => {
250                                Some((identifier, cache_entry, None))
251                            }
252                            DefinitionTask::CacheMiss(task) => {
253                                let locations = task.await.log_err()??;
254                                let duration = start_time.elapsed();
255                                Some(cx.update(|cx| {
256                                    (
257                                        identifier,
258                                        Arc::new(CacheEntry {
259                                            definitions: locations
260                                                .into_iter()
261                                                .filter_map(|location| {
262                                                    process_definition(location, &project, cx)
263                                                })
264                                                .collect(),
265                                        }),
266                                        Some(duration),
267                                    )
268                                }))
269                            }
270                        }
271                    })
272                })
273                .collect::<Vec<_>>()
274        })?;
275
276        let mut cache_hit_count = 0;
277        let mut cache_miss_count = 0;
278        let mut mean_definition_latency = Duration::ZERO;
279        let mut max_definition_latency = Duration::ZERO;
280        let mut new_cache = HashMap::default();
281        new_cache.reserve(futures.len());
282        for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
283            new_cache.insert(identifier, entry);
284            if let Some(duration) = duration {
285                cache_miss_count += 1;
286                mean_definition_latency += duration;
287                max_definition_latency = max_definition_latency.max(duration);
288            } else {
289                cache_hit_count += 1;
290            }
291        }
292        mean_definition_latency /= cache_miss_count.max(1) as u32;
293
294        let (new_cache, related_buffers) = rebuild_related_files(&project, new_cache, cx).await?;
295
296        if let Some(file) = &file {
297            log::debug!(
298                "finished retrieving context buffer:{}, latency:{:?}",
299                file.path().as_unix_str(),
300                start_time.elapsed()
301            );
302        }
303
304        this.update(cx, |this, cx| {
305            this.cache = new_cache;
306            this.related_buffers = related_buffers;
307            cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
308                cache_hit_count,
309                cache_miss_count,
310                mean_definition_latency,
311                max_definition_latency,
312            });
313        })?;
314
315        anyhow::Ok(())
316    }
317}
318
319async fn rebuild_related_files(
320    project: &Entity<Project>,
321    mut new_entries: HashMap<Identifier, Arc<CacheEntry>>,
322    cx: &mut AsyncApp,
323) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedBuffer>)> {
324    let mut snapshots = HashMap::default();
325    let mut worktree_root_names = HashMap::default();
326    for entry in new_entries.values() {
327        for definition in &entry.definitions {
328            if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
329                definition
330                    .buffer
331                    .read_with(cx, |buffer, _| buffer.parsing_idle())
332                    .await;
333                e.insert(
334                    definition
335                        .buffer
336                        .read_with(cx, |buffer, _| buffer.snapshot()),
337                );
338            }
339            let worktree_id = definition.path.worktree_id;
340            if let hash_map::Entry::Vacant(e) =
341                worktree_root_names.entry(definition.path.worktree_id)
342            {
343                project.read_with(cx, |project, cx| {
344                    if let Some(worktree) = project.worktree_for_id(worktree_id, cx) {
345                        e.insert(worktree.read(cx).root_name().as_unix_str().to_string());
346                    }
347                });
348            }
349        }
350    }
351
352    Ok(cx
353        .background_spawn(async move {
354            let mut ranges_by_buffer =
355                HashMap::<EntityId, (Entity<Buffer>, Vec<Range<Point>>)>::default();
356            let mut paths_by_buffer = HashMap::default();
357            for entry in new_entries.values_mut() {
358                for definition in &entry.definitions {
359                    let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
360                        continue;
361                    };
362                    paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
363
364                    ranges_by_buffer
365                        .entry(definition.buffer.entity_id())
366                        .or_insert_with(|| (definition.buffer.clone(), Vec::new()))
367                        .1
368                        .push(definition.anchor_range.to_point(snapshot));
369                }
370            }
371
372            let mut related_buffers: Vec<RelatedBuffer> = ranges_by_buffer
373                .into_iter()
374                .filter_map(|(entity_id, (buffer, ranges))| {
375                    let snapshot = snapshots.get(&entity_id)?;
376                    let project_path = paths_by_buffer.get(&entity_id)?;
377                    let row_ranges = assemble_excerpt_ranges(snapshot, ranges);
378                    let root_name = worktree_root_names.get(&project_path.worktree_id)?;
379
380                    let path: Arc<Path> = Path::new(&format!(
381                        "{}/{}",
382                        root_name,
383                        project_path.path.as_unix_str()
384                    ))
385                    .into();
386
387                    let anchor_ranges = row_ranges
388                        .into_iter()
389                        .map(|row_range| {
390                            let start = snapshot.anchor_before(Point::new(row_range.start, 0));
391                            let end_col = snapshot.line_len(row_range.end);
392                            let end = snapshot.anchor_after(Point::new(row_range.end, end_col));
393                            start..end
394                        })
395                        .collect();
396
397                    let mut related_buffer = RelatedBuffer {
398                        buffer,
399                        path,
400                        anchor_ranges,
401                        cached_file: None,
402                    };
403                    related_buffer.fill_cache(snapshot);
404                    Some(related_buffer)
405                })
406                .collect();
407
408            related_buffers.sort_by_key(|related| related.path.clone());
409
410            (new_entries, related_buffers)
411        })
412        .await)
413}
414
415impl RelatedBuffer {
416    fn related_file(&mut self, cx: &App) -> RelatedFile {
417        let buffer = self.buffer.read(cx);
418        let path = self.path.clone();
419        let cached = if let Some(cached) = &self.cached_file
420            && buffer.version() == cached.buffer_version
421        {
422            cached
423        } else {
424            self.fill_cache(buffer)
425        };
426        let related_file = RelatedFile {
427            path,
428            excerpts: cached.excerpts.clone(),
429            max_row: buffer.max_point().row,
430        };
431        return related_file;
432    }
433
434    fn fill_cache(&mut self, buffer: &text::BufferSnapshot) -> &CachedRelatedFile {
435        let excerpts = self
436            .anchor_ranges
437            .iter()
438            .map(|range| {
439                let start = range.start.to_point(buffer);
440                let end = range.end.to_point(buffer);
441                RelatedExcerpt {
442                    row_range: start.row..end.row,
443                    text: buffer.text_for_range(start..end).collect::<String>().into(),
444                }
445            })
446            .collect::<Vec<_>>();
447        self.cached_file = Some(CachedRelatedFile {
448            excerpts: excerpts,
449            buffer_version: buffer.version().clone(),
450        });
451        self.cached_file.as_ref().unwrap()
452    }
453}
454
455use language::ToPoint as _;
456
457const MAX_TARGET_LEN: usize = 128;
458
459fn process_definition(
460    location: LocationLink,
461    project: &Entity<Project>,
462    cx: &mut App,
463) -> Option<CachedDefinition> {
464    let buffer = location.target.buffer.read(cx);
465    let anchor_range = location.target.range;
466    let file = buffer.file()?;
467    let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
468    if worktree.read(cx).is_single_file() {
469        return None;
470    }
471
472    // If the target range is large, it likely means we requested the definition of an entire module.
473    // For individual definitions, the target range should be small as it only covers the symbol.
474    let buffer = location.target.buffer.read(cx);
475    let target_len = anchor_range.to_offset(&buffer).len();
476    if target_len > MAX_TARGET_LEN {
477        return None;
478    }
479
480    Some(CachedDefinition {
481        path: ProjectPath {
482            worktree_id: file.worktree_id(cx),
483            path: file.path().clone(),
484        },
485        buffer: location.target.buffer,
486        anchor_range,
487    })
488}
489
490/// Gets all of the identifiers that are present in the given line, and its containing
491/// outline items.
492fn identifiers_for_position(
493    buffer: &BufferSnapshot,
494    position: Anchor,
495    identifier_line_count: u32,
496) -> Vec<Identifier> {
497    let offset = position.to_offset(buffer);
498    let point = buffer.offset_to_point(offset);
499
500    // Search for identifiers on lines adjacent to the cursor.
501    let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
502    let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
503    let line_range = start..end;
504    let mut ranges = vec![line_range.to_offset(&buffer)];
505
506    // Search for identifiers mentioned in headers/signatures of containing outline items.
507    let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
508    for item in outline_items {
509        if let Some(body_range) = item.body_range(&buffer) {
510            ranges.push(item.range.start..body_range.start.to_offset(&buffer));
511        } else {
512            ranges.push(item.range.clone());
513        }
514    }
515
516    ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
517    ranges.dedup_by(|a, b| {
518        if a.start <= b.end {
519            b.start = b.start.min(a.start);
520            b.end = b.end.max(a.end);
521            true
522        } else {
523            false
524        }
525    });
526
527    let mut identifiers = Vec::new();
528    let outer_range =
529        ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
530
531    let mut captures = buffer
532        .syntax
533        .captures(outer_range.clone(), &buffer.text, |grammar| {
534            grammar
535                .highlights_config
536                .as_ref()
537                .map(|config| &config.query)
538        });
539
540    for range in ranges {
541        captures.set_byte_range(range.start..outer_range.end);
542
543        let mut last_range = None;
544        while let Some(capture) = captures.peek() {
545            let node_range = capture.node.byte_range();
546            if node_range.start > range.end {
547                break;
548            }
549            let config = captures.grammars()[capture.grammar_index]
550                .highlights_config
551                .as_ref();
552
553            if let Some(config) = config
554                && config.identifier_capture_indices.contains(&capture.index)
555                && range.contains_inclusive(&node_range)
556                && Some(&node_range) != last_range.as_ref()
557            {
558                let name = buffer.text_for_range(node_range.clone()).collect();
559                identifiers.push(Identifier {
560                    range: buffer.anchor_after(node_range.start)
561                        ..buffer.anchor_before(node_range.end),
562                    name,
563                });
564                last_range = Some(node_range);
565            }
566
567            captures.advance();
568        }
569    }
570
571    identifiers
572}