1use std::{
  2    cmp::Reverse, collections::hash_map::Entry, ops::Range, path::PathBuf, sync::Arc, time::Instant,
  3};
  4
  5use crate::{
  6    ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
  7    ZetaSearchQueryDebugInfo, merge_excerpts::merge_excerpts,
  8};
  9use anyhow::{Result, anyhow};
 10use cloud_zeta2_prompt::write_codeblock;
 11use collections::HashMap;
 12use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
 13use futures::{
 14    StreamExt,
 15    channel::mpsc::{self, UnboundedSender},
 16    stream::BoxStream,
 17};
 18use gpui::{App, AppContext, AsyncApp, Entity, Task};
 19use indoc::indoc;
 20use language::{
 21    Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _,
 22};
 23use language_model::{
 24    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
 25    LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
 26    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
 27    LanguageModelToolUse, MessageContent, Role,
 28};
 29use project::{
 30    Project, WorktreeSettings,
 31    search::{SearchQuery, SearchResult},
 32};
 33use schemars::JsonSchema;
 34use serde::{Deserialize, Serialize};
 35use util::{
 36    ResultExt as _,
 37    paths::{PathMatcher, PathStyle},
 38};
 39use workspace::item::Settings as _;
 40
 41const SEARCH_PROMPT: &str = indoc! {r#"
 42    ## Task
 43
 44    You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
 45    that will serve as context for predicting the next required edit.
 46
 47    **Your task:**
 48    - Analyze the user's recent edits and current cursor context
 49    - Use the `search` tool to find code that may be relevant for predicting the next edit
 50    - Focus on finding:
 51       - Code patterns that might need similar changes based on the recent edits
 52       - Functions, variables, types, and constants referenced in the current cursor context
 53       - Related implementations, usages, or dependencies that may require consistent updates
 54
 55    **Important constraints:**
 56    - This conversation has exactly 2 turns
 57    - You must make ALL search queries in your first response via the `search` tool
 58    - All queries will be executed in parallel and results returned together
 59    - In the second turn, you will select the most relevant results via the `select` tool.
 60
 61    ## User Edits
 62
 63    {edits}
 64
 65    ## Current cursor context
 66
 67    `````{current_file_path}
 68    {cursor_excerpt}
 69    `````
 70
 71    --
 72    Use the `search` tool now
 73"#};
 74
 75const SEARCH_TOOL_NAME: &str = "search";
 76
 77/// Search for relevant code
 78///
 79/// For the best results, run multiple queries at once with a single invocation of this tool.
 80#[derive(Clone, Deserialize, Serialize, JsonSchema)]
 81pub struct SearchToolInput {
 82    /// An array of queries to run for gathering context relevant to the next prediction
 83    #[schemars(length(max = 5))]
 84    pub queries: Box<[SearchToolQuery]>,
 85}
 86
 87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
 88pub struct SearchToolQuery {
 89    /// A glob pattern to match file paths in the codebase
 90    pub glob: String,
 91    /// A regular expression to match content within the files matched by the glob pattern
 92    pub regex: String,
 93}
 94
 95const RESULTS_MESSAGE: &str = indoc! {"
 96    Here are the results of your queries combined and grouped by file:
 97
 98"};
 99
100const SELECT_TOOL_NAME: &str = "select";
101
102const SELECT_PROMPT: &str = indoc! {"
103    Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
104    Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
105    Include up to 200 lines in total.
106"};
107
108/// Select line ranges from search results
109#[derive(Deserialize, JsonSchema)]
110struct SelectToolInput {
111    /// The line ranges to select from search results.
112    ranges: Vec<SelectLineRange>,
113}
114
115/// A specific line range to select from a file
116#[derive(Debug, Deserialize, JsonSchema)]
117struct SelectLineRange {
118    /// The file path containing the lines to select
119    /// Exactly as it appears in the search result codeblocks.
120    path: PathBuf,
121    /// The starting line number (1-based)
122    #[schemars(range(min = 1))]
123    start_line: u32,
124    /// The ending line number (1-based, inclusive)
125    #[schemars(range(min = 1))]
126    end_line: u32,
127}
128
129#[derive(Debug, Clone, PartialEq)]
130pub struct LlmContextOptions {
131    pub excerpt: EditPredictionExcerptOptions,
132}
133
134pub const MODEL_PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
135
136pub fn find_related_excerpts(
137    buffer: Entity<language::Buffer>,
138    cursor_position: Anchor,
139    project: &Entity<Project>,
140    mut edit_history_unified_diff: String,
141    options: &LlmContextOptions,
142    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
143    cx: &App,
144) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
145    let language_model_registry = LanguageModelRegistry::global(cx);
146    let Some(model) = language_model_registry
147        .read(cx)
148        .available_models(cx)
149        .find(|model| {
150            model.provider_id() == MODEL_PROVIDER_ID
151                && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
152        })
153    else {
154        return Task::ready(Err(anyhow!("could not find context model")));
155    };
156
157    if edit_history_unified_diff.is_empty() {
158        edit_history_unified_diff.push_str("(No user edits yet)");
159    }
160
161    // TODO [zeta2] include breadcrumbs?
162    let snapshot = buffer.read(cx).snapshot();
163    let cursor_point = cursor_position.to_point(&snapshot);
164    let Some(cursor_excerpt) =
165        EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
166    else {
167        return Task::ready(Ok(HashMap::default()));
168    };
169
170    let current_file_path = snapshot
171        .file()
172        .map(|f| f.full_path(cx).display().to_string())
173        .unwrap_or_else(|| "untitled".to_string());
174
175    let prompt = SEARCH_PROMPT
176        .replace("{edits}", &edit_history_unified_diff)
177        .replace("{current_file_path}", ¤t_file_path)
178        .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
179
180    if let Some(debug_tx) = &debug_tx {
181        debug_tx
182            .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
183                ZetaContextRetrievalStartedDebugInfo {
184                    project: project.clone(),
185                    timestamp: Instant::now(),
186                    search_prompt: prompt.clone(),
187                },
188            ))
189            .ok();
190    }
191
192    let path_style = project.read(cx).path_style(cx);
193
194    let exclude_matcher = {
195        let global_settings = WorktreeSettings::get_global(cx);
196        let exclude_patterns = global_settings
197            .file_scan_exclusions
198            .sources()
199            .iter()
200            .chain(global_settings.private_files.sources().iter());
201
202        match PathMatcher::new(exclude_patterns, path_style) {
203            Ok(matcher) => matcher,
204            Err(err) => {
205                return Task::ready(Err(anyhow!(err)));
206            }
207        }
208    };
209
210    let project = project.clone();
211    cx.spawn(async move |cx| {
212        let initial_prompt_message = LanguageModelRequestMessage {
213            role: Role::User,
214            content: vec![prompt.into()],
215            cache: false,
216        };
217
218        let mut search_stream = request_tool_call::<SearchToolInput>(
219            vec![initial_prompt_message.clone()],
220            SEARCH_TOOL_NAME,
221            &model,
222            cx,
223        )
224        .await?;
225
226        let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
227        select_request_messages.push(initial_prompt_message);
228
229        let mut regex_by_glob: HashMap<String, String> = HashMap::default();
230        let mut search_calls = Vec::new();
231
232        while let Some(event) = search_stream.next().await {
233            match event? {
234                LanguageModelCompletionEvent::ToolUse(tool_use) => {
235                    if !tool_use.is_input_complete {
236                        continue;
237                    }
238
239                    if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
240                        let input =
241                            serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
242
243                        for query in input.queries {
244                            let regex = regex_by_glob.entry(query.glob).or_default();
245                            if !regex.is_empty() {
246                                regex.push('|');
247                            }
248                            regex.push_str(&query.regex);
249                        }
250
251                        search_calls.push(tool_use);
252                    } else {
253                        log::warn!(
254                            "context gathering model tried to use unknown tool: {}",
255                            tool_use.name
256                        );
257                    }
258                }
259                LanguageModelCompletionEvent::Text(txt) => {
260                    if let Some(LanguageModelRequestMessage {
261                        role: Role::Assistant,
262                        content,
263                        ..
264                    }) = select_request_messages.last_mut()
265                    {
266                        if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
267                            existing_text.push_str(&txt);
268                        } else {
269                            content.push(MessageContent::Text(txt));
270                        }
271                    } else {
272                        select_request_messages.push(LanguageModelRequestMessage {
273                            role: Role::Assistant,
274                            content: vec![MessageContent::Text(txt)],
275                            cache: false,
276                        });
277                    }
278                }
279                LanguageModelCompletionEvent::Thinking { text, signature } => {
280                    if let Some(LanguageModelRequestMessage {
281                        role: Role::Assistant,
282                        content,
283                        ..
284                    }) = select_request_messages.last_mut()
285                    {
286                        if let Some(MessageContent::Thinking {
287                            text: existing_text,
288                            signature: existing_signature,
289                        }) = content.last_mut()
290                        {
291                            existing_text.push_str(&text);
292                            *existing_signature = signature;
293                        } else {
294                            content.push(MessageContent::Thinking { text, signature });
295                        }
296                    } else {
297                        select_request_messages.push(LanguageModelRequestMessage {
298                            role: Role::Assistant,
299                            content: vec![MessageContent::Thinking { text, signature }],
300                            cache: false,
301                        });
302                    }
303                }
304                LanguageModelCompletionEvent::RedactedThinking { data } => {
305                    if let Some(LanguageModelRequestMessage {
306                        role: Role::Assistant,
307                        content,
308                        ..
309                    }) = select_request_messages.last_mut()
310                    {
311                        if let Some(MessageContent::RedactedThinking(existing_data)) =
312                            content.last_mut()
313                        {
314                            existing_data.push_str(&data);
315                        } else {
316                            content.push(MessageContent::RedactedThinking(data));
317                        }
318                    } else {
319                        select_request_messages.push(LanguageModelRequestMessage {
320                            role: Role::Assistant,
321                            content: vec![MessageContent::RedactedThinking(data)],
322                            cache: false,
323                        });
324                    }
325                }
326                ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
327                    log::error!("{ev:?}");
328                }
329                ev => {
330                    log::trace!("context search event: {ev:?}")
331                }
332            }
333        }
334
335        let search_tool_use = if search_calls.is_empty() {
336            log::warn!("context model ran 0 searches");
337            return anyhow::Ok(Default::default());
338        } else if search_calls.len() == 1 {
339            search_calls.swap_remove(0)
340        } else {
341            // In theory, the model could perform multiple search calls
342            // Dealing with them separately is not worth it when it doesn't happen in practice.
343            // If it were to happen, here we would combine them into one.
344            // The second request doesn't need to know it was actually two different calls ;)
345            let input = serde_json::to_value(&SearchToolInput {
346                queries: regex_by_glob
347                    .iter()
348                    .map(|(glob, regex)| SearchToolQuery {
349                        glob: glob.clone(),
350                        regex: regex.clone(),
351                    })
352                    .collect(),
353            })
354            .unwrap_or_default();
355
356            LanguageModelToolUse {
357                id: search_calls.swap_remove(0).id,
358                name: SELECT_TOOL_NAME.into(),
359                raw_input: serde_json::to_string(&input).unwrap_or_default(),
360                input,
361                is_input_complete: true,
362            }
363        };
364
365        if let Some(debug_tx) = &debug_tx {
366            debug_tx
367                .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
368                    ZetaSearchQueryDebugInfo {
369                        project: project.clone(),
370                        timestamp: Instant::now(),
371                        queries: regex_by_glob
372                            .iter()
373                            .map(|(glob, regex)| SearchToolQuery {
374                                glob: glob.clone(),
375                                regex: regex.clone(),
376                            })
377                            .collect(),
378                    },
379                ))
380                .ok();
381        }
382
383        let (results_tx, mut results_rx) = mpsc::unbounded();
384
385        for (glob, regex) in regex_by_glob {
386            let exclude_matcher = exclude_matcher.clone();
387            let results_tx = results_tx.clone();
388            let project = project.clone();
389            cx.spawn(async move |cx| {
390                run_query(
391                    &glob,
392                    ®ex,
393                    results_tx.clone(),
394                    path_style,
395                    exclude_matcher,
396                    &project,
397                    cx,
398                )
399                .await
400                .log_err();
401            })
402            .detach()
403        }
404        drop(results_tx);
405
406        struct ResultBuffer {
407            buffer: Entity<Buffer>,
408            snapshot: TextBufferSnapshot,
409        }
410
411        let (result_buffers_by_path, merged_result) = cx
412            .background_spawn(async move {
413                let mut excerpts_by_buffer: HashMap<Entity<Buffer>, MatchedBuffer> =
414                    HashMap::default();
415
416                while let Some((buffer, matched)) = results_rx.next().await {
417                    match excerpts_by_buffer.entry(buffer) {
418                        Entry::Occupied(mut entry) => {
419                            let entry = entry.get_mut();
420                            entry.full_path = matched.full_path;
421                            entry.snapshot = matched.snapshot;
422                            entry.line_ranges.extend(matched.line_ranges);
423                        }
424                        Entry::Vacant(entry) => {
425                            entry.insert(matched);
426                        }
427                    }
428                }
429
430                let mut result_buffers_by_path = HashMap::default();
431                let mut merged_result = RESULTS_MESSAGE.to_string();
432
433                for (buffer, mut matched) in excerpts_by_buffer {
434                    matched
435                        .line_ranges
436                        .sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
437
438                    write_codeblock(
439                        &matched.full_path,
440                        merge_excerpts(&matched.snapshot, matched.line_ranges).iter(),
441                        &[],
442                        Line(matched.snapshot.max_point().row),
443                        true,
444                        &mut merged_result,
445                    );
446
447                    result_buffers_by_path.insert(
448                        matched.full_path,
449                        ResultBuffer {
450                            buffer,
451                            snapshot: matched.snapshot.text,
452                        },
453                    );
454                }
455
456                (result_buffers_by_path, merged_result)
457            })
458            .await;
459
460        if let Some(debug_tx) = &debug_tx {
461            debug_tx
462                .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
463                    ZetaContextRetrievalDebugInfo {
464                        project: project.clone(),
465                        timestamp: Instant::now(),
466                    },
467                ))
468                .ok();
469        }
470
471        let tool_result = LanguageModelToolResult {
472            tool_use_id: search_tool_use.id.clone(),
473            tool_name: SEARCH_TOOL_NAME.into(),
474            is_error: false,
475            content: merged_result.into(),
476            output: None,
477        };
478
479        select_request_messages.extend([
480            LanguageModelRequestMessage {
481                role: Role::Assistant,
482                content: vec![MessageContent::ToolUse(search_tool_use)],
483                cache: false,
484            },
485            LanguageModelRequestMessage {
486                role: Role::User,
487                content: vec![MessageContent::ToolResult(tool_result)],
488                cache: false,
489            },
490        ]);
491
492        if result_buffers_by_path.is_empty() {
493            log::trace!("context gathering queries produced no results");
494            return anyhow::Ok(HashMap::default());
495        }
496
497        select_request_messages.push(LanguageModelRequestMessage {
498            role: Role::User,
499            content: vec![SELECT_PROMPT.into()],
500            cache: false,
501        });
502
503        let mut select_stream = request_tool_call::<SelectToolInput>(
504            select_request_messages,
505            SELECT_TOOL_NAME,
506            &model,
507            cx,
508        )
509        .await?;
510
511        cx.background_spawn(async move {
512            let mut selected_ranges = Vec::new();
513
514            while let Some(event) = select_stream.next().await {
515                match event? {
516                    LanguageModelCompletionEvent::ToolUse(tool_use) => {
517                        if !tool_use.is_input_complete {
518                            continue;
519                        }
520
521                        if tool_use.name.as_ref() == SELECT_TOOL_NAME {
522                            let call =
523                                serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
524                            selected_ranges.extend(call.ranges);
525                        } else {
526                            log::warn!(
527                                "context gathering model tried to use unknown tool: {}",
528                                tool_use.name
529                            );
530                        }
531                    }
532                    ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
533                        log::error!("{ev:?}");
534                    }
535                    ev => {
536                        log::trace!("context select event: {ev:?}")
537                    }
538                }
539            }
540
541            if let Some(debug_tx) = &debug_tx {
542                debug_tx
543                    .unbounded_send(ZetaDebugInfo::SearchResultsFiltered(
544                        ZetaContextRetrievalDebugInfo {
545                            project: project.clone(),
546                            timestamp: Instant::now(),
547                        },
548                    ))
549                    .ok();
550            }
551
552            if selected_ranges.is_empty() {
553                log::trace!("context gathering selected no ranges")
554            }
555
556            selected_ranges.sort_unstable_by(|a, b| {
557                a.start_line
558                    .cmp(&b.start_line)
559                    .then(b.end_line.cmp(&a.end_line))
560            });
561
562            let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
563
564            for selected_range in selected_ranges {
565                if let Some(ResultBuffer { buffer, snapshot }) =
566                    result_buffers_by_path.get(&selected_range.path)
567                {
568                    let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
569                    let end_point =
570                        snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
571                    let range =
572                        snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
573
574                    related_excerpts_by_buffer
575                        .entry(buffer.clone())
576                        .or_default()
577                        .push(range);
578                } else {
579                    log::warn!(
580                        "selected path that wasn't included in search results: {}",
581                        selected_range.path.display()
582                    );
583                }
584            }
585
586            anyhow::Ok(related_excerpts_by_buffer)
587        })
588        .await
589    })
590}
591
592async fn request_tool_call<T: JsonSchema>(
593    messages: Vec<LanguageModelRequestMessage>,
594    tool_name: &'static str,
595    model: &Arc<dyn LanguageModel>,
596    cx: &mut AsyncApp,
597) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
598{
599    let schema = schemars::schema_for!(T);
600
601    let request = LanguageModelRequest {
602        messages,
603        tools: vec![LanguageModelRequestTool {
604            name: tool_name.into(),
605            description: schema
606                .get("description")
607                .and_then(|description| description.as_str())
608                .unwrap()
609                .to_string(),
610            input_schema: serde_json::to_value(schema).unwrap(),
611        }],
612        ..Default::default()
613    };
614
615    Ok(model.stream_completion(request, cx).await?)
616}
617
618const MIN_EXCERPT_LEN: usize = 16;
619const MAX_EXCERPT_LEN: usize = 768;
620const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
621
622struct MatchedBuffer {
623    snapshot: BufferSnapshot,
624    line_ranges: Vec<Range<Line>>,
625    full_path: PathBuf,
626}
627
628async fn run_query(
629    glob: &str,
630    regex: &str,
631    results_tx: UnboundedSender<(Entity<Buffer>, MatchedBuffer)>,
632    path_style: PathStyle,
633    exclude_matcher: PathMatcher,
634    project: &Entity<Project>,
635    cx: &mut AsyncApp,
636) -> Result<()> {
637    let include_matcher = PathMatcher::new(vec![glob], path_style)?;
638
639    let query = SearchQuery::regex(
640        regex,
641        false,
642        true,
643        false,
644        true,
645        include_matcher,
646        exclude_matcher,
647        true,
648        None,
649    )?;
650
651    let results = project.update(cx, |project, cx| project.search(query, cx))?;
652    futures::pin_mut!(results);
653
654    let mut total_bytes = 0;
655
656    while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
657        if ranges.is_empty() {
658            continue;
659        }
660
661        let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| {
662            Some((buffer.snapshot(), buffer.file()?.full_path(cx)))
663        })?
664        else {
665            continue;
666        };
667
668        let results_tx = results_tx.clone();
669        cx.background_spawn(async move {
670            let mut line_ranges = Vec::with_capacity(ranges.len());
671
672            for range in ranges {
673                let offset_range = range.to_offset(&snapshot);
674                let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
675
676                if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
677                    break;
678                }
679
680                let excerpt = EditPredictionExcerpt::select_from_buffer(
681                    query_point,
682                    &snapshot,
683                    &EditPredictionExcerptOptions {
684                        max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
685                        min_bytes: MIN_EXCERPT_LEN,
686                        target_before_cursor_over_total_bytes: 0.5,
687                    },
688                    None,
689                );
690
691                if let Some(excerpt) = excerpt {
692                    total_bytes += excerpt.range.len();
693                    if !excerpt.line_range.is_empty() {
694                        line_ranges.push(excerpt.line_range);
695                    }
696                }
697            }
698
699            results_tx
700                .unbounded_send((
701                    buffer,
702                    MatchedBuffer {
703                        snapshot,
704                        line_ranges,
705                        full_path,
706                    },
707                ))
708                .log_err();
709        })
710        .detach();
711    }
712
713    anyhow::Ok(())
714}