Use LLM service for tool call requests (#16046)

Max Brunsfeld and Marshall created

Release Notes:

- N/A

---------

Co-authored-by: Marshall <marshall@zed.dev>

Change summary

crates/collab/src/llm.rs                    |  44 +-
crates/language_model/src/provider/cloud.rs | 413 ++++++++++++++++------
2 files changed, 328 insertions(+), 129 deletions(-)

Detailed changes

crates/collab/src/llm.rs 🔗

@@ -322,25 +322,33 @@ async fn perform_completion(
 }
 
 fn normalize_model_name(provider: LanguageModelProvider, name: String) -> String {
-    match provider {
-        LanguageModelProvider::Anthropic => {
-            for prefix in &[
-                "claude-3-5-sonnet",
-                "claude-3-haiku",
-                "claude-3-opus",
-                "claude-3-sonnet",
-            ] {
-                if name.starts_with(prefix) {
-                    return prefix.to_string();
-                }
-            }
-        }
-        LanguageModelProvider::OpenAi => {}
-        LanguageModelProvider::Google => {}
-        LanguageModelProvider::Zed => {}
-    }
+    let prefixes: &[_] = match provider {
+        LanguageModelProvider::Anthropic => &[
+            "claude-3-5-sonnet",
+            "claude-3-haiku",
+            "claude-3-opus",
+            "claude-3-sonnet",
+        ],
+        LanguageModelProvider::OpenAi => &[
+            "gpt-3.5-turbo",
+            "gpt-4-turbo-preview",
+            "gpt-4o-mini",
+            "gpt-4o",
+            "gpt-4",
+        ],
+        LanguageModelProvider::Google => &[],
+        LanguageModelProvider::Zed => &[],
+    };
 
-    name
+    if let Some(prefix) = prefixes
+        .iter()
+        .filter(|&&prefix| name.starts_with(prefix))
+        .max_by_key(|&&prefix| prefix.len())
+    {
+        prefix.to_string()
+    } else {
+        name
+    }
 }
 
 async fn check_usage_limit(

crates/language_model/src/provider/cloud.rs 🔗

@@ -590,7 +590,7 @@ impl LanguageModel for CloudLanguageModel {
         tool_name: String,
         tool_description: String,
         input_schema: serde_json::Value,
-        _cx: &AsyncAppContext,
+        cx: &AsyncAppContext,
     ) -> BoxFuture<'static, Result<serde_json::Value>> {
         match &self.model {
             CloudModel::Anthropic(model) => {
@@ -605,34 +605,106 @@ impl LanguageModel for CloudLanguageModel {
                     input_schema,
                 }];
 
-                self.request_limiter
-                    .run(async move {
-                        let request = serde_json::to_string(&request)?;
-                        let response = client
-                            .request(proto::CompleteWithLanguageModel {
-                                provider: proto::LanguageModelProvider::Anthropic as i32,
-                                request,
-                            })
+                if cx
+                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+                    .unwrap_or(false)
+                {
+                    let llm_api_token = self.llm_api_token.clone();
+                    self.request_limiter
+                        .run(async move {
+                            let response = Self::perform_llm_completion(
+                                client.clone(),
+                                llm_api_token,
+                                PerformCompletionParams {
+                                    provider: client::LanguageModelProvider::Anthropic,
+                                    model: request.model.clone(),
+                                    provider_request: RawValue::from_string(
+                                        serde_json::to_string(&request)?,
+                                    )?,
+                                },
+                            )
                             .await?;
-                        let response: anthropic::Response =
-                            serde_json::from_str(&response.completion)?;
-                        response
-                            .content
-                            .into_iter()
-                            .find_map(|content| {
-                                if let anthropic::Content::ToolUse { name, input, .. } = content {
-                                    if name == tool_name {
-                                        Some(input)
+
+                            let mut tool_use_index = None;
+                            let mut tool_input = String::new();
+                            let mut body = BufReader::new(response.into_body());
+                            let mut line = String::new();
+                            while body.read_line(&mut line).await? > 0 {
+                                let event: anthropic::Event = serde_json::from_str(&line)?;
+                                line.clear();
+
+                                match event {
+                                    anthropic::Event::ContentBlockStart {
+                                        content_block,
+                                        index,
+                                    } => {
+                                        if let anthropic::Content::ToolUse { name, .. } =
+                                            content_block
+                                        {
+                                            if name == tool_name {
+                                                tool_use_index = Some(index);
+                                            }
+                                        }
+                                    }
+                                    anthropic::Event::ContentBlockDelta { index, delta } => {
+                                        match delta {
+                                            anthropic::ContentDelta::TextDelta { .. } => {}
+                                            anthropic::ContentDelta::InputJsonDelta {
+                                                partial_json,
+                                            } => {
+                                                if Some(index) == tool_use_index {
+                                                    tool_input.push_str(&partial_json);
+                                                }
+                                            }
+                                        }
+                                    }
+                                    anthropic::Event::ContentBlockStop { index } => {
+                                        if Some(index) == tool_use_index {
+                                            return Ok(serde_json::from_str(&tool_input)?);
+                                        }
+                                    }
+                                    _ => {}
+                                }
+                            }
+
+                            if tool_use_index.is_some() {
+                                Err(anyhow!("tool content incomplete"))
+                            } else {
+                                Err(anyhow!("tool not used"))
+                            }
+                        })
+                        .boxed()
+                } else {
+                    self.request_limiter
+                        .run(async move {
+                            let request = serde_json::to_string(&request)?;
+                            let response = client
+                                .request(proto::CompleteWithLanguageModel {
+                                    provider: proto::LanguageModelProvider::Anthropic as i32,
+                                    request,
+                                })
+                                .await?;
+                            let response: anthropic::Response =
+                                serde_json::from_str(&response.completion)?;
+                            response
+                                .content
+                                .into_iter()
+                                .find_map(|content| {
+                                    if let anthropic::Content::ToolUse { name, input, .. } = content
+                                    {
+                                        if name == tool_name {
+                                            Some(input)
+                                        } else {
+                                            None
+                                        }
                                     } else {
                                         None
                                     }
-                                } else {
-                                    None
-                                }
-                            })
-                            .context("tool not used")
-                    })
-                    .boxed()
+                                })
+                                .context("tool not used")
+                        })
+                        .boxed()
+                }
             }
             CloudModel::OpenAi(model) => {
                 let mut request = request.into_open_ai(model.id().into());
@@ -650,56 +722,116 @@ impl LanguageModel for CloudLanguageModel {
                 function.description = Some(tool_description);
                 function.parameters = Some(input_schema);
                 request.tools = vec![open_ai::ToolDefinition::Function { function }];
-                self.request_limiter
-                    .run(async move {
-                        let request = serde_json::to_string(&request)?;
-                        let response = client
-                            .request_stream(proto::StreamCompleteWithLanguageModel {
-                                provider: proto::LanguageModelProvider::OpenAi as i32,
-                                request,
-                            })
+
+                if cx
+                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+                    .unwrap_or(false)
+                {
+                    let llm_api_token = self.llm_api_token.clone();
+                    self.request_limiter
+                        .run(async move {
+                            let response = Self::perform_llm_completion(
+                                client.clone(),
+                                llm_api_token,
+                                PerformCompletionParams {
+                                    provider: client::LanguageModelProvider::OpenAi,
+                                    model: request.model.clone(),
+                                    provider_request: RawValue::from_string(
+                                        serde_json::to_string(&request)?,
+                                    )?,
+                                },
+                            )
                             .await?;
-                        // Call arguments are gonna be streamed in over multiple chunks.
-                        let mut load_state = None;
-                        let mut response = response.map(
-                            |item: Result<
-                                proto::StreamCompleteWithLanguageModelResponse,
-                                anyhow::Error,
-                            >| {
-                                Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
-                                    serde_json::from_str(&item?.event)?,
-                                )
-                            },
-                        );
-                        while let Some(Ok(part)) = response.next().await {
-                            for choice in part.choices {
-                                let Some(tool_calls) = choice.delta.tool_calls else {
-                                    continue;
-                                };
-
-                                for call in tool_calls {
-                                    if let Some(func) = call.function {
-                                        if func.name.as_deref() == Some(tool_name.as_str()) {
-                                            load_state = Some((String::default(), call.index));
+
+                            let mut body = BufReader::new(response.into_body());
+                            let mut line = String::new();
+                            let mut load_state = None;
+
+                            while body.read_line(&mut line).await? > 0 {
+                                let part: open_ai::ResponseStreamEvent =
+                                    serde_json::from_str(&line)?;
+                                line.clear();
+
+                                for choice in part.choices {
+                                    let Some(tool_calls) = choice.delta.tool_calls else {
+                                        continue;
+                                    };
+
+                                    for call in tool_calls {
+                                        if let Some(func) = call.function {
+                                            if func.name.as_deref() == Some(tool_name.as_str()) {
+                                                load_state = Some((String::default(), call.index));
+                                            }
+                                            if let Some((arguments, (output, index))) =
+                                                func.arguments.zip(load_state.as_mut())
+                                            {
+                                                if call.index == *index {
+                                                    output.push_str(&arguments);
+                                                }
+                                            }
                                         }
-                                        if let Some((arguments, (output, index))) =
-                                            func.arguments.zip(load_state.as_mut())
-                                        {
-                                            if call.index == *index {
-                                                output.push_str(&arguments);
+                                    }
+                                }
+                            }
+
+                            if let Some((arguments, _)) = load_state {
+                                return Ok(serde_json::from_str(&arguments)?);
+                            } else {
+                                bail!("tool not used");
+                            }
+                        })
+                        .boxed()
+                } else {
+                    self.request_limiter
+                        .run(async move {
+                            let request = serde_json::to_string(&request)?;
+                            let response = client
+                                .request_stream(proto::StreamCompleteWithLanguageModel {
+                                    provider: proto::LanguageModelProvider::OpenAi as i32,
+                                    request,
+                                })
+                                .await?;
+                            let mut load_state = None;
+                            let mut response = response.map(
+                                |item: Result<
+                                    proto::StreamCompleteWithLanguageModelResponse,
+                                    anyhow::Error,
+                                >| {
+                                    Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
+                                        serde_json::from_str(&item?.event)?,
+                                    )
+                                },
+                            );
+                            while let Some(Ok(part)) = response.next().await {
+                                for choice in part.choices {
+                                    let Some(tool_calls) = choice.delta.tool_calls else {
+                                        continue;
+                                    };
+
+                                    for call in tool_calls {
+                                        if let Some(func) = call.function {
+                                            if func.name.as_deref() == Some(tool_name.as_str()) {
+                                                load_state = Some((String::default(), call.index));
+                                            }
+                                            if let Some((arguments, (output, index))) =
+                                                func.arguments.zip(load_state.as_mut())
+                                            {
+                                                if call.index == *index {
+                                                    output.push_str(&arguments);
+                                                }
                                             }
                                         }
                                     }
                                 }
                             }
-                        }
-                        if let Some((arguments, _)) = load_state {
-                            return Ok(serde_json::from_str(&arguments)?);
-                        } else {
-                            bail!("tool not used");
-                        }
-                    })
-                    .boxed()
+                            if let Some((arguments, _)) = load_state {
+                                return Ok(serde_json::from_str(&arguments)?);
+                            } else {
+                                bail!("tool not used");
+                            }
+                        })
+                        .boxed()
+                }
             }
             CloudModel::Google(_) => {
                 future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
@@ -721,56 +853,115 @@ impl LanguageModel for CloudLanguageModel {
                 function.description = Some(tool_description);
                 function.parameters = Some(input_schema);
                 request.tools = vec![open_ai::ToolDefinition::Function { function }];
-                self.request_limiter
-                    .run(async move {
-                        let request = serde_json::to_string(&request)?;
-                        let response = client
-                            .request_stream(proto::StreamCompleteWithLanguageModel {
-                                provider: proto::LanguageModelProvider::OpenAi as i32,
-                                request,
-                            })
+
+                if cx
+                    .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
+                    .unwrap_or(false)
+                {
+                    let llm_api_token = self.llm_api_token.clone();
+                    self.request_limiter
+                        .run(async move {
+                            let response = Self::perform_llm_completion(
+                                client.clone(),
+                                llm_api_token,
+                                PerformCompletionParams {
+                                    provider: client::LanguageModelProvider::Zed,
+                                    model: request.model.clone(),
+                                    provider_request: RawValue::from_string(
+                                        serde_json::to_string(&request)?,
+                                    )?,
+                                },
+                            )
                             .await?;
-                        // Call arguments are gonna be streamed in over multiple chunks.
-                        let mut load_state = None;
-                        let mut response = response.map(
-                            |item: Result<
-                                proto::StreamCompleteWithLanguageModelResponse,
-                                anyhow::Error,
-                            >| {
-                                Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
-                                    serde_json::from_str(&item?.event)?,
-                                )
-                            },
-                        );
-                        while let Some(Ok(part)) = response.next().await {
-                            for choice in part.choices {
-                                let Some(tool_calls) = choice.delta.tool_calls else {
-                                    continue;
-                                };
-
-                                for call in tool_calls {
-                                    if let Some(func) = call.function {
-                                        if func.name.as_deref() == Some(tool_name.as_str()) {
-                                            load_state = Some((String::default(), call.index));
+
+                            let mut body = BufReader::new(response.into_body());
+                            let mut line = String::new();
+                            let mut load_state = None;
+
+                            while body.read_line(&mut line).await? > 0 {
+                                let part: open_ai::ResponseStreamEvent =
+                                    serde_json::from_str(&line)?;
+                                line.clear();
+
+                                for choice in part.choices {
+                                    let Some(tool_calls) = choice.delta.tool_calls else {
+                                        continue;
+                                    };
+
+                                    for call in tool_calls {
+                                        if let Some(func) = call.function {
+                                            if func.name.as_deref() == Some(tool_name.as_str()) {
+                                                load_state = Some((String::default(), call.index));
+                                            }
+                                            if let Some((arguments, (output, index))) =
+                                                func.arguments.zip(load_state.as_mut())
+                                            {
+                                                if call.index == *index {
+                                                    output.push_str(&arguments);
+                                                }
+                                            }
                                         }
-                                        if let Some((arguments, (output, index))) =
-                                            func.arguments.zip(load_state.as_mut())
-                                        {
-                                            if call.index == *index {
-                                                output.push_str(&arguments);
+                                    }
+                                }
+                            }
+                            if let Some((arguments, _)) = load_state {
+                                return Ok(serde_json::from_str(&arguments)?);
+                            } else {
+                                bail!("tool not used");
+                            }
+                        })
+                        .boxed()
+                } else {
+                    self.request_limiter
+                        .run(async move {
+                            let request = serde_json::to_string(&request)?;
+                            let response = client
+                                .request_stream(proto::StreamCompleteWithLanguageModel {
+                                    provider: proto::LanguageModelProvider::OpenAi as i32,
+                                    request,
+                                })
+                                .await?;
+                            let mut load_state = None;
+                            let mut response = response.map(
+                                |item: Result<
+                                    proto::StreamCompleteWithLanguageModelResponse,
+                                    anyhow::Error,
+                                >| {
+                                    Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
+                                        serde_json::from_str(&item?.event)?,
+                                    )
+                                },
+                            );
+                            while let Some(Ok(part)) = response.next().await {
+                                for choice in part.choices {
+                                    let Some(tool_calls) = choice.delta.tool_calls else {
+                                        continue;
+                                    };
+
+                                    for call in tool_calls {
+                                        if let Some(func) = call.function {
+                                            if func.name.as_deref() == Some(tool_name.as_str()) {
+                                                load_state = Some((String::default(), call.index));
+                                            }
+                                            if let Some((arguments, (output, index))) =
+                                                func.arguments.zip(load_state.as_mut())
+                                            {
+                                                if call.index == *index {
+                                                    output.push_str(&arguments);
+                                                }
                                             }
                                         }
                                     }
                                 }
                             }
-                        }
-                        if let Some((arguments, _)) = load_state {
-                            return Ok(serde_json::from_str(&arguments)?);
-                        } else {
-                            bail!("tool not used");
-                        }
-                    })
-                    .boxed()
+                            if let Some((arguments, _)) = load_state {
+                                return Ok(serde_json::from_str(&arguments)?);
+                            } else {
+                                bail!("tool not used");
+                            }
+                        })
+                        .boxed()
+                }
             }
         }
     }