assistant: Require user to accept TOS for cloud provider (#16111)

Thorsten Ball created

This adds the requirement for users to accept the terms of service the
first time they send a message with the Cloud provider.

Once this is out and in a nightly, we need to add the check to the
server side too, to authenticate access to the models.

Demo:


https://github.com/user-attachments/assets/0edebf74-8120-4fa2-b801-bb76f04e8a17



Release Notes:

- N/A

Change summary

crates/assistant/src/assistant_panel.rs                         | 44 +
crates/client/src/test.rs                                       |  7 
crates/client/src/user.rs                                       | 53 +
crates/collab/migrations.sqlite/20221109000000_test_schema.sql  |  3 
crates/collab/migrations/20240812073542_add_accepted_tos_at.sql |  1 
crates/collab/src/db/queries/users.rs                           | 20 
crates/collab/src/db/tables/user.rs                             |  1 
crates/collab/src/db/tests.rs                                   |  1 
crates/collab/src/db/tests/user_tests.rs                        | 45 +
crates/collab/src/rpc.rs                                        | 21 
crates/language_model/src/language_model.rs                     | 10 
crates/language_model/src/provider/cloud.rs                     | 84 ++
crates/proto/proto/zed.proto                                    | 13 
crates/proto/src/proto.rs                                       |  3 
14 files changed, 297 insertions(+), 9 deletions(-)

Detailed changes

crates/assistant/src/assistant_panel.rs 🔗

@@ -490,6 +490,7 @@ impl AssistantPanel {
                     }
                     language_model::Event::ProviderStateChanged => {
                         this.ensure_authenticated(cx);
+                        cx.notify()
                     }
                     language_model::Event::AddedProvider(_)
                     | language_model::Event::RemovedProvider(_) => {
@@ -1712,6 +1713,7 @@ pub struct ContextEditor {
     assistant_panel: WeakView<AssistantPanel>,
     error_message: Option<SharedString>,
     debug_inspector: Option<ContextInspector>,
+    show_accept_terms: bool,
 }
 
 const DEFAULT_TAB_TITLE: &str = "New Context";
@@ -1772,6 +1774,7 @@ impl ContextEditor {
             assistant_panel,
             error_message: None,
             debug_inspector: None,
+            show_accept_terms: false,
         };
         this.update_message_headers(cx);
         this.insert_slash_command_output_sections(sections, cx);
@@ -1804,6 +1807,16 @@ impl ContextEditor {
     }
 
     fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
+        let provider = LanguageModelRegistry::read_global(cx).active_provider();
+        if provider
+            .as_ref()
+            .map_or(false, |provider| provider.must_accept_terms(cx))
+        {
+            self.show_accept_terms = true;
+            cx.notify();
+            return;
+        }
+
         if !self.apply_active_workflow_step(cx) {
             self.error_message = None;
             self.send_to_model(cx);
@@ -3388,7 +3401,14 @@ impl ContextEditor {
             None => (ButtonStyle::Filled, None),
         };
 
+        let provider = LanguageModelRegistry::read_global(cx).active_provider();
+        let disabled = self.show_accept_terms
+            && provider
+                .as_ref()
+                .map_or(false, |provider| provider.must_accept_terms(cx));
+
         ButtonLike::new("send_button")
+            .disabled(disabled)
             .style(style)
             .when_some(tooltip, |button, tooltip| {
                 button.tooltip(move |_| tooltip.clone())
@@ -3437,6 +3457,15 @@ impl EventEmitter<SearchEvent> for ContextEditor {}
 
 impl Render for ContextEditor {
     fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+        let provider = LanguageModelRegistry::read_global(cx).active_provider();
+        let accept_terms = if self.show_accept_terms {
+            provider
+                .as_ref()
+                .and_then(|provider| provider.render_accept_terms(cx))
+        } else {
+            None
+        };
+
         v_flex()
             .key_context("ContextEditor")
             .capture_action(cx.listener(ContextEditor::cancel))
@@ -3455,6 +3484,21 @@ impl Render for ContextEditor {
                     .bg(cx.theme().colors().editor_background)
                     .child(self.editor.clone()),
             )
+            .when_some(accept_terms, |this, element| {
+                this.child(
+                    div()
+                        .absolute()
+                        .right_4()
+                        .bottom_10()
+                        .max_w_96()
+                        .py_2()
+                        .px_3()
+                        .elevation_2(cx)
+                        .bg(cx.theme().colors().surface_background)
+                        .occlude()
+                        .child(element),
+                )
+            })
             .child(
                 h_flex().flex_none().relative().child(
                     h_flex()

crates/client/src/test.rs 🔗

@@ -1,5 +1,6 @@
 use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
 use anyhow::{anyhow, Result};
+use chrono::Duration;
 use futures::{stream::BoxStream, StreamExt};
 use gpui::{BackgroundExecutor, Context, Model, TestAppContext};
 use parking_lot::Mutex;
@@ -162,6 +163,11 @@ impl FakeServer {
                 return Ok(*message.downcast().unwrap());
             }
 
+            let accepted_tos_at = chrono::Utc::now()
+                .checked_sub_signed(Duration::hours(5))
+                .expect("failed to build accepted_tos_at")
+                .timestamp() as u64;
+
             if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
                 self.respond(
                     message
@@ -172,6 +178,7 @@ impl FakeServer {
                         metrics_id: "the-metrics-id".into(),
                         staff: false,
                         flags: Default::default(),
+                        accepted_tos_at: Some(accepted_tos_at),
                     },
                 );
                 continue;

crates/client/src/user.rs 🔗

@@ -1,5 +1,6 @@
 use super::{proto, Client, Status, TypedEnvelope};
 use anyhow::{anyhow, Context, Result};
+use chrono::{DateTime, Utc};
 use collections::{hash_map::Entry, HashMap, HashSet};
 use feature_flags::FeatureFlagAppExt;
 use futures::{channel::mpsc, Future, StreamExt};
@@ -94,6 +95,7 @@ pub struct UserStore {
     update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
     current_plan: Option<proto::Plan>,
     current_user: watch::Receiver<Option<Arc<User>>>,
+    accepted_tos_at: Option<Option<DateTime<Utc>>>,
     contacts: Vec<Arc<Contact>>,
     incoming_contact_requests: Vec<Arc<User>>,
     outgoing_contact_requests: Vec<Arc<User>>,
@@ -150,6 +152,7 @@ impl UserStore {
             by_github_login: Default::default(),
             current_user: current_user_rx,
             current_plan: None,
+            accepted_tos_at: None,
             contacts: Default::default(),
             incoming_contact_requests: Default::default(),
             participant_indices: Default::default(),
@@ -189,9 +192,10 @@ impl UserStore {
                                 } else {
                                     break;
                                 };
-                                let fetch_metrics_id =
+                                let fetch_private_user_info =
                                     client.request(proto::GetPrivateUserInfo {}).log_err();
-                                let (user, info) = futures::join!(fetch_user, fetch_metrics_id);
+                                let (user, info) =
+                                    futures::join!(fetch_user, fetch_private_user_info);
 
                                 cx.update(|cx| {
                                     if let Some(info) = info {
@@ -202,9 +206,17 @@ impl UserStore {
                                         client.telemetry.set_authenticated_user_info(
                                             Some(info.metrics_id.clone()),
                                             staff,
-                                        )
+                                        );
+
+                                        this.update(cx, |this, _| {
+                                            this.set_current_user_accepted_tos_at(
+                                                info.accepted_tos_at,
+                                            );
+                                        })
+                                    } else {
+                                        anyhow::Ok(())
                                     }
-                                })?;
+                                })??;
 
                                 current_user_tx.send(user).await.ok();
 
@@ -680,6 +692,39 @@ impl UserStore {
         self.current_user.clone()
     }
 
+    pub fn current_user_has_accepted_terms(&self) -> Option<bool> {
+        self.accepted_tos_at
+            .map(|accepted_tos_at| accepted_tos_at.is_some())
+    }
+
+    pub fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+        if self.current_user().is_none() {
+            return Task::ready(Err(anyhow!("no current user")));
+        };
+
+        let client = self.client.clone();
+        cx.spawn(move |this, mut cx| async move {
+            if let Some(client) = client.upgrade() {
+                let response = client
+                    .request(proto::AcceptTermsOfService {})
+                    .await
+                    .context("error accepting tos")?;
+
+                this.update(&mut cx, |this, _| {
+                    this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at))
+                })
+            } else {
+                Err(anyhow!("client not found"))
+            }
+        })
+    }
+
+    fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option<u64>) {
+        self.accepted_tos_at = Some(
+            accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)),
+        );
+    }
+
     fn load_users(
         &mut self,
         request: impl RequestMessage<Response = UsersResponse>,

crates/collab/migrations.sqlite/20221109000000_test_schema.sql 🔗

@@ -9,7 +9,8 @@ CREATE TABLE "users" (
     "connected_once" BOOLEAN NOT NULL DEFAULT false,
     "created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
     "metrics_id" TEXT,
-    "github_user_id" INTEGER
+    "github_user_id" INTEGER,
+    "accepted_tos_at" TIMESTAMP WITHOUT TIME ZONE
 );
 CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login");
 CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code");

crates/collab/src/db/queries/users.rs 🔗

@@ -225,6 +225,26 @@ impl Database {
         .await
     }
 
+    /// Sets "accepted_tos_at" on the user to the given timestamp.
+    pub async fn set_user_accepted_tos_at(
+        &self,
+        id: UserId,
+        accepted_tos_at: Option<DateTime>,
+    ) -> Result<()> {
+        self.transaction(|tx| async move {
+            user::Entity::update_many()
+                .filter(user::Column::Id.eq(id))
+                .set(user::ActiveModel {
+                    accepted_tos_at: ActiveValue::set(accepted_tos_at),
+                    ..Default::default()
+                })
+                .exec(&*tx)
+                .await?;
+            Ok(())
+        })
+        .await
+    }
+
     /// hard delete the user.
     pub async fn destroy_user(&self, id: UserId) -> Result<()> {
         self.transaction(|tx| async move {

crates/collab/src/db/tables/user.rs 🔗

@@ -18,6 +18,7 @@ pub struct Model {
     pub connected_once: bool,
     pub metrics_id: Uuid,
     pub created_at: DateTime,
+    pub accepted_tos_at: Option<DateTime>,
 }
 
 #[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]

crates/collab/src/db/tests.rs 🔗

@@ -10,6 +10,7 @@ mod extension_tests;
 mod feature_flag_tests;
 mod message_tests;
 mod processed_stripe_event_tests;
+mod user_tests;
 
 use crate::migrations::run_database_migrations;
 

crates/collab/src/db/tests/user_tests.rs 🔗

@@ -0,0 +1,45 @@
+use chrono::Utc;
+
+use crate::{
+    db::{Database, NewUserParams},
+    test_both_dbs,
+};
+use std::sync::Arc;
+
+test_both_dbs!(
+    test_accepted_tos,
+    test_accepted_tos_postgres,
+    test_accepted_tos_sqlite
+);
+
+async fn test_accepted_tos(db: &Arc<Database>) {
+    let user_id = db
+        .create_user(
+            "user1@example.com",
+            false,
+            NewUserParams {
+                github_login: "user1".to_string(),
+                github_user_id: 1,
+            },
+        )
+        .await
+        .unwrap()
+        .user_id;
+
+    let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
+    assert!(user.accepted_tos_at.is_none());
+
+    let accepted_tos_at = Utc::now().naive_utc();
+    db.set_user_accepted_tos_at(user_id, Some(accepted_tos_at))
+        .await
+        .unwrap();
+
+    let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
+    assert!(user.accepted_tos_at.is_some());
+    assert_eq!(user.accepted_tos_at, Some(accepted_tos_at));
+
+    db.set_user_accepted_tos_at(user_id, None).await.unwrap();
+
+    let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
+    assert!(user.accepted_tos_at.is_none());
+}

crates/collab/src/rpc.rs 🔗

@@ -31,6 +31,7 @@ use axum::{
     routing::get,
     Extension, Router, TypedHeader,
 };
+use chrono::Utc;
 use collections::{HashMap, HashSet};
 pub use connection_pool::{ConnectionPool, ZedVersion};
 use core::fmt::{self, Debug, Formatter};
@@ -604,6 +605,7 @@ impl Server {
             .add_message_handler(user_message_handler(update_followers))
             .add_request_handler(user_handler(get_private_user_info))
             .add_request_handler(user_handler(get_llm_api_token))
+            .add_request_handler(user_handler(accept_terms_of_service))
             .add_message_handler(user_message_handler(acknowledge_channel_message))
             .add_message_handler(user_message_handler(acknowledge_buffer_version))
             .add_request_handler(user_handler(get_supermaven_api_key))
@@ -4882,6 +4884,25 @@ async fn get_private_user_info(
         metrics_id,
         staff: user.admin,
         flags,
+        accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
+    })?;
+    Ok(())
+}
+
+/// Accept the terms of service (tos) on behalf of the current user
+async fn accept_terms_of_service(
+    _request: proto::AcceptTermsOfService,
+    response: Response<proto::AcceptTermsOfService>,
+    session: UserSession,
+) -> Result<()> {
+    let db = session.db().await;
+
+    let accepted_tos_at = Utc::now();
+    db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
+        .await?;
+
+    response.send(proto::AcceptTermsOfServiceResponse {
+        accepted_tos_at: accepted_tos_at.timestamp() as u64,
     })?;
     Ok(())
 }

crates/language_model/src/language_model.rs 🔗

@@ -9,7 +9,9 @@ pub mod settings;
 use anyhow::Result;
 use client::{Client, UserStore};
 use futures::{future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext};
+use gpui::{
+    AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
+};
 pub use model::*;
 use project::Fs;
 use proto::Plan;
@@ -114,6 +116,12 @@ pub trait LanguageModelProvider: 'static {
     fn is_authenticated(&self, cx: &AppContext) -> bool;
     fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
     fn configuration_view(&self, cx: &mut WindowContext) -> AnyView;
+    fn must_accept_terms(&self, _cx: &AppContext) -> bool {
+        false
+    }
+    fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option<AnyElement> {
+        None
+    }
     fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
 }
 

crates/language_model/src/provider/cloud.rs 🔗

@@ -9,7 +9,10 @@ use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADE
 use collections::BTreeMap;
 use feature_flags::{FeatureFlagAppExt, LanguageModels};
 use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
-use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
+use gpui::{
+    AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
+    Subscription, Task,
+};
 use http_client::{AsyncBody, HttpClient, Method, Response};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
@@ -62,6 +65,7 @@ pub struct State {
     client: Arc<Client>,
     user_store: Model<UserStore>,
     status: client::Status,
+    accept_terms: Option<Task<Result<()>>>,
     _subscription: Subscription,
 }
 
@@ -77,6 +81,26 @@ impl State {
             this.update(&mut cx, |_, cx| cx.notify())
         })
     }
+
+    fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
+        self.user_store
+            .read(cx)
+            .current_user_has_accepted_terms()
+            .unwrap_or(false)
+    }
+
+    fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
+        let user_store = self.user_store.clone();
+        self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
+            let _ = user_store
+                .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
+                .await;
+            this.update(&mut cx, |this, cx| {
+                this.accept_terms = None;
+                cx.notify()
+            })
+        }));
+    }
 }
 
 impl CloudLanguageModelProvider {
@@ -88,6 +112,7 @@ impl CloudLanguageModelProvider {
             client: client.clone(),
             user_store,
             status,
+            accept_terms: None,
             _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
                 cx.notify();
             }),
