@@ -12,7 +12,7 @@ use axum::{
};
use futures::StreamExt as _;
use http_client::IsahcHttpClient;
-use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
+use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
use std::sync::Arc;
pub use token::*;
@@ -94,29 +94,118 @@ async fn perform_completion(
Extension(_claims): Extension<LlmTokenClaims>,
Json(params): Json<PerformCompletionParams>,
) -> Result<impl IntoResponse> {
- let api_key = state
- .config
- .anthropic_api_key
- .as_ref()
- .context("no Anthropic AI API key configured on the server")?;
- let chunks = anthropic::stream_completion(
- &state.http_client,
- anthropic::ANTHROPIC_API_URL,
- api_key,
- serde_json::from_str(¶ms.provider_request.get())?,
- None,
- )
- .await?;
-
- let stream = chunks.map(|event| {
- let mut buffer = Vec::new();
- event.map(|chunk| {
- buffer.clear();
- serde_json::to_writer(&mut buffer, &chunk).unwrap();
- buffer.push(b'\n');
- buffer
- })
- });
-
- Ok(Response::new(Body::wrap_stream(stream)))
+ match params.provider {
+ LanguageModelProvider::Anthropic => {
+ let api_key = state
+ .config
+ .anthropic_api_key
+ .as_ref()
+ .context("no Anthropic AI API key configured on the server")?;
+ let chunks = anthropic::stream_completion(
+ &state.http_client,
+ anthropic::ANTHROPIC_API_URL,
+ api_key,
+ serde_json::from_str(¶ms.provider_request.get())?,
+ None,
+ )
+ .await?;
+
+ let stream = chunks.map(|event| {
+ let mut buffer = Vec::new();
+ event.map(|chunk| {
+ buffer.clear();
+ serde_json::to_writer(&mut buffer, &chunk).unwrap();
+ buffer.push(b'\n');
+ buffer
+ })
+ });
+
+ Ok(Response::new(Body::wrap_stream(stream)))
+ }
+ LanguageModelProvider::OpenAi => {
+ let api_key = state
+ .config
+ .openai_api_key
+ .as_ref()
+ .context("no OpenAI API key configured on the server")?;
+ let chunks = open_ai::stream_completion(
+ &state.http_client,
+ open_ai::OPEN_AI_API_URL,
+ api_key,
+ serde_json::from_str(¶ms.provider_request.get())?,
+ None,
+ )
+ .await?;
+
+ let stream = chunks.map(|event| {
+ let mut buffer = Vec::new();
+ event.map(|chunk| {
+ buffer.clear();
+ serde_json::to_writer(&mut buffer, &chunk).unwrap();
+ buffer.push(b'\n');
+ buffer
+ })
+ });
+
+ Ok(Response::new(Body::wrap_stream(stream)))
+ }
+ LanguageModelProvider::Google => {
+ let api_key = state
+ .config
+ .google_ai_api_key
+ .as_ref()
+ .context("no Google AI API key configured on the server")?;
+ let chunks = google_ai::stream_generate_content(
+ &state.http_client,
+ google_ai::API_URL,
+ api_key,
+ serde_json::from_str(¶ms.provider_request.get())?,
+ )
+ .await?;
+
+ let stream = chunks.map(|event| {
+ let mut buffer = Vec::new();
+ event.map(|chunk| {
+ buffer.clear();
+ serde_json::to_writer(&mut buffer, &chunk).unwrap();
+ buffer.push(b'\n');
+ buffer
+ })
+ });
+
+ Ok(Response::new(Body::wrap_stream(stream)))
+ }
+ LanguageModelProvider::Zed => {
+ let api_key = state
+ .config
+ .qwen2_7b_api_key
+ .as_ref()
+ .context("no Qwen2-7B API key configured on the server")?;
+ let api_url = state
+ .config
+ .qwen2_7b_api_url
+ .as_ref()
+ .context("no Qwen2-7B URL configured on the server")?;
+ let chunks = open_ai::stream_completion(
+ &state.http_client,
+ &api_url,
+ api_key,
+ serde_json::from_str(¶ms.provider_request.get())?,
+ None,
+ )
+ .await?;
+
+ let stream = chunks.map(|event| {
+ let mut buffer = Vec::new();
+ event.map(|chunk| {
+ buffer.clear();
+ serde_json::to_writer(&mut buffer, &chunk).unwrap();
+ buffer.push(b'\n');
+ buffer
+ })
+ });
+
+ Ok(Response::new(Body::wrap_stream(stream)))
+ }
+ }
}
@@ -10,7 +10,7 @@ use collections::BTreeMap;
use feature_flags::{FeatureFlag, FeatureFlagAppExt};
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
-use http_client::{HttpClient, Method};
+use http_client::{AsyncBody, HttpClient, Method, Response};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue;
@@ -239,6 +239,47 @@ pub struct CloudLanguageModel {
#[derive(Clone, Default)]
struct LlmApiToken(Arc<RwLock<Option<String>>>);
+impl CloudLanguageModel {
+ async fn perform_llm_completion(
+ client: Arc<Client>,
+ llm_api_token: LlmApiToken,
+ body: PerformCompletionParams,
+ ) -> Result<Response<AsyncBody>> {
+ let http_client = &client.http_client();
+
+ let mut token = llm_api_token.acquire(&client).await?;
+ let mut did_retry = false;
+
+ let response = loop {
+ let request = http_client::Request::builder()
+ .method(Method::POST)
+ .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
+ .header("Content-Type", "application/json")
+ .header("Authorization", format!("Bearer {token}"))
+ .body(serde_json::to_string(&body)?.into())?;
+ let response = http_client.send(request).await?;
+ if response.status().is_success() {
+ break response;
+ } else if !did_retry
+ && response
+ .headers()
+ .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+ .is_some()
+ {
+ did_retry = true;
+ token = llm_api_token.refresh(&client).await?;
+ } else {
+ break Err(anyhow!(
+ "cloud language model completion failed with status {}",
+ response.status()
+ ))?;
+ }
+ };
+
+ Ok(response)
+ }
+}
+
impl LanguageModel for CloudLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@@ -314,46 +355,21 @@ impl LanguageModel for CloudLanguageModel {
.update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
.unwrap_or(false)
{
- let http_client = self.client.http_client();
let llm_api_token = self.llm_api_token.clone();
let future = self.request_limiter.stream(async move {
- let request = serde_json::to_string(&request)?;
- let mut token = llm_api_token.acquire(&client).await?;
- let mut did_retry = false;
-
- let response = loop {
- let request = http_client::Request::builder()
- .method(Method::POST)
- .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
- .header("Content-Type", "application/json")
- .header("Authorization", format!("Bearer {token}"))
- .body(
- serde_json::to_string(&PerformCompletionParams {
- provider_request: RawValue::from_string(request.clone())?,
- })?
- .into(),
- )?;
- let response = http_client.send(request).await?;
- if response.status().is_success() {
- break response;
- } else if !did_retry
- && response
- .headers()
- .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
- .is_some()
- {
- did_retry = true;
- token = llm_api_token.refresh(&client).await?;
- } else {
- break Err(anyhow!(
- "cloud language model completion failed with status {}",
- response.status()
- ))?;
- }
- };
-
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::Anthropic,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(serde_json::to_string(
+ &request,
+ )?)?,
+ },
+ )
+ .await?;
let body = BufReader::new(response.into_body());
-
let stream =
futures::stream::try_unfold(body, move |mut body| async move {
let mut buffer = String::new();
@@ -389,54 +405,171 @@ impl LanguageModel for CloudLanguageModel {
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = request.into_open_ai(model.id().into());
- let future = self.request_limiter.stream(async move {
- let request = serde_json::to_string(&request)?;
- let stream = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::OpenAi as i32,
- request,
- })
+
+ if cx
+ .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+ .unwrap_or(false)
+ {
+ let llm_api_token = self.llm_api_token.clone();
+ let future = self.request_limiter.stream(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::OpenAi,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(serde_json::to_string(
+ &request,
+ )?)?,
+ },
+ )
.await?;
- Ok(open_ai::extract_text_from_events(
- stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
- ))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
+ let body = BufReader::new(response.into_body());
+ let stream =
+ futures::stream::try_unfold(body, move |mut body| async move {
+ let mut buffer = String::new();
+ match body.read_line(&mut buffer).await {
+ Ok(0) => Ok(None),
+ Ok(_) => {
+ let event: open_ai::ResponseStreamEvent =
+ serde_json::from_str(&buffer)?;
+ Ok(Some((event, body)))
+ }
+ Err(e) => Err(e.into()),
+ }
+ });
+
+ Ok(open_ai::extract_text_from_events(stream))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ } else {
+ let future = self.request_limiter.stream(async move {
+ let request = serde_json::to_string(&request)?;
+ let stream = client
+ .request_stream(proto::StreamCompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::OpenAi as i32,
+ request,
+ })
+ .await?;
+ Ok(open_ai::extract_text_from_events(
+ stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
+ ))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
}
CloudModel::Google(model) => {
let client = self.client.clone();
let request = request.into_google(model.id().into());
- let future = self.request_limiter.stream(async move {
- let request = serde_json::to_string(&request)?;
- let stream = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::Google as i32,
- request,
- })
+
+ if cx
+ .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+ .unwrap_or(false)
+ {
+ let llm_api_token = self.llm_api_token.clone();
+ let future = self.request_limiter.stream(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::Google,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(serde_json::to_string(
+ &request,
+ )?)?,
+ },
+ )
.await?;
- Ok(google_ai::extract_text_from_events(
- stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
- ))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
+ let body = BufReader::new(response.into_body());
+ let stream =
+ futures::stream::try_unfold(body, move |mut body| async move {
+ let mut buffer = String::new();
+ match body.read_line(&mut buffer).await {
+ Ok(0) => Ok(None),
+ Ok(_) => {
+ let event: google_ai::GenerateContentResponse =
+ serde_json::from_str(&buffer)?;
+ Ok(Some((event, body)))
+ }
+ Err(e) => Err(e.into()),
+ }
+ });
+
+ Ok(google_ai::extract_text_from_events(stream))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ } else {
+ let future = self.request_limiter.stream(async move {
+ let request = serde_json::to_string(&request)?;
+ let stream = client
+ .request_stream(proto::StreamCompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::Google as i32,
+ request,
+ })
+ .await?;
+ Ok(google_ai::extract_text_from_events(
+ stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
+ ))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
}
CloudModel::Zed(model) => {
let client = self.client.clone();
let mut request = request.into_open_ai(model.id().into());
request.max_tokens = Some(4000);
- let future = self.request_limiter.stream(async move {
- let request = serde_json::to_string(&request)?;
- let stream = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::Zed as i32,
- request,
- })
+
+ if cx
+ .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+ .unwrap_or(false)
+ {
+ let llm_api_token = self.llm_api_token.clone();
+ let future = self.request_limiter.stream(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::Zed,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(serde_json::to_string(
+ &request,
+ )?)?,
+ },
+ )
.await?;
- Ok(open_ai::extract_text_from_events(
- stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
- ))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
+ let body = BufReader::new(response.into_body());
+ let stream =
+ futures::stream::try_unfold(body, move |mut body| async move {
+ let mut buffer = String::new();
+ match body.read_line(&mut buffer).await {
+ Ok(0) => Ok(None),
+ Ok(_) => {
+ let event: open_ai::ResponseStreamEvent =
+ serde_json::from_str(&buffer)?;
+ Ok(Some((event, body)))
+ }
+ Err(e) => Err(e.into()),
+ }
+ });
+
+ Ok(open_ai::extract_text_from_events(stream))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ } else {
+ let future = self.request_limiter.stream(async move {
+ let request = serde_json::to_string(&request)?;
+ let stream = client
+ .request_stream(proto::StreamCompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::Zed as i32,
+ request,
+ })
+ .await?;
+ Ok(open_ai::extract_text_from_events(
+ stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
+ ))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
}
}
}
@@ -2,7 +2,18 @@ use serde::{Deserialize, Serialize};
pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum LanguageModelProvider {
+ Anthropic,
+ OpenAi,
+ Google,
+ Zed,
+}
+
#[derive(Serialize, Deserialize)]
pub struct PerformCompletionParams {
+ pub provider: LanguageModelProvider,
+ pub model: String,
pub provider_request: Box<serde_json::value::RawValue>,
}