Detailed changes
@@ -1,10 +1,10 @@
mod supported_countries;
-use std::{pin::Pin, str::FromStr};
+use std::str::FromStr;
use anyhow::{Context as _, Result, anyhow};
use chrono::{DateTime, Utc};
-use futures::{AsyncBufReadExt, AsyncReadExt, Stream, StreamExt, io::BufReader, stream::BoxStream};
+use futures::{AsyncBufReadExt, AsyncReadExt, StreamExt, io::BufReader, stream::BoxStream};
use http_client::http::{HeaderMap, HeaderValue};
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use serde::{Deserialize, Serialize};
@@ -437,50 +437,6 @@ pub async fn stream_completion_with_rate_limit_info(
}
}
-pub async fn extract_tool_args_from_events(
- tool_name: String,
- mut events: Pin<Box<dyn Send + Stream<Item = Result<Event>>>>,
-) -> Result<impl Send + Stream<Item = Result<String>>> {
- let mut tool_use_index = None;
- while let Some(event) = events.next().await {
- if let Event::ContentBlockStart {
- index,
- content_block: ResponseContent::ToolUse { name, .. },
- } = event?
- {
- if name == tool_name {
- tool_use_index = Some(index);
- break;
- }
- }
- }
-
- let Some(tool_use_index) = tool_use_index else {
- return Err(anyhow!("tool not used"));
- };
-
- Ok(events.filter_map(move |event| {
- let result = match event {
- Err(error) => Some(Err(error)),
- Ok(Event::ContentBlockDelta { index, delta }) => match delta {
- ContentDelta::TextDelta { .. } => None,
- ContentDelta::ThinkingDelta { .. } => None,
- ContentDelta::SignatureDelta { .. } => None,
- ContentDelta::InputJsonDelta { partial_json } => {
- if index == tool_use_index {
- Some(Ok(partial_json))
- } else {
- None
- }
- }
- },
- _ => None,
- };
-
- async move { result }
- }))
-}
-
#[derive(Debug, Serialize, Deserialize, Copy, Clone)]
#[serde(rename_all = "lowercase")]
pub enum CacheControlType {
@@ -8,9 +8,7 @@ use aws_config::Region;
use aws_config::stalled_stream_protection::StalledStreamProtectionConfig;
use aws_credential_types::Credentials;
use aws_http_client::AwsHttpClient;
-use bedrock::bedrock_client::types::{
- ContentBlockDelta, ContentBlockStart, ContentBlockStartEvent, ConverseStreamOutput,
-};
+use bedrock::bedrock_client::types::{ContentBlockDelta, ContentBlockStart, ConverseStreamOutput};
use bedrock::bedrock_client::{self, Config};
use bedrock::{BedrockError, BedrockInnerContent, BedrockMessage, BedrockStreamingResponse, Model};
use collections::{BTreeMap, HashMap};
@@ -544,70 +542,6 @@ pub fn get_bedrock_tokens(
.boxed()
}
-pub async fn extract_tool_args_from_events(
- name: String,
- mut events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
- handle: Handle,
-) -> Result<impl Send + Stream<Item = Result<String>>> {
- handle
- .spawn(async move {
- let mut tool_use_index = None;
- while let Some(event) = events.next().await {
- if let BedrockStreamingResponse::ContentBlockStart(ContentBlockStartEvent {
- content_block_index,
- start,
- ..
- }) = event?
- {
- match start {
- None => {
- continue;
- }
- Some(start) => match start.as_tool_use() {
- Ok(tool_use) => {
- if name == tool_use.name {
- tool_use_index = Some(content_block_index);
- break;
- }
- }
- Err(err) => {
- return Err(anyhow!("Failed to parse tool use event: {:?}", err));
- }
- },
- }
- }
- }
-
- let Some(tool_use_index) = tool_use_index else {
- return Err(anyhow!("Tool is not used"));
- };
-
- Ok(events.filter_map(move |event| {
- let result = match event {
- Err(_err) => None,
- Ok(output) => match output.clone() {
- BedrockStreamingResponse::ContentBlockDelta(inner) => {
- match inner.clone().delta {
- Some(ContentBlockDelta::ToolUse(tool_use)) => {
- if inner.content_block_index == tool_use_index {
- Some(Ok(tool_use.input))
- } else {
- None
- }
- }
- _ => None,
- }
- }
- _ => None,
- },
- };
-
- async move { result }
- }))
- })
- .await?
-}
-
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<BedrockStreamingResponse, BedrockError>>>>,
handle: Handle,
@@ -12,7 +12,6 @@ use serde_json::Value;
use std::{
convert::TryFrom,
future::{self, Future},
- pin::Pin,
};
use strum::EnumIter;
@@ -620,57 +619,6 @@ pub fn embed<'a>(
}
}
-pub async fn extract_tool_args_from_events(
- tool_name: String,
- mut events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
-) -> Result<impl Send + Stream<Item = Result<String>>> {
- let mut tool_use_index = None;
- let mut first_chunk = None;
- while let Some(event) = events.next().await {
- let call = event?.choices.into_iter().find_map(|choice| {
- choice.delta.tool_calls?.into_iter().find_map(|call| {
- if call.function.as_ref()?.name.as_deref()? == tool_name {
- Some(call)
- } else {
- None
- }
- })
- });
- if let Some(call) = call {
- tool_use_index = Some(call.index);
- first_chunk = call.function.and_then(|func| func.arguments);
- break;
- }
- }
-
- let Some(tool_use_index) = tool_use_index else {
- return Err(anyhow!("tool not used"));
- };
-
- Ok(events.filter_map(move |event| {
- let result = match event {
- Err(error) => Some(Err(error)),
- Ok(ResponseStreamEvent { choices, .. }) => choices.into_iter().find_map(|choice| {
- choice.delta.tool_calls?.into_iter().find_map(|call| {
- if call.index == tool_use_index {
- let func = call.function?;
- let mut arguments = func.arguments?;
- if let Some(mut first_chunk) = first_chunk.take() {
- first_chunk.push_str(&arguments);
- arguments = first_chunk
- }
- Some(Ok(arguments))
- } else {
- None
- }
- })
- }),
- };
-
- async move { result }
- }))
-}
-
pub fn extract_text_from_events(
response: impl Stream<Item = Result<ResponseStreamEvent>>,
) -> impl Stream<Item = Result<String>> {