Add Qwen2-7B to the list of zed.dev models (#15649)

Antonio Scandurra and Nathan created

Release Notes:

- N/A

---------

Co-authored-by: Nathan <nathan@zed.dev>

Change summary

crates/collab/k8s/collab.template.yml          | 10 ++++++
crates/collab/src/lib.rs                       |  2 +
crates/collab/src/rpc.rs                       | 24 +++++++++++++++
crates/collab/src/tests/test_server.rs         |  2 +
crates/language_model/src/model/cloud_model.rs | 31 ++++++++++++++++++++
crates/language_model/src/provider/cloud.rs    | 29 ++++++++++++++++++
crates/language_model/src/request.rs           |  1 
crates/open_ai/src/open_ai.rs                  | 14 ++++++++
crates/proto/proto/zed.proto                   |  1 
9 files changed, 112 insertions(+), 2 deletions(-)

Detailed changes

crates/collab/k8s/collab.template.yml 🔗

@@ -127,6 +127,16 @@ spec:
                 secretKeyRef:
                   name: google-ai
                   key: api_key
+            - name: QWEN2_7B_API_KEY
+              valueFrom:
+                secretKeyRef:
+                  name: hugging-face
+                  key: api_key
+            - name: QWEN2_7B_API_URL
+              valueFrom:
+                secretKeyRef:
+                  name: hugging-face
+                  key: qwen2_api_url
             - name: BLOB_STORE_ACCESS_KEY
               valueFrom:
                 secretKeyRef:

crates/collab/src/lib.rs 🔗

@@ -151,6 +151,8 @@ pub struct Config {
     pub openai_api_key: Option<Arc<str>>,
     pub google_ai_api_key: Option<Arc<str>>,
     pub anthropic_api_key: Option<Arc<str>>,
+    pub qwen2_7b_api_key: Option<Arc<str>>,
+    pub qwen2_7b_api_url: Option<Arc<str>>,
     pub zed_client_checksum_seed: Option<String>,
     pub slack_panics_webhook: Option<String>,
     pub auto_join_channel_id: Option<ChannelId>,

crates/collab/src/rpc.rs 🔗

@@ -4706,6 +4706,30 @@ async fn stream_complete_with_language_model(
                 })?;
             }
         }
+        Some(proto::LanguageModelProvider::Zed) => {
+            let api_key = config
+                .qwen2_7b_api_key
+                .as_ref()
+                .context("no Qwen2-7B API key configured on the server")?;
+            let api_url = config
+                .qwen2_7b_api_url
+                .as_ref()
+                .context("no Qwen2-7B URL configured on the server")?;
+            let mut events = open_ai::stream_completion(
+                session.http_client.as_ref(),
+                &api_url,
+                api_key,
+                serde_json::from_str(&request.request)?,
+                None,
+            )
+            .await?;
+            while let Some(event) = events.next().await {
+                let event = event?;
+                response.send(proto::StreamCompleteWithLanguageModelResponse {
+                    event: serde_json::to_string(&event)?,
+                })?;
+            }
+        }
         None => return Err(anyhow!("unknown provider"))?,
     }
 

crates/collab/src/tests/test_server.rs 🔗

@@ -672,6 +672,8 @@ impl TestServer {
                 stripe_api_key: None,
                 stripe_price_id: None,
                 supermaven_admin_api_key: None,
+                qwen2_7b_api_key: None,
+                qwen2_7b_api_url: None,
             },
         })
     }

crates/language_model/src/model/cloud_model.rs 🔗

@@ -1,5 +1,6 @@
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
+use strum::EnumIter;
 
 #[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 #[serde(tag = "provider", rename_all = "lowercase")]
@@ -7,6 +8,33 @@ pub enum CloudModel {
     Anthropic(anthropic::Model),
     OpenAi(open_ai::Model),
     Google(google_ai::Model),
+    Zed(ZedModel),
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema, EnumIter)]
+pub enum ZedModel {
+    #[serde(rename = "qwen2-7b-instruct")]
+    Qwen2_7bInstruct,
+}
+
+impl ZedModel {
+    pub fn id(&self) -> &str {
+        match self {
+            ZedModel::Qwen2_7bInstruct => "qwen2-7b-instruct",
+        }
+    }
+
+    pub fn display_name(&self) -> &str {
+        match self {
+            ZedModel::Qwen2_7bInstruct => "Qwen2 7B Instruct",
+        }
+    }
+
+    pub fn max_token_count(&self) -> usize {
+        match self {
+            ZedModel::Qwen2_7bInstruct => 8192,
+        }
+    }
 }
 
 impl Default for CloudModel {
@@ -21,6 +49,7 @@ impl CloudModel {
             CloudModel::Anthropic(model) => model.id(),
             CloudModel::OpenAi(model) => model.id(),
             CloudModel::Google(model) => model.id(),
+            CloudModel::Zed(model) => model.id(),
         }
     }
 
