1use std::{
2 cmp::Reverse, collections::hash_map::Entry, fmt::Write, ops::Range, path::PathBuf, sync::Arc,
3 time::Instant,
4};
5
6use crate::{
7 ZetaContextRetrievalDebugInfo, ZetaDebugInfo, ZetaSearchQueryDebugInfo,
8 merge_excerpts::write_merged_excerpts,
9};
10use anyhow::{Result, anyhow};
11use collections::HashMap;
12use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
13use futures::{
14 StreamExt,
15 channel::mpsc::{self, UnboundedSender},
16 stream::BoxStream,
17};
18use gpui::{App, AppContext, AsyncApp, Entity, Task};
19use indoc::indoc;
20use language::{
21 Anchor, Bias, Buffer, BufferSnapshot, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _,
22};
23use language_model::{
24 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
25 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
26 LanguageModelRequestTool, LanguageModelToolResult, LanguageModelToolUse, MessageContent, Role,
27};
28use project::{
29 Project, WorktreeSettings,
30 search::{SearchQuery, SearchResult},
31};
32use schemars::JsonSchema;
33use serde::{Deserialize, Serialize};
34use util::{
35 ResultExt as _,
36 paths::{PathMatcher, PathStyle},
37};
38use workspace::item::Settings as _;
39
40const SEARCH_PROMPT: &str = indoc! {r#"
41 ## Task
42
43 You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
44 that will serve as context for predicting the next required edit.
45
46 **Your task:**
47 - Analyze the user's recent edits and current cursor context
48 - Use the `search` tool to find code that may be relevant for predicting the next edit
49 - Focus on finding:
50 - Code patterns that might need similar changes based on the recent edits
51 - Functions, variables, types, and constants referenced in the current cursor context
52 - Related implementations, usages, or dependencies that may require consistent updates
53
54 **Important constraints:**
55 - This conversation has exactly 2 turns
56 - You must make ALL search queries in your first response via the `search` tool
57 - All queries will be executed in parallel and results returned together
58 - In the second turn, you will select the most relevant results via the `select` tool.
59
60 ## User Edits
61
62 {edits}
63
64 ## Current cursor context
65
66 `````filename={current_file_path}
67 {cursor_excerpt}
68 `````
69
70 --
71 Use the `search` tool now
72"#};
73
74const SEARCH_TOOL_NAME: &str = "search";
75
76/// Search for relevant code
77///
78/// For the best results, run multiple queries at once with a single invocation of this tool.
79#[derive(Clone, Deserialize, Serialize, JsonSchema)]
80pub struct SearchToolInput {
81 /// An array of queries to run for gathering context relevant to the next prediction
82 #[schemars(length(max = 5))]
83 pub queries: Box<[SearchToolQuery]>,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
87pub struct SearchToolQuery {
88 /// A glob pattern to match file paths in the codebase
89 pub glob: String,
90 /// A regular expression to match content within the files matched by the glob pattern
91 pub regex: String,
92}
93
94const RESULTS_MESSAGE: &str = indoc! {"
95 Here are the results of your queries combined and grouped by file:
96
97"};
98
99const SELECT_TOOL_NAME: &str = "select";
100
101const SELECT_PROMPT: &str = indoc! {"
102 Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
103 Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
104 Include up to 200 lines in total.
105"};
106
107/// Select line ranges from search results
108#[derive(Deserialize, JsonSchema)]
109struct SelectToolInput {
110 /// The line ranges to select from search results.
111 ranges: Vec<SelectLineRange>,
112}
113
114/// A specific line range to select from a file
115#[derive(Debug, Deserialize, JsonSchema)]
116struct SelectLineRange {
117 /// The file path containing the lines to select
118 /// Exactly as it appears in the search result codeblocks.
119 path: PathBuf,
120 /// The starting line number (1-based)
121 #[schemars(range(min = 1))]
122 start_line: u32,
123 /// The ending line number (1-based, inclusive)
124 #[schemars(range(min = 1))]
125 end_line: u32,
126}
127
128#[derive(Debug, Clone, PartialEq)]
129pub struct LlmContextOptions {
130 pub excerpt: EditPredictionExcerptOptions,
131}
132
133pub fn find_related_excerpts<'a>(
134 buffer: Entity<language::Buffer>,
135 cursor_position: Anchor,
136 project: &Entity<Project>,
137 events: impl Iterator<Item = &'a crate::Event>,
138 options: &LlmContextOptions,
139 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
140 cx: &App,
141) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
142 let language_model_registry = LanguageModelRegistry::global(cx);
143 let Some(model) = language_model_registry
144 .read(cx)
145 .available_models(cx)
146 .find(|model| {
147 model.provider_id() == language_model::ANTHROPIC_PROVIDER_ID
148 && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
149 })
150 else {
151 return Task::ready(Err(anyhow!("could not find claude model")));
152 };
153
154 let mut edits_string = String::new();
155
156 for event in events {
157 if let Some(event) = event.to_request_event(cx) {
158 writeln!(&mut edits_string, "{event}").ok();
159 }
160 }
161
162 if edits_string.is_empty() {
163 edits_string.push_str("(No user edits yet)");
164 }
165
166 // TODO [zeta2] include breadcrumbs?
167 let snapshot = buffer.read(cx).snapshot();
168 let cursor_point = cursor_position.to_point(&snapshot);
169 let Some(cursor_excerpt) =
170 EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
171 else {
172 return Task::ready(Ok(HashMap::default()));
173 };
174
175 let current_file_path = snapshot
176 .file()
177 .map(|f| f.full_path(cx).display().to_string())
178 .unwrap_or_else(|| "untitled".to_string());
179
180 let prompt = SEARCH_PROMPT
181 .replace("{edits}", &edits_string)
182 .replace("{current_file_path}", ¤t_file_path)
183 .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
184
185 let path_style = project.read(cx).path_style(cx);
186
187 let exclude_matcher = {
188 let global_settings = WorktreeSettings::get_global(cx);
189 let exclude_patterns = global_settings
190 .file_scan_exclusions
191 .sources()
192 .iter()
193 .chain(global_settings.private_files.sources().iter());
194
195 match PathMatcher::new(exclude_patterns, path_style) {
196 Ok(matcher) => matcher,
197 Err(err) => {
198 return Task::ready(Err(anyhow!(err)));
199 }
200 }
201 };
202
203 let project = project.clone();
204 cx.spawn(async move |cx| {
205 let initial_prompt_message = LanguageModelRequestMessage {
206 role: Role::User,
207 content: vec![prompt.into()],
208 cache: false,
209 };
210
211 let mut search_stream = request_tool_call::<SearchToolInput>(
212 vec![initial_prompt_message.clone()],
213 SEARCH_TOOL_NAME,
214 &model,
215 cx,
216 )
217 .await?;
218
219 let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
220 select_request_messages.push(initial_prompt_message);
221
222 let mut regex_by_glob: HashMap<String, String> = HashMap::default();
223 let mut search_calls = Vec::new();
224
225 while let Some(event) = search_stream.next().await {
226 match event? {
227 LanguageModelCompletionEvent::ToolUse(tool_use) => {
228 if !tool_use.is_input_complete {
229 continue;
230 }
231
232 if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
233 let input =
234 serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
235
236 for query in input.queries {
237 let regex = regex_by_glob.entry(query.glob).or_default();
238 if !regex.is_empty() {
239 regex.push('|');
240 }
241 regex.push_str(&query.regex);
242 }
243
244 search_calls.push(tool_use);
245 } else {
246 log::warn!(
247 "context gathering model tried to use unknown tool: {}",
248 tool_use.name
249 );
250 }
251 }
252 LanguageModelCompletionEvent::Text(txt) => {
253 if let Some(LanguageModelRequestMessage {
254 role: Role::Assistant,
255 content,
256 ..
257 }) = select_request_messages.last_mut()
258 {
259 if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
260 existing_text.push_str(&txt);
261 } else {
262 content.push(MessageContent::Text(txt));
263 }
264 } else {
265 select_request_messages.push(LanguageModelRequestMessage {
266 role: Role::Assistant,
267 content: vec![MessageContent::Text(txt)],
268 cache: false,
269 });
270 }
271 }
272 LanguageModelCompletionEvent::Thinking { text, signature } => {
273 if let Some(LanguageModelRequestMessage {
274 role: Role::Assistant,
275 content,
276 ..
277 }) = select_request_messages.last_mut()
278 {
279 if let Some(MessageContent::Thinking {
280 text: existing_text,
281 signature: existing_signature,
282 }) = content.last_mut()
283 {
284 existing_text.push_str(&text);
285 *existing_signature = signature;
286 } else {
287 content.push(MessageContent::Thinking { text, signature });
288 }
289 } else {
290 select_request_messages.push(LanguageModelRequestMessage {
291 role: Role::Assistant,
292 content: vec![MessageContent::Thinking { text, signature }],
293 cache: false,
294 });
295 }
296 }
297 LanguageModelCompletionEvent::RedactedThinking { data } => {
298 if let Some(LanguageModelRequestMessage {
299 role: Role::Assistant,
300 content,
301 ..
302 }) = select_request_messages.last_mut()
303 {
304 if let Some(MessageContent::RedactedThinking(existing_data)) =
305 content.last_mut()
306 {
307 existing_data.push_str(&data);
308 } else {
309 content.push(MessageContent::RedactedThinking(data));
310 }
311 } else {
312 select_request_messages.push(LanguageModelRequestMessage {
313 role: Role::Assistant,
314 content: vec![MessageContent::RedactedThinking(data)],
315 cache: false,
316 });
317 }
318 }
319 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
320 log::error!("{ev:?}");
321 }
322 ev => {
323 log::trace!("context search event: {ev:?}")
324 }
325 }
326 }
327
328 let search_tool_use = if search_calls.is_empty() {
329 log::warn!("context model ran 0 searches");
330 return anyhow::Ok(Default::default());
331 } else if search_calls.len() == 1 {
332 search_calls.swap_remove(0)
333 } else {
334 // In theory, the model could perform multiple search calls
335 // Dealing with them separately is not worth it when it doesn't happen in practice.
336 // If it were to happen, here we would combine them into one.
337 // The second request doesn't need to know it was actually two different calls ;)
338 let input = serde_json::to_value(&SearchToolInput {
339 queries: regex_by_glob
340 .iter()
341 .map(|(glob, regex)| SearchToolQuery {
342 glob: glob.clone(),
343 regex: regex.clone(),
344 })
345 .collect(),
346 })
347 .unwrap_or_default();
348
349 LanguageModelToolUse {
350 id: search_calls.swap_remove(0).id,
351 name: SELECT_TOOL_NAME.into(),
352 raw_input: serde_json::to_string(&input).unwrap_or_default(),
353 input,
354 is_input_complete: true,
355 }
356 };
357
358 if let Some(debug_tx) = &debug_tx {
359 debug_tx
360 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
361 ZetaSearchQueryDebugInfo {
362 project: project.clone(),
363 timestamp: Instant::now(),
364 queries: regex_by_glob
365 .iter()
366 .map(|(glob, regex)| SearchToolQuery {
367 glob: glob.clone(),
368 regex: regex.clone(),
369 })
370 .collect(),
371 },
372 ))
373 .ok();
374 }
375
376 let (results_tx, mut results_rx) = mpsc::unbounded();
377
378 for (glob, regex) in regex_by_glob {
379 let exclude_matcher = exclude_matcher.clone();
380 let results_tx = results_tx.clone();
381 let project = project.clone();
382 cx.spawn(async move |cx| {
383 run_query(
384 &glob,
385 ®ex,
386 results_tx.clone(),
387 path_style,
388 exclude_matcher,
389 &project,
390 cx,
391 )
392 .await
393 .log_err();
394 })
395 .detach()
396 }
397 drop(results_tx);
398
399 struct ResultBuffer {
400 buffer: Entity<Buffer>,
401 snapshot: TextBufferSnapshot,
402 }
403
404 let (result_buffers_by_path, merged_result) = cx
405 .background_spawn(async move {
406 let mut excerpts_by_buffer: HashMap<Entity<Buffer>, MatchedBuffer> =
407 HashMap::default();
408
409 while let Some((buffer, matched)) = results_rx.next().await {
410 match excerpts_by_buffer.entry(buffer) {
411 Entry::Occupied(mut entry) => {
412 let entry = entry.get_mut();
413 entry.full_path = matched.full_path;
414 entry.snapshot = matched.snapshot;
415 entry.line_ranges.extend(matched.line_ranges);
416 }
417 Entry::Vacant(entry) => {
418 entry.insert(matched);
419 }
420 }
421 }
422
423 let mut result_buffers_by_path = HashMap::default();
424 let mut merged_result = RESULTS_MESSAGE.to_string();
425
426 for (buffer, mut matched) in excerpts_by_buffer {
427 matched
428 .line_ranges
429 .sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
430
431 writeln!(
432 &mut merged_result,
433 "`````filename={}",
434 matched.full_path.display()
435 )
436 .unwrap();
437 write_merged_excerpts(
438 &matched.snapshot,
439 matched.line_ranges,
440 &[],
441 &mut merged_result,
442 );
443 merged_result.push_str("`````\n\n");
444
445 result_buffers_by_path.insert(
446 matched.full_path,
447 ResultBuffer {
448 buffer,
449 snapshot: matched.snapshot.text,
450 },
451 );
452 }
453
454 (result_buffers_by_path, merged_result)
455 })
456 .await;
457
458 if let Some(debug_tx) = &debug_tx {
459 debug_tx
460 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
461 ZetaContextRetrievalDebugInfo {
462 project: project.clone(),
463 timestamp: Instant::now(),
464 },
465 ))
466 .ok();
467 }
468
469 let tool_result = LanguageModelToolResult {
470 tool_use_id: search_tool_use.id.clone(),
471 tool_name: SEARCH_TOOL_NAME.into(),
472 is_error: false,
473 content: merged_result.into(),
474 output: None,
475 };
476
477 select_request_messages.extend([
478 LanguageModelRequestMessage {
479 role: Role::Assistant,
480 content: vec![MessageContent::ToolUse(search_tool_use)],
481 cache: false,
482 },
483 LanguageModelRequestMessage {
484 role: Role::User,
485 content: vec![MessageContent::ToolResult(tool_result)],
486 cache: false,
487 },
488 ]);
489
490 if result_buffers_by_path.is_empty() {
491 log::trace!("context gathering queries produced no results");
492 return anyhow::Ok(HashMap::default());
493 }
494
495 select_request_messages.push(LanguageModelRequestMessage {
496 role: Role::User,
497 content: vec![SELECT_PROMPT.into()],
498 cache: false,
499 });
500
501 let mut select_stream = request_tool_call::<SelectToolInput>(
502 select_request_messages,
503 SELECT_TOOL_NAME,
504 &model,
505 cx,
506 )
507 .await?;
508
509 cx.background_spawn(async move {
510 let mut selected_ranges = Vec::new();
511
512 while let Some(event) = select_stream.next().await {
513 match event? {
514 LanguageModelCompletionEvent::ToolUse(tool_use) => {
515 if !tool_use.is_input_complete {
516 continue;
517 }
518
519 if tool_use.name.as_ref() == SELECT_TOOL_NAME {
520 let call =
521 serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
522 selected_ranges.extend(call.ranges);
523 } else {
524 log::warn!(
525 "context gathering model tried to use unknown tool: {}",
526 tool_use.name
527 );
528 }
529 }
530 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
531 log::error!("{ev:?}");
532 }
533 ev => {
534 log::trace!("context select event: {ev:?}")
535 }
536 }
537 }
538
539 if let Some(debug_tx) = &debug_tx {
540 debug_tx
541 .unbounded_send(ZetaDebugInfo::SearchResultsFiltered(
542 ZetaContextRetrievalDebugInfo {
543 project: project.clone(),
544 timestamp: Instant::now(),
545 },
546 ))
547 .ok();
548 }
549
550 if selected_ranges.is_empty() {
551 log::trace!("context gathering selected no ranges")
552 }
553
554 selected_ranges.sort_unstable_by(|a, b| {
555 a.start_line
556 .cmp(&b.start_line)
557 .then(b.end_line.cmp(&a.end_line))
558 });
559
560 let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
561
562 for selected_range in selected_ranges {
563 if let Some(ResultBuffer { buffer, snapshot }) =
564 result_buffers_by_path.get(&selected_range.path)
565 {
566 let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
567 let end_point =
568 snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
569 let range =
570 snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
571
572 related_excerpts_by_buffer
573 .entry(buffer.clone())
574 .or_default()
575 .push(range);
576 } else {
577 log::warn!(
578 "selected path that wasn't included in search results: {}",
579 selected_range.path.display()
580 );
581 }
582 }
583
584 anyhow::Ok(related_excerpts_by_buffer)
585 })
586 .await
587 })
588}
589
590async fn request_tool_call<T: JsonSchema>(
591 messages: Vec<LanguageModelRequestMessage>,
592 tool_name: &'static str,
593 model: &Arc<dyn LanguageModel>,
594 cx: &mut AsyncApp,
595) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
596{
597 let schema = schemars::schema_for!(T);
598
599 let request = LanguageModelRequest {
600 messages,
601 tools: vec![LanguageModelRequestTool {
602 name: tool_name.into(),
603 description: schema
604 .get("description")
605 .and_then(|description| description.as_str())
606 .unwrap()
607 .to_string(),
608 input_schema: serde_json::to_value(schema).unwrap(),
609 }],
610 ..Default::default()
611 };
612
613 Ok(model.stream_completion(request, cx).await?)
614}
615
616const MIN_EXCERPT_LEN: usize = 16;
617const MAX_EXCERPT_LEN: usize = 768;
618const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
619
620struct MatchedBuffer {
621 snapshot: BufferSnapshot,
622 line_ranges: Vec<Range<Line>>,
623 full_path: PathBuf,
624}
625
626async fn run_query(
627 glob: &str,
628 regex: &str,
629 results_tx: UnboundedSender<(Entity<Buffer>, MatchedBuffer)>,
630 path_style: PathStyle,
631 exclude_matcher: PathMatcher,
632 project: &Entity<Project>,
633 cx: &mut AsyncApp,
634) -> Result<()> {
635 let include_matcher = PathMatcher::new(vec![glob], path_style)?;
636
637 let query = SearchQuery::regex(
638 regex,
639 false,
640 true,
641 false,
642 true,
643 include_matcher,
644 exclude_matcher,
645 true,
646 None,
647 )?;
648
649 let results = project.update(cx, |project, cx| project.search(query, cx))?;
650 futures::pin_mut!(results);
651
652 let mut total_bytes = 0;
653
654 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
655 if ranges.is_empty() {
656 continue;
657 }
658
659 let Some((snapshot, full_path)) = buffer.read_with(cx, |buffer, cx| {
660 Some((buffer.snapshot(), buffer.file()?.full_path(cx)))
661 })?
662 else {
663 continue;
664 };
665
666 let results_tx = results_tx.clone();
667 cx.background_spawn(async move {
668 let mut line_ranges = Vec::with_capacity(ranges.len());
669
670 for range in ranges {
671 let offset_range = range.to_offset(&snapshot);
672 let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
673
674 if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
675 break;
676 }
677
678 let excerpt = EditPredictionExcerpt::select_from_buffer(
679 query_point,
680 &snapshot,
681 &EditPredictionExcerptOptions {
682 max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
683 min_bytes: MIN_EXCERPT_LEN,
684 target_before_cursor_over_total_bytes: 0.5,
685 },
686 None,
687 );
688
689 if let Some(excerpt) = excerpt {
690 total_bytes += excerpt.range.len();
691 if !excerpt.line_range.is_empty() {
692 line_ranges.push(excerpt.line_range);
693 }
694 }
695 }
696
697 results_tx
698 .unbounded_send((
699 buffer,
700 MatchedBuffer {
701 snapshot,
702 line_ranges,
703 full_path,
704 },
705 ))
706 .log_err();
707 })
708 .detach();
709 }
710
711 anyhow::Ok(())
712}