@@ -14,7 +14,7 @@ use license_detection::LICENSE_FILES_TO_CHECK;
pub use license_detection::is_license_eligible_for_data_collection;
pub use rate_completion_modal::*;
-use anyhow::{Context as _, Result};
+use anyhow::{Context as _, Result, anyhow};
use arrayvec::ArrayVec;
use client::{Client, UserStore};
use collections::{HashMap, HashSet, VecDeque};
@@ -23,7 +23,7 @@ use gpui::{
App, AppContext as _, AsyncApp, Context, Entity, EntityId, Global, SemanticVersion,
Subscription, Task, WeakEntity, actions,
};
-use http_client::{HttpClient, Method};
+use http_client::{AsyncBody, HttpClient, Method, Request, Response};
use input_excerpt::excerpt_for_cursor_position;
use language::{
Anchor, Buffer, BufferSnapshot, EditPreview, OffsetRangeExt, ToOffset, ToPoint, text_diff,
@@ -54,8 +54,8 @@ use workspace::Workspace;
use workspace::notifications::{ErrorMessagePrompt, NotificationId};
use worktree::Worktree;
use zed_llm_client::{
- EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME, PredictEditsBody,
- PredictEditsResponse, ZED_VERSION_HEADER_NAME,
+ AcceptEditPredictionBody, EXPIRED_LLM_TOKEN_HEADER_NAME, MINIMUM_REQUIRED_VERSION_HEADER_NAME,
+ PredictEditsBody, PredictEditsResponse, ZED_VERSION_HEADER_NAME,
};
const CURSOR_MARKER: &'static str = "<|user_cursor_is_here|>";
@@ -823,6 +823,74 @@ and then another
}
}
+ fn accept_edit_prediction(
+ &mut self,
+ request_id: InlineCompletionId,
+ cx: &mut Context<Self>,
+ ) -> Task<Result<()>> {
+ let client = self.client.clone();
+ let llm_token = self.llm_token.clone();
+ let app_version = AppVersion::global(cx);
+ cx.spawn(async move |this, cx| {
+ let http_client = client.http_client();
+ let mut response = llm_token_retry(&llm_token, &client, |token| {
+ let request_builder = http_client::Request::builder().method(Method::POST);
+ let request_builder =
+ if let Ok(accept_prediction_url) = std::env::var("ZED_ACCEPT_PREDICTION_URL") {
+ request_builder.uri(accept_prediction_url)
+ } else {
+ request_builder.uri(
+ http_client
+ .build_zed_llm_url("/predict_edits/accept", &[])?
+ .as_ref(),
+ )
+ };
+ Ok(request_builder
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {}", token))
+ .header(ZED_VERSION_HEADER_NAME, app_version.to_string())
+ .body(
+ serde_json::to_string(&AcceptEditPredictionBody {
+ request_id: request_id.0,
+ })?
+ .into(),
+ )?)
+ })
+ .await?;
+
+ if let Some(minimum_required_version) = response
+ .headers()
+ .get(MINIMUM_REQUIRED_VERSION_HEADER_NAME)
+ .and_then(|version| SemanticVersion::from_str(version.to_str().ok()?).ok())
+ {
+ if app_version < minimum_required_version {
+ return Err(anyhow!(ZedUpdateRequiredError {
+ minimum_version: minimum_required_version
+ }));
+ }
+ }
+
+ if response.status().is_success() {
+ if let Some(usage) = EditPredictionUsage::from_headers(response.headers()).ok() {
+ this.update(cx, |this, cx| {
+ this.last_usage = Some(usage);
+ cx.notify();
+ })?;
+ }
+
+ Ok(())
+ } else {
+ let mut body = String::new();
+ response.body_mut().read_to_string(&mut body).await?;
+ Err(anyhow!(
+ "error accepting edit prediction.\nStatus: {:?}\nBody: {}",
+ response.status(),
+ body
+ ))
+ }
+ })
+ }
+
fn process_completion_response(
prediction_response: PredictEditsResponse,
buffer: Entity<Buffer>,
@@ -1381,6 +1449,34 @@ impl ProviderDataCollection {
}
}
+async fn llm_token_retry(
+ llm_token: &LlmApiToken,
+ client: &Arc<Client>,
+ build_request: impl Fn(String) -> Result<Request<AsyncBody>>,
+) -> Result<Response<AsyncBody>> {
+ let mut did_retry = false;
+ let http_client = client.http_client();
+ let mut token = llm_token.acquire(client).await?;
+ loop {
+ let request = build_request(token.clone())?;
+ let response = http_client.send(request).await?;
+
+ if !did_retry
+ && !response.status().is_success()
+ && response
+ .headers()
+ .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
+ {
+ did_retry = true;
+ token = llm_token.refresh(client).await?;
+ continue;
+ }
+
+ return Ok(response);
+ }
+}
+
pub struct ZetaInlineCompletionProvider {
zeta: Entity<Zeta>,
pending_completions: ArrayVec<PendingCompletion, 2>,
@@ -1597,7 +1693,18 @@ impl inline_completion::EditPredictionProvider for ZetaInlineCompletionProvider
// Right now we don't support cycling.
}
- fn accept(&mut self, _cx: &mut Context<Self>) {
+ fn accept(&mut self, cx: &mut Context<Self>) {
+ let completion_id = self
+ .current_completion
+ .as_ref()
+ .map(|completion| completion.completion.id);
+ if let Some(completion_id) = completion_id {
+ self.zeta
+ .update(cx, |zeta, cx| {
+ zeta.accept_edit_prediction(completion_id, cx)
+ })
+ .detach();
+ }
self.pending_completions.clear();
}