Detailed changes
@@ -853,7 +853,17 @@
}
},
// Different settings for specific language models.
- "language_models": {},
+ "language_models": {
+ "anthropic": {
+ "api_url": "https://api.anthropic.com"
+ },
+ "openai": {
+ "api_url": "https://api.openai.com/v1"
+ },
+ "ollama": {
+ "api_url": "http://localhost:11434"
+ }
+ },
// Zed's Prettier integration settings.
// Allows to enable/disable formatting with Prettier
// and configure default Prettier, used when no project-level Prettier installation is found.
@@ -23,7 +23,7 @@ use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal
use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*;
use language_model::{
- LanguageModelId, LanguageModelProviderName, LanguageModelRegistry, LanguageModelResponseMessage,
+ LanguageModelId, LanguageModelProviderId, LanguageModelRegistry, LanguageModelResponseMessage,
};
pub(crate) use model_selector::*;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
@@ -231,7 +231,7 @@ fn init_completion_provider(cx: &mut AppContext) {
fn update_active_language_model_from_settings(cx: &mut AppContext) {
let settings = AssistantSettings::get_global(cx);
- let provider_name = LanguageModelProviderName::from(settings.default_model.provider.clone());
+ let provider_name = LanguageModelProviderId::from(settings.default_model.provider.clone());
let model_id = LanguageModelId::from(settings.default_model.model.clone());
let Some(provider) = LanguageModelRegistry::global(cx)
@@ -144,8 +144,8 @@ impl AssistantSettingsContent {
fs,
cx,
move |content, _| {
- if content.open_ai.is_none() {
- content.open_ai =
+ if content.openai.is_none() {
+ content.openai =
Some(language_model::settings::OpenAiSettingsContent {
api_url,
low_speed_timeout_in_seconds,
@@ -243,7 +243,7 @@ impl AssistantSettingsContent {
pub fn set_model(&mut self, language_model: Arc<dyn LanguageModel>) {
let model = language_model.id().0.to_string();
- let provider = language_model.provider_name().0.to_string();
+ let provider = language_model.provider_id().0.to_string();
match self {
AssistantSettingsContent::Versioned(settings) => match settings {
@@ -1438,7 +1438,7 @@ impl Render for PromptEditor {
{
let model_name = available_model.name().0.clone();
let provider =
- available_model.provider_name().0.clone();
+ available_model.provider_id().0.clone();
move |_| {
h_flex()
.w_full()
@@ -565,7 +565,7 @@ impl Render for PromptEditor {
{
let model_name = available_model.name().0.clone();
let provider =
- available_model.provider_name().0.clone();
+ available_model.provider_id().0.clone();
move |_| {
h_flex()
.w_full()
@@ -2,7 +2,7 @@ use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AppContext, Global, Model, ModelContext, Task};
use language_model::{
- LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelRegistry,
+ LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelRegistry,
LanguageModelRequest,
};
use smol::lock::{Semaphore, SemaphoreGuardArc};
@@ -89,7 +89,7 @@ impl LanguageModelCompletionProvider {
pub fn set_active_provider(
&mut self,
- provider_name: LanguageModelProviderName,
+ provider_name: LanguageModelProviderId,
cx: &mut ModelContext<Self>,
) {
self.active_provider = LanguageModelRegistry::read_global(cx).provider(&provider_name);
@@ -103,14 +103,19 @@ impl LanguageModelCompletionProvider {
pub fn set_active_model(&mut self, model: Arc<dyn LanguageModel>, cx: &mut ModelContext<Self>) {
if self.active_model.as_ref().map_or(false, |m| {
- m.id() == model.id() && m.provider_name() == model.provider_name()
+ m.id() == model.id() && m.provider_id() == model.provider_id()
}) {
return;
}
self.active_provider =
- LanguageModelRegistry::read_global(cx).provider(&model.provider_name());
- self.active_model = Some(model);
+ LanguageModelRegistry::read_global(cx).provider(&model.provider_id());
+ self.active_model = Some(model.clone());
+
+ if let Some(provider) = self.active_provider.as_ref() {
+ provider.load_model(model, cx);
+ }
+
cx.notify();
}
@@ -25,6 +25,7 @@ pub fn init(client: Arc<Client>, cx: &mut AppContext) {
pub trait LanguageModel: Send + Sync {
fn id(&self) -> LanguageModelId;
fn name(&self) -> LanguageModelName;
+ fn provider_id(&self) -> LanguageModelProviderId;
fn provider_name(&self) -> LanguageModelProviderName;
fn telemetry_id(&self) -> String;
@@ -44,8 +45,10 @@ pub trait LanguageModel: Send + Sync {
}
pub trait LanguageModelProvider: 'static {
+ fn id(&self) -> LanguageModelProviderId;
fn name(&self) -> LanguageModelProviderName;
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>>;
+ fn load_model(&self, _model: Arc<dyn LanguageModel>, _cx: &AppContext) {}
fn is_authenticated(&self, cx: &AppContext) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView;
@@ -62,6 +65,9 @@ pub struct LanguageModelId(pub SharedString);
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct LanguageModelName(pub SharedString);
+#[derive(Clone, Eq, PartialEq, Hash, Debug)]
+pub struct LanguageModelProviderId(pub SharedString);
+
#[derive(Clone, Eq, PartialEq, Hash, Debug)]
pub struct LanguageModelProviderName(pub SharedString);
@@ -77,6 +83,12 @@ impl From<String> for LanguageModelName {
}
}
+impl From<String> for LanguageModelProviderId {
+ fn from(value: String) -> Self {
+ Self(SharedString::from(value))
+ }
+}
+
impl From<String> for LanguageModelProviderName {
fn from(value: String) -> Self {
Self(SharedString::from(value))
@@ -1,6 +1,5 @@
use anthropic::{stream_completion, Request, RequestMessage};
use anyhow::{anyhow, Result};
-use collections::HashMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{
@@ -9,7 +8,7 @@ use gpui::{
};
use http_client::HttpClient;
use settings::{Settings, SettingsStore};
-use std::{sync::Arc, time::Duration};
+use std::{collections::BTreeMap, sync::Arc, time::Duration};
use strum::IntoEnumIterator;
use theme::ThemeSettings;
use ui::prelude::*;
@@ -17,11 +16,12 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
- LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
- LanguageModelRequest, LanguageModelRequestMessage, Role,
+ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest, LanguageModelRequestMessage, Role,
};
-const PROVIDER_NAME: &str = "anthropic";
+const PROVIDER_ID: &str = "anthropic";
+const PROVIDER_NAME: &str = "Anthropic";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct AnthropicSettings {
@@ -37,7 +37,6 @@ pub struct AnthropicLanguageModelProvider {
struct State {
api_key: Option<String>,
- settings: AnthropicSettings,
_subscription: Subscription,
}
@@ -45,9 +44,7 @@ impl AnthropicLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
- settings: AnthropicSettings::default(),
- _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
- this.settings = AllLanguageModelSettings::get_global(cx).anthropic.clone();
+ _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
@@ -64,12 +61,16 @@ impl LanguageModelProviderState for AnthropicLanguageModelProvider {
}
impl LanguageModelProvider for AnthropicLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
- let mut models = HashMap::default();
+ let mut models = BTreeMap::default();
// Add base models from anthropic::Model::iter()
for model in anthropic::Model::iter() {
@@ -79,7 +80,11 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
}
// Override with available models from settings
- for model in &self.state.read(cx).settings.available_models {
+ for model in AllLanguageModelSettings::get_global(cx)
+ .anthropic
+ .available_models
+ .iter()
+ {
models.insert(model.id().to_string(), model.clone());
}
@@ -104,7 +109,10 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
- let api_url = self.state.read(cx).settings.api_url.clone();
+ let api_url = AllLanguageModelSettings::get_global(cx)
+ .anthropic
+ .api_url
+ .clone();
let state = self.state.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
@@ -132,7 +140,8 @@ impl LanguageModelProvider for AnthropicLanguageModelProvider {
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
let state = self.state.clone();
- let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
+ let delete_credentials =
+ cx.delete_credentials(&AllLanguageModelSettings::get_global(cx).anthropic.api_url);
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
state.update(&mut cx, |this, cx| {
@@ -221,6 +230,10 @@ impl LanguageModel for AnthropicModel {
LanguageModelName::from(self.model.display_name().to_string())
}
+ fn provider_id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@@ -249,11 +262,13 @@ impl LanguageModel for AnthropicModel {
let request = self.to_anthropic_request(request);
let http_client = self.http_client.clone();
- let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
+
+ let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
+ let settings = &AllLanguageModelSettings::get_global(cx).anthropic;
(
state.api_key.clone(),
- state.settings.api_url.clone(),
- state.settings.low_speed_timeout,
+ settings.api_url.clone(),
+ settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
@@ -365,7 +380,10 @@ impl AuthenticationPrompt {
}
let write_credentials = cx.write_credentials(
- &self.state.read(cx).settings.api_url,
+ AllLanguageModelSettings::get_global(cx)
+ .anthropic
+ .api_url
+ .as_str(),
"Bearer",
api_key.as_bytes(),
);
@@ -1,15 +1,15 @@
use super::open_ai::count_open_ai_tokens;
use crate::{
settings::AllLanguageModelSettings, CloudModel, LanguageModel, LanguageModelId,
- LanguageModelName, LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
+ LanguageModelName, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest,
};
use anyhow::Result;
use client::Client;
-use collections::HashMap;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, AsyncAppContext, Subscription, Task};
use settings::{Settings, SettingsStore};
-use std::sync::Arc;
+use std::{collections::BTreeMap, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@@ -17,6 +17,7 @@ use crate::LanguageModelProvider;
use super::anthropic::{count_anthropic_tokens, preprocess_anthropic_request};
+pub const PROVIDER_ID: &str = "zed.dev";
pub const PROVIDER_NAME: &str = "zed.dev";
#[derive(Default, Clone, Debug, PartialEq)]
@@ -33,7 +34,6 @@ pub struct CloudLanguageModelProvider {
struct State {
client: Arc<Client>,
status: client::Status,
- settings: ZedDotDevSettings,
_subscription: Subscription,
}
@@ -52,9 +52,7 @@ impl CloudLanguageModelProvider {
let state = cx.new_model(|cx| State {
client: client.clone(),
status,
- settings: ZedDotDevSettings::default(),
- _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
- this.settings = AllLanguageModelSettings::get_global(cx).zed_dot_dev.clone();
+ _subscription: cx.observe_global::<SettingsStore>(|_, cx| {
cx.notify();
}),
});
@@ -90,12 +88,16 @@ impl LanguageModelProviderState for CloudLanguageModelProvider {
}
impl LanguageModelProvider for CloudLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
- let mut models = HashMap::default();
+ let mut models = BTreeMap::default();
// Add base models from CloudModel::iter()
for model in CloudModel::iter() {
@@ -105,7 +107,10 @@ impl LanguageModelProvider for CloudLanguageModelProvider {
}
// Override with available models from settings
- for model in &self.state.read(cx).settings.available_models {
+ for model in &AllLanguageModelSettings::get_global(cx)
+ .zed_dot_dev
+ .available_models
+ {
models.insert(model.id().to_string(), model.clone());
}
@@ -156,6 +161,10 @@ impl LanguageModel for CloudLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
+ fn provider_id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@@ -187,6 +196,9 @@ impl LanguageModel for CloudLanguageModel {
| CloudModel::Claude3Opus
| CloudModel::Claude3Sonnet
| CloudModel::Claude3Haiku => count_anthropic_tokens(request, cx),
+ CloudModel::Custom { name, .. } if name.starts_with("anthropic/") => {
+ count_anthropic_tokens(request, cx)
+ }
_ => {
let request = self.client.request(proto::CountTokensWithLanguageModel {
model: self.model.id().to_string(),
@@ -5,7 +5,8 @@ use futures::{channel::mpsc, future::BoxFuture, stream::BoxStream, FutureExt, St
use crate::{
LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
- LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
+ LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
+ LanguageModelRequest,
};
use gpui::{AnyView, AppContext, AsyncAppContext, Task};
use http_client::Result;
@@ -19,8 +20,12 @@ pub fn language_model_name() -> LanguageModelName {
LanguageModelName::from("Fake".to_string())
}
+pub fn provider_id() -> LanguageModelProviderId {
+ LanguageModelProviderId::from("fake".to_string())
+}
+
pub fn provider_name() -> LanguageModelProviderName {
- LanguageModelProviderName::from("fake".to_string())
+ LanguageModelProviderName::from("Fake".to_string())
}
#[derive(Clone, Default)]
@@ -35,6 +40,10 @@ impl LanguageModelProviderState for FakeLanguageModelProvider {
}
impl LanguageModelProvider for FakeLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ provider_id()
+ }
+
fn name(&self) -> LanguageModelProviderName {
provider_name()
}
@@ -125,6 +134,10 @@ impl LanguageModel for FakeLanguageModel {
language_model_name()
}
+ fn provider_id(&self) -> LanguageModelProviderId {
+ provider_id()
+ }
+
fn provider_name(&self) -> LanguageModelProviderName {
provider_name()
}
@@ -2,21 +2,24 @@ use anyhow::{anyhow, Result};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
use http_client::HttpClient;
-use ollama::{get_models, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest};
+use ollama::{
+ get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
+};
use settings::{Settings, SettingsStore};
use std::{sync::Arc, time::Duration};
use ui::{prelude::*, ButtonLike, ElevationIndex};
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
- LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
- LanguageModelRequest, Role,
+ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest, Role,
};
const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
-const PROVIDER_NAME: &str = "ollama";
+const PROVIDER_ID: &str = "ollama";
+const PROVIDER_NAME: &str = "Ollama";
#[derive(Default, Debug, Clone, PartialEq)]
pub struct OllamaSettings {
@@ -32,14 +35,14 @@ pub struct OllamaLanguageModelProvider {
struct State {
http_client: Arc<dyn HttpClient>,
available_models: Vec<ollama::Model>,
- settings: OllamaSettings,
_subscription: Subscription,
}
impl State {
- fn fetch_models(&self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
+ fn fetch_models(&self, cx: &ModelContext<Self>) -> Task<Result<()>> {
+ let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
- let api_url = self.settings.api_url.clone();
+ let api_url = settings.api_url.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
cx.spawn(|this, mut cx| async move {
@@ -66,23 +69,25 @@ impl State {
impl OllamaLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
- Self {
+ let this = Self {
http_client: http_client.clone(),
state: cx.new_model(|cx| State {
http_client,
available_models: Default::default(),
- settings: OllamaSettings::default(),
_subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
- this.settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
+ this.fetch_models(cx).detach_and_log_err(cx);
cx.notify();
}),
}),
- }
+ };
+ this.fetch_models(cx).detach_and_log_err(cx);
+ this
}
fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
+ let settings = &AllLanguageModelSettings::get_global(cx).ollama;
let http_client = self.http_client.clone();
- let api_url = self.state.read(cx).settings.api_url.clone();
+ let api_url = settings.api_url.clone();
let state = self.state.clone();
// As a proxy for the server being "authenticated", we'll check if its up by fetching the models
@@ -117,6 +122,10 @@ impl LanguageModelProviderState for OllamaLanguageModelProvider {
}
impl LanguageModelProvider for OllamaLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@@ -131,12 +140,20 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
id: LanguageModelId::from(model.name.clone()),
model: model.clone(),
http_client: self.http_client.clone(),
- state: self.state.clone(),
}) as Arc<dyn LanguageModel>
})
.collect()
}
+ fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
+ let settings = &AllLanguageModelSettings::get_global(cx).ollama;
+ let http_client = self.http_client.clone();
+ let api_url = settings.api_url.clone();
+ let id = model.id().0.to_string();
+ cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
+ .detach_and_log_err(cx);
+ }
+
fn is_authenticated(&self, cx: &AppContext) -> bool {
!self.state.read(cx).available_models.is_empty()
}
@@ -167,7 +184,6 @@ impl LanguageModelProvider for OllamaLanguageModelProvider {
pub struct OllamaLanguageModel {
id: LanguageModelId,
model: ollama::Model,
- state: gpui::Model<State>,
http_client: Arc<dyn HttpClient>,
}
@@ -211,6 +227,14 @@ impl LanguageModel for OllamaLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
+ fn provider_id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
+ fn provider_name(&self) -> LanguageModelProviderName {
+ LanguageModelProviderName(PROVIDER_NAME.into())
+ }
+
fn max_token_count(&self) -> usize {
self.model.max_token_count()
}
@@ -219,10 +243,6 @@ impl LanguageModel for OllamaLanguageModel {
format!("ollama/{}", self.model.id())
}
- fn provider_name(&self) -> LanguageModelProviderName {
- LanguageModelProviderName(PROVIDER_NAME.into())
- }
-
fn count_tokens(
&self,
request: LanguageModelRequest,
@@ -248,11 +268,9 @@ impl LanguageModel for OllamaLanguageModel {
let request = self.to_ollama_request(request);
let http_client = self.http_client.clone();
- let Ok((api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
- (
- state.settings.api_url.clone(),
- state.settings.low_speed_timeout,
- )
+ let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
+ let settings = &AllLanguageModelSettings::get_global(cx).ollama;
+ (settings.api_url.clone(), settings.low_speed_timeout)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
};
@@ -1,5 +1,5 @@
use anyhow::{anyhow, Result};
-use collections::HashMap;
+use collections::BTreeMap;
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, FutureExt, StreamExt};
use gpui::{
@@ -17,11 +17,12 @@ use util::ResultExt;
use crate::{
settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
- LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
- LanguageModelRequest, Role,
+ LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
+ LanguageModelProviderState, LanguageModelRequest, Role,
};
-const PROVIDER_NAME: &str = "openai";
+const PROVIDER_ID: &str = "openai";
+const PROVIDER_NAME: &str = "OpenAI";
#[derive(Default, Clone, Debug, PartialEq)]
pub struct OpenAiSettings {
@@ -37,7 +38,6 @@ pub struct OpenAiLanguageModelProvider {
struct State {
api_key: Option<String>,
- settings: OpenAiSettings,
_subscription: Subscription,
}
@@ -45,9 +45,7 @@ impl OpenAiLanguageModelProvider {
pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
let state = cx.new_model(|cx| State {
api_key: None,
- settings: OpenAiSettings::default(),
- _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
- this.settings = AllLanguageModelSettings::get_global(cx).open_ai.clone();
+ _subscription: cx.observe_global::<SettingsStore>(|_this: &mut State, cx| {
cx.notify();
}),
});
@@ -65,12 +63,16 @@ impl LanguageModelProviderState for OpenAiLanguageModelProvider {
}
impl LanguageModelProvider for OpenAiLanguageModelProvider {
+ fn id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
fn name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
- let mut models = HashMap::default();
+ let mut models = BTreeMap::default();
// Add base models from open_ai::Model::iter()
for model in open_ai::Model::iter() {
@@ -80,7 +82,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
}
// Override with available models from settings
- for model in &self.state.read(cx).settings.available_models {
+ for model in &AllLanguageModelSettings::get_global(cx)
+ .openai
+ .available_models
+ {
models.insert(model.id().to_string(), model.clone());
}
@@ -105,7 +110,10 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
if self.is_authenticated(cx) {
Task::ready(Ok(()))
} else {
- let api_url = self.state.read(cx).settings.api_url.clone();
+ let api_url = AllLanguageModelSettings::get_global(cx)
+ .openai
+ .api_url
+ .clone();
let state = self.state.clone();
cx.spawn(|mut cx| async move {
let api_key = if let Ok(api_key) = std::env::var("OPENAI_API_KEY") {
@@ -131,7 +139,8 @@ impl LanguageModelProvider for OpenAiLanguageModelProvider {
}
fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
- let delete_credentials = cx.delete_credentials(&self.state.read(cx).settings.api_url);
+ let settings = &AllLanguageModelSettings::get_global(cx).openai;
+ let delete_credentials = cx.delete_credentials(&settings.api_url);
let state = self.state.clone();
cx.spawn(|mut cx| async move {
delete_credentials.await.log_err();
@@ -188,6 +197,10 @@ impl LanguageModel for OpenAiLanguageModel {
LanguageModelName::from(self.model.display_name().to_string())
}
+ fn provider_id(&self) -> LanguageModelProviderId {
+ LanguageModelProviderId(PROVIDER_ID.into())
+ }
+
fn provider_name(&self) -> LanguageModelProviderName {
LanguageModelProviderName(PROVIDER_NAME.into())
}
@@ -216,11 +229,12 @@ impl LanguageModel for OpenAiLanguageModel {
let request = self.to_open_ai_request(request);
let http_client = self.http_client.clone();
- let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, _| {
+ let Ok((api_key, api_url, low_speed_timeout)) = cx.read_model(&self.state, |state, cx| {
+ let settings = &AllLanguageModelSettings::get_global(cx).openai;
(
state.api_key.clone(),
- state.settings.api_url.clone(),
- state.settings.low_speed_timeout,
+ settings.api_url.clone(),
+ settings.low_speed_timeout,
)
}) else {
return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
@@ -307,11 +321,9 @@ impl AuthenticationPrompt {
return;
}
- let write_credentials = cx.write_credentials(
- &self.state.read(cx).settings.api_url,
- "Bearer",
- api_key.as_bytes(),
- );
+ let settings = &AllLanguageModelSettings::get_global(cx).openai;
+ let write_credentials =
+ cx.write_credentials(&settings.api_url, "Bearer", api_key.as_bytes());
let state = self.state.clone();
cx.spawn(|_, mut cx| async move {
write_credentials.await?;
@@ -9,7 +9,7 @@ use crate::{
anthropic::AnthropicLanguageModelProvider, cloud::CloudLanguageModelProvider,
ollama::OllamaLanguageModelProvider, open_ai::OpenAiLanguageModelProvider,
},
- LanguageModel, LanguageModelProvider, LanguageModelProviderName, LanguageModelProviderState,
+ LanguageModel, LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderState,
};
pub fn init(client: Arc<Client>, cx: &mut AppContext) {
@@ -48,7 +48,7 @@ fn register_language_model_providers(
registry.register_provider(CloudLanguageModelProvider::new(client.clone(), cx), cx);
} else {
registry.unregister_provider(
- &LanguageModelProviderName::from(
+ &LanguageModelProviderId::from(
crate::provider::cloud::PROVIDER_NAME.to_string(),
),
cx,
@@ -65,7 +65,7 @@ impl Global for GlobalLanguageModelRegistry {}
#[derive(Default)]
pub struct LanguageModelRegistry {
- providers: HashMap<LanguageModelProviderName, Arc<dyn LanguageModelProvider>>,
+ providers: HashMap<LanguageModelProviderId, Arc<dyn LanguageModelProvider>>,
}
impl LanguageModelRegistry {
@@ -94,7 +94,7 @@ impl LanguageModelRegistry {
provider: T,
cx: &mut ModelContext<Self>,
) {
- let name = provider.name();
+ let name = provider.id();
if let Some(subscription) = provider.subscribe(cx) {
subscription.detach();
@@ -106,7 +106,7 @@ impl LanguageModelRegistry {
pub fn unregister_provider(
&mut self,
- name: &LanguageModelProviderName,
+ name: &LanguageModelProviderId,
cx: &mut ModelContext<Self>,
) {
if self.providers.remove(name).is_some() {
@@ -116,7 +116,7 @@ impl LanguageModelRegistry {
pub fn providers(
&self,
- ) -> impl Iterator<Item = (&LanguageModelProviderName, &Arc<dyn LanguageModelProvider>)> {
+ ) -> impl Iterator<Item = (&LanguageModelProviderId, &Arc<dyn LanguageModelProvider>)> {
self.providers.iter()
}
@@ -130,7 +130,7 @@ impl LanguageModelRegistry {
pub fn available_models_grouped_by_provider(
&self,
cx: &AppContext,
- ) -> HashMap<LanguageModelProviderName, Vec<Arc<dyn LanguageModel>>> {
+ ) -> HashMap<LanguageModelProviderId, Vec<Arc<dyn LanguageModel>>> {
self.providers
.iter()
.map(|(name, provider)| (name.clone(), provider.provided_models(cx)))
@@ -139,7 +139,7 @@ impl LanguageModelRegistry {
pub fn provider(
&self,
- name: &LanguageModelProviderName,
+ name: &LanguageModelProviderId,
) -> Option<Arc<dyn LanguageModelProvider>> {
self.providers.get(name).cloned()
}
@@ -160,10 +160,10 @@ mod tests {
let providers = registry.read(cx).providers().collect::<Vec<_>>();
assert_eq!(providers.len(), 1);
- assert_eq!(providers[0].0, &crate::provider::fake::provider_name());
+ assert_eq!(providers[0].0, &crate::provider::fake::provider_id());
registry.update(cx, |registry, cx| {
- registry.unregister_provider(&crate::provider::fake::provider_name(), cx);
+ registry.unregister_provider(&crate::provider::fake::provider_id(), cx);
});
let providers = registry.read(cx).providers().collect::<Vec<_>>();
@@ -21,9 +21,9 @@ pub fn init(cx: &mut AppContext) {
#[derive(Default)]
pub struct AllLanguageModelSettings {
- pub open_ai: OpenAiSettings,
pub anthropic: AnthropicSettings,
pub ollama: OllamaSettings,
+ pub openai: OpenAiSettings,
pub zed_dot_dev: ZedDotDevSettings,
}
@@ -31,7 +31,7 @@ pub struct AllLanguageModelSettings {
pub struct AllLanguageModelSettingsContent {
pub anthropic: Option<AnthropicSettingsContent>,
pub ollama: Option<OllamaSettingsContent>,
- pub open_ai: Option<OpenAiSettingsContent>,
+ pub openai: Option<OpenAiSettingsContent>,
#[serde(rename = "zed.dev")]
pub zed_dot_dev: Option<ZedDotDevSettingsContent>,
}
@@ -110,21 +110,21 @@ impl settings::Settings for AllLanguageModelSettings {
}
merge(
- &mut settings.open_ai.api_url,
- value.open_ai.as_ref().and_then(|s| s.api_url.clone()),
+ &mut settings.openai.api_url,
+ value.openai.as_ref().and_then(|s| s.api_url.clone()),
);
if let Some(low_speed_timeout_in_seconds) = value
- .open_ai
+ .openai
.as_ref()
.and_then(|s| s.low_speed_timeout_in_seconds)
{
- settings.open_ai.low_speed_timeout =
+ settings.openai.low_speed_timeout =
Some(Duration::from_secs(low_speed_timeout_in_seconds));
}
merge(
- &mut settings.open_ai.available_models,
+ &mut settings.openai.available_models,
value
- .open_ai
+ .openai
.as_ref()
.and_then(|s| s.available_models.clone()),
);
@@ -4,7 +4,7 @@ use http_client::{AsyncBody, HttpClient, Method, Request as HttpRequest};
use isahc::config::Configurable;
use schemars::JsonSchema;
use serde::{Deserialize, Serialize};
-use std::{convert::TryFrom, time::Duration};
+use std::{convert::TryFrom, sync::Arc, time::Duration};
pub const OLLAMA_API_URL: &str = "http://localhost:11434";
@@ -243,7 +243,7 @@ pub async fn get_models(
}
/// Sends an empty request to Ollama to trigger loading the model
-pub async fn preload_model(client: &dyn HttpClient, api_url: &str, model: &str) -> Result<()> {
+pub async fn preload_model(client: Arc<dyn HttpClient>, api_url: &str, model: &str) -> Result<()> {
let uri = format!("{api_url}/api/generate");
let request = HttpRequest::builder()
.method(Method::POST)
@@ -85,12 +85,8 @@ To do so, add the following to your Zed `settings.json`:
```json
{
- "assistant": {
- "version": "1",
- "provider": {
- "name": "openai",
- "type": "openai",
- "default_model": "gpt-4-turbo-preview",
+ "language_models": {
+ "openai": {
"api_url": "http://localhost:11434/v1"
}
}
@@ -103,51 +99,32 @@ The custom URL here is `http://localhost:11434/v1`.
You can use Ollama with the Zed assistant by making Ollama appear as an OpenAPI endpoint.
-1. Add the following to your Zed `settings.json`:
-
- ```json
- {
- "assistant": {
- "version": "1",
- "provider": {
- "name": "openai",
- "type": "openai",
- "default_model": "gpt-4-turbo-preview",
- "api_url": "http://localhost:11434/v1"
- }
- }
- }
+1. Download, for example, the `mistral` model with Ollama:
```
-2. Download, for example, the `mistral` model with Ollama:
+ ollama pull mistral
```
- ollama run mistral
+2. Make sure that the Ollama server is running. You can start it either via running the Ollama app, or launching:
```
-3. Copy the model and change its name to match the model in the Zed `settings.json`:
+ ollama serve
```
- ollama cp mistral gpt-4-turbo-preview
- ```
-4. Use `assistant: reset key` (see the [Setup](#setup) section above) and enter the following API key:
- ```
- ollama
- ```
-5. Restart Zed
-
-### Using Claude 3.5 Sonnet
-
-You can use Claude with the Zed assistant by adding the following settings:
+3. In the assistant panel, select one of the Ollama models using the model dropdown.
+4. (Optional) If you want to change the default url that is used to access the Ollama server, you can do so by adding the following settings:
```json
-"assistant": {
- "version": "1",
- "provider": {
- "default_model": "claude-3-5-sonnet",
- "name": "anthropic"
+{
+ "language_models": {
+ "ollama": {
+ "api_url": "http://localhost:11434"
+ }
}
-},
+}
```
-When you save the settings, the assistant panel will open and ask you to add your Anthropic API key.
-You need can obtain this key [here](https://console.anthropic.com/settings/keys).
+### Using Claude 3.5 Sonnet
+
+You can use Claude with the Zed assistant by choosing it via the model dropdown in the assistant panel.
+
+You need can obtain an API key [here](https://console.anthropic.com/settings/keys).
Even if you pay for Claude Pro, you will still have to [pay for additional credits](https://console.anthropic.com/settings/plans) to use it via the API.