distill.rs

 1use serde::Deserialize;
 2use std::sync::Arc;
 3
 4use crate::{
 5    DistillArguments,
 6    example::Example,
 7    source_location::SourceLocation,
 8    training::{
 9        context::ContextType,
10        llm_client::LlmClient,
11        teacher::{TeacherModel, TeacherOutput},
12    },
13};
14use anyhow::Result;
15use reqwest_client::ReqwestClient;
16
17#[derive(Debug, Deserialize)]
18pub struct SplitCommit {
19    repo_url: String,
20    commit_sha: String,
21    edit_history: String,
22    expected_patch: String,
23    cursor_position: String,
24}
25
26pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
27    let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
28        .expect("Failed to read split commit dataset")
29        .lines()
30        .map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
31        .collect();
32
33    let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
34
35    let llm_client = if let Some(cache_path) = arguments.batch {
36        LlmClient::batch(&cache_path, http_client)?
37    } else {
38        LlmClient::plain(http_client)?
39    };
40
41    let mut teacher = TeacherModel::new(
42        "claude-sonnet-4-5".to_string(),
43        ContextType::CurrentFile,
44        llm_client,
45    );
46
47    let mut num_marked_for_batching = 0;
48
49    for commit in split_commits {
50        if let Some(distilled) = distill_one(&mut teacher, commit).await? {
51            println!("{}", serde_json::to_string(&distilled)?);
52        } else {
53            if num_marked_for_batching == 0 {
54                log::warn!("Marked for batching");
55            }
56            num_marked_for_batching += 1;
57        }
58    }
59
60    eprintln!(
61        "{} requests are marked for batching",
62        num_marked_for_batching
63    );
64    let llm_client = teacher.client;
65    llm_client.sync_batches().await?;
66
67    Ok(())
68}
69
70pub async fn distill_one(
71    teacher: &mut TeacherModel,
72    commit: SplitCommit,
73) -> Result<Option<TeacherOutput>> {
74    let cursor: SourceLocation = commit
75        .cursor_position
76        .parse()
77        .expect("Failed to parse cursor position");
78
79    let path = cursor.path.to_rel_path_buf();
80
81    let example = Example {
82        repository_url: commit.repo_url,
83        revision: commit.commit_sha,
84        uncommitted_diff: commit.edit_history.clone(),
85        cursor_path: path.as_std_path().to_path_buf(),
86        cursor_position: commit.cursor_position,
87        edit_history: commit.edit_history, // todo: trim
88        expected_patch: commit.expected_patch,
89    };
90
91    let prediction = teacher.predict(example).await;
92
93    prediction
94}