related_excerpts.rs

  1use std::{
  2    cmp::Reverse, collections::hash_map::Entry, fmt::Write, ops::Range, path::PathBuf, sync::Arc,
  3    time::Instant,
  4};
  5
  6use crate::{
  7    ZetaContextRetrievalDebugInfo, ZetaDebugInfo, ZetaSearchQueryDebugInfo,
  8    merge_excerpts::write_merged_excerpts,
  9};
 10use anyhow::{Result, anyhow};
 11use collections::HashMap;
 12use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
 13use futures::{
 14    StreamExt,
 15    channel::mpsc::{self, UnboundedSender},
 16    stream::BoxStream,
 17};
 18use gpui::{App, AppContext, AsyncApp, Entity, Task};
 19use indoc::indoc;
 20use language::{
 21    Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _,
 22};
 23use language_model::{
 24    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
 25    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
 26    LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
 27};
 28use project::{
 29    Project, WorktreeSettings,
 30    search::{SearchQuery, SearchResult},
 31};
 32use schemars::JsonSchema;
 33use serde::{Deserialize, Serialize};
 34use util::{
 35    ResultExt as _,
 36    paths::{PathMatcher, PathStyle},
 37};
 38use workspace::item::Settings as _;
 39
 40const SEARCH_PROMPT: &str = indoc! {r#"
 41    ## Task
 42
 43    You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
 44    that will serve as context for predicting the next required edit.
 45
 46    **Your task:**
 47    - Analyze the user's recent edits and current cursor context
 48    - Use the `search` tool to find code that may be relevant for predicting the next edit
 49    - Focus on finding:
 50       - Code patterns that might need similar changes based on the recent edits
 51       - Functions, variables, types, and constants referenced in the current cursor context
 52       - Related implementations, usages, or dependencies that may require consistent updates
 53
 54    **Important constraints:**
 55    - This conversation has exactly 2 turns
 56    - You must make ALL search queries in your first response via the `search` tool
 57    - All queries will be executed in parallel and results returned together
 58    - In the second turn, you will select the most relevant results via the `select` tool.
 59
 60    ## User Edits
 61
 62    {edits}
 63
 64    ## Current cursor context
 65
 66    `````filename={current_file_path}
 67    {cursor_excerpt}
 68    `````
 69
 70    --
 71    Use the `search` tool now
 72"#};
 73
 74const SEARCH_TOOL_NAME: &str = "search";
 75
 76/// Search for relevant code
 77///
 78/// For the best results, run multiple queries at once with a single invocation of this tool.
 79#[derive(Clone, Deserialize, Serialize, JsonSchema)]
 80pub struct SearchToolInput {
 81    /// An array of queries to run for gathering context relevant to the next prediction
 82    #[schemars(length(max = 5))]
 83    pub queries: Box<[SearchToolQuery]>,
 84}
 85
 86#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
 87pub struct SearchToolQuery {
 88    /// A glob pattern to match file paths in the codebase
 89    pub glob: String,
 90    /// A regular expression to match content within the files matched by the glob pattern
 91    pub regex: String,
 92}
 93
 94const RESULTS_MESSAGE: &str = indoc! {"
 95    Here are the results of your queries combined and grouped by file:
 96
 97"};
 98
 99const SELECT_TOOL_NAME: &str = "select";
100
101const SELECT_PROMPT: &str = indoc! {"
102    Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
103    Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
104    Include up to 200 lines in total.
105"};
106
107/// Select line ranges from search results
108#[derive(Deserialize, JsonSchema)]
109struct SelectToolInput {
110    /// The line ranges to select from search results.
111    ranges: Vec<SelectLineRange>,
112}
113
114/// A specific line range to select from a file
115#[derive(Debug, Deserialize, JsonSchema)]
116struct SelectLineRange {
117    /// The file path containing the lines to select
118    /// Exactly as it appears in the search result codeblocks.
119    path: PathBuf,
120    /// The starting line number (1-based)
121    #[schemars(range(min = 1))]
122    start_line: u32,
123    /// The ending line number (1-based, inclusive)
124    #[schemars(range(min = 1))]
125    end_line: u32,
126}
127
128#[derive(Debug, Clone, PartialEq)]
129pub struct LlmContextOptions {
130    pub excerpt: EditPredictionExcerptOptions,
131}
132
133pub fn find_related_excerpts<'a>(
134    buffer: Entity<language::Buffer>,
135    cursor_position: Anchor,
136    project: &Entity<Project>,
137    events: impl Iterator<Item = &'a crate::Event>,
138    options: &LlmContextOptions,
139    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
140    cx: &App,
141) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
142    let language_model_registry = LanguageModelRegistry::global(cx);
143    let Some(model) = language_model_registry
144        .read(cx)
145        .available_models(cx)
146        .find(|model| {
147            model.provider_id() == language_model::ANTHROPIC_PROVIDER_ID
148                && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
149        })
150    else {
151        return Task::ready(Err(anyhow!("could not find claude model")));
152    };
153
154    let mut edits_string = String::new();
155
156    for event in events {
157        if let Some(event) = event.to_request_event(cx) {
158            writeln!(&mut edits_string, "{event}").ok();
159        }
160    }
161
162    if edits_string.is_empty() {
163        edits_string.push_str("(No user edits yet)");
164    }
165
166    // TODO [zeta2] include breadcrumbs?
167    let snapshot = buffer.read(cx).snapshot();
168    let cursor_point = cursor_position.to_point(&snapshot);
169    let Some(cursor_excerpt) =
170        EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
171    else {
172        return Task::ready(Ok(HashMap::default()));
173    };
174
175    let current_file_path = snapshot
176        .file()
177        .map(|f| f.full_path(cx).display().to_string())
178        .unwrap_or_else(|| "untitled".to_string());
179
180    let prompt = SEARCH_PROMPT
181        .replace("{edits}", &edits_string)
182        .replace("{current_file_path}", &current_file_path)
183        .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
184
185    let path_style = project.read(cx).path_style(cx);
186
187    let exclude_matcher = {
188        let global_settings = WorktreeSettings::get_global(cx);
189        let exclude_patterns = global_settings
190            .file_scan_exclusions
191            .sources()
192            .iter()
193            .chain(global_settings.private_files.sources().iter());
194
195        match PathMatcher::new(exclude_patterns, path_style) {
196            Ok(matcher) => matcher,
197            Err(err) => {
198                return Task::ready(Err(anyhow!(err)));
199            }
200        }
201    };
202
203    let project = project.clone();
204    cx.spawn(async move |cx| {
205        let initial_prompt_message = LanguageModelRequestMessage {
206            role: Role::User,
207            content: vec![prompt.into()],
208            cache: false,
209        };
210
211        let mut search_stream = request_tool_call::<SearchToolInput>(
212            vec![initial_prompt_message.clone()],
213            SEARCH_TOOL_NAME,
214            &model,
215            cx,
216        )
217        .await?;
218
219        let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
220        select_request_messages.push(initial_prompt_message);
221
222        let mut regex_by_glob: HashMap<String, String> = HashMap::default();
223        let mut search_calls = Vec::new();
224
225        while let Some(event) = search_stream.next().await {
226            match event? {
227                LanguageModelCompletionEvent::ToolUse(tool_use) => {
228                    if !tool_use.is_input_complete {
229                        continue;
230                    }
231
232                    if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
233                        let input =
234                            serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
235
236                        for query in input.queries {
237                            let regex = regex_by_glob.entry(query.glob).or_default();
238                            if !regex.is_empty() {
239                                regex.push('|');
240                            }
241                            regex.push_str(&query.regex);
242                        }
243
244                        search_calls.push(tool_use);
245                    } else {
246                        log::warn!(
247                            "context gathering model tried to use unknown tool: {}",
248                            tool_use.name
249                        );
250                    }
251                }
252                LanguageModelCompletionEvent::Text(txt) => {
253                    if let Some(LanguageModelRequestMessage {
254                        role: Role::Assistant,
255                        content,
256                        ..
257                    }) = select_request_messages.last_mut()
258                    {
259                        if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
260                            existing_text.push_str(&txt);
261                        } else {
262                            content.push(MessageContent::Text(txt));
263                        }
264                    } else {
265                        select_request_messages.push(LanguageModelRequestMessage {
266                            role: Role::Assistant,
267                            content: vec![MessageContent::Text(txt)],
268                            cache: false,
269                        });
270                    }
271                }
272                LanguageModelCompletionEvent::Thinking { text, signature } => {
273                    if let Some(LanguageModelRequestMessage {
274                        role: Role::Assistant,
275                        content,
276                        ..
277                    }) = select_request_messages.last_mut()
278                    {
279                        if let Some(MessageContent::Thinking {
280                            text: existing_text,
281                            signature: existing_signature,
282                        }) = content.last_mut()
283                        {
284                            existing_text.push_str(&text);
285                            *existing_signature = signature;
286                        } else {
287                            content.push(MessageContent::Thinking { text, signature });
288                        }
289                    } else {
290                        select_request_messages.push(LanguageModelRequestMessage {
291                            role: Role::Assistant,
292                            content: vec![MessageContent::Thinking { text, signature }],
293                            cache: false,
294                        });
295                    }
296                }
297                LanguageModelCompletionEvent::RedactedThinking { data } => {
298                    if let Some(LanguageModelRequestMessage {
299                        role: Role::Assistant,
300                        content,
301                        ..
302                    }) = select_request_messages.last_mut()
303                    {
304                        if let Some(MessageContent::RedactedThinking(existing_data)) =
305                            content.last_mut()
306                        {
307                            existing_data.push_str(&data);
308                        } else {
309                            content.push(MessageContent::RedactedThinking(data));
310                        }
311                    } else {
312                        select_request_messages.push(LanguageModelRequestMessage {
313                            role: Role::Assistant,
314                            content: vec![MessageContent::RedactedThinking(data)],
315                            cache: false,
316                        });
317                    }
318                }
319                ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
320                    log::error!("{ev:?}");
321                }
322                ev => {
323                    log::trace!("context search event: {ev:?}")
324                }
325            }
326        }
327
328        let search_tool_use = if search_calls.is_empty() {
329            log::warn!("context model ran 0 searches");
330            return anyhow::Ok(Default::default());
331        } else if search_calls.len() == 1 {
332            search_calls.swap_remove(0)
333        } else {
334            // In theory, the model could perform multiple search calls
335            // Dealing with them separately is not worth it when it doesn't happen in practice.
336            // If it were to happen, here we would combine them into one.
337            // The second request doesn't need to know it was actually two different calls ;)
338            let input = serde_json::to_value(&SearchToolInput {
339                queries: regex_by_glob
340                    .iter()
341                    .map(|(glob, regex)| SearchToolQuery {
342                        glob: glob.clone(),
343                        regex: regex.clone(),
344                    })
345                    .collect(),
346            })
347            .unwrap_or_default();
348
349            LanguageModelToolUse {
350                id: search_calls.swap_remove(0).id,
351                name: SELECT_TOOL_NAME.into(),
352                raw_input: serde_json::to_string(&input).unwrap_or_default(),
353                input,
354                is_input_complete: true,
355            }
356        };
357
358        if let Some(debug_tx) = &debug_tx {
359            debug_tx
360                .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
361                    ZetaSearchQueryDebugInfo {
362                        project: project.clone(),
363                        timestamp: Instant::now(),
364                        queries: regex_by_glob
365                            .iter()
366                            .map(|(glob, regex)| SearchToolQuery {
367                                glob: glob.clone(),
368                                regex: regex.clone(),
369                            })
370                            .collect(),
371                    },
372                ))
373                .ok();
374        }
375
376        let (results_tx, mut results_rx) = mpsc::unbounded();
377
378        for (glob, regex) in regex_by_glob {
379            let exclude_matcher = exclude_matcher.clone();
380            let results_tx = results_tx.clone();
381            let project = project.clone();
382            cx.spawn(async move |cx| {
383                run_query(
384                    &glob,
385                    &regex,
386                    results_tx.clone(),
387                    path_style,
388                    exclude_matcher,
389                    &project,
390                    cx,
391                )
392                .await
393                .log_err();
394            })
395            .detach()
396        }
397        drop(results_tx);
398
399        struct ResultBuffer {
400            buffer: Entity<Buffer>,
401            snapshot: TextBufferSnapshot,
402        }
403
404        let (result_buffers_by_path, merged_result) = cx
405            .background_spawn(async move {
406                let mut excerpts_by_buffer: HashMap<Entity<Buffer>, MatchedBuffer> =
407                    HashMap::default();
408
409                while let Some((buffer, matched)) = results_rx.next().await {
410                    match excerpts_by_buffer.entry(buffer) {
411                        Entry::Occupied(mut entry) => {
412                            let entry = entry.get_mut();
413                            entry.full_path = matched.full_path;
414                            entry.snapshot = matched.snapshot;
415                            entry.line_ranges.extend(matched.line_ranges);
416                        }
417                        Entry::Vacant(entry) => {
418                            entry.insert(matched);
419                        }
420                    }
421                }
422
423                let mut result_buffers_by_path = HashMap::default();
424                let mut merged_result = RESULTS_MESSAGE.to_string();
425
426                for (buffer, mut matched) in excerpts_by_buffer {
427                    matched
428                        .line_ranges
429                        .sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
430
431                    writeln!(
432                        &mut merged_result,
433                        "`````filename={}",
434                        matched.full_path.display()
435                    )
436                    .unwrap();
437                    write_merged_excerpts(
438                        &matched.snapshot,
439                        matched.line_ranges,
440                        &[],
441                        &mut merged_result,
442                    );
443                    merged_result.push_str("`````\n\n");
444
445                    result_buffers_by_path.insert(
446                        matched.full_path,
447                        ResultBuffer {
448                            buffer,
449                            snapshot: matched.snapshot.text,
450                        },
451                    );
452                }
453
454                (result_buffers_by_path, merged_result)
455            })
456            .await;
457
458        if let Some(debug_tx) = &debug_tx {
459            debug_tx
460                .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
461                    ZetaContextRetrievalDebugInfo {
462                        project: project.clone(),
463                        timestamp: Instant::now(),
464                    },
465                ))
466                .ok();
467        }
468
469        let tool_result = LanguageModelToolResult {
470            tool_use_id: search_tool_use.id.clone(),
471            tool_name: SEARCH_TOOL_NAME.into(),
472            is_error: false,
473            content: merged_result.into(),
474            output: None,
475        };
476
477        select_request_messages.extend([
478            LanguageModelRequestMessage {
479                role: Role::Assistant,
480                content: vec![MessageContent::ToolUse(search_tool_use)],
481                cache: false,
482            },
483            LanguageModelRequestMessage {
484                role: Role::User,
485                content: vec![MessageContent::ToolResult(tool_result)],
486                cache: false,
487            },
488        ]);
489
490        if result_buffers_by_path.is_empty() {
491            log::trace!("context gathering queries produced no results");
492            return anyhow::Ok(HashMap::default());
493        }
494
495        select_request_messages.push(LanguageModelRequestMessage {
496            role: Role::User,
497            content: vec![SELECT_PROMPT.into()],
498            cache: false,
499        });
500
501        let mut select_stream = request_tool_call::<SelectToolInput>(
502            select_request_messages,
503            SELECT_TOOL_NAME,
504            &model,
505            cx,
506        )
507        .await?;
508
509        cx.background_spawn(async move {
510            let mut selected_ranges = Vec::new();
511
512            while let Some(event) = select_stream.next().await {
513                match event? {
514                    LanguageModelCompletionEvent::ToolUse(tool_use) => {
515                        if !tool_use.is_input_complete {
516                            continue;
517                        }
518
519                        if tool_use.name.as_ref() == SELECT_TOOL_NAME {
520                            let call =
521                                serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
522                            selected_ranges.extend(call.ranges);
523                        } else {
524                            log::warn!(
525                                "context gathering model tried to use unknown tool: {}",
526                                tool_use.name
527                            );
528                        }
529                    }
530                    ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
531                        log::error!("{ev:?}");
532                    }
533                    ev => {
534                        log::trace!("context select event: {ev:?}")
535                    }
536                }
537            }
538
539            if let Some(debug_tx) = &debug_tx {
540                debug_tx
541                    .unbounded_send(ZetaDebugInfo::SearchResultsFiltered(
542                        ZetaContextRetrievalDebugInfo {
543                            project: project.clone(),
544                            timestamp: Instant::now(),
545                        },
546                    ))
547                    .ok();
548            }
549
550            if selected_ranges.is_empty() {
551                log::trace!("context gathering selected no ranges")
552            }
553
554            selected_ranges.sort_unstable_by(|a, b| {
555                a.start_line
556                    .cmp(&b.start_line)
557                    .then(b.end_line.cmp(&a.end_line))
558            });
559
560            let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
561
562            for selected_range in selected_ranges {
563                if let Some(ResultBuffer { buffer, snapshot }) =
564                    result_buffers_by_path.get(&selected_range.path)
565                {
566                    let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
567                    let end_point =
568                        snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
569                    let range =
570                        snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
571
572                    related_excerpts_by_buffer
573                        .entry(buffer.clone())
574                        .or_default()
575                        .push(range);
576                } else {
577                    log::warn!(
578                        "selected path that wasn't included in search results: {}",
579                        selected_range.path.display()
580                    );
581                }
582            }
583
584            anyhow::Ok(related_excerpts_by_buffer)
585        })
586        .await
587    })
588}
589
590async fn request_tool_call<T: JsonSchema>(
591    messages: Vec<LanguageModelRequestMessage>,
592    tool_name: &'static str,
593    model: &Arc<dyn LanguageModel>,
594    cx: &mut AsyncApp,
595) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
596{
597    let schema = schemars::schema_for!(T);
598
599    let request = LanguageModelRequest {
600        messages,
601        tools: vec![LanguageModelRequestTool {
602            name: tool_name.into(),
603            description: schema
604                .get("description")
605                .and_then(|description| description.as_str())
606                .unwrap()
607                .to_string(),
608            input_schema: serde_json::to_value(schema).unwrap(),
609        }],
610        ..Default::default()
611    };
612
613    Ok(model.stream_completion(request, cx).await?)
614}
615
616const MIN_EXCERPT_LEN: usize = 16;
617const MAX_EXCERPT_LEN: usize = 768;
618const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
619
620struct MatchedBuffer {
621    snapshot: BufferSnapshot,
622    line_ranges: Vec<Range<Line>>,
623    full_path: PathBuf,
624}
625
626async fn run_query(
627    glob: &str,
628    regex: &str,
629    results_tx: UnboundedSender<(Entity<Buffer>, MatchedBuffer)>,
630    path_style: PathStyle,
631    exclude_matcher: PathMatcher,
632    project: &Entity<Project>,
633    cx: &mut AsyncApp,
634) -> Result<()> {
635    let include_matcher = PathMatcher::new(vec![glob], path_style)?;
636
637    let query = SearchQuery::regex(
638        regex,
639        false,
640        true,
641        false,
642        true,
643        include_matcher,
644        exclude_matcher,
645        true,
646        None,
647    )?;
648
649    let results = project.update(cx, |project, cx| project.search(query, cx))?;
650    futures::pin_mut!(results);
651
652    let mut total_bytes = 0;
653
654    while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
655        if ranges.is_empty() {
656            continue;
657        }
658
659        let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| {
660            Some((buffer.snapshot(), buffer.file()?.full_path(cx)))
661        })?
662        else {
663            continue;
664        };
665
666        let results_tx = results_tx.clone();
667        cx.background_spawn(async move {
668            let mut line_ranges = Vec::with_capacity(ranges.len());
669
670            for range in ranges {
671                let offset_range = range.to_offset(&snapshot);
672                let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
673
674                if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
675                    break;
676                }
677
678                let excerpt = EditPredictionExcerpt::select_from_buffer(
679                    query_point,
680                    &snapshot,
681                    &EditPredictionExcerptOptions {
682                        max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
683                        min_bytes: MIN_EXCERPT_LEN,
684                        target_before_cursor_over_total_bytes: 0.5,
685                    },
686                    None,
687                );
688
689                if let Some(excerpt) = excerpt {
690                    total_bytes += excerpt.range.len();
691                    if !excerpt.line_range.is_empty() {
692                        line_ranges.push(excerpt.line_range);
693                    }
694                }
695            }
696
697            results_tx
698                .unbounded_send((
699                    buffer,
700                    MatchedBuffer {
701                        snapshot,
702                        line_ranges,
703                        full_path,
704                    },
705                ))
706                .log_err();
707        })
708        .detach();
709    }
710
711    anyhow::Ok(())
712}