related_excerpts.rs

  1use std::{cmp::Reverse, fmt::Write, ops::Range, path::PathBuf, sync::Arc};
  2
  3use crate::merge_excerpts::write_merged_excerpts;
  4use anyhow::{Result, anyhow};
  5use collections::HashMap;
  6use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
  7use futures::{StreamExt, stream::BoxStream};
  8use gpui::{App, AsyncApp, Entity, Task};
  9use indoc::indoc;
 10use language::{Anchor, Bias, Buffer, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _};
 11use language_model::{
 12    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
 13    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
 14    LanguageModelRequestTool, LanguageModelToolResult, MessageContent, Role,
 15};
 16use project::{
 17    Project, WorktreeSettings,
 18    search::{SearchQuery, SearchResult},
 19};
 20use schemars::JsonSchema;
 21use serde::Deserialize;
 22use util::paths::{PathMatcher, PathStyle};
 23use workspace::item::Settings as _;
 24
 25const SEARCH_PROMPT: &str = indoc! {r#"
 26    ## Task
 27
 28    You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
 29    that will serve as context for predicting the next required edit.
 30
 31    **Your task:**
 32    - Analyze the user's recent edits and current cursor context
 33    - Use the `search` tool to find code that may be relevant for predicting the next edit
 34    - Focus on finding:
 35       - Code patterns that might need similar changes based on the recent edits
 36       - Functions, variables, types, and constants referenced in the current cursor context
 37       - Related implementations, usages, or dependencies that may require consistent updates
 38
 39    **Important constraints:**
 40    - This conversation has exactly 2 turns
 41    - You must make ALL search queries in your first response via the `search` tool
 42    - All queries will be executed in parallel and results returned together
 43    - In the second turn, you will select the most relevant results via the `select` tool.
 44
 45    ## User Edits
 46
 47    {edits}
 48
 49    ## Current cursor context
 50
 51    `````filename={current_file_path}
 52    {cursor_excerpt}
 53    `````
 54
 55    --
 56    Use the `search` tool now
 57"#};
 58
 59const SEARCH_TOOL_NAME: &str = "search";
 60
 61/// Search for relevant code
 62///
 63/// For the best results, run multiple queries at once with a single invocation of this tool.
 64#[derive(Deserialize, JsonSchema)]
 65struct SearchToolInput {
 66    /// An array of queries to run for gathering context relevant to the next prediction
 67    #[schemars(length(max = 5))]
 68    queries: Box<[SearchToolQuery]>,
 69}
 70
 71#[derive(Deserialize, JsonSchema)]
 72struct SearchToolQuery {
 73    /// A glob pattern to match file paths in the codebase
 74    glob: String,
 75    /// A regular expression to match content within the files matched by the glob pattern
 76    regex: String,
 77    /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
 78    #[serde(default)]
 79    case_sensitive: bool,
 80}
 81
 82const RESULTS_MESSAGE: &str = indoc! {"
 83    Here are the results of your queries combined and grouped by file:
 84
 85"};
 86
 87const SELECT_TOOL_NAME: &str = "select";
 88
 89const SELECT_PROMPT: &str = indoc! {"
 90    Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
 91    Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
 92    Include up to 200 lines in total.
 93"};
 94
 95/// Select line ranges from search results
 96#[derive(Deserialize, JsonSchema)]
 97struct SelectToolInput {
 98    /// The line ranges to select from search results.
 99    ranges: Vec<SelectLineRange>,
100}
101
102/// A specific line range to select from a file
103#[derive(Debug, Deserialize, JsonSchema)]
104struct SelectLineRange {
105    /// The file path containing the lines to select
106    /// Exactly as it appears in the search result codeblocks.
107    path: PathBuf,
108    /// The starting line number (1-based)
109    #[schemars(range(min = 1))]
110    start_line: u32,
111    /// The ending line number (1-based, inclusive)
112    #[schemars(range(min = 1))]
113    end_line: u32,
114}
115
116#[derive(Debug, Clone, PartialEq)]
117pub struct LlmContextOptions {
118    pub excerpt: EditPredictionExcerptOptions,
119}
120
121pub fn find_related_excerpts<'a>(
122    buffer: Entity<language::Buffer>,
123    cursor_position: Anchor,
124    project: &Entity<Project>,
125    events: impl Iterator<Item = &'a crate::Event>,
126    options: &LlmContextOptions,
127    cx: &App,
128) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
129    let language_model_registry = LanguageModelRegistry::global(cx);
130    let Some(model) = language_model_registry
131        .read(cx)
132        .available_models(cx)
133        .find(|model| {
134            model.provider_id() == language_model::ANTHROPIC_PROVIDER_ID
135                && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
136        })
137    else {
138        return Task::ready(Err(anyhow!("could not find claude model")));
139    };
140
141    let mut edits_string = String::new();
142
143    for event in events {
144        if let Some(event) = event.to_request_event(cx) {
145            writeln!(&mut edits_string, "{event}").ok();
146        }
147    }
148
149    if edits_string.is_empty() {
150        edits_string.push_str("(No user edits yet)");
151    }
152
153    // TODO [zeta2] include breadcrumbs?
154    let snapshot = buffer.read(cx).snapshot();
155    let cursor_point = cursor_position.to_point(&snapshot);
156    let Some(cursor_excerpt) =
157        EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
158    else {
159        return Task::ready(Ok(HashMap::default()));
160    };
161
162    let current_file_path = snapshot
163        .file()
164        .map(|f| f.full_path(cx).display().to_string())
165        .unwrap_or_else(|| "untitled".to_string());
166
167    let prompt = SEARCH_PROMPT
168        .replace("{edits}", &edits_string)
169        .replace("{current_file_path}", &current_file_path)
170        .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
171
172    let path_style = project.read(cx).path_style(cx);
173
174    let exclude_matcher = {
175        let global_settings = WorktreeSettings::get_global(cx);
176        let exclude_patterns = global_settings
177            .file_scan_exclusions
178            .sources()
179            .iter()
180            .chain(global_settings.private_files.sources().iter());
181
182        match PathMatcher::new(exclude_patterns, path_style) {
183            Ok(matcher) => matcher,
184            Err(err) => {
185                return Task::ready(Err(anyhow!(err)));
186            }
187        }
188    };
189
190    let project = project.clone();
191    cx.spawn(async move |cx| {
192        let initial_prompt_message = LanguageModelRequestMessage {
193            role: Role::User,
194            content: vec![prompt.into()],
195            cache: false,
196        };
197
198        let mut search_stream = request_tool_call::<SearchToolInput>(
199            vec![initial_prompt_message.clone()],
200            SEARCH_TOOL_NAME,
201            &model,
202            cx,
203        )
204        .await?;
205
206        let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
207        select_request_messages.push(initial_prompt_message);
208        let mut search_calls = Vec::new();
209
210        while let Some(event) = search_stream.next().await {
211            match event? {
212                LanguageModelCompletionEvent::ToolUse(tool_use) => {
213                    if !tool_use.is_input_complete {
214                        continue;
215                    }
216
217                    if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
218                        search_calls.push((select_request_messages.len(), tool_use));
219                    } else {
220                        log::warn!(
221                            "context gathering model tried to use unknown tool: {}",
222                            tool_use.name
223                        );
224                    }
225                }
226                LanguageModelCompletionEvent::Text(txt) => {
227                    if let Some(LanguageModelRequestMessage {
228                        role: Role::Assistant,
229                        content,
230                        ..
231                    }) = select_request_messages.last_mut()
232                    {
233                        if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
234                            existing_text.push_str(&txt);
235                        } else {
236                            content.push(MessageContent::Text(txt));
237                        }
238                    } else {
239                        select_request_messages.push(LanguageModelRequestMessage {
240                            role: Role::Assistant,
241                            content: vec![MessageContent::Text(txt)],
242                            cache: false,
243                        });
244                    }
245                }
246                LanguageModelCompletionEvent::Thinking { text, signature } => {
247                    if let Some(LanguageModelRequestMessage {
248                        role: Role::Assistant,
249                        content,
250                        ..
251                    }) = select_request_messages.last_mut()
252                    {
253                        if let Some(MessageContent::Thinking {
254                            text: existing_text,
255                            signature: existing_signature,
256                        }) = content.last_mut()
257                        {
258                            existing_text.push_str(&text);
259                            *existing_signature = signature;
260                        } else {
261                            content.push(MessageContent::Thinking { text, signature });
262                        }
263                    } else {
264                        select_request_messages.push(LanguageModelRequestMessage {
265                            role: Role::Assistant,
266                            content: vec![MessageContent::Thinking { text, signature }],
267                            cache: false,
268                        });
269                    }
270                }
271                LanguageModelCompletionEvent::RedactedThinking { data } => {
272                    if let Some(LanguageModelRequestMessage {
273                        role: Role::Assistant,
274                        content,
275                        ..
276                    }) = select_request_messages.last_mut()
277                    {
278                        if let Some(MessageContent::RedactedThinking(existing_data)) =
279                            content.last_mut()
280                        {
281                            existing_data.push_str(&data);
282                        } else {
283                            content.push(MessageContent::RedactedThinking(data));
284                        }
285                    } else {
286                        select_request_messages.push(LanguageModelRequestMessage {
287                            role: Role::Assistant,
288                            content: vec![MessageContent::RedactedThinking(data)],
289                            cache: false,
290                        });
291                    }
292                }
293                ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
294                    log::error!("{ev:?}");
295                }
296                ev => {
297                    log::trace!("context search event: {ev:?}")
298                }
299            }
300        }
301
302        struct ResultBuffer {
303            buffer: Entity<Buffer>,
304            snapshot: TextBufferSnapshot,
305        }
306
307        let mut result_buffers_by_path = HashMap::default();
308
309        for (index, tool_use) in search_calls.into_iter().rev() {
310            let call = serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
311
312            let mut excerpts_by_buffer = HashMap::default();
313
314            for query in call.queries {
315                // TODO [zeta2] parallelize?
316
317                run_query(
318                    query,
319                    &mut excerpts_by_buffer,
320                    path_style,
321                    exclude_matcher.clone(),
322                    &project,
323                    cx,
324                )
325                .await?;
326            }
327
328            if excerpts_by_buffer.is_empty() {
329                continue;
330            }
331
332            let mut merged_result = RESULTS_MESSAGE.to_string();
333
334            for (buffer_entity, mut excerpts_for_buffer) in excerpts_by_buffer {
335                excerpts_for_buffer.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
336
337                buffer_entity
338                    .clone()
339                    .read_with(cx, |buffer, cx| {
340                        let Some(file) = buffer.file() else {
341                            return;
342                        };
343
344                        let path = file.full_path(cx);
345
346                        writeln!(&mut merged_result, "`````filename={}", path.display()).unwrap();
347
348                        let snapshot = buffer.snapshot();
349
350                        write_merged_excerpts(
351                            &snapshot,
352                            excerpts_for_buffer,
353                            &[],
354                            &mut merged_result,
355                        );
356
357                        merged_result.push_str("`````\n\n");
358
359                        result_buffers_by_path.insert(
360                            path,
361                            ResultBuffer {
362                                buffer: buffer_entity,
363                                snapshot: snapshot.text,
364                            },
365                        );
366                    })
367                    .ok();
368            }
369
370            let tool_result = LanguageModelToolResult {
371                tool_use_id: tool_use.id.clone(),
372                tool_name: SEARCH_TOOL_NAME.into(),
373                is_error: false,
374                content: merged_result.into(),
375                output: None,
376            };
377
378            // Almost always appends at the end, but in theory, the model could return some text after the tool call
379            // or perform parallel tool calls, so we splice at the message index for correctness.
380            select_request_messages.splice(
381                index..index,
382                [
383                    LanguageModelRequestMessage {
384                        role: Role::Assistant,
385                        content: vec![MessageContent::ToolUse(tool_use)],
386                        cache: false,
387                    },
388                    LanguageModelRequestMessage {
389                        role: Role::User,
390                        content: vec![MessageContent::ToolResult(tool_result)],
391                        cache: false,
392                    },
393                ],
394            );
395        }
396
397        if result_buffers_by_path.is_empty() {
398            log::trace!("context gathering queries produced no results");
399            return anyhow::Ok(HashMap::default());
400        }
401
402        select_request_messages.push(LanguageModelRequestMessage {
403            role: Role::User,
404            content: vec![SELECT_PROMPT.into()],
405            cache: false,
406        });
407
408        let mut select_stream = request_tool_call::<SelectToolInput>(
409            select_request_messages,
410            SELECT_TOOL_NAME,
411            &model,
412            cx,
413        )
414        .await?;
415        let mut selected_ranges = Vec::new();
416
417        while let Some(event) = select_stream.next().await {
418            match event? {
419                LanguageModelCompletionEvent::ToolUse(tool_use) => {
420                    if !tool_use.is_input_complete {
421                        continue;
422                    }
423
424                    if tool_use.name.as_ref() == SELECT_TOOL_NAME {
425                        let call =
426                            serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
427                        selected_ranges.extend(call.ranges);
428                    } else {
429                        log::warn!(
430                            "context gathering model tried to use unknown tool: {}",
431                            tool_use.name
432                        );
433                    }
434                }
435                ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
436                    log::error!("{ev:?}");
437                }
438                ev => {
439                    log::trace!("context select event: {ev:?}")
440                }
441            }
442        }
443
444        if selected_ranges.is_empty() {
445            log::trace!("context gathering selected no ranges")
446        }
447
448        let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
449
450        for selected_range in selected_ranges {
451            if let Some(ResultBuffer { buffer, snapshot }) =
452                result_buffers_by_path.get(&selected_range.path)
453            {
454                let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
455                let end_point =
456                    snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
457                let range = snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
458
459                related_excerpts_by_buffer
460                    .entry(buffer.clone())
461                    .or_default()
462                    .push(range);
463            } else {
464                log::warn!(
465                    "selected path that wasn't included in search results: {}",
466                    selected_range.path.display()
467                );
468            }
469        }
470
471        for (buffer, ranges) in &mut related_excerpts_by_buffer {
472            buffer.read_with(cx, |buffer, _cx| {
473                ranges.sort_unstable_by(|a, b| {
474                    a.start
475                        .cmp(&b.start, buffer)
476                        .then(b.end.cmp(&a.end, buffer))
477                });
478            })?;
479        }
480
481        anyhow::Ok(related_excerpts_by_buffer)
482    })
483}
484
485async fn request_tool_call<T: JsonSchema>(
486    messages: Vec<LanguageModelRequestMessage>,
487    tool_name: &'static str,
488    model: &Arc<dyn LanguageModel>,
489    cx: &mut AsyncApp,
490) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
491{
492    let schema = schemars::schema_for!(T);
493
494    let request = LanguageModelRequest {
495        messages,
496        tools: vec![LanguageModelRequestTool {
497            name: tool_name.into(),
498            description: schema
499                .get("description")
500                .and_then(|description| description.as_str())
501                .unwrap()
502                .to_string(),
503            input_schema: serde_json::to_value(schema).unwrap(),
504        }],
505        ..Default::default()
506    };
507
508    Ok(model.stream_completion(request, cx).await?)
509}
510
511const MIN_EXCERPT_LEN: usize = 16;
512const MAX_EXCERPT_LEN: usize = 768;
513const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
514
515async fn run_query(
516    args: SearchToolQuery,
517    excerpts_by_buffer: &mut HashMap<Entity<Buffer>, Vec<Range<Line>>>,
518    path_style: PathStyle,
519    exclude_matcher: PathMatcher,
520    project: &Entity<Project>,
521    cx: &mut AsyncApp,
522) -> Result<()> {
523    let include_matcher = PathMatcher::new(vec![args.glob], path_style)?;
524
525    let query = SearchQuery::regex(
526        &args.regex,
527        false,
528        args.case_sensitive,
529        false,
530        true,
531        include_matcher,
532        exclude_matcher,
533        true,
534        None,
535    )?;
536
537    let results = project.update(cx, |project, cx| project.search(query, cx))?;
538    futures::pin_mut!(results);
539
540    let mut total_bytes = 0;
541
542    while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
543        if ranges.is_empty() {
544            continue;
545        }
546
547        let excerpts_for_buffer = excerpts_by_buffer
548            .entry(buffer.clone())
549            .or_insert_with(|| Vec::with_capacity(ranges.len()));
550
551        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
552
553        for range in ranges {
554            let offset_range = range.to_offset(&snapshot);
555            let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
556
557            if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
558                break;
559            }
560
561            let excerpt = EditPredictionExcerpt::select_from_buffer(
562                query_point,
563                &snapshot,
564                &EditPredictionExcerptOptions {
565                    max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
566                    min_bytes: MIN_EXCERPT_LEN,
567                    target_before_cursor_over_total_bytes: 0.5,
568                },
569                None,
570            );
571
572            if let Some(excerpt) = excerpt {
573                total_bytes += excerpt.range.len();
574                if !excerpt.line_range.is_empty() {
575                    excerpts_for_buffer.push(excerpt.line_range);
576                }
577            }
578        }
579
580        if excerpts_for_buffer.is_empty() {
581            excerpts_by_buffer.remove(&buffer);
582        }
583    }
584
585    anyhow::Ok(())
586}