1use std::{cmp::Reverse, fmt::Write, ops::Range, path::PathBuf, sync::Arc, time::Instant};
2
3use crate::{
4 ZetaContextRetrievalDebugInfo, ZetaDebugInfo, ZetaSearchQueryDebugInfo,
5 merge_excerpts::write_merged_excerpts,
6};
7use anyhow::{Result, anyhow};
8use collections::HashMap;
9use edit_prediction_context::{EditPredictionExcerpt, EditPredictionExcerptOptions, Line};
10use futures::{StreamExt, channel::mpsc, stream::BoxStream};
11use gpui::{App, AsyncApp, Entity, Task};
12use indoc::indoc;
13use language::{Anchor, Bias, Buffer, OffsetRangeExt, Point, TextBufferSnapshot, ToPoint as _};
14use language_model::{
15 LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent, LanguageModelId,
16 LanguageModelRegistry, LanguageModelRequest, LanguageModelRequestMessage,
17 LanguageModelRequestTool, LanguageModelToolResult, MessageContent, Role,
18};
19use project::{
20 Project, WorktreeSettings,
21 search::{SearchQuery, SearchResult},
22};
23use schemars::JsonSchema;
24use serde::Deserialize;
25use util::paths::{PathMatcher, PathStyle};
26use workspace::item::Settings as _;
27
28const SEARCH_PROMPT: &str = indoc! {r#"
29 ## Task
30
31 You are part of an edit prediction system in a code editor. Your role is to identify relevant code locations
32 that will serve as context for predicting the next required edit.
33
34 **Your task:**
35 - Analyze the user's recent edits and current cursor context
36 - Use the `search` tool to find code that may be relevant for predicting the next edit
37 - Focus on finding:
38 - Code patterns that might need similar changes based on the recent edits
39 - Functions, variables, types, and constants referenced in the current cursor context
40 - Related implementations, usages, or dependencies that may require consistent updates
41
42 **Important constraints:**
43 - This conversation has exactly 2 turns
44 - You must make ALL search queries in your first response via the `search` tool
45 - All queries will be executed in parallel and results returned together
46 - In the second turn, you will select the most relevant results via the `select` tool.
47
48 ## User Edits
49
50 {edits}
51
52 ## Current cursor context
53
54 `````filename={current_file_path}
55 {cursor_excerpt}
56 `````
57
58 --
59 Use the `search` tool now
60"#};
61
62const SEARCH_TOOL_NAME: &str = "search";
63
64/// Search for relevant code
65///
66/// For the best results, run multiple queries at once with a single invocation of this tool.
67#[derive(Clone, Deserialize, JsonSchema)]
68pub struct SearchToolInput {
69 /// An array of queries to run for gathering context relevant to the next prediction
70 #[schemars(length(max = 5))]
71 pub queries: Box<[SearchToolQuery]>,
72}
73
74#[derive(Debug, Clone, Deserialize, JsonSchema)]
75pub struct SearchToolQuery {
76 /// A glob pattern to match file paths in the codebase
77 pub glob: String,
78 /// A regular expression to match content within the files matched by the glob pattern
79 pub regex: String,
80 /// Whether the regex is case-sensitive. Defaults to false (case-insensitive).
81 #[serde(default)]
82 pub case_sensitive: bool,
83}
84
85const RESULTS_MESSAGE: &str = indoc! {"
86 Here are the results of your queries combined and grouped by file:
87
88"};
89
90const SELECT_TOOL_NAME: &str = "select";
91
92const SELECT_PROMPT: &str = indoc! {"
93 Use the `select` tool now to pick the most relevant line ranges according to the user state provided in the first message.
94 Make sure to include enough lines of context so that the edit prediction model can suggest accurate edits.
95 Include up to 200 lines in total.
96"};
97
98/// Select line ranges from search results
99#[derive(Deserialize, JsonSchema)]
100struct SelectToolInput {
101 /// The line ranges to select from search results.
102 ranges: Vec<SelectLineRange>,
103}
104
105/// A specific line range to select from a file
106#[derive(Debug, Deserialize, JsonSchema)]
107struct SelectLineRange {
108 /// The file path containing the lines to select
109 /// Exactly as it appears in the search result codeblocks.
110 path: PathBuf,
111 /// The starting line number (1-based)
112 #[schemars(range(min = 1))]
113 start_line: u32,
114 /// The ending line number (1-based, inclusive)
115 #[schemars(range(min = 1))]
116 end_line: u32,
117}
118
119#[derive(Debug, Clone, PartialEq)]
120pub struct LlmContextOptions {
121 pub excerpt: EditPredictionExcerptOptions,
122}
123
124pub fn find_related_excerpts<'a>(
125 buffer: Entity<language::Buffer>,
126 cursor_position: Anchor,
127 project: &Entity<Project>,
128 events: impl Iterator<Item = &'a crate::Event>,
129 options: &LlmContextOptions,
130 debug_tx: Option<mpsc::UnboundedSender<ZetaDebugInfo>>,
131 cx: &App,
132) -> Task<Result<HashMap<Entity<Buffer>, Vec<Range<Anchor>>>>> {
133 let language_model_registry = LanguageModelRegistry::global(cx);
134 let Some(model) = language_model_registry
135 .read(cx)
136 .available_models(cx)
137 .find(|model| {
138 model.provider_id() == language_model::OLLAMA_PROVIDER_ID
139 && model.id() == LanguageModelId("qwen3:8b".into())
140 })
141 else {
142 return Task::ready(Err(anyhow!("could not find claude model")));
143 };
144
145 let mut edits_string = String::new();
146
147 for event in events {
148 if let Some(event) = event.to_request_event(cx) {
149 writeln!(&mut edits_string, "{event}").ok();
150 }
151 }
152
153 if edits_string.is_empty() {
154 edits_string.push_str("(No user edits yet)");
155 }
156
157 // TODO [zeta2] include breadcrumbs?
158 let snapshot = buffer.read(cx).snapshot();
159 let cursor_point = cursor_position.to_point(&snapshot);
160 let Some(cursor_excerpt) =
161 EditPredictionExcerpt::select_from_buffer(cursor_point, &snapshot, &options.excerpt, None)
162 else {
163 return Task::ready(Ok(HashMap::default()));
164 };
165
166 let current_file_path = snapshot
167 .file()
168 .map(|f| f.full_path(cx).display().to_string())
169 .unwrap_or_else(|| "untitled".to_string());
170
171 let prompt = SEARCH_PROMPT
172 .replace("{edits}", &edits_string)
173 .replace("{current_file_path}", ¤t_file_path)
174 .replace("{cursor_excerpt}", &cursor_excerpt.text(&snapshot).body);
175
176 let path_style = project.read(cx).path_style(cx);
177
178 let exclude_matcher = {
179 let global_settings = WorktreeSettings::get_global(cx);
180 let exclude_patterns = global_settings
181 .file_scan_exclusions
182 .sources()
183 .iter()
184 .chain(global_settings.private_files.sources().iter());
185
186 match PathMatcher::new(exclude_patterns, path_style) {
187 Ok(matcher) => matcher,
188 Err(err) => {
189 return Task::ready(Err(anyhow!(err)));
190 }
191 }
192 };
193
194 let project = project.clone();
195 cx.spawn(async move |cx| {
196 let initial_prompt_message = LanguageModelRequestMessage {
197 role: Role::User,
198 content: vec![prompt.into()],
199 cache: false,
200 };
201
202 let mut search_stream = request_tool_call::<SearchToolInput>(
203 vec![initial_prompt_message.clone()],
204 SEARCH_TOOL_NAME,
205 &model,
206 cx,
207 )
208 .await?;
209
210 let mut select_request_messages = Vec::with_capacity(5); // initial prompt, LLM response/thinking, tool use, tool result, select prompt
211 select_request_messages.push(initial_prompt_message);
212 let mut search_calls = Vec::new();
213
214 while let Some(event) = search_stream.next().await {
215 match event? {
216 LanguageModelCompletionEvent::ToolUse(tool_use) => {
217 if !tool_use.is_input_complete {
218 continue;
219 }
220
221 if tool_use.name.as_ref() == SEARCH_TOOL_NAME {
222 search_calls.push((select_request_messages.len(), tool_use));
223 } else {
224 log::warn!(
225 "context gathering model tried to use unknown tool: {}",
226 tool_use.name
227 );
228 }
229 }
230 LanguageModelCompletionEvent::Text(txt) => {
231 if let Some(LanguageModelRequestMessage {
232 role: Role::Assistant,
233 content,
234 ..
235 }) = select_request_messages.last_mut()
236 {
237 if let Some(MessageContent::Text(existing_text)) = content.last_mut() {
238 existing_text.push_str(&txt);
239 } else {
240 content.push(MessageContent::Text(txt));
241 }
242 } else {
243 select_request_messages.push(LanguageModelRequestMessage {
244 role: Role::Assistant,
245 content: vec![MessageContent::Text(txt)],
246 cache: false,
247 });
248 }
249 }
250 LanguageModelCompletionEvent::Thinking { text, signature } => {
251 if let Some(LanguageModelRequestMessage {
252 role: Role::Assistant,
253 content,
254 ..
255 }) = select_request_messages.last_mut()
256 {
257 if let Some(MessageContent::Thinking {
258 text: existing_text,
259 signature: existing_signature,
260 }) = content.last_mut()
261 {
262 existing_text.push_str(&text);
263 *existing_signature = signature;
264 } else {
265 content.push(MessageContent::Thinking { text, signature });
266 }
267 } else {
268 select_request_messages.push(LanguageModelRequestMessage {
269 role: Role::Assistant,
270 content: vec![MessageContent::Thinking { text, signature }],
271 cache: false,
272 });
273 }
274 }
275 LanguageModelCompletionEvent::RedactedThinking { data } => {
276 if let Some(LanguageModelRequestMessage {
277 role: Role::Assistant,
278 content,
279 ..
280 }) = select_request_messages.last_mut()
281 {
282 if let Some(MessageContent::RedactedThinking(existing_data)) =
283 content.last_mut()
284 {
285 existing_data.push_str(&data);
286 } else {
287 content.push(MessageContent::RedactedThinking(data));
288 }
289 } else {
290 select_request_messages.push(LanguageModelRequestMessage {
291 role: Role::Assistant,
292 content: vec![MessageContent::RedactedThinking(data)],
293 cache: false,
294 });
295 }
296 }
297 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
298 log::error!("{ev:?}");
299 }
300 ev => {
301 log::trace!("context search event: {ev:?}")
302 }
303 }
304 }
305
306 struct ResultBuffer {
307 buffer: Entity<Buffer>,
308 snapshot: TextBufferSnapshot,
309 }
310
311 let search_queries = search_calls
312 .iter()
313 .map(|(_, tool_use)| {
314 Ok(serde_json::from_value::<SearchToolInput>(
315 tool_use.input.clone(),
316 )?)
317 })
318 .collect::<Result<Vec<_>>>()?;
319
320 if let Some(debug_tx) = &debug_tx {
321 debug_tx
322 .unbounded_send(ZetaDebugInfo::SearchQueriesGenerated(
323 ZetaSearchQueryDebugInfo {
324 project: project.clone(),
325 timestamp: Instant::now(),
326 queries: search_queries
327 .iter()
328 .flat_map(|call| call.queries.iter().cloned())
329 .collect(),
330 },
331 ))
332 .ok();
333 }
334
335 let mut result_buffers_by_path = HashMap::default();
336
337 for ((index, tool_use), call) in search_calls.into_iter().zip(search_queries).rev() {
338 let mut excerpts_by_buffer = HashMap::default();
339
340 for query in call.queries {
341 // TODO [zeta2] parallelize?
342
343 run_query(
344 query,
345 &mut excerpts_by_buffer,
346 path_style,
347 exclude_matcher.clone(),
348 &project,
349 cx,
350 )
351 .await?;
352 }
353
354 if excerpts_by_buffer.is_empty() {
355 continue;
356 }
357
358 let mut merged_result = RESULTS_MESSAGE.to_string();
359
360 for (buffer_entity, mut excerpts_for_buffer) in excerpts_by_buffer {
361 excerpts_for_buffer.sort_unstable_by_key(|range| (range.start, Reverse(range.end)));
362
363 buffer_entity
364 .clone()
365 .read_with(cx, |buffer, cx| {
366 let Some(file) = buffer.file() else {
367 return;
368 };
369
370 let path = file.full_path(cx);
371
372 writeln!(&mut merged_result, "`````filename={}", path.display()).unwrap();
373
374 let snapshot = buffer.snapshot();
375
376 write_merged_excerpts(
377 &snapshot,
378 excerpts_for_buffer,
379 &[],
380 &mut merged_result,
381 );
382
383 merged_result.push_str("`````\n\n");
384
385 result_buffers_by_path.insert(
386 path,
387 ResultBuffer {
388 buffer: buffer_entity,
389 snapshot: snapshot.text,
390 },
391 );
392 })
393 .ok();
394 }
395
396 let tool_result = LanguageModelToolResult {
397 tool_use_id: tool_use.id.clone(),
398 tool_name: SEARCH_TOOL_NAME.into(),
399 is_error: false,
400 content: merged_result.into(),
401 output: None,
402 };
403
404 // Almost always appends at the end, but in theory, the model could return some text after the tool call
405 // or perform parallel tool calls, so we splice at the message index for correctness.
406 select_request_messages.splice(
407 index..index,
408 [
409 LanguageModelRequestMessage {
410 role: Role::Assistant,
411 content: vec![MessageContent::ToolUse(tool_use)],
412 cache: false,
413 },
414 LanguageModelRequestMessage {
415 role: Role::User,
416 content: vec![MessageContent::ToolResult(tool_result)],
417 cache: false,
418 },
419 ],
420 );
421
422 if let Some(debug_tx) = &debug_tx {
423 debug_tx
424 .unbounded_send(ZetaDebugInfo::SearchQueriesExecuted(
425 ZetaContextRetrievalDebugInfo {
426 project: project.clone(),
427 timestamp: Instant::now(),
428 },
429 ))
430 .ok();
431 }
432 }
433
434 if result_buffers_by_path.is_empty() {
435 log::trace!("context gathering queries produced no results");
436 return anyhow::Ok(HashMap::default());
437 }
438
439 select_request_messages.push(LanguageModelRequestMessage {
440 role: Role::User,
441 content: vec![SELECT_PROMPT.into()],
442 cache: false,
443 });
444
445 let mut select_stream = request_tool_call::<SelectToolInput>(
446 select_request_messages,
447 SELECT_TOOL_NAME,
448 &model,
449 cx,
450 )
451 .await?;
452 let mut selected_ranges = Vec::new();
453
454 while let Some(event) = select_stream.next().await {
455 match event? {
456 LanguageModelCompletionEvent::ToolUse(tool_use) => {
457 if !tool_use.is_input_complete {
458 continue;
459 }
460
461 if tool_use.name.as_ref() == SELECT_TOOL_NAME {
462 let call =
463 serde_json::from_value::<SelectToolInput>(tool_use.input.clone())?;
464 selected_ranges.extend(call.ranges);
465 } else {
466 log::warn!(
467 "context gathering model tried to use unknown tool: {}",
468 tool_use.name
469 );
470 }
471 }
472 ev @ LanguageModelCompletionEvent::ToolUseJsonParseError { .. } => {
473 log::error!("{ev:?}");
474 }
475 ev => {
476 log::trace!("context select event: {ev:?}")
477 }
478 }
479 }
480
481 if selected_ranges.is_empty() {
482 log::trace!("context gathering selected no ranges")
483 }
484
485 let mut related_excerpts_by_buffer: HashMap<_, Vec<_>> = HashMap::default();
486
487 for selected_range in selected_ranges {
488 if let Some(ResultBuffer { buffer, snapshot }) =
489 result_buffers_by_path.get(&selected_range.path)
490 {
491 let start_point = Point::new(selected_range.start_line.saturating_sub(1), 0);
492 let end_point =
493 snapshot.clip_point(Point::new(selected_range.end_line, 0), Bias::Left);
494 let range = snapshot.anchor_after(start_point)..snapshot.anchor_before(end_point);
495
496 related_excerpts_by_buffer
497 .entry(buffer.clone())
498 .or_default()
499 .push(range);
500 } else {
501 log::warn!(
502 "selected path that wasn't included in search results: {}",
503 selected_range.path.display()
504 );
505 }
506 }
507
508 for (buffer, ranges) in &mut related_excerpts_by_buffer {
509 buffer.read_with(cx, |buffer, _cx| {
510 ranges.sort_unstable_by(|a, b| {
511 a.start
512 .cmp(&b.start, buffer)
513 .then(b.end.cmp(&a.end, buffer))
514 });
515 })?;
516 }
517
518 anyhow::Ok(related_excerpts_by_buffer)
519 })
520}
521
522async fn request_tool_call<T: JsonSchema>(
523 messages: Vec<LanguageModelRequestMessage>,
524 tool_name: &'static str,
525 model: &Arc<dyn LanguageModel>,
526 cx: &mut AsyncApp,
527) -> Result<BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>>
528{
529 let schema = schemars::schema_for!(T);
530
531 let request = LanguageModelRequest {
532 messages,
533 tools: vec![LanguageModelRequestTool {
534 name: tool_name.into(),
535 description: schema
536 .get("description")
537 .and_then(|description| description.as_str())
538 .unwrap()
539 .to_string(),
540 input_schema: serde_json::to_value(schema).unwrap(),
541 }],
542 thinking_allowed: false,
543 ..Default::default()
544 };
545
546 Ok(model.stream_completion(request, cx).await?)
547}
548
549const MIN_EXCERPT_LEN: usize = 16;
550const MAX_EXCERPT_LEN: usize = 768;
551const MAX_RESULT_BYTES_PER_QUERY: usize = MAX_EXCERPT_LEN * 5;
552
553async fn run_query(
554 args: SearchToolQuery,
555 excerpts_by_buffer: &mut HashMap<Entity<Buffer>, Vec<Range<Line>>>,
556 path_style: PathStyle,
557 exclude_matcher: PathMatcher,
558 project: &Entity<Project>,
559 cx: &mut AsyncApp,
560) -> Result<()> {
561 let include_matcher = PathMatcher::new(vec![args.glob], path_style)?;
562
563 let query = SearchQuery::regex(
564 &args.regex,
565 false,
566 args.case_sensitive,
567 false,
568 true,
569 include_matcher,
570 exclude_matcher,
571 true,
572 None,
573 )?;
574
575 let results = project.update(cx, |project, cx| project.search(query, cx))?;
576 futures::pin_mut!(results);
577
578 let mut total_bytes = 0;
579
580 while let Some(SearchResult::Buffer { buffer, ranges }) = results.next().await {
581 if ranges.is_empty() {
582 continue;
583 }
584
585 let excerpts_for_buffer = excerpts_by_buffer
586 .entry(buffer.clone())
587 .or_insert_with(|| Vec::with_capacity(ranges.len()));
588
589 let snapshot = buffer.read_with(cx, |buffer, _cx| buffer.snapshot())?;
590
591 for range in ranges {
592 let offset_range = range.to_offset(&snapshot);
593 let query_point = (offset_range.start + offset_range.len() / 2).to_point(&snapshot);
594
595 if total_bytes + MIN_EXCERPT_LEN >= MAX_RESULT_BYTES_PER_QUERY {
596 break;
597 }
598
599 let excerpt = EditPredictionExcerpt::select_from_buffer(
600 query_point,
601 &snapshot,
602 &EditPredictionExcerptOptions {
603 max_bytes: MAX_EXCERPT_LEN.min(MAX_RESULT_BYTES_PER_QUERY - total_bytes),
604 min_bytes: MIN_EXCERPT_LEN,
605 target_before_cursor_over_total_bytes: 0.5,
606 },
607 None,
608 );
609
610 if let Some(excerpt) = excerpt {
611 total_bytes += excerpt.range.len();
612 if !excerpt.line_range.is_empty() {
613 excerpts_for_buffer.push(excerpt.line_range);
614 }
615 }
616 }
617
618 if excerpts_for_buffer.is_empty() {
619 excerpts_by_buffer.remove(&buffer);
620 }
621 }
622
623 anyhow::Ok(())
624}