1use std::{
2 cmp::Reverse, collections::hash_map::Entry, ops::Range, path::PathBuf, sync::Arc, time::Instant,
3};
4
5use crate::{
6 ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
7 ZetaSearchQueryDebugInfo, merge_excerpts::merge_excerpts,
8};
9use anyhow::{Result, anyhow};
10use cloud_zeta2_prompt::write_codeblock;
11use collections::HashMap;
12use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
13use futures::{
14 StreamExt,
15 channel::mpsc::{self, UnboundedSender},
16 stream::BoxStream,
17};
18use gpui::{App, AppContext, AsyncApp, Entity, Task};
19use indoc::indoc;
20use language::{
21 Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _,
22};
23use language_model::{
24 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
25 LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
26 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
27 LanguageModelToolUse, MessageContent, Role,
28};
29use project::{
30 Project, WorktreeSettings,
31 search::{SearchQuery, SearchResult},
32};
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use util::{
36 ResultExt as _,
37 paths::{PathMatcher, PathStyle},
38};
39use workspace::item::Settings as _;
40
41const SEARCH_PROMPT: &str = indoc! {r#"
42 ## Task
43
44 You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
45 that will serve as context for predicting the next required edit.
46
47 **Your task:**
48 - Analyze the user's recent edits and current cursor context
49 - Use the `search` tool to find code that may be relevant for predicting the next edit
50 - Focus on finding:
51 - Code patterns that might need similar changes based on the recent edits
52 - Functions, variables, types, and constants referenced in the current cursor context
53 - Related implementations, usages, or dependencies that may require consistent updates
54
55 **Important constraints:**
56 - This conversation has exactly 2 turns
57 - You must make ALL search queries in your first response via the `search` tool
58 - All queries will be executed in parallel and results returned together
59 - In the second turn, you will select the most relevant results via the `select` tool.
60
61 ## User Edits
62
63 {edits}
64
65 ## Current cursor context
66
67 `````{current_file_path}
68 {cursor_excerpt}
69 `````
70
71 --
72 Use the `search` tool now
73"#};
74
75const SEARCH_TOOL_NAME: &str = "search";
76
77/// Search for relevant code
78///
79/// For the best results, run multiple queries at once with a single invocation of this tool.
80#[derive(Clone, Deserialize, Serialize, JsonSchema)]
81pub struct SearchToolInput {
82 /// An array of queries to run for gathering context relevant to the next prediction
83 #[schemars(length(max = 5))]
84 pub queries: Box<[SearchToolQuery]>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
88pub struct SearchToolQuery {
89 /// A glob pattern to match file paths in the codebase
90 pub glob: String,
91 /// A regular expression to match content within the files matched by the glob pattern
92 pub regex: String,
93}
94
95const RESULTS_MESSAGE: &str = indoc! {"
96 Here are the results of your queries combined and grouped by file:
97
98"};
99
100const SELECT_TOOL_NAME: &str = "select";
101
102const SELECT_PROMPT: &str = indoc! {"
103 Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
104 Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
105 Include up to 200 lines in total.
106"};
107
108/// Select line ranges from search results
109#[derive(Deserialize, JsonSchema)]
110struct SelectToolInput {
111 /// The line ranges to select from search results.
112 ranges: Vec<SelectLineRange>,
113}
114
115/// A specific line range to select from a file
116#[derive(Debug, Deserialize, JsonSchema)]
117struct SelectLineRange {
118 /// The file path containing the lines to select
119 /// Exactly as it appears in the search result codeblocks.
120 path: PathBuf,
121 /// The starting line number (1-based)
122 #[schemars(range(min = 1))]
123 start_line: u32,
124 /// The ending line number (1-based, inclusive)
125 #[schemars(range(min = 1))]
126 end_line: u32,
127}
128
129#[derive(Debug, Clone, PartialEq)]
130pub struct LlmContextOptions {
131 pub excerpt: EditPredictionExcerptOptions,
132}
133
134pub const MODEL_PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
135
136pub fn find_related_excerpts(
137 buffer: Entity<language::Buffer>,
138 cursor_position: Anchor,
139 project: &Entity<Project>,
140 mut edit_history_unified_diff: String,
141 options: &LlmContextOptions,
142 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
143 cx: &App,
144) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
145 let language_model_registry = LanguageModelRegistry::global(cx);
146 let Some(model) = language_model_registry
147 .read(cx)
148 .available_models(cx)
149 .find(|model| {
150 model.provider_id() == MODEL_PROVIDER_ID
151 && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
152 // model.provider_id() == LanguageModelProviderId::new("zeta-ctx-qwen-30b")
153 // model.provider_id() == LanguageModelProviderId::new("ollama")
154 // && model.id() == LanguageModelId("gpt-oss:20b".into())
155 })
156 else {
157 return Task::ready(Err(anyhow!("could not find context model")));
158 };
159
160 if edit_history_unified_diff.is_empty() {
161 edit_history_unified_diff.push_str("(No user edits yet)");
162 }
163
164 // TODO [zeta2] include breadcrumbs?
165 let snapshot = buffer.read(cx).snapshot();
166 let cursor_point = cursor_position.to_point(&snapshot);
167 let Some(cursor_excerpt) =
168 EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
169 else {
170 return Task::ready(Ok(HashMap::default()));
171 };
172
173 let current_file_path = snapshot
174 .file()
175 .map(|f| f.full_path(cx).display().to_string())
176 .unwrap_or_else(|| "untitled".to_string());
177
178 let prompt = SEARCH_PROMPT
179 .replace("{edits}", &edit_history_unified_diff)
180 .replace("{current_file_path}", ¤t_file_path)
181 .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
182
183 if let Some(debug_tx) = &debug_tx {
184 debug_tx
185 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
186 ZetaContextRetrievalStartedDebugInfo {
187 project: project.clone(),
188 timestamp: Instant::now(),
189 search_prompt: prompt.clone(),
190 },
191 ))
192 .ok();
193 }
194
195 let path_style = project.read(cx).path_style(cx);
196
197 let exclude_matcher = {
198 let global_settings = WorktreeSettings::get_global(cx);
199 let exclude_patterns = global_settings
200 .file_scan_exclusions
201 .sources()
202 .iter()
203 .chain(global_settings.private_files.sources().iter());
204
205 match PathMatcher::new(exclude_patterns, path_style) {
206 Ok(matcher) => matcher,
207 Err(err) => {
208 return Task::ready(Err(anyhow!(err)));
209 }
210 }
211 };
212
213 let project = project.clone();
214 cx.spawn(async move |cx| {
215 let initial_prompt_message = LanguageModelRequestMessage {
216 role: Role::User,
217 content: vec![prompt.into()],
218 cache: false,
219 };
220
221 let mut search_stream = request_tool_call::<SearchToolInput>(
222 vec![initial_prompt_message.clone()],
223 SEARCH_TOOL_NAME,
224 &model,
225 cx,
226 )
227 .await?;
228
229 let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
230 select_request_messages.push(initial_prompt_message);
231
232 let mut regex_by_glob: HashMap<String, String> = HashMap::default();
233 let mut search_calls = Vec::new();
234
235 while let Some(event) = search_stream.next().await {
236 match event? {
237 LanguageModelCompletionEvent::ToolUse(tool_use) => {
238 if !tool_use.is_input_complete {
239 continue;
240 }
241
242 if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
243 let input =
244 serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
245
246 for query in input.queries {
247 let regex = regex_by_glob.entry(query.glob).or_default();
248 if !regex.is_empty() {
249 regex.push('|');
250 }
251 regex.push_str(&query.regex);
252 }
253
254 search_calls.push(tool_use);
255 } else {
256 log::warn!(
257 "context gathering model tried to use unknown tool: {}",
258 tool_use.name
259 );
260 }
261 }
262 LanguageModelCompletionEvent::Text(txt) => {
263 if let Some(LanguageModelRequestMessage {
264 role: Role::Assistant,
265 content,
266 ..
267 }) = select_request_messages.last_mut()
268 {
269 if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
270 existing_text.push_str(&txt);
271 } else {
272 content.push(MessageContent::Text(txt));
273 }
274 } else {
275 select_request_messages.push(LanguageModelRequestMessage {
276 role: Role::Assistant,
277 content: vec![MessageContent::Text(txt)],
278 cache: false,
279 });
280 }
281 }
282 LanguageModelCompletionEvent::Thinking { text, signature } => {
283 if let Some(LanguageModelRequestMessage {
284 role: Role::Assistant,
285 content,
286 ..
287 }) = select_request_messages.last_mut()
288 {
289 if let Some(MessageContent::Thinking {
290 text: existing_text,
291 signature: existing_signature,
292 }) = content.last_mut()
293 {
294 existing_text.push_str(&text);
295 *existing_signature = signature;
296 } else {
297 content.push(MessageContent::Thinking { text, signature });
298 }
299 } else {
300 select_request_messages.push(LanguageModelRequestMessage {
301 role: Role::Assistant,
302 content: vec![MessageContent::Thinking { text, signature }],
303 cache: false,
304 });
305 }
306 }
307 LanguageModelCompletionEvent::RedactedThinking { data } => {
308 if let Some(LanguageModelRequestMessage {
309 role: Role::Assistant,
310 content,
311 ..
312 }) = select_request_messages.last_mut()
313 {
314 if let Some(MessageContent::RedactedThinking(existing_data)) =
315 content.last_mut()
316 {
317 existing_data.push_str(&data);
318 } else {
319 content.push(MessageContent::RedactedThinking(data));
320 }
321 } else {
322 select_request_messages.push(LanguageModelRequestMessage {
323 role: Role::Assistant,
324 content: vec![MessageContent::RedactedThinking(data)],
325 cache: false,
326 });
327 }
328 }
329 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
330 log::error!("{ev:?}");
331 }
332 ev => {
333 log::trace!("context search event: {ev:?}")
334 }
335 }
336 }
337
338 let search_tool_use = if search_calls.is_empty() {
339 log::warn!("context model ran 0 searches");
340 return anyhow::Ok(Default::default());
341 } else if search_calls.len() == 1 {
342 search_calls.swap_remove(0)
343 } else {
344 // In theory, the model could perform multiple search calls
345 // Dealing with them separately is not worth it when it doesn't happen in practice.
346 // If it were to happen, here we would combine them into one.
347 // The second request doesn't need to know it was actually two different calls ;)
348 let input = serde_json::to_value(&SearchToolInput {
349 queries: regex_by_glob
350 .iter()
351 .map(|(glob, regex)| SearchToolQuery {
352 glob: glob.clone(),
353 regex: regex.clone(),
354 })
355 .collect(),
356 })
357 .unwrap_or_default();
358
359 LanguageModelToolUse {
360 id: search_calls.swap_remove(0).id,
361 name: SELECT_TOOL_NAME.into(),
362 raw_input: serde_json::to_string(&input).unwrap_or_default(),
363 input,
364 is_input_complete: true,
365 thought_signature: None,
366 }
367 };
368
369 if let Some(debug_tx) = &debug_tx {
370 debug_tx
371 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
372 ZetaSearchQueryDebugInfo {
373 project: project.clone(),
374 timestamp: Instant::now(),
375 queries: regex_by_glob
376 .iter()
377 .map(|(glob, regex)| SearchToolQuery {
378 glob: glob.clone(),
379 regex: regex.clone(),
380 })
381 .collect(),
382 },
383 ))
384 .ok();
385 }
386
387 let (results_tx, mut results_rx) = mpsc::unbounded();
388
389 for (glob, regex) in regex_by_glob {
390 let exclude_matcher = exclude_matcher.clone();
391 let results_tx = results_tx.clone();
392 let project = project.clone();
393 cx.spawn(async move |cx| {
394 run_query(
395 &glob,
396 ®ex,
397 results_tx.clone(),
398 path_style,
399 exclude_matcher,
400 &project,
401 cx,
402 )
403 .await
404 .log_err();
405 })
406 .detach()
407 }
408 drop(results_tx);
409
410 struct ResultBuffer {
411 buffer: Entity<Buffer>,
412 snapshot: TextBufferSnapshot,
413 }
414
415 let (result_buffers_by_path, merged_result) = cx
416 .background_spawn(async move {
417 let mut excerpts_by_buffer: HashMap<Entity<Buffer>, MatchedBuffer> =
418 HashMap::default();
419
420 while let Some((buffer, matched)) = results_rx.next().await {
421 match excerpts_by_buffer.entry(buffer) {
422 Entry::Occupied(mut entry) => {
423 let entry = entry.get_mut();
424 entry.full_path = matched.full_path;
425 entry.snapshot = matched.snapshot;
426 entry.line_ranges.extend(matched.line_ranges);
427 }
428 Entry::Vacant(entry) => {
429 entry.insert(matched);
430 }
431 }
432 }
433
434 let mut result_buffers_by_path = HashMap::default();
435 let mut merged_result = RESULTS_MESSAGE.to_string();
436
437 for (buffer, mut matched) in excerpts_by_buffer {
438 matched
439 .line_ranges
440 .sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
441
442 write_codeblock(
443 &matched.full_path,
444 merge_excerpts(&matched.snapshot, matched.line_ranges).iter(),
445 &[],
446 Line(matched.snapshot.max_point().row),
447 true,
448 &mut merged_result,
449 );
450
451 result_buffers_by_path.insert(
452 matched.full_path,
453 ResultBuffer {
454 buffer,
455 snapshot: matched.snapshot.text,
456 },
457 );
458 }
459
460 (result_buffers_by_path, merged_result)
461 })
462 .await;
463
464 if let Some(debug_tx) = &debug_tx {
465 debug_tx
466 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
467 ZetaContextRetrievalDebugInfo {
468 project: project.clone(),
469 timestamp: Instant::now(),
470 },
471 ))
472 .ok();
473 }
474
475 let tool_result = LanguageModelToolResult {
476 tool_use_id: search_tool_use.id.clone(),
477 tool_name: SEARCH_TOOL_NAME.into(),
478 is_error: false,
479 content: merged_result.into(),
480 output: None,
481 };
482
483 select_request_messages.extend([
484 LanguageModelRequestMessage {
485 role: Role::Assistant,
486 content: vec![MessageContent::ToolUse(search_tool_use)],
487 cache: false,
488 },
489 LanguageModelRequestMessage {
490 role: Role::User,
491 content: vec![MessageContent::ToolResult(tool_result)],
492 cache: false,
493 },
494 ]);
495
496 if result_buffers_by_path.is_empty() {
497 log::trace!("context gathering queries produced no results");
498 return anyhow::Ok(HashMap::default());
499 }
500
501 select_request_messages.push(LanguageModelRequestMessage {
502 role: Role::User,
503 content: vec![SELECT_PROMPT.into()],
504 cache: false,
505 });
506
507 let mut select_stream = request_tool_call::<SelectToolInput>(
508 select_request_messages,
509 SELECT_TOOL_NAME,
510 &model,
511 cx,
512 )
513 .await?;
514
515 cx.background_spawn(async move {
516 let mut selected_ranges = Vec::new();
517
518 while let Some(event) = select_stream.next().await {
519 match event? {
520 LanguageModelCompletionEvent::ToolUse(tool_use) => {
521 if !tool_use.is_input_complete {
522 continue;
523 }
524
525 if tool_use.name.as_ref() == SELECT_TOOL_NAME {
526 let call =
527 serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
528 selected_ranges.extend(call.ranges);
529 } else {
530 log::warn!(
531 "context gathering model tried to use unknown tool: {}",
532 tool_use.name
533 );
534 }
535 }
536 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
537 log::error!("{ev:?}");
538 }
539 ev => {
540 log::trace!("context select event: {ev:?}")
541 }
542 }
543 }
544
545 if let Some(debug_tx) = &debug_tx {
546 debug_tx
547 .unbounded_send(ZetaDebugInfo::SearchResultsFiltered(
548 ZetaContextRetrievalDebugInfo {
549 project: project.clone(),
550 timestamp: Instant::now(),
551 },
552 ))
553 .ok();
554 }
555
556 if selected_ranges.is_empty() {
557 log::trace!("context gathering selected no ranges")
558 }
559
560 selected_ranges.sort_unstable_by(|a, b| {
561 a.start_line
562 .cmp(&b.start_line)
563 .then(b.end_line.cmp(&a.end_line))
564 });
565
566 let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
567
568 for selected_range in selected_ranges {
569 if let Some(ResultBuffer { buffer, snapshot }) =
570 result_buffers_by_path.get(&selected_range.path)
571 {
572 let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
573 let end_point =
574 snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
575 let range =
576 snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
577
578 related_excerpts_by_buffer
579 .entry(buffer.clone())
580 .or_default()
581 .push(range);
582 } else {
583 log::warn!(
584 "selected path that wasn't included in search results: {}",
585 selected_range.path.display()
586 );
587 }
588 }
589
590 anyhow::Ok(related_excerpts_by_buffer)
591 })
592 .await
593 })
594}
595
596async fn request_tool_call<T: JsonSchema>(
597 messages: Vec<LanguageModelRequestMessage>,
598 tool_name: &'static str,
599 model: &Arc<dyn LanguageModel>,
600 cx: &mut AsyncApp,
601) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
602{
603 let schema = schemars::schema_for!(T);
604
605 let request = LanguageModelRequest {
606 messages,
607 tools: vec![LanguageModelRequestTool {
608 name: tool_name.into(),
609 description: schema
610 .get("description")
611 .and_then(|description| description.as_str())
612 .unwrap()
613 .to_string(),
614 input_schema: serde_json::to_value(schema).unwrap(),
615 }],
616 ..Default::default()
617 };
618
619 Ok(model.stream_completion(request, cx).await?)
620}
621
622const MIN_EXCERPT_LEN: usize = 16;
623const MAX_EXCERPT_LEN: usize = 768;
624const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
625
626struct MatchedBuffer {
627 snapshot: BufferSnapshot,
628 line_ranges: Vec<Range<Line>>,
629 full_path: PathBuf,
630}
631
632async fn run_query(
633 glob: &str,
634 regex: &str,
635 results_tx: UnboundedSender<(Entity<Buffer>, MatchedBuffer)>,
636 path_style: PathStyle,
637 exclude_matcher: PathMatcher,
638 project: &Entity<Project>,
639 cx: &mut AsyncApp,
640) -> Result<()> {
641 let include_matcher = PathMatcher::new(vec![glob], path_style)?;
642
643 let query = SearchQuery::regex(
644 regex,
645 false,
646 true,
647 false,
648 true,
649 include_matcher,
650 exclude_matcher,
651 true,
652 None,
653 )?;
654
655 let results = project.update(cx, |project, cx| project.search(query, cx))?;
656 futures::pin_mut!(results);
657
658 let mut total_bytes = 0;
659
660 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
661 if ranges.is_empty() {
662 continue;
663 }
664
665 let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| {
666 Some((buffer.snapshot(), buffer.file()?.full_path(cx)))
667 })?
668 else {
669 continue;
670 };
671
672 let results_tx = results_tx.clone();
673 cx.background_spawn(async move {
674 let mut line_ranges = Vec::with_capacity(ranges.len());
675
676 for range in ranges {
677 let offset_range = range.to_offset(&snapshot);
678 let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
679
680 if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
681 break;
682 }
683
684 let excerpt = EditPredictionExcerpt::select_from_buffer(
685 query_point,
686 &snapshot,
687 &EditPredictionExcerptOptions {
688 max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
689 min_bytes: MIN_EXCERPT_LEN,
690 target_before_cursor_over_total_bytes: 0.5,
691 },
692 None,
693 );
694
695 if let Some(excerpt) = excerpt {
696 total_bytes += excerpt.range.len();
697 if !excerpt.line_range.is_empty() {
698 line_ranges.push(excerpt.line_range);
699 }
700 }
701 }
702
703 results_tx
704 .unbounded_send((
705 buffer,
706 MatchedBuffer {
707 snapshot,
708 line_ranges,
709 full_path,
710 },
711 ))
712 .log_err();
713 })
714 .detach();
715 }
716
717 anyhow::Ok(())
718}