@@ -1,7 +1,8 @@
use anyhow::{Context as _, Result, anyhow};
-use collections::BTreeMap;
+use collections::{BTreeMap, HashMap};
use credentials_provider::CredentialsProvider;
use editor::{Editor, EditorElement, EditorStyle};
+use futures::Stream;
use futures::{FutureExt, StreamExt, future::BoxFuture};
use gpui::{
AnyView, App, AsyncApp, Context, Entity, FontStyle, Subscription, Task, TextStyle, WhiteSpace,
@@ -10,17 +11,20 @@ use http_client::HttpClient;
use language_model::{
AuthenticateError, LanguageModel, LanguageModelCompletionEvent, LanguageModelId,
LanguageModelName, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
- LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
+ LanguageModelProviderState, LanguageModelRequest, LanguageModelToolUse, MessageContent,
+ RateLimiter, Role, StopReason,
};
use open_ai::{ResponseStreamEvent, stream_completion};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
+use std::pin::Pin;
+use std::str::FromStr as _;
use std::sync::Arc;
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{Icon, IconName, List, Tooltip, prelude::*};
-use util::ResultExt;
+use util::{ResultExt, maybe};
use crate::{AllLanguageModelSettings, ui::InstructionListItem};
@@ -289,7 +293,7 @@ impl LanguageModel for OpenAiLanguageModel {
}
fn supports_tools(&self) -> bool {
- false
+ true
}
fn telemetry_id(&self) -> String {
@@ -322,12 +326,8 @@ impl LanguageModel for OpenAiLanguageModel {
> {
let request = into_open_ai(request, self.model.id().into(), self.max_output_tokens());
let completions = self.stream_completion(request, cx);
- async move {
- Ok(open_ai::extract_text_from_events(completions.await?)
- .map(|result| result.map(LanguageModelCompletionEvent::Text))
- .boxed())
- }
- .boxed()
+ async move { Ok(map_to_language_model_completion_events(completions.await?).boxed()) }
+ .boxed()
}
}
@@ -337,33 +337,186 @@ pub fn into_open_ai(
max_output_tokens: Option<u32>,
) -> open_ai::Request {
let stream = !model.starts_with("o1-");
+
+ let mut messages = Vec::new();
+ for message in request.messages {
+ for content in message.content {
+ match content {
+ MessageContent::Text(text) => messages.push(match message.role {
+ Role::User => open_ai::RequestMessage::User { content: text },
+ Role::Assistant => open_ai::RequestMessage::Assistant {
+ content: Some(text),
+ tool_calls: Vec::new(),
+ },
+ Role::System => open_ai::RequestMessage::System { content: text },
+ }),
+ MessageContent::Image(_) => {}
+ MessageContent::ToolUse(tool_use) => {
+ let tool_call = open_ai::ToolCall {
+ id: tool_use.id.to_string(),
+ content: open_ai::ToolCallContent::Function {
+ function: open_ai::FunctionContent {
+ name: tool_use.name.to_string(),
+ arguments: serde_json::to_string(&tool_use.input)
+ .unwrap_or_default(),
+ },
+ },
+ };
+
+ if let Some(last_assistant_message) = messages.iter_mut().rfind(|message| {
+ matches!(message, open_ai::RequestMessage::Assistant { .. })
+ }) {
+ if let open_ai::RequestMessage::Assistant { tool_calls, .. } =
+ last_assistant_message
+ {
+ tool_calls.push(tool_call);
+ }
+ } else {
+ messages.push(open_ai::RequestMessage::Assistant {
+ content: None,
+ tool_calls: vec![tool_call],
+ });
+ }
+ }
+ MessageContent::ToolResult(tool_result) => {
+ messages.push(open_ai::RequestMessage::Tool {
+ content: tool_result.content.to_string(),
+ tool_call_id: tool_result.tool_use_id.to_string(),
+ });
+ }
+ }
+ }
+ }
+
open_ai::Request {
model,
- messages: request
- .messages
- .into_iter()
- .map(|msg| match msg.role {
- Role::User => open_ai::RequestMessage::User {
- content: msg.string_contents(),
- },
- Role::Assistant => open_ai::RequestMessage::Assistant {
- content: Some(msg.string_contents()),
- tool_calls: Vec::new(),
- },
- Role::System => open_ai::RequestMessage::System {
- content: msg.string_contents(),
- },
- })
- .collect(),
+ messages,
stream,
stop: request.stop,
temperature: request.temperature.unwrap_or(1.0),
max_tokens: max_output_tokens,
- tools: Vec::new(),
+ tools: request
+ .tools
+ .into_iter()
+ .map(|tool| open_ai::ToolDefinition::Function {
+ function: open_ai::FunctionDefinition {
+ name: tool.name,
+ description: Some(tool.description),
+ parameters: Some(tool.input_schema),
+ },
+ })
+ .collect(),
tool_choice: None,
}
}
+pub fn map_to_language_model_completion_events(
+ events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+) -> impl Stream<Item = Result<LanguageModelCompletionEvent>> {
+ #[derive(Default)]
+ struct RawToolCall {
+ id: String,
+ name: String,
+ arguments: String,
+ }
+
+ struct State {
+ events: Pin<Box<dyn Send + Stream<Item = Result<ResponseStreamEvent>>>>,
+ tool_calls_by_index: HashMap<usize, RawToolCall>,
+ }
+
+ futures::stream::unfold(
+ State {
+ events,
+ tool_calls_by_index: HashMap::default(),
+ },
+ |mut state| async move {
+ if let Some(event) = state.events.next().await {
+ match event {
+ Ok(event) => {
+ let Some(choice) = event.choices.first() else {
+ return Some((
+ vec![Err(anyhow!("Response contained no choices"))],
+ state,
+ ));
+ };
+
+ let mut events = Vec::new();
+ if let Some(content) = choice.delta.content.clone() {
+ events.push(Ok(LanguageModelCompletionEvent::Text(content)));
+ }
+
+ if let Some(tool_calls) = choice.delta.tool_calls.as_ref() {
+ for tool_call in tool_calls {
+ let entry = state
+ .tool_calls_by_index
+ .entry(tool_call.index)
+ .or_default();
+
+ if let Some(tool_id) = tool_call.id.clone() {
+ entry.id = tool_id;
+ }
+
+ if let Some(function) = tool_call.function.as_ref() {
+ if let Some(name) = function.name.clone() {
+ entry.name = name;
+ }
+
+ if let Some(arguments) = function.arguments.clone() {
+ entry.arguments.push_str(&arguments);
+ }
+ }
+ }
+ }
+
+ match choice.finish_reason.as_deref() {
+ Some("stop") => {
+ events.push(Ok(LanguageModelCompletionEvent::Stop(
+ StopReason::EndTurn,
+ )));
+ }
+ Some("tool_calls") => {
+ events.extend(state.tool_calls_by_index.drain().map(
+ |(_, tool_call)| {
+ maybe!({
+ Ok(LanguageModelCompletionEvent::ToolUse(
+ LanguageModelToolUse {
+ id: tool_call.id.into(),
+ name: tool_call.name.as_str().into(),
+ input: serde_json::Value::from_str(
+ &tool_call.arguments,
+ )?,
+ },
+ ))
+ })
+ },
+ ));
+
+ events.push(Ok(LanguageModelCompletionEvent::Stop(
+ StopReason::ToolUse,
+ )));
+ }
+ Some(stop_reason) => {
+ log::error!("Unexpected OpenAI stop_reason: {stop_reason:?}",);
+ events.push(Ok(LanguageModelCompletionEvent::Stop(
+ StopReason::EndTurn,
+ )));
+ }
+ None => {}
+ }
+
+ return Some((events, state));
+ }
+ Err(err) => return Some((vec![Err(err)], state)),
+ }
+ }
+
+ None
+ },
+ )
+ .flat_map(futures::stream::iter)
+}
+
pub fn count_open_ai_tokens(
request: LanguageModelRequest,
model: open_ai::Model,
@@ -2,7 +2,7 @@ mod supported_countries;
use anyhow::{Context as _, Result, anyhow};
use futures::{
- AsyncBufReadExt, AsyncReadExt, Stream, StreamExt,
+ AsyncBufReadExt, AsyncReadExt, StreamExt,
io::BufReader,
stream::{self, BoxStream},
};
@@ -618,14 +618,3 @@ pub fn embed<'a>(
}
}
}
-
-pub fn extract_text_from_events(
- response: impl Stream<Item = Result<ResponseStreamEvent>>,
-) -> impl Stream<Item = Result<String>> {
- response.filter_map(|response| async move {
- match response {
- Ok(mut response) => Some(Ok(response.choices.pop()?.delta.content?)),
- Err(error) => Some(Err(error)),
- }
- })
-}