1use serde::Deserialize;
2use std::sync::Arc;
3
4use crate::{
5 DistillArguments,
6 example::Example,
7 source_location::SourceLocation,
8 training::{
9 context::ContextType,
10 llm_client::LlmClient,
11 teacher::{TeacherModel, TeacherOutput},
12 },
13};
14use anyhow::Result;
15use reqwest_client::ReqwestClient;
16
17#[derive(Debug, Deserialize)]
18pub struct SplitCommit {
19 repo_url: String,
20 commit_sha: String,
21 edit_history: String,
22 expected_patch: String,
23 cursor_position: String,
24}
25
26pub async fn run_distill(arguments: DistillArguments) -> Result<()> {
27 let split_commits: Vec<SplitCommit> = std::fs::read_to_string(&arguments.split_commit_dataset)
28 .expect("Failed to read split commit dataset")
29 .lines()
30 .map(|line| serde_json::from_str(line).expect("Failed to parse JSON line"))
31 .collect();
32
33 let http_client: Arc<dyn http_client::HttpClient> = Arc::new(ReqwestClient::new());
34
35 let llm_client = if let Some(cache_path) = arguments.batch {
36 LlmClient::batch(&cache_path, http_client)?
37 } else {
38 LlmClient::plain(http_client)?
39 };
40
41 let mut teacher = TeacherModel::new(
42 "claude-sonnet-4-5".to_string(),
43 ContextType::CurrentFile,
44 llm_client,
45 );
46
47 let mut num_marked_for_batching = 0;
48
49 for commit in split_commits {
50 if let Some(distilled) = distill_one(&mut teacher, commit).await? {
51 println!("{}", serde_json::to_string(&distilled)?);
52 } else {
53 if num_marked_for_batching == 0 {
54 log::warn!("Marked for batching");
55 }
56 num_marked_for_batching += 1;
57 }
58 }
59
60 eprintln!(
61 "{} requests are marked for batching",
62 num_marked_for_batching
63 );
64 let llm_client = teacher.client;
65 llm_client.sync_batches().await?;
66
67 Ok(())
68}
69
70pub async fn distill_one(
71 teacher: &mut TeacherModel,
72 commit: SplitCommit,
73) -> Result<Option<TeacherOutput>> {
74 let cursor: SourceLocation = commit
75 .cursor_position
76 .parse()
77 .expect("Failed to parse cursor position");
78
79 let path = cursor.path.to_rel_path_buf();
80
81 let example = Example {
82 repository_url: commit.repo_url,
83 revision: commit.commit_sha,
84 uncommitted_diff: commit.edit_history.clone(),
85 cursor_path: path.as_std_path().to_path_buf(),
86 cursor_position: commit.cursor_position,
87 edit_history: commit.edit_history, // todo: trim
88 expected_patch: commit.expected_patch,
89 };
90
91 let prediction = teacher.predict(example).await;
92
93 prediction
94}