crates/language_model/src/provider/cloud.rs 🔗
@@ -4,10 +4,10 @@ use crate::{
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
-use anyhow::{anyhow, bail, Context as _, Result};
+use anyhow::{anyhow, bail, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
-use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
+use feature_flags::{FeatureFlagAppExt, LanguageModels};
use futures::{future::BoxFuture, stream::BoxStream, AsyncBufReadExt, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Model, ModelContext, Subscription, Task};
use http_client::{AsyncBody, HttpClient, Method, Response};
@@ -228,16 +228,6 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
}
-struct LlmServiceFeatureFlag;
-
-impl FeatureFlag for LlmServiceFeatureFlag {
- const NAME: &'static str = "llm-service";
-
- fn enabled_for_staff() -> bool {
- false
- }
-}
-
pub struct CloudLanguageModel {
id: LanguageModelId,
model: CloudModel,
@@ -354,232 +344,148 @@ impl LanguageModel for CloudLanguageModel {
fn stream_completion(
&self,
request: LanguageModelRequest,
- cx: &AsyncAppContext,
+ _cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
match &self.model {
CloudModel::Anthropic(model) => {
let request = request.into_anthropic(model.id().into());
let client = self.client.clone();
-
- if cx
- .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
- .unwrap_or(false)
- {
- let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let response = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- PerformCompletionParams {
- provider: client::LanguageModelProvider::Anthropic,
- model: request.model.clone(),
- provider_request: RawValue::from_string(serde_json::to_string(
- &request,
- )?)?,
- },
- )
- .await?;
- let body = BufReader::new(response.into_body());
- let stream =
- futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: anthropic::Event =
- serde_json::from_str(&buffer)?;
- Ok(Some((event, body)))
- }
- Err(e) => Err(e.into()),
- }
- });
-
- Ok(anthropic::extract_text_from_events(stream))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- } else {
- let future = self.request_limiter.stream(async move {
- let request = serde_json::to_string(&request)?;
- let stream = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::Anthropic as i32,
- request,
- })
- .await?
- .map(|event| Ok(serde_json::from_str(&event?.event)?));
- Ok(anthropic::extract_text_from_events(stream))
+ let llm_api_token = self.llm_api_token.clone();
+ let future = self.request_limiter.stream(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::Anthropic,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(serde_json::to_string(
+ &request,
+ )?)?,
+ },
+ )
+ .await?;
+ let body = BufReader::new(response.into_body());
+ let stream = futures::stream::try_unfold(body, move |mut body| async move {
+ let mut buffer = String::new();
+ match body.read_line(&mut buffer).await {
+ Ok(0) => Ok(None),
+ Ok(_) => {
+ let event: anthropic::Event = serde_json::from_str(&buffer)?;
+ Ok(Some((event, body)))
+ }
+ Err(e) => Err(e.into()),
+ }
});
- async move { Ok(future.await?.boxed()) }.boxed()
- }
+
+ Ok(anthropic::extract_text_from_events(stream))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::OpenAi(model) => {
let client = self.client.clone();
let request = request.into_open_ai(model.id().into());
-
- if cx
- .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
- .unwrap_or(false)
- {
- let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let response = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- PerformCompletionParams {
- provider: client::LanguageModelProvider::OpenAi,
- model: request.model.clone(),
- provider_request: RawValue::from_string(serde_json::to_string(
- &request,
- )?)?,
- },
- )
- .await?;
- let body = BufReader::new(response.into_body());
- let stream =
- futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: open_ai::ResponseStreamEvent =
- serde_json::from_str(&buffer)?;
- Ok(Some((event, body)))
- }
- Err(e) => Err(e.into()),
- }
- });
-
- Ok(open_ai::extract_text_from_events(stream))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- } else {
- let future = self.request_limiter.stream(async move {
- let request = serde_json::to_string(&request)?;
- let stream = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::OpenAi as i32,
- request,
- })
- .await?;
- Ok(open_ai::extract_text_from_events(
- stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
- ))
+ let llm_api_token = self.llm_api_token.clone();
+ let future = self.request_limiter.stream(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::OpenAi,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(serde_json::to_string(
+ &request,
+ )?)?,
+ },
+ )
+ .await?;
+ let body = BufReader::new(response.into_body());
+ let stream = futures::stream::try_unfold(body, move |mut body| async move {
+ let mut buffer = String::new();
+ match body.read_line(&mut buffer).await {
+ Ok(0) => Ok(None),
+ Ok(_) => {
+ let event: open_ai::ResponseStreamEvent =
+ serde_json::from_str(&buffer)?;
+ Ok(Some((event, body)))
+ }
+ Err(e) => Err(e.into()),
+ }
});
- async move { Ok(future.await?.boxed()) }.boxed()
- }
+
+ Ok(open_ai::extract_text_from_events(stream))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::Google(model) => {
let client = self.client.clone();
let request = request.into_google(model.id().into());
-
- if cx
- .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
- .unwrap_or(false)
- {
- let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let response = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- PerformCompletionParams {
- provider: client::LanguageModelProvider::Google,
- model: request.model.clone(),
- provider_request: RawValue::from_string(serde_json::to_string(
- &request,
- )?)?,
- },
- )
- .await?;
- let body = BufReader::new(response.into_body());
- let stream =
- futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: google_ai::GenerateContentResponse =
- serde_json::from_str(&buffer)?;
- Ok(Some((event, body)))
- }
- Err(e) => Err(e.into()),
- }
- });
-
- Ok(google_ai::extract_text_from_events(stream))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- } else {
- let future = self.request_limiter.stream(async move {
- let request = serde_json::to_string(&request)?;
- let stream = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::Google as i32,
- request,
- })
- .await?;
- Ok(google_ai::extract_text_from_events(
- stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
- ))
+ let llm_api_token = self.llm_api_token.clone();
+ let future = self.request_limiter.stream(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::Google,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(serde_json::to_string(
+ &request,
+ )?)?,
+ },
+ )
+ .await?;
+ let body = BufReader::new(response.into_body());
+ let stream = futures::stream::try_unfold(body, move |mut body| async move {
+ let mut buffer = String::new();
+ match body.read_line(&mut buffer).await {
+ Ok(0) => Ok(None),
+ Ok(_) => {
+ let event: google_ai::GenerateContentResponse =
+ serde_json::from_str(&buffer)?;
+ Ok(Some((event, body)))
+ }
+ Err(e) => Err(e.into()),
+ }
});
- async move { Ok(future.await?.boxed()) }.boxed()
- }
+
+ Ok(google_ai::extract_text_from_events(stream))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
}
CloudModel::Zed(model) => {
let client = self.client.clone();
let mut request = request.into_open_ai(model.id().into());
request.max_tokens = Some(4000);
-
- if cx
- .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
- .unwrap_or(false)
- {
- let llm_api_token = self.llm_api_token.clone();
- let future = self.request_limiter.stream(async move {
- let response = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- PerformCompletionParams {
- provider: client::LanguageModelProvider::Zed,
- model: request.model.clone(),
- provider_request: RawValue::from_string(serde_json::to_string(
- &request,
- )?)?,
- },
- )
- .await?;
- let body = BufReader::new(response.into_body());
- let stream =
- futures::stream::try_unfold(body, move |mut body| async move {
- let mut buffer = String::new();
- match body.read_line(&mut buffer).await {
- Ok(0) => Ok(None),
- Ok(_) => {
- let event: open_ai::ResponseStreamEvent =
- serde_json::from_str(&buffer)?;
- Ok(Some((event, body)))
- }
- Err(e) => Err(e.into()),
- }
- });
-
- Ok(open_ai::extract_text_from_events(stream))
- });
- async move { Ok(future.await?.boxed()) }.boxed()
- } else {
- let future = self.request_limiter.stream(async move {
- let request = serde_json::to_string(&request)?;
- let stream = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::Zed as i32,
- request,
- })
- .await?;
- Ok(open_ai::extract_text_from_events(
- stream.map(|item| Ok(serde_json::from_str(&item?.event)?)),
- ))
+ let llm_api_token = self.llm_api_token.clone();
+ let future = self.request_limiter.stream(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::Zed,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(serde_json::to_string(
+ &request,
+ )?)?,
+ },
+ )
+ .await?;
+ let body = BufReader::new(response.into_body());
+ let stream = futures::stream::try_unfold(body, move |mut body| async move {
+ let mut buffer = String::new();
+ match body.read_line(&mut buffer).await {
+ Ok(0) => Ok(None),
+ Ok(_) => {
+ let event: open_ai::ResponseStreamEvent =
+ serde_json::from_str(&buffer)?;
+ Ok(Some((event, body)))
+ }
+ Err(e) => Err(e.into()),
+ }
});
- async move { Ok(future.await?.boxed()) }.boxed()
- }
+
+ Ok(open_ai::extract_text_from_events(stream))
+ });
+ async move { Ok(future.await?.boxed()) }.boxed()
}
}
}
@@ -590,7 +496,7 @@ impl LanguageModel for CloudLanguageModel {
tool_name: String,
tool_description: String,
input_schema: serde_json::Value,
- cx: &AsyncAppContext,
+ _cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
match &self.model {
CloudModel::Anthropic(model) => {
@@ -605,106 +511,67 @@ impl LanguageModel for CloudLanguageModel {
input_schema,
}];
- if cx
- .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
- .unwrap_or(false)
- {
- let llm_api_token = self.llm_api_token.clone();
- self.request_limiter
- .run(async move {
- let response = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- PerformCompletionParams {
- provider: client::LanguageModelProvider::Anthropic,
- model: request.model.clone(),
- provider_request: RawValue::from_string(
- serde_json::to_string(&request)?,
- )?,
- },
- )
- .await?;
-
- let mut tool_use_index = None;
- let mut tool_input = String::new();
- let mut body = BufReader::new(response.into_body());
- let mut line = String::new();
- while body.read_line(&mut line).await? > 0 {
- let event: anthropic::Event = serde_json::from_str(&line)?;
- line.clear();
-
- match event {
- anthropic::Event::ContentBlockStart {
- content_block,
- index,
- } => {
- if let anthropic::Content::ToolUse { name, .. } =
- content_block
- {
- if name == tool_name {
- tool_use_index = Some(index);
- }
- }
- }
- anthropic::Event::ContentBlockDelta { index, delta } => {
- match delta {
- anthropic::ContentDelta::TextDelta { .. } => {}
- anthropic::ContentDelta::InputJsonDelta {
- partial_json,
- } => {
- if Some(index) == tool_use_index {
- tool_input.push_str(&partial_json);
- }
- }
+ let llm_api_token = self.llm_api_token.clone();
+ self.request_limiter
+ .run(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::Anthropic,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(serde_json::to_string(
+ &request,
+ )?)?,
+ },
+ )
+ .await?;
+
+ let mut tool_use_index = None;
+ let mut tool_input = String::new();
+ let mut body = BufReader::new(response.into_body());
+ let mut line = String::new();
+ while body.read_line(&mut line).await? > 0 {
+ let event: anthropic::Event = serde_json::from_str(&line)?;
+ line.clear();
+
+ match event {
+ anthropic::Event::ContentBlockStart {
+ content_block,
+ index,
+ } => {
+ if let anthropic::Content::ToolUse { name, .. } = content_block
+ {
+ if name == tool_name {
+ tool_use_index = Some(index);
}
}
- anthropic::Event::ContentBlockStop { index } => {
+ }
+ anthropic::Event::ContentBlockDelta { index, delta } => match delta
+ {
+ anthropic::ContentDelta::TextDelta { .. } => {}
+ anthropic::ContentDelta::InputJsonDelta { partial_json } => {
if Some(index) == tool_use_index {
- return Ok(serde_json::from_str(&tool_input)?);
+ tool_input.push_str(&partial_json);
}
}
- _ => {}
+ },
+ anthropic::Event::ContentBlockStop { index } => {
+ if Some(index) == tool_use_index {
+ return Ok(serde_json::from_str(&tool_input)?);
+ }
}
+ _ => {}
}
+ }
- if tool_use_index.is_some() {
- Err(anyhow!("tool content incomplete"))
- } else {
- Err(anyhow!("tool not used"))
- }
- })
- .boxed()
- } else {
- self.request_limiter
- .run(async move {
- let request = serde_json::to_string(&request)?;
- let response = client
- .request(proto::CompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::Anthropic as i32,
- request,
- })
- .await?;
- let response: anthropic::Response =
- serde_json::from_str(&response.completion)?;
- response
- .content
- .into_iter()
- .find_map(|content| {
- if let anthropic::Content::ToolUse { name, input, .. } = content
- {
- if name == tool_name {
- Some(input)
- } else {
- None
- }
- } else {
- None
- }
- })
- .context("tool not used")
- })
- .boxed()
- }
+ if tool_use_index.is_some() {
+ Err(anyhow!("tool content incomplete"))
+ } else {
+ Err(anyhow!("tool not used"))
+ }
+ })
+ .boxed()
}
CloudModel::OpenAi(model) => {
let mut request = request.into_open_ai(model.id().into());
@@ -723,115 +590,59 @@ impl LanguageModel for CloudLanguageModel {
function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }];
- if cx
- .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
- .unwrap_or(false)
- {
- let llm_api_token = self.llm_api_token.clone();
- self.request_limiter
- .run(async move {
- let response = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- PerformCompletionParams {
- provider: client::LanguageModelProvider::OpenAi,
- model: request.model.clone(),
- provider_request: RawValue::from_string(
- serde_json::to_string(&request)?,
- )?,
- },
- )
- .await?;
-
- let mut body = BufReader::new(response.into_body());
- let mut line = String::new();
- let mut load_state = None;
-
- while body.read_line(&mut line).await? > 0 {
- let part: open_ai::ResponseStreamEvent =
- serde_json::from_str(&line)?;
- line.clear();
-
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
- }
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
- }
- }
- }
- }
- }
- }
+ let llm_api_token = self.llm_api_token.clone();
+ self.request_limiter
+ .run(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::OpenAi,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(serde_json::to_string(
+ &request,
+ )?)?,
+ },
+ )
+ .await?;
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
- })
- .boxed()
- } else {
- self.request_limiter
- .run(async move {
- let request = serde_json::to_string(&request)?;
- let response = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::OpenAi as i32,
- request,
- })
- .await?;
- let mut load_state = None;
- let mut response = response.map(
- |item: Result<
- proto::StreamCompleteWithLanguageModelResponse,
- anyhow::Error,
- >| {
- Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
- serde_json::from_str(&item?.event)?,
- )
- },
- );
- while let Some(Ok(part)) = response.next().await {
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
- }
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
- }
+ let mut body = BufReader::new(response.into_body());
+ let mut line = String::new();
+ let mut load_state = None;
+
+ while body.read_line(&mut line).await? > 0 {
+ let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
+ line.clear();
+
+ for choice in part.choices {
+ let Some(tool_calls) = choice.delta.tool_calls else {
+ continue;
+ };
+
+ for call in tool_calls {
+ if let Some(func) = call.function {
+ if func.name.as_deref() == Some(tool_name.as_str()) {
+ load_state = Some((String::default(), call.index));
+ }
+ if let Some((arguments, (output, index))) =
+ func.arguments.zip(load_state.as_mut())
+ {
+ if call.index == *index {
+ output.push_str(&arguments);
}
}
}
}
}
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
- })
- .boxed()
- }
+ }
+
+ if let Some((arguments, _)) = load_state {
+ return Ok(serde_json::from_str(&arguments)?);
+ } else {
+ bail!("tool not used");
+ }
+ })
+ .boxed()
}
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
@@ -854,114 +665,58 @@ impl LanguageModel for CloudLanguageModel {
function.parameters = Some(input_schema);
request.tools = vec![open_ai::ToolDefinition::Function { function }];
- if cx
- .update(|cx| cx.has_flag::<LlmServiceFeatureFlag>())
- .unwrap_or(false)
- {
- let llm_api_token = self.llm_api_token.clone();
- self.request_limiter
- .run(async move {
- let response = Self::perform_llm_completion(
- client.clone(),
- llm_api_token,
- PerformCompletionParams {
- provider: client::LanguageModelProvider::Zed,
- model: request.model.clone(),
- provider_request: RawValue::from_string(
- serde_json::to_string(&request)?,
- )?,
- },
- )
- .await?;
-
- let mut body = BufReader::new(response.into_body());
- let mut line = String::new();
- let mut load_state = None;
-
- while body.read_line(&mut line).await? > 0 {
- let part: open_ai::ResponseStreamEvent =
- serde_json::from_str(&line)?;
- line.clear();
-
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
- }
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
- }
- }
+ let llm_api_token = self.llm_api_token.clone();
+ self.request_limiter
+ .run(async move {
+ let response = Self::perform_llm_completion(
+ client.clone(),
+ llm_api_token,
+ PerformCompletionParams {
+ provider: client::LanguageModelProvider::Zed,
+ model: request.model.clone(),
+ provider_request: RawValue::from_string(serde_json::to_string(
+ &request,
+ )?)?,
+ },
+ )
+ .await?;
+
+ let mut body = BufReader::new(response.into_body());
+ let mut line = String::new();
+ let mut load_state = None;
+
+ while body.read_line(&mut line).await? > 0 {
+ let part: open_ai::ResponseStreamEvent = serde_json::from_str(&line)?;
+ line.clear();
+
+ for choice in part.choices {
+ let Some(tool_calls) = choice.delta.tool_calls else {
+ continue;
+ };
+
+ for call in tool_calls {
+ if let Some(func) = call.function {
+ if func.name.as_deref() == Some(tool_name.as_str()) {
+ load_state = Some((String::default(), call.index));
}
- }
- }
- }
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
- })
- .boxed()
- } else {
- self.request_limiter
- .run(async move {
- let request = serde_json::to_string(&request)?;
- let response = client
- .request_stream(proto::StreamCompleteWithLanguageModel {
- provider: proto::LanguageModelProvider::OpenAi as i32,
- request,
- })
- .await?;
- let mut load_state = None;
- let mut response = response.map(
- |item: Result<
- proto::StreamCompleteWithLanguageModelResponse,
- anyhow::Error,
- >| {
- Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
- serde_json::from_str(&item?.event)?,
- )
- },
- );
- while let Some(Ok(part)) = response.next().await {
- for choice in part.choices {
- let Some(tool_calls) = choice.delta.tool_calls else {
- continue;
- };
-
- for call in tool_calls {
- if let Some(func) = call.function {
- if func.name.as_deref() == Some(tool_name.as_str()) {
- load_state = Some((String::default(), call.index));
- }
- if let Some((arguments, (output, index))) =
- func.arguments.zip(load_state.as_mut())
- {
- if call.index == *index {
- output.push_str(&arguments);
- }
+ if let Some((arguments, (output, index))) =
+ func.arguments.zip(load_state.as_mut())
+ {
+ if call.index == *index {
+ output.push_str(&arguments);
}
}
}
}
}
- if let Some((arguments, _)) = load_state {
- return Ok(serde_json::from_str(&arguments)?);
- } else {
- bail!("tool not used");
- }
- })
- .boxed()
- }
+ }
+ if let Some((arguments, _)) = load_state {
+ return Ok(serde_json::from_str(&arguments)?);
+ } else {
+ bail!("tool not used");
+ }
+ })
+ .boxed()
}
}
}