1use crate::assemble_excerpts::assemble_excerpts;
2use anyhow::Result;
3use collections::HashMap;
4use futures::{FutureExt, StreamExt as _, channel::mpsc, future};
5use gpui::{App, AppContext, AsyncApp, Context, Entity, EventEmitter, Task, WeakEntity};
6use language::{Anchor, Buffer, BufferSnapshot, OffsetRangeExt as _, Point, Rope, ToOffset as _};
7use project::{LocationLink, Project, ProjectPath};
8use serde::{Serialize, Serializer};
9use smallvec::SmallVec;
10use std::{
11 collections::hash_map,
12 ops::Range,
13 sync::Arc,
14 time::{Duration, Instant},
15};
16use util::{RangeExt as _, ResultExt};
17
18mod assemble_excerpts;
19#[cfg(test)]
20mod edit_prediction_context_tests;
21mod excerpt;
22#[cfg(test)]
23mod fake_definition_lsp;
24
25pub use cloud_llm_client::predict_edits_v3::Line;
26pub use excerpt::{EditPredictionExcerpt, EditPredictionExcerptOptions, EditPredictionExcerptText};
27
28const IDENTIFIER_LINE_COUNT: u32 = 3;
29
30pub struct RelatedExcerptStore {
31 project: WeakEntity<Project>,
32 related_files: Vec<RelatedFile>,
33 cache: HashMap<Identifier, Arc<CacheEntry>>,
34 update_tx: mpsc::UnboundedSender<(Entity<Buffer>, Anchor)>,
35 identifier_line_count: u32,
36}
37
38pub enum RelatedExcerptStoreEvent {
39 StartedRefresh,
40 FinishedRefresh {
41 cache_hit_count: usize,
42 cache_miss_count: usize,
43 mean_definition_latency: Duration,
44 max_definition_latency: Duration,
45 },
46}
47
48#[derive(Clone, Debug, PartialEq, Eq, Hash)]
49struct Identifier {
50 pub name: String,
51 pub range: Range<Anchor>,
52}
53
54enum DefinitionTask {
55 CacheHit(Arc<CacheEntry>),
56 CacheMiss(Task<Result<Option<Vec<LocationLink>>>>),
57}
58
59#[derive(Debug)]
60struct CacheEntry {
61 definitions: SmallVec<[CachedDefinition; 1]>,
62}
63
64#[derive(Clone, Debug)]
65struct CachedDefinition {
66 path: ProjectPath,
67 buffer: Entity<Buffer>,
68 anchor_range: Range<Anchor>,
69}
70
71#[derive(Clone, Debug, Serialize)]
72pub struct RelatedFile {
73 #[serde(serialize_with = "serialize_project_path")]
74 pub path: ProjectPath,
75 #[serde(skip)]
76 pub buffer: WeakEntity<Buffer>,
77 pub excerpts: Vec<RelatedExcerpt>,
78 pub max_row: u32,
79}
80
81impl RelatedFile {
82 pub fn merge_excerpts(&mut self) {
83 self.excerpts.sort_unstable_by(|a, b| {
84 a.point_range
85 .start
86 .cmp(&b.point_range.start)
87 .then(b.point_range.end.cmp(&a.point_range.end))
88 });
89
90 let mut index = 1;
91 while index < self.excerpts.len() {
92 if self.excerpts[index - 1]
93 .point_range
94 .end
95 .cmp(&self.excerpts[index].point_range.start)
96 .is_ge()
97 {
98 let removed = self.excerpts.remove(index);
99 if removed
100 .point_range
101 .end
102 .cmp(&self.excerpts[index - 1].point_range.end)
103 .is_gt()
104 {
105 self.excerpts[index - 1].point_range.end = removed.point_range.end;
106 self.excerpts[index - 1].anchor_range.end = removed.anchor_range.end;
107 }
108 } else {
109 index += 1;
110 }
111 }
112 }
113}
114
115#[derive(Clone, Debug, Serialize)]
116pub struct RelatedExcerpt {
117 #[serde(skip)]
118 pub anchor_range: Range<Anchor>,
119 #[serde(serialize_with = "serialize_point_range")]
120 pub point_range: Range<Point>,
121 #[serde(serialize_with = "serialize_rope")]
122 pub text: Rope,
123}
124
125fn serialize_project_path<S: Serializer>(
126 project_path: &ProjectPath,
127 serializer: S,
128) -> Result<S::Ok, S::Error> {
129 project_path.path.serialize(serializer)
130}
131
132fn serialize_rope<S: Serializer>(rope: &Rope, serializer: S) -> Result<S::Ok, S::Error> {
133 rope.to_string().serialize(serializer)
134}
135
136fn serialize_point_range<S: Serializer>(
137 range: &Range<Point>,
138 serializer: S,
139) -> Result<S::Ok, S::Error> {
140 [
141 [range.start.row, range.start.column],
142 [range.end.row, range.end.column],
143 ]
144 .serialize(serializer)
145}
146
147const DEBOUNCE_DURATION: Duration = Duration::from_millis(100);
148
149impl EventEmitter<RelatedExcerptStoreEvent> for RelatedExcerptStore {}
150
151impl RelatedExcerptStore {
152 pub fn new(project: &Entity<Project>, cx: &mut Context<Self>) -> Self {
153 let (update_tx, mut update_rx) = mpsc::unbounded::<(Entity<Buffer>, Anchor)>();
154 cx.spawn(async move |this, cx| {
155 let executor = cx.background_executor().clone();
156 while let Some((mut buffer, mut position)) = update_rx.next().await {
157 let mut timer = executor.timer(DEBOUNCE_DURATION).fuse();
158 loop {
159 futures::select_biased! {
160 next = update_rx.next() => {
161 if let Some((new_buffer, new_position)) = next {
162 buffer = new_buffer;
163 position = new_position;
164 timer = executor.timer(DEBOUNCE_DURATION).fuse();
165 } else {
166 return anyhow::Ok(());
167 }
168 }
169 _ = timer => break,
170 }
171 }
172
173 Self::fetch_excerpts(this.clone(), buffer, position, cx).await?;
174 }
175 anyhow::Ok(())
176 })
177 .detach_and_log_err(cx);
178
179 RelatedExcerptStore {
180 project: project.downgrade(),
181 update_tx,
182 related_files: Vec::new(),
183 cache: Default::default(),
184 identifier_line_count: IDENTIFIER_LINE_COUNT,
185 }
186 }
187
188 pub fn set_identifier_line_count(&mut self, count: u32) {
189 self.identifier_line_count = count;
190 }
191
192 pub fn refresh(&mut self, buffer: Entity<Buffer>, position: Anchor, _: &mut Context<Self>) {
193 self.update_tx.unbounded_send((buffer, position)).ok();
194 }
195
196 pub fn related_files(&self) -> &[RelatedFile] {
197 &self.related_files
198 }
199
200 async fn fetch_excerpts(
201 this: WeakEntity<Self>,
202 buffer: Entity<Buffer>,
203 position: Anchor,
204 cx: &mut AsyncApp,
205 ) -> Result<()> {
206 let (project, snapshot, identifier_line_count) = this.read_with(cx, |this, cx| {
207 (
208 this.project.upgrade(),
209 buffer.read(cx).snapshot(),
210 this.identifier_line_count,
211 )
212 })?;
213 let Some(project) = project else {
214 return Ok(());
215 };
216
217 let file = snapshot.file().cloned();
218 if let Some(file) = &file {
219 log::debug!("retrieving_context buffer:{}", file.path().as_unix_str());
220 }
221
222 this.update(cx, |_, cx| {
223 cx.emit(RelatedExcerptStoreEvent::StartedRefresh);
224 })?;
225
226 let identifiers = cx
227 .background_spawn(async move {
228 identifiers_for_position(&snapshot, position, identifier_line_count)
229 })
230 .await;
231
232 let async_cx = cx.clone();
233 let start_time = Instant::now();
234 let futures = this.update(cx, |this, cx| {
235 identifiers
236 .into_iter()
237 .filter_map(|identifier| {
238 let task = if let Some(entry) = this.cache.get(&identifier) {
239 DefinitionTask::CacheHit(entry.clone())
240 } else {
241 DefinitionTask::CacheMiss(
242 this.project
243 .update(cx, |project, cx| {
244 project.definitions(&buffer, identifier.range.start, cx)
245 })
246 .ok()?,
247 )
248 };
249
250 let cx = async_cx.clone();
251 let project = project.clone();
252 Some(async move {
253 match task {
254 DefinitionTask::CacheHit(cache_entry) => {
255 Some((identifier, cache_entry, None))
256 }
257 DefinitionTask::CacheMiss(task) => {
258 let locations = task.await.log_err()??;
259 let duration = start_time.elapsed();
260 cx.update(|cx| {
261 (
262 identifier,
263 Arc::new(CacheEntry {
264 definitions: locations
265 .into_iter()
266 .filter_map(|location| {
267 process_definition(location, &project, cx)
268 })
269 .collect(),
270 }),
271 Some(duration),
272 )
273 })
274 .ok()
275 }
276 }
277 })
278 })
279 .collect::<Vec<_>>()
280 })?;
281
282 let mut cache_hit_count = 0;
283 let mut cache_miss_count = 0;
284 let mut mean_definition_latency = Duration::ZERO;
285 let mut max_definition_latency = Duration::ZERO;
286 let mut new_cache = HashMap::default();
287 new_cache.reserve(futures.len());
288 for (identifier, entry, duration) in future::join_all(futures).await.into_iter().flatten() {
289 new_cache.insert(identifier, entry);
290 if let Some(duration) = duration {
291 cache_miss_count += 1;
292 mean_definition_latency += duration;
293 max_definition_latency = max_definition_latency.max(duration);
294 } else {
295 cache_hit_count += 1;
296 }
297 }
298 mean_definition_latency /= cache_miss_count.max(1) as u32;
299
300 let (new_cache, related_files) = rebuild_related_files(new_cache, cx).await?;
301
302 if let Some(file) = &file {
303 log::debug!(
304 "finished retrieving context buffer:{}, latency:{:?}",
305 file.path().as_unix_str(),
306 start_time.elapsed()
307 );
308 }
309
310 this.update(cx, |this, cx| {
311 this.cache = new_cache;
312 this.related_files = related_files;
313 cx.emit(RelatedExcerptStoreEvent::FinishedRefresh {
314 cache_hit_count,
315 cache_miss_count,
316 mean_definition_latency,
317 max_definition_latency,
318 });
319 })?;
320
321 anyhow::Ok(())
322 }
323}
324
325async fn rebuild_related_files(
326 new_entries: HashMap<Identifier, Arc<CacheEntry>>,
327 cx: &mut AsyncApp,
328) -> Result<(HashMap<Identifier, Arc<CacheEntry>>, Vec<RelatedFile>)> {
329 let mut snapshots = HashMap::default();
330 for entry in new_entries.values() {
331 for definition in &entry.definitions {
332 if let hash_map::Entry::Vacant(e) = snapshots.entry(definition.buffer.entity_id()) {
333 definition
334 .buffer
335 .read_with(cx, |buffer, _| buffer.parsing_idle())?
336 .await;
337 e.insert(
338 definition
339 .buffer
340 .read_with(cx, |buffer, _| buffer.snapshot())?,
341 );
342 }
343 }
344 }
345
346 Ok(cx
347 .background_spawn(async move {
348 let mut files = Vec::<RelatedFile>::new();
349 let mut ranges_by_buffer = HashMap::<_, Vec<Range<Point>>>::default();
350 let mut paths_by_buffer = HashMap::default();
351 for entry in new_entries.values() {
352 for definition in &entry.definitions {
353 let Some(snapshot) = snapshots.get(&definition.buffer.entity_id()) else {
354 continue;
355 };
356 paths_by_buffer.insert(definition.buffer.entity_id(), definition.path.clone());
357 ranges_by_buffer
358 .entry(definition.buffer.clone())
359 .or_default()
360 .push(definition.anchor_range.to_point(snapshot));
361 }
362 }
363
364 for (buffer, ranges) in ranges_by_buffer {
365 let Some(snapshot) = snapshots.get(&buffer.entity_id()) else {
366 continue;
367 };
368 let Some(project_path) = paths_by_buffer.get(&buffer.entity_id()) else {
369 continue;
370 };
371 let excerpts = assemble_excerpts(snapshot, ranges);
372 files.push(RelatedFile {
373 path: project_path.clone(),
374 buffer: buffer.downgrade(),
375 excerpts,
376 max_row: snapshot.max_point().row,
377 });
378 }
379
380 files.sort_by_key(|file| file.path.clone());
381 (new_entries, files)
382 })
383 .await)
384}
385
386fn process_definition(
387 location: LocationLink,
388 project: &Entity<Project>,
389 cx: &mut App,
390) -> Option<CachedDefinition> {
391 let buffer = location.target.buffer.read(cx);
392 let anchor_range = location.target.range;
393 let file = buffer.file()?;
394 let worktree = project.read(cx).worktree_for_id(file.worktree_id(cx), cx)?;
395 if worktree.read(cx).is_single_file() {
396 return None;
397 }
398 Some(CachedDefinition {
399 path: ProjectPath {
400 worktree_id: file.worktree_id(cx),
401 path: file.path().clone(),
402 },
403 buffer: location.target.buffer,
404 anchor_range,
405 })
406}
407
408/// Gets all of the identifiers that are present in the given line, and its containing
409/// outline items.
410fn identifiers_for_position(
411 buffer: &BufferSnapshot,
412 position: Anchor,
413 identifier_line_count: u32,
414) -> Vec<Identifier> {
415 let offset = position.to_offset(buffer);
416 let point = buffer.offset_to_point(offset);
417
418 // Search for identifiers on lines adjacent to the cursor.
419 let start = Point::new(point.row.saturating_sub(identifier_line_count), 0);
420 let end = Point::new(point.row + identifier_line_count + 1, 0).min(buffer.max_point());
421 let line_range = start..end;
422 let mut ranges = vec![line_range.to_offset(&buffer)];
423
424 // Search for identifiers mentioned in headers/signatures of containing outline items.
425 let outline_items = buffer.outline_items_as_offsets_containing(offset..offset, false, None);
426 for item in outline_items {
427 if let Some(body_range) = item.body_range(&buffer) {
428 ranges.push(item.range.start..body_range.start.to_offset(&buffer));
429 } else {
430 ranges.push(item.range.clone());
431 }
432 }
433
434 ranges.sort_by(|a, b| a.start.cmp(&b.start).then(b.end.cmp(&a.end)));
435 ranges.dedup_by(|a, b| {
436 if a.start <= b.end {
437 b.start = b.start.min(a.start);
438 b.end = b.end.max(a.end);
439 true
440 } else {
441 false
442 }
443 });
444
445 let mut identifiers = Vec::new();
446 let outer_range =
447 ranges.first().map_or(0, |r| r.start)..ranges.last().map_or(buffer.len(), |r| r.end);
448
449 let mut captures = buffer
450 .syntax
451 .captures(outer_range.clone(), &buffer.text, |grammar| {
452 grammar
453 .highlights_config
454 .as_ref()
455 .map(|config| &config.query)
456 });
457
458 for range in ranges {
459 captures.set_byte_range(range.start..outer_range.end);
460
461 let mut last_range = None;
462 while let Some(capture) = captures.peek() {
463 let node_range = capture.node.byte_range();
464 if node_range.start > range.end {
465 break;
466 }
467 let config = captures.grammars()[capture.grammar_index]
468 .highlights_config
469 .as_ref();
470
471 if let Some(config) = config
472 && config.identifier_capture_indices.contains(&capture.index)
473 && range.contains_inclusive(&node_range)
474 && Some(&node_range) != last_range.as_ref()
475 {
476 let name = buffer.text_for_range(node_range.clone()).collect();
477 identifiers.push(Identifier {
478 range: buffer.anchor_after(node_range.start)
479 ..buffer.anchor_before(node_range.end),
480 name,
481 });
482 last_range = Some(node_range);
483 }
484
485 captures.advance();
486 }
487 }
488
489 identifiers
490}