@@ -322,25 +322,33 @@ async fn perform_completion(
}
fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
- match provider {
- LanguageModelProvider::Anthropic => {
- for prefix in &[
- "claude-3-5-sonnet",
- "claude-3-haiku",
- "claude-3-opus",
- "claude-3-sonnet",
- ] {
- if name.starts_with(prefix) {
- return prefix.to_string();
- }
- }
- }
- LanguageModelProvider::OpenAi => {}
- LanguageModelProvider::Google => {}
- LanguageModelProvider::Zed => {}
- }
+ let prefixes: &[_] = match provider {
+ LanguageModelProvider::Anthropic => &[
+ "claude-3-5-sonnet",
+ "claude-3-haiku",
+ "claude-3-opus",
+ "claude-3-sonnet",
+ ],
+ LanguageModelProvider::OpenAi => &[
+ "gpt-3.5-turbo",
+ "gpt-4-turbo-preview",
+ "gpt-4o-mini",
+ "gpt-4o",
+ "gpt-4",
+ ],
+ LanguageModelProvider::Google => &[],
+ LanguageModelProvider::Zed => &[],
+ };
- name
+ if let Some(prefix) = prefixes
+ .iter()
+ .filter(|&&prefix| name.starts_with(prefix))
+ .max_by_key(|&&prefix| prefix.len())
+ {
+ prefix.to_string()
+ } else {
+ name
+ }
}
async fn check_usage_limit(
@@ -590,7 +590,7 @@ impl LanguageModel for CloudLanguageModel {
tool_name: String,
tool_description: String,
input_schema: serde_json::Value,
- _cx: &AsyncAppContext,
+ cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
match &self.model {
CloudModel::Anthropic(model) => {
@@ -605,34 +605,106 @@ impl LanguageModel for CloudLanguageModel {
input_schema,
}];
- self.request_limiter
- .run(async move {
- let request = serde_json::to_string(&request)?;
- let response = client
- .request(proto::CompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::Anthropic as i32,
- request,
- })
+ if cx
+ .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+ .unwrap_or(false)
+ {
+ let llm_api_token = self.llm_api_token.clone();
+ self.request_limiter
+ .run(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::Anthropic,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(
+ serde_json::to_string(&request)?,
+ )?,
+ },
+ )
.await?;
- let response: anthropic::Response =
- serde_json::from_str(&response.completion)?;
- response
- .content
- .into_iter()
- .find_map(|content| {
- if let anthropic::Content::ToolUse { name, input, .. } = content {
- if name == tool_name {
- Some(input)
+
+ let mut tool_use_index = None;
+ let mut tool_input = String::new();
+ let mut body = BufReader::new(response.into_body());
+ let mut line = String::new();
+ while body.read_line(&mut line).await? > 0 {
+ let event: anthropic::Event = serde_json::from_str(&line)?;
+ line.clear();
+
+ match event {
+ anthropic::Event::ContentBlockStart {
+ content_block,
+ index,
+ } => {
+ if let anthropic::Content::ToolUse { name, .. } =
+ content_block
+ {
+ if name == tool_name {
+ tool_use_index = Some(index);
+ }
+ }
+ }
+ anthropic::Event::ContentBlockDelta { index, delta } => {
+ match delta {
+ anthropic::ContentDelta::TextDelta { .. } => {}
+ anthropic::ContentDelta::InputJsonDelta {
+ partial_json,
+ } => {
+ if Some(index) == tool_use_index {
+ tool_input.push_str(&partial_json);
+ }
+ }
+ }
+ }
+ anthropic::Event::ContentBlockStop { index } => {
+ if Some(index) == tool_use_index {
+ return Ok(serde_json::from_str(&tool_input)?);
+ }
+ }
+ _ => {}
+ }
+ }
+
+ if tool_use_index.is_some() {
+ Err(anyhow!("tool content incomplete"))
+ } else {
+ Err(anyhow!("tool not used"))
+ }
+ })
+ .boxed()
+ } else {
+ self.request_limiter
+ .run(async move {
+ let request = serde_json::to_string(&request)?;
+ let response = client
+ .request(proto::CompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::Anthropic as i32,
+ request,
+ })
+ .await?;
+ let response: anthropic::Response =
+ serde_json::from_str(&response.completion)?;
+ response
+ .content
+ .into_iter()
+ .find_map(|content| {
+ if let anthropic::Content::ToolUse { name, input, .. } = content
+ {
+ if name == tool_name {
+ Some(input)
+ } else {
+ None
+ }
} else {
None
}
- } else {
- None
- }
- })
- .context("tool not used")
- })
- .boxed()
+ })
+ .context("tool not used")
+ })
+ .boxed()
+ }
}
CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into());
@@ -650,56 +722,116 @@ impl LanguageModel for CloudLanguageModel {
function.description = Some(tool_description);
function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }];
- self.request_limiter
- .run(async move {
- let request = serde_json::to_string(&request)?;
- let response = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::OpenAi as i32,
- request,
- })
+
+ if cx
+ .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+ .unwrap_or(false)
+ {
+ let llm_api_token = self.llm_api_token.clone();
+ self.request_limiter
+ .run(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::OpenAi,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(
+ serde_json::to_string(&request)?,
+ )?,
+ },
+ )
.await?;
- // Call arguments are gonna be streamed in over multiple chunks.
- let mut load_state = None;
- let mut response = response.map(
- |item: Result<
- proto::StreamCompleteWithLanguageModelResponse,
- anyhow::Error,
- >| {
- Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
- serde_json::from_str(&item?.event)?,
- )
- },
- );
- while let Some(Ok(part)) = response.next().await {
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
+
+ let mut body = BufReader::new(response.into_body());
+ let mut line = String::new();
+ let mut load_state = None;
+
+ while body.read_line(&mut line).await? > 0 {
+ let part: open_ai::ResponseStreamEvent =
+ serde_json::from_str(&line)?;
+ line.clear();
+
+ for choice in part.choices {
+ let Some(tool_calls) = choice.delta.tool_calls else {
+ continue;
+ };
+
+ for call in tool_calls {
+ if let Some(func) = call.function {
+ if func.name.as_deref() == Some(tool_name.as_str()) {
+ load_state = Some((String::default(), call.index));
+ }
+ if let Some((arguments, (output, index))) =
+ func.arguments.zip(load_state.as_mut())
+ {
+ if call.index == *index {
+ output.push_str(&arguments);
+ }
+ }
}
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
+ }
+ }
+ }
+
+ if let Some((arguments, _)) = load_state {
+ return Ok(serde_json::from_str(&arguments)?);
+ } else {
+ bail!("tool not used");
+ }
+ })
+ .boxed()
+ } else {
+ self.request_limiter
+ .run(async move {
+ let request = serde_json::to_string(&request)?;
+ let response = client
+ .request_stream(proto::StreamCompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::OpenAi as i32,
+ request,
+ })
+ .await?;
+ let mut load_state = None;
+ let mut response = response.map(
+ |item: Result<
+ proto::StreamCompleteWithLanguageModelResponse,
+ anyhow::Error,
+ >| {
+ Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
+ serde_json::from_str(&item?.event)?,
+ )
+ },
+ );
+ while let Some(Ok(part)) = response.next().await {
+ for choice in part.choices {
+ let Some(tool_calls) = choice.delta.tool_calls else {
+ continue;
+ };
+
+ for call in tool_calls {
+ if let Some(func) = call.function {
+ if func.name.as_deref() == Some(tool_name.as_str()) {
+ load_state = Some((String::default(), call.index));
+ }
+ if let Some((arguments, (output, index))) =
+ func.arguments.zip(load_state.as_mut())
+ {
+ if call.index == *index {
+ output.push_str(&arguments);
+ }
}
}
}
}
}
- }
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
- })
- .boxed()
+ if let Some((arguments, _)) = load_state {
+ return Ok(serde_json::from_str(&arguments)?);
+ } else {
+ bail!("tool not used");
+ }
+ })
+ .boxed()
+ }
}
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
@@ -721,56 +853,115 @@ impl LanguageModel for CloudLanguageModel {
function.description = Some(tool_description);
function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }];
- self.request_limiter
- .run(async move {
- let request = serde_json::to_string(&request)?;
- let response = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::OpenAi as i32,
- request,
- })
+
+ if cx
+ .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+ .unwrap_or(false)
+ {
+ let llm_api_token = self.llm_api_token.clone();
+ self.request_limiter
+ .run(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::Zed,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(
+ serde_json::to_string(&request)?,
+ )?,
+ },
+ )
.await?;
- // Call arguments are gonna be streamed in over multiple chunks.
- let mut load_state = None;
- let mut response = response.map(
- |item: Result<
- proto::StreamCompleteWithLanguageModelResponse,
- anyhow::Error,
- >| {
- Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
- serde_json::from_str(&item?.event)?,
- )
- },
- );
- while let Some(Ok(part)) = response.next().await {
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
+
+ let mut body = BufReader::new(response.into_body());
+ let mut line = String::new();
+ let mut load_state = None;
+
+ while body.read_line(&mut line).await? > 0 {
+ let part: open_ai::ResponseStreamEvent =
+ serde_json::from_str(&line)?;
+ line.clear();
+
+ for choice in part.choices {
+ let Some(tool_calls) = choice.delta.tool_calls else {
+ continue;
+ };
+
+ for call in tool_calls {
+ if let Some(func) = call.function {
+ if func.name.as_deref() == Some(tool_name.as_str()) {
+ load_state = Some((String::default(), call.index));
+ }
+ if let Some((arguments, (output, index))) =
+ func.arguments.zip(load_state.as_mut())
+ {
+ if call.index == *index {
+ output.push_str(&arguments);
+ }
+ }
}
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
+ }
+ }
+ }
+ if let Some((arguments, _)) = load_state {
+ return Ok(serde_json::from_str(&arguments)?);
+ } else {
+ bail!("tool not used");
+ }
+ })
+ .boxed()
+ } else {
+ self.request_limiter
+ .run(async move {
+ let request = serde_json::to_string(&request)?;
+ let response = client
+ .request_stream(proto::StreamCompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::OpenAi as i32,
+ request,
+ })
+ .await?;
+ let mut load_state = None;
+ let mut response = response.map(
+ |item: Result<
+ proto::StreamCompleteWithLanguageModelResponse,
+ anyhow::Error,
+ >| {
+ Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
+ serde_json::from_str(&item?.event)?,
+ )
+ },
+ );
+ while let Some(Ok(part)) = response.next().await {
+ for choice in part.choices {
+ let Some(tool_calls) = choice.delta.tool_calls else {
+ continue;
+ };
+
+ for call in tool_calls {
+ if let Some(func) = call.function {
+ if func.name.as_deref() == Some(tool_name.as_str()) {
+ load_state = Some((String::default(), call.index));
+ }
+ if let Some((arguments, (output, index))) =
+ func.arguments.zip(load_state.as_mut())
+ {
+ if call.index == *index {
+ output.push_str(&arguments);
+ }
}
}
}
}
}
- }
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
- })
- .boxed()
+ if let Some((arguments, _)) = load_state {
+ return Ok(serde_json::from_str(&arguments)?);
+ } else {
+ bail!("tool not used");
+ }
+ })
+ .boxed()
+ }
}
}
}