diff --git a/crates/edit_prediction/src/edit_prediction.rs b/crates/edit_prediction/src/edit_prediction.rs index 6b2019aa30030b0852f74bc851e2012feac4f0e2..5c7ce045121739f341b84dd87d827878550f4048 100644 --- a/crates/edit_prediction/src/edit_prediction.rs +++ b/crates/edit_prediction/src/edit_prediction.rs @@ -125,6 +125,7 @@ impl Global for EditPredictionStoreGlobal {} #[derive(Clone)] pub struct Zeta2RawConfig { pub model_id: Option, + pub environment: Option, pub format: ZetaFormat, } @@ -760,7 +761,12 @@ impl EditPredictionStore { let version_str = env::var("ZED_ZETA_FORMAT").ok()?; let format = ZetaFormat::parse(&version_str).ok()?; let model_id = env::var("ZED_ZETA_MODEL").ok(); - Some(Zeta2RawConfig { model_id, format }) + let environment = env::var("ZED_ZETA_ENVIRONMENT").ok(); + Some(Zeta2RawConfig { + model_id, + environment, + format, + }) } pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) { diff --git a/crates/edit_prediction/src/zeta.rs b/crates/edit_prediction/src/zeta.rs index 8c158c074bf926d2cee9b77cec65b28c4317a22a..ccb058e1193eaf2919c286c6e675a907e4af159f 100644 --- a/crates/edit_prediction/src/zeta.rs +++ b/crates/edit_prediction/src/zeta.rs @@ -186,13 +186,17 @@ pub fn request_prediction_with_zeta( let prompt = format_zeta_prompt(&prompt_input, config.format); let prefill = get_prefill(&prompt_input, config.format); let prompt = format!("{prompt}{prefill}"); + let environment = config + .environment + .clone() + .or_else(|| Some(config.format.to_string().to_lowercase())); let request = RawCompletionRequest { model: config.model_id.clone().unwrap_or_default(), prompt, temperature: None, stop: vec![], max_tokens: Some(2048), - environment: Some(config.format.to_string().to_lowercase()), + environment, }; editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format( diff --git a/crates/edit_prediction_cli/src/predict.rs b/crates/edit_prediction_cli/src/predict.rs index 8f537dc0817a9cb0b4fd74348ae5e43d4f63beb9..bd89d54ab37521ecb9661b6f1bb0156f30ba1acb 100644 --- a/crates/edit_prediction_cli/src/predict.rs +++ b/crates/edit_prediction_cli/src/predict.rs @@ -148,7 +148,12 @@ pub async fn run_prediction( if let PredictionProvider::Zeta2(format) = provider { if format != ZetaFormat::default() { let model_id = std::env::var("ZED_ZETA_MODEL").ok(); - store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format }); + let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok(); + store.set_zeta2_raw_config(Zeta2RawConfig { + model_id, + environment, + format, + }); } } });