related_excerpts.rs

  1use std::{
  2    cmp::Reverse, collections::hash_map::Entry, ops::Range, path::PathBuf, sync::Arc, time::Instant,
  3};
  4
  5use crate::{
  6    ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
  7    ZetaSearchQueryDebugInfo, merge_excerpts::merge_excerpts,
  8};
  9use anyhow::{Result, anyhow};
 10use cloud_zeta2_prompt::write_codeblock;
 11use collections::HashMap;
 12use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
 13use futures::{
 14    StreamExt,
 15    channel::mpsc::{self, UnboundedSender},
 16    stream::BoxStream,
 17};
 18use gpui::{App, AppContext, AsyncApp, Entity, Task};
 19use indoc::indoc;
 20use language::{
 21    Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _,
 22};
 23use language_model::{
 24    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
 25    LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
 26    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
 27    LanguageModelToolUse, MessageContent, Role,
 28};
 29use project::{
 30    Project, WorktreeSettings,
 31    search::{SearchQuery, SearchResult},
 32};
 33use schemars::JsonSchema;
 34use serde::{Deserialize, Serialize};
 35use util::{
 36    ResultExt as _,
 37    paths::{PathMatcher, PathStyle},
 38};
 39use workspace::item::Settings as _;
 40
 41const SEARCH_PROMPT: &str = indoc! {r#"
 42    ## Task
 43
 44    You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
 45    that will serve as context for predicting the next required edit.
 46
 47    **Your task:**
 48    - Analyze the user's recent edits and current cursor context
 49    - Use the `search` tool to find code that may be relevant for predicting the next edit
 50    - Focus on finding:
 51       - Code patterns that might need similar changes based on the recent edits
 52       - Functions, variables, types, and constants referenced in the current cursor context
 53       - Related implementations, usages, or dependencies that may require consistent updates
 54
 55    **Important constraints:**
 56    - This conversation has exactly 2 turns
 57    - You must make ALL search queries in your first response via the `search` tool
 58    - All queries will be executed in parallel and results returned together
 59    - In the second turn, you will select the most relevant results via the `select` tool.
 60
 61    ## User Edits
 62
 63    {edits}
 64
 65    ## Current cursor context
 66
 67    `````{current_file_path}
 68    {cursor_excerpt}
 69    `````
 70
 71    --
 72    Use the `search` tool now
 73"#};
 74
 75const SEARCH_TOOL_NAME: &str = "search";
 76
 77/// Search for relevant code
 78///
 79/// For the best results, run multiple queries at once with a single invocation of this tool.
 80#[derive(Clone, Deserialize, Serialize, JsonSchema)]
 81pub struct SearchToolInput {
 82    /// An array of queries to run for gathering context relevant to the next prediction
 83    #[schemars(length(max = 5))]
 84    pub queries: Box<[SearchToolQuery]>,
 85}
 86
 87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
 88pub struct SearchToolQuery {
 89    /// A glob pattern to match file paths in the codebase
 90    pub glob: String,
 91    /// A regular expression to match content within the files matched by the glob pattern
 92    pub regex: String,
 93}
 94
 95const RESULTS_MESSAGE: &str = indoc! {"
 96    Here are the results of your queries combined and grouped by file:
 97
 98"};
 99
100const SELECT_TOOL_NAME: &str = "select";
101
102const SELECT_PROMPT: &str = indoc! {"
103    Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
104    Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
105    Include up to 200 lines in total.
106"};
107
108/// Select line ranges from search results
109#[derive(Deserialize, JsonSchema)]
110struct SelectToolInput {
111    /// The line ranges to select from search results.
112    ranges: Vec<SelectLineRange>,
113}
114
115/// A specific line range to select from a file
116#[derive(Debug, Deserialize, JsonSchema)]
117struct SelectLineRange {
118    /// The file path containing the lines to select
119    /// Exactly as it appears in the search result codeblocks.
120    path: PathBuf,
121    /// The starting line number (1-based)
122    #[schemars(range(min = 1))]
123    start_line: u32,
124    /// The ending line number (1-based, inclusive)
125    #[schemars(range(min = 1))]
126    end_line: u32,
127}
128
129#[derive(Debug, Clone, PartialEq)]
130pub struct LlmContextOptions {
131    pub excerpt: EditPredictionExcerptOptions,
132}
133
134pub const MODEL_PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
135
136pub fn find_related_excerpts(
137    buffer: Entity<language::Buffer>,
138    cursor_position: Anchor,
139    project: &Entity<Project>,
140    mut edit_history_unified_diff: String,
141    options: &LlmContextOptions,
142    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
143    cx: &App,
144) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
145    let language_model_registry = LanguageModelRegistry::global(cx);
146    let Some(model) = language_model_registry
147        .read(cx)
148        .available_models(cx)
149        .find(|model| {
150            model.provider_id() == MODEL_PROVIDER_ID
151                && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
152            // model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b")
153            // model.provider_id() == LanguageModelProviderId::new("ollama")
154            //     && model.id() == LanguageModelId("gpt-oss:20b".into())
155        })
156    else {
157        return Task::ready(Err(anyhow!("could not find context model")));
158    };
159
160    if edit_history_unified_diff.is_empty() {
161        edit_history_unified_diff.push_str("(No user edits yet)");
162    }
163
164    // TODO [zeta2] include breadcrumbs?
165    let snapshot = buffer.read(cx).snapshot();
166    let cursor_point = cursor_position.to_point(&snapshot);
167    let Some(cursor_excerpt) =
168        EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
169    else {
170        return Task::ready(Ok(HashMap::default()));
171    };
172
173    let current_file_path = snapshot
174        .file()
175        .map(|f| f.full_path(cx).display().to_string())
176        .unwrap_or_else(|| "untitled".to_string());
177
178    let prompt = SEARCH_PROMPT
179        .replace("{edits}", &edit_history_unified_diff)
180        .replace("{current_file_path}", &current_file_path)
181        .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
182
183    if let Some(debug_tx) = &debug_tx {
184        debug_tx
185            .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
186                ZetaContextRetrievalStartedDebugInfo {
187                    project: project.clone(),
188                    timestamp: Instant::now(),
189                    search_prompt: prompt.clone(),
190                },
191            ))
192            .ok();
193    }
194
195    let path_style = project.read(cx).path_style(cx);
196
197    let exclude_matcher = {
198        let global_settings = WorktreeSettings::get_global(cx);
199        let exclude_patterns = global_settings
200            .file_scan_exclusions
201            .sources()
202            .iter()
203            .chain(global_settings.private_files.sources().iter());
204
205        match PathMatcher::new(exclude_patterns, path_style) {
206            Ok(matcher) => matcher,
207            Err(err) => {
208                return Task::ready(Err(anyhow!(err)));
209            }
210        }
211    };
212
213    let project = project.clone();
214    cx.spawn(async move |cx| {
215        let initial_prompt_message = LanguageModelRequestMessage {
216            role: Role::User,
217            content: vec![prompt.into()],
218            cache: false,
219        };
220
221        let mut search_stream = request_tool_call::<SearchToolInput>(
222            vec![initial_prompt_message.clone()],
223            SEARCH_TOOL_NAME,
224            &model,
225            cx,
226        )
227        .await?;
228
229        let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
230        select_request_messages.push(initial_prompt_message);
231
232        let mut regex_by_glob: HashMap<String, String> = HashMap::default();
233        let mut search_calls = Vec::new();
234
235        while let Some(event) = search_stream.next().await {
236            match event? {
237                LanguageModelCompletionEvent::ToolUse(tool_use) => {
238                    if !tool_use.is_input_complete {
239                        continue;
240                    }
241
242                    if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
243                        let input =
244                            serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
245
246                        for query in input.queries {
247                            let regex = regex_by_glob.entry(query.glob).or_default();
248                            if !regex.is_empty() {
249                                regex.push('|');
250                            }
251                            regex.push_str(&query.regex);
252                        }
253
254                        search_calls.push(tool_use);
255                    } else {
256                        log::warn!(
257                            "context gathering model tried to use unknown tool: {}",
258                            tool_use.name
259                        );
260                    }
261                }
262                LanguageModelCompletionEvent::Text(txt) => {
263                    if let Some(LanguageModelRequestMessage {
264                        role: Role::Assistant,
265                        content,
266                        ..
267                    }) = select_request_messages.last_mut()
268                    {
269                        if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
270                            existing_text.push_str(&txt);
271                        } else {
272                            content.push(MessageContent::Text(txt));
273                        }
274                    } else {
275                        select_request_messages.push(LanguageModelRequestMessage {
276                            role: Role::Assistant,
277                            content: vec![MessageContent::Text(txt)],
278                            cache: false,
279                        });
280                    }
281                }
282                LanguageModelCompletionEvent::Thinking { text, signature } => {
283                    if let Some(LanguageModelRequestMessage {
284                        role: Role::Assistant,
285                        content,
286                        ..
287                    }) = select_request_messages.last_mut()
288                    {
289                        if let Some(MessageContent::Thinking {
290                            text: existing_text,
291                            signature: existing_signature,
292                        }) = content.last_mut()
293                        {
294                            existing_text.push_str(&text);
295                            *existing_signature = signature;
296                        } else {
297                            content.push(MessageContent::Thinking { text, signature });
298                        }
299                    } else {
300                        select_request_messages.push(LanguageModelRequestMessage {
301                            role: Role::Assistant,
302                            content: vec![MessageContent::Thinking { text, signature }],
303                            cache: false,
304                        });
305                    }
306                }
307                LanguageModelCompletionEvent::RedactedThinking { data } => {
308                    if let Some(LanguageModelRequestMessage {
309                        role: Role::Assistant,
310                        content,
311                        ..
312                    }) = select_request_messages.last_mut()
313                    {
314                        if let Some(MessageContent::RedactedThinking(existing_data)) =
315                            content.last_mut()
316                        {
317                            existing_data.push_str(&data);
318                        } else {
319                            content.push(MessageContent::RedactedThinking(data));
320                        }
321                    } else {
322                        select_request_messages.push(LanguageModelRequestMessage {
323                            role: Role::Assistant,
324                            content: vec![MessageContent::RedactedThinking(data)],
325                            cache: false,
326                        });
327                    }
328                }
329                ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
330                    log::error!("{ev:?}");
331                }
332                ev => {
333                    log::trace!("context search event: {ev:?}")
334                }
335            }
336        }
337
338        let search_tool_use = if search_calls.is_empty() {
339            log::warn!("context model ran 0 searches");
340            return anyhow::Ok(Default::default());
341        } else if search_calls.len() == 1 {
342            search_calls.swap_remove(0)
343        } else {
344            // In theory, the model could perform multiple search calls
345            // Dealing with them separately is not worth it when it doesn't happen in practice.
346            // If it were to happen, here we would combine them into one.
347            // The second request doesn't need to know it was actually two different calls ;)
348            let input = serde_json::to_value(&SearchToolInput {
349                queries: regex_by_glob
350                    .iter()
351                    .map(|(glob, regex)| SearchToolQuery {
352                        glob: glob.clone(),
353                        regex: regex.clone(),
354                    })
355                    .collect(),
356            })
357            .unwrap_or_default();
358
359            LanguageModelToolUse {
360                id: search_calls.swap_remove(0).id,
361                name: SELECT_TOOL_NAME.into(),
362                raw_input: serde_json::to_string(&input).unwrap_or_default(),
363                input,
364                is_input_complete: true,
365                thought_signature: None,
366            }
367        };
368
369        if let Some(debug_tx) = &debug_tx {
370            debug_tx
371                .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
372                    ZetaSearchQueryDebugInfo {
373                        project: project.clone(),
374                        timestamp: Instant::now(),
375                        queries: regex_by_glob
376                            .iter()
377                            .map(|(glob, regex)| SearchToolQuery {
378                                glob: glob.clone(),
379                                regex: regex.clone(),
380                            })
381                            .collect(),
382                    },
383                ))
384                .ok();
385        }
386
387        let (results_tx, mut results_rx) = mpsc::unbounded();
388
389        for (glob, regex) in regex_by_glob {
390            let exclude_matcher = exclude_matcher.clone();
391            let results_tx = results_tx.clone();
392            let project = project.clone();
393            cx.spawn(async move |cx| {
394                run_query(
395                    &glob,
396                    &regex,
397                    results_tx.clone(),
398                    path_style,
399                    exclude_matcher,
400                    &project,
401                    cx,
402                )
403                .await
404                .log_err();
405            })
406            .detach()
407        }
408        drop(results_tx);
409
410        struct ResultBuffer {
411            buffer: Entity<Buffer>,
412            snapshot: TextBufferSnapshot,
413        }
414
415        let (result_buffers_by_path, merged_result) = cx
416            .background_spawn(async move {
417                let mut excerpts_by_buffer: HashMap<Entity<Buffer>, MatchedBuffer> =
418                    HashMap::default();
419
420                while let Some((buffer, matched)) = results_rx.next().await {
421                    match excerpts_by_buffer.entry(buffer) {
422                        Entry::Occupied(mut entry) => {
423                            let entry = entry.get_mut();
424                            entry.full_path = matched.full_path;
425                            entry.snapshot = matched.snapshot;
426                            entry.line_ranges.extend(matched.line_ranges);
427                        }
428                        Entry::Vacant(entry) => {
429                            entry.insert(matched);
430                        }
431                    }
432                }
433
434                let mut result_buffers_by_path = HashMap::default();
435                let mut merged_result = RESULTS_MESSAGE.to_string();
436
437                for (buffer, mut matched) in excerpts_by_buffer {
438                    matched
439                        .line_ranges
440                        .sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
441
442                    write_codeblock(
443                        &matched.full_path,
444                        merge_excerpts(&matched.snapshot, matched.line_ranges).iter(),
445                        &[],
446                        Line(matched.snapshot.max_point().row),
447                        true,
448                        &mut merged_result,
449                    );
450
451                    result_buffers_by_path.insert(
452                        matched.full_path,
453                        ResultBuffer {
454                            buffer,
455                            snapshot: matched.snapshot.text,
456                        },
457                    );
458                }
459
460                (result_buffers_by_path, merged_result)
461            })
462            .await;
463
464        if let Some(debug_tx) = &debug_tx {
465            debug_tx
466                .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
467                    ZetaContextRetrievalDebugInfo {
468                        project: project.clone(),
469                        timestamp: Instant::now(),
470                    },
471                ))
472                .ok();
473        }
474
475        let tool_result = LanguageModelToolResult {
476            tool_use_id: search_tool_use.id.clone(),
477            tool_name: SEARCH_TOOL_NAME.into(),
478            is_error: false,
479            content: merged_result.into(),
480            output: None,
481        };
482
483        select_request_messages.extend([
484            LanguageModelRequestMessage {
485                role: Role::Assistant,
486                content: vec![MessageContent::ToolUse(search_tool_use)],
487                cache: false,
488            },
489            LanguageModelRequestMessage {
490                role: Role::User,
491                content: vec![MessageContent::ToolResult(tool_result)],
492                cache: false,
493            },
494        ]);
495
496        if result_buffers_by_path.is_empty() {
497            log::trace!("context gathering queries produced no results");
498            return anyhow::Ok(HashMap::default());
499        }
500
501        select_request_messages.push(LanguageModelRequestMessage {
502            role: Role::User,
503            content: vec![SELECT_PROMPT.into()],
504            cache: false,
505        });
506
507        let mut select_stream = request_tool_call::<SelectToolInput>(
508            select_request_messages,
509            SELECT_TOOL_NAME,
510            &model,
511            cx,
512        )
513        .await?;
514
515        cx.background_spawn(async move {
516            let mut selected_ranges = Vec::new();
517
518            while let Some(event) = select_stream.next().await {
519                match event? {
520                    LanguageModelCompletionEvent::ToolUse(tool_use) => {
521                        if !tool_use.is_input_complete {
522                            continue;
523                        }
524
525                        if tool_use.name.as_ref() == SELECT_TOOL_NAME {
526                            let call =
527                                serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
528                            selected_ranges.extend(call.ranges);
529                        } else {
530                            log::warn!(
531                                "context gathering model tried to use unknown tool: {}",
532                                tool_use.name
533                            );
534                        }
535                    }
536                    ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
537                        log::error!("{ev:?}");
538                    }
539                    ev => {
540                        log::trace!("context select event: {ev:?}")
541                    }
542                }
543            }
544
545            if let Some(debug_tx) = &debug_tx {
546                debug_tx
547                    .unbounded_send(ZetaDebugInfo::SearchResultsFiltered(
548                        ZetaContextRetrievalDebugInfo {
549                            project: project.clone(),
550                            timestamp: Instant::now(),
551                        },
552                    ))
553                    .ok();
554            }
555
556            if selected_ranges.is_empty() {
557                log::trace!("context gathering selected no ranges")
558            }
559
560            selected_ranges.sort_unstable_by(|a, b| {
561                a.start_line
562                    .cmp(&b.start_line)
563                    .then(b.end_line.cmp(&a.end_line))
564            });
565
566            let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
567
568            for selected_range in selected_ranges {
569                if let Some(ResultBuffer { buffer, snapshot }) =
570                    result_buffers_by_path.get(&selected_range.path)
571                {
572                    let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
573                    let end_point =
574                        snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
575                    let range =
576                        snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
577
578                    related_excerpts_by_buffer
579                        .entry(buffer.clone())
580                        .or_default()
581                        .push(range);
582                } else {
583                    log::warn!(
584                        "selected path that wasn't included in search results: {}",
585                        selected_range.path.display()
586                    );
587                }
588            }
589
590            anyhow::Ok(related_excerpts_by_buffer)
591        })
592        .await
593    })
594}
595
596async fn request_tool_call<T: JsonSchema>(
597    messages: Vec<LanguageModelRequestMessage>,
598    tool_name: &'static str,
599    model: &Arc<dyn LanguageModel>,
600    cx: &mut AsyncApp,
601) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
602{
603    let schema = schemars::schema_for!(T);
604
605    let request = LanguageModelRequest {
606        messages,
607        tools: vec![LanguageModelRequestTool {
608            name: tool_name.into(),
609            description: schema
610                .get("description")
611                .and_then(|description| description.as_str())
612                .unwrap()
613                .to_string(),
614            input_schema: serde_json::to_value(schema).unwrap(),
615        }],
616        ..Default::default()
617    };
618
619    Ok(model.stream_completion(request, cx).await?)
620}
621
622const MIN_EXCERPT_LEN: usize = 16;
623const MAX_EXCERPT_LEN: usize = 768;
624const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
625
626struct MatchedBuffer {
627    snapshot: BufferSnapshot,
628    line_ranges: Vec<Range<Line>>,
629    full_path: PathBuf,
630}
631
632async fn run_query(
633    glob: &str,
634    regex: &str,
635    results_tx: UnboundedSender<(Entity<Buffer>, MatchedBuffer)>,
636    path_style: PathStyle,
637    exclude_matcher: PathMatcher,
638    project: &Entity<Project>,
639    cx: &mut AsyncApp,
640) -> Result<()> {
641    let include_matcher = PathMatcher::new(vec![glob], path_style)?;
642
643    let query = SearchQuery::regex(
644        regex,
645        false,
646        true,
647        false,
648        true,
649        include_matcher,
650        exclude_matcher,
651        true,
652        None,
653    )?;
654
655    let results = project.update(cx, |project, cx| project.search(query, cx))?;
656    futures::pin_mut!(results);
657
658    let mut total_bytes = 0;
659
660    while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
661        if ranges.is_empty() {
662            continue;
663        }
664
665        let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| {
666            Some((buffer.snapshot(), buffer.file()?.full_path(cx)))
667        })?
668        else {
669            continue;
670        };
671
672        let results_tx = results_tx.clone();
673        cx.background_spawn(async move {
674            let mut line_ranges = Vec::with_capacity(ranges.len());
675
676            for range in ranges {
677                let offset_range = range.to_offset(&snapshot);
678                let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
679
680                if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
681                    break;
682                }
683
684                let excerpt = EditPredictionExcerpt::select_from_buffer(
685                    query_point,
686                    &snapshot,
687                    &EditPredictionExcerptOptions {
688                        max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
689                        min_bytes: MIN_EXCERPT_LEN,
690                        target_before_cursor_over_total_bytes: 0.5,
691                    },
692                    None,
693                );
694
695                if let Some(excerpt) = excerpt {
696                    total_bytes += excerpt.range.len();
697                    if !excerpt.line_range.is_empty() {
698                        line_ranges.push(excerpt.line_range);
699                    }
700                }
701            }
702
703            results_tx
704                .unbounded_send((
705                    buffer,
706                    MatchedBuffer {
707                        snapshot,
708                        line_ranges,
709                        full_path,
710                    },
711                ))
712                .log_err();
713        })
714        .detach();
715    }
716
717    anyhow::Ok(())
718}