ep: Option to configure custom Baseten environment (#50706)

Oleksiy Syvokon created

Release Notes:

- N/A

Change summary

crates/edit_prediction/src/edit_prediction.rs | 8 +++++++-
crates/edit_prediction/src/zeta.rs            | 6 +++++-
crates/edit_prediction_cli/src/predict.rs     | 7 ++++++-
3 files changed, 18 insertions(+), 3 deletions(-)

Detailed changes

crates/edit_prediction/src/edit_prediction.rs 🔗

@@ -125,6 +125,7 @@ impl Global for EditPredictionStoreGlobal {}
 #[derive(Clone)]
 pub struct Zeta2RawConfig {
     pub model_id: Option<String>,
+    pub environment: Option<String>,
     pub format: ZetaFormat,
 }
 
@@ -760,7 +761,12 @@ impl EditPredictionStore {
         let version_str = env::var("ZED_ZETA_FORMAT").ok()?;
         let format = ZetaFormat::parse(&version_str).ok()?;
         let model_id = env::var("ZED_ZETA_MODEL").ok();
-        Some(Zeta2RawConfig { model_id, format })
+        let environment = env::var("ZED_ZETA_ENVIRONMENT").ok();
+        Some(Zeta2RawConfig {
+            model_id,
+            environment,
+            format,
+        })
     }
 
     pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {

crates/edit_prediction/src/zeta.rs 🔗

@@ -186,13 +186,17 @@ pub fn request_prediction_with_zeta(
                     let prompt = format_zeta_prompt(&prompt_input, config.format);
                     let prefill = get_prefill(&prompt_input, config.format);
                     let prompt = format!("{prompt}{prefill}");
+                    let environment = config
+                        .environment
+                        .clone()
+                        .or_else(|| Some(config.format.to_string().to_lowercase()));
                     let request = RawCompletionRequest {
                         model: config.model_id.clone().unwrap_or_default(),
                         prompt,
                         temperature: None,
                         stop: vec![],
                         max_tokens: Some(2048),
-                        environment: Some(config.format.to_string().to_lowercase()),
+                        environment,
                     };
 
                     editable_range_in_excerpt = zeta_prompt::excerpt_range_for_format(

crates/edit_prediction_cli/src/predict.rs 🔗

@@ -148,7 +148,12 @@ pub async fn run_prediction(
         if let PredictionProvider::Zeta2(format) = provider {
             if format != ZetaFormat::default() {
                 let model_id = std::env::var("ZED_ZETA_MODEL").ok();
-                store.set_zeta2_raw_config(Zeta2RawConfig { model_id, format });
+                let environment = std::env::var("ZED_ZETA_ENVIRONMENT").ok();
+                store.set_zeta2_raw_config(Zeta2RawConfig {
+                    model_id,
+                    environment,
+                    format,
+                });
             }
         }
     });