Detailed changes
@@ -490,6 +490,7 @@ impl AssistantPanel {
}
language_model::Event::ProviderStateChanged => {
this.ensure_authenticated(cx);
+ cx.notify()
}
language_model::Event::AddedProvider(_)
| language_model::Event::RemovedProvider(_) => {
@@ -1712,6 +1713,7 @@ pub struct ContextEditor {
assistant_panel: WeakView<AssistantPanel>,
error_message: Option<SharedString>,
debug_inspector: Option<ContextInspector>,
+ show_accept_terms: bool,
}
const DEFAULT_TAB_TITLE: &str = "New Context";
@@ -1772,6 +1774,7 @@ impl ContextEditor {
assistant_panel,
error_message: None,
debug_inspector: None,
+ show_accept_terms: false,
};
this.update_message_headers(cx);
this.insert_slash_command_output_sections(sections, cx);
@@ -1804,6 +1807,16 @@ impl ContextEditor {
}
fn assist(&mut self, _: &Assist, cx: &mut ViewContext<Self>) {
+ let provider = LanguageModelRegistry::read_global(cx).active_provider();
+ if provider
+ .as_ref()
+ .map_or(false, |provider| provider.must_accept_terms(cx))
+ {
+ self.show_accept_terms = true;
+ cx.notify();
+ return;
+ }
+
if !self.apply_active_workflow_step(cx) {
self.error_message = None;
self.send_to_model(cx);
@@ -3388,7 +3401,14 @@ impl ContextEditor {
None => (ButtonStyle::Filled, None),
};
+ let provider = LanguageModelRegistry::read_global(cx).active_provider();
+ let disabled = self.show_accept_terms
+ && provider
+ .as_ref()
+ .map_or(false, |provider| provider.must_accept_terms(cx));
+
ButtonLike::new("send_button")
+ .disabled(disabled)
.style(style)
.when_some(tooltip, |button, tooltip| {
button.tooltip(move |_| tooltip.clone())
@@ -3437,6 +3457,15 @@ impl EventEmitter<SearchEvent> for ContextEditor {}
impl Render for ContextEditor {
fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
+ let provider = LanguageModelRegistry::read_global(cx).active_provider();
+ let accept_terms = if self.show_accept_terms {
+ provider
+ .as_ref()
+ .and_then(|provider| provider.render_accept_terms(cx))
+ } else {
+ None
+ };
+
v_flex()
.key_context("ContextEditor")
.capture_action(cx.listener(ContextEditor::cancel))
@@ -3455,6 +3484,21 @@ impl Render for ContextEditor {
.bg(cx.theme().colors().editor_background)
.child(self.editor.clone()),
)
+ .when_some(accept_terms, |this, element| {
+ this.child(
+ div()
+ .absolute()
+ .right_4()
+ .bottom_10()
+ .max_w_96()
+ .py_2()
+ .px_3()
+ .elevation_2(cx)
+ .bg(cx.theme().colors().surface_background)
+ .occlude()
+ .child(element),
+ )
+ })
.child(
h_flex().flex_none().relative().child(
h_flex()
@@ -1,5 +1,6 @@
use crate::{Client, Connection, Credentials, EstablishConnectionError, UserStore};
use anyhow::{anyhow, Result};
+use chrono::Duration;
use futures::{stream::BoxStream, StreamExt};
use gpui::{BackgroundExecutor, Context, Model, TestAppContext};
use parking_lot::Mutex;
@@ -162,6 +163,11 @@ impl FakeServer {
return Ok(*message.downcast().unwrap());
}
+ let accepted_tos_at = chrono::Utc::now()
+ .checked_sub_signed(Duration::hours(5))
+ .expect("failed to build accepted_tos_at")
+ .timestamp() as u64;
+
if message.is::<TypedEnvelope<GetPrivateUserInfo>>() {
self.respond(
message
@@ -172,6 +178,7 @@ impl FakeServer {
metrics_id: "the-metrics-id".into(),
staff: false,
flags: Default::default(),
+ accepted_tos_at: Some(accepted_tos_at),
},
);
continue;
@@ -1,5 +1,6 @@
use super::{proto, Client, Status, TypedEnvelope};
use anyhow::{anyhow, Context, Result};
+use chrono::{DateTime, Utc};
use collections::{hash_map::Entry, HashMap, HashSet};
use feature_flags::FeatureFlagAppExt;
use futures::{channel::mpsc, Future, StreamExt};
@@ -94,6 +95,7 @@ pub struct UserStore {
update_contacts_tx: mpsc::UnboundedSender<UpdateContacts>,
current_plan: Option<proto::Plan>,
current_user: watch::Receiver<Option<Arc<User>>>,
+ accepted_tos_at: Option<Option<DateTime<Utc>>>,
contacts: Vec<Arc<Contact>>,
incoming_contact_requests: Vec<Arc<User>>,
outgoing_contact_requests: Vec<Arc<User>>,
@@ -150,6 +152,7 @@ impl UserStore {
by_github_login: Default::default(),
current_user: current_user_rx,
current_plan: None,
+ accepted_tos_at: None,
contacts: Default::default(),
incoming_contact_requests: Default::default(),
participant_indices: Default::default(),
@@ -189,9 +192,10 @@ impl UserStore {
} else {
break;
};
- let fetch_metrics_id =
+ let fetch_private_user_info =
client.request(proto::GetPrivateUserInfo {}).log_err();
- let (user, info) = futures::join!(fetch_user, fetch_metrics_id);
+ let (user, info) =
+ futures::join!(fetch_user, fetch_private_user_info);
cx.update(|cx| {
if let Some(info) = info {
@@ -202,9 +206,17 @@ impl UserStore {
client.telemetry.set_authenticated_user_info(
Some(info.metrics_id.clone()),
staff,
- )
+ );
+
+ this.update(cx, |this, _| {
+ this.set_current_user_accepted_tos_at(
+ info.accepted_tos_at,
+ );
+ })
+ } else {
+ anyhow::Ok(())
}
- })?;
+ })??;
current_user_tx.send(user).await.ok();
@@ -680,6 +692,39 @@ impl UserStore {
self.current_user.clone()
}
+ pub fn current_user_has_accepted_terms(&self) -> Option<bool> {
+ self.accepted_tos_at
+ .map(|accepted_tos_at| accepted_tos_at.is_some())
+ }
+
+ pub fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+ if self.current_user().is_none() {
+ return Task::ready(Err(anyhow!("no current user")));
+ };
+
+ let client = self.client.clone();
+ cx.spawn(move |this, mut cx| async move {
+ if let Some(client) = client.upgrade() {
+ let response = client
+ .request(proto::AcceptTermsOfService {})
+ .await
+ .context("error accepting tos")?;
+
+ this.update(&mut cx, |this, _| {
+ this.set_current_user_accepted_tos_at(Some(response.accepted_tos_at))
+ })
+ } else {
+ Err(anyhow!("client not found"))
+ }
+ })
+ }
+
+ fn set_current_user_accepted_tos_at(&mut self, accepted_tos_at: Option<u64>) {
+ self.accepted_tos_at = Some(
+ accepted_tos_at.and_then(|timestamp| DateTime::from_timestamp(timestamp as i64, 0)),
+ );
+ }
+
fn load_users(
&mut self,
request: impl RequestMessage<Response = UsersResponse>,
@@ -9,7 +9,8 @@ CREATE TABLE "users" (
"connected_once" BOOLEAN NOT NULL DEFAULT false,
"created_at" TIMESTAMP NOT NULL DEFAULT CURRENT_TIMESTAMP,
"metrics_id" TEXT,
- "github_user_id" INTEGER
+ "github_user_id" INTEGER,
+ "accepted_tos_at" TIMESTAMP WITHOUT TIME ZONE
);
CREATE UNIQUE INDEX "index_users_github_login" ON "users" ("github_login");
CREATE UNIQUE INDEX "index_invite_code_users" ON "users" ("invite_code");
@@ -0,0 +1 @@
+ALTER TABLE users ADD accepted_tos_at TIMESTAMP WITHOUT TIME ZONE;
@@ -225,6 +225,26 @@ impl Database {
.await
}
+ /// Sets "accepted_tos_at" on the user to the given timestamp.
+ pub async fn set_user_accepted_tos_at(
+ &self,
+ id: UserId,
+ accepted_tos_at: Option<DateTime>,
+ ) -> Result<()> {
+ self.transaction(|tx| async move {
+ user::Entity::update_many()
+ .filter(user::Column::Id.eq(id))
+ .set(user::ActiveModel {
+ accepted_tos_at: ActiveValue::set(accepted_tos_at),
+ ..Default::default()
+ })
+ .exec(&*tx)
+ .await?;
+ Ok(())
+ })
+ .await
+ }
+
/// hard delete the user.
pub async fn destroy_user(&self, id: UserId) -> Result<()> {
self.transaction(|tx| async move {
@@ -18,6 +18,7 @@ pub struct Model {
pub connected_once: bool,
pub metrics_id: Uuid,
pub created_at: DateTime,
+ pub accepted_tos_at: Option<DateTime>,
}
#[derive(Copy, Clone, Debug, EnumIter, DeriveRelation)]
@@ -10,6 +10,7 @@ mod extension_tests;
mod feature_flag_tests;
mod message_tests;
mod processed_stripe_event_tests;
+mod user_tests;
use crate::migrations::run_database_migrations;
@@ -0,0 +1,45 @@
+use chrono::Utc;
+
+use crate::{
+ db::{Database, NewUserParams},
+ test_both_dbs,
+};
+use std::sync::Arc;
+
+test_both_dbs!(
+ test_accepted_tos,
+ test_accepted_tos_postgres,
+ test_accepted_tos_sqlite
+);
+
+async fn test_accepted_tos(db: &Arc<Database>) {
+ let user_id = db
+ .create_user(
+ "user1@example.com",
+ false,
+ NewUserParams {
+ github_login: "user1".to_string(),
+ github_user_id: 1,
+ },
+ )
+ .await
+ .unwrap()
+ .user_id;
+
+ let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
+ assert!(user.accepted_tos_at.is_none());
+
+ let accepted_tos_at = Utc::now().naive_utc();
+ db.set_user_accepted_tos_at(user_id, Some(accepted_tos_at))
+ .await
+ .unwrap();
+
+ let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
+ assert!(user.accepted_tos_at.is_some());
+ assert_eq!(user.accepted_tos_at, Some(accepted_tos_at));
+
+ db.set_user_accepted_tos_at(user_id, None).await.unwrap();
+
+ let user = db.get_user_by_id(user_id).await.unwrap().unwrap();
+ assert!(user.accepted_tos_at.is_none());
+}
@@ -31,6 +31,7 @@ use axum::{
routing::get,
Extension, Router, TypedHeader,
};
+use chrono::Utc;
use collections::{HashMap, HashSet};
pub use connection_pool::{ConnectionPool, ZedVersion};
use core::fmt::{self, Debug, Formatter};
@@ -604,6 +605,7 @@ impl Server {
.add_message_handler(user_message_handler(update_followers))
.add_request_handler(user_handler(get_private_user_info))
.add_request_handler(user_handler(get_llm_api_token))
+ .add_request_handler(user_handler(accept_terms_of_service))
.add_message_handler(user_message_handler(acknowledge_channel_message))
.add_message_handler(user_message_handler(acknowledge_buffer_version))
.add_request_handler(user_handler(get_supermaven_api_key))
@@ -4882,6 +4884,25 @@ async fn get_private_user_info(
metrics_id,
staff: user.admin,
flags,
+ accepted_tos_at: user.accepted_tos_at.map(|t| t.and_utc().timestamp() as u64),
+ })?;
+ Ok(())
+}
+
+/// Accept the terms of service (tos) on behalf of the current user
+async fn accept_terms_of_service(
+ _request: proto::AcceptTermsOfService,
+ response: Response<proto::AcceptTermsOfService>,
+ session: UserSession,
+) -> Result<()> {
+ let db = session.db().await;
+
+ let accepted_tos_at = Utc::now();
+ db.set_user_accepted_tos_at(session.user_id(), Some(accepted_tos_at.naive_utc()))
+ .await?;
+
+ response.send(proto::AcceptTermsOfServiceResponse {
+ accepted_tos_at: accepted_tos_at.timestamp() as u64,
})?;
Ok(())
}
@@ -9,7 +9,9 @@ pub mod settings;
use anyhow::Result;
use client::{Client, UserStore};
use futures::{future::BoxFuture, stream::BoxStream};
-use gpui::{AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext};
+use gpui::{
+ AnyElement, AnyView, AppContext, AsyncAppContext, Model, SharedString, Task, WindowContext,
+};
pub use model::*;
use project::Fs;
use proto::Plan;
@@ -114,6 +116,12 @@ pub trait LanguageModelProvider: 'static {
fn is_authenticated(&self, cx: &AppContext) -> bool;
fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>>;
fn configuration_view(&self, cx: &mut WindowContext) -> AnyView;
+ fn must_accept_terms(&self, _cx: &AppContext) -> bool {
+ false
+ }
+ fn render_accept_terms(&self, _cx: &mut WindowContext) -> Option<AnyElement> {
+ None
+ }
fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>>;
}
@@ -9,7 +9,10 @@ use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADE
use collections::BTreeMap;
use feature_flags::{FeatureFlagAppExt, LanguageModels};
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
-use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
+use gpui::{
+ AnyElement, AnyView, AppContext, AsyncAppContext, FontWeight, Model, ModelContext,
+ Subscription, Task,
+};
use http_client::{AsyncBody, HttpClient, Method, Response};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
@@ -62,6 +65,7 @@ pub struct State {
client: Arc<Client>,
user_store: Model<UserStore>,
status: client::Status,
+ accept_terms: Option<Task<Result<()>>>,
_subscription: Subscription,
}
@@ -77,6 +81,26 @@ impl State {
this.update(&mut cx, |_, cx| cx.notify())
})
}
+
+ fn has_accepted_terms_of_service(&self, cx: &AppContext) -> bool {
+ self.user_store
+ .read(cx)
+ .current_user_has_accepted_terms()
+ .unwrap_or(false)
+ }
+
+ fn accept_terms_of_service(&mut self, cx: &mut ModelContext<Self>) {
+ let user_store = self.user_store.clone();
+ self.accept_terms = Some(cx.spawn(move |this, mut cx| async move {
+ let _ = user_store
+ .update(&mut cx, |store, cx| store.accept_terms_of_service(cx))?
+ .await;
+ this.update(&mut cx, |this, cx| {
+ this.accept_terms = None;
+ cx.notify()
+ })
+ }));
+ }
}
impl CloudLanguageModelProvider {
@@ -88,6 +112,7 @@ impl CloudLanguageModelProvider {
client: client.clone(),
user_store,
status,
+ accept_terms: None,
_subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
@@ -223,6 +248,57 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
.into()
}
+ fn must_accept_terms(&self, cx: &AppContext) -> bool {
+ !self.state.read(cx).has_accepted_terms_of_service(cx)
+ }
+
+ fn render_accept_terms(&self, cx: &mut WindowContext) -> Option<AnyElement> {
+ let state = self.state.read(cx);
+
+ let terms = [(
+ "anthropic_terms_of_service",
+ "Anthropic Terms of Service",
+ "https://www.anthropic.com/legal/consumer-terms",
+ )]
+ .map(|(id, label, url)| {
+ Button::new(id, label)
+ .style(ButtonStyle::Subtle)
+ .icon(IconName::ExternalLink)
+ .icon_size(IconSize::XSmall)
+ .icon_color(Color::Muted)
+ .on_click(move |_, cx| cx.open_url(url))
+ });
+
+ if state.has_accepted_terms_of_service(cx) {
+ None
+ } else {
+ let disabled = state.accept_terms.is_some();
+ Some(
+ v_flex()
+ .child(Label::new("Terms & Conditions").weight(FontWeight::SEMIBOLD))
+ .child("Please read and accept the terms and conditions of Zed AI and our provider partners to continue.")
+ .child(v_flex().m_2().gap_1().children(terms))
+ .child(
+ h_flex().justify_end().mt_1().child(
+ Button::new("accept_terms", "Accept")
+ .disabled(disabled)
+ .on_click({
+ let state = self.state.downgrade();
+ move |_, cx| {
+ state
+ .update(cx, |state, cx| {
+ state.accept_terms_of_service(cx)
+ })
+ .ok();
+ }
+ }),
+ ),
+ )
+ .into_any(),
+ )
+ }
+ }
+
fn reset_credentials(&self, _cx: &mut AppContext) -> Task<Result<()>> {
Task::ready(Ok(()))
}
@@ -766,6 +842,7 @@ impl Render for ConfigurationView {
let is_connected = !self.state.read(cx).is_signed_out();
let plan = self.state.read(cx).user_store.read(cx).current_plan();
+ let must_accept_terms = !self.state.read(cx).has_accepted_terms_of_service(cx);
let is_pro = plan == Some(proto::Plan::ZedPro);
@@ -773,6 +850,11 @@ impl Render for ConfigurationView {
v_flex()
.gap_3()
.max_w_4_5()
+ .when(must_accept_terms, |this| {
+ this.child(Label::new(
+ "You must accept the terms of service to use this provider.",
+ ))
+ })
.child(Label::new(
if is_pro {
"You have full access to Zed's hosted models from Anthropic, OpenAI, Google with faster speeds and higher limits through Zed Pro."
@@ -49,7 +49,7 @@ message Envelope {
GetDefinition get_definition = 32;
GetDefinitionResponse get_definition_response = 33;
GetDeclaration get_declaration = 237;
- GetDeclarationResponse get_declaration_response = 238; // current max
+ GetDeclarationResponse get_declaration_response = 238;
GetTypeDefinition get_type_definition = 34;
GetTypeDefinitionResponse get_type_definition_response = 35;
@@ -130,6 +130,8 @@ message Envelope {
GetPrivateUserInfoResponse get_private_user_info_response = 103;
UpdateUserPlan update_user_plan = 234;
UpdateDiffBase update_diff_base = 104;
+ AcceptTermsOfService accept_terms_of_service = 239;
+ AcceptTermsOfServiceResponse accept_terms_of_service_response = 240; // current max
OnTypeFormatting on_type_formatting = 105;
OnTypeFormattingResponse on_type_formatting_response = 106;
@@ -270,7 +272,7 @@ message Envelope {
AddWorktreeResponse add_worktree_response = 223;
GetLlmToken get_llm_token = 235;
- GetLlmTokenResponse get_llm_token_response = 236; // current max
+ GetLlmTokenResponse get_llm_token_response = 236;
}
reserved 158 to 161;
@@ -1692,6 +1694,7 @@ message GetPrivateUserInfoResponse {
string metrics_id = 1;
bool staff = 2;
repeated string flags = 3;
+ optional uint64 accepted_tos_at = 4;
}
enum Plan {
@@ -1703,6 +1706,12 @@ message UpdateUserPlan {
Plan plan = 1;
}
+message AcceptTermsOfService {}
+
+message AcceptTermsOfServiceResponse {
+ uint64 accepted_tos_at = 1;
+}
+
// Entities
message ViewId {
@@ -187,6 +187,8 @@ impl fmt::Display for PeerId {
}
messages!(
+ (AcceptTermsOfService, Foreground),
+ (AcceptTermsOfServiceResponse, Foreground),
(Ack, Foreground),
(AckBufferOperation, Background),
(AckChannelMessage, Background),
@@ -409,6 +411,7 @@ messages!(
);
request_messages!(
+ (AcceptTermsOfService, AcceptTermsOfServiceResponse),
(ApplyCodeAction, ApplyCodeActionResponse),
(
ApplyCompletionAdditionalEdits,