diff --git a/crates/language_model/src/provider/cloud.rs b/crates/language_model/src/provider/cloud.rs index 2f6651a2eb1275fe78ec9d504380a9488f7712e6..27a4f4f8b72acd981dcb887334028461382bb9a8 100644 --- a/crates/language_model/src/provider/cloud.rs +++ b/crates/language_model/src/provider/cloud.rs @@ -4,10 +4,10 @@ use crate::{ LanguageModelName, LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel, }; -use anyhow::{anyhow, bail, Context as _, Result}; +use anyhow::{anyhow, bail, Result}; use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME}; use collections::BTreeMap; -use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels}; +use feature_flags::{FeatureFlagAppExt, LanguageModels}; use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task}; use http_client::{AsyncBody, HttpClient, Method, Response}; @@ -228,16 +228,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider { } } -struct LlmServiceFeatureFlag; - -impl FeatureFlag for LlmServiceFeatureFlag { - const NAME: &'static str = "llm-service"; - - fn enabled_for_staff() -> bool { - false - } -} - pub struct CloudLanguageModel { id: LanguageModelId, model: CloudModel, @@ -354,232 +344,148 @@ impl LanguageModel for CloudLanguageModel { fn stream_completion( &self, request: LanguageModelRequest, - cx: &AsyncAppContext, + _cx: &AsyncAppContext, ) -> BoxFuture<'static, Result>>> { match &self.model { CloudModel::Anthropic(model) => { let request = request.into_anthropic(model.id().into()); let client = self.client.clone(); - - if cx - .update(|cx| cx.has_flag::()) - .unwrap_or(false) - { - let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( - client.clone(), - llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::Anthropic, - model: request.model.clone(), - provider_request: RawValue::from_string(serde_json::to_string( - &request, - )?)?, - }, - ) - .await?; - let body = BufReader::new(response.into_body()); - let stream = - futures::stream::try_unfold(body, move |mut body| async move { - let mut buffer = String::new(); - match body.read_line(&mut buffer).await { - Ok(0) => Ok(None), - Ok(_) => { - let event: anthropic::Event = - serde_json::from_str(&buffer)?; - Ok(Some((event, body))) - } - Err(e) => Err(e.into()), - } - }); - - Ok(anthropic::extract_text_from_events(stream)) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } else { - let future = self.request_limiter.stream(async move { - let request = serde_json::to_string(&request)?; - let stream = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::Anthropic as i32, - request, - }) - .await? - .map(|event| Ok(serde_json::from_str(&event?.event)?)); - Ok(anthropic::extract_text_from_events(stream)) + let llm_api_token = self.llm_api_token.clone(); + let future = self.request_limiter.stream(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::Anthropic, + model: request.model.clone(), + provider_request: RawValue::from_string(serde_json::to_string( + &request, + )?)?, + }, + ) + .await?; + let body = BufReader::new(response.into_body()); + let stream = futures::stream::try_unfold(body, move |mut body| async move { + let mut buffer = String::new(); + match body.read_line(&mut buffer).await { + Ok(0) => Ok(None), + Ok(_) => { + let event: anthropic::Event = serde_json::from_str(&buffer)?; + Ok(Some((event, body))) + } + Err(e) => Err(e.into()), + } }); - async move { Ok(future.await?.boxed()) }.boxed() - } + + Ok(anthropic::extract_text_from_events(stream)) + }); + async move { Ok(future.await?.boxed()) }.boxed() } CloudModel::OpenAi(model) => { let client = self.client.clone(); let request = request.into_open_ai(model.id().into()); - - if cx - .update(|cx| cx.has_flag::()) - .unwrap_or(false) - { - let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( - client.clone(), - llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::OpenAi, - model: request.model.clone(), - provider_request: RawValue::from_string(serde_json::to_string( - &request, - )?)?, - }, - ) - .await?; - let body = BufReader::new(response.into_body()); - let stream = - futures::stream::try_unfold(body, move |mut body| async move { - let mut buffer = String::new(); - match body.read_line(&mut buffer).await { - Ok(0) => Ok(None), - Ok(_) => { - let event: open_ai::ResponseStreamEvent = - serde_json::from_str(&buffer)?; - Ok(Some((event, body))) - } - Err(e) => Err(e.into()), - } - }); - - Ok(open_ai::extract_text_from_events(stream)) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } else { - let future = self.request_limiter.stream(async move { - let request = serde_json::to_string(&request)?; - let stream = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::OpenAi as i32, - request, - }) - .await?; - Ok(open_ai::extract_text_from_events( - stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), - )) + let llm_api_token = self.llm_api_token.clone(); + let future = self.request_limiter.stream(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: RawValue::from_string(serde_json::to_string( + &request, + )?)?, + }, + ) + .await?; + let body = BufReader::new(response.into_body()); + let stream = futures::stream::try_unfold(body, move |mut body| async move { + let mut buffer = String::new(); + match body.read_line(&mut buffer).await { + Ok(0) => Ok(None), + Ok(_) => { + let event: open_ai::ResponseStreamEvent = + serde_json::from_str(&buffer)?; + Ok(Some((event, body))) + } + Err(e) => Err(e.into()), + } }); - async move { Ok(future.await?.boxed()) }.boxed() - } + + Ok(open_ai::extract_text_from_events(stream)) + }); + async move { Ok(future.await?.boxed()) }.boxed() } CloudModel::Google(model) => { let client = self.client.clone(); let request = request.into_google(model.id().into()); - - if cx - .update(|cx| cx.has_flag::()) - .unwrap_or(false) - { - let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( - client.clone(), - llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::Google, - model: request.model.clone(), - provider_request: RawValue::from_string(serde_json::to_string( - &request, - )?)?, - }, - ) - .await?; - let body = BufReader::new(response.into_body()); - let stream = - futures::stream::try_unfold(body, move |mut body| async move { - let mut buffer = String::new(); - match body.read_line(&mut buffer).await { - Ok(0) => Ok(None), - Ok(_) => { - let event: google_ai::GenerateContentResponse = - serde_json::from_str(&buffer)?; - Ok(Some((event, body))) - } - Err(e) => Err(e.into()), - } - }); - - Ok(google_ai::extract_text_from_events(stream)) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } else { - let future = self.request_limiter.stream(async move { - let request = serde_json::to_string(&request)?; - let stream = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::Google as i32, - request, - }) - .await?; - Ok(google_ai::extract_text_from_events( - stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), - )) + let llm_api_token = self.llm_api_token.clone(); + let future = self.request_limiter.stream(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::Google, + model: request.model.clone(), + provider_request: RawValue::from_string(serde_json::to_string( + &request, + )?)?, + }, + ) + .await?; + let body = BufReader::new(response.into_body()); + let stream = futures::stream::try_unfold(body, move |mut body| async move { + let mut buffer = String::new(); + match body.read_line(&mut buffer).await { + Ok(0) => Ok(None), + Ok(_) => { + let event: google_ai::GenerateContentResponse = + serde_json::from_str(&buffer)?; + Ok(Some((event, body))) + } + Err(e) => Err(e.into()), + } }); - async move { Ok(future.await?.boxed()) }.boxed() - } + + Ok(google_ai::extract_text_from_events(stream)) + }); + async move { Ok(future.await?.boxed()) }.boxed() } CloudModel::Zed(model) => { let client = self.client.clone(); let mut request = request.into_open_ai(model.id().into()); request.max_tokens = Some(4000); - - if cx - .update(|cx| cx.has_flag::()) - .unwrap_or(false) - { - let llm_api_token = self.llm_api_token.clone(); - let future = self.request_limiter.stream(async move { - let response = Self::perform_llm_completion( - client.clone(), - llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::Zed, - model: request.model.clone(), - provider_request: RawValue::from_string(serde_json::to_string( - &request, - )?)?, - }, - ) - .await?; - let body = BufReader::new(response.into_body()); - let stream = - futures::stream::try_unfold(body, move |mut body| async move { - let mut buffer = String::new(); - match body.read_line(&mut buffer).await { - Ok(0) => Ok(None), - Ok(_) => { - let event: open_ai::ResponseStreamEvent = - serde_json::from_str(&buffer)?; - Ok(Some((event, body))) - } - Err(e) => Err(e.into()), - } - }); - - Ok(open_ai::extract_text_from_events(stream)) - }); - async move { Ok(future.await?.boxed()) }.boxed() - } else { - let future = self.request_limiter.stream(async move { - let request = serde_json::to_string(&request)?; - let stream = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::Zed as i32, - request, - }) - .await?; - Ok(open_ai::extract_text_from_events( - stream.map(|item| Ok(serde_json::from_str(&item?.event)?)), - )) + let llm_api_token = self.llm_api_token.clone(); + let future = self.request_limiter.stream(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::Zed, + model: request.model.clone(), + provider_request: RawValue::from_string(serde_json::to_string( + &request, + )?)?, + }, + ) + .await?; + let body = BufReader::new(response.into_body()); + let stream = futures::stream::try_unfold(body, move |mut body| async move { + let mut buffer = String::new(); + match body.read_line(&mut buffer).await { + Ok(0) => Ok(None), + Ok(_) => { + let event: open_ai::ResponseStreamEvent = + serde_json::from_str(&buffer)?; + Ok(Some((event, body))) + } + Err(e) => Err(e.into()), + } }); - async move { Ok(future.await?.boxed()) }.boxed() - } + + Ok(open_ai::extract_text_from_events(stream)) + }); + async move { Ok(future.await?.boxed()) }.boxed() } } } @@ -590,7 +496,7 @@ impl LanguageModel for CloudLanguageModel { tool_name: String, tool_description: String, input_schema: serde_json::Value, - cx: &AsyncAppContext, + _cx: &AsyncAppContext, ) -> BoxFuture<'static, Result> { match &self.model { CloudModel::Anthropic(model) => { @@ -605,106 +511,67 @@ impl LanguageModel for CloudLanguageModel { input_schema, }]; - if cx - .update(|cx| cx.has_flag::()) - .unwrap_or(false) - { - let llm_api_token = self.llm_api_token.clone(); - self.request_limiter - .run(async move { - let response = Self::perform_llm_completion( - client.clone(), - llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::Anthropic, - model: request.model.clone(), - provider_request: RawValue::from_string( - serde_json::to_string(&request)?, - )?, - }, - ) - .await?; - - let mut tool_use_index = None; - let mut tool_input = String::new(); - let mut body = BufReader::new(response.into_body()); - let mut line = String::new(); - while body.read_line(&mut line).await? > 0 { - let event: anthropic::Event = serde_json::from_str(&line)?; - line.clear(); - - match event { - anthropic::Event::ContentBlockStart { - content_block, - index, - } => { - if let anthropic::Content::ToolUse { name, .. } = - content_block - { - if name == tool_name { - tool_use_index = Some(index); - } - } - } - anthropic::Event::ContentBlockDelta { index, delta } => { - match delta { - anthropic::ContentDelta::TextDelta { .. } => {} - anthropic::ContentDelta::InputJsonDelta { - partial_json, - } => { - if Some(index) == tool_use_index { - tool_input.push_str(&partial_json); - } - } + let llm_api_token = self.llm_api_token.clone(); + self.request_limiter + .run(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::Anthropic, + model: request.model.clone(), + provider_request: RawValue::from_string(serde_json::to_string( + &request, + )?)?, + }, + ) + .await?; + + let mut tool_use_index = None; + let mut tool_input = String::new(); + let mut body = BufReader::new(response.into_body()); + let mut line = String::new(); + while body.read_line(&mut line).await? > 0 { + let event: anthropic::Event = serde_json::from_str(&line)?; + line.clear(); + + match event { + anthropic::Event::ContentBlockStart { + content_block, + index, + } => { + if let anthropic::Content::ToolUse { name, .. } = content_block + { + if name == tool_name { + tool_use_index = Some(index); } } - anthropic::Event::ContentBlockStop { index } => { + } + anthropic::Event::ContentBlockDelta { index, delta } => match delta + { + anthropic::ContentDelta::TextDelta { .. } => {} + anthropic::ContentDelta::InputJsonDelta { partial_json } => { if Some(index) == tool_use_index { - return Ok(serde_json::from_str(&tool_input)?); + tool_input.push_str(&partial_json); } } - _ => {} + }, + anthropic::Event::ContentBlockStop { index } => { + if Some(index) == tool_use_index { + return Ok(serde_json::from_str(&tool_input)?); + } } + _ => {} } + } - if tool_use_index.is_some() { - Err(anyhow!("tool content incomplete")) - } else { - Err(anyhow!("tool not used")) - } - }) - .boxed() - } else { - self.request_limiter - .run(async move { - let request = serde_json::to_string(&request)?; - let response = client - .request(proto::CompleteWithLanguageModel { - provider: proto::LanguageModelProvider::Anthropic as i32, - request, - }) - .await?; - let response: anthropic::Response = - serde_json::from_str(&response.completion)?; - response - .content - .into_iter() - .find_map(|content| { - if let anthropic::Content::ToolUse { name, input, .. } = content - { - if name == tool_name { - Some(input) - } else { - None - } - } else { - None - } - }) - .context("tool not used") - }) - .boxed() - } + if tool_use_index.is_some() { + Err(anyhow!("tool content incomplete")) + } else { + Err(anyhow!("tool not used")) + } + }) + .boxed() } CloudModel::OpenAi(model) => { let mut request = request.into_open_ai(model.id().into()); @@ -723,115 +590,59 @@ impl LanguageModel for CloudLanguageModel { function.parameters = Some(input_schema); request.tools = vec![open_ai::ToolDefinition::Function { function }]; - if cx - .update(|cx| cx.has_flag::()) - .unwrap_or(false) - { - let llm_api_token = self.llm_api_token.clone(); - self.request_limiter - .run(async move { - let response = Self::perform_llm_completion( - client.clone(), - llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::OpenAi, - model: request.model.clone(), - provider_request: RawValue::from_string( - serde_json::to_string(&request)?, - )?, - }, - ) - .await?; - - let mut body = BufReader::new(response.into_body()); - let mut line = String::new(); - let mut load_state = None; - - while body.read_line(&mut line).await? > 0 { - let part: open_ai::ResponseStreamEvent = - serde_json::from_str(&line)?; - line.clear(); - - for choice in part.choices { - let Some(tool_calls) = choice.delta.tool_calls else { - continue; - }; - - for call in tool_calls { - if let Some(func) = call.function { - if func.name.as_deref() == Some(tool_name.as_str()) { - load_state = Some((String::default(), call.index)); - } - if let Some((arguments, (output, index))) = - func.arguments.zip(load_state.as_mut()) - { - if call.index == *index { - output.push_str(&arguments); - } - } - } - } - } - } + let llm_api_token = self.llm_api_token.clone(); + self.request_limiter + .run(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::OpenAi, + model: request.model.clone(), + provider_request: RawValue::from_string(serde_json::to_string( + &request, + )?)?, + }, + ) + .await?; - if let Some((arguments, _)) = load_state { - return Ok(serde_json::from_str(&arguments)?); - } else { - bail!("tool not used"); - } - }) - .boxed() - } else { - self.request_limiter - .run(async move { - let request = serde_json::to_string(&request)?; - let response = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::OpenAi as i32, - request, - }) - .await?; - let mut load_state = None; - let mut response = response.map( - |item: Result< - proto::StreamCompleteWithLanguageModelResponse, - anyhow::Error, - >| { - Result::::Ok( - serde_json::from_str(&item?.event)?, - ) - }, - ); - while let Some(Ok(part)) = response.next().await { - for choice in part.choices { - let Some(tool_calls) = choice.delta.tool_calls else { - continue; - }; - - for call in tool_calls { - if let Some(func) = call.function { - if func.name.as_deref() == Some(tool_name.as_str()) { - load_state = Some((String::default(), call.index)); - } - if let Some((arguments, (output, index))) = - func.arguments.zip(load_state.as_mut()) - { - if call.index == *index { - output.push_str(&arguments); - } + let mut body = BufReader::new(response.into_body()); + let mut line = String::new(); + let mut load_state = None; + + while body.read_line(&mut line).await? > 0 { + let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?; + line.clear(); + + for choice in part.choices { + let Some(tool_calls) = choice.delta.tool_calls else { + continue; + }; + + for call in tool_calls { + if let Some(func) = call.function { + if func.name.as_deref() == Some(tool_name.as_str()) { + load_state = Some((String::default(), call.index)); + } + if let Some((arguments, (output, index))) = + func.arguments.zip(load_state.as_mut()) + { + if call.index == *index { + output.push_str(&arguments); } } } } } - if let Some((arguments, _)) = load_state { - return Ok(serde_json::from_str(&arguments)?); - } else { - bail!("tool not used"); - } - }) - .boxed() - } + } + + if let Some((arguments, _)) = load_state { + return Ok(serde_json::from_str(&arguments)?); + } else { + bail!("tool not used"); + } + }) + .boxed() } CloudModel::Google(_) => { future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed() @@ -854,114 +665,58 @@ impl LanguageModel for CloudLanguageModel { function.parameters = Some(input_schema); request.tools = vec![open_ai::ToolDefinition::Function { function }]; - if cx - .update(|cx| cx.has_flag::()) - .unwrap_or(false) - { - let llm_api_token = self.llm_api_token.clone(); - self.request_limiter - .run(async move { - let response = Self::perform_llm_completion( - client.clone(), - llm_api_token, - PerformCompletionParams { - provider: client::LanguageModelProvider::Zed, - model: request.model.clone(), - provider_request: RawValue::from_string( - serde_json::to_string(&request)?, - )?, - }, - ) - .await?; - - let mut body = BufReader::new(response.into_body()); - let mut line = String::new(); - let mut load_state = None; - - while body.read_line(&mut line).await? > 0 { - let part: open_ai::ResponseStreamEvent = - serde_json::from_str(&line)?; - line.clear(); - - for choice in part.choices { - let Some(tool_calls) = choice.delta.tool_calls else { - continue; - }; - - for call in tool_calls { - if let Some(func) = call.function { - if func.name.as_deref() == Some(tool_name.as_str()) { - load_state = Some((String::default(), call.index)); - } - if let Some((arguments, (output, index))) = - func.arguments.zip(load_state.as_mut()) - { - if call.index == *index { - output.push_str(&arguments); - } - } + let llm_api_token = self.llm_api_token.clone(); + self.request_limiter + .run(async move { + let response = Self::perform_llm_completion( + client.clone(), + llm_api_token, + PerformCompletionParams { + provider: client::LanguageModelProvider::Zed, + model: request.model.clone(), + provider_request: RawValue::from_string(serde_json::to_string( + &request, + )?)?, + }, + ) + .await?; + + let mut body = BufReader::new(response.into_body()); + let mut line = String::new(); + let mut load_state = None; + + while body.read_line(&mut line).await? > 0 { + let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?; + line.clear(); + + for choice in part.choices { + let Some(tool_calls) = choice.delta.tool_calls else { + continue; + }; + + for call in tool_calls { + if let Some(func) = call.function { + if func.name.as_deref() == Some(tool_name.as_str()) { + load_state = Some((String::default(), call.index)); } - } - } - } - if let Some((arguments, _)) = load_state { - return Ok(serde_json::from_str(&arguments)?); - } else { - bail!("tool not used"); - } - }) - .boxed() - } else { - self.request_limiter - .run(async move { - let request = serde_json::to_string(&request)?; - let response = client - .request_stream(proto::StreamCompleteWithLanguageModel { - provider: proto::LanguageModelProvider::OpenAi as i32, - request, - }) - .await?; - let mut load_state = None; - let mut response = response.map( - |item: Result< - proto::StreamCompleteWithLanguageModelResponse, - anyhow::Error, - >| { - Result::::Ok( - serde_json::from_str(&item?.event)?, - ) - }, - ); - while let Some(Ok(part)) = response.next().await { - for choice in part.choices { - let Some(tool_calls) = choice.delta.tool_calls else { - continue; - }; - - for call in tool_calls { - if let Some(func) = call.function { - if func.name.as_deref() == Some(tool_name.as_str()) { - load_state = Some((String::default(), call.index)); - } - if let Some((arguments, (output, index))) = - func.arguments.zip(load_state.as_mut()) - { - if call.index == *index { - output.push_str(&arguments); - } + if let Some((arguments, (output, index))) = + func.arguments.zip(load_state.as_mut()) + { + if call.index == *index { + output.push_str(&arguments); } } } } } - if let Some((arguments, _)) = load_state { - return Ok(serde_json::from_str(&arguments)?); - } else { - bail!("tool not used"); - } - }) - .boxed() - } + } + if let Some((arguments, _)) = load_state { + return Ok(serde_json::from_str(&arguments)?); + } else { + bail!("tool not used"); + } + }) + .boxed() } } }