@@ -19,6 +19,7 @@ use futures::{
select_biased,
};
use gpui::BackgroundExecutor;
+use gpui::http_client::Url;
use gpui::{
App, AsyncApp, Entity, EntityId, Global, SharedString, Subscription, Task, WeakEntity, actions,
http_client::{self, AsyncBody, Method},
@@ -127,15 +128,6 @@ static EDIT_PREDICTIONS_MODEL_ID: LazyLock<String> = LazyLock::new(|| {
}
.to_string()
});
-static PREDICT_EDITS_URL: LazyLock<Option<String>> = LazyLock::new(|| {
- env::var("ZED_PREDICT_EDITS_URL").ok().or_else(|| {
- if *USE_OLLAMA {
- Some("http://localhost:11434/v1/chat/completions".into())
- } else {
- None
- }
- })
-});
pub struct Zeta2FeatureFlag;
@@ -170,6 +162,7 @@ pub struct EditPredictionStore {
reject_predictions_tx: mpsc::UnboundedSender<EditPredictionRejection>,
shown_predictions: VecDeque<EditPrediction>,
rated_predictions: HashSet<EditPredictionId>,
+ custom_predict_edits_url: Option<Arc<Url>>,
}
#[derive(Copy, Clone, Default, PartialEq, Eq)]
@@ -568,6 +561,20 @@ impl EditPredictionStore {
reject_predictions_tx: reject_tx,
rated_predictions: Default::default(),
shown_predictions: Default::default(),
+ custom_predict_edits_url: match env::var("ZED_PREDICT_EDITS_URL") {
+ Ok(custom_url) => Url::parse(&custom_url).log_err().map(Into::into),
+ Err(_) => {
+ if *USE_OLLAMA {
+ Some(
+ Url::parse("http://localhost:11434/v1/chat/completions")
+ .unwrap()
+ .into(),
+ )
+ } else {
+ None
+ }
+ }
+ },
};
this.configure_context_retrieval(cx);
@@ -586,6 +593,11 @@ impl EditPredictionStore {
this
}
+ #[cfg(test)]
+ pub fn set_custom_predict_edits_url(&mut self, url: Url) {
+ self.custom_predict_edits_url = Some(url.into());
+ }
+
pub fn set_edit_prediction_model(&mut self, model: EditPredictionModel) {
self.edit_prediction_model = model;
}
@@ -1015,8 +1027,13 @@ impl EditPredictionStore {
}
fn accept_current_prediction(&mut self, project: &Entity<Project>, cx: &mut Context<Self>) {
+ let custom_accept_url = env::var("ZED_ACCEPT_PREDICTION_URL").ok();
match self.edit_prediction_model {
- EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
+ if self.custom_predict_edits_url.is_some() && custom_accept_url.is_none() {
+ return;
+ }
+ }
EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
}
@@ -1036,12 +1053,15 @@ impl EditPredictionStore {
let llm_token = self.llm_token.clone();
let app_version = AppVersion::global(cx);
cx.spawn(async move |this, cx| {
- let url = if let Ok(predict_edits_url) = env::var("ZED_ACCEPT_PREDICTION_URL") {
- http_client::Url::parse(&predict_edits_url)?
+ let (url, require_auth) = if let Some(accept_edits_url) = custom_accept_url {
+ (http_client::Url::parse(&accept_edits_url)?, false)
} else {
- client
- .http_client()
- .build_zed_llm_url("/predict_edits/accept", &[])?
+ (
+ client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/accept", &[])?,
+ true,
+ )
};
let response = cx
@@ -1058,6 +1078,7 @@ impl EditPredictionStore {
client,
llm_token,
app_version,
+ require_auth,
))
.await;
@@ -1116,6 +1137,7 @@ impl EditPredictionStore {
client.clone(),
llm_token.clone(),
app_version.clone(),
+ true,
)
.await;
@@ -1161,7 +1183,11 @@ impl EditPredictionStore {
was_shown: bool,
) {
match self.edit_prediction_model {
- EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {}
+ EditPredictionModel::Zeta1 | EditPredictionModel::Zeta2 => {
+ if self.custom_predict_edits_url.is_some() {
+ return;
+ }
+ }
EditPredictionModel::Sweep | EditPredictionModel::Mercury => return,
}
@@ -1671,13 +1697,9 @@ impl EditPredictionStore {
#[cfg(feature = "cli-support")] eval_cache: Option<Arc<dyn EvalCache>>,
#[cfg(feature = "cli-support")] eval_cache_kind: EvalCacheEntryKind,
) -> Result<(open_ai::Response, Option<EditPredictionUsage>)> {
- let url = if let Some(predict_edits_url) = PREDICT_EDITS_URL.as_ref() {
- http_client::Url::parse(&predict_edits_url)?
- } else {
- client
- .http_client()
- .build_zed_llm_url("/predict_edits/raw", &[])?
- };
+ let url = client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/raw", &[])?;
#[cfg(feature = "cli-support")]
let cache_key = if let Some(cache) = eval_cache {
@@ -1710,6 +1732,7 @@ impl EditPredictionStore {
client,
llm_token,
app_version,
+ true,
)
.await?;
@@ -1770,23 +1793,34 @@ impl EditPredictionStore {
client: Arc<Client>,
llm_token: LlmApiToken,
app_version: Version,
+ require_auth: bool,
) -> Result<(Res, Option<EditPredictionUsage>)>
where
Res: DeserializeOwned,
{
let http_client = client.http_client();
- let mut token = llm_token.acquire(&client).await?;
+
+ let mut token = if require_auth {
+ Some(llm_token.acquire(&client).await?)
+ } else {
+ llm_token.acquire(&client).await.ok()
+ };
let mut did_retry = false;
loop {
let request_builder = http_client::Request::builder().method(Method::POST);
- let request = build(
- request_builder
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {}", token))
- .header(ZED_VERSION_HEADER_NAME, app_version.to_string()),
- )?;
+ let mut request_builder = request_builder
+ .header("Content-Type", "application/json")
+ .header(ZED_VERSION_HEADER_NAME, app_version.to_string());
+
+ // Only add Authorization header if we have a token
+ if let Some(ref token_value) = token {
+ request_builder =
+ request_builder.header("Authorization", format!("Bearer {}", token_value));
+ }
+
+ let request = build(request_builder)?;
let mut response = http_client.send(request).await?;
@@ -1810,13 +1844,14 @@ impl EditPredictionStore {
response.body_mut().read_to_end(&mut body).await?;
return Ok((serde_json::from_slice(&body)?, usage));
} else if !did_retry
+ && token.is_some()
&& response
.headers()
.get(EXPIRED_LLM_TOKEN_HEADER_NAME)
.is_some()
{
did_retry = true;
- token = llm_token.refresh(&client).await?;
+ token = Some(llm_token.refresh(&client).await?);
} else {
let mut body = String::new();
response.body_mut().read_to_string(&mut body).await?;
@@ -1914,6 +1914,174 @@ fn from_completion_edits(
.collect()
}
+#[gpui::test]
+async fn test_unauthenticated_without_custom_url_blocks_prediction_impl(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/project",
+ serde_json::json!({
+ "main.rs": "fn main() {\n \n}\n"
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+
+ let http_client = FakeHttpClient::create(|_req| async move {
+ Ok(gpui::http_client::Response::builder()
+ .status(401)
+ .body("Unauthorized".into())
+ .unwrap())
+ });
+
+ let client =
+ cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
+ cx.update(|cx| {
+ language_model::RefreshLlmTokenListener::register(client.clone(), cx);
+ });
+
+ let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project
+ .find_project_path(path!("/project/main.rs"), cx)
+ .unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(&buffer, &project, cx)
+ });
+ cx.background_executor.run_until_parked();
+
+ let completion_task = ep_store.update(cx, |ep_store, cx| {
+ ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
+ ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
+ });
+
+ let result = completion_task.await;
+ assert!(
+ result.is_err(),
+ "Without authentication and without custom URL, prediction should fail"
+ );
+}
+
+#[gpui::test]
+async fn test_unauthenticated_with_custom_url_allows_prediction_impl(cx: &mut TestAppContext) {
+ init_test(cx);
+
+ let fs = FakeFs::new(cx.executor());
+ fs.insert_tree(
+ "/project",
+ serde_json::json!({
+ "main.rs": "fn main() {\n \n}\n"
+ }),
+ )
+ .await;
+
+ let project = Project::test(fs.clone(), [path!("/project").as_ref()], cx).await;
+
+ let predict_called = Arc::new(std::sync::atomic::AtomicBool::new(false));
+ let predict_called_clone = predict_called.clone();
+
+ let http_client = FakeHttpClient::create({
+ move |req| {
+ let uri = req.uri().path().to_string();
+ let predict_called = predict_called_clone.clone();
+ async move {
+ if uri.contains("predict") {
+ predict_called.store(true, std::sync::atomic::Ordering::SeqCst);
+ Ok(gpui::http_client::Response::builder()
+ .body(
+ serde_json::to_string(&open_ai::Response {
+ id: "test-123".to_string(),
+ object: "chat.completion".to_string(),
+ created: 0,
+ model: "test".to_string(),
+ usage: open_ai::Usage {
+ prompt_tokens: 0,
+ completion_tokens: 0,
+ total_tokens: 0,
+ },
+ choices: vec![open_ai::Choice {
+ index: 0,
+ message: open_ai::RequestMessage::Assistant {
+ content: Some(open_ai::MessageContent::Plain(
+ indoc! {"
+ ```main.rs
+ <|start_of_file|>
+ <|editable_region_start|>
+ fn main() {
+ println!(\"Hello, world!\");
+ }
+ <|editable_region_end|>
+ ```
+ "}
+ .to_string(),
+ )),
+ tool_calls: vec![],
+ },
+ finish_reason: Some("stop".to_string()),
+ }],
+ })
+ .unwrap()
+ .into(),
+ )
+ .unwrap())
+ } else {
+ Ok(gpui::http_client::Response::builder()
+ .status(401)
+ .body("Unauthorized".into())
+ .unwrap())
+ }
+ }
+ }
+ });
+
+ let client =
+ cx.update(|cx| client::Client::new(Arc::new(FakeSystemClock::new()), http_client, cx));
+ cx.update(|cx| {
+ language_model::RefreshLlmTokenListener::register(client.clone(), cx);
+ });
+
+ let ep_store = cx.new(|cx| EditPredictionStore::new(client, project.read(cx).user_store(), cx));
+
+ let buffer = project
+ .update(cx, |project, cx| {
+ let path = project
+ .find_project_path(path!("/project/main.rs"), cx)
+ .unwrap();
+ project.open_buffer(path, cx)
+ })
+ .await
+ .unwrap();
+
+ let cursor = buffer.read_with(cx, |buffer, _| buffer.anchor_before(Point::new(1, 4)));
+ ep_store.update(cx, |ep_store, cx| {
+ ep_store.register_buffer(&buffer, &project, cx)
+ });
+ cx.background_executor.run_until_parked();
+
+ let completion_task = ep_store.update(cx, |ep_store, cx| {
+ ep_store.set_custom_predict_edits_url(Url::parse("http://test/predict").unwrap());
+ ep_store.set_edit_prediction_model(EditPredictionModel::Zeta1);
+ ep_store.request_prediction(&project, &buffer, cursor, Default::default(), cx)
+ });
+
+ let _ = completion_task.await;
+
+ assert!(
+ predict_called.load(std::sync::atomic::Ordering::SeqCst),
+ "With custom URL, predict endpoint should be called even without authentication"
+ );
+}
+
#[ctor::ctor]
fn init_logger() {
zlog::init_test();
@@ -78,6 +78,19 @@ pub(crate) fn request_prediction_with_zeta1(
cx,
);
+ let (uri, require_auth) = match &store.custom_predict_edits_url {
+ Some(custom_url) => (custom_url.clone(), false),
+ None => {
+ match client
+ .http_client()
+ .build_zed_llm_url("/predict_edits/v2", &[])
+ {
+ Ok(url) => (url.into(), true),
+ Err(err) => return Task::ready(Err(err)),
+ }
+ }
+ };
+
cx.spawn(async move |this, cx| {
let GatherContextOutput {
mut body,
@@ -102,25 +115,16 @@ pub(crate) fn request_prediction_with_zeta1(
body.input_excerpt
);
- let http_client = client.http_client();
-
let response = EditPredictionStore::send_api_request::<PredictEditsResponse>(
|request| {
- let uri = if let Ok(predict_edits_url) = std::env::var("ZED_PREDICT_EDITS_URL") {
- predict_edits_url
- } else {
- http_client
- .build_zed_llm_url("/predict_edits/v2", &[])?
- .as_str()
- .into()
- };
Ok(request
- .uri(uri)
+ .uri(uri.as_str())
.body(serde_json::to_string(&body)?.into())?)
},
client,
llm_token,
app_version,
+ require_auth,
)
.await;