1use std::{
  2    cmp::Reverse, collections::hash_map::Entry, ops::Range, path::PathBuf, sync::Arc, time::Instant,
  3};
  4
  5use crate::{
  6    ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
  7    ZetaSearchQueryDebugInfo, merge_excerpts::merge_excerpts,
  8};
  9use anyhow::{Result, anyhow};
 10use cloud_zeta2_prompt::write_codeblock;
 11use collections::HashMap;
 12use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
 13use futures::{
 14    StreamExt,
 15    channel::mpsc::{self, UnboundedSender},
 16    stream::BoxStream,
 17};
 18use gpui::{App, AppContext, AsyncApp, Entity, Task};
 19use indoc::indoc;
 20use language::{
 21    Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _,
 22};
 23use language_model::{
 24    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
 25    LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
 26    LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
 27    LanguageModelToolUse, MessageContent, Role,
 28};
 29use project::{
 30    Project, WorktreeSettings,
 31    search::{SearchQuery, SearchResult},
 32};
 33use schemars::JsonSchema;
 34use serde::{Deserialize, Serialize};
 35use util::{
 36    ResultExt as _,
 37    paths::{PathMatcher, PathStyle},
 38};
 39use workspace::item::Settings as _;
 40
 41const SEARCH_PROMPT: &str = indoc! {r#"
 42    ## Task
 43
 44    You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
 45    that will serve as context for predicting the next required edit.
 46
 47    **Your task:**
 48    - Analyze the user's recent edits and current cursor context
 49    - Use the `search` tool to find code that may be relevant for predicting the next edit
 50    - Focus on finding:
 51       - Code patterns that might need similar changes based on the recent edits
 52       - Functions, variables, types, and constants referenced in the current cursor context
 53       - Related implementations, usages, or dependencies that may require consistent updates
 54
 55    **Important constraints:**
 56    - This conversation has exactly 2 turns
 57    - You must make ALL search queries in your first response via the `search` tool
 58    - All queries will be executed in parallel and results returned together
 59    - In the second turn, you will select the most relevant results via the `select` tool.
 60
 61    ## User Edits
 62
 63    {edits}
 64
 65    ## Current cursor context
 66
 67    `````{current_file_path}
 68    {cursor_excerpt}
 69    `````
 70
 71    --
 72    Use the `search` tool now
 73"#};
 74
 75const SEARCH_TOOL_NAME: &str = "search";
 76
 77/// Search for relevant code
 78///
 79/// For the best results, run multiple queries at once with a single invocation of this tool.
 80#[derive(Clone, Deserialize, Serialize, JsonSchema)]
 81pub struct SearchToolInput {
 82    /// An array of queries to run for gathering context relevant to the next prediction
 83    #[schemars(length(max = 5))]
 84    pub queries: Box<[SearchToolQuery]>,
 85}
 86
 87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
 88pub struct SearchToolQuery {
 89    /// A glob pattern to match file paths in the codebase
 90    pub glob: String,
 91    /// A regular expression to match content within the files matched by the glob pattern
 92    pub regex: String,
 93}
 94
 95const RESULTS_MESSAGE: &str = indoc! {"
 96    Here are the results of your queries combined and grouped by file:
 97
 98"};
 99
100const SELECT_TOOL_NAME: &str = "select";
101
102const SELECT_PROMPT: &str = indoc! {"
103    Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
104    Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
105    Include up to 200 lines in total.
106"};
107
108/// Select line ranges from search results
109#[derive(Deserialize, JsonSchema)]
110struct SelectToolInput {
111    /// The line ranges to select from search results.
112    ranges: Vec<SelectLineRange>,
113}
114
115/// A specific line range to select from a file
116#[derive(Debug, Deserialize, JsonSchema)]
117struct SelectLineRange {
118    /// The file path containing the lines to select
119    /// Exactly as it appears in the search result codeblocks.
120    path: PathBuf,
121    /// The starting line number (1-based)
122    #[schemars(range(min = 1))]
123    start_line: u32,
124    /// The ending line number (1-based, inclusive)
125    #[schemars(range(min = 1))]
126    end_line: u32,
127}
128
129#[derive(Debug, Clone, PartialEq)]
130pub struct LlmContextOptions {
131    pub excerpt: EditPredictionExcerptOptions,
132}
133
134pub const MODEL_PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
135
136pub fn find_related_excerpts(
137    buffer: Entity<language::Buffer>,
138    cursor_position: Anchor,
139    project: &Entity<Project>,
140    mut edit_history_unified_diff: String,
141    options: &LlmContextOptions,
142    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
143    cx: &App,
144) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
145    let language_model_registry = LanguageModelRegistry::global(cx);
146    let Some(model) = language_model_registry
147        .read(cx)
148        .available_models(cx)
149        .find(|model| {
150            model.provider_id() == MODEL_PROVIDER_ID
151                && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
152            // model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b")
153            // model.provider_id() == LanguageModelProviderId::new("ollama")
154            //     && model.id() == LanguageModelId("gpt-oss:20b".into())
155        })
156    else {
157        return Task::ready(Err(anyhow!("could not find context model")));
158    };
159
160    if edit_history_unified_diff.is_empty() {
161        edit_history_unified_diff.push_str("(No user edits yet)");
162    }
163
164    // TODO [zeta2] include breadcrumbs?
165    let snapshot = buffer.read(cx).snapshot();
166    let cursor_point = cursor_position.to_point(&snapshot);
167    let Some(cursor_excerpt) =
168        EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
169    else {
170        return Task::ready(Ok(HashMap::default()));
171    };
172
173    let current_file_path = snapshot
174        .file()
175        .map(|f| f.full_path(cx).display().to_string())
176        .unwrap_or_else(|| "untitled".to_string());
177
178    let prompt = SEARCH_PROMPT
179        .replace("{edits}", &edit_history_unified_diff)
180        .replace("{current_file_path}", ¤t_file_path)
181        .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
182
183    if let Some(debug_tx) = &debug_tx {
184        debug_tx
185            .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
186                ZetaContextRetrievalStartedDebugInfo {
187                    project: project.clone(),
188                    timestamp: Instant::now(),
189                    search_prompt: prompt.clone(),
190                },
191            ))
192            .ok();
193    }
194
195    let path_style = project.read(cx).path_style(cx);
196
197    let exclude_matcher = {
198        let global_settings = WorktreeSettings::get_global(cx);
199        let exclude_patterns = global_settings
200            .file_scan_exclusions
201            .sources()
202            .iter()
203            .chain(global_settings.private_files.sources().iter());
204
205        match PathMatcher::new(exclude_patterns, path_style) {
206            Ok(matcher) => matcher,
207            Err(err) => {
208                return Task::ready(Err(anyhow!(err)));
209            }
210        }
211    };
212
213    let project = project.clone();
214    cx.spawn(async move |cx| {
215        let initial_prompt_message = LanguageModelRequestMessage {
216            role: Role::User,
217            content: vec![prompt.into()],
218            cache: false,
219        };
220
221        let mut search_stream = request_tool_call::<SearchToolInput>(
222            vec![initial_prompt_message.clone()],
223            SEARCH_TOOL_NAME,
224            &model,
225            cx,
226        )
227        .await?;
228
229        let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
230        select_request_messages.push(initial_prompt_message);
231
232        let mut regex_by_glob: HashMap<String, String> = HashMap::default();
233        let mut search_calls = Vec::new();
234
235        while let Some(event) = search_stream.next().await {
236            match event? {
237                LanguageModelCompletionEvent::ToolUse(tool_use) => {
238                    if !tool_use.is_input_complete {
239                        continue;
240                    }
241
242                    if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
243                        let input =
244                            serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
245
246                        for query in input.queries {
247                            let regex = regex_by_glob.entry(query.glob).or_default();
248                            if !regex.is_empty() {
249                                regex.push('|');
250                            }
251                            regex.push_str(&query.regex);
252                        }
253
254                        search_calls.push(tool_use);
255                    } else {
256                        log::warn!(
257                            "context gathering model tried to use unknown tool: {}",
258                            tool_use.name
259                        );
260                    }
261                }
262                LanguageModelCompletionEvent::Text(txt) => {
263                    if let Some(LanguageModelRequestMessage {
264                        role: Role::Assistant,
265                        content,
266                        ..
267                    }) = select_request_messages.last_mut()
268                    {
269                        if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
270                            existing_text.push_str(&txt);
271                        } else {
272                            content.push(MessageContent::Text(txt));
273                        }
274                    } else {
275                        select_request_messages.push(LanguageModelRequestMessage {
276                            role: Role::Assistant,
277                            content: vec![MessageContent::Text(txt)],
278                            cache: false,
279                        });
280                    }
281                }
282                LanguageModelCompletionEvent::Thinking { text, signature } => {
283                    if let Some(LanguageModelRequestMessage {
284                        role: Role::Assistant,
285                        content,
286                        ..
287                    }) = select_request_messages.last_mut()
288                    {
289                        if let Some(MessageContent::Thinking {
290                            text: existing_text,
291                            signature: existing_signature,
292                        }) = content.last_mut()
293                        {
294                            existing_text.push_str(&text);
295                            *existing_signature = signature;
296                        } else {
297                            content.push(MessageContent::Thinking { text, signature });
298                        }
299                    } else {
300                        select_request_messages.push(LanguageModelRequestMessage {
301                            role: Role::Assistant,
302                            content: vec![MessageContent::Thinking { text, signature }],
303                            cache: false,
304                        });
305                    }
306                }
307                LanguageModelCompletionEvent::RedactedThinking { data } => {
308                    if let Some(LanguageModelRequestMessage {
309                        role: Role::Assistant,
310                        content,
311                        ..
312                    }) = select_request_messages.last_mut()
313                    {
314                        if let Some(MessageContent::RedactedThinking(existing_data)) =
315                            content.last_mut()
316                        {
317                            existing_data.push_str(&data);
318                        } else {
319                            content.push(MessageContent::RedactedThinking(data));
320                        }
321                    } else {
322                        select_request_messages.push(LanguageModelRequestMessage {
323                            role: Role::Assistant,
324                            content: vec![MessageContent::RedactedThinking(data)],
325                            cache: false,
326                        });
327                    }
328                }
329                ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
330                    log::error!("{ev:?}");
331                }
332                ev => {
333                    log::trace!("context search event: {ev:?}")
334                }
335            }
336        }
337
338        let search_tool_use = if search_calls.is_empty() {
339            log::warn!("context model ran 0 searches");
340            return anyhow::Ok(Default::default());
341        } else if search_calls.len() == 1 {
342            search_calls.swap_remove(0)
343        } else {
344            // In theory, the model could perform multiple search calls
345            // Dealing with them separately is not worth it when it doesn't happen in practice.
346            // If it were to happen, here we would combine them into one.
347            // The second request doesn't need to know it was actually two different calls ;)
348            let input = serde_json::to_value(&SearchToolInput {
349                queries: regex_by_glob
350                    .iter()
351                    .map(|(glob, regex)| SearchToolQuery {
352                        glob: glob.clone(),
353                        regex: regex.clone(),
354                    })
355                    .collect(),
356            })
357            .unwrap_or_default();
358
359            LanguageModelToolUse {
360                id: search_calls.swap_remove(0).id,
361                name: SELECT_TOOL_NAME.into(),
362                raw_input: serde_json::to_string(&input).unwrap_or_default(),
363                input,
364                is_input_complete: true,
365            }
366        };
367
368        if let Some(debug_tx) = &debug_tx {
369            debug_tx
370                .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
371                    ZetaSearchQueryDebugInfo {
372                        project: project.clone(),
373                        timestamp: Instant::now(),
374                        queries: regex_by_glob
375                            .iter()
376                            .map(|(glob, regex)| SearchToolQuery {
377                                glob: glob.clone(),
378                                regex: regex.clone(),
379                            })
380                            .collect(),
381                    },
382                ))
383                .ok();
384        }
385
386        let (results_tx, mut results_rx) = mpsc::unbounded();
387
388        for (glob, regex) in regex_by_glob {
389            let exclude_matcher = exclude_matcher.clone();
390            let results_tx = results_tx.clone();
391            let project = project.clone();
392            cx.spawn(async move |cx| {
393                run_query(
394                    &glob,
395                    ®ex,
396                    results_tx.clone(),
397                    path_style,
398                    exclude_matcher,
399                    &project,
400                    cx,
401                )
402                .await
403                .log_err();
404            })
405            .detach()
406        }
407        drop(results_tx);
408
409        struct ResultBuffer {
410            buffer: Entity<Buffer>,
411            snapshot: TextBufferSnapshot,
412        }
413
414        let (result_buffers_by_path, merged_result) = cx
415            .background_spawn(async move {
416                let mut excerpts_by_buffer: HashMap<Entity<Buffer>, MatchedBuffer> =
417                    HashMap::default();
418
419                while let Some((buffer, matched)) = results_rx.next().await {
420                    match excerpts_by_buffer.entry(buffer) {
421                        Entry::Occupied(mut entry) => {
422                            let entry = entry.get_mut();
423                            entry.full_path = matched.full_path;
424                            entry.snapshot = matched.snapshot;
425                            entry.line_ranges.extend(matched.line_ranges);
426                        }
427                        Entry::Vacant(entry) => {
428                            entry.insert(matched);
429                        }
430                    }
431                }
432
433                let mut result_buffers_by_path = HashMap::default();
434                let mut merged_result = RESULTS_MESSAGE.to_string();
435
436                for (buffer, mut matched) in excerpts_by_buffer {
437                    matched
438                        .line_ranges
439                        .sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
440
441                    write_codeblock(
442                        &matched.full_path,
443                        merge_excerpts(&matched.snapshot, matched.line_ranges).iter(),
444                        &[],
445                        Line(matched.snapshot.max_point().row),
446                        true,
447                        &mut merged_result,
448                    );
449
450                    result_buffers_by_path.insert(
451                        matched.full_path,
452                        ResultBuffer {
453                            buffer,
454                            snapshot: matched.snapshot.text,
455                        },
456                    );
457                }
458
459                (result_buffers_by_path, merged_result)
460            })
461            .await;
462
463        if let Some(debug_tx) = &debug_tx {
464            debug_tx
465                .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
466                    ZetaContextRetrievalDebugInfo {
467                        project: project.clone(),
468                        timestamp: Instant::now(),
469                    },
470                ))
471                .ok();
472        }
473
474        let tool_result = LanguageModelToolResult {
475            tool_use_id: search_tool_use.id.clone(),
476            tool_name: SEARCH_TOOL_NAME.into(),
477            is_error: false,
478            content: merged_result.into(),
479            output: None,
480        };
481
482        select_request_messages.extend([
483            LanguageModelRequestMessage {
484                role: Role::Assistant,
485                content: vec![MessageContent::ToolUse(search_tool_use)],
486                cache: false,
487            },
488            LanguageModelRequestMessage {
489                role: Role::User,
490                content: vec![MessageContent::ToolResult(tool_result)],
491                cache: false,
492            },
493        ]);
494
495        if result_buffers_by_path.is_empty() {
496            log::trace!("context gathering queries produced no results");
497            return anyhow::Ok(HashMap::default());
498        }
499
500        select_request_messages.push(LanguageModelRequestMessage {
501            role: Role::User,
502            content: vec![SELECT_PROMPT.into()],
503            cache: false,
504        });
505
506        let mut select_stream = request_tool_call::<SelectToolInput>(
507            select_request_messages,
508            SELECT_TOOL_NAME,
509            &model,
510            cx,
511        )
512        .await?;
513
514        cx.background_spawn(async move {
515            let mut selected_ranges = Vec::new();
516
517            while let Some(event) = select_stream.next().await {
518                match event? {
519                    LanguageModelCompletionEvent::ToolUse(tool_use) => {
520                        if !tool_use.is_input_complete {
521                            continue;
522                        }
523
524                        if tool_use.name.as_ref() == SELECT_TOOL_NAME {
525                            let call =
526                                serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
527                            selected_ranges.extend(call.ranges);
528                        } else {
529                            log::warn!(
530                                "context gathering model tried to use unknown tool: {}",
531                                tool_use.name
532                            );
533                        }
534                    }
535                    ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
536                        log::error!("{ev:?}");
537                    }
538                    ev => {
539                        log::trace!("context select event: {ev:?}")
540                    }
541                }
542            }
543
544            if let Some(debug_tx) = &debug_tx {
545                debug_tx
546                    .unbounded_send(ZetaDebugInfo::SearchResultsFiltered(
547                        ZetaContextRetrievalDebugInfo {
548                            project: project.clone(),
549                            timestamp: Instant::now(),
550                        },
551                    ))
552                    .ok();
553            }
554
555            if selected_ranges.is_empty() {
556                log::trace!("context gathering selected no ranges")
557            }
558
559            selected_ranges.sort_unstable_by(|a, b| {
560                a.start_line
561                    .cmp(&b.start_line)
562                    .then(b.end_line.cmp(&a.end_line))
563            });
564
565            let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
566
567            for selected_range in selected_ranges {
568                if let Some(ResultBuffer { buffer, snapshot }) =
569                    result_buffers_by_path.get(&selected_range.path)
570                {
571                    let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
572                    let end_point =
573                        snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
574                    let range =
575                        snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
576
577                    related_excerpts_by_buffer
578                        .entry(buffer.clone())
579                        .or_default()
580                        .push(range);
581                } else {
582                    log::warn!(
583                        "selected path that wasn't included in search results: {}",
584                        selected_range.path.display()
585                    );
586                }
587            }
588
589            anyhow::Ok(related_excerpts_by_buffer)
590        })
591        .await
592    })
593}
594
595async fn request_tool_call<T: JsonSchema>(
596    messages: Vec<LanguageModelRequestMessage>,
597    tool_name: &'static str,
598    model: &Arc<dyn LanguageModel>,
599    cx: &mut AsyncApp,
600) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
601{
602    let schema = schemars::schema_for!(T);
603
604    let request = LanguageModelRequest {
605        messages,
606        tools: vec![LanguageModelRequestTool {
607            name: tool_name.into(),
608            description: schema
609                .get("description")
610                .and_then(|description| description.as_str())
611                .unwrap()
612                .to_string(),
613            input_schema: serde_json::to_value(schema).unwrap(),
614        }],
615        ..Default::default()
616    };
617
618    Ok(model.stream_completion(request, cx).await?)
619}
620
621const MIN_EXCERPT_LEN: usize = 16;
622const MAX_EXCERPT_LEN: usize = 768;
623const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
624
625struct MatchedBuffer {
626    snapshot: BufferSnapshot,
627    line_ranges: Vec<Range<Line>>,
628    full_path: PathBuf,
629}
630
631async fn run_query(
632    glob: &str,
633    regex: &str,
634    results_tx: UnboundedSender<(Entity<Buffer>, MatchedBuffer)>,
635    path_style: PathStyle,
636    exclude_matcher: PathMatcher,
637    project: &Entity<Project>,
638    cx: &mut AsyncApp,
639) -> Result<()> {
640    let include_matcher = PathMatcher::new(vec![glob], path_style)?;
641
642    let query = SearchQuery::regex(
643        regex,
644        false,
645        true,
646        false,
647        true,
648        include_matcher,
649        exclude_matcher,
650        true,
651        None,
652    )?;
653
654    let results = project.update(cx, |project, cx| project.search(query, cx))?;
655    futures::pin_mut!(results);
656
657    let mut total_bytes = 0;
658
659    while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
660        if ranges.is_empty() {
661            continue;
662        }
663
664        let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| {
665            Some((buffer.snapshot(), buffer.file()?.full_path(cx)))
666        })?
667        else {
668            continue;
669        };
670
671        let results_tx = results_tx.clone();
672        cx.background_spawn(async move {
673            let mut line_ranges = Vec::with_capacity(ranges.len());
674
675            for range in ranges {
676                let offset_range = range.to_offset(&snapshot);
677                let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
678
679                if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
680                    break;
681                }
682
683                let excerpt = EditPredictionExcerpt::select_from_buffer(
684                    query_point,
685                    &snapshot,
686                    &EditPredictionExcerptOptions {
687                        max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
688                        min_bytes: MIN_EXCERPT_LEN,
689                        target_before_cursor_over_total_bytes: 0.5,
690                    },
691                    None,
692                );
693
694                if let Some(excerpt) = excerpt {
695                    total_bytes += excerpt.range.len();
696                    if !excerpt.line_range.is_empty() {
697                        line_ranges.push(excerpt.line_range);
698                    }
699                }
700            }
701
702            results_tx
703                .unbounded_send((
704                    buffer,
705                    MatchedBuffer {
706                        snapshot,
707                        line_ranges,
708                        full_path,
709                    },
710                ))
711                .log_err();
712        })
713        .detach();
714    }
715
716    anyhow::Ok(())
717}