1use std::path::Path;
2
3use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
4
5#[derive(Debug, Clone, Default, clap::ValueEnum)]
6pub enum ContextType {
7 #[default]
8 CurrentFile,
9}
10
11const MAX_CONTEXT_SIZE: usize = 32768;
12
13pub fn collect_context(
14 context_type: &ContextType,
15 worktree_dir: &Path,
16 cursor: SourceLocation,
17) -> String {
18 let context = match context_type {
19 ContextType::CurrentFile => {
20 let file_path = worktree_dir.join(cursor.path.as_std_path());
21 let context = std::fs::read_to_string(&file_path).unwrap_or_default();
22
23 let context = add_special_tags(&context, worktree_dir, cursor);
24 context
25 }
26 };
27
28 let region_end_offset = context.find(TeacherModel::REGION_END);
29
30 if context.len() <= MAX_CONTEXT_SIZE {
31 return context;
32 }
33
34 if let Some(region_end_offset) = region_end_offset
35 && region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
36 {
37 let to_truncate = context.len() - MAX_CONTEXT_SIZE;
38 format!(
39 "[...{} bytes truncated]\n{}\n",
40 to_truncate,
41 &context[to_truncate..]
42 )
43 } else {
44 format!(
45 "{}\n[...{} bytes truncated]\n",
46 &context[..MAX_CONTEXT_SIZE],
47 context.len() - MAX_CONTEXT_SIZE
48 )
49 }
50}
51
52/// Add <|editable_region_start/end|> tags
53fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
54 let path = worktree_dir.join(cursor.path.as_std_path());
55 let file = std::fs::read_to_string(&path).unwrap_or_default();
56 let lines = file.lines().collect::<Vec<_>>();
57 let cursor_row = cursor.point.row as usize;
58 let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
59 let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
60
61 let snippet = lines[start_line..end_line].join("\n");
62
63 if context.contains(&snippet) {
64 let mut cursor_line = lines[cursor_row].to_string();
65 cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
66
67 let mut snippet_with_tags_lines = vec![];
68 snippet_with_tags_lines.push(TeacherModel::REGION_START);
69 snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
70 snippet_with_tags_lines.push(&cursor_line);
71 snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
72 snippet_with_tags_lines.push(TeacherModel::REGION_END);
73 let snippet_with_tags = snippet_with_tags_lines.join("\n");
74
75 context.replace(&snippet, &snippet_with_tags)
76 } else {
77 log::warn!(
78 "Can't find area around the cursor in the context; proceeding without special tags"
79 );
80 context.to_string()
81 }
82}
83
84pub fn strip_special_tags(context: &str) -> String {
85 context
86 .replace(TeacherModel::REGION_START, "")
87 .replace(TeacherModel::REGION_END, "")
88 .replace(TeacherModel::USER_CURSOR, "")
89}