Change summary
crates/edit_prediction/src/edit_prediction.rs | 8 +++++++-
crates/edit_prediction/src/zeta.rs | 6 +++++-
crates/edit_prediction_cli/src/predict.rs | 7 ++++++-
3 files changed, 18 insertions(+), 3 deletions(-)
Detailed changes
@@ -125,6 +125,7 @@ impl Global for EditPredictionStoreGlobal {}
#[derive(Clone)]
pub struct Zeta2RawConfig {
pub model_id: Option<String>,
+ pub environment: Option<String>,
pub format: ZetaFormat,
}
@@ -760,7 +761,12 @@ impl EditPredictionStore {
let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
let format = ZetaFormat::parse(&version_str).ok()?;
let model_id = env::var("ZED_ZETA_MODEL").ok();
- Some(Zeta2RawConfig { model_id, format })
+ let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
+ Some(Zeta2RawConfig {
+ model_id,
+ environment,
+ format,
+ })
}
pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
@@ -186,13 +186,17 @@ pub fn request_prediction_with_zeta(
let prompt = format_zeta_prompt(&prompt_input, config.format);
let prefill = get_prefill(&prompt_input, config.format);
let prompt = format!("{prompt}{prefill}");
+ let environment = config
+ .environment
+ .clone()
+ .or_else(|| Some(config.format.to_string().to_lowercase()));
let request = RawCompletionRequest {
model: config.model_id.clone().unwrap_or_default(),
prompt,
temperature: None,
stop: vec![],
max_tokens: Some(2048),
- environment: Some(config.format.to_string().to_lowercase()),
+ environment,
};
editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(
@@ -148,7 +148,12 @@ pub async fn run_prediction(
if let PredictionProvider::Zeta2(format) = provider {
if format != ZetaFormat::default() {
let model_id = std::env::var("ZED_ZETA_MODEL").ok();
- store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
+ let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
+ store.set_zeta2_raw_config(Zeta2RawConfig {
+ model_id,
+ environment,
+ format,
+ });
}
}
});