collab: Add support for more providers to the LLM service (#15832)

Marshall Bowers created

This PR adds support for additional providers to the LLM service:

- OpenAI
- Google
- Custom Zed models (through Hugging Face)

Release Notes:

- N/A

Change summary

crates/collab/src/llm.rs                    | 141 +++++++++--
crates/language_model/src/provider/cloud.rs | 281 ++++++++++++++++------
crates/rpc/src/llm.rs                       |  11 
3 files changed, 333 insertions(+), 100 deletions(-)

Detailed changes

crates/collab/src/llm.rs 🔗

@@ -12,7 +12,7 @@ use axum::{
 };
 use futures::StreamExt as _;
 use http_client::IsahcHttpClient;
-use rpc::{PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
+use rpc::{LanguageModelProvider, PerformCompletionParams, EXPIRED_LLM_TOKEN_HEADER_NAME};
 use std::sync::Arc;
 
 pub use token::*;
@@ -94,29 +94,118 @@ async fn perform_completion(
     Extension(_claims): Extension<LlmTokenClaims>,
     Json(params): Json<PerformCompletionParams>,
 ) -> Result<impl IntoResponse> {
-    let api_key = state
-        .config
-        .anthropic_api_key
-        .as_ref()
-        .context("no Anthropic AI API key configured on the server")?;
-    let chunks = anthropic::stream_completion(
-        &state.http_client,
-        anthropic::ANTHROPIC_API_URL,
-        api_key,
-        serde_json::from_str(&params.provider_request.get())?,
-        None,
-    )
-    .await?;
-
-    let stream = chunks.map(|event| {
-        let mut buffer = Vec::new();
-        event.map(|chunk| {
-            buffer.clear();
-            serde_json::to_writer(&mut buffer, &chunk).unwrap();
-            buffer.push(b'\n');
-            buffer
-        })
-    });
-
-    Ok(Response::new(Body::wrap_stream(stream)))
+    match params.provider {
+        LanguageModelProvider::Anthropic => {
+            let api_key = state
+                .config
+                .anthropic_api_key
+                .as_ref()
+                .context("no Anthropic AI API key configured on the server")?;
+            let chunks = anthropic::stream_completion(
+                &state.http_client,
+                anthropic::ANTHROPIC_API_URL,
+                api_key,
+                serde_json::from_str(&params.provider_request.get())?,
+                None,
+            )
+            .await?;
+
+            let stream = chunks.map(|event| {
+                let mut buffer = Vec::new();
+                event.map(|chunk| {
+                    buffer.clear();
+                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
+                    buffer.push(b'\n');
+                    buffer
+                })
+            });
+
+            Ok(Response::new(Body::wrap_stream(stream)))
+        }
+        LanguageModelProvider::OpenAi => {
+            let api_key = state
+                .config
+                .openai_api_key
+                .as_ref()
+                .context("no OpenAI API key configured on the server")?;
+            let chunks = open_ai::stream_completion(
+                &state.http_client,
+                open_ai::OPEN_AI_API_URL,
+                api_key,
+                serde_json::from_str(&params.provider_request.get())?,
+                None,
+            )
+            .await?;
+
+            let stream = chunks.map(|event| {
+                let mut buffer = Vec::new();
+                event.map(|chunk| {
+                    buffer.clear();
+                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
+                    buffer.push(b'\n');
+                    buffer
+                })
+            });
+
+            Ok(Response::new(Body::wrap_stream(stream)))
+        }
+        LanguageModelProvider::Google => {
+            let api_key = state
+                .config
+                .google_ai_api_key
+                .as_ref()
+                .context("no Google AI API key configured on the server")?;
+            let chunks = google_ai::stream_generate_content(
+                &state.http_client,
+                google_ai::API_URL,
+                api_key,
+                serde_json::from_str(&params.provider_request.get())?,
+            )
+            .await?;
+
+            let stream = chunks.map(|event| {
+                let mut buffer = Vec::new();
+                event.map(|chunk| {
+                    buffer.clear();
+                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
+                    buffer.push(b'\n');
+                    buffer
+                })
+            });
+
+            Ok(Response::new(Body::wrap_stream(stream)))
+        }
+        LanguageModelProvider::Zed => {
+            let api_key = state
+                .config
+                .qwen2_7b_api_key
+                .as_ref()
+                .context("no Qwen2-7B API key configured on the server")?;
+            let api_url = state
+                .config
+                .qwen2_7b_api_url
+                .as_ref()
+                .context("no Qwen2-7B URL configured on the server")?;
+            let chunks = open_ai::stream_completion(
+                &state.http_client,
+                &api_url,
+                api_key,
+                serde_json::from_str(&params.provider_request.get())?,
+                None,
+            )
+            .await?;
+
+            let stream = chunks.map(|event| {
+                let mut buffer = Vec::new();
+                event.map(|chunk| {
+                    buffer.clear();
+                    serde_json::to_writer(&mut buffer, &chunk).unwrap();
+                    buffer.push(b'\n');
+                    buffer
+                })
+            });
+
+            Ok(Response::new(Body::wrap_stream(stream)))
+        }
+    }
 }

crates/language_model/src/provider/cloud.rs 🔗

@@ -10,7 +10,7 @@ use collections::BTreeMap;
 use feature_flags::{FeatureFlag, FeatureFlagAppExt};
 use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
 use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
-use http_client::{HttpClient, Method};
+use http_client::{AsyncBody, HttpClient, Method, Response};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use serde_json::value::RawValue;
@@ -239,6 +239,47 @@ pub struct CloudLanguageModel {
 #[derive(Clone, Default)]
 struct LlmApiToken(Arc<RwLock<Option<String>>>);
 
+impl CloudLanguageModel {
+    async fn perform_llm_completion(
+        client: Arc<Client>,
+        llm_api_token: LlmApiToken,
+        body: PerformCompletionParams,
+    ) -> Result<Response<AsyncBody>> {
+        let http_client = &client.http_client();
+
+        let mut token = llm_api_token.acquire(&client).await?;
+        let mut did_retry = false;
+
+        let response = loop {
+            let request = http_client::Request::builder()
+                .method(Method::POST)
+                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
+                .header("Content-Type", "application/json")
+                .header("Authorization", format!("Bearer {token}"))
+                .body(serde_json::to_string(&body)?.into())?;
+            let response = http_client.send(request).await?;
+            if response.status().is_success() {
+                break response;
+            } else if !did_retry
+                && response
+                    .headers()
+                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
+                    .is_some()
+            {
+                did_retry = true;
+                token = llm_api_token.refresh(&client).await?;
+            } else {
+                break Err(anyhow!(
+                    "cloud language model completion failed with status {}",
+                    response.status()
+                ))?;
+            }
+        };
+
+        Ok(response)
+    }
+}
+
 impl LanguageModel for CloudLanguageModel {
     fn id(&self) -> LanguageModelId {
         self.id.clone()
@@ -314,46 +355,21 @@ impl LanguageModel for CloudLanguageModel {
                     .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
                     .unwrap_or(false)
                 {
-                    let http_client = self.client.http_client();
                     let llm_api_token = self.llm_api_token.clone();
                     let future = self.request_limiter.stream(async move {
-                        let request = serde_json::to_string(&request)?;
-                        let mut token = llm_api_token.acquire(&client).await?;
-                        let mut did_retry = false;
-
-                        let response = loop {
-                            let request = http_client::Request::builder()
-                                .method(Method::POST)
-                                .uri(http_client.build_zed_llm_url("/completion", &[])?.as_ref())
-                                .header("Content-Type", "application/json")
-                                .header("Authorization", format!("Bearer {token}"))
-                                .body(
-                                    serde_json::to_string(&PerformCompletionParams {
-                                        provider_request: RawValue::from_string(request.clone())?,
-                                    })?
-                                    .into(),
-                                )?;
-                            let response = http_client.send(request).await?;
-                            if response.status().is_success() {
-                                break response;
-                            } else if !did_retry
-                                && response
-                                    .headers()
-                                    .get(EXPIRED_LLM_TOKEN_HEADER_NAME)
-                                    .is_some()
-                            {
-                                did_retry = true;
-                                token = llm_api_token.refresh(&client).await?;
-                            } else {
-                                break Err(anyhow!(
-                                    "cloud language model completion failed with status {}",
-                                    response.status()
-                                ))?;
-                            }
-                        };
-
+                        let response = Self::perform_llm_completion(
+                            client.clone(),
+                            llm_api_token,
+                            PerformCompletionParams {
+                                provider: client::LanguageModelProvider::Anthropic,
+                                model: request.model.clone(),
+                                provider_request: RawValue::from_string(serde_json::to_string(
+                                    &request,
+                                )?)?,
+                            },
+                        )
+                        .await?;
                         let body = BufReader::new(response.into_body());
-
                         let stream =
                             futures::stream::try_unfold(body, move |mut body| async move {
                                 let mut buffer = String::new();
@@ -389,54 +405,171 @@ impl LanguageModel for CloudLanguageModel {
             CloudModel::OpenAi(model) => {
                 let client = self.client.clone();
                 let request = request.into_open_ai(model.id().into());
-                let future = self.request_limiter.stream(async move {
-                    let request = serde_json::to_string(&request)?;
-                    let stream = client
-                        .request_stream(proto::StreamCompleteWithLanguageModel {
-                            provider: proto::LanguageModelProvider::OpenAi as i32,
-                            request,
-                        })
+
+                if cx
+                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+                    .unwrap_or(false)
+                {
+                    let llm_api_token = self.llm_api_token.clone();
+                    let future = self.request_limiter.stream(async move {
+                        let response = Self::perform_llm_completion(
+                            client.clone(),
+                            llm_api_token,
+                            PerformCompletionParams {
+                                provider: client::LanguageModelProvider::OpenAi,
+                                model: request.model.clone(),
+                                provider_request: RawValue::from_string(serde_json::to_string(
+                                    &request,
+                                )?)?,
+                            },
+                        )
                         .await?;
-                    Ok(open_ai::extract_text_from_events(
-                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
-                    ))
-                });
-                async move { Ok(future.await?.boxed()) }.boxed()
+                        let body = BufReader::new(response.into_body());
+                        let stream =
+                            futures::stream::try_unfold(body, move |mut body| async move {
+                                let mut buffer = String::new();
+                                match body.read_line(&mut buffer).await {
+                                    Ok(0) => Ok(None),
+                                    Ok(_) => {
+                                        let event: open_ai::ResponseStreamEvent =
+                                            serde_json::from_str(&buffer)?;
+                                        Ok(Some((event, body)))
+                                    }
+                                    Err(e) => Err(e.into()),
+                                }
+                            });
+
+                        Ok(open_ai::extract_text_from_events(stream))
+                    });
+                    async move { Ok(future.await?.boxed()) }.boxed()
+                } else {
+                    let future = self.request_limiter.stream(async move {
+                        let request = serde_json::to_string(&request)?;
+                        let stream = client
+                            .request_stream(proto::StreamCompleteWithLanguageModel {
+                                provider: proto::LanguageModelProvider::OpenAi as i32,
+                                request,
+                            })
+                            .await?;
+                        Ok(open_ai::extract_text_from_events(
+                            stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
+                        ))
+                    });
+                    async move { Ok(future.await?.boxed()) }.boxed()
+                }
             }
             CloudModel::Google(model) => {
                 let client = self.client.clone();
                 let request = request.into_google(model.id().into());
-                let future = self.request_limiter.stream(async move {
-                    let request = serde_json::to_string(&request)?;
-                    let stream = client
-                        .request_stream(proto::StreamCompleteWithLanguageModel {
-                            provider: proto::LanguageModelProvider::Google as i32,
-                            request,
-                        })
+
+                if cx
+                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+                    .unwrap_or(false)
+                {
+                    let llm_api_token = self.llm_api_token.clone();
+                    let future = self.request_limiter.stream(async move {
+                        let response = Self::perform_llm_completion(
+                            client.clone(),
+                            llm_api_token,
+                            PerformCompletionParams {
+                                provider: client::LanguageModelProvider::Google,
+                                model: request.model.clone(),
+                                provider_request: RawValue::from_string(serde_json::to_string(
+                                    &request,
+                                )?)?,
+                            },
+                        )
                         .await?;
-                    Ok(google_ai::extract_text_from_events(
-                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
-                    ))
-                });
-                async move { Ok(future.await?.boxed()) }.boxed()
+                        let body = BufReader::new(response.into_body());
+                        let stream =
+                            futures::stream::try_unfold(body, move |mut body| async move {
+                                let mut buffer = String::new();
+                                match body.read_line(&mut buffer).await {
+                                    Ok(0) => Ok(None),
+                                    Ok(_) => {
+                                        let event: google_ai::GenerateContentResponse =
+                                            serde_json::from_str(&buffer)?;
+                                        Ok(Some((event, body)))
+                                    }
+                                    Err(e) => Err(e.into()),
+                                }
+                            });
+
+                        Ok(google_ai::extract_text_from_events(stream))
+                    });
+                    async move { Ok(future.await?.boxed()) }.boxed()
+                } else {
+                    let future = self.request_limiter.stream(async move {
+                        let request = serde_json::to_string(&request)?;
+                        let stream = client
+                            .request_stream(proto::StreamCompleteWithLanguageModel {
+                                provider: proto::LanguageModelProvider::Google as i32,
+                                request,
+                            })
+                            .await?;
+                        Ok(google_ai::extract_text_from_events(
+                            stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
+                        ))
+                    });
+                    async move { Ok(future.await?.boxed()) }.boxed()
+                }
             }
             CloudModel::Zed(model) => {
                 let client = self.client.clone();
                 let mut request = request.into_open_ai(model.id().into());
                 request.max_tokens = Some(4000);
-                let future = self.request_limiter.stream(async move {
-                    let request = serde_json::to_string(&request)?;
-                    let stream = client
-                        .request_stream(proto::StreamCompleteWithLanguageModel {
-                            provider: proto::LanguageModelProvider::Zed as i32,
-                            request,
-                        })
+
+                if cx
+                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+                    .unwrap_or(false)
+                {
+                    let llm_api_token = self.llm_api_token.clone();
+                    let future = self.request_limiter.stream(async move {
+                        let response = Self::perform_llm_completion(
+                            client.clone(),
+                            llm_api_token,
+                            PerformCompletionParams {
+                                provider: client::LanguageModelProvider::Zed,
+                                model: request.model.clone(),
+                                provider_request: RawValue::from_string(serde_json::to_string(
+                                    &request,
+                                )?)?,
+                            },
+                        )
                         .await?;
-                    Ok(open_ai::extract_text_from_events(
-                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
-                    ))
-                });
-                async move { Ok(future.await?.boxed()) }.boxed()
+                        let body = BufReader::new(response.into_body());
+                        let stream =
+                            futures::stream::try_unfold(body, move |mut body| async move {
+                                let mut buffer = String::new();
+                                match body.read_line(&mut buffer).await {
+                                    Ok(0) => Ok(None),
+                                    Ok(_) => {
+                                        let event: open_ai::ResponseStreamEvent =
+                                            serde_json::from_str(&buffer)?;
+                                        Ok(Some((event, body)))
+                                    }
+                                    Err(e) => Err(e.into()),
+                                }
+                            });
+
+                        Ok(open_ai::extract_text_from_events(stream))
+                    });
+                    async move { Ok(future.await?.boxed()) }.boxed()
+                } else {
+                    let future = self.request_limiter.stream(async move {
+                        let request = serde_json::to_string(&request)?;
+                        let stream = client
+                            .request_stream(proto::StreamCompleteWithLanguageModel {
+                                provider: proto::LanguageModelProvider::Zed as i32,
+                                request,
+                            })
+                            .await?;
+                        Ok(open_ai::extract_text_from_events(
+                            stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
+                        ))
+                    });
+                    async move { Ok(future.await?.boxed()) }.boxed()
+                }
             }
         }
     }

crates/rpc/src/llm.rs 🔗

@@ -2,7 +2,18 @@ use serde::{Deserialize, Serialize};
 
 pub const EXPIRED_LLM_TOKEN_HEADER_NAME: &str = "x-zed-expired-token";
 
+#[derive(Debug, PartialEq, Eq, Clone, Copy, Serialize, Deserialize)]
+#[serde(rename_all = "snake_case")]
+pub enum LanguageModelProvider {
+    Anthropic,
+    OpenAi,
+    Google,
+    Zed,
+}
+
 #[derive(Serialize, Deserialize)]
 pub struct PerformCompletionParams {
+    pub provider: LanguageModelProvider,
+    pub model: String,
     pub provider_request: Box<serde_json::value::RawValue>,
 }