Cargo.lock 🔗
@@ -93,6 +93,7 @@ dependencies = [
"postage",
"rand 0.8.5",
"rusqlite",
+ "schemars",
"serde",
"serde_json",
"tiktoken-rs",
Marshall Bowers created
This PR wires up support for [Azure
OpenAI](https://learn.microsoft.com/en-us/azure/ai-services/openai/overview)
as an alternative AI provider in the assistant panel.
This can be configured using the following in the settings file:
```json
{
"assistant": {
"provider": {
"type": "azure_openai",
"api_url": "https://{your-resource-name}.openai.azure.com",
"deployment_id": "gpt-4",
"api_version": "2023-05-15"
}
},
}
```
You will need to deploy a model within Azure and update the settings
accordingly.
Release Notes:
- N/A
Cargo.lock | 1
assets/settings/default.json | 16 +
crates/ai/Cargo.toml | 1
crates/ai/src/providers/open_ai/completion.rs | 67 ++++++-
crates/assistant/src/assistant_panel.rs | 77 ++++---
crates/assistant/src/assistant_settings.rs | 193 ++++++++++++++++++--
crates/client/src/telemetry.rs | 2
7 files changed, 291 insertions(+), 66 deletions(-)
@@ -93,6 +93,7 @@ dependencies = [
"postage",
"rand 0.8.5",
"rusqlite",
+ "schemars",
"serde",
"serde_json",
"tiktoken-rs",
@@ -228,15 +228,29 @@
"default_width": 640,
// Default height when the assistant is docked to the bottom.
"default_height": 320,
+ // Deprecated: Please use `provider.api_url` instead.
// The default OpenAI API endpoint to use when starting new conversations.
"openai_api_url": "https://api.openai.com/v1",
+ // Deprecated: Please use `provider.default_model` instead.
// The default OpenAI model to use when starting new conversations. This
// setting can take three values:
//
// 1. "gpt-3.5-turbo-0613""
// 2. "gpt-4-0613""
// 3. "gpt-4-1106-preview"
- "default_open_ai_model": "gpt-4-1106-preview"
+ "default_open_ai_model": "gpt-4-1106-preview",
+ "provider": {
+ "type": "openai",
+ // The default OpenAI API endpoint to use when starting new conversations.
+ "api_url": "https://api.openai.com/v1",
+ // The default OpenAI model to use when starting new conversations. This
+ // setting can take three values:
+ //
+ // 1. "gpt-3.5-turbo-0613""
+ // 2. "gpt-4-0613""
+ // 3. "gpt-4-1106-preview"
+ "default_model": "gpt-4-1106-preview"
+ }
},
// Whether the screen sharing icon is shown in the os status bar.
"show_call_status_icon": true,
@@ -29,6 +29,7 @@ parse_duration = "2.1.1"
postage.workspace = true
rand.workspace = true
rusqlite = { version = "0.29.0", features = ["blob", "array", "modern_sqlite"] }
+schemars.workspace = true
serde.workspace = true
serde_json.workspace = true
tiktoken-rs.workspace = true
@@ -1,3 +1,10 @@
+use std::{
+ env,
+ fmt::{self, Display},
+ io,
+ sync::Arc,
+};
+
use anyhow::{anyhow, Result};
use futures::{
future::BoxFuture, io::BufReader, stream::BoxStream, AsyncBufReadExt, AsyncReadExt, FutureExt,
@@ -6,23 +13,17 @@ use futures::{
use gpui::{AppContext, BackgroundExecutor};
use isahc::{http::StatusCode, Request, RequestExt};
use parking_lot::RwLock;
+use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use std::{
- env,
- fmt::{self, Display},
- io,
- sync::Arc,
-};
use util::ResultExt;
+use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
use crate::{
auth::{CredentialProvider, ProviderCredential},
completion::{CompletionProvider, CompletionRequest},
models::LanguageModel,
};
-use crate::providers::open_ai::{OpenAiLanguageModel, OPEN_AI_API_URL};
-
#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(rename_all = "lowercase")]
pub enum Role {
@@ -196,12 +197,56 @@ async fn stream_completion(
}
}
+#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema)]
+pub enum AzureOpenAiApiVersion {
+ /// Retiring April 2, 2024.
+ #[serde(rename = "2023-03-15-preview")]
+ V2023_03_15Preview,
+ #[serde(rename = "2023-05-15")]
+ V2023_05_15,
+ /// Retiring April 2, 2024.
+ #[serde(rename = "2023-06-01-preview")]
+ V2023_06_01Preview,
+ /// Retiring April 2, 2024.
+ #[serde(rename = "2023-07-01-preview")]
+ V2023_07_01Preview,
+ /// Retiring April 2, 2024.
+ #[serde(rename = "2023-08-01-preview")]
+ V2023_08_01Preview,
+ /// Retiring April 2, 2024.
+ #[serde(rename = "2023-09-01-preview")]
+ V2023_09_01Preview,
+ #[serde(rename = "2023-12-01-preview")]
+ V2023_12_01Preview,
+ #[serde(rename = "2024-02-15-preview")]
+ V2024_02_15Preview,
+}
+
+impl fmt::Display for AzureOpenAiApiVersion {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
+ write!(
+ f,
+ "{}",
+ match self {
+ Self::V2023_03_15Preview => "2023-03-15-preview",
+ Self::V2023_05_15 => "2023-05-15",
+ Self::V2023_06_01Preview => "2023-06-01-preview",
+ Self::V2023_07_01Preview => "2023-07-01-preview",
+ Self::V2023_08_01Preview => "2023-08-01-preview",
+ Self::V2023_09_01Preview => "2023-09-01-preview",
+ Self::V2023_12_01Preview => "2023-12-01-preview",
+ Self::V2024_02_15Preview => "2024-02-15-preview",
+ }
+ )
+ }
+}
+
#[derive(Clone)]
pub enum OpenAiCompletionProviderKind {
OpenAi,
AzureOpenAi {
deployment_id: String,
- api_version: String,
+ api_version: AzureOpenAiApiVersion,
},
}
@@ -217,8 +262,8 @@ impl OpenAiCompletionProviderKind {
deployment_id,
api_version,
} => {
- // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#completions
- format!("{api_url}/openai/deployments/{deployment_id}/completions?api-version={api_version}")
+ // https://learn.microsoft.com/en-us/azure/ai-services/openai/reference#chat-completions
+ format!("{api_url}/openai/deployments/{deployment_id}/chat/completions?api-version={api_version}")
}
}
}
@@ -124,16 +124,18 @@ impl AssistantPanel {
.await
.log_err()
.unwrap_or_default();
- let (api_url, model_name) = cx.update(|cx| {
+ let (provider_kind, api_url, model_name) = cx.update(|cx| {
let settings = AssistantSettings::get_global(cx);
- (
- settings.openai_api_url.clone(),
- settings.default_open_ai_model.full_name().to_string(),
- )
- })?;
+ anyhow::Ok((
+ settings.provider_kind()?,
+ settings.provider_api_url()?,
+ settings.provider_model_name()?,
+ ))
+ })??;
+
let completion_provider = OpenAiCompletionProvider::new(
api_url,
- OpenAiCompletionProviderKind::OpenAi,
+ provider_kind,
model_name,
cx.background_executor().clone(),
)
@@ -693,24 +695,29 @@ impl AssistantPanel {
Task::ready(Ok(Vec::new()))
};
- let mut model = AssistantSettings::get_global(cx)
- .default_open_ai_model
- .clone();
- let model_name = model.full_name();
-
- let prompt = cx.background_executor().spawn(async move {
- let snippets = snippets.await?;
+ let Some(mut model_name) = AssistantSettings::get_global(cx)
+ .provider_model_name()
+ .log_err()
+ else {
+ return;
+ };
- let language_name = language_name.as_deref();
- generate_content_prompt(
- user_prompt,
- language_name,
- buffer,
- range,
- snippets,
- model_name,
- project_name,
- )
+ let prompt = cx.background_executor().spawn({
+ let model_name = model_name.clone();
+ async move {
+ let snippets = snippets.await?;
+
+ let language_name = language_name.as_deref();
+ generate_content_prompt(
+ user_prompt,
+ language_name,
+ buffer,
+ range,
+ snippets,
+ &model_name,
+ project_name,
+ )
+ }
});
let mut messages = Vec::new();
@@ -722,7 +729,7 @@ impl AssistantPanel {
.messages(cx)
.map(|message| message.to_open_ai_message(buffer)),
);
- model = conversation.model.clone();
+ model_name = conversation.model.full_name().to_string();
}
cx.spawn(|_, mut cx| async move {
@@ -735,7 +742,7 @@ impl AssistantPanel {
});
let request = Box::new(OpenAiRequest {
- model: model.full_name().into(),
+ model: model_name,
messages,
stream: true,
stop: vec!["|END|>".to_string()],
@@ -1454,8 +1461,14 @@ impl Conversation {
});
let settings = AssistantSettings::get_global(cx);
- let model = settings.default_open_ai_model.clone();
- let api_url = settings.openai_api_url.clone();
+ let model = settings
+ .provider_model()
+ .log_err()
+ .unwrap_or(OpenAiModel::FourTurbo);
+ let api_url = settings
+ .provider_api_url()
+ .log_err()
+ .unwrap_or_else(|| OPEN_AI_API_URL.to_string());
let mut this = Self {
id: Some(Uuid::new_v4().to_string()),
@@ -3655,9 +3668,9 @@ fn report_assistant_event(
let client = workspace.read(cx).project().read(cx).client();
let telemetry = client.telemetry();
- let model = AssistantSettings::get_global(cx)
- .default_open_ai_model
- .clone();
+ let Ok(model_name) = AssistantSettings::get_global(cx).provider_model_name() else {
+ return;
+ };
- telemetry.report_assistant_event(conversation_id, assistant_kind, model.full_name())
+ telemetry.report_assistant_event(conversation_id, assistant_kind, &model_name)
}
@@ -1,10 +1,14 @@
-use anyhow;
+use ai::providers::open_ai::{
+ AzureOpenAiApiVersion, OpenAiCompletionProviderKind, OPEN_AI_API_URL,
+};
+use anyhow::anyhow;
use gpui::Pixels;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
use settings::Settings;
-#[derive(Clone, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
+#[derive(Clone, Copy, Debug, Serialize, Deserialize, JsonSchema, PartialEq)]
+#[serde(rename_all = "snake_case")]
pub enum OpenAiModel {
#[serde(rename = "gpt-3.5-turbo-0613")]
ThreePointFiveTurbo,
@@ -17,25 +21,25 @@ pub enum OpenAiModel {
impl OpenAiModel {
pub fn full_name(&self) -> &'static str {
match self {
- OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
- OpenAiModel::Four => "gpt-4-0613",
- OpenAiModel::FourTurbo => "gpt-4-1106-preview",
+ Self::ThreePointFiveTurbo => "gpt-3.5-turbo-0613",
+ Self::Four => "gpt-4-0613",
+ Self::FourTurbo => "gpt-4-1106-preview",
}
}
pub fn short_name(&self) -> &'static str {
match self {
- OpenAiModel::ThreePointFiveTurbo => "gpt-3.5-turbo",
- OpenAiModel::Four => "gpt-4",
- OpenAiModel::FourTurbo => "gpt-4-turbo",
+ Self::ThreePointFiveTurbo => "gpt-3.5-turbo",
+ Self::Four => "gpt-4",
+ Self::FourTurbo => "gpt-4-turbo",
}
}
pub fn cycle(&self) -> Self {
match self {
- OpenAiModel::ThreePointFiveTurbo => OpenAiModel::Four,
- OpenAiModel::Four => OpenAiModel::FourTurbo,
- OpenAiModel::FourTurbo => OpenAiModel::ThreePointFiveTurbo,
+ Self::ThreePointFiveTurbo => Self::Four,
+ Self::Four => Self::FourTurbo,
+ Self::FourTurbo => Self::ThreePointFiveTurbo,
}
}
}
@@ -48,14 +52,99 @@ pub enum AssistantDockPosition {
Bottom,
}
-#[derive(Deserialize, Debug)]
+#[derive(Debug, Deserialize)]
pub struct AssistantSettings {
+ /// Whether to show the assistant panel button in the status bar.
pub button: bool,
+ /// Where to dock the assistant.
pub dock: AssistantDockPosition,
+ /// Default width in pixels when the assistant is docked to the left or right.
pub default_width: Pixels,
+ /// Default height in pixels when the assistant is docked to the bottom.
pub default_height: Pixels,
+ /// The default OpenAI model to use when starting new conversations.
+ #[deprecated = "Please use `provider.default_model` instead."]
pub default_open_ai_model: OpenAiModel,
+ /// OpenAI API base URL to use when starting new conversations.
+ #[deprecated = "Please use `provider.api_url` instead."]
pub openai_api_url: String,
+ /// The settings for the AI provider.
+ pub provider: AiProviderSettings,
+}
+
+impl AssistantSettings {
+ pub fn provider_kind(&self) -> anyhow::Result<OpenAiCompletionProviderKind> {
+ match &self.provider {
+ AiProviderSettings::OpenAi(_) => Ok(OpenAiCompletionProviderKind::OpenAi),
+ AiProviderSettings::AzureOpenAi(settings) => {
+ let deployment_id = settings
+ .deployment_id
+ .clone()
+ .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID"))?;
+ let api_version = settings
+ .api_version
+ .ok_or_else(|| anyhow!("no Azure OpenAI API version"))?;
+
+ Ok(OpenAiCompletionProviderKind::AzureOpenAi {
+ deployment_id,
+ api_version,
+ })
+ }
+ }
+ }
+
+ pub fn provider_api_url(&self) -> anyhow::Result<String> {
+ match &self.provider {
+ AiProviderSettings::OpenAi(settings) => Ok(settings
+ .api_url
+ .clone()
+ .unwrap_or_else(|| OPEN_AI_API_URL.to_string())),
+ AiProviderSettings::AzureOpenAi(settings) => settings
+ .api_url
+ .clone()
+ .ok_or_else(|| anyhow!("no Azure OpenAI API URL")),
+ }
+ }
+
+ pub fn provider_model(&self) -> anyhow::Result<OpenAiModel> {
+ match &self.provider {
+ AiProviderSettings::OpenAi(settings) => {
+ Ok(settings.default_model.unwrap_or(OpenAiModel::FourTurbo))
+ }
+ AiProviderSettings::AzureOpenAi(_settings) => {
+ // TODO: We need to use an Azure OpenAI model here.
+ Ok(OpenAiModel::FourTurbo)
+ }
+ }
+ }
+
+ pub fn provider_model_name(&self) -> anyhow::Result<String> {
+ match &self.provider {
+ AiProviderSettings::OpenAi(settings) => Ok(settings
+ .default_model
+ .unwrap_or(OpenAiModel::FourTurbo)
+ .full_name()
+ .to_string()),
+ AiProviderSettings::AzureOpenAi(settings) => settings
+ .deployment_id
+ .clone()
+ .ok_or_else(|| anyhow!("no Azure OpenAI deployment ID")),
+ }
+ }
+}
+
+impl Settings for AssistantSettings {
+ const KEY: Option<&'static str> = Some("assistant");
+
+ type FileContent = AssistantSettingsContent;
+
+ fn load(
+ default_value: &Self::FileContent,
+ user_values: &[&Self::FileContent],
+ _: &mut gpui::AppContext,
+ ) -> anyhow::Result<Self> {
+ Self::load_via_json_merge(default_value, user_values)
+ }
}
/// Assistant panel settings
@@ -77,26 +166,88 @@ pub struct AssistantSettingsContent {
///
/// Default: 320
pub default_height: Option<f32>,
+ /// Deprecated: Please use `provider.default_model` instead.
/// The default OpenAI model to use when starting new conversations.
///
/// Default: gpt-4-1106-preview
+ #[deprecated = "Please use `provider.default_model` instead."]
pub default_open_ai_model: Option<OpenAiModel>,
+ /// Deprecated: Please use `provider.api_url` instead.
/// OpenAI API base URL to use when starting new conversations.
///
/// Default: https://api.openai.com/v1
+ #[deprecated = "Please use `provider.api_url` instead."]
pub openai_api_url: Option<String>,
+ /// The settings for the AI provider.
+ #[serde(default)]
+ pub provider: AiProviderSettingsContent,
}
-impl Settings for AssistantSettings {
- const KEY: Option<&'static str> = Some("assistant");
+#[derive(Debug, Clone, Deserialize)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum AiProviderSettings {
+ /// The settings for the OpenAI provider.
+ #[serde(rename = "openai")]
+ OpenAi(OpenAiProviderSettings),
+ /// The settings for the Azure OpenAI provider.
+ #[serde(rename = "azure_openai")]
+ AzureOpenAi(AzureOpenAiProviderSettings),
+}
- type FileContent = AssistantSettingsContent;
+/// The settings for the AI provider used by the Zed Assistant.
+#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
+#[serde(tag = "type", rename_all = "snake_case")]
+pub enum AiProviderSettingsContent {
+ /// The settings for the OpenAI provider.
+ #[serde(rename = "openai")]
+ OpenAi(OpenAiProviderSettingsContent),
+ /// The settings for the Azure OpenAI provider.
+ #[serde(rename = "azure_openai")]
+ AzureOpenAi(AzureOpenAiProviderSettingsContent),
+}
- fn load(
- default_value: &Self::FileContent,
- user_values: &[&Self::FileContent],
- _: &mut gpui::AppContext,
- ) -> anyhow::Result<Self> {
- Self::load_via_json_merge(default_value, user_values)
+impl Default for AiProviderSettingsContent {
+ fn default() -> Self {
+ Self::OpenAi(OpenAiProviderSettingsContent::default())
}
}
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct OpenAiProviderSettings {
+ /// The OpenAI API base URL to use when starting new conversations.
+ pub api_url: Option<String>,
+ /// The default OpenAI model to use when starting new conversations.
+ pub default_model: Option<OpenAiModel>,
+}
+
+#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
+pub struct OpenAiProviderSettingsContent {
+ /// The OpenAI API base URL to use when starting new conversations.
+ ///
+ /// Default: https://api.openai.com/v1
+ pub api_url: Option<String>,
+ /// The default OpenAI model to use when starting new conversations.
+ ///
+ /// Default: gpt-4-1106-preview
+ pub default_model: Option<OpenAiModel>,
+}
+
+#[derive(Debug, Clone, Deserialize)]
+pub struct AzureOpenAiProviderSettings {
+ /// The Azure OpenAI API base URL to use when starting new conversations.
+ pub api_url: Option<String>,
+ /// The Azure OpenAI API version.
+ pub api_version: Option<AzureOpenAiApiVersion>,
+ /// The Azure OpenAI API deployment ID.
+ pub deployment_id: Option<String>,
+}
+
+#[derive(Debug, Default, Clone, Serialize, Deserialize, JsonSchema)]
+pub struct AzureOpenAiProviderSettingsContent {
+ /// The Azure OpenAI API base URL to use when starting new conversations.
+ pub api_url: Option<String>,
+ /// The Azure OpenAI API version.
+ pub api_version: Option<AzureOpenAiApiVersion>,
+ /// The Azure OpenAI deployment ID.
+ pub deployment_id: Option<String>,
+}
@@ -263,7 +263,7 @@ impl Telemetry {
self: &Arc<Self>,
conversation_id: Option<String>,
kind: AssistantKind,
- model: &'static str,
+ model: &str,
) {
let event = Event::Assistant(AssistantEvent {
conversation_id,