Add tool use support for OpenAI models (#28051)

Marshall Bowers created

This PR adds support for using tools to the OpenAI models.

Release Notes:

- agent: Added support for tool use with OpenAI models (Preview only).

Change summary

crates/language_models/src/provider/cloud.rs   |  16 
crates/language_models/src/provider/open_ai.rs | 207 +++++++++++++++++--
crates/open_ai/src/open_ai.rs                  |  13 -
3 files changed, 188 insertions(+), 48 deletions(-)

Detailed changes

crates/language_models/src/provider/cloud.rs 🔗

@@ -587,7 +587,7 @@ impl LanguageModel for CloudLanguageModel {
         match self.model {
             CloudModel::Anthropic(_) => true,
             CloudModel::Google(_) => true,
-            CloudModel::OpenAi(_) => false,
+            CloudModel::OpenAi(_) => true,
         }
     }
 
@@ -705,15 +705,13 @@ impl LanguageModel for CloudLanguageModel {
                         },
                     )
                     .await?;
-                    Ok(open_ai::extract_text_from_events(response_lines(response)))
+                    Ok(
+                        crate::provider::open_ai::map_to_language_model_completion_events(
+                            Box::pin(response_lines(response)),
+                        ),
+                    )
                 });
-                async move {
-                    Ok(future
-                        .await?
-                        .map(|result| result.map(LanguageModelCompletionEvent::Text))
-                        .boxed())
-                }
-                .boxed()
+                async move { Ok(future.await?.boxed()) }.boxed()
             }
             CloudModel::Google(model) => {
                 let client = self.client.clone();

crates/language_models/src/provider/open_ai.rs 🔗

@@ -1,7 +1,8 @@
 use anyhow::{Context as _, Result, anyhow};
-use collections::BTreeMap;
+use collections::{BTreeMap, HashMap};
 use credentials_provider::CredentialsProvider;
 use editor::{Editor, EditorElement, EditorStyle};
+use futures::Stream;
 use futures::{FutureExt, StreamExt, future::BoxFuture};
 use gpui::{
     AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
@@ -10,17 +11,20 @@ use http_client::HttpClient;
 use language_model::{
     AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
     LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
-    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+    LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
+    RateLimiter, Role, StopReason,
 };
 use open_ai::{ResponseStreamEvent, stream_completion};
 use schemars::JsonSchema;
 use serde::{Deserialize, Serialize};
 use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::str::FromStr as _;
 use std::sync::Arc;
 use strum::IntoEnumIterator;
 use theme::ThemeSettings;
 use ui::{Icon, IconName, List, Tooltip, prelude::*};
-use util::ResultExt;
+use util::{ResultExt, maybe};
 
 use crate::{AllLanguageModelSettings, ui::InstructionListItem};
 
@@ -289,7 +293,7 @@ impl LanguageModel for OpenAiLanguageModel {
     }
 
     fn supports_tools(&self) -> bool {
-        false
+        true
     }
 
     fn telemetry_id(&self) -> String {
@@ -322,12 +326,8 @@ impl LanguageModel for OpenAiLanguageModel {
     > {
         let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
         let completions = self.stream_completion(request, cx);
-        async move {
-            Ok(open_ai::extract_text_from_events(completions.await?)
-                .map(|result| result.map(LanguageModelCompletionEvent::Text))
-                .boxed())
-        }
-        .boxed()
+        async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) }
+            .boxed()
     }
 }
 
@@ -337,33 +337,186 @@ pub fn into_open_ai(
     max_output_tokens: Option<u32>,
 ) -> open_ai::Request {
     let stream = !model.starts_with("o1-");
+
+    let mut messages = Vec::new();
+    for message in request.messages {
+        for content in message.content {
+            match content {
+                MessageContent::Text(text) => messages.push(match message.role {
+                    Role::User => open_ai::RequestMessage::User { content: text },
+                    Role::Assistant => open_ai::RequestMessage::Assistant {
+                        content: Some(text),
+                        tool_calls: Vec::new(),
+                    },
+                    Role::System => open_ai::RequestMessage::System { content: text },
+                }),
+                MessageContent::Image(_) => {}
+                MessageContent::ToolUse(tool_use) => {
+                    let tool_call = open_ai::ToolCall {
+                        id: tool_use.id.to_string(),
+                        content: open_ai::ToolCallContent::Function {
+                            function: open_ai::FunctionContent {
+                                name: tool_use.name.to_string(),
+                                arguments: serde_json::to_string(&tool_use.input)
+                                    .unwrap_or_default(),
+                            },
+                        },
+                    };
+
+                    if let Some(last_assistant_message) = messages.iter_mut().rfind(|message| {
+                        matches!(message, open_ai::RequestMessage::Assistant { .. })
+                    }) {
+                        if let open_ai::RequestMessage::Assistant { tool_calls, .. } =
+                            last_assistant_message
+                        {
+                            tool_calls.push(tool_call);
+                        }
+                    } else {
+                        messages.push(open_ai::RequestMessage::Assistant {
+                            content: None,
+                            tool_calls: vec![tool_call],
+                        });
+                    }
+                }
+                MessageContent::ToolResult(tool_result) => {
+                    messages.push(open_ai::RequestMessage::Tool {
+                        content: tool_result.content.to_string(),
+                        tool_call_id: tool_result.tool_use_id.to_string(),
+                    });
+                }
+            }
+        }
+    }
+
     open_ai::Request {
         model,
-        messages: request
-            .messages
-            .into_iter()
-            .map(|msg| match msg.role {
-                Role::User => open_ai::RequestMessage::User {
-                    content: msg.string_contents(),
-                },
-                Role::Assistant => open_ai::RequestMessage::Assistant {
-                    content: Some(msg.string_contents()),
-                    tool_calls: Vec::new(),
-                },
-                Role::System => open_ai::RequestMessage::System {
-                    content: msg.string_contents(),
-                },
-            })
-            .collect(),
+        messages,
         stream,
         stop: request.stop,
         temperature: request.temperature.unwrap_or(1.0),
         max_tokens: max_output_tokens,
-        tools: Vec::new(),
+        tools: request
+            .tools
+            .into_iter()
+            .map(|tool| open_ai::ToolDefinition::Function {
+                function: open_ai::FunctionDefinition {
+                    name: tool.name,
+                    description: Some(tool.description),
+                    parameters: Some(tool.input_schema),
+                },
+            })
+            .collect(),
         tool_choice: None,
     }
 }
 
+pub fn map_to_language_model_completion_events(
+    events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
+    #[derive(Default)]
+    struct RawToolCall {
+        id: String,
+        name: String,
+        arguments: String,
+    }
+
+    struct State {
+        events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+        tool_calls_by_index: HashMap<usize, RawToolCall>,
+    }
+
+    futures::stream::unfold(
+        State {
+            events,
+            tool_calls_by_index: HashMap::default(),
+        },
+        |mut state| async move {
+            if let Some(event) = state.events.next().await {
+                match event {
+                    Ok(event) => {
+                        let Some(choice) = event.choices.first() else {
+                            return Some((
+                                vec![Err(anyhow!("Response contained no choices"))],
+                                state,
+                            ));
+                        };
+
+                        let mut events = Vec::new();
+                        if let Some(content) = choice.delta.content.clone() {
+                            events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+                        }
+
+                        if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
+                            for tool_call in tool_calls {
+                                let entry = state
+                                    .tool_calls_by_index
+                                    .entry(tool_call.index)
+                                    .or_default();
+
+                                if let Some(tool_id) = tool_call.id.clone() {
+                                    entry.id = tool_id;
+                                }
+
+                                if let Some(function) = tool_call.function.as_ref() {
+                                    if let Some(name) = function.name.clone() {
+                                        entry.name = name;
+                                    }
+
+                                    if let Some(arguments) = function.arguments.clone() {
+                                        entry.arguments.push_str(&arguments);
+                                    }
+                                }
+                            }
+                        }
+
+                        match choice.finish_reason.as_deref() {
+                            Some("stop") => {
+                                events.push(Ok(LanguageModelCompletionEvent::Stop(
+                                    StopReason::EndTurn,
+                                )));
+                            }
+                            Some("tool_calls") => {
+                                events.extend(state.tool_calls_by_index.drain().map(
+                                    |(_, tool_call)| {
+                                        maybe!({
+                                            Ok(LanguageModelCompletionEvent::ToolUse(
+                                                LanguageModelToolUse {
+                                                    id: tool_call.id.into(),
+                                                    name: tool_call.name.as_str().into(),
+                                                    input: serde_json::Value::from_str(
+                                                        &tool_call.arguments,
+                                                    )?,
+                                                },
+                                            ))
+                                        })
+                                    },
+                                ));
+
+                                events.push(Ok(LanguageModelCompletionEvent::Stop(
+                                    StopReason::ToolUse,
+                                )));
+                            }
+                            Some(stop_reason) => {
+                                log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
+                                events.push(Ok(LanguageModelCompletionEvent::Stop(
+                                    StopReason::EndTurn,
+                                )));
+                            }
+                            None => {}
+                        }
+
+                        return Some((events, state));
+                    }
+                    Err(err) => return Some((vec![Err(err)], state)),
+                }
+            }
+
+            None
+        },
+    )
+    .flat_map(futures::stream::iter)
+}
+
 pub fn count_open_ai_tokens(
     request: LanguageModelRequest,
     model: open_ai::Model,

crates/open_ai/src/open_ai.rs 🔗

@@ -2,7 +2,7 @@ mod supported_countries;
 
 use anyhow::{Context as _, Result, anyhow};
 use futures::{
-    AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
+    AsyncBufReadExt, AsyncReadExt, StreamExt,
     io::BufReader,
     stream::{self, BoxStream},
 };
@@ -618,14 +618,3 @@ pub fn embed<'a>(
         }
     }
 }
-
-pub fn extract_text_from_events(
-    response: impl Stream<Item = Result<ResponseStreamEvent>>,
-) -> impl Stream<Item = Result<String>> {
-    response.filter_map(|response| async move {
-        match response {
-            Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
-            Err(error) => Some(Err(error)),
-        }
-    })
-}