Detailed changes
@@ -4,7 +4,7 @@ use crate::{
LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
LanguageModelProviderState, LanguageModelRequest, RateLimiter, ZedModel,
};
-use anyhow::{anyhow, Context as _, Result};
+use anyhow::{anyhow, bail, Context as _, Result};
use client::{Client, PerformCompletionParams, UserStore, EXPIRED_LLM_TOKEN_HEADER_NAME};
use collections::BTreeMap;
use feature_flags::{FeatureFlag, FeatureFlagAppExt, LanguageModels};
@@ -634,14 +634,143 @@ impl LanguageModel for CloudLanguageModel {
})
.boxed()
}
- CloudModel::OpenAi(_) => {
- future::ready(Err(anyhow!("tool use not implemented for OpenAI"))).boxed()
+ CloudModel::OpenAi(model) => {
+ let mut request = request.into_open_ai(model.id().into());
+ let client = self.client.clone();
+ let mut function = open_ai::FunctionDefinition {
+ name: tool_name.clone(),
+ description: None,
+ parameters: None,
+ };
+ let func = open_ai::ToolDefinition::Function {
+ function: function.clone(),
+ };
+ request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
+ // Fill in description and params separately, as they're not needed for tool_choice field.
+ function.description = Some(tool_description);
+ function.parameters = Some(input_schema);
+ request.tools = vec![open_ai::ToolDefinition::Function { function }];
+ self.request_limiter
+ .run(async move {
+ let request = serde_json::to_string(&request)?;
+ let response = client
+ .request_stream(proto::StreamCompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::OpenAi as i32,
+ request,
+ })
+ .await?;
+ // Call arguments are gonna be streamed in over multiple chunks.
+ let mut load_state = None;
+ let mut response = response.map(
+ |item: Result<
+ proto::StreamCompleteWithLanguageModelResponse,
+ anyhow::Error,
+ >| {
+ Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
+ serde_json::from_str(&item?.event)?,
+ )
+ },
+ );
+ while let Some(Ok(part)) = response.next().await {
+ for choice in part.choices {
+ let Some(tool_calls) = choice.delta.tool_calls else {
+ continue;
+ };
+
+ for call in tool_calls {
+ if let Some(func) = call.function {
+ if func.name.as_deref() == Some(tool_name.as_str()) {
+ load_state = Some((String::default(), call.index));
+ }
+ if let Some((arguments, (output, index))) =
+ func.arguments.zip(load_state.as_mut())
+ {
+ if call.index == *index {
+ output.push_str(&arguments);
+ }
+ }
+ }
+ }
+ }
+ }
+ if let Some((arguments, _)) = load_state {
+ return Ok(serde_json::from_str(&arguments)?);
+ } else {
+ bail!("tool not used");
+ }
+ })
+ .boxed()
}
CloudModel::Google(_) => {
future::ready(Err(anyhow!("tool use not implemented for Google AI"))).boxed()
}
- CloudModel::Zed(_) => {
- future::ready(Err(anyhow!("tool use not implemented for Zed models"))).boxed()
+ CloudModel::Zed(model) => {
+ // All Zed models are OpenAI-based at the time of writing.
+ let mut request = request.into_open_ai(model.id().into());
+ let client = self.client.clone();
+ let mut function = open_ai::FunctionDefinition {
+ name: tool_name.clone(),
+ description: None,
+ parameters: None,
+ };
+ let func = open_ai::ToolDefinition::Function {
+ function: function.clone(),
+ };
+ request.tool_choice = Some(open_ai::ToolChoice::Other(func.clone()));
+ // Fill in description and params separately, as they're not needed for tool_choice field.
+ function.description = Some(tool_description);
+ function.parameters = Some(input_schema);
+ request.tools = vec![open_ai::ToolDefinition::Function { function }];
+ self.request_limiter
+ .run(async move {
+ let request = serde_json::to_string(&request)?;
+ let response = client
+ .request_stream(proto::StreamCompleteWithLanguageModel {
+ provider: proto::LanguageModelProvider::OpenAi as i32,
+ request,
+ })
+ .await?;
+ // Call arguments are gonna be streamed in over multiple chunks.
+ let mut load_state = None;
+ let mut response = response.map(
+ |item: Result<
+ proto::StreamCompleteWithLanguageModelResponse,
+ anyhow::Error,
+ >| {
+ Result::<open_ai::ResponseStreamEvent, anyhow::Error>::Ok(
+ serde_json::from_str(&item?.event)?,
+ )
+ },
+ );
+ while let Some(Ok(part)) = response.next().await {
+ for choice in part.choices {
+ let Some(tool_calls) = choice.delta.tool_calls else {
+ continue;
+ };
+
+ for call in tool_calls {
+ if let Some(func) = call.function {
+ if func.name.as_deref() == Some(tool_name.as_str()) {
+ load_state = Some((String::default(), call.index));
+ }
+ if let Some((arguments, (output, index))) =
+ func.arguments.zip(load_state.as_mut())
+ {
+ if call.index == *index {
+ output.push_str(&arguments);
+ }
+ }
+ }
+ }
+ }
+ }
+ if let Some((arguments, _)) = load_state {
+ return Ok(serde_json::from_str(&arguments)?);
+ } else {
+ bail!("tool not used");
+ }
+ })
+ .boxed()
}
}
}
@@ -1,12 +1,14 @@
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, bail, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use http_client::HttpClient;
use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
+ ChatResponseDelta, OllamaToolCall,
};
+use serde_json::Value;
use settings::{Settings, SettingsStore};
-use std::{future, sync::Arc, time::Duration};
+use std::{sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, Indicator};
use util::ResultExt;
@@ -184,6 +186,7 @@ impl OllamaLanguageModel {
},
Role::Assistant => ChatMessage::Assistant {
content: msg.content,
+ tool_calls: None,
},
Role::System => ChatMessage::System {
content: msg.content,
@@ -198,8 +201,25 @@ impl OllamaLanguageModel {
temperature: Some(request.temperature),
..Default::default()
}),
+ tools: vec![],
}
}
+ fn request_completion(
+ &self,
+ request: ChatRequest,
+ cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
+ let http_client = self.http_client.clone();
+
+ let Ok(api_url) = cx.update(|cx| {
+ let settings = &AllLanguageModelSettings::get_global(cx).ollama;
+ settings.api_url.clone()
+ }) else {
+ return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ };
+
+ async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
+ }
}
impl LanguageModel for OllamaLanguageModel {
@@ -269,7 +289,7 @@ impl LanguageModel for OllamaLanguageModel {
Ok(delta) => {
let content = match delta.message {
ChatMessage::User { content } => content,
- ChatMessage::Assistant { content } => content,
+ ChatMessage::Assistant { content, .. } => content,
ChatMessage::System { content } => content,
};
Some(Ok(content))
@@ -286,13 +306,48 @@ impl LanguageModel for OllamaLanguageModel {
fn use_any_tool(
&self,
- _request: LanguageModelRequest,
- _name: String,
- _description: String,
- _schema: serde_json::Value,
- _cx: &AsyncAppContext,
+ request: LanguageModelRequest,
+ tool_name: String,
+ tool_description: String,
+ schema: serde_json::Value,
+ cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
- future::ready(Err(anyhow!("not implemented"))).boxed()
+ use ollama::{OllamaFunctionTool, OllamaTool};
+ let function = OllamaFunctionTool {
+ name: tool_name.clone(),
+ description: Some(tool_description),
+ parameters: Some(schema),
+ };
+ let tools = vec![OllamaTool::Function { function }];
+ let request = self.to_ollama_request(request).with_tools(tools);
+ let response = self.request_completion(request, cx);
+ self.request_limiter
+ .run(async move {
+ let response = response.await?;
+ let ChatMessage::Assistant {
+ tool_calls,
+ content,
+ } = response.message
+ else {
+ bail!("message does not have an assistant role");
+ };
+ if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
+ for call in tool_calls {
+ let OllamaToolCall::Function(function) = call;
+ if function.name == tool_name {
+ return Ok(function.arguments);
+ }
+ }
+ } else if let Ok(args) = serde_json::from_str::<Value>(&content) {
+ // Parse content as arguments.
+ return Ok(args);
+ } else {
+ bail!("assistant message does not have any tool calls");
+ };
+
+ bail!("tool not used")
+ })
+ .boxed()
}
}
@@ -1,4 +1,4 @@
-use anyhow::{anyhow, Result};
+use anyhow::{anyhow, bail, Result};
use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
@@ -7,11 +7,13 @@ use gpui::{
View, WhiteSpace,
};
use http_client::HttpClient;
-use open_ai::stream_completion;
+use open_ai::{
+ stream_completion, FunctionDefinition, ResponseStreamEvent, ToolChoice, ToolDefinition,
+};
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsStore};
-use std::{future, sync::Arc, time::Duration};
+use std::{sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::{prelude::*, Indicator};
@@ -206,6 +208,41 @@ pub struct OpenAiLanguageModel {
request_limiter: RateLimiter,
}
+impl OpenAiLanguageModel {
+ fn stream_completion(
+ &self,
+ request: open_ai::Request,
+ cx: &AsyncAppContext,
+ ) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<ResponseStreamEvent>>>>
+ {
+ let http_client = self.http_client.clone();
+ let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
+ let settings = &AllLanguageModelSettings::get_global(cx).openai;
+ (
+ state.api_key.clone(),
+ settings.api_url.clone(),
+ settings.low_speed_timeout,
+ )
+ }) else {
+ return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
+ };
+
+ let future = self.request_limiter.stream(async move {
+ let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
+ let request = stream_completion(
+ http_client.as_ref(),
+ &api_url,
+ &api_key,
+ request,
+ low_speed_timeout,
+ );
+ let response = request.await?;
+ Ok(response)
+ });
+
+ async move { Ok(future.await?.boxed()) }.boxed()
+ }
+}
impl LanguageModel for OpenAiLanguageModel {
fn id(&self) -> LanguageModelId {
self.id.clone()
@@ -245,44 +282,68 @@ impl LanguageModel for OpenAiLanguageModel {
cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<futures::stream::BoxStream<'static, Result<String>>>> {
let request = request.into_open_ai(self.model.id().into());
-
- let http_client = self.http_client.clone();
- let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
- let settings = &AllLanguageModelSettings::get_global(cx).openai;
- (
- state.api_key.clone(),
- settings.api_url.clone(),
- settings.low_speed_timeout,
- )
- }) else {
- return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
- };
-
- let future = self.request_limiter.stream(async move {
- let api_key = api_key.ok_or_else(|| anyhow!("missing api key"))?;
- let request = stream_completion(
- http_client.as_ref(),
- &api_url,
- &api_key,
- request,
- low_speed_timeout,
- );
- let response = request.await?;
- Ok(open_ai::extract_text_from_events(response).boxed())
- });
-
- async move { Ok(future.await?.boxed()) }.boxed()
+ let completions = self.stream_completion(request, cx);
+ async move { Ok(open_ai::extract_text_from_events(completions.await?).boxed()) }.boxed()
}
fn use_any_tool(
&self,
- _request: LanguageModelRequest,
- _name: String,
- _description: String,
- _schema: serde_json::Value,
- _cx: &AsyncAppContext,
+ request: LanguageModelRequest,
+ tool_name: String,
+ tool_description: String,
+ schema: serde_json::Value,
+ cx: &AsyncAppContext,
) -> BoxFuture<'static, Result<serde_json::Value>> {
- future::ready(Err(anyhow!("not implemented"))).boxed()
+ let mut request = request.into_open_ai(self.model.id().into());
+ let mut function = FunctionDefinition {
+ name: tool_name.clone(),
+ description: None,
+ parameters: None,
+ };
+ let func = ToolDefinition::Function {
+ function: function.clone(),
+ };
+ request.tool_choice = Some(ToolChoice::Other(func.clone()));
+ // Fill in description and params separately, as they're not needed for tool_choice field.
+ function.description = Some(tool_description);
+ function.parameters = Some(schema);
+ request.tools = vec![ToolDefinition::Function { function }];
+ let response = self.stream_completion(request, cx);
+ self.request_limiter
+ .run(async move {
+ let mut response = response.await?;
+
+ // Call arguments are gonna be streamed in over multiple chunks.
+ let mut load_state = None;
+ while let Some(Ok(part)) = response.next().await {
+ for choice in part.choices {
+ let Some(tool_calls) = choice.delta.tool_calls else {
+ continue;
+ };
+
+ for call in tool_calls {
+ if let Some(func) = call.function {
+ if func.name.as_deref() == Some(tool_name.as_str()) {
+ load_state = Some((String::default(), call.index));
+ }
+ if let Some((arguments, (output, index))) =
+ func.arguments.zip(load_state.as_mut())
+ {
+ if call.index == *index {
+ output.push_str(&arguments);
+ }
+ }
+ }
+ }
+ }
+ }
+ if let Some((arguments, _)) = load_state {
+ return Ok(serde_json::from_str(&arguments)?);
+ } else {
+ bail!("tool not used");
+ }
+ })
+ .boxed()
}
}
@@ -4,6 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
+use serde_json::Value;
use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@@ -94,22 +95,63 @@ impl Model {
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum ChatMessage {
- Assistant { content: String },
- User { content: String },
- System { content: String },
+ Assistant {
+ content: String,
+ tool_calls: Option<Vec<OllamaToolCall>>,
+ },
+ User {
+ content: String,
+ },
+ System {
+ content: String,
+ },
}
-#[derive(Serialize)]
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum OllamaToolCall {
+ Function(OllamaFunctionCall),
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct OllamaFunctionCall {
+ pub name: String,
+ pub arguments: Value,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct OllamaFunctionTool {
+ pub name: String,
+ pub description: Option<String>,
+ pub parameters: Option<Value>,
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(tag = "type", rename_all = "lowercase")]
+pub enum OllamaTool {
+ Function { function: OllamaFunctionTool },
+}
+
+#[derive(Serialize, Debug)]
pub struct ChatRequest {
pub model: String,
pub messages: Vec<ChatMessage>,
pub stream: bool,
pub keep_alive: KeepAlive,
pub options: Option<ChatOptions>,
+ pub tools: Vec<OllamaTool>,
+}
+
+impl ChatRequest {
+ pub fn with_tools(mut self, tools: Vec<OllamaTool>) -> Self {
+ self.stream = false;
+ self.tools = tools;
+ self
+ }
}
// https://github.com/ollama/ollama/blob/main/docs/modelfile.md#valid-parameters-and-values
-#[derive(Serialize, Default)]
+#[derive(Serialize, Default, Debug)]
pub struct ChatOptions {
pub num_ctx: Option<usize>,
pub num_predict: Option<isize>,
@@ -118,7 +160,7 @@ pub struct ChatOptions {
pub top_p: Option<f32>,
}
-#[derive(Deserialize)]
+#[derive(Deserialize, Debug)]
pub struct ChatResponseDelta {
#[allow(unused)]
pub model: String,
@@ -162,6 +204,38 @@ pub struct ModelDetails {
pub quantization_level: String,
}
+pub async fn complete(
+ client: &dyn HttpClient,
+ api_url: &str,
+ request: ChatRequest,
+) -> Result<ChatResponseDelta> {
+ let uri = format!("{api_url}/api/chat");
+ let request_builder = HttpRequest::builder()
+ .method(Method::POST)
+ .uri(uri)
+ .header("Content-Type", "application/json");
+
+ let serialized_request = serde_json::to_string(&request)?;
+ let request = request_builder.body(AsyncBody::from(serialized_request))?;
+
+ let mut response = client.send(request).await?;
+ if response.status().is_success() {
+ let mut body = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+ let response_message: ChatResponseDelta = serde_json::from_slice(&body)?;
+ Ok(response_message)
+ } else {
+ let mut body = Vec::new();
+ response.body_mut().read_to_end(&mut body).await?;
+ let body_str = std::str::from_utf8(&body)?;
+ Err(anyhow!(
+ "Failed to connect to API: {} {}",
+ response.status(),
+ body_str
+ ))
+ }
+}
+
pub async fn stream_chat_completion(
client: &dyn HttpClient,
api_url: &str,
@@ -3,7 +3,7 @@ use futures::{io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, S
use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use serde::{Deserialize, Serialize};
-use serde_json::{Map, Value};
+use serde_json::Value;
use std::{convert::TryFrom, future::Future, time::Duration};
use strum::EnumIter;
@@ -121,25 +121,34 @@ pub struct Request {
pub stop: Vec<String>,
pub temperature: f32,
#[serde(default, skip_serializing_if = "Option::is_none")]
- pub tool_choice: Option<String>,
+ pub tool_choice: Option<ToolChoice>,
#[serde(default, skip_serializing_if = "Vec::is_empty")]
pub tools: Vec<ToolDefinition>,
}
-#[derive(Debug, Deserialize, Serialize)]
-pub struct FunctionDefinition {
- pub name: String,
- pub description: Option<String>,
- pub parameters: Option<Map<String, Value>>,
+#[derive(Debug, Serialize, Deserialize)]
+#[serde(untagged)]
+pub enum ToolChoice {
+ Auto,
+ Required,
+ None,
+ Other(ToolDefinition),
}
-#[derive(Deserialize, Serialize, Debug)]
+#[derive(Clone, Deserialize, Serialize, Debug)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum ToolDefinition {
#[allow(dead_code)]
Function { function: FunctionDefinition },
}
+#[derive(Clone, Debug, Serialize, Deserialize)]
+pub struct FunctionDefinition {
+ pub name: String,
+ pub description: Option<String>,
+ pub parameters: Option<Value>,
+}
+
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "role", rename_all = "lowercase")]
pub enum RequestMessage {