language_models: Fix non-streaming Copilot Chat models (#28537)

Marshall Bowers created

This PR fixes usage of non-streaming Copilot Chat models.

Closes https://github.com/zed-industries/zed/issues/28528.

Release Notes:

- Fixed an issue with using non-streaming Copilot Chat models (e.g., o1,
o3-mini).

Change summary

crates/language_models/src/provider/copilot_chat.rs | 49 +++++++++-----
1 file changed, 30 insertions(+), 19 deletions(-)

Detailed changes

crates/language_models/src/provider/copilot_chat.rs 🔗

@@ -254,6 +254,7 @@ impl LanguageModel for CopilotChatLanguageModel {
             Ok(request) => request,
             Err(err) => return futures::future::ready(Err(err)).boxed(),
         };
+        let is_streaming = copilot_request.stream;
 
         let request_limiter = self.request_limiter.clone();
         let future = cx.spawn(async move |cx| {
@@ -261,7 +262,10 @@ impl LanguageModel for CopilotChatLanguageModel {
             request_limiter
                 .stream(async move {
                     let response = request.await?;
-                    Ok(map_to_language_model_completion_events(response))
+                    Ok(map_to_language_model_completion_events(
+                        response,
+                        is_streaming,
+                    ))
                 })
                 .await
         });
@@ -271,6 +275,7 @@ impl LanguageModel for CopilotChatLanguageModel {
 
 pub fn map_to_language_model_completion_events(
     events: Pin<Box<dyn Send + Stream<Item = Result<ResponseEvent>>>>,
+    is_streaming: bool,
 ) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
     #[derive(Default)]
     struct RawToolCall {
@@ -289,7 +294,7 @@ pub fn map_to_language_model_completion_events(
             events,
             tool_calls_by_index: HashMap::default(),
         },
-        |mut state| async move {
+        move |mut state| async move {
             if let Some(event) = state.events.next().await {
                 match event {
                     Ok(event) => {
@@ -300,7 +305,13 @@ pub fn map_to_language_model_completion_events(
                             ));
                         };
 
-                        let Some(delta) = choice.delta.as_ref() else {
+                        let delta = if is_streaming {
+                            choice.delta.as_ref()
+                        } else {
+                            choice.message.as_ref()
+                        };
+
+                        let Some(delta) = delta else {
                             return Some((
                                 vec![Err(anyhow!("Response contained no delta"))],
                                 state,
@@ -312,26 +323,26 @@ pub fn map_to_language_model_completion_events(
                             events.push(Ok(LanguageModelCompletionEvent::Text(content)));
                         }
 
-                            for tool_call in &delta.tool_calls {
-                                let entry = state
-                                    .tool_calls_by_index
-                                    .entry(tool_call.index)
-                                    .or_default();
+                        for tool_call in &delta.tool_calls {
+                            let entry = state
+                                .tool_calls_by_index
+                                .entry(tool_call.index)
+                                .or_default();
 
-                                if let Some(tool_id) = tool_call.id.clone() {
-                                    entry.id = tool_id;
-                                }
+                            if let Some(tool_id) = tool_call.id.clone() {
+                                entry.id = tool_id;
+                            }
 
-                                if let Some(function) = tool_call.function.as_ref() {
-                                    if let Some(name) = function.name.clone() {
-                                        entry.name = name;
-                                    }
+                            if let Some(function) = tool_call.function.as_ref() {
+                                if let Some(name) = function.name.clone() {
+                                    entry.name = name;
+                                }
 
-                                    if let Some(arguments) = function.arguments.clone() {
-                                        entry.arguments.push_str(&arguments);
-                                    }
+                                if let Some(arguments) = function.arguments.clone() {
+                                    entry.arguments.push_str(&arguments);
                                 }
                             }
+                        }
 
                         match choice.finish_reason.as_deref() {
                             Some("stop") => {
@@ -361,7 +372,7 @@ pub fn map_to_language_model_completion_events(
                                 )));
                             }
                             Some(stop_reason) => {
-                                log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}",);
+                                log::error!("Unexpected Copilot Chat stop_reason: {stop_reason:?}");
                                 events.push(Ok(LanguageModelCompletionEvent::Stop(
                                     StopReason::EndTurn,
                                 )));