diff --git a/crates/cloud_llm_client/src/cloud_llm_client.rs b/crates/cloud_llm_client/src/cloud_llm_client.rs index bb77c3a5b7f8009093cbf7bc427160ed535e6c62..ff8275fe40eae6945691a7b8d315414617be0235 100644 --- a/crates/cloud_llm_client/src/cloud_llm_client.rs +++ b/crates/cloud_llm_client/src/cloud_llm_client.rs @@ -183,13 +183,13 @@ pub struct PredictEditsGitInfo { #[derive(Debug, Clone, Serialize, Deserialize)] pub struct PredictEditsResponse { - pub request_id: Uuid, + pub request_id: String, pub output_excerpt: String, } #[derive(Debug, Clone, Serialize, Deserialize)] pub struct AcceptEditPredictionBody { - pub request_id: Uuid, + pub request_id: String, } #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Clone, Copy, Serialize, Deserialize)] diff --git a/crates/zeta/src/zeta.rs b/crates/zeta/src/zeta.rs index 0b939bdee27851a9ec9975b586c9a7bcad67484f..708a53ff47bd2c60e6b9620e8bed30b16419ba14 100644 --- a/crates/zeta/src/zeta.rs +++ b/crates/zeta/src/zeta.rs @@ -652,7 +652,7 @@ impl Zeta { .header(ZED_VERSION_HEADER_NAME, app_version.to_string()) .body( serde_json::to_string(&AcceptEditPredictionBody { - request_id: request_id.0, + request_id: request_id.0.to_string(), })? .into(), )?) @@ -735,6 +735,8 @@ impl Zeta { return anyhow::Ok(None); }; + let request_id = Uuid::from_str(&request_id).context("failed to parse request id")?; + let edit_preview = edit_preview.await; Ok(Some(EditPrediction { @@ -2162,7 +2164,7 @@ mod tests { .status(200) .body( serde_json::to_string(&PredictEditsResponse { - request_id: Uuid::new_v4(), + request_id: Uuid::new_v4().to_string(), output_excerpt: completion_response.lock().clone(), }) .unwrap() diff --git a/crates/zeta2/src/prediction.rs b/crates/zeta2/src/prediction.rs index 54a6987b3f781a48fe928636dc3537117ee6a401..e9f726ce00c36b5235919c0e185876996f4fda03 100644 --- a/crates/zeta2/src/prediction.rs +++ b/crates/zeta2/src/prediction.rs @@ -1,21 +1,14 @@ use std::{ops::Range, sync::Arc}; -use gpui::{AsyncApp, Entity}; +use gpui::{AsyncApp, Entity, SharedString}; use language::{Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, TextBufferSnapshot}; -use uuid::Uuid; -#[derive(Copy, Clone, Default, Debug, PartialEq, Eq, Hash)] -pub struct EditPredictionId(pub Uuid); - -impl Into for EditPredictionId { - fn into(self) -> Uuid { - self.0 - } -} +#[derive(Clone, Default, Debug, PartialEq, Eq, Hash)] +pub struct EditPredictionId(pub SharedString); impl From for gpui::ElementId { fn from(value: EditPredictionId) -> Self { - gpui::ElementId::Uuid(value.0) + gpui::ElementId::Name(value.0) } } @@ -149,7 +142,7 @@ mod tests { .await; let prediction = EditPrediction { - id: EditPredictionId(Uuid::new_v4()), + id: EditPredictionId("prediction-1".into()), edits, snapshot: cx.read(|cx| buffer.read(cx).snapshot()), buffer: buffer.clone(), diff --git a/crates/zeta2/src/zeta2.rs b/crates/zeta2/src/zeta2.rs index ff0ff4f1ba2af59f32cddee96e4b9c0dd25af22d..3a51f9975ccbcf3fb325712f7aafadc5187da541 100644 --- a/crates/zeta2/src/zeta2.rs +++ b/crates/zeta2/src/zeta2.rs @@ -30,8 +30,8 @@ use project::Project; use release_channel::AppVersion; use serde::de::DeserializeOwned; use std::collections::{VecDeque, hash_map}; -use uuid::Uuid; +use std::env; use std::ops::Range; use std::path::Path; use std::str::FromStr as _; @@ -88,8 +88,24 @@ pub const DEFAULT_OPTIONS: ZetaOptions = ZetaOptions { buffer_change_grouping_interval: Duration::from_secs(1), }; -static MODEL_ID: LazyLock = - LazyLock::new(|| std::env::var("ZED_ZETA2_MODEL").unwrap_or("yqvev8r3".to_string())); +static USE_OLLAMA: LazyLock = + LazyLock::new(|| env::var("ZED_ZETA2_OLLAMA").is_ok_and(|var| !var.is_empty())); +static MODEL_ID: LazyLock = LazyLock::new(|| { + env::var("ZED_ZETA2_MODEL").unwrap_or(if *USE_OLLAMA { + "qwen3-coder:30b".to_string() + } else { + "yqvev8r3".to_string() + }) +}); +static PREDICT_EDITS_URL: LazyLock> = LazyLock::new(|| { + env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| { + if *USE_OLLAMA { + Some("http://localhost:11434/v1/chat/completions".into()) + } else { + None + } + }) +}); pub struct Zeta2FeatureFlag; @@ -567,13 +583,13 @@ impl Zeta { let Some(prediction) = project_state.current_prediction.take() else { return; }; - let request_id = prediction.prediction.id.into(); + let request_id = prediction.prediction.id.to_string(); let client = self.client.clone(); let llm_token = self.llm_token.clone(); let app_version = AppVersion::global(cx); cx.spawn(async move |this, cx| { - let url = if let Ok(predict_edits_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") { + let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") { http_client::Url::parse(&predict_edits_url)? } else { client @@ -585,7 +601,10 @@ impl Zeta { .background_spawn(Self::send_api_request::<()>( move |builder| { let req = builder.uri(url.as_ref()).body( - serde_json::to_string(&AcceptEditPredictionBody { request_id })?.into(), + serde_json::to_string(&AcceptEditPredictionBody { + request_id: request_id.clone(), + })? + .into(), ); Ok(req?) }, @@ -875,7 +894,7 @@ impl Zeta { None }; - if cfg!(debug_assertions) && std::env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() { + if cfg!(debug_assertions) && env::var("ZED_ZETA2_SKIP_REQUEST").is_ok() { if let Some(debug_response_tx) = debug_response_tx { debug_response_tx .send((Err("Request skipped".to_string()), TimeDelta::zero())) @@ -923,7 +942,7 @@ impl Zeta { } let (res, usage) = response?; - let request_id = EditPredictionId(Uuid::from_str(&res.id)?); + let request_id = EditPredictionId(res.id.clone().into()); let Some(output_text) = text_from_response(res) else { return Ok((None, usage)) }; @@ -980,7 +999,7 @@ impl Zeta { app_version: SemanticVersion, request: open_ai::Request, ) -> Result<(open_ai::Response, Option)> { - let url = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") { + let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() { http_client::Url::parse(&predict_edits_url)? } else { client