1use std::{
2 cmp::Reverse, collections::hash_map::Entry, ops::Range, path::PathBuf, sync::Arc, time::Instant,
3};
4
5use crate::{
6 ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
7 ZetaSearchQueryDebugInfo, merge_excerpts::merge_excerpts,
8};
9use anyhow::{Result, anyhow};
10use cloud_zeta2_prompt::write_codeblock;
11use collections::HashMap;
12use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
13use futures::{
14 StreamExt,
15 channel::mpsc::{self, UnboundedSender},
16 stream::BoxStream,
17};
18use gpui::{App, AppContext, AsyncApp, Entity, Task};
19use indoc::indoc;
20use language::{
21 Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _,
22};
23use language_model::{
24 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
25 LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
26 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
27 LanguageModelToolUse, MessageContent, Role,
28};
29use project::{
30 Project, WorktreeSettings,
31 search::{SearchQuery, SearchResult},
32};
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use util::{
36 ResultExt as _,
37 paths::{PathMatcher, PathStyle},
38};
39use workspace::item::Settings as _;
40
41const SEARCH_PROMPT: &str = indoc! {r#"
42 ## Task
43
44 You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
45 that will serve as context for predicting the next required edit.
46
47 **Your task:**
48 - Analyze the user's recent edits and current cursor context
49 - Use the `search` tool to find code that may be relevant for predicting the next edit
50 - Focus on finding:
51 - Code patterns that might need similar changes based on the recent edits
52 - Functions, variables, types, and constants referenced in the current cursor context
53 - Related implementations, usages, or dependencies that may require consistent updates
54
55 **Important constraints:**
56 - This conversation has exactly 2 turns
57 - You must make ALL search queries in your first response via the `search` tool
58 - All queries will be executed in parallel and results returned together
59 - In the second turn, you will select the most relevant results via the `select` tool.
60
61 ## User Edits
62
63 {edits}
64
65 ## Current cursor context
66
67 `````{current_file_path}
68 {cursor_excerpt}
69 `````
70
71 --
72 Use the `search` tool now
73"#};
74
75const SEARCH_TOOL_NAME: &str = "search";
76
77/// Search for relevant code
78///
79/// For the best results, run multiple queries at once with a single invocation of this tool.
80#[derive(Clone, Deserialize, Serialize, JsonSchema)]
81pub struct SearchToolInput {
82 /// An array of queries to run for gathering context relevant to the next prediction
83 #[schemars(length(max = 5))]
84 pub queries: Box<[SearchToolQuery]>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
88pub struct SearchToolQuery {
89 /// A glob pattern to match file paths in the codebase
90 pub glob: String,
91 /// A regular expression to match content within the files matched by the glob pattern
92 pub regex: String,
93}
94
95const RESULTS_MESSAGE: &str = indoc! {"
96 Here are the results of your queries combined and grouped by file:
97
98"};
99
100const SELECT_TOOL_NAME: &str = "select";
101
102const SELECT_PROMPT: &str = indoc! {"
103 Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
104 Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
105 Include up to 200 lines in total.
106"};
107
108/// Select line ranges from search results
109#[derive(Deserialize, JsonSchema)]
110struct SelectToolInput {
111 /// The line ranges to select from search results.
112 ranges: Vec<SelectLineRange>,
113}
114
115/// A specific line range to select from a file
116#[derive(Debug, Deserialize, JsonSchema)]
117struct SelectLineRange {
118 /// The file path containing the lines to select
119 /// Exactly as it appears in the search result codeblocks.
120 path: PathBuf,
121 /// The starting line number (1-based)
122 #[schemars(range(min = 1))]
123 start_line: u32,
124 /// The ending line number (1-based, inclusive)
125 #[schemars(range(min = 1))]
126 end_line: u32,
127}
128
129#[derive(Debug, Clone, PartialEq)]
130pub struct LlmContextOptions {
131 pub excerpt: EditPredictionExcerptOptions,
132}
133
134pub const MODEL_PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
135
136pub fn find_related_excerpts(
137 buffer: Entity<language::Buffer>,
138 cursor_position: Anchor,
139 project: &Entity<Project>,
140 mut edit_history_unified_diff: String,
141 options: &LlmContextOptions,
142 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
143 cx: &App,
144) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
145 let language_model_registry = LanguageModelRegistry::global(cx);
146 let Some(model) = language_model_registry
147 .read(cx)
148 .available_models(cx)
149 .find(|model| {
150 model.provider_id() == MODEL_PROVIDER_ID
151 && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
152 // model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b")
153 // model.provider_id() == LanguageModelProviderId::new("ollama")
154 // && model.id() == LanguageModelId("gpt-oss:20b".into())
155 })
156 else {
157 return Task::ready(Err(anyhow!("could not find context model")));
158 };
159
160 if edit_history_unified_diff.is_empty() {
161 edit_history_unified_diff.push_str("(No user edits yet)");
162 }
163
164 // TODO [zeta2] include breadcrumbs?
165 let snapshot = buffer.read(cx).snapshot();
166 let cursor_point = cursor_position.to_point(&snapshot);
167 let Some(cursor_excerpt) =
168 EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
169 else {
170 return Task::ready(Ok(HashMap::default()));
171 };
172
173 let current_file_path = snapshot
174 .file()
175 .map(|f| f.full_path(cx).display().to_string())
176 .unwrap_or_else(|| "untitled".to_string());
177
178 let prompt = SEARCH_PROMPT
179 .replace("{edits}", &edit_history_unified_diff)
180 .replace("{current_file_path}", ¤t_file_path)
181 .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
182
183 if let Some(debug_tx) = &debug_tx {
184 debug_tx
185 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
186 ZetaContextRetrievalStartedDebugInfo {
187 project: project.clone(),
188 timestamp: Instant::now(),
189 search_prompt: prompt.clone(),
190 },
191 ))
192 .ok();
193 }
194
195 let path_style = project.read(cx).path_style(cx);
196
197 let exclude_matcher = {
198 let global_settings = WorktreeSettings::get_global(cx);
199 let exclude_patterns = global_settings
200 .file_scan_exclusions
201 .sources()
202 .iter()
203 .chain(global_settings.private_files.sources().iter());
204
205 match PathMatcher::new(exclude_patterns, path_style) {
206 Ok(matcher) => matcher,
207 Err(err) => {
208 return Task::ready(Err(anyhow!(err)));
209 }
210 }
211 };
212
213 let project = project.clone();
214 cx.spawn(async move |cx| {
215 let initial_prompt_message = LanguageModelRequestMessage {
216 role: Role::User,
217 content: vec![prompt.into()],
218 cache: false,
219 };
220
221 let mut search_stream = request_tool_call::<SearchToolInput>(
222 vec![initial_prompt_message.clone()],
223 SEARCH_TOOL_NAME,
224 &model,
225 cx,
226 )
227 .await?;
228
229 let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
230 select_request_messages.push(initial_prompt_message);
231
232 let mut regex_by_glob: HashMap<String, String> = HashMap::default();
233 let mut search_calls = Vec::new();
234
235 while let Some(event) = search_stream.next().await {
236 match event? {
237 LanguageModelCompletionEvent::ToolUse(tool_use) => {
238 if !tool_use.is_input_complete {
239 continue;
240 }
241
242 if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
243 let input =
244 serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
245
246 for query in input.queries {
247 let regex = regex_by_glob.entry(query.glob).or_default();
248 if !regex.is_empty() {
249 regex.push('|');
250 }
251 regex.push_str(&query.regex);
252 }
253
254 search_calls.push(tool_use);
255 } else {
256 log::warn!(
257 "context gathering model tried to use unknown tool: {}",
258 tool_use.name
259 );
260 }
261 }
262 LanguageModelCompletionEvent::Text(txt) => {
263 if let Some(LanguageModelRequestMessage {
264 role: Role::Assistant,
265 content,
266 ..
267 }) = select_request_messages.last_mut()
268 {
269 if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
270 existing_text.push_str(&txt);
271 } else {
272 content.push(MessageContent::Text(txt));
273 }
274 } else {
275 select_request_messages.push(LanguageModelRequestMessage {
276 role: Role::Assistant,
277 content: vec![MessageContent::Text(txt)],
278 cache: false,
279 });
280 }
281 }
282 LanguageModelCompletionEvent::Thinking { text, signature } => {
283 if let Some(LanguageModelRequestMessage {
284 role: Role::Assistant,
285 content,
286 ..
287 }) = select_request_messages.last_mut()
288 {
289 if let Some(MessageContent::Thinking {
290 text: existing_text,
291 signature: existing_signature,
292 }) = content.last_mut()
293 {
294 existing_text.push_str(&text);
295 *existing_signature = signature;
296 } else {
297 content.push(MessageContent::Thinking { text, signature });
298 }
299 } else {
300 select_request_messages.push(LanguageModelRequestMessage {
301 role: Role::Assistant,
302 content: vec![MessageContent::Thinking { text, signature }],
303 cache: false,
304 });
305 }
306 }
307 LanguageModelCompletionEvent::RedactedThinking { data } => {
308 if let Some(LanguageModelRequestMessage {
309 role: Role::Assistant,
310 content,
311 ..
312 }) = select_request_messages.last_mut()
313 {
314 if let Some(MessageContent::RedactedThinking(existing_data)) =
315 content.last_mut()
316 {
317 existing_data.push_str(&data);
318 } else {
319 content.push(MessageContent::RedactedThinking(data));
320 }
321 } else {
322 select_request_messages.push(LanguageModelRequestMessage {
323 role: Role::Assistant,
324 content: vec![MessageContent::RedactedThinking(data)],
325 cache: false,
326 });
327 }
328 }
329 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
330 log::error!("{ev:?}");
331 }
332 ev => {
333 log::trace!("context search event: {ev:?}")
334 }
335 }
336 }
337
338 let search_tool_use = if search_calls.is_empty() {
339 log::warn!("context model ran 0 searches");
340 return anyhow::Ok(Default::default());
341 } else if search_calls.len() == 1 {
342 search_calls.swap_remove(0)
343 } else {
344 // In theory, the model could perform multiple search calls
345 // Dealing with them separately is not worth it when it doesn't happen in practice.
346 // If it were to happen, here we would combine them into one.
347 // The second request doesn't need to know it was actually two different calls ;)
348 let input = serde_json::to_value(&SearchToolInput {
349 queries: regex_by_glob
350 .iter()
351 .map(|(glob, regex)| SearchToolQuery {
352 glob: glob.clone(),
353 regex: regex.clone(),
354 })
355 .collect(),
356 })
357 .unwrap_or_default();
358
359 LanguageModelToolUse {
360 id: search_calls.swap_remove(0).id,
361 name: SELECT_TOOL_NAME.into(),
362 raw_input: serde_json::to_string(&input).unwrap_or_default(),
363 input,
364 is_input_complete: true,
365 }
366 };
367
368 if let Some(debug_tx) = &debug_tx {
369 debug_tx
370 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
371 ZetaSearchQueryDebugInfo {
372 project: project.clone(),
373 timestamp: Instant::now(),
374 queries: regex_by_glob
375 .iter()
376 .map(|(glob, regex)| SearchToolQuery {
377 glob: glob.clone(),
378 regex: regex.clone(),
379 })
380 .collect(),
381 },
382 ))
383 .ok();
384 }
385
386 let (results_tx, mut results_rx) = mpsc::unbounded();
387
388 for (glob, regex) in regex_by_glob {
389 let exclude_matcher = exclude_matcher.clone();
390 let results_tx = results_tx.clone();
391 let project = project.clone();
392 cx.spawn(async move |cx| {
393 run_query(
394 &glob,
395 ®ex,
396 results_tx.clone(),
397 path_style,
398 exclude_matcher,
399 &project,
400 cx,
401 )
402 .await
403 .log_err();
404 })
405 .detach()
406 }
407 drop(results_tx);
408
409 struct ResultBuffer {
410 buffer: Entity<Buffer>,
411 snapshot: TextBufferSnapshot,
412 }
413
414 let (result_buffers_by_path, merged_result) = cx
415 .background_spawn(async move {
416 let mut excerpts_by_buffer: HashMap<Entity<Buffer>, MatchedBuffer> =
417 HashMap::default();
418
419 while let Some((buffer, matched)) = results_rx.next().await {
420 match excerpts_by_buffer.entry(buffer) {
421 Entry::Occupied(mut entry) => {
422 let entry = entry.get_mut();
423 entry.full_path = matched.full_path;
424 entry.snapshot = matched.snapshot;
425 entry.line_ranges.extend(matched.line_ranges);
426 }
427 Entry::Vacant(entry) => {
428 entry.insert(matched);
429 }
430 }
431 }
432
433 let mut result_buffers_by_path = HashMap::default();
434 let mut merged_result = RESULTS_MESSAGE.to_string();
435
436 for (buffer, mut matched) in excerpts_by_buffer {
437 matched
438 .line_ranges
439 .sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
440
441 write_codeblock(
442 &matched.full_path,
443 merge_excerpts(&matched.snapshot, matched.line_ranges).iter(),
444 &[],
445 Line(matched.snapshot.max_point().row),
446 true,
447 &mut merged_result,
448 );
449
450 result_buffers_by_path.insert(
451 matched.full_path,
452 ResultBuffer {
453 buffer,
454 snapshot: matched.snapshot.text,
455 },
456 );
457 }
458
459 (result_buffers_by_path, merged_result)
460 })
461 .await;
462
463 if let Some(debug_tx) = &debug_tx {
464 debug_tx
465 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
466 ZetaContextRetrievalDebugInfo {
467 project: project.clone(),
468 timestamp: Instant::now(),
469 },
470 ))
471 .ok();
472 }
473
474 let tool_result = LanguageModelToolResult {
475 tool_use_id: search_tool_use.id.clone(),
476 tool_name: SEARCH_TOOL_NAME.into(),
477 is_error: false,
478 content: merged_result.into(),
479 output: None,
480 };
481
482 select_request_messages.extend([
483 LanguageModelRequestMessage {
484 role: Role::Assistant,
485 content: vec![MessageContent::ToolUse(search_tool_use)],
486 cache: false,
487 },
488 LanguageModelRequestMessage {
489 role: Role::User,
490 content: vec![MessageContent::ToolResult(tool_result)],
491 cache: false,
492 },
493 ]);
494
495 if result_buffers_by_path.is_empty() {
496 log::trace!("context gathering queries produced no results");
497 return anyhow::Ok(HashMap::default());
498 }
499
500 select_request_messages.push(LanguageModelRequestMessage {
501 role: Role::User,
502 content: vec![SELECT_PROMPT.into()],
503 cache: false,
504 });
505
506 let mut select_stream = request_tool_call::<SelectToolInput>(
507 select_request_messages,
508 SELECT_TOOL_NAME,
509 &model,
510 cx,
511 )
512 .await?;
513
514 cx.background_spawn(async move {
515 let mut selected_ranges = Vec::new();
516
517 while let Some(event) = select_stream.next().await {
518 match event? {
519 LanguageModelCompletionEvent::ToolUse(tool_use) => {
520 if !tool_use.is_input_complete {
521 continue;
522 }
523
524 if tool_use.name.as_ref() == SELECT_TOOL_NAME {
525 let call =
526 serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
527 selected_ranges.extend(call.ranges);
528 } else {
529 log::warn!(
530 "context gathering model tried to use unknown tool: {}",
531 tool_use.name
532 );
533 }
534 }
535 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
536 log::error!("{ev:?}");
537 }
538 ev => {
539 log::trace!("context select event: {ev:?}")
540 }
541 }
542 }
543
544 if let Some(debug_tx) = &debug_tx {
545 debug_tx
546 .unbounded_send(ZetaDebugInfo::SearchResultsFiltered(
547 ZetaContextRetrievalDebugInfo {
548 project: project.clone(),
549 timestamp: Instant::now(),
550 },
551 ))
552 .ok();
553 }
554
555 if selected_ranges.is_empty() {
556 log::trace!("context gathering selected no ranges")
557 }
558
559 selected_ranges.sort_unstable_by(|a, b| {
560 a.start_line
561 .cmp(&b.start_line)
562 .then(b.end_line.cmp(&a.end_line))
563 });
564
565 let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
566
567 for selected_range in selected_ranges {
568 if let Some(ResultBuffer { buffer, snapshot }) =
569 result_buffers_by_path.get(&selected_range.path)
570 {
571 let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
572 let end_point =
573 snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
574 let range =
575 snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
576
577 related_excerpts_by_buffer
578 .entry(buffer.clone())
579 .or_default()
580 .push(range);
581 } else {
582 log::warn!(
583 "selected path that wasn't included in search results: {}",
584 selected_range.path.display()
585 );
586 }
587 }
588
589 anyhow::Ok(related_excerpts_by_buffer)
590 })
591 .await
592 })
593}
594
595async fn request_tool_call<T: JsonSchema>(
596 messages: Vec<LanguageModelRequestMessage>,
597 tool_name: &'static str,
598 model: &Arc<dyn LanguageModel>,
599 cx: &mut AsyncApp,
600) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
601{
602 let schema = schemars::schema_for!(T);
603
604 let request = LanguageModelRequest {
605 messages,
606 tools: vec![LanguageModelRequestTool {
607 name: tool_name.into(),
608 description: schema
609 .get("description")
610 .and_then(|description| description.as_str())
611 .unwrap()
612 .to_string(),
613 input_schema: serde_json::to_value(schema).unwrap(),
614 }],
615 ..Default::default()
616 };
617
618 Ok(model.stream_completion(request, cx).await?)
619}
620
621const MIN_EXCERPT_LEN: usize = 16;
622const MAX_EXCERPT_LEN: usize = 768;
623const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
624
625struct MatchedBuffer {
626 snapshot: BufferSnapshot,
627 line_ranges: Vec<Range<Line>>,
628 full_path: PathBuf,
629}
630
631async fn run_query(
632 glob: &str,
633 regex: &str,
634 results_tx: UnboundedSender<(Entity<Buffer>, MatchedBuffer)>,
635 path_style: PathStyle,
636 exclude_matcher: PathMatcher,
637 project: &Entity<Project>,
638 cx: &mut AsyncApp,
639) -> Result<()> {
640 let include_matcher = PathMatcher::new(vec![glob], path_style)?;
641
642 let query = SearchQuery::regex(
643 regex,
644 false,
645 true,
646 false,
647 true,
648 include_matcher,
649 exclude_matcher,
650 true,
651 None,
652 )?;
653
654 let results = project.update(cx, |project, cx| project.search(query, cx))?;
655 futures::pin_mut!(results);
656
657 let mut total_bytes = 0;
658
659 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
660 if ranges.is_empty() {
661 continue;
662 }
663
664 let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| {
665 Some((buffer.snapshot(), buffer.file()?.full_path(cx)))
666 })?
667 else {
668 continue;
669 };
670
671 let results_tx = results_tx.clone();
672 cx.background_spawn(async move {
673 let mut line_ranges = Vec::with_capacity(ranges.len());
674
675 for range in ranges {
676 let offset_range = range.to_offset(&snapshot);
677 let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
678
679 if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
680 break;
681 }
682
683 let excerpt = EditPredictionExcerpt::select_from_buffer(
684 query_point,
685 &snapshot,
686 &EditPredictionExcerptOptions {
687 max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
688 min_bytes: MIN_EXCERPT_LEN,
689 target_before_cursor_over_total_bytes: 0.5,
690 },
691 None,
692 );
693
694 if let Some(excerpt) = excerpt {
695 total_bytes += excerpt.range.len();
696 if !excerpt.line_range.is_empty() {
697 line_ranges.push(excerpt.line_range);
698 }
699 }
700 }
701
702 results_tx
703 .unbounded_send((
704 buffer,
705 MatchedBuffer {
706 snapshot,
707 line_ranges,
708 full_path,
709 },
710 ))
711 .log_err();
712 })
713 .detach();
714 }
715
716 anyhow::Ok(())
717}