1use std::{cmp::Reverse, fmt::Write, ops::Range, path::PathBuf, sync::Arc, time::Instant};
  2
  3use crate::{
  4    ZetaContextRetrievalDebugInfo, ZetaDebugInfo, ZetaSearchQueryDebugInfo,
  5    merge_excerpts::write_merged_excerpts,
  6};
  7use anyhow::{Result, anyhow};
  8use collections::HashMap;
  9use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
 10use futures::{StreamExt, channel::mpsc, stream::BoxStream};
 11use gpui::{App, AsyncApp, Entity, Task};
 12use indoc::indoc;
 13use language::{Anchor, Bias, Buffer, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _};
 14use language_model::{
 15    LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
 16    LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
 17    LanguageModelRequestTool, LanguageModelToolResult, MessageContent, Role,
 18};
 19use project::{
 20    Project, WorktreeSettings,
 21    search::{SearchQuery, SearchResult},
 22};
 23use schemars::JsonSchema;
 24use serde::Deserialize;
 25use util::paths::{PathMatcher, PathStyle};
 26use workspace::item::Settings as _;
 27
 28const SEARCH_PROMPT: &str = indoc! {r#"
 29    ## Task
 30
 31    You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
 32    that will serve as context for predicting the next required edit.
 33
 34    **Your task:**
 35    - Analyze the user's recent edits and current cursor context
 36    - Use the `search` tool to find code that may be relevant for predicting the next edit
 37    - Focus on finding:
 38       - Code patterns that might need similar changes based on the recent edits
 39       - Functions, variables, types, and constants referenced in the current cursor context
 40       - Related implementations, usages, or dependencies that may require consistent updates
 41
 42    **Important constraints:**
 43    - This conversation has exactly 2 turns
 44    - You must make ALL search queries in your first response via the `search` tool
 45    - All queries will be executed in parallel and results returned together
 46    - In the second turn, you will select the most relevant results via the `select` tool.
 47
 48    ## User Edits
 49
 50    {edits}
 51
 52    ## Current cursor context
 53
 54    `````filename={current_file_path}
 55    {cursor_excerpt}
 56    `````
 57
 58    --
 59    Use the `search` tool now
 60"#};
 61
 62const SEARCH_TOOL_NAME: &str = "search";
 63
 64/// Search for relevant code
 65///
 66/// For the best results, run multiple queries at once with a single invocation of this tool.
 67#[derive(Clone, Deserialize, JsonSchema)]
 68pub struct SearchToolInput {
 69    /// An array of queries to run for gathering context relevant to the next prediction
 70    #[schemars(length(max = 5))]
 71    pub queries: Box<[SearchToolQuery]>,
 72}
 73
 74#[derive(Debug, Clone, Deserialize, JsonSchema)]
 75pub struct SearchToolQuery {
 76    /// A glob pattern to match file paths in the codebase
 77    pub glob: String,
 78    /// A regular expression to match content within the files matched by the glob pattern
 79    pub regex: String,
 80    /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
 81    #[serde(default)]
 82    pub case_sensitive: bool,
 83}
 84
 85const RESULTS_MESSAGE: &str = indoc! {"
 86    Here are the results of your queries combined and grouped by file:
 87
 88"};
 89
 90const SELECT_TOOL_NAME: &str = "select";
 91
 92const SELECT_PROMPT: &str = indoc! {"
 93    Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
 94    Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
 95    Include up to 200 lines in total.
 96"};
 97
 98/// Select line ranges from search results
 99#[derive(Deserialize, JsonSchema)]
100struct SelectToolInput {
101    /// The line ranges to select from search results.
102    ranges: Vec<SelectLineRange>,
103}
104
105/// A specific line range to select from a file
106#[derive(Debug, Deserialize, JsonSchema)]
107struct SelectLineRange {
108    /// The file path containing the lines to select
109    /// Exactly as it appears in the search result codeblocks.
110    path: PathBuf,
111    /// The starting line number (1-based)
112    #[schemars(range(min = 1))]
113    start_line: u32,
114    /// The ending line number (1-based, inclusive)
115    #[schemars(range(min = 1))]
116    end_line: u32,
117}
118
119#[derive(Debug, Clone, PartialEq)]
120pub struct LlmContextOptions {
121    pub excerpt: EditPredictionExcerptOptions,
122}
123
124pub fn find_related_excerpts<'a>(
125    buffer: Entity<language::Buffer>,
126    cursor_position: Anchor,
127    project: &Entity<Project>,
128    events: impl Iterator<Item = &'a crate::Event>,
129    options: &LlmContextOptions,
130    debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
131    cx: &App,
132) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
133    let language_model_registry = LanguageModelRegistry::global(cx);
134    let Some(model) = language_model_registry
135        .read(cx)
136        .available_models(cx)
137        .find(|model| {
138            model.provider_id() == language_model::OLLAMA_PROVIDER_ID
139                && model.id() == LanguageModelId("qwen3:8b".into())
140        })
141    else {
142        return Task::ready(Err(anyhow!("could not find claude model")));
143    };
144
145    let mut edits_string = String::new();
146
147    for event in events {
148        if let Some(event) = event.to_request_event(cx) {
149            writeln!(&mut edits_string, "{event}").ok();
150        }
151    }
152
153    if edits_string.is_empty() {
154        edits_string.push_str("(No user edits yet)");
155    }
156
157    // TODO [zeta2] include breadcrumbs?
158    let snapshot = buffer.read(cx).snapshot();
159    let cursor_point = cursor_position.to_point(&snapshot);
160    let Some(cursor_excerpt) =
161        EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
162    else {
163        return Task::ready(Ok(HashMap::default()));
164    };
165
166    let current_file_path = snapshot
167        .file()
168        .map(|f| f.full_path(cx).display().to_string())
169        .unwrap_or_else(|| "untitled".to_string());
170
171    let prompt = SEARCH_PROMPT
172        .replace("{edits}", &edits_string)
173        .replace("{current_file_path}", ¤t_file_path)
174        .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
175
176    let path_style = project.read(cx).path_style(cx);
177
178    let exclude_matcher = {
179        let global_settings = WorktreeSettings::get_global(cx);
180        let exclude_patterns = global_settings
181            .file_scan_exclusions
182            .sources()
183            .iter()
184            .chain(global_settings.private_files.sources().iter());
185
186        match PathMatcher::new(exclude_patterns, path_style) {
187            Ok(matcher) => matcher,
188            Err(err) => {
189                return Task::ready(Err(anyhow!(err)));
190            }
191        }
192    };
193
194    let project = project.clone();
195    cx.spawn(async move |cx| {
196        let initial_prompt_message = LanguageModelRequestMessage {
197            role: Role::User,
198            content: vec![prompt.into()],
199            cache: false,
200        };
201
202        let mut search_stream = request_tool_call::<SearchToolInput>(
203            vec![initial_prompt_message.clone()],
204            SEARCH_TOOL_NAME,
205            &model,
206            cx,
207        )
208        .await?;
209
210        let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
211        select_request_messages.push(initial_prompt_message);
212        let mut search_calls = Vec::new();
213
214        while let Some(event) = search_stream.next().await {
215            match event? {
216                LanguageModelCompletionEvent::ToolUse(tool_use) => {
217                    if !tool_use.is_input_complete {
218                        continue;
219                    }
220
221                    if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
222                        search_calls.push((select_request_messages.len(), tool_use));
223                    } else {
224                        log::warn!(
225                            "context gathering model tried to use unknown tool: {}",
226                            tool_use.name
227                        );
228                    }
229                }
230                LanguageModelCompletionEvent::Text(txt) => {
231                    if let Some(LanguageModelRequestMessage {
232                        role: Role::Assistant,
233                        content,
234                        ..
235                    }) = select_request_messages.last_mut()
236                    {
237                        if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
238                            existing_text.push_str(&txt);
239                        } else {
240                            content.push(MessageContent::Text(txt));
241                        }
242                    } else {
243                        select_request_messages.push(LanguageModelRequestMessage {
244                            role: Role::Assistant,
245                            content: vec![MessageContent::Text(txt)],
246                            cache: false,
247                        });
248                    }
249                }
250                LanguageModelCompletionEvent::Thinking { text, signature } => {
251                    if let Some(LanguageModelRequestMessage {
252                        role: Role::Assistant,
253                        content,
254                        ..
255                    }) = select_request_messages.last_mut()
256                    {
257                        if let Some(MessageContent::Thinking {
258                            text: existing_text,
259                            signature: existing_signature,
260                        }) = content.last_mut()
261                        {
262                            existing_text.push_str(&text);
263                            *existing_signature = signature;
264                        } else {
265                            content.push(MessageContent::Thinking { text, signature });
266                        }
267                    } else {
268                        select_request_messages.push(LanguageModelRequestMessage {
269                            role: Role::Assistant,
270                            content: vec![MessageContent::Thinking { text, signature }],
271                            cache: false,
272                        });
273                    }
274                }
275                LanguageModelCompletionEvent::RedactedThinking { data } => {
276                    if let Some(LanguageModelRequestMessage {
277                        role: Role::Assistant,
278                        content,
279                        ..
280                    }) = select_request_messages.last_mut()
281                    {
282                        if let Some(MessageContent::RedactedThinking(existing_data)) =
283                            content.last_mut()
284                        {
285                            existing_data.push_str(&data);
286                        } else {
287                            content.push(MessageContent::RedactedThinking(data));
288                        }
289                    } else {
290                        select_request_messages.push(LanguageModelRequestMessage {
291                            role: Role::Assistant,
292                            content: vec![MessageContent::RedactedThinking(data)],
293                            cache: false,
294                        });
295                    }
296                }
297                ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
298                    log::error!("{ev:?}");
299                }
300                ev => {
301                    log::trace!("context search event: {ev:?}")
302                }
303            }
304        }
305
306        struct ResultBuffer {
307            buffer: Entity<Buffer>,
308            snapshot: TextBufferSnapshot,
309        }
310
311        let search_queries = search_calls
312            .iter()
313            .map(|(_, tool_use)| {
314                Ok(serde_json::from_value::<SearchToolInput>(
315                    tool_use.input.clone(),
316                )?)
317            })
318            .collect::<Result<Vec<_>>>()?;
319
320        if let Some(debug_tx) = &debug_tx {
321            debug_tx
322                .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
323                    ZetaSearchQueryDebugInfo {
324                        project: project.clone(),
325                        timestamp: Instant::now(),
326                        queries: search_queries
327                            .iter()
328                            .flat_map(|call| call.queries.iter().cloned())
329                            .collect(),
330                    },
331                ))
332                .ok();
333        }
334
335        let mut result_buffers_by_path = HashMap::default();
336
337        for ((index, tool_use), call) in search_calls.into_iter().zip(search_queries).rev() {
338            let mut excerpts_by_buffer = HashMap::default();
339
340            for query in call.queries {
341                // TODO [zeta2] parallelize?
342
343                run_query(
344                    query,
345                    &mut excerpts_by_buffer,
346                    path_style,
347                    exclude_matcher.clone(),
348                    &project,
349                    cx,
350                )
351                .await?;
352            }
353
354            if excerpts_by_buffer.is_empty() {
355                continue;
356            }
357
358            let mut merged_result = RESULTS_MESSAGE.to_string();
359
360            for (buffer_entity, mut excerpts_for_buffer) in excerpts_by_buffer {
361                excerpts_for_buffer.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
362
363                buffer_entity
364                    .clone()
365                    .read_with(cx, |buffer, cx| {
366                        let Some(file) = buffer.file() else {
367                            return;
368                        };
369
370                        let path = file.full_path(cx);
371
372                        writeln!(&mut merged_result, "`````filename={}", path.display()).unwrap();
373
374                        let snapshot = buffer.snapshot();
375
376                        write_merged_excerpts(
377                            &snapshot,
378                            excerpts_for_buffer,
379                            &[],
380                            &mut merged_result,
381                        );
382
383                        merged_result.push_str("`````\n\n");
384
385                        result_buffers_by_path.insert(
386                            path,
387                            ResultBuffer {
388                                buffer: buffer_entity,
389                                snapshot: snapshot.text,
390                            },
391                        );
392                    })
393                    .ok();
394            }
395
396            let tool_result = LanguageModelToolResult {
397                tool_use_id: tool_use.id.clone(),
398                tool_name: SEARCH_TOOL_NAME.into(),
399                is_error: false,
400                content: merged_result.into(),
401                output: None,
402            };
403
404            // Almost always appends at the end, but in theory, the model could return some text after the tool call
405            // or perform parallel tool calls, so we splice at the message index for correctness.
406            select_request_messages.splice(
407                index..index,
408                [
409                    LanguageModelRequestMessage {
410                        role: Role::Assistant,
411                        content: vec![MessageContent::ToolUse(tool_use)],
412                        cache: false,
413                    },
414                    LanguageModelRequestMessage {
415                        role: Role::User,
416                        content: vec![MessageContent::ToolResult(tool_result)],
417                        cache: false,
418                    },
419                ],
420            );
421
422            if let Some(debug_tx) = &debug_tx {
423                debug_tx
424                    .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
425                        ZetaContextRetrievalDebugInfo {
426                            project: project.clone(),
427                            timestamp: Instant::now(),
428                        },
429                    ))
430                    .ok();
431            }
432        }
433
434        if result_buffers_by_path.is_empty() {
435            log::trace!("context gathering queries produced no results");
436            return anyhow::Ok(HashMap::default());
437        }
438
439        select_request_messages.push(LanguageModelRequestMessage {
440            role: Role::User,
441            content: vec![SELECT_PROMPT.into()],
442            cache: false,
443        });
444
445        let mut select_stream = request_tool_call::<SelectToolInput>(
446            select_request_messages,
447            SELECT_TOOL_NAME,
448            &model,
449            cx,
450        )
451        .await?;
452        let mut selected_ranges = Vec::new();
453
454        while let Some(event) = select_stream.next().await {
455            match event? {
456                LanguageModelCompletionEvent::ToolUse(tool_use) => {
457                    if !tool_use.is_input_complete {
458                        continue;
459                    }
460
461                    if tool_use.name.as_ref() == SELECT_TOOL_NAME {
462                        let call =
463                            serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
464                        selected_ranges.extend(call.ranges);
465                    } else {
466                        log::warn!(
467                            "context gathering model tried to use unknown tool: {}",
468                            tool_use.name
469                        );
470                    }
471                }
472                ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
473                    log::error!("{ev:?}");
474                }
475                ev => {
476                    log::trace!("context select event: {ev:?}")
477                }
478            }
479        }
480
481        if selected_ranges.is_empty() {
482            log::trace!("context gathering selected no ranges")
483        }
484
485        let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
486
487        for selected_range in selected_ranges {
488            if let Some(ResultBuffer { buffer, snapshot }) =
489                result_buffers_by_path.get(&selected_range.path)
490            {
491                let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
492                let end_point =
493                    snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
494                let range = snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
495
496                related_excerpts_by_buffer
497                    .entry(buffer.clone())
498                    .or_default()
499                    .push(range);
500            } else {
501                log::warn!(
502                    "selected path that wasn't included in search results: {}",
503                    selected_range.path.display()
504                );
505            }
506        }
507
508        for (buffer, ranges) in &mut related_excerpts_by_buffer {
509            buffer.read_with(cx, |buffer, _cx| {
510                ranges.sort_unstable_by(|a, b| {
511                    a.start
512                        .cmp(&b.start, buffer)
513                        .then(b.end.cmp(&a.end, buffer))
514                });
515            })?;
516        }
517
518        anyhow::Ok(related_excerpts_by_buffer)
519    })
520}
521
522async fn request_tool_call<T: JsonSchema>(
523    messages: Vec<LanguageModelRequestMessage>,
524    tool_name: &'static str,
525    model: &Arc<dyn LanguageModel>,
526    cx: &mut AsyncApp,
527) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
528{
529    let schema = schemars::schema_for!(T);
530
531    let request = LanguageModelRequest {
532        messages,
533        tools: vec![LanguageModelRequestTool {
534            name: tool_name.into(),
535            description: schema
536                .get("description")
537                .and_then(|description| description.as_str())
538                .unwrap()
539                .to_string(),
540            input_schema: serde_json::to_value(schema).unwrap(),
541        }],
542        thinking_allowed: false,
543        ..Default::default()
544    };
545
546    Ok(model.stream_completion(request, cx).await?)
547}
548
549const MIN_EXCERPT_LEN: usize = 16;
550const MAX_EXCERPT_LEN: usize = 768;
551const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
552
553async fn run_query(
554    args: SearchToolQuery,
555    excerpts_by_buffer: &mut HashMap<Entity<Buffer>, Vec<Range<Line>>>,
556    path_style: PathStyle,
557    exclude_matcher: PathMatcher,
558    project: &Entity<Project>,
559    cx: &mut AsyncApp,
560) -> Result<()> {
561    let include_matcher = PathMatcher::new(vec![args.glob], path_style)?;
562
563    let query = SearchQuery::regex(
564        &args.regex,
565        false,
566        args.case_sensitive,
567        false,
568        true,
569        include_matcher,
570        exclude_matcher,
571        true,
572        None,
573    )?;
574
575    let results = project.update(cx, |project, cx| project.search(query, cx))?;
576    futures::pin_mut!(results);
577
578    let mut total_bytes = 0;
579
580    while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
581        if ranges.is_empty() {
582            continue;
583        }
584
585        let excerpts_for_buffer = excerpts_by_buffer
586            .entry(buffer.clone())
587            .or_insert_with(|| Vec::with_capacity(ranges.len()));
588
589        let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
590
591        for range in ranges {
592            let offset_range = range.to_offset(&snapshot);
593            let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
594
595            if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
596                break;
597            }
598
599            let excerpt = EditPredictionExcerpt::select_from_buffer(
600                query_point,
601                &snapshot,
602                &EditPredictionExcerptOptions {
603                    max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
604                    min_bytes: MIN_EXCERPT_LEN,
605                    target_before_cursor_over_total_bytes: 0.5,
606                },
607                None,
608            );
609
610            if let Some(excerpt) = excerpt {
611                total_bytes += excerpt.range.len();
612                if !excerpt.line_range.is_empty() {
613                    excerpts_for_buffer.push(excerpt.line_range);
614                }
615            }
616        }
617
618        if excerpts_for_buffer.is_empty() {
619            excerpts_by_buffer.remove(&buffer);
620        }
621    }
622
623    anyhow::Ok(())
624}