diff --git a/crates/google_ai/src/google_ai.rs b/crates/google_ai/src/google_ai.rs index b95c35b91acc4c3db9325f37cce4313a740d3efb..84bb15affa5aac239cb7f67f7d48db6e9f47b107 100644 --- a/crates/google_ai/src/google_ai.rs +++ b/crates/google_ai/src/google_ai.rs @@ -76,12 +76,10 @@ pub async fn count_tokens( client: &dyn HttpClient, api_url: &str, api_key: &str, + model_id: &str, request: CountTokensRequest, ) -> Result { - let uri = format!( - "{}/v1beta/models/gemini-pro:countTokens?key={}", - api_url, api_key - ); + let uri = format!("{api_url}/v1beta/models/{model_id}:countTokens?key={api_key}",); let request = serde_json::to_string(&request)?; let request_builder = HttpRequest::builder() diff --git a/crates/language_models/src/provider/google.rs b/crates/language_models/src/provider/google.rs index 1b25eade5af9a03b1c88c0d7bb81df4045251ed3..1bb0df310e3d56ad09254a2f4d6f8c5f3cf3a8a7 100644 --- a/crates/language_models/src/provider/google.rs +++ b/crates/language_models/src/provider/google.rs @@ -327,7 +327,8 @@ impl LanguageModel for GoogleLanguageModel { request: LanguageModelRequest, cx: &App, ) -> BoxFuture<'static, Result> { - let request = into_google(request, self.model.id().to_string()); + let model_id = self.model.id().to_string(); + let request = into_google(request, model_id.clone()); let http_client = self.http_client.clone(); let api_key = self.state.read(cx).api_key.clone(); @@ -340,6 +341,7 @@ impl LanguageModel for GoogleLanguageModel { http_client.as_ref(), &api_url, &api_key, + &model_id, google_ai::CountTokensRequest { contents: request.contents, },