zeta: Refresh LLM token in case it expired (#21796)

Thorsten Ball , Antonio , and Bennet created

Release Notes:

- N/A

---------

Co-authored-by: Antonio <antonio@zed.dev>
Co-authored-by: Bennet <bennet@zed.dev>

Change summary

crates/feature_flags/src/feature_flags.rs                       |   5 
crates/inline_completion_button/src/inline_completion_button.rs |   4 
crates/zed/src/zed/inline_completion_registry.rs                |  49 
crates/zeta/src/zeta.rs                                         | 245 +-
4 files changed, 190 insertions(+), 113 deletions(-)

Detailed changes

crates/feature_flags/src/feature_flags.rs 🔗

@@ -59,6 +59,11 @@ impl FeatureFlag for ToolUseFeatureFlag {
     }
 }
 
+pub struct ZetaFeatureFlag;
+impl FeatureFlag for ZetaFeatureFlag {
+    const NAME: &'static str = "zeta";
+}
+
 pub struct Remoting {}
 impl FeatureFlag for Remoting {
     const NAME: &'static str = "remoting";

crates/inline_completion_button/src/inline_completion_button.rs 🔗

@@ -1,7 +1,7 @@
 use anyhow::Result;
 use copilot::{Copilot, Status};
 use editor::{scroll::Autoscroll, Editor};
-use feature_flags::FeatureFlagAppExt;
+use feature_flags::{FeatureFlagAppExt, ZetaFeatureFlag};
 use fs::Fs;
 use gpui::{
     div, Action, AnchorCorner, AppContext, AsyncWindowContext, Entity, IntoElement, ParentElement,
@@ -199,7 +199,7 @@ impl Render for InlineCompletionButton {
             }
 
             InlineCompletionProvider::Zeta => {
-                if !cx.is_staff() {
+                if !cx.has_flag::<ZetaFeatureFlag>() {
                     return div();
                 }
 

crates/zed/src/zed/inline_completion_registry.rs 🔗

@@ -4,9 +4,9 @@ use client::Client;
 use collections::HashMap;
 use copilot::{Copilot, CopilotCompletionProvider};
 use editor::{Editor, EditorMode};
-use feature_flags::FeatureFlagAppExt;
+use feature_flags::{FeatureFlagAppExt, ZetaFeatureFlag};
 use gpui::{AnyWindowHandle, AppContext, Context, ViewContext, WeakView};
-use language::language_settings::all_language_settings;
+use language::language_settings::{all_language_settings, InlineCompletionProvider};
 use settings::SettingsStore;
 use supermaven::{Supermaven, SupermavenCompletionProvider};
 
@@ -49,22 +49,45 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
         });
     }
 
-    cx.observe_global::<SettingsStore>(move |cx| {
-        let new_provider = all_language_settings(None, cx).inline_completions.provider;
-        if new_provider != provider {
-            provider = new_provider;
-            for (editor, window) in editors.borrow().iter() {
-                _ = window.update(cx, |_window, cx| {
-                    _ = editor.update(cx, |editor, cx| {
-                        assign_inline_completion_provider(editor, provider, &client, cx);
-                    })
-                });
+    cx.observe_flag::<ZetaFeatureFlag, _>({
+        let editors = editors.clone();
+        let client = client.clone();
+        move |_flag, cx| {
+            let provider = all_language_settings(None, cx).inline_completions.provider;
+            assign_inline_completion_providers(&editors, provider, &client, cx)
+        }
+    })
+    .detach();
+
+    cx.observe_global::<SettingsStore>({
+        let editors = editors.clone();
+        let client = client.clone();
+        move |cx| {
+            let new_provider = all_language_settings(None, cx).inline_completions.provider;
+            if new_provider != provider {
+                provider = new_provider;
+                assign_inline_completion_providers(&editors, provider, &client, cx)
             }
         }
     })
     .detach();
 }
 
+fn assign_inline_completion_providers(
+    editors: &Rc<RefCell<HashMap<WeakView<Editor>, AnyWindowHandle>>>,
+    provider: InlineCompletionProvider,
+    client: &Arc<Client>,
+    cx: &mut AppContext,
+) {
+    for (editor, window) in editors.borrow().iter() {
+        _ = window.update(cx, |_window, cx| {
+            _ = editor.update(cx, |editor, cx| {
+                assign_inline_completion_provider(editor, provider, &client, cx);
+            })
+        });
+    }
+}
+
 fn register_backward_compatible_actions(editor: &mut Editor, cx: &ViewContext<Editor>) {
     // We renamed some of these actions to not be copilot-specific, but that
     // would have not been backwards-compatible. So here we are re-registering
@@ -129,7 +152,7 @@ fn assign_inline_completion_provider(
             }
         }
         language::language_settings::InlineCompletionProvider::Zeta => {
-            if cx.is_staff() {
+            if cx.has_flag::<ZetaFeatureFlag>() {
                 let zeta = zeta::Zeta::register(client.clone(), cx);
                 if let Some(buffer) = editor.buffer().read(cx).as_singleton() {
                     if buffer.read(cx).file().is_some() {

crates/zeta/src/zeta.rs 🔗

@@ -13,7 +13,7 @@ use language::{
     Point, ToOffset, ToPoint,
 };
 use language_models::LlmApiToken;
-use rpc::{PredictEditsParams, PredictEditsResponse};
+use rpc::{PredictEditsParams, PredictEditsResponse, EXPIRED_LLM_TOKEN_HEADER_NAME};
 use std::{
     borrow::Cow,
     cmp,
@@ -269,8 +269,6 @@ impl Zeta {
         cx.spawn(|this, mut cx| async move {
             let start = std::time::Instant::now();
 
-            let token = llm_token.acquire(&client).await?;
-
             let mut input_events = String::new();
             for event in events {
                 if !input_events.is_empty() {
@@ -283,141 +281,192 @@ impl Zeta {
 
             log::debug!("Events:\n{}\nExcerpt:\n{}", input_events, input_excerpt);
 
-            let http_client = client.http_client();
             let body = PredictEditsParams {
                 input_events: input_events.clone(),
                 input_excerpt: input_excerpt.clone(),
             };
+
+            let response = Self::perform_predict_edits(&client, llm_token, body).await?;
+
+            let output_excerpt = response.output_excerpt;
+            log::debug!("prediction took: {:?}", start.elapsed());
+            log::debug!("completion response: {}", output_excerpt);
+
+            let inline_completion = Self::process_completion_response(
+                output_excerpt,
+                &snapshot,
+                excerpt_range,
+                path,
+                input_events,
+                input_excerpt,
+            )?;
+
+            this.update(&mut cx, |this, cx| {
+                this.recent_completions
+                    .push_front(inline_completion.clone());
+                if this.recent_completions.len() > 50 {
+                    this.recent_completions.pop_back();
+                }
+                cx.notify();
+            })?;
+
+            Ok(inline_completion)
+        })
+    }
+
+    async fn perform_predict_edits(
+        client: &Arc<Client>,
+        llm_token: LlmApiToken,
+        body: PredictEditsParams,
+    ) -> Result<PredictEditsResponse> {
+        let http_client = client.http_client();
+        let mut token = llm_token.acquire(client).await?;
+        let mut did_retry = false;
+
+        loop {
             let request_builder = http_client::Request::builder();
             let request = request_builder
                 .method(Method::POST)
                 .uri(
-                    client
-                        .http_client()
+                    http_client
                         .build_zed_llm_url("/predict_edits", &[])?
                         .as_ref(),
                 )
                 .header("Content-Type", "application/json")
                 .header("Authorization", format!("Bearer {}", token))
                 .body(serde_json::to_string(&body)?.into())?;
+
             let mut response = http_client.send(request).await?;
-            let mut body = String::new();
-            response.body_mut().read_to_string(&mut body).await?;
-            if !response.status().is_success() {
+
+            if response.status().is_success() {
+                let mut body = String::new();
+                response.body_mut().read_to_string(&mut body).await?;
+                return Ok(serde_json::from_str(&body)?);
+            } else if !did_retry
+                && response
+                    .headers()
+                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+                    .is_some()
+            {
+                did_retry = true;
+                token = llm_token.refresh(client).await?;
+            } else {
+                let mut body = String::new();
+                response.body_mut().read_to_string(&mut body).await?;
                 return Err(anyhow!(
                     "error predicting edits.\nStatus: {:?}\nBody: {}",
                     response.status(),
                     body
                 ));
             }
+        }
+    }
 
-            let response = serde_json::from_str::<PredictEditsResponse>(&body)?;
-            let output_excerpt = response.output_excerpt;
-            log::debug!("prediction took: {:?}", start.elapsed());
-            log::debug!("completion response: {}", output_excerpt);
+    fn process_completion_response(
+        output_excerpt: String,
+        snapshot: &BufferSnapshot,
+        excerpt_range: Range<usize>,
+        path: Arc<Path>,
+        input_events: String,
+        input_excerpt: String,
+    ) -> Result<InlineCompletion> {
+        let content = output_excerpt.replace(CURSOR_MARKER, "");
 
-            let content = output_excerpt.replace(CURSOR_MARKER, "");
-            let mut new_text = content.as_str();
+        let codefence_start = content
+            .find(EDITABLE_REGION_START_MARKER)
+            .context("could not find start marker")?;
+        let content = &content[codefence_start..];
 
-            let codefence_start = new_text
-                .find(EDITABLE_REGION_START_MARKER)
-                .context("could not find start marker")?;
-            new_text = &new_text[codefence_start..];
+        let newline_ix = content.find('\n').context("could not find newline")?;
+        let content = &content[newline_ix + 1..];
 
-            let newline_ix = new_text.find('\n').context("could not find newline")?;
-            new_text = &new_text[newline_ix + 1..];
+        let codefence_end = content
+            .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
+            .context("could not find end marker")?;
+        let new_text = &content[..codefence_end];
 
-            let codefence_end = new_text
-                .rfind(&format!("\n{EDITABLE_REGION_END_MARKER}"))
-                .context("could not find end marker")?;
-            new_text = &new_text[..codefence_end];
-            log::debug!("sanitized completion response: {}", new_text);
+        let old_text = snapshot
+            .text_for_range(excerpt_range.clone())
+            .collect::<String>();
 
-            let old_text = snapshot
-                .text_for_range(excerpt_range.clone())
-                .collect::<String>();
+        let edits = Self::compute_edits(old_text, new_text, excerpt_range.start, snapshot);
 
-            let diff = similar::TextDiff::from_chars(old_text.as_str(), new_text);
+        Ok(InlineCompletion {
+            id: InlineCompletionId::new(),
+            path,
+            excerpt_range,
+            edits: edits.into(),
+            snapshot: snapshot.clone(),
+            input_events: input_events.into(),
+            input_excerpt: input_excerpt.into(),
+            output_excerpt: output_excerpt.into(),
+        })
+    }
 
-            let mut edits: Vec<(Range<usize>, String)> = Vec::new();
-            let mut old_start = excerpt_range.start;
-            for change in diff.iter_all_changes() {
-                let value = change.value();
-                match change.tag() {
-                    similar::ChangeTag::Equal => {
-                        old_start += value.len();
-                    }
-                    similar::ChangeTag::Delete => {
-                        let old_end = old_start + value.len();
-                        if let Some((last_old_range, _)) = edits.last_mut() {
-                            if last_old_range.end == old_start {
-                                last_old_range.end = old_end;
-                            } else {
-                                edits.push((old_start..old_end, String::new()));
-                            }
+    fn compute_edits(
+        old_text: String,
+        new_text: &str,
+        offset: usize,
+        snapshot: &BufferSnapshot,
+    ) -> Vec<(Range<Anchor>, String)> {
+        let diff = similar::TextDiff::from_chars(old_text.as_str(), new_text);
+
+        let mut edits: Vec<(Range<usize>, String)> = Vec::new();
+        let mut old_start = offset;
+        for change in diff.iter_all_changes() {
+            let value = change.value();
+            match change.tag() {
+                similar::ChangeTag::Equal => {
+                    old_start += value.len();
+                }
+                similar::ChangeTag::Delete => {
+                    let old_end = old_start + value.len();
+                    if let Some((last_old_range, _)) = edits.last_mut() {
+                        if last_old_range.end == old_start {
+                            last_old_range.end = old_end;
                         } else {
                             edits.push((old_start..old_end, String::new()));
                         }
-
-                        old_start = old_end;
+                    } else {
+                        edits.push((old_start..old_end, String::new()));
                     }
-                    similar::ChangeTag::Insert => {
-                        if let Some((last_old_range, last_new_text)) = edits.last_mut() {
-                            if last_old_range.end == old_start {
-                                last_new_text.push_str(value);
-                            } else {
-                                edits.push((old_start..old_start, value.into()));
-                            }
+                    old_start = old_end;
+                }
+                similar::ChangeTag::Insert => {
+                    if let Some((last_old_range, last_new_text)) = edits.last_mut() {
+                        if last_old_range.end == old_start {
+                            last_new_text.push_str(value);
                         } else {
                             edits.push((old_start..old_start, value.into()));
                         }
+                    } else {
+                        edits.push((old_start..old_start, value.into()));
                     }
                 }
             }
+        }
 
-            let edits = edits
-                .into_iter()
-                .map(|(mut old_range, new_text)| {
-                    let prefix_len = common_prefix(
-                        snapshot.chars_for_range(old_range.clone()),
-                        new_text.chars(),
-                    );
-                    old_range.start += prefix_len;
-                    let suffix_len = common_prefix(
-                        snapshot.reversed_chars_for_range(old_range.clone()),
-                        new_text[prefix_len..].chars().rev(),
-                    );
-                    old_range.end = old_range.end.saturating_sub(suffix_len);
-
-                    let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
-                    (
-                        snapshot.anchor_after(old_range.start)
-                            ..snapshot.anchor_before(old_range.end),
-                        new_text,
-                    )
-                })
-                .collect();
-            let inline_completion = InlineCompletion {
-                id: InlineCompletionId::new(),
-                path,
-                excerpt_range,
-                edits,
-                snapshot,
-                input_events: input_events.into(),
-                input_excerpt: input_excerpt.into(),
-                output_excerpt: output_excerpt.into(),
-            };
-            this.update(&mut cx, |this, cx| {
-                this.recent_completions
-                    .push_front(inline_completion.clone());
-                if this.recent_completions.len() > 50 {
-                    this.recent_completions.pop_back();
-                }
-                cx.notify();
-            })?;
-
-            Ok(inline_completion)
-        })
+        edits
+            .into_iter()
+            .map(|(mut old_range, new_text)| {
+                let prefix_len = common_prefix(
+                    snapshot.chars_for_range(old_range.clone()),
+                    new_text.chars(),
+                );
+                old_range.start += prefix_len;
+                let suffix_len = common_prefix(
+                    snapshot.reversed_chars_for_range(old_range.clone()),
+                    new_text[prefix_len..].chars().rev(),
+                );
+                old_range.end = old_range.end.saturating_sub(suffix_len);
+
+                let new_text = new_text[prefix_len..new_text.len() - suffix_len].to_string();
+                (
+                    snapshot.anchor_after(old_range.start)..snapshot.anchor_before(old_range.end),
+                    new_text,
+                )
+            })
+            .collect()
     }
 
     pub fn is_completion_rated(&self, completion_id: InlineCompletionId) -> bool {