@@ -254,6 +254,7 @@ impl LanguageModel for CopilotChatLanguageModel {
Ok(request) => request,
Err(err) => return futures::future::ready(Err(err)).boxed(),
};
+ let is_streaming = copilot_request.stream;
let request_limiter = self.request_limiter.clone();
let future = cx.spawn(async move |cx| {
@@ -261,7 +262,10 @@ impl LanguageModel for CopilotChatLanguageModel {
request_limiter
.stream(async move {
let response = request.await?;
- Ok(map_to_language_model_completion_events(response))
+ Ok(map_to_language_model_completion_events(
+ response,
+ is_streaming,
+ ))
})
.await
});
@@ -271,6 +275,7 @@ impl LanguageModel for CopilotChatLanguageModel {
pub fn map_to_language_model_completion_events(
events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
+ is_streaming: bool,
) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
#[derive(Default)]
struct RawToolCall {
@@ -289,7 +294,7 @@ pub fn map_to_language_model_completion_events(
events,
tool_calls_by_index: HashMap::default(),
},
- |mut state| async move {
+ move |mut state| async move {
if let Some(event) = state.events.next().await {
match event {
Ok(event) => {
@@ -300,7 +305,13 @@ pub fn map_to_language_model_completion_events(
));
};
- let Some(delta) = choice.delta.as_ref() else {
+ let delta = if is_streaming {
+ choice.delta.as_ref()
+ } else {
+ choice.message.as_ref()
+ };
+
+ let Some(delta) = delta else {
return Some((
vec![Err(anyhow!("Response contained no delta"))],
state,
@@ -312,26 +323,26 @@ pub fn map_to_language_model_completion_events(
events.push(Ok(LanguageModelCompletionEvent::Text(content)));
}
- for tool_call in &delta.tool_calls {
- let entry = state
- .tool_calls_by_index
- .entry(tool_call.index)
- .or_default();
+ for tool_call in &delta.tool_calls {
+ let entry = state
+ .tool_calls_by_index
+ .entry(tool_call.index)
+ .or_default();
- if let Some(tool_id) = tool_call.id.clone() {
- entry.id = tool_id;
- }
+ if let Some(tool_id) = tool_call.id.clone() {
+ entry.id = tool_id;
+ }
- if let Some(function) = tool_call.function.as_ref() {
- if let Some(name) = function.name.clone() {
- entry.name = name;
- }
+ if let Some(function) = tool_call.function.as_ref() {
+ if let Some(name) = function.name.clone() {
+ entry.name = name;
+ }
- if let Some(arguments) = function.arguments.clone() {
- entry.arguments.push_str(&arguments);
- }
+ if let Some(arguments) = function.arguments.clone() {
+ entry.arguments.push_str(&arguments);
}
}
+ }
match choice.finish_reason.as_deref() {
Some("stop") => {
@@ -361,7 +372,7 @@ pub fn map_to_language_model_completion_events(
)));
}
Some(stop_reason) => {
- log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",);
+ log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
events.push(Ok(LanguageModelCompletionEvent::Stop(
StopReason::EndTurn,
)));