diff --git a/crates/edit_prediction/src/mercury.rs b/crates/edit_prediction/src/mercury.rs index 71362f4c873ca7b6f89030392449916cdc297b8e..df47a38062344512a784c6d2feb563e9848afb27 100644 --- a/crates/edit_prediction/src/mercury.rs +++ b/crates/edit_prediction/src/mercury.rs @@ -137,6 +137,7 @@ impl Mercury { content: open_ai::MessageContent::Plain(prompt), }], stream: false, + stream_options: None, max_completion_tokens: None, stop: vec![], temperature: None, diff --git a/crates/edit_prediction_cli/src/openai_client.rs b/crates/edit_prediction_cli/src/openai_client.rs index c9947e16099c7923e6c948045eda8ca08ff625cf..6bc9c2d77c0d6be6e2955182ebbce096be422945 100644 --- a/crates/edit_prediction_cli/src/openai_client.rs +++ b/crates/edit_prediction_cli/src/openai_client.rs @@ -40,6 +40,7 @@ impl PlainOpenAiClient { model: model.to_string(), messages, stream: false, + stream_options: None, max_completion_tokens: Some(max_tokens), stop: Vec::new(), temperature: None, @@ -490,6 +491,7 @@ impl BatchingOpenAiClient { model: serializable_request.model, messages, stream: false, + stream_options: None, max_completion_tokens: Some(serializable_request.max_tokens), stop: Vec::new(), temperature: None, diff --git a/crates/language_models/src/provider/open_ai.rs b/crates/language_models/src/provider/open_ai.rs index c1ebf76e0b0678d35a5e013e87f9efd9488a4e8d..e033be8ee234156bc2452c11fad438be70a7f143 100644 --- a/crates/language_models/src/provider/open_ai.rs +++ b/crates/language_models/src/provider/open_ai.rs @@ -506,6 +506,11 @@ pub fn into_open_ai( model: model_id.into(), messages, stream, + stream_options: if stream { + Some(open_ai::StreamOptions::default()) + } else { + None + }, stop: request.stop, temperature: request.temperature.or(Some(1.0)), max_completion_tokens: max_output_tokens, diff --git a/crates/open_ai/src/open_ai.rs b/crates/open_ai/src/open_ai.rs index 25946591e320df4e2d58e8dd0341d7f27451cc89..c4a3e078d76eb028b90e5b80fe95b1281b795f34 100644 --- a/crates/open_ai/src/open_ai.rs +++ b/crates/open_ai/src/open_ai.rs @@ -295,12 +295,27 @@ impl Model { } } +#[derive(Debug, Serialize, Deserialize)] +pub struct StreamOptions { + pub include_usage: bool, +} + +impl Default for StreamOptions { + fn default() -> Self { + Self { + include_usage: true, + } + } +} + #[derive(Debug, Serialize, Deserialize)] pub struct Request { pub model: String, pub messages: Vec, pub stream: bool, #[serde(default, skip_serializing_if = "Option::is_none")] + pub stream_options: Option, + #[serde(default, skip_serializing_if = "Option::is_none")] pub max_completion_tokens: Option, #[serde(default, skip_serializing_if = "Vec::is_empty")] pub stop: Vec,