context.rs

 1use std::path::Path;
 2
 3use crate::{source_location::SourceLocation, training::teacher::TeacherModel};
 4
 5#[derive(Debug, Clone, Default, clap::ValueEnum)]
 6pub enum ContextType {
 7    #[default]
 8    CurrentFile,
 9}
10
11const MAX_CONTEXT_SIZE: usize = 32768;
12
13pub fn collect_context(
14    context_type: &ContextType,
15    worktree_dir: &Path,
16    cursor: SourceLocation,
17) -> String {
18    let context = match context_type {
19        ContextType::CurrentFile => {
20            let file_path = worktree_dir.join(cursor.path.as_std_path());
21            let context = std::fs::read_to_string(&file_path).unwrap_or_default();
22
23            let context = add_special_tags(&context, worktree_dir, cursor);
24            context
25        }
26    };
27
28    let region_end_offset = context.find(TeacherModel::REGION_END);
29
30    if context.len() <= MAX_CONTEXT_SIZE {
31        return context;
32    }
33
34    if let Some(region_end_offset) = region_end_offset
35        && region_end_offset + TeacherModel::REGION_END.len() > MAX_CONTEXT_SIZE
36    {
37        let to_truncate = context.len() - MAX_CONTEXT_SIZE;
38        format!(
39            "[...{} bytes truncated]\n{}\n",
40            to_truncate,
41            &context[to_truncate..]
42        )
43    } else {
44        format!(
45            "{}\n[...{} bytes truncated]\n",
46            &context[..MAX_CONTEXT_SIZE],
47            context.len() - MAX_CONTEXT_SIZE
48        )
49    }
50}
51
52/// Add <|editable_region_start/end|> tags
53fn add_special_tags(context: &str, worktree_dir: &Path, cursor: SourceLocation) -> String {
54    let path = worktree_dir.join(cursor.path.as_std_path());
55    let file = std::fs::read_to_string(&path).unwrap_or_default();
56    let lines = file.lines().collect::<Vec<_>>();
57    let cursor_row = cursor.point.row as usize;
58    let start_line = cursor_row.saturating_sub(TeacherModel::LEFT_CONTEXT_SIZE);
59    let end_line = (cursor_row + TeacherModel::RIGHT_CONTEXT_SIZE).min(lines.len());
60
61    let snippet = lines[start_line..end_line].join("\n");
62
63    if context.contains(&snippet) {
64        let mut cursor_line = lines[cursor_row].to_string();
65        cursor_line.insert_str(cursor.point.column as usize, TeacherModel::USER_CURSOR);
66
67        let mut snippet_with_tags_lines = vec![];
68        snippet_with_tags_lines.push(TeacherModel::REGION_START);
69        snippet_with_tags_lines.extend(&lines[start_line..cursor_row]);
70        snippet_with_tags_lines.push(&cursor_line);
71        snippet_with_tags_lines.extend(&lines[cursor_row + 1..end_line]);
72        snippet_with_tags_lines.push(TeacherModel::REGION_END);
73        let snippet_with_tags = snippet_with_tags_lines.join("\n");
74
75        context.replace(&snippet, &snippet_with_tags)
76    } else {
77        log::warn!(
78            "Can't find area around the cursor in the context; proceeding without special tags"
79        );
80        context.to_string()
81    }
82}
83
84pub fn strip_special_tags(context: &str) -> String {
85    context
86        .replace(TeacherModel::REGION_START, "")
87        .replace(TeacherModel::REGION_END, "")
88        .replace(TeacherModel::USER_CURSOR, "")
89}