assistant2: Show accept terms UI in thread empty state (#23630)

Agus Zubiaga and Danilo created

<img
src="https://github.com/user-attachments/assets/cea93cfb-8a40-48c4-9d90-f1751c79603b"
width=400>



Release Notes:

- N/A

---------

Co-authored-by: Danilo <danilo@zed.dev>

Change summary

crates/assistant2/src/assistant_panel.rs              |  97 ++++---
crates/assistant_context_editor/src/context_editor.rs |  15 
crates/language_model/src/language_model.rs           |  13 
crates/language_models/src/provider/cloud.rs          | 174 +++++-------
4 files changed, 155 insertions(+), 144 deletions(-)

Detailed changes

crates/assistant2/src/assistant_panel.rs 🔗

@@ -9,6 +9,7 @@ use assistant_context_editor::{
 use assistant_settings::{AssistantDockPosition, AssistantSettings};
 use assistant_slash_command::SlashCommandWorkingSet;
 use assistant_tool::ToolWorkingSet;
+
 use client::zed_urls;
 use editor::Editor;
 use fs::Fs;
@@ -18,7 +19,7 @@ use gpui::{
     ViewContext, WeakView, WindowContext,
 };
 use language::LanguageRegistry;
-use language_model::LanguageModelRegistry;
+use language_model::{LanguageModelProviderTosView, LanguageModelRegistry};
 use project::Project;
 use prompt_library::{open_prompt_library, PromptBuilder, PromptLibrary};
 use settings::{update_settings_file, Settings};
@@ -663,17 +664,16 @@ impl AssistantPanel {
     }
 
     fn configuration_error(&self, cx: &AppContext) -> Option<ConfigurationError> {
-        let provider = LanguageModelRegistry::read_global(cx).active_provider();
-        let is_authenticated = provider
-            .as_ref()
-            .map_or(false, |provider| provider.is_authenticated(cx));
+        let Some(provider) = LanguageModelRegistry::read_global(cx).active_provider() else {
+            return Some(ConfigurationError::NoProvider);
+        };
 
-        if provider.is_some() && is_authenticated {
-            return None;
+        if !provider.is_authenticated(cx) {
+            return Some(ConfigurationError::ProviderNotAuthenticated);
         }
 
-        if !is_authenticated {
-            return Some(ConfigurationError::ProviderNotAuthenticated);
+        if provider.must_accept_terms(cx) {
+            return Some(ConfigurationError::ProviderPendingTermsAcceptance(provider));
         }
 
         None
@@ -691,6 +691,9 @@ impl AssistantPanel {
                 .child(Headline::new("Welcome to the Assistant Panel").size(HeadlineSize::Small))
         };
 
+        let configuration_error = self.configuration_error(cx);
+        let no_error = configuration_error.is_none();
+
         v_flex()
             .gap_2()
             .child(
@@ -704,41 +707,51 @@ impl AssistantPanel {
                         .mb_4(),
                 ),
             )
-            .when(
-                matches!(
-                    self.configuration_error(cx),
-                    Some(ConfigurationError::ProviderNotAuthenticated)
-                ),
-                |parent| {
-                    parent.child(
-                        v_flex()
-                            .gap_0p5()
-                            .child(create_welcome_heading())
-                            .child(
-                                h_flex().mb_2().w_full().justify_center().child(
-                                    Label::new(
-                                        "To start using the assistant, configure at least one LLM provider.",
-                                    )
-                                    .color(Color::Muted),
-                                ),
-                            )
-                            .child(
-                                h_flex().w_full().justify_center().child(
-                                    Button::new("open-configuration", "Configure a Provider")
-                                        .size(ButtonSize::Compact)
-                                        .icon(Some(IconName::Sliders))
-                                        .icon_size(IconSize::Small)
-                                        .icon_position(IconPosition::Start)
-                                        .on_click(cx.listener(|this, _, cx| {
-                                            this.open_configuration(cx);
-                                        })),
+            .map(|parent| {
+                match configuration_error {
+                    Some(ConfigurationError::ProviderNotAuthenticated) | Some(ConfigurationError::NoProvider)  => {
+                        parent.child(
+                            v_flex()
+                                .gap_0p5()
+                                .child(create_welcome_heading())
+                                .child(
+                                    h_flex().mb_2().w_full().justify_center().child(
+                                        Label::new(
+                                            "To start using the assistant, configure at least one LLM provider.",
+                                        )
+                                        .color(Color::Muted),
+                                    ),
+                                )
+                                .child(
+                                    h_flex().w_full().justify_center().child(
+                                        Button::new("open-configuration", "Configure a Provider")
+                                            .size(ButtonSize::Compact)
+                                            .icon(Some(IconName::Sliders))
+                                            .icon_size(IconSize::Small)
+                                            .icon_position(IconPosition::Start)
+                                            .on_click(cx.listener(|this, _, cx| {
+                                                this.open_configuration(cx);
+                                            })),
+                                    ),
                                 ),
-                            ),
-                    )
-                },
-            )
+                        )
+                    }
+                    Some(ConfigurationError::ProviderPendingTermsAcceptance(provider)) => {
+                        parent.child(
+                            v_flex()
+                                .gap_0p5()
+                                .child(create_welcome_heading())
+                                .children(provider.render_accept_terms(
+                                    LanguageModelProviderTosView::ThreadEmptyState,
+                                    cx,
+                                )),
+                        )
+                    }
+                    None => parent,
+                }
+            })
             .when(
-                recent_threads.is_empty() && self.configuration_error(cx).is_none(),
+                recent_threads.is_empty() && no_error,
                 |parent| {
                     parent.child(
                         v_flex().gap_0p5().child(create_welcome_heading()).child(

crates/assistant_context_editor/src/context_editor.rs 🔗

@@ -31,7 +31,10 @@ use gpui::{
 };
 use indexed_docs::IndexedDocsStore;
 use language::{language_settings::SoftWrap, BufferSnapshot, LspAdapterDelegate, ToOffset};
-use language_model::{LanguageModelImage, LanguageModelRegistry, LanguageModelToolUse, Role};
+use language_model::{
+    LanguageModelImage, LanguageModelProvider, LanguageModelProviderTosView, LanguageModelRegistry,
+    LanguageModelToolUse, Role,
+};
 use language_model_selector::{LanguageModelSelector, LanguageModelSelectorPopoverMenu};
 use multi_buffer::MultiBufferRow;
 use picker::Picker;
@@ -2260,6 +2263,9 @@ impl ContextEditor {
             let label = match configuration_error {
                 ConfigurationError::NoProvider => "No LLM provider selected.",
                 ConfigurationError::ProviderNotAuthenticated => "LLM provider is not configured.",
+                ConfigurationError::ProviderPendingTermsAcceptance(_) => {
+                    "LLM provider requires accepting the Terms of Service."
+                }
             };
             Some(
                 h_flex()
@@ -2855,9 +2861,9 @@ impl Render for ContextEditor {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
         let provider = LanguageModelRegistry::read_global(cx).active_provider();
         let accept_terms = if self.show_accept_terms {
-            provider
-                .as_ref()
-                .and_then(|provider| provider.render_accept_terms(cx))
+            provider.as_ref().and_then(|provider| {
+                provider.render_accept_terms(LanguageModelProviderTosView::PromptEditorPopup, cx)
+            })
         } else {
             None
         };
@@ -3502,6 +3508,7 @@ fn size_for_image(data: &RenderImage, max_size: Size<Pixels>) -> Size<Pixels> {
 pub enum ConfigurationError {
     NoProvider,
     ProviderNotAuthenticated,
+    ProviderPendingTermsAcceptance(Arc<dyn LanguageModelProvider>),
 }
 
 fn configuration_error(cx: &AppContext) -> Option<ConfigurationError> {

crates/language_model/src/language_model.rs 🔗

@@ -245,12 +245,23 @@ pub trait LanguageModelProvider: 'static {
     fn must_accept_terms(&self, _cx: &AppContext) -> bool {
         false
     }
-    fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option<AnyElement> {
+    fn render_accept_terms(
+        &self,
+        _view: LanguageModelProviderTosView,
+        _cx: &mut WindowContext,
+    ) -> Option<AnyElement> {
         None
     }
     fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
 }
 
+#[derive(PartialEq, Eq)]
+pub enum LanguageModelProviderTosView {
+    ThreadEmptyState,
+    PromptEditorPopup,
+    Configuration,
+}
+
 pub trait LanguageModelProviderState: 'static {
     type ObservableEntity;
 

crates/language_models/src/provider/cloud.rs 🔗

@@ -12,14 +12,14 @@ use futures::{
     TryStreamExt as _,
 };
 use gpui::{
-    AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, FontWeight, Global, Model,
-    ModelContext, ReadGlobal, Subscription, Task,
+    AnyElement, AnyView, AppContext, AsyncAppContext, EventEmitter, Global, Model, ModelContext,
+    ReadGlobal, Subscription, Task,
 };
 use http_client::{AsyncBody, HttpClient, Method, Response, StatusCode};
 use language_model::{
     CloudModel, LanguageModel, LanguageModelCacheConfiguration, LanguageModelId, LanguageModelName,
     LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
-    LanguageModelRequest, RateLimiter, ZED_CLOUD_PROVIDER_ID,
+    LanguageModelProviderTosView, LanguageModelRequest, RateLimiter, ZED_CLOUD_PROVIDER_ID,
 };
 use language_model::{
     LanguageModelAvailability, LanguageModelCompletionEvent, LanguageModelProvider,
@@ -378,60 +378,12 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
         !self.state.read(cx).has_accepted_terms_of_service(cx)
     }
 
-    fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
-        let state = self.state.read(cx);
-
-        let terms = [(
-            "terms_of_service",
-            "Terms of Service",
-            "https://zed.dev/terms-of-service",
-        )]
-        .map(|(id, label, url)| {
-            Button::new(id, label)
-                .style(ButtonStyle::Subtle)
-                .icon(IconName::ExternalLink)
-                .icon_size(IconSize::XSmall)
-                .icon_color(Color::Muted)
-                .on_click(move |_, cx| cx.open_url(url))
-        });
-
-        if state.has_accepted_terms_of_service(cx) {
-            None
-        } else {
-            let disabled = state.accept_terms.is_some();
-            Some(
-                v_flex()
-                    .gap_2()
-                    .child(
-                        v_flex()
-                            .child(Label::new("Terms and Conditions").weight(FontWeight::MEDIUM))
-                            .child(
-                                Label::new(
-                                    "Please read and accept our terms and conditions to continue.",
-                                )
-                                .size(LabelSize::Small),
-                            ),
-                    )
-                    .child(v_flex().gap_1().children(terms))
-                    .child(
-                        h_flex().justify_end().child(
-                            Button::new("accept_terms", "I've read it and accept it")
-                                .disabled(disabled)
-                                .on_click({
-                                    let state = self.state.downgrade();
-                                    move |_, cx| {
-                                        state
-                                            .update(cx, |state, cx| {
-                                                state.accept_terms_of_service(cx)
-                                            })
-                                            .ok();
-                                    }
-                                }),
-                        ),
-                    )
-                    .into_any(),
-            )
-        }
+    fn render_accept_terms(
+        &self,
+        view: LanguageModelProviderTosView,
+        cx: &mut WindowContext,
+    ) -> Option<AnyElement> {
+        render_accept_terms(self.state.clone(), view, cx)
     }
 
     fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
@@ -439,6 +391,68 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
     }
 }
 
+fn render_accept_terms(
+    state: Model<State>,
+    view_kind: LanguageModelProviderTosView,
+    cx: &mut WindowContext,
+) -> Option<AnyElement> {
+    if state.read(cx).has_accepted_terms_of_service(cx) {
+        return None;
+    }
+
+    let accept_terms_disabled = state.read(cx).accept_terms.is_some();
+
+    let terms_button = Button::new("terms_of_service", "Terms of Service")
+        .style(ButtonStyle::Subtle)
+        .icon(IconName::ArrowUpRight)
+        .icon_color(Color::Muted)
+        .icon_size(IconSize::XSmall)
+        .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service"));
+
+    let text = "To start using Zed AI, please read and accept the";
+
+    let form = v_flex()
+        .w_full()
+        .gap_2()
+        .when(
+            view_kind == LanguageModelProviderTosView::ThreadEmptyState,
+            |form| form.items_center(),
+        )
+        .child(
+            h_flex()
+                .flex_wrap()
+                .when(
+                    view_kind == LanguageModelProviderTosView::ThreadEmptyState,
+                    |form| form.justify_center(),
+                )
+                .child(Label::new(text))
+                .child(terms_button),
+        )
+        .child({
+            let button_container = h_flex().w_full().child(
+                Button::new("accept_terms", "I accept the Terms of Service")
+                    .style(ButtonStyle::Tinted(TintColor::Accent))
+                    .disabled(accept_terms_disabled)
+                    .on_click({
+                        let state = state.downgrade();
+                        move |_, cx| {
+                            state
+                                .update(cx, |state, cx| state.accept_terms_of_service(cx))
+                                .ok();
+                        }
+                    }),
+            );
+
+            match view_kind {
+                LanguageModelProviderTosView::ThreadEmptyState => button_container.justify_center(),
+                LanguageModelProviderTosView::PromptEditorPopup => button_container.justify_end(),
+                LanguageModelProviderTosView::Configuration => button_container.justify_start(),
+            }
+        });
+
+    Some(form.into_any())
+}
+
 pub struct CloudLanguageModel {
     id: LanguageModelId,
     model: CloudModel,
@@ -852,44 +866,6 @@ impl ConfigurationView {
         });
         cx.notify();
     }
-
-    fn render_accept_terms(&mut self, cx: &mut ViewContext<Self>) -> Option<AnyElement> {
-        if self.state.read(cx).has_accepted_terms_of_service(cx) {
-            return None;
-        }
-
-        let accept_terms_disabled = self.state.read(cx).accept_terms.is_some();
-
-        let terms_button = Button::new("terms_of_service", "Terms of Service")
-            .style(ButtonStyle::Subtle)
-            .icon(IconName::ArrowUpRight)
-            .icon_color(Color::Muted)
-            .icon_size(IconSize::XSmall)
-            .on_click(move |_, cx| cx.open_url("https://zed.dev/terms-of-service"));
-
-        let text = "To start using Zed AI, please read and accept the";
-
-        let form = v_flex()
-            .gap_1()
-            .child(h_flex().child(Label::new(text)).child(terms_button))
-            .child(
-                h_flex().child(
-                    Button::new("accept_terms", "I've read and accept the Terms of Service")
-                        .style(ButtonStyle::Tinted(TintColor::Accent))
-                        .disabled(accept_terms_disabled)
-                        .on_click({
-                            let state = self.state.downgrade();
-                            move |_, cx| {
-                                state
-                                    .update(cx, |state, cx| state.accept_terms_of_service(cx))
-                                    .ok();
-                            }
-                        }),
-                ),
-            );
-
-        Some(form.into_any())
-    }
 }
 
 impl Render for ConfigurationView {
@@ -939,8 +915,12 @@ impl Render for ConfigurationView {
         if is_connected {
             v_flex()
                 .gap_3()
-                .max_w_4_5()
-                .children(self.render_accept_terms(cx))
+                .w_full()
+                .children(render_accept_terms(
+                    self.state.clone(),
+                    LanguageModelProviderTosView::Configuration,
+                    cx,
+                ))
                 .when(has_accepted_terms, |this| {
                     this.child(subscription_text)
                         .children(manage_subscription_button)