Fix interaction with Anthropic models when using it via zed.dev (#15009)

Antonio Scandurra and Bennet created

Release Notes:

- N/A

---------

Co-authored-by: Bennet <bennet@zed.dev>

Change summary

crates/completion/src/cloud.rs                 | 31 +++++++++++--------
crates/language_model/src/model/cloud_model.rs | 16 ----------
crates/language_model/src/request.rs           |  5 ++
3 files changed, 22 insertions(+), 30 deletions(-)

Detailed changes

crates/completion/src/cloud.rs 🔗

@@ -101,7 +101,7 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider {
         request: LanguageModelRequest,
         cx: &AppContext,
     ) -> BoxFuture<'static, Result<usize>> {
-        match request.model {
+        match &request.model {
             LanguageModel::Cloud(CloudModel::Gpt4)
             | LanguageModel::Cloud(CloudModel::Gpt4Turbo)
             | LanguageModel::Cloud(CloudModel::Gpt4Omni)
@@ -118,19 +118,24 @@ impl LanguageModelCompletionProvider for CloudCompletionProvider {
                 count_open_ai_tokens(request, cx.background_executor())
             }
             LanguageModel::Cloud(CloudModel::Custom { name, .. }) => {
-                let request = self.client.request(proto::CountTokensWithLanguageModel {
-                    model: name,
-                    messages: request
-                        .messages
-                        .iter()
-                        .map(|message| message.to_proto())
-                        .collect(),
-                });
-                async move {
-                    let response = request.await?;
-                    Ok(response.token_count as usize)
+                if name.starts_with("anthropic/") {
+                    // Can't find a tokenizer for Anthropic models, so for now just use the same as OpenAI's as an approximation.
+                    count_open_ai_tokens(request, cx.background_executor())
+                } else {
+                    let request = self.client.request(proto::CountTokensWithLanguageModel {
+                        model: name.clone(),
+                        messages: request
+                            .messages
+                            .iter()
+                            .map(|message| message.to_proto())
+                            .collect(),
+                    });
+                    async move {
+                        let response = request.await?;
+                        Ok(response.token_count as usize)
+                    }
+                    .boxed()
                 }
-                .boxed()
             }
             _ => future::ready(Err(anyhow!("invalid model"))).boxed(),
         }

crates/language_model/src/model/cloud_model.rs 🔗

@@ -1,4 +1,3 @@
-use crate::LanguageModelRequest;
 pub use anthropic::Model as AnthropicModel;
 pub use ollama::Model as OllamaModel;
 pub use open_ai::Model as OpenAiModel;
@@ -88,19 +87,4 @@ impl CloudModel {
             Self::Custom { max_tokens, .. } => max_tokens.unwrap_or(200_000),
         }
     }
-
-    pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
-        match self {
-            Self::Claude3Opus
-            | Self::Claude3Sonnet
-            | Self::Claude3Haiku
-            | Self::Claude3_5Sonnet => {
-                request.preprocess_anthropic();
-            }
-            Self::Custom { name, .. } if name.starts_with("anthropic/") => {
-                request.preprocess_anthropic();
-            }
-            _ => {}
-        }
-    }
 }

crates/language_model/src/request.rs 🔗

@@ -45,7 +45,7 @@ impl LanguageModelRequest {
     pub fn preprocess(&mut self) {
         match &self.model {
             LanguageModel::OpenAi(_) => {}
-            LanguageModel::Anthropic(_) => {}
+            LanguageModel::Anthropic(_) => self.preprocess_anthropic(),
             LanguageModel::Ollama(_) => {}
             LanguageModel::Cloud(model) => match model {
                 CloudModel::Claude3Opus
@@ -54,6 +54,9 @@ impl LanguageModelRequest {
                 | CloudModel::Claude3_5Sonnet => {
                     self.preprocess_anthropic();
                 }
+                CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
+                    self.preprocess_anthropic();
+                }
                 _ => {}
             },
         }