@@ -223,6 +248,57 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
         .into()
     }
 
+    fn must_accept_terms(&self, cx: &AppContext) -> bool {
+        !self.state.read(cx).has_accepted_terms_of_service(cx)
+    }
+
+    fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
+        let state = self.state.read(cx);
+
+        let terms = [(
+            "anthropic_terms_of_service",
+            "Anthropic Terms of Service",
+            "https://www.anthropic.com/legal/consumer-terms",
+        )]
+        .map(|(id, label, url)| {
+            Button::new(id, label)
+                .style(ButtonStyle::Subtle)
+                .icon(IconName::ExternalLink)
+                .icon_size(IconSize::XSmall)
+                .icon_color(Color::Muted)
+                .on_click(move |_, cx| cx.open_url(url))
+        });
+
+        if state.has_accepted_terms_of_service(cx) {
+            None
+        } else {
+            let disabled = state.accept_terms.is_some();
+            Some(
+                v_flex()
+                    .child(Label::new("Terms & Conditions").weight(FontWeight::SEMIBOLD))
+                    .child("Please read and accept the terms and conditions of Zed AI and our provider partners to continue.")
+                    .child(v_flex().m_2().gap_1().children(terms))
+                    .child(
+                        h_flex().justify_end().mt_1().child(
+                            Button::new("accept_terms", "Accept")
+                                .disabled(disabled)
+                                .on_click({
+                                    let state = self.state.downgrade();
+                                    move |_, cx| {
+                                        state
+                                            .update(cx, |state, cx| {
+                                                state.accept_terms_of_service(cx)
+                                            })
+                                            .ok();
+                                    }
+                                }),
+                        ),
+                    )
+                    .into_any(),
+            )
+        }
+    }
+
     fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
         Task::ready(Ok(()))
     }
