@@ -105,18 +105,6 @@ impl<R: RequestMessage> Response<R> {
}
}
-struct StreamingResponse<R: RequestMessage> {
- peer: Arc<Peer>,
- receipt: Receipt<R>,
-}
-
-impl<R: RequestMessage> StreamingResponse<R> {
- fn send(&self, payload: R::Response) -> Result<()> {
- self.peer.respond(self.receipt, payload)?;
- Ok(())
- }
-}
-
#[derive(Clone, Debug)]
pub enum Principal {
User(User),
@@ -630,31 +618,6 @@ impl Server {
))
.add_message_handler(broadcast_project_message_from_host::<proto::AdvertiseContexts>)
.add_message_handler(update_context)
- .add_request_handler({
- let app_state = app_state.clone();
- move |request, response, session| {
- let app_state = app_state.clone();
- async move {
- complete_with_language_model(request, response, session, &app_state.config)
- .await
- }
- }
- })
- .add_streaming_request_handler({
- let app_state = app_state.clone();
- move |request, response, session| {
- let app_state = app_state.clone();
- async move {
- stream_complete_with_language_model(
- request,
- response,
- session,
- &app_state.config,
- )
- .await
- }
- }
- })
.add_request_handler({
let app_state = app_state.clone();
move |request, response, session| {
@@ -948,40 +911,6 @@ impl Server {
})
}
- fn add_streaming_request_handler<F, Fut, M>(&mut self, handler: F) -> &mut Self
- where
- F: 'static + Send + Sync + Fn(M, StreamingResponse<M>, Session) -> Fut,
- Fut: Send + Future<Output = Result<()>>,
- M: RequestMessage,
- {
- let handler = Arc::new(handler);
- self.add_handler(move |envelope, session| {
- let receipt = envelope.receipt();
- let handler = handler.clone();
- async move {
- let peer = session.peer.clone();
- let response = StreamingResponse {
- peer: peer.clone(),
- receipt,
- };
- match (handler)(envelope.payload, response, session).await {
- Ok(()) => {
- peer.end_stream(receipt)?;
- Ok(())
- }
- Err(error) => {
- let proto_err = match &error {
- Error::Internal(err) => err.to_proto(),
- _ => ErrorCode::Internal.message(format!("{}", error)).to_proto(),
- };
- peer.respond_with_error(receipt, proto_err)?;
- Err(error)
- }
- }
- }
- })
- }
-
#[allow(clippy::too_many_arguments)]
pub fn handle_connection(
self: &Arc<Self>,
@@ -4561,202 +4490,6 @@ async fn acknowledge_buffer_version(
Ok(())
}
-struct ZedProCompleteWithLanguageModelRateLimit;
-
-impl RateLimit for ZedProCompleteWithLanguageModelRateLimit {
- fn capacity(&self) -> usize {
- std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR")
- .ok()
- .and_then(|v| v.parse().ok())
- .unwrap_or(120) // Picked arbitrarily
- }
-
- fn refill_duration(&self) -> chrono::Duration {
- chrono::Duration::hours(1)
- }
-
- fn db_name(&self) -> &'static str {
- "zed-pro:complete-with-language-model"
- }
-}
-
-struct FreeCompleteWithLanguageModelRateLimit;
-
-impl RateLimit for FreeCompleteWithLanguageModelRateLimit {
- fn capacity(&self) -> usize {
- std::env::var("COMPLETE_WITH_LANGUAGE_MODEL_RATE_LIMIT_PER_HOUR_FREE")
- .ok()
- .and_then(|v| v.parse().ok())
- .unwrap_or(120 / 10) // Picked arbitrarily
- }
-
- fn refill_duration(&self) -> chrono::Duration {
- chrono::Duration::hours(1)
- }
-
- fn db_name(&self) -> &'static str {
- "free:complete-with-language-model"
- }
-}
-
-async fn complete_with_language_model(
- request: proto::CompleteWithLanguageModel,
- response: Response<proto::CompleteWithLanguageModel>,
- session: Session,
- config: &Config,
-) -> Result<()> {
- let Some(session) = session.for_user() else {
- return Err(anyhow!("user not found"))?;
- };
- authorize_access_to_language_models(&session).await?;
-
- let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
- proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
- proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
- };
-
- session
- .app_state
- .rate_limiter
- .check(&*rate_limit, session.user_id())
- .await?;
-
- let result = match proto::LanguageModelProvider::from_i32(request.provider) {
- Some(proto::LanguageModelProvider::Anthropic) => {
- let api_key = config
- .anthropic_api_key
- .as_ref()
- .context("no Anthropic AI API key configured on the server")?;
- anthropic::complete(
- session.http_client.as_ref(),
- anthropic::ANTHROPIC_API_URL,
- api_key,
- serde_json::from_str(&request.request)?,
- )
- .await?
- }
- _ => return Err(anyhow!("unsupported provider"))?,
- };
-
- response.send(proto::CompleteWithLanguageModelResponse {
- completion: serde_json::to_string(&result)?,
- })?;
-
- Ok(())
-}
-
-async fn stream_complete_with_language_model(
- request: proto::StreamCompleteWithLanguageModel,
- response: StreamingResponse<proto::StreamCompleteWithLanguageModel>,
- session: Session,
- config: &Config,
-) -> Result<()> {
- let Some(session) = session.for_user() else {
- return Err(anyhow!("user not found"))?;
- };
- authorize_access_to_language_models(&session).await?;
-
- let rate_limit: Box<dyn RateLimit> = match session.current_plan().await? {
- proto::Plan::ZedPro => Box::new(ZedProCompleteWithLanguageModelRateLimit),
- proto::Plan::Free => Box::new(FreeCompleteWithLanguageModelRateLimit),
- };
-
- session
- .app_state
- .rate_limiter
- .check(&*rate_limit, session.user_id())
- .await?;
-
- match proto::LanguageModelProvider::from_i32(request.provider) {
- Some(proto::LanguageModelProvider::Anthropic) => {
- let api_key = config
- .anthropic_api_key
- .as_ref()
- .context("no Anthropic AI API key configured on the server")?;
- let mut chunks = anthropic::stream_completion(
- session.http_client.as_ref(),
- anthropic::ANTHROPIC_API_URL,
- api_key,
- serde_json::from_str(&request.request)?,
- None,
- )
- .await?;
- while let Some(event) = chunks.next().await {
- let chunk = event?;
- response.send(proto::StreamCompleteWithLanguageModelResponse {
- event: serde_json::to_string(&chunk)?,
- })?;
- }
- }
- Some(proto::LanguageModelProvider::OpenAi) => {
- let api_key = config
- .openai_api_key
- .as_ref()
- .context("no OpenAI API key configured on the server")?;
- let mut events = open_ai::stream_completion(
- session.http_client.as_ref(),
- open_ai::OPEN_AI_API_URL,
- api_key,
- serde_json::from_str(&request.request)?,
- None,
- )
- .await?;
- while let Some(event) = events.next().await {
- let event = event?;
- response.send(proto::StreamCompleteWithLanguageModelResponse {
- event: serde_json::to_string(&event)?,
- })?;
- }
- }
- Some(proto::LanguageModelProvider::Google) => {
- let api_key = config
- .google_ai_api_key
- .as_ref()
- .context("no Google AI API key configured on the server")?;
- let mut events = google_ai::stream_generate_content(
- session.http_client.as_ref(),
- google_ai::API_URL,
- api_key,
- serde_json::from_str(&request.request)?,
- )
- .await?;
- while let Some(event) = events.next().await {
- let event = event?;
- response.send(proto::StreamCompleteWithLanguageModelResponse {
- event: serde_json::to_string(&event)?,
- })?;
- }
- }
- Some(proto::LanguageModelProvider::Zed) => {
- let api_key = config
- .qwen2_7b_api_key
- .as_ref()
- .context("no Qwen2-7B API key configured on the server")?;
- let api_url = config
- .qwen2_7b_api_url
- .as_ref()
- .context("no Qwen2-7B URL configured on the server")?;
- let mut events = open_ai::stream_completion(
- session.http_client.as_ref(),
- &api_url,
- api_key,
- serde_json::from_str(&request.request)?,
- None,
- )
- .await?;
- while let Some(event) = events.next().await {
- let event = event?;
- response.send(proto::StreamCompleteWithLanguageModelResponse {
- event: serde_json::to_string(&event)?,
- })?;
- }
- }
- None => return Err(anyhow!("unknown provider"))?,
- }
-
- Ok(())
-}
-
async fn count_language_model_tokens(
request: proto::CountLanguageModelTokens,
response: Response<proto::CountLanguageModelTokens>,
@@ -298,10 +298,6 @@ messages!(
(PrepareRename, Background),
(PrepareRenameResponse, Background),
(ProjectEntryResponse, Foreground),
- (CompleteWithLanguageModel, Background),
- (CompleteWithLanguageModelResponse, Background),
- (StreamCompleteWithLanguageModel, Background),
- (StreamCompleteWithLanguageModelResponse, Background),
(CountLanguageModelTokens, Background),
(CountLanguageModelTokensResponse, Background),
(RefreshInlayHints, Foreground),
@@ -476,11 +472,6 @@ request_messages!(
(PerformRename, PerformRenameResponse),
(Ping, Ack),
(PrepareRename, PrepareRenameResponse),
- (CompleteWithLanguageModel, CompleteWithLanguageModelResponse),
- (
- StreamCompleteWithLanguageModel,
- StreamCompleteWithLanguageModelResponse
- ),
(CountLanguageModelTokens, CountLanguageModelTokensResponse),
(RefreshInlayHints, Ack),
(RejoinChannelBuffers, RejoinChannelBuffersResponse),