Detailed changes
@@ -208,8 +208,9 @@ impl BatchingLlmClient {
model: &str,
max_tokens: u64,
messages: &[Message],
+ seed: Option<usize>,
) -> Result<Option<AnthropicResponse>> {
- let request_hash_str = Self::request_hash(model, max_tokens, messages);
+ let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
let connection = self.connection.lock().unwrap();
let response: Vec<String> = connection.select_bound(
&sql!(SELECT response FROM cache WHERE request_hash = ?1 AND response IS NOT NULL;),
@@ -220,8 +221,14 @@ impl BatchingLlmClient {
.and_then(|text| serde_json::from_str(&text).ok()))
}
- pub fn mark_for_batch(&self, model: &str, max_tokens: u64, messages: &[Message]) -> Result<()> {
- let request_hash = Self::request_hash(model, max_tokens, messages);
+ pub fn mark_for_batch(
+ &self,
+ model: &str,
+ max_tokens: u64,
+ messages: &[Message],
+ seed: Option<usize>,
+ ) -> Result<()> {
+ let request_hash = Self::request_hash(model, max_tokens, messages, seed);
let serializable_messages: Vec<SerializableMessage> = messages
.iter()
@@ -259,13 +266,14 @@ impl BatchingLlmClient {
model: &str,
max_tokens: u64,
messages: Vec<Message>,
+ seed: Option<usize>,
) -> Result<Option<AnthropicResponse>> {
- let response = self.lookup(model, max_tokens, &messages)?;
+ let response = self.lookup(model, max_tokens, &messages, seed)?;
if let Some(response) = response {
return Ok(Some(response));
}
- self.mark_for_batch(model, max_tokens, &messages)?;
+ self.mark_for_batch(model, max_tokens, &messages, seed)?;
Ok(None)
}
@@ -606,13 +614,21 @@ impl BatchingLlmClient {
Ok(all_batch_ids)
}
- fn request_hash(model: &str, max_tokens: u64, messages: &[Message]) -> String {
+ fn request_hash(
+ model: &str,
+ max_tokens: u64,
+ messages: &[Message],
+ seed: Option<usize>,
+ ) -> String {
let mut hasher = std::hash::DefaultHasher::new();
model.hash(&mut hasher);
max_tokens.hash(&mut hasher);
for msg in messages {
message_content_to_string(&msg.content).hash(&mut hasher);
}
+ if let Some(seed) = seed {
+ seed.hash(&mut hasher);
+ }
let request_hash = hasher.finish();
format!("{request_hash:016x}")
}
@@ -655,6 +671,7 @@ impl AnthropicClient {
model: &str,
max_tokens: u64,
messages: Vec<Message>,
+ seed: Option<usize>,
) -> Result<Option<AnthropicResponse>> {
match self {
AnthropicClient::Plain(plain_llm_client) => plain_llm_client
@@ -663,7 +680,7 @@ impl AnthropicClient {
.map(Some),
AnthropicClient::Batch(batching_llm_client) => {
batching_llm_client
- .generate(model, max_tokens, messages)
+ .generate(model, max_tokens, messages, seed)
.await
}
AnthropicClient::Dummy => panic!("Dummy LLM client is not expected to be used"),
@@ -138,8 +138,9 @@ impl BatchingOpenAiClient {
model: &str,
max_tokens: u64,
messages: &[RequestMessage],
+ seed: Option<usize>,
) -> Result<Option<OpenAiResponse>> {
- let request_hash_str = Self::request_hash(model, max_tokens, messages);
+ let request_hash_str = Self::request_hash(model, max_tokens, messages, seed);
let connection = self.connection.lock().unwrap();
let response: Vec<String> = connection.select_bound(
&sql!(SELECT response FROM openai_cache WHERE request_hash = ?1 AND response IS NOT NULL;),
@@ -155,8 +156,9 @@ impl BatchingOpenAiClient {
model: &str,
max_tokens: u64,
messages: &[RequestMessage],
+ seed: Option<usize>,
) -> Result<()> {
- let request_hash = Self::request_hash(model, max_tokens, messages);
+ let request_hash = Self::request_hash(model, max_tokens, messages, seed);
let serializable_messages: Vec<SerializableMessage> = messages
.iter()
@@ -191,13 +193,14 @@ impl BatchingOpenAiClient {
model: &str,
max_tokens: u64,
messages: Vec<RequestMessage>,
+ seed: Option<usize>,
) -> Result<Option<OpenAiResponse>> {
- let response = self.lookup(model, max_tokens, &messages)?;
+ let response = self.lookup(model, max_tokens, &messages, seed)?;
if let Some(response) = response {
return Ok(Some(response));
}
- self.mark_for_batch(model, max_tokens, &messages)?;
+ self.mark_for_batch(model, max_tokens, &messages, seed)?;
Ok(None)
}
@@ -558,7 +561,12 @@ impl BatchingOpenAiClient {
Ok(all_batch_ids)
}
- fn request_hash(model: &str, max_tokens: u64, messages: &[RequestMessage]) -> String {
+ fn request_hash(
+ model: &str,
+ max_tokens: u64,
+ messages: &[RequestMessage],
+ seed: Option<usize>,
+ ) -> String {
let mut hasher = std::hash::DefaultHasher::new();
"openai".hash(&mut hasher);
model.hash(&mut hasher);
@@ -566,6 +574,9 @@ impl BatchingOpenAiClient {
for msg in messages {
message_content_to_string(msg).hash(&mut hasher);
}
+ if let Some(seed) = seed {
+ seed.hash(&mut hasher);
+ }
let request_hash = hasher.finish();
format!("{request_hash:016x}")
}
@@ -631,6 +642,7 @@ impl OpenAiClient {
model: &str,
max_tokens: u64,
messages: Vec<RequestMessage>,
+ seed: Option<usize>,
) -> Result<Option<OpenAiResponse>> {
match self {
OpenAiClient::Plain(plain_client) => plain_client
@@ -638,7 +650,9 @@ impl OpenAiClient {
.await
.map(Some),
OpenAiClient::Batch(batching_client) => {
- batching_client.generate(model, max_tokens, messages).await
+ batching_client
+ .generate(model, max_tokens, messages, seed)
+ .await
}
OpenAiClient::Dummy => panic!("Dummy OpenAI client is not expected to be used"),
}
@@ -69,7 +69,7 @@ pub async fn run_prediction(
.await?;
let batched = matches!(provider, PredictionProvider::Teacher(..));
- return predict_teacher(example, backend, batched).await;
+ return predict_teacher(example, backend, batched, repetition_count).await;
}
run_load_project(example, app_state.clone(), example_progress, cx.clone()).await?;
@@ -261,10 +261,13 @@ async fn predict_teacher(
example: &mut Example,
backend: TeacherBackend,
batched: bool,
+ repetition_count: usize,
) -> anyhow::Result<()> {
match backend {
- TeacherBackend::Sonnet45 => predict_anthropic(example, backend, batched).await,
- TeacherBackend::Gpt52 => predict_openai(example, backend, batched).await,
+ TeacherBackend::Sonnet45 => {
+ predict_anthropic(example, backend, batched, repetition_count).await
+ }
+ TeacherBackend::Gpt52 => predict_openai(example, backend, batched, repetition_count).await,
}
}
@@ -272,6 +275,7 @@ async fn predict_anthropic(
example: &mut Example,
backend: TeacherBackend,
batched: bool,
+ repetition_count: usize,
) -> anyhow::Result<()> {
let llm_model_name = backend.model_name();
let max_tokens = 16384;
@@ -286,46 +290,49 @@ async fn predict_anthropic(
let prompt = example.prompt.as_ref().context("Prompt is required")?;
- let messages = vec![anthropic::Message {
- role: anthropic::Role::User,
- content: vec![anthropic::RequestContent::Text {
- text: prompt.input.clone(),
- cache_control: None,
- }],
- }];
-
- let Some(response) = llm_client
- .generate(llm_model_name, max_tokens, messages)
- .await?
- else {
- // Request stashed for batched processing
- return Ok(());
- };
+ for ix in 0..repetition_count {
+ let messages = vec![anthropic::Message {
+ role: anthropic::Role::User,
+ content: vec![anthropic::RequestContent::Text {
+ text: prompt.input.clone(),
+ cache_control: None,
+ }],
+ }];
+
+ let seed = if repetition_count > 1 { Some(ix) } else { None };
+ let Some(response) = llm_client
+ .generate(llm_model_name, max_tokens, messages, seed)
+ .await?
+ else {
+ // Request stashed for batched processing
+ return Ok(());
+ };
- let actual_output = response
- .content
- .into_iter()
- .filter_map(|content| match content {
- anthropic::ResponseContent::Text { text } => Some(text),
- _ => None,
- })
- .collect::<Vec<String>>()
- .join("\n");
-
- let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
-
- let prediction = ExamplePrediction {
- actual_patch: Some(actual_patch),
- actual_output,
- error: None,
- provider: if batched {
- PredictionProvider::Teacher(backend)
- } else {
- PredictionProvider::TeacherNonBatching(backend)
- },
- };
+ let actual_output = response
+ .content
+ .into_iter()
+ .filter_map(|content| match content {
+ anthropic::ResponseContent::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<String>>()
+ .join("\n");
+
+ let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
+
+ let prediction = ExamplePrediction {
+ actual_patch: Some(actual_patch),
+ actual_output,
+ error: None,
+ provider: if batched {
+ PredictionProvider::Teacher(backend)
+ } else {
+ PredictionProvider::TeacherNonBatching(backend)
+ },
+ };
- example.predictions.push(prediction);
+ example.predictions.push(prediction);
+ }
Ok(())
}
@@ -333,6 +340,7 @@ async fn predict_openai(
example: &mut Example,
backend: TeacherBackend,
batched: bool,
+ repetition_count: usize,
) -> anyhow::Result<()> {
let llm_model_name = backend.model_name();
let max_tokens = 16384;
@@ -347,52 +355,55 @@ async fn predict_openai(
let prompt = example.prompt.as_ref().context("Prompt is required")?;
- let messages = vec![open_ai::RequestMessage::User {
- content: open_ai::MessageContent::Plain(prompt.input.clone()),
- }];
+ for ix in 0..repetition_count {
+ let messages = vec![open_ai::RequestMessage::User {
+ content: open_ai::MessageContent::Plain(prompt.input.clone()),
+ }];
+
+ let seed = if repetition_count > 1 { Some(ix) } else { None };
+ let Some(response) = llm_client
+ .generate(llm_model_name, max_tokens, messages, seed)
+ .await?
+ else {
+ // Request stashed for batched processing
+ return Ok(());
+ };
- let Some(response) = llm_client
- .generate(llm_model_name, max_tokens, messages)
- .await?
- else {
- // Request stashed for batched processing
- return Ok(());
- };
+ let actual_output = response
+ .choices
+ .into_iter()
+ .filter_map(|choice| match choice.message {
+ open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
+ open_ai::MessageContent::Plain(text) => text,
+ open_ai::MessageContent::Multipart(parts) => parts
+ .into_iter()
+ .filter_map(|p| match p {
+ open_ai::MessagePart::Text { text } => Some(text),
+ _ => None,
+ })
+ .collect::<Vec<_>>()
+ .join(""),
+ }),
+ _ => None,
+ })
+ .collect::<Vec<String>>()
+ .join("\n");
- let actual_output = response
- .choices
- .into_iter()
- .filter_map(|choice| match choice.message {
- open_ai::RequestMessage::Assistant { content, .. } => content.map(|c| match c {
- open_ai::MessageContent::Plain(text) => text,
- open_ai::MessageContent::Multipart(parts) => parts
- .into_iter()
- .filter_map(|p| match p {
- open_ai::MessagePart::Text { text } => Some(text),
- _ => None,
- })
- .collect::<Vec<_>>()
- .join(""),
- }),
- _ => None,
- })
- .collect::<Vec<String>>()
- .join("\n");
-
- let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
-
- let prediction = ExamplePrediction {
- actual_patch: Some(actual_patch),
- actual_output,
- error: None,
- provider: if batched {
- PredictionProvider::Teacher(backend)
- } else {
- PredictionProvider::TeacherNonBatching(backend)
- },
- };
+ let actual_patch = TeacherPrompt::parse(example, &actual_output)?;
+
+ let prediction = ExamplePrediction {
+ actual_patch: Some(actual_patch),
+ actual_output,
+ error: None,
+ provider: if batched {
+ PredictionProvider::Teacher(backend)
+ } else {
+ PredictionProvider::TeacherNonBatching(backend)
+ },
+ };
- example.predictions.push(prediction);
+ example.predictions.push(prediction);
+ }
Ok(())
}
@@ -172,7 +172,7 @@ impl QaClient {
cache_control: None,
}],
}];
- let response = client.generate(model, max_tokens, messages).await?;
+ let response = client.generate(model, max_tokens, messages, None).await?;
Ok(response.map(|r| {
r.content
.iter()
@@ -188,7 +188,7 @@ impl QaClient {
let messages = vec![open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(prompt.to_string()),
}];
- let response = client.generate(model, max_tokens, messages).await?;
+ let response = client.generate(model, max_tokens, messages, None).await?;
Ok(response.map(|r| {
r.choices
.into_iter()
@@ -152,7 +152,7 @@ impl RepairClient {
cache_control: None,
}],
}];
- let response = client.generate(model, max_tokens, messages).await?;
+ let response = client.generate(model, max_tokens, messages, None).await?;
Ok(response.map(|r| {
r.content
.iter()
@@ -168,7 +168,7 @@ impl RepairClient {
let messages = vec![open_ai::RequestMessage::User {
content: open_ai::MessageContent::Plain(prompt.to_string()),
}];
- let response = client.generate(model, max_tokens, messages).await?;
+ let response = client.generate(model, max_tokens, messages, None).await?;
Ok(response.map(|r| {
r.choices
.into_iter()