@@ -766,6 +842,7 @@ impl Render for ConfigurationView {
 
         let is_connected = !self.state.read(cx).is_signed_out();
         let plan = self.state.read(cx).user_store.read(cx).current_plan();
+        let must_accept_terms = !self.state.read(cx).has_accepted_terms_of_service(cx);
 
         let is_pro = plan == Some(proto::Plan::ZedPro);
 
@@ -773,6 +850,11 @@ impl Render for ConfigurationView {
             v_flex()
                 .gap_3()
                 .max_w_4_5()
+                .when(must_accept_terms, |this| {
+                    this.child(Label::new(
+                        "You must accept the terms of service to use this provider.",
+                    ))
+                })
                 .child(Label::new(
                     if is_pro {
                         "You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."

crates/proto/proto/zed.proto 🔗

@@ -49,7 +49,7 @@ message Envelope {
         GetDefinition get_definition = 32;
         GetDefinitionResponse get_definition_response = 33;
         GetDeclaration get_declaration = 237;
-        GetDeclarationResponse get_declaration_response = 238; // current max
+        GetDeclarationResponse get_declaration_response = 238;
         GetTypeDefinition get_type_definition = 34;
         GetTypeDefinitionResponse get_type_definition_response = 35;
 
@@ -130,6 +130,8 @@ message Envelope {
         GetPrivateUserInfoResponse get_private_user_info_response = 103;
         UpdateUserPlan update_user_plan = 234;
         UpdateDiffBase update_diff_base = 104;
+        AcceptTermsOfService accept_terms_of_service = 239;
+        AcceptTermsOfServiceResponse accept_terms_of_service_response = 240; // current max
 
         OnTypeFormatting on_type_formatting = 105;
         OnTypeFormattingResponse on_type_formatting_response = 106;
@@ -270,7 +272,7 @@ message Envelope {
         AddWorktreeResponse add_worktree_response = 223;
 
         GetLlmToken get_llm_token = 235;
-        GetLlmTokenResponse get_llm_token_response = 236; // current max
+        GetLlmTokenResponse get_llm_token_response = 236;
     }
 
     reserved 158 to 161;
@@ -1692,6 +1694,7 @@ message GetPrivateUserInfoResponse {
     string metrics_id = 1;
     bool staff = 2;
     repeated string flags = 3;
+    optional uint64 accepted_tos_at = 4;
 }
 
 enum Plan {
@@ -1703,6 +1706,12 @@ message UpdateUserPlan {
     Plan plan = 1;
 }
 
+message AcceptTermsOfService {}
+
+message AcceptTermsOfServiceResponse {
+    uint64 accepted_tos_at = 1;
+}
+
 // Entities
 
 message ViewId {

crates/proto/src/proto.rs 🔗

@@ -187,6 +187,8 @@ impl fmt::Display for PeerId {
 }
 
 messages!(
+    (AcceptTermsOfService, Foreground),
+    (AcceptTermsOfServiceResponse, Foreground),
     (Ack, Foreground),
     (AckBufferOperation, Background),
     (AckChannelMessage, Background),
@@ -409,6 +411,7 @@ messages!(
 );
 
 request_messages!(
+    (AcceptTermsOfService, AcceptTermsOfServiceResponse),
     (ApplyCodeAction, ApplyCodeActionResponse),
     (
         ApplyCompletionAdditionalEdits,