1use std::{
2 cmp::Reverse, collections::hash_map::Entry, ops::Range, path::PathBuf, sync::Arc, time::Instant,
3};
4
5use crate::{
6 ZetaContextRetrievalDebugInfo, ZetaContextRetrievalStartedDebugInfo, ZetaDebugInfo,
7 ZetaSearchQueryDebugInfo, merge_excerpts::merge_excerpts,
8};
9use anyhow::{Result, anyhow};
10use cloud_zeta2_prompt::write_codeblock;
11use collections::HashMap;
12use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
13use futures::{
14 StreamExt,
15 channel::mpsc::{self, UnboundedSender},
16 stream::BoxStream,
17};
18use gpui::{App, AppContext, AsyncApp, Entity, Task};
19use indoc::indoc;
20use language::{
21 Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _,
22};
23use language_model::{
24 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
25 LanguageModelProviderId, LanguageModelRegistry, LanguageModelRequest,
26 LanguageModelRequestMessage, LanguageModelRequestTool, LanguageModelToolResult,
27 LanguageModelToolUse, MessageContent, Role,
28};
29use project::{
30 Project, WorktreeSettings,
31 search::{SearchQuery, SearchResult},
32};
33use schemars::JsonSchema;
34use serde::{Deserialize, Serialize};
35use util::{
36 ResultExt as _,
37 paths::{PathMatcher, PathStyle},
38};
39use workspace::item::Settings as _;
40
41const SEARCH_PROMPT: &str = indoc! {r#"
42 ## Task
43
44 You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
45 that will serve as context for predicting the next required edit.
46
47 **Your task:**
48 - Analyze the user's recent edits and current cursor context
49 - Use the `search` tool to find code that may be relevant for predicting the next edit
50 - Focus on finding:
51 - Code patterns that might need similar changes based on the recent edits
52 - Functions, variables, types, and constants referenced in the current cursor context
53 - Related implementations, usages, or dependencies that may require consistent updates
54
55 **Important constraints:**
56 - This conversation has exactly 2 turns
57 - You must make ALL search queries in your first response via the `search` tool
58 - All queries will be executed in parallel and results returned together
59 - In the second turn, you will select the most relevant results via the `select` tool.
60
61 ## User Edits
62
63 {edits}
64
65 ## Current cursor context
66
67 `````{current_file_path}
68 {cursor_excerpt}
69 `````
70
71 --
72 Use the `search` tool now
73"#};
74
75const SEARCH_TOOL_NAME: &str = "search";
76
77/// Search for relevant code
78///
79/// For the best results, run multiple queries at once with a single invocation of this tool.
80#[derive(Clone, Deserialize, Serialize, JsonSchema)]
81pub struct SearchToolInput {
82 /// An array of queries to run for gathering context relevant to the next prediction
83 #[schemars(length(max = 5))]
84 pub queries: Box<[SearchToolQuery]>,
85}
86
87#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
88pub struct SearchToolQuery {
89 /// A glob pattern to match file paths in the codebase
90 pub glob: String,
91 /// A regular expression to match content within the files matched by the glob pattern
92 pub regex: String,
93}
94
95const RESULTS_MESSAGE: &str = indoc! {"
96 Here are the results of your queries combined and grouped by file:
97
98"};
99
100const SELECT_TOOL_NAME: &str = "select";
101
102const SELECT_PROMPT: &str = indoc! {"
103 Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
104 Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
105 Include up to 200 lines in total.
106"};
107
108/// Select line ranges from search results
109#[derive(Deserialize, JsonSchema)]
110struct SelectToolInput {
111 /// The line ranges to select from search results.
112 ranges: Vec<SelectLineRange>,
113}
114
115/// A specific line range to select from a file
116#[derive(Debug, Deserialize, JsonSchema)]
117struct SelectLineRange {
118 /// The file path containing the lines to select
119 /// Exactly as it appears in the search result codeblocks.
120 path: PathBuf,
121 /// The starting line number (1-based)
122 #[schemars(range(min = 1))]
123 start_line: u32,
124 /// The ending line number (1-based, inclusive)
125 #[schemars(range(min = 1))]
126 end_line: u32,
127}
128
129#[derive(Debug, Clone, PartialEq)]
130pub struct LlmContextOptions {
131 pub excerpt: EditPredictionExcerptOptions,
132}
133
134pub const MODEL_PROVIDER_ID: LanguageModelProviderId = language_model::ANTHROPIC_PROVIDER_ID;
135
136pub fn find_related_excerpts(
137 buffer: Entity<language::Buffer>,
138 cursor_position: Anchor,
139 project: &Entity<Project>,
140 mut edit_history_unified_diff: String,
141 options: &LlmContextOptions,
142 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
143 cx: &App,
144) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
145 let language_model_registry = LanguageModelRegistry::global(cx);
146 let Some(model) = language_model_registry
147 .read(cx)
148 .available_models(cx)
149 .find(|model| {
150 model.provider_id() == MODEL_PROVIDER_ID
151 && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
152 })
153 else {
154 return Task::ready(Err(anyhow!("could not find context model")));
155 };
156
157 if edit_history_unified_diff.is_empty() {
158 edit_history_unified_diff.push_str("(No user edits yet)");
159 }
160
161 // TODO [zeta2] include breadcrumbs?
162 let snapshot = buffer.read(cx).snapshot();
163 let cursor_point = cursor_position.to_point(&snapshot);
164 let Some(cursor_excerpt) =
165 EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
166 else {
167 return Task::ready(Ok(HashMap::default()));
168 };
169
170 let current_file_path = snapshot
171 .file()
172 .map(|f| f.full_path(cx).display().to_string())
173 .unwrap_or_else(|| "untitled".to_string());
174
175 let prompt = SEARCH_PROMPT
176 .replace("{edits}", &edit_history_unified_diff)
177 .replace("{current_file_path}", ¤t_file_path)
178 .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
179
180 if let Some(debug_tx) = &debug_tx {
181 debug_tx
182 .unbounded_send(ZetaDebugInfo::ContextRetrievalStarted(
183 ZetaContextRetrievalStartedDebugInfo {
184 project: project.clone(),
185 timestamp: Instant::now(),
186 search_prompt: prompt.clone(),
187 },
188 ))
189 .ok();
190 }
191
192 let path_style = project.read(cx).path_style(cx);
193
194 let exclude_matcher = {
195 let global_settings = WorktreeSettings::get_global(cx);
196 let exclude_patterns = global_settings
197 .file_scan_exclusions
198 .sources()
199 .iter()
200 .chain(global_settings.private_files.sources().iter());
201
202 match PathMatcher::new(exclude_patterns, path_style) {
203 Ok(matcher) => matcher,
204 Err(err) => {
205 return Task::ready(Err(anyhow!(err)));
206 }
207 }
208 };
209
210 let project = project.clone();
211 cx.spawn(async move |cx| {
212 let initial_prompt_message = LanguageModelRequestMessage {
213 role: Role::User,
214 content: vec![prompt.into()],
215 cache: false,
216 };
217
218 let mut search_stream = request_tool_call::<SearchToolInput>(
219 vec![initial_prompt_message.clone()],
220 SEARCH_TOOL_NAME,
221 &model,
222 cx,
223 )
224 .await?;
225
226 let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
227 select_request_messages.push(initial_prompt_message);
228
229 let mut regex_by_glob: HashMap<String, String> = HashMap::default();
230 let mut search_calls = Vec::new();
231
232 while let Some(event) = search_stream.next().await {
233 match event? {
234 LanguageModelCompletionEvent::ToolUse(tool_use) => {
235 if !tool_use.is_input_complete {
236 continue;
237 }
238
239 if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
240 let input =
241 serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
242
243 for query in input.queries {
244 let regex = regex_by_glob.entry(query.glob).or_default();
245 if !regex.is_empty() {
246 regex.push('|');
247 }
248 regex.push_str(&query.regex);
249 }
250
251 search_calls.push(tool_use);
252 } else {
253 log::warn!(
254 "context gathering model tried to use unknown tool: {}",
255 tool_use.name
256 );
257 }
258 }
259 LanguageModelCompletionEvent::Text(txt) => {
260 if let Some(LanguageModelRequestMessage {
261 role: Role::Assistant,
262 content,
263 ..
264 }) = select_request_messages.last_mut()
265 {
266 if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
267 existing_text.push_str(&txt);
268 } else {
269 content.push(MessageContent::Text(txt));
270 }
271 } else {
272 select_request_messages.push(LanguageModelRequestMessage {
273 role: Role::Assistant,
274 content: vec![MessageContent::Text(txt)],
275 cache: false,
276 });
277 }
278 }
279 LanguageModelCompletionEvent::Thinking { text, signature } => {
280 if let Some(LanguageModelRequestMessage {
281 role: Role::Assistant,
282 content,
283 ..
284 }) = select_request_messages.last_mut()
285 {
286 if let Some(MessageContent::Thinking {
287 text: existing_text,
288 signature: existing_signature,
289 }) = content.last_mut()
290 {
291 existing_text.push_str(&text);
292 *existing_signature = signature;
293 } else {
294 content.push(MessageContent::Thinking { text, signature });
295 }
296 } else {
297 select_request_messages.push(LanguageModelRequestMessage {
298 role: Role::Assistant,
299 content: vec![MessageContent::Thinking { text, signature }],
300 cache: false,
301 });
302 }
303 }
304 LanguageModelCompletionEvent::RedactedThinking { data } => {
305 if let Some(LanguageModelRequestMessage {
306 role: Role::Assistant,
307 content,
308 ..
309 }) = select_request_messages.last_mut()
310 {
311 if let Some(MessageContent::RedactedThinking(existing_data)) =
312 content.last_mut()
313 {
314 existing_data.push_str(&data);
315 } else {
316 content.push(MessageContent::RedactedThinking(data));
317 }
318 } else {
319 select_request_messages.push(LanguageModelRequestMessage {
320 role: Role::Assistant,
321 content: vec![MessageContent::RedactedThinking(data)],
322 cache: false,
323 });
324 }
325 }
326 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
327 log::error!("{ev:?}");
328 }
329 ev => {
330 log::trace!("context search event: {ev:?}")
331 }
332 }
333 }
334
335 let search_tool_use = if search_calls.is_empty() {
336 log::warn!("context model ran 0 searches");
337 return anyhow::Ok(Default::default());
338 } else if search_calls.len() == 1 {
339 search_calls.swap_remove(0)
340 } else {
341 // In theory, the model could perform multiple search calls
342 // Dealing with them separately is not worth it when it doesn't happen in practice.
343 // If it were to happen, here we would combine them into one.
344 // The second request doesn't need to know it was actually two different calls ;)
345 let input = serde_json::to_value(&SearchToolInput {
346 queries: regex_by_glob
347 .iter()
348 .map(|(glob, regex)| SearchToolQuery {
349 glob: glob.clone(),
350 regex: regex.clone(),
351 })
352 .collect(),
353 })
354 .unwrap_or_default();
355
356 LanguageModelToolUse {
357 id: search_calls.swap_remove(0).id,
358 name: SELECT_TOOL_NAME.into(),
359 raw_input: serde_json::to_string(&input).unwrap_or_default(),
360 input,
361 is_input_complete: true,
362 }
363 };
364
365 if let Some(debug_tx) = &debug_tx {
366 debug_tx
367 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
368 ZetaSearchQueryDebugInfo {
369 project: project.clone(),
370 timestamp: Instant::now(),
371 queries: regex_by_glob
372 .iter()
373 .map(|(glob, regex)| SearchToolQuery {
374 glob: glob.clone(),
375 regex: regex.clone(),
376 })
377 .collect(),
378 },
379 ))
380 .ok();
381 }
382
383 let (results_tx, mut results_rx) = mpsc::unbounded();
384
385 for (glob, regex) in regex_by_glob {
386 let exclude_matcher = exclude_matcher.clone();
387 let results_tx = results_tx.clone();
388 let project = project.clone();
389 cx.spawn(async move |cx| {
390 run_query(
391 &glob,
392 ®ex,
393 results_tx.clone(),
394 path_style,
395 exclude_matcher,
396 &project,
397 cx,
398 )
399 .await
400 .log_err();
401 })
402 .detach()
403 }
404 drop(results_tx);
405
406 struct ResultBuffer {
407 buffer: Entity<Buffer>,
408 snapshot: TextBufferSnapshot,
409 }
410
411 let (result_buffers_by_path, merged_result) = cx
412 .background_spawn(async move {
413 let mut excerpts_by_buffer: HashMap<Entity<Buffer>, MatchedBuffer> =
414 HashMap::default();
415
416 while let Some((buffer, matched)) = results_rx.next().await {
417 match excerpts_by_buffer.entry(buffer) {
418 Entry::Occupied(mut entry) => {
419 let entry = entry.get_mut();
420 entry.full_path = matched.full_path;
421 entry.snapshot = matched.snapshot;
422 entry.line_ranges.extend(matched.line_ranges);
423 }
424 Entry::Vacant(entry) => {
425 entry.insert(matched);
426 }
427 }
428 }
429
430 let mut result_buffers_by_path = HashMap::default();
431 let mut merged_result = RESULTS_MESSAGE.to_string();
432
433 for (buffer, mut matched) in excerpts_by_buffer {
434 matched
435 .line_ranges
436 .sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
437
438 write_codeblock(
439 &matched.full_path,
440 merge_excerpts(&matched.snapshot, matched.line_ranges).iter(),
441 &[],
442 Line(matched.snapshot.max_point().row),
443 true,
444 &mut merged_result,
445 );
446
447 result_buffers_by_path.insert(
448 matched.full_path,
449 ResultBuffer {
450 buffer,
451 snapshot: matched.snapshot.text,
452 },
453 );
454 }
455
456 (result_buffers_by_path, merged_result)
457 })
458 .await;
459
460 if let Some(debug_tx) = &debug_tx {
461 debug_tx
462 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
463 ZetaContextRetrievalDebugInfo {
464 project: project.clone(),
465 timestamp: Instant::now(),
466 },
467 ))
468 .ok();
469 }
470
471 let tool_result = LanguageModelToolResult {
472 tool_use_id: search_tool_use.id.clone(),
473 tool_name: SEARCH_TOOL_NAME.into(),
474 is_error: false,
475 content: merged_result.into(),
476 output: None,
477 };
478
479 select_request_messages.extend([
480 LanguageModelRequestMessage {
481 role: Role::Assistant,
482 content: vec![MessageContent::ToolUse(search_tool_use)],
483 cache: false,
484 },
485 LanguageModelRequestMessage {
486 role: Role::User,
487 content: vec![MessageContent::ToolResult(tool_result)],
488 cache: false,
489 },
490 ]);
491
492 if result_buffers_by_path.is_empty() {
493 log::trace!("context gathering queries produced no results");
494 return anyhow::Ok(HashMap::default());
495 }
496
497 select_request_messages.push(LanguageModelRequestMessage {
498 role: Role::User,
499 content: vec![SELECT_PROMPT.into()],
500 cache: false,
501 });
502
503 let mut select_stream = request_tool_call::<SelectToolInput>(
504 select_request_messages,
505 SELECT_TOOL_NAME,
506 &model,
507 cx,
508 )
509 .await?;
510
511 cx.background_spawn(async move {
512 let mut selected_ranges = Vec::new();
513
514 while let Some(event) = select_stream.next().await {
515 match event? {
516 LanguageModelCompletionEvent::ToolUse(tool_use) => {
517 if !tool_use.is_input_complete {
518 continue;
519 }
520
521 if tool_use.name.as_ref() == SELECT_TOOL_NAME {
522 let call =
523 serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
524 selected_ranges.extend(call.ranges);
525 } else {
526 log::warn!(
527 "context gathering model tried to use unknown tool: {}",
528 tool_use.name
529 );
530 }
531 }
532 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
533 log::error!("{ev:?}");
534 }
535 ev => {
536 log::trace!("context select event: {ev:?}")
537 }
538 }
539 }
540
541 if let Some(debug_tx) = &debug_tx {
542 debug_tx
543 .unbounded_send(ZetaDebugInfo::SearchResultsFiltered(
544 ZetaContextRetrievalDebugInfo {
545 project: project.clone(),
546 timestamp: Instant::now(),
547 },
548 ))
549 .ok();
550 }
551
552 if selected_ranges.is_empty() {
553 log::trace!("context gathering selected no ranges")
554 }
555
556 selected_ranges.sort_unstable_by(|a, b| {
557 a.start_line
558 .cmp(&b.start_line)
559 .then(b.end_line.cmp(&a.end_line))
560 });
561
562 let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
563
564 for selected_range in selected_ranges {
565 if let Some(ResultBuffer { buffer, snapshot }) =
566 result_buffers_by_path.get(&selected_range.path)
567 {
568 let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
569 let end_point =
570 snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
571 let range =
572 snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
573
574 related_excerpts_by_buffer
575 .entry(buffer.clone())
576 .or_default()
577 .push(range);
578 } else {
579 log::warn!(
580 "selected path that wasn't included in search results: {}",
581 selected_range.path.display()
582 );
583 }
584 }
585
586 anyhow::Ok(related_excerpts_by_buffer)
587 })
588 .await
589 })
590}
591
592async fn request_tool_call<T: JsonSchema>(
593 messages: Vec<LanguageModelRequestMessage>,
594 tool_name: &'static str,
595 model: &Arc<dyn LanguageModel>,
596 cx: &mut AsyncApp,
597) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
598{
599 let schema = schemars::schema_for!(T);
600
601 let request = LanguageModelRequest {
602 messages,
603 tools: vec![LanguageModelRequestTool {
604 name: tool_name.into(),
605 description: schema
606 .get("description")
607 .and_then(|description| description.as_str())
608 .unwrap()
609 .to_string(),
610 input_schema: serde_json::to_value(schema).unwrap(),
611 }],
612 ..Default::default()
613 };
614
615 Ok(model.stream_completion(request, cx).await?)
616}
617
618const MIN_EXCERPT_LEN: usize = 16;
619const MAX_EXCERPT_LEN: usize = 768;
620const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
621
622struct MatchedBuffer {
623 snapshot: BufferSnapshot,
624 line_ranges: Vec<Range<Line>>,
625 full_path: PathBuf,
626}
627
628async fn run_query(
629 glob: &str,
630 regex: &str,
631 results_tx: UnboundedSender<(Entity<Buffer>, MatchedBuffer)>,
632 path_style: PathStyle,
633 exclude_matcher: PathMatcher,
634 project: &Entity<Project>,
635 cx: &mut AsyncApp,
636) -> Result<()> {
637 let include_matcher = PathMatcher::new(vec![glob], path_style)?;
638
639 let query = SearchQuery::regex(
640 regex,
641 false,
642 true,
643 false,
644 true,
645 include_matcher,
646 exclude_matcher,
647 true,
648 None,
649 )?;
650
651 let results = project.update(cx, |project, cx| project.search(query, cx))?;
652 futures::pin_mut!(results);
653
654 let mut total_bytes = 0;
655
656 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
657 if ranges.is_empty() {
658 continue;
659 }
660
661 let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| {
662 Some((buffer.snapshot(), buffer.file()?.full_path(cx)))
663 })?
664 else {
665 continue;
666 };
667
668 let results_tx = results_tx.clone();
669 cx.background_spawn(async move {
670 let mut line_ranges = Vec::with_capacity(ranges.len());
671
672 for range in ranges {
673 let offset_range = range.to_offset(&snapshot);
674 let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
675
676 if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
677 break;
678 }
679
680 let excerpt = EditPredictionExcerpt::select_from_buffer(
681 query_point,
682 &snapshot,
683 &EditPredictionExcerptOptions {
684 max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
685 min_bytes: MIN_EXCERPT_LEN,
686 target_before_cursor_over_total_bytes: 0.5,
687 },
688 None,
689 );
690
691 if let Some(excerpt) = excerpt {
692 total_bytes += excerpt.range.len();
693 if !excerpt.line_range.is_empty() {
694 line_ranges.push(excerpt.line_range);
695 }
696 }
697 }
698
699 results_tx
700 .unbounded_send((
701 buffer,
702 MatchedBuffer {
703 snapshot,
704 line_ranges,
705 full_path,
706 },
707 ))
708 .log_err();
709 })
710 .detach();
711 }
712
713 anyhow::Ok(())
714}