@@ -29,6 +58,7 @@ impl CloudModel {
             CloudModel::Anthropic(model) => model.display_name(),
             CloudModel::OpenAi(model) => model.display_name(),
             CloudModel::Google(model) => model.display_name(),
+            CloudModel::Zed(model) => model.display_name(),
         }
     }
 
@@ -37,6 +67,7 @@ impl CloudModel {
             CloudModel::Anthropic(model) => model.max_token_count(),
             CloudModel::OpenAi(model) => model.max_token_count(),
             CloudModel::Google(model) => model.max_token_count(),
+            CloudModel::Zed(model) => model.max_token_count(),
         }
     }
 }

crates/language_model/src/provider/cloud.rs 🔗

@@ -2,7 +2,7 @@ use super::open_ai::count_open_ai_tokens;
 use crate::{
     settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
     LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
-    LanguageModelProviderState, LanguageModelRequest, RateLimiter,
+    LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
 };
 use anyhow::{anyhow, Context as _, Result};
 use client::{Client, UserStore};
@@ -146,6 +146,9 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
                 models.insert(model.id().to_string(), CloudModel::Google(model));
             }
         }
+        for model in ZedModel::iter() {
+            models.insert(model.id().to_string(), CloudModel::Zed(model));
+        }
 
         // Override with available models from settings
         for model in &AllLanguageModelSettings::get_global(cx)
@@ -263,6 +266,9 @@ impl LanguageModel for CloudLanguageModel {
                 }
                 .boxed()
             }
+            CloudModel::Zed(_) => {
+                count_open_ai_tokens(request, open_ai::Model::ThreePointFiveTurbo, cx)
+            }
         }
     }
 
@@ -323,6 +329,24 @@ impl LanguageModel for CloudLanguageModel {
                 });
                 async move { Ok(future.await?.boxed()) }.boxed()
             }
+            CloudModel::Zed(model) => {
+                let client = self.client.clone();
+                let mut request = request.into_open_ai(model.id().into());
+                request.max_tokens = Some(4000);
+                let future = self.request_limiter.stream(async move {
+                    let request = serde_json::to_string(&request)?;
+                    let stream = client
+                        .request_stream(proto::StreamCompleteWithLanguageModel {
+                            provider: proto::LanguageModelProvider::Zed as i32,
+                            request,
+                        })
+                        .await?;
+                    Ok(open_ai::extract_text_from_events(
+                        stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
+                    ))
+                });
+                async move { Ok(future.await?.boxed()) }.boxed()
+            }
         }
     }
 
@@ -382,6 +406,9 @@ impl LanguageModel for CloudLanguageModel {
             CloudModel::Google(_) => {
                 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
             }
+            CloudModel::Zed(_) => {
+                future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
+            }
         }
     }
 }

crates/language_model/src/request.rs 🔗

@@ -37,6 +37,7 @@ impl LanguageModelRequest {
             stream: true,
             stop: self.stop,
             temperature: self.temperature,
+            max_tokens: None,
             tools: Vec::new(),
             tool_choice: None,
         }

crates/open_ai/src/open_ai.rs 🔗

@@ -116,6 +116,8 @@ pub struct Request {
     pub model: String,
     pub messages: Vec<RequestMessage>,
     pub stream: bool,
+    #[serde(default, skip_serializing_if = "Option::is_none")]
+    pub max_tokens: Option<usize>,
     pub stop: Vec<String>,
     pub temperature: f32,
     #[serde(default, skip_serializing_if = "Option::is_none")]
@@ -216,6 +218,13 @@ pub struct ChoiceDelta {
     pub finish_reason: Option<String>,
 }
 
+#[derive(Serialize, Deserialize, Debug)]
+#[serde(untagged)]
+pub enum ResponseStreamResult {
+    Ok(ResponseStreamEvent),
+    Err { error: String },
+}
+
 #[derive(Serialize, Deserialize, Debug)]
 pub struct ResponseStreamEvent {
     pub created: u32,
@@ -256,7 +265,10 @@ pub async fn stream_completion(
                             None
                         } else {
                             match serde_json::from_str(line) {
-                                Ok(response) => Some(Ok(response)),
+                                Ok(ResponseStreamResult::Ok(response)) => Some(Ok(response)),
+                                Ok(ResponseStreamResult::Err { error }) => {
+                                    Some(Err(anyhow!(error)))
+                                }
                                 Err(error) => Some(Err(anyhow!(error))),
                             }
                         }

crates/proto/proto/zed.proto 🔗

@@ -2099,6 +2099,7 @@ enum LanguageModelProvider {
     Anthropic = 0;
     OpenAI = 1;
     Google = 2;
+    Zed = 3;
 }
 
 message GetCachedEmbeddings {