1use std::{cmp::Reverse, fmt::Write, ops::Range, path::PathBuf, sync::Arc};
2
3use crate::merge_excerpts::write_merged_excerpts;
4use anyhow::{Result, anyhow};
5use collections::HashMap;
6use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
7use futures::{StreamExt, stream::BoxStream};
8use gpui::{App, AsyncApp, Entity, Task};
9use indoc::indoc;
10use language::{Anchor, Bias, Buffer, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _};
11use language_model::{
12 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
13 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
14 LanguageModelRequestTool, LanguageModelToolResult, MessageContent, Role,
15};
16use project::{
17 Project, WorktreeSettings,
18 search::{SearchQuery, SearchResult},
19};
20use schemars::JsonSchema;
21use serde::Deserialize;
22use util::paths::{PathMatcher, PathStyle};
23use workspace::item::Settings as _;
24
25const SEARCH_PROMPT: &str = indoc! {r#"
26 ## Task
27
28 You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
29 that will serve as context for predicting the next required edit.
30
31 **Your task:**
32 - Analyze the user's recent edits and current cursor context
33 - Use the `search` tool to find code that may be relevant for predicting the next edit
34 - Focus on finding:
35 - Code patterns that might need similar changes based on the recent edits
36 - Functions, variables, types, and constants referenced in the current cursor context
37 - Related implementations, usages, or dependencies that may require consistent updates
38
39 **Important constraints:**
40 - This conversation has exactly 2 turns
41 - You must make ALL search queries in your first response via the `search` tool
42 - All queries will be executed in parallel and results returned together
43 - In the second turn, you will select the most relevant results via the `select` tool.
44
45 ## User Edits
46
47 {edits}
48
49 ## Current cursor context
50
51 `````filename={current_file_path}
52 {cursor_excerpt}
53 `````
54
55 --
56 Use the `search` tool now
57"#};
58
59const SEARCH_TOOL_NAME: &str = "search";
60
61/// Search for relevant code
62///
63/// For the best results, run multiple queries at once with a single invocation of this tool.
64#[derive(Deserialize, JsonSchema)]
65struct SearchToolInput {
66 /// An array of queries to run for gathering context relevant to the next prediction
67 #[schemars(length(max = 5))]
68 queries: Box<[SearchToolQuery]>,
69}
70
71#[derive(Deserialize, JsonSchema)]
72struct SearchToolQuery {
73 /// A glob pattern to match file paths in the codebase
74 glob: String,
75 /// A regular expression to match content within the files matched by the glob pattern
76 regex: String,
77 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
78 #[serde(default)]
79 case_sensitive: bool,
80}
81
82const RESULTS_MESSAGE: &str = indoc! {"
83 Here are the results of your queries combined and grouped by file:
84
85"};
86
87const SELECT_TOOL_NAME: &str = "select";
88
89const SELECT_PROMPT: &str = indoc! {"
90 Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
91 Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
92 Include up to 200 lines in total.
93"};
94
95/// Select line ranges from search results
96#[derive(Deserialize, JsonSchema)]
97struct SelectToolInput {
98 /// The line ranges to select from search results.
99 ranges: Vec<SelectLineRange>,
100}
101
102/// A specific line range to select from a file
103#[derive(Debug, Deserialize, JsonSchema)]
104struct SelectLineRange {
105 /// The file path containing the lines to select
106 /// Exactly as it appears in the search result codeblocks.
107 path: PathBuf,
108 /// The starting line number (1-based)
109 #[schemars(range(min = 1))]
110 start_line: u32,
111 /// The ending line number (1-based, inclusive)
112 #[schemars(range(min = 1))]
113 end_line: u32,
114}
115
116#[derive(Debug, Clone, PartialEq)]
117pub struct LlmContextOptions {
118 pub excerpt: EditPredictionExcerptOptions,
119}
120
121pub fn find_related_excerpts<'a>(
122 buffer: Entity<language::Buffer>,
123 cursor_position: Anchor,
124 project: &Entity<Project>,
125 events: impl Iterator<Item = &'a crate::Event>,
126 options: &LlmContextOptions,
127 cx: &App,
128) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
129 let language_model_registry = LanguageModelRegistry::global(cx);
130 let Some(model) = language_model_registry
131 .read(cx)
132 .available_models(cx)
133 .find(|model| {
134 model.provider_id() == language_model::ANTHROPIC_PROVIDER_ID
135 && model.id() == LanguageModelId("claude-haiku-4-5-latest".into())
136 })
137 else {
138 return Task::ready(Err(anyhow!("could not find claude model")));
139 };
140
141 let mut edits_string = String::new();
142
143 for event in events {
144 if let Some(event) = event.to_request_event(cx) {
145 writeln!(&mut edits_string, "{event}").ok();
146 }
147 }
148
149 if edits_string.is_empty() {
150 edits_string.push_str("(No user edits yet)");
151 }
152
153 // TODO [zeta2] include breadcrumbs?
154 let snapshot = buffer.read(cx).snapshot();
155 let cursor_point = cursor_position.to_point(&snapshot);
156 let Some(cursor_excerpt) =
157 EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
158 else {
159 return Task::ready(Ok(HashMap::default()));
160 };
161
162 let current_file_path = snapshot
163 .file()
164 .map(|f| f.full_path(cx).display().to_string())
165 .unwrap_or_else(|| "untitled".to_string());
166
167 let prompt = SEARCH_PROMPT
168 .replace("{edits}", &edits_string)
169 .replace("{current_file_path}", ¤t_file_path)
170 .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
171
172 let path_style = project.read(cx).path_style(cx);
173
174 let exclude_matcher = {
175 let global_settings = WorktreeSettings::get_global(cx);
176 let exclude_patterns = global_settings
177 .file_scan_exclusions
178 .sources()
179 .iter()
180 .chain(global_settings.private_files.sources().iter());
181
182 match PathMatcher::new(exclude_patterns, path_style) {
183 Ok(matcher) => matcher,
184 Err(err) => {
185 return Task::ready(Err(anyhow!(err)));
186 }
187 }
188 };
189
190 let project = project.clone();
191 cx.spawn(async move |cx| {
192 let initial_prompt_message = LanguageModelRequestMessage {
193 role: Role::User,
194 content: vec![prompt.into()],
195 cache: false,
196 };
197
198 let mut search_stream = request_tool_call::<SearchToolInput>(
199 vec![initial_prompt_message.clone()],
200 SEARCH_TOOL_NAME,
201 &model,
202 cx,
203 )
204 .await?;
205
206 let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
207 select_request_messages.push(initial_prompt_message);
208 let mut search_calls = Vec::new();
209
210 while let Some(event) = search_stream.next().await {
211 match event? {
212 LanguageModelCompletionEvent::ToolUse(tool_use) => {
213 if !tool_use.is_input_complete {
214 continue;
215 }
216
217 if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
218 search_calls.push((select_request_messages.len(), tool_use));
219 } else {
220 log::warn!(
221 "context gathering model tried to use unknown tool: {}",
222 tool_use.name
223 );
224 }
225 }
226 LanguageModelCompletionEvent::Text(txt) => {
227 if let Some(LanguageModelRequestMessage {
228 role: Role::Assistant,
229 content,
230 ..
231 }) = select_request_messages.last_mut()
232 {
233 if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
234 existing_text.push_str(&txt);
235 } else {
236 content.push(MessageContent::Text(txt));
237 }
238 } else {
239 select_request_messages.push(LanguageModelRequestMessage {
240 role: Role::Assistant,
241 content: vec![MessageContent::Text(txt)],
242 cache: false,
243 });
244 }
245 }
246 LanguageModelCompletionEvent::Thinking { text, signature } => {
247 if let Some(LanguageModelRequestMessage {
248 role: Role::Assistant,
249 content,
250 ..
251 }) = select_request_messages.last_mut()
252 {
253 if let Some(MessageContent::Thinking {
254 text: existing_text,
255 signature: existing_signature,
256 }) = content.last_mut()
257 {
258 existing_text.push_str(&text);
259 *existing_signature = signature;
260 } else {
261 content.push(MessageContent::Thinking { text, signature });
262 }
263 } else {
264 select_request_messages.push(LanguageModelRequestMessage {
265 role: Role::Assistant,
266 content: vec![MessageContent::Thinking { text, signature }],
267 cache: false,
268 });
269 }
270 }
271 LanguageModelCompletionEvent::RedactedThinking { data } => {
272 if let Some(LanguageModelRequestMessage {
273 role: Role::Assistant,
274 content,
275 ..
276 }) = select_request_messages.last_mut()
277 {
278 if let Some(MessageContent::RedactedThinking(existing_data)) =
279 content.last_mut()
280 {
281 existing_data.push_str(&data);
282 } else {
283 content.push(MessageContent::RedactedThinking(data));
284 }
285 } else {
286 select_request_messages.push(LanguageModelRequestMessage {
287 role: Role::Assistant,
288 content: vec![MessageContent::RedactedThinking(data)],
289 cache: false,
290 });
291 }
292 }
293 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
294 log::error!("{ev:?}");
295 }
296 ev => {
297 log::trace!("context search event: {ev:?}")
298 }
299 }
300 }
301
302 struct ResultBuffer {
303 buffer: Entity<Buffer>,
304 snapshot: TextBufferSnapshot,
305 }
306
307 let mut result_buffers_by_path = HashMap::default();
308
309 for (index, tool_use) in search_calls.into_iter().rev() {
310 let call = serde_json::from_value::<SearchToolInput>(tool_use.input.clone())?;
311
312 let mut excerpts_by_buffer = HashMap::default();
313
314 for query in call.queries {
315 // TODO [zeta2] parallelize?
316
317 run_query(
318 query,
319 &mut excerpts_by_buffer,
320 path_style,
321 exclude_matcher.clone(),
322 &project,
323 cx,
324 )
325 .await?;
326 }
327
328 if excerpts_by_buffer.is_empty() {
329 continue;
330 }
331
332 let mut merged_result = RESULTS_MESSAGE.to_string();
333
334 for (buffer_entity, mut excerpts_for_buffer) in excerpts_by_buffer {
335 excerpts_for_buffer.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
336
337 buffer_entity
338 .clone()
339 .read_with(cx, |buffer, cx| {
340 let Some(file) = buffer.file() else {
341 return;
342 };
343
344 let path = file.full_path(cx);
345
346 writeln!(&mut merged_result, "`````filename={}", path.display()).unwrap();
347
348 let snapshot = buffer.snapshot();
349
350 write_merged_excerpts(
351 &snapshot,
352 excerpts_for_buffer,
353 &[],
354 &mut merged_result,
355 );
356
357 merged_result.push_str("`````\n\n");
358
359 result_buffers_by_path.insert(
360 path,
361 ResultBuffer {
362 buffer: buffer_entity,
363 snapshot: snapshot.text,
364 },
365 );
366 })
367 .ok();
368 }
369
370 let tool_result = LanguageModelToolResult {
371 tool_use_id: tool_use.id.clone(),
372 tool_name: SEARCH_TOOL_NAME.into(),
373 is_error: false,
374 content: merged_result.into(),
375 output: None,
376 };
377
378 // Almost always appends at the end, but in theory, the model could return some text after the tool call
379 // or perform parallel tool calls, so we splice at the message index for correctness.
380 select_request_messages.splice(
381 index..index,
382 [
383 LanguageModelRequestMessage {
384 role: Role::Assistant,
385 content: vec![MessageContent::ToolUse(tool_use)],
386 cache: false,
387 },
388 LanguageModelRequestMessage {
389 role: Role::User,
390 content: vec![MessageContent::ToolResult(tool_result)],
391 cache: false,
392 },
393 ],
394 );
395 }
396
397 if result_buffers_by_path.is_empty() {
398 log::trace!("context gathering queries produced no results");
399 return anyhow::Ok(HashMap::default());
400 }
401
402 select_request_messages.push(LanguageModelRequestMessage {
403 role: Role::User,
404 content: vec![SELECT_PROMPT.into()],
405 cache: false,
406 });
407
408 let mut select_stream = request_tool_call::<SelectToolInput>(
409 select_request_messages,
410 SELECT_TOOL_NAME,
411 &model,
412 cx,
413 )
414 .await?;
415 let mut selected_ranges = Vec::new();
416
417 while let Some(event) = select_stream.next().await {
418 match event? {
419 LanguageModelCompletionEvent::ToolUse(tool_use) => {
420 if !tool_use.is_input_complete {
421 continue;
422 }
423
424 if tool_use.name.as_ref() == SELECT_TOOL_NAME {
425 let call =
426 serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
427 selected_ranges.extend(call.ranges);
428 } else {
429 log::warn!(
430 "context gathering model tried to use unknown tool: {}",
431 tool_use.name
432 );
433 }
434 }
435 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
436 log::error!("{ev:?}");
437 }
438 ev => {
439 log::trace!("context select event: {ev:?}")
440 }
441 }
442 }
443
444 if selected_ranges.is_empty() {
445 log::trace!("context gathering selected no ranges")
446 }
447
448 let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
449
450 for selected_range in selected_ranges {
451 if let Some(ResultBuffer { buffer, snapshot }) =
452 result_buffers_by_path.get(&selected_range.path)
453 {
454 let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
455 let end_point =
456 snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
457 let range = snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
458
459 related_excerpts_by_buffer
460 .entry(buffer.clone())
461 .or_default()
462 .push(range);
463 } else {
464 log::warn!(
465 "selected path that wasn't included in search results: {}",
466 selected_range.path.display()
467 );
468 }
469 }
470
471 for (buffer, ranges) in &mut related_excerpts_by_buffer {
472 buffer.read_with(cx, |buffer, _cx| {
473 ranges.sort_unstable_by(|a, b| {
474 a.start
475 .cmp(&b.start, buffer)
476 .then(b.end.cmp(&a.end, buffer))
477 });
478 })?;
479 }
480
481 anyhow::Ok(related_excerpts_by_buffer)
482 })
483}
484
485async fn request_tool_call<T: JsonSchema>(
486 messages: Vec<LanguageModelRequestMessage>,
487 tool_name: &'static str,
488 model: &Arc<dyn LanguageModel>,
489 cx: &mut AsyncApp,
490) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
491{
492 let schema = schemars::schema_for!(T);
493
494 let request = LanguageModelRequest {
495 messages,
496 tools: vec![LanguageModelRequestTool {
497 name: tool_name.into(),
498 description: schema
499 .get("description")
500 .and_then(|description| description.as_str())
501 .unwrap()
502 .to_string(),
503 input_schema: serde_json::to_value(schema).unwrap(),
504 }],
505 ..Default::default()
506 };
507
508 Ok(model.stream_completion(request, cx).await?)
509}
510
511const MIN_EXCERPT_LEN: usize = 16;
512const MAX_EXCERPT_LEN: usize = 768;
513const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
514
515async fn run_query(
516 args: SearchToolQuery,
517 excerpts_by_buffer: &mut HashMap<Entity<Buffer>, Vec<Range<Line>>>,
518 path_style: PathStyle,
519 exclude_matcher: PathMatcher,
520 project: &Entity<Project>,
521 cx: &mut AsyncApp,
522) -> Result<()> {
523 let include_matcher = PathMatcher::new(vec![args.glob], path_style)?;
524
525 let query = SearchQuery::regex(
526 &args.regex,
527 false,
528 args.case_sensitive,
529 false,
530 true,
531 include_matcher,
532 exclude_matcher,
533 true,
534 None,
535 )?;
536
537 let results = project.update(cx, |project, cx| project.search(query, cx))?;
538 futures::pin_mut!(results);
539
540 let mut total_bytes = 0;
541
542 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
543 if ranges.is_empty() {
544 continue;
545 }
546
547 let excerpts_for_buffer = excerpts_by_buffer
548 .entry(buffer.clone())
549 .or_insert_with(|| Vec::with_capacity(ranges.len()));
550
551 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
552
553 for range in ranges {
554 let offset_range = range.to_offset(&snapshot);
555 let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
556
557 if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
558 break;
559 }
560
561 let excerpt = EditPredictionExcerpt::select_from_buffer(
562 query_point,
563 &snapshot,
564 &EditPredictionExcerptOptions {
565 max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
566 min_bytes: MIN_EXCERPT_LEN,
567 target_before_cursor_over_total_bytes: 0.5,
568 },
569 None,
570 );
571
572 if let Some(excerpt) = excerpt {
573 total_bytes += excerpt.range.len();
574 if !excerpt.line_range.is_empty() {
575 excerpts_for_buffer.push(excerpt.line_range);
576 }
577 }
578 }
579
580 if excerpts_for_buffer.is_empty() {
581 excerpts_by_buffer.remove(&buffer);
582 }
583 }
584
585 anyhow::Ok(())
586}