From ec487d8f649603e040fec4df2764585f9425532f Mon Sep 17 00:00:00 2001 From: Richard Feldman Date: Fri, 19 Jul 2024 13:35:34 -0400 Subject: [PATCH] Extract completion provider crate (#14823) We will soon need `semantic_index` to be able to use `CompletionProvider`. This is currently impossible due to a cyclic crate dependency, because `CompletionProvider` lives in the `assistant` crate, which depends on `semantic_index`. This PR breaks the dependency cycle by extracting two crates out of `assistant`: `language_model` and `completion`. Only one piece of logic changed: [this code](https://github.com/zed-industries/zed/commit/922fcaf5a6076e56890373035b1065b13512546d#diff-3857b3707687a4d585f1200eec4c34a7a079eae8d303b4ce5b4fce46234ace9fR61-R69). * As of https://github.com/zed-industries/zed/pull/13276, whenever we ask a given completion provider for its available models, OpenAI providers would go and ask the global assistant settings whether the user had configured an `available_models` setting, and if so, return that. * This PR changes it so that instead of eagerly asking the assistant settings for this info (the new crate must not depend on `assistant`, or else the dependency cycle would be back), OpenAI completion providers now store the user-configured settings as part of their struct, and whenever the settings change, we update the provider. In theory, this change should not change user-visible behavior...but since it's the only change in this large PR that's more than just moving code around, I'm mentioning it here in case there's an unexpected regression in practice! (cc @amtoaer in case you'd like to try out this branch and verify that the feature is still working the way you expect.) Release Notes: - N/A --------- Co-authored-by: Marshall Bowers --- Cargo.lock | 64 +++- Cargo.toml | 4 + crates/assistant/Cargo.toml | 5 +- crates/assistant/src/assistant.rs | 191 ++---------- crates/assistant/src/assistant_panel.rs | 13 +- crates/assistant/src/assistant_settings.rs | 293 ++++++++---------- crates/assistant/src/context.rs | 11 +- crates/assistant/src/inline_assistant.rs | 9 +- crates/assistant/src/model_selector.rs | 2 +- crates/assistant/src/prompt_library.rs | 3 +- .../src/terminal_inline_assistant.rs | 6 +- crates/collab/Cargo.toml | 2 + crates/collab/src/tests/test_server.rs | 2 +- crates/completion/Cargo.toml | 56 ++++ crates/completion/LICENSE-GPL | 1 + .../src}/anthropic.rs | 57 +--- .../src}/cloud.rs | 7 +- .../src/completion.rs} | 169 ++-------- .../src}/fake.rs | 2 +- .../src}/ollama.rs | 19 +- .../src}/open_ai.rs | 63 ++-- crates/language_model/Cargo.toml | 41 +++ crates/language_model/LICENSE-GPL | 1 + crates/language_model/src/language_model.rs | 7 + .../language_model/src/model/cloud_model.rs | 160 ++++++++++ crates/language_model/src/model/mod.rs | 60 ++++ crates/language_model/src/request.rs | 110 +++++++ crates/language_model/src/role.rs | 68 ++++ crates/semantic_index/Cargo.toml | 1 + crates/semantic_index/src/semantic_index.rs | 3 + 30 files changed, 820 insertions(+), 610 deletions(-) create mode 100644 crates/completion/Cargo.toml create mode 120000 crates/completion/LICENSE-GPL rename crates/{assistant/src/completion_provider => completion/src}/anthropic.rs (86%) rename crates/{assistant/src/completion_provider => completion/src}/cloud.rs (96%) rename crates/{assistant/src/completion_provider.rs => completion/src/completion.rs} (57%) rename crates/{assistant/src/completion_provider => completion/src}/fake.rs (97%) rename crates/{assistant/src/completion_provider => completion/src}/ollama.rs (96%) rename crates/{assistant/src/completion_provider => completion/src}/open_ai.rs (89%) create mode 100644 crates/language_model/Cargo.toml create mode 120000 crates/language_model/LICENSE-GPL create mode 100644 crates/language_model/src/language_model.rs create mode 100644 crates/language_model/src/model/cloud_model.rs create mode 100644 crates/language_model/src/model/mod.rs create mode 100644 crates/language_model/src/request.rs create mode 100644 crates/language_model/src/role.rs diff --git a/Cargo.lock b/Cargo.lock index c734a89f45d7d30db324816c4bc7d0a632f6f4fb..40f9c5922850bf415844f6f44d8ffe04f00c83e6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -382,6 +382,7 @@ dependencies = [ "clock", "collections", "command_palette_hooks", + "completion", "ctor", "editor", "env_logger", @@ -396,6 +397,7 @@ dependencies = [ "indexed_docs", "indoc", "language", + "language_model", "log", "menu", "multi_buffer", @@ -418,13 +420,11 @@ dependencies = [ "settings", "similar", "smol", - "strum", "telemetry_events", "terminal", "terminal_view", "text", "theme", - "tiktoken-rs", "toml 0.8.10", "ui", "unindent", @@ -2491,6 +2491,7 @@ dependencies = [ "clock", "collab_ui", "collections", + "completion", "ctor", "dashmap", "dev_server_projects", @@ -2673,6 +2674,42 @@ dependencies = [ "gpui", ] +[[package]] +name = "completion" +version = "0.1.0" +dependencies = [ + "anthropic", + "anyhow", + "client", + "collections", + "ctor", + "editor", + "env_logger", + "futures 0.3.28", + "gpui", + "http 0.1.0", + "language", + "language_model", + "log", + "menu", + "ollama", + "open_ai", + "parking_lot", + "project", + "rand 0.8.5", + "serde", + "serde_json", + "settings", + "smol", + "strum", + "text", + "theme", + "tiktoken-rs", + "ui", + "unindent", + "util", +] + [[package]] name = "concurrent-queue" version = "2.2.0" @@ -5996,6 +6033,28 @@ dependencies = [ "util", ] +[[package]] +name = "language_model" +version = "0.1.0" +dependencies = [ + "anthropic", + "ctor", + "editor", + "env_logger", + "language", + "log", + "ollama", + "open_ai", + "project", + "proto", + "rand 0.8.5", + "schemars", + "serde", + "strum", + "text", + "unindent", +] + [[package]] name = "language_selector" version = "0.1.0" @@ -9510,6 +9569,7 @@ dependencies = [ "client", "clock", "collections", + "completion", "env_logger", "fs", "futures 0.3.28", diff --git a/Cargo.toml b/Cargo.toml index 2f607134c450bcef7525588768cc6041994ea9fe..1df8affd08737532ef5725d56b01d9be4351f07b 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -19,6 +19,7 @@ members = [ "crates/collections", "crates/command_palette", "crates/command_palette_hooks", + "crates/completion", "crates/copilot", "crates/db", "crates/dev_server_projects", @@ -50,6 +51,7 @@ members = [ "crates/install_cli", "crates/journal", "crates/language", + "crates/language_model", "crates/language_selector", "crates/language_tools", "crates/languages", @@ -176,6 +178,7 @@ collab_ui = { path = "crates/collab_ui" } collections = { path = "crates/collections" } command_palette = { path = "crates/command_palette" } command_palette_hooks = { path = "crates/command_palette_hooks" } +completion = { path = "crates/completion" } copilot = { path = "crates/copilot" } db = { path = "crates/db" } dev_server_projects = { path = "crates/dev_server_projects" } @@ -205,6 +208,7 @@ inline_completion_button = { path = "crates/inline_completion_button" } install_cli = { path = "crates/install_cli" } journal = { path = "crates/journal" } language = { path = "crates/language" } +language_model = { path = "crates/language_model" } language_selector = { path = "crates/language_selector" } language_tools = { path = "crates/language_tools" } languages = { path = "crates/languages" } diff --git a/crates/assistant/Cargo.toml b/crates/assistant/Cargo.toml index e3ddd4e2c7454c7bff5c0388a6071323cedaf154..201e16bd57eef0fcbecfd1bc41bb27ad26819123 100644 --- a/crates/assistant/Cargo.toml +++ b/crates/assistant/Cargo.toml @@ -33,6 +33,7 @@ client.workspace = true clock.workspace = true collections.workspace = true command_palette_hooks.workspace = true +completion.workspace = true editor.workspace = true feature_flags.workspace = true fs.workspace = true @@ -45,6 +46,7 @@ http.workspace = true indexed_docs.workspace = true indoc.workspace = true language.workspace = true +language_model.workspace = true log.workspace = true menu.workspace = true multi_buffer.workspace = true @@ -64,12 +66,10 @@ serde_json.workspace = true settings.workspace = true similar.workspace = true smol.workspace = true -strum.workspace = true telemetry_events.workspace = true terminal.workspace = true terminal_view.workspace = true theme.workspace = true -tiktoken-rs.workspace = true toml.workspace = true ui.workspace = true util.workspace = true @@ -79,6 +79,7 @@ picker.workspace = true roxmltree = "0.20.0" [dev-dependencies] +completion = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } env_logger.workspace = true diff --git a/crates/assistant/src/assistant.rs b/crates/assistant/src/assistant.rs index cf3726485ff4909080895b710ad5c18d0960e780..0b12cc099cf958a1d842975fed1bd692de438747 100644 --- a/crates/assistant/src/assistant.rs +++ b/crates/assistant/src/assistant.rs @@ -1,6 +1,5 @@ pub mod assistant_panel; pub mod assistant_settings; -mod completion_provider; mod context; pub mod context_store; mod inline_assistant; @@ -12,17 +11,20 @@ mod streaming_diff; mod terminal_inline_assistant; pub use assistant_panel::{AssistantPanel, AssistantPanelEvent}; -use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel}; +use assistant_settings::AssistantSettings; use assistant_slash_command::SlashCommandRegistry; use client::{proto, Client}; use command_palette_hooks::CommandPaletteFilter; -pub use completion_provider::*; +use completion::CompletionProvider; pub use context::*; pub use context_store::*; use fs::Fs; -use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal}; +use gpui::{ + actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal, +}; use indexed_docs::IndexedDocsRegistry; pub(crate) use inline_assistant::*; +use language_model::LanguageModelResponseMessage; pub(crate) use model_selector::*; use semantic_index::{CloudEmbeddingProvider, SemanticIndex}; use serde::{Deserialize, Serialize}; @@ -32,10 +34,7 @@ use slash_command::{ file_command, now_command, project_command, prompt_command, search_command, symbols_command, tabs_command, term_command, }; -use std::{ - fmt::{self, Display}, - sync::Arc, -}; +use std::sync::Arc; pub(crate) use streaming_diff::*; actions!( @@ -73,166 +72,6 @@ impl MessageId { } } -#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] -#[serde(rename_all = "lowercase")] -pub enum Role { - User, - Assistant, - System, -} - -impl Role { - pub fn from_proto(role: i32) -> Role { - match proto::LanguageModelRole::from_i32(role) { - Some(proto::LanguageModelRole::LanguageModelUser) => Role::User, - Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant, - Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System, - Some(proto::LanguageModelRole::LanguageModelTool) => Role::System, - None => Role::User, - } - } - - pub fn to_proto(&self) -> proto::LanguageModelRole { - match self { - Role::User => proto::LanguageModelRole::LanguageModelUser, - Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant, - Role::System => proto::LanguageModelRole::LanguageModelSystem, - } - } - - pub fn cycle(self) -> Role { - match self { - Role::User => Role::Assistant, - Role::Assistant => Role::System, - Role::System => Role::User, - } - } -} - -impl Display for Role { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { - match self { - Role::User => write!(f, "user"), - Role::Assistant => write!(f, "assistant"), - Role::System => write!(f, "system"), - } - } -} - -#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] -pub enum LanguageModel { - Cloud(CloudModel), - OpenAi(OpenAiModel), - Anthropic(AnthropicModel), - Ollama(OllamaModel), -} - -impl Default for LanguageModel { - fn default() -> Self { - LanguageModel::Cloud(CloudModel::default()) - } -} - -impl LanguageModel { - pub fn telemetry_id(&self) -> String { - match self { - LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), - LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), - LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()), - LanguageModel::Ollama(model) => format!("ollama/{}", model.id()), - } - } - - pub fn display_name(&self) -> String { - match self { - LanguageModel::OpenAi(model) => model.display_name().into(), - LanguageModel::Anthropic(model) => model.display_name().into(), - LanguageModel::Cloud(model) => model.display_name().into(), - LanguageModel::Ollama(model) => model.display_name().into(), - } - } - - pub fn max_token_count(&self) -> usize { - match self { - LanguageModel::OpenAi(model) => model.max_token_count(), - LanguageModel::Anthropic(model) => model.max_token_count(), - LanguageModel::Cloud(model) => model.max_token_count(), - LanguageModel::Ollama(model) => model.max_token_count(), - } - } - - pub fn id(&self) -> &str { - match self { - LanguageModel::OpenAi(model) => model.id(), - LanguageModel::Anthropic(model) => model.id(), - LanguageModel::Cloud(model) => model.id(), - LanguageModel::Ollama(model) => model.id(), - } - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct LanguageModelRequestMessage { - pub role: Role, - pub content: String, -} - -impl LanguageModelRequestMessage { - pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { - proto::LanguageModelRequestMessage { - role: self.role.to_proto() as i32, - content: self.content.clone(), - tool_calls: Vec::new(), - tool_call_id: None, - } - } -} - -#[derive(Debug, Default, Serialize, Deserialize)] -pub struct LanguageModelRequest { - pub model: LanguageModel, - pub messages: Vec, - pub stop: Vec, - pub temperature: f32, -} - -impl LanguageModelRequest { - pub fn to_proto(&self) -> proto::CompleteWithLanguageModel { - proto::CompleteWithLanguageModel { - model: self.model.id().to_string(), - messages: self.messages.iter().map(|m| m.to_proto()).collect(), - stop: self.stop.clone(), - temperature: self.temperature, - tool_choice: None, - tools: Vec::new(), - } - } - - /// Before we send the request to the server, we can perform fixups on it appropriate to the model. - pub fn preprocess(&mut self) { - match &self.model { - LanguageModel::OpenAi(_) => {} - LanguageModel::Anthropic(_) => {} - LanguageModel::Ollama(_) => {} - LanguageModel::Cloud(model) => match model { - CloudModel::Claude3Opus - | CloudModel::Claude3Sonnet - | CloudModel::Claude3Haiku - | CloudModel::Claude3_5Sonnet => { - preprocess_anthropic_request(self); - } - _ => {} - }, - } - } -} - -#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] -pub struct LanguageModelResponseMessage { - pub role: Option, - pub content: Option, -} - #[derive(Deserialize, Debug)] pub struct LanguageModelUsage { pub prompt_tokens: u32, @@ -343,7 +182,7 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { context_store::init(&client); prompt_library::init(cx); - completion_provider::init(client.clone(), cx); + init_completion_provider(Arc::clone(&client), cx); assistant_slash_command::init(cx); register_slash_commands(cx); assistant_panel::init(cx); @@ -368,6 +207,20 @@ pub fn init(fs: Arc, client: Arc, cx: &mut AppContext) { .detach(); } +fn init_completion_provider(client: Arc, cx: &mut AppContext) { + let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx); + cx.set_global(CompletionProvider::new(provider, Some(client))); + + let mut settings_version = 0; + cx.observe_global::(move |cx| { + settings_version += 1; + cx.update_global::(|provider, cx| { + assistant_settings::update_completion_provider_settings(provider, settings_version, cx); + }) + }) + .detach(); +} + fn register_slash_commands(cx: &mut AppContext) { let slash_command_registry = SlashCommandRegistry::global(cx); slash_command_registry.register_command(file_command::FileSlashCommand, true); diff --git a/crates/assistant/src/assistant_panel.rs b/crates/assistant/src/assistant_panel.rs index 92bd4b9cbe983208013ef6efcd445f7bb9a7ce1d..e02c26837ad8ad88815988a23cb9ebfdaa779c5d 100644 --- a/crates/assistant/src/assistant_panel.rs +++ b/crates/assistant/src/assistant_panel.rs @@ -8,18 +8,18 @@ use crate::{ SlashCommandCompletionProvider, SlashCommandRegistry, }, terminal_inline_assistant::TerminalInlineAssistant, - Assist, CompletionProvider, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, - CycleMessageRole, DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, - EditStepOperations, EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, - InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, - QuoteSelection, RemoteContextMetadata, ResetKey, Role, SavedContextMetadata, Split, - ToggleFocus, ToggleModelSelector, + Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole, + DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, EditStepOperations, + EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, InsertIntoEditor, + MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection, + RemoteContextMetadata, ResetKey, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector, }; use anyhow::{anyhow, Result}; use assistant_slash_command::{SlashCommand, SlashCommandOutputSection}; use breadcrumbs::Breadcrumbs; use client::proto; use collections::{BTreeSet, HashMap, HashSet}; +use completion::CompletionProvider; use editor::{ actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt}, display_map::{ @@ -43,6 +43,7 @@ use language::{ language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point, ToOffset, }; +use language_model::Role; use multi_buffer::MultiBufferRow; use picker::{Picker, PickerDelegate}; use project::{Project, ProjectLspAdapterDelegate}; diff --git a/crates/assistant/src/assistant_settings.rs b/crates/assistant/src/assistant_settings.rs index d341973326d6230161140c37a618d75000bcbd07..7fca691e7a244ad5824c99aeee159ccce94f7fe6 100644 --- a/crates/assistant/src/assistant_settings.rs +++ b/crates/assistant/src/assistant_settings.rs @@ -1,166 +1,19 @@ -use std::fmt; - -use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest}; -pub use anthropic::Model as AnthropicModel; -use gpui::Pixels; -pub use ollama::Model as OllamaModel; -pub use open_ai::Model as OpenAiModel; -use schemars::{ - schema::{InstanceType, Metadata, Schema, SchemaObject}, - JsonSchema, -}; -use serde::{ - de::{self, Visitor}, - Deserialize, Deserializer, Serialize, Serializer, +use std::{sync::Arc, time::Duration}; + +use anthropic::Model as AnthropicModel; +use client::Client; +use completion::{ + AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider, + LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider, }; +use gpui::{AppContext, Pixels}; +use language_model::{CloudModel, LanguageModel}; +use ollama::Model as OllamaModel; +use open_ai::Model as OpenAiModel; +use parking_lot::RwLock; +use schemars::{schema::Schema, JsonSchema}; +use serde::{Deserialize, Serialize}; use settings::{Settings, SettingsSources}; -use strum::{EnumIter, IntoEnumIterator}; - -#[derive(Clone, Debug, Default, PartialEq, EnumIter)] -pub enum CloudModel { - Gpt3Point5Turbo, - Gpt4, - Gpt4Turbo, - #[default] - Gpt4Omni, - Gpt4OmniMini, - Claude3_5Sonnet, - Claude3Opus, - Claude3Sonnet, - Claude3Haiku, - Gemini15Pro, - Gemini15Flash, - Custom(String), -} - -impl Serialize for CloudModel { - fn serialize(&self, serializer: S) -> Result - where - S: Serializer, - { - serializer.serialize_str(self.id()) - } -} - -impl<'de> Deserialize<'de> for CloudModel { - fn deserialize(deserializer: D) -> Result - where - D: Deserializer<'de>, - { - struct ZedDotDevModelVisitor; - - impl<'de> Visitor<'de> for ZedDotDevModelVisitor { - type Value = CloudModel; - - fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { - formatter.write_str("a string for a ZedDotDevModel variant or a custom model") - } - - fn visit_str(self, value: &str) -> Result - where - E: de::Error, - { - let model = CloudModel::iter() - .find(|model| model.id() == value) - .unwrap_or_else(|| CloudModel::Custom(value.to_string())); - Ok(model) - } - } - - deserializer.deserialize_str(ZedDotDevModelVisitor) - } -} - -impl JsonSchema for CloudModel { - fn schema_name() -> String { - "ZedDotDevModel".to_owned() - } - - fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { - let variants = CloudModel::iter() - .filter_map(|model| { - let id = model.id(); - if id.is_empty() { - None - } else { - Some(id.to_string()) - } - }) - .collect::>(); - Schema::Object(SchemaObject { - instance_type: Some(InstanceType::String.into()), - enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()), - metadata: Some(Box::new(Metadata { - title: Some("ZedDotDevModel".to_owned()), - default: Some(CloudModel::default().id().into()), - examples: variants.into_iter().map(Into::into).collect(), - ..Default::default() - })), - ..Default::default() - }) - } -} - -impl CloudModel { - pub fn id(&self) -> &str { - match self { - Self::Gpt3Point5Turbo => "gpt-3.5-turbo", - Self::Gpt4 => "gpt-4", - Self::Gpt4Turbo => "gpt-4-turbo-preview", - Self::Gpt4Omni => "gpt-4o", - Self::Gpt4OmniMini => "gpt-4o-mini", - Self::Claude3_5Sonnet => "claude-3-5-sonnet", - Self::Claude3Opus => "claude-3-opus", - Self::Claude3Sonnet => "claude-3-sonnet", - Self::Claude3Haiku => "claude-3-haiku", - Self::Gemini15Pro => "gemini-1.5-pro", - Self::Gemini15Flash => "gemini-1.5-flash", - Self::Custom(id) => id, - } - } - - pub fn display_name(&self) -> &str { - match self { - Self::Gpt3Point5Turbo => "GPT 3.5 Turbo", - Self::Gpt4 => "GPT 4", - Self::Gpt4Turbo => "GPT 4 Turbo", - Self::Gpt4Omni => "GPT 4 Omni", - Self::Gpt4OmniMini => "GPT 4 Omni Mini", - Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", - Self::Claude3Opus => "Claude 3 Opus", - Self::Claude3Sonnet => "Claude 3 Sonnet", - Self::Claude3Haiku => "Claude 3 Haiku", - Self::Gemini15Pro => "Gemini 1.5 Pro", - Self::Gemini15Flash => "Gemini 1.5 Flash", - Self::Custom(id) => id.as_str(), - } - } - - pub fn max_token_count(&self) -> usize { - match self { - Self::Gpt3Point5Turbo => 2048, - Self::Gpt4 => 4096, - Self::Gpt4Turbo | Self::Gpt4Omni => 128000, - Self::Gpt4OmniMini => 128000, - Self::Claude3_5Sonnet - | Self::Claude3Opus - | Self::Claude3Sonnet - | Self::Claude3Haiku => 200000, - Self::Gemini15Pro => 128000, - Self::Gemini15Flash => 32000, - Self::Custom(_) => 4096, // TODO: Make this configurable - } - } - - pub fn preprocess_request(&self, request: &mut LanguageModelRequest) { - match self { - Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => { - preprocess_anthropic_request(request) - } - _ => {} - } - } -} #[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)] #[serde(rename_all = "snake_case")] @@ -620,6 +473,124 @@ fn merge(target: &mut T, value: Option) { } } +pub fn update_completion_provider_settings( + provider: &mut CompletionProvider, + version: usize, + cx: &mut AppContext, +) { + let updated = match &AssistantSettings::get_global(cx).provider { + AssistantProvider::ZedDotDev { model } => provider + .update_current_as::<_, CloudCompletionProvider>(|provider| { + provider.update(model.clone(), version); + }), + AssistantProvider::OpenAi { + model, + api_url, + low_speed_timeout_in_seconds, + available_models, + } => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| { + provider.update( + choose_openai_model(&model, &available_models), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + ); + }), + AssistantProvider::Anthropic { + model, + api_url, + low_speed_timeout_in_seconds, + } => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| { + provider.update( + model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + ); + }), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + } => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| { + provider.update( + model.clone(), + api_url.clone(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + version, + cx, + ); + }), + }; + + // Previously configured provider was changed to another one + if updated.is_none() { + provider.update_provider(|client| create_provider_from_settings(client, version, cx)); + } +} + +pub(crate) fn create_provider_from_settings( + client: Arc, + settings_version: usize, + cx: &mut AppContext, +) -> Arc> { + match &AssistantSettings::get_global(cx).provider { + AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new( + CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), + )), + AssistantProvider::OpenAi { + model, + api_url, + low_speed_timeout_in_seconds, + available_models, + } => Arc::new(RwLock::new(OpenAiCompletionProvider::new( + choose_openai_model(&model, &available_models), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + available_models.clone(), + ))), + AssistantProvider::Anthropic { + model, + api_url, + low_speed_timeout_in_seconds, + } => Arc::new(RwLock::new(AnthropicCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + ))), + AssistantProvider::Ollama { + model, + api_url, + low_speed_timeout_in_seconds, + } => Arc::new(RwLock::new(OllamaCompletionProvider::new( + model.clone(), + api_url.clone(), + client.http_client(), + low_speed_timeout_in_seconds.map(Duration::from_secs), + settings_version, + cx, + ))), + } +} + +/// Choose which model to use for openai provider. +/// If the model is not available, try to use the first available model, or fallback to the original model. +fn choose_openai_model( + model: &::open_ai::Model, + available_models: &[::open_ai::Model], +) -> ::open_ai::Model { + available_models + .iter() + .find(|&m| m == model) + .or_else(|| available_models.first()) + .unwrap_or_else(|| model) + .clone() +} + #[cfg(test)] mod tests { use gpui::{AppContext, UpdateGlobal}; diff --git a/crates/assistant/src/context.rs b/crates/assistant/src/context.rs index 25f24753a1bd75e49d8358469efb258b4e1d37db..f75b693bbd057b523b9f83f91807b6995d53d516 100644 --- a/crates/assistant/src/context.rs +++ b/crates/assistant/src/context.rs @@ -1,12 +1,12 @@ use crate::{ - prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, - LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageStatus, Role, + prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId, + MessageStatus, }; use anyhow::{anyhow, Context as _, Result}; use assistant_slash_command::{ SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry, }; -use client::{proto, telemetry::Telemetry}; +use client::{self, proto, telemetry::Telemetry}; use clock::ReplicaId; use collections::{HashMap, HashSet}; use fs::Fs; @@ -18,6 +18,8 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip use language::{ AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset, }; +use language_model::LanguageModelRequestMessage; +use language_model::{LanguageModelRequest, Role}; use open_ai::Model as OpenAiModel; use paths::contexts_dir; use project::Project; @@ -2477,9 +2479,10 @@ mod tests { use crate::{ assistant_panel, prompt_library, slash_command::{active_command, file_command}, - FakeCompletionProvider, MessageId, + MessageId, }; use assistant_slash_command::{ArgumentCompletion, SlashCommand}; + use completion::FakeCompletionProvider; use fs::FakeFs; use gpui::{AppContext, TestAppContext, WeakView}; use indoc::indoc; diff --git a/crates/assistant/src/inline_assistant.rs b/crates/assistant/src/inline_assistant.rs index be14e271e86f0fb33b1d874e0a926175b7a1aa06..b8dbcacd2b4a00f9cc9188f4d60976ebf9c7b0b0 100644 --- a/crates/assistant/src/inline_assistant.rs +++ b/crates/assistant/src/inline_assistant.rs @@ -1,7 +1,6 @@ use crate::{ assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt, - AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, LanguageModelRequest, - LanguageModelRequestMessage, Role, StreamingDiff, + AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff, }; use anyhow::{anyhow, Context as _, Result}; use client::telemetry::Telemetry; @@ -28,6 +27,7 @@ use gpui::{ WhiteSpace, WindowContext, }; use language::{Buffer, Point, Selection, TransactionId}; +use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use multi_buffer::MultiBufferRow; use parking_lot::Mutex; use rope::Rope; @@ -1432,8 +1432,7 @@ impl Render for PromptEditor { PopoverMenu::new("model-switcher") .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models(cx) - { + for model in CompletionProvider::global(cx).available_models() { menu = menu.custom_entry( { let model = model.clone(); @@ -2606,7 +2605,7 @@ fn merge_ranges(ranges: &mut Vec>, buffer: &MultiBufferSnapshot) { #[cfg(test)] mod tests { use super::*; - use crate::FakeCompletionProvider; + use completion::FakeCompletionProvider; use futures::stream::{self}; use gpui::{Context, TestAppContext}; use indoc::indoc; diff --git a/crates/assistant/src/model_selector.rs b/crates/assistant/src/model_selector.rs index a27b2b55655aa904d941d2143efa1c8743234d21..6cd50a59dabb107c3ed343cefc996060dd18dd51 100644 --- a/crates/assistant/src/model_selector.rs +++ b/crates/assistant/src/model_selector.rs @@ -23,7 +23,7 @@ impl RenderOnce for ModelSelector { .with_handle(self.handle) .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models(cx) { + for model in CompletionProvider::global(cx).available_models() { menu = menu.custom_entry( { let model = model.clone(); diff --git a/crates/assistant/src/prompt_library.rs b/crates/assistant/src/prompt_library.rs index 9d782aedc78cdd3dcb3da88da313205976e0b3f9..a59f4e3c0f907db6af6bf6a043bcb0ecf50784ac 100644 --- a/crates/assistant/src/prompt_library.rs +++ b/crates/assistant/src/prompt_library.rs @@ -1,6 +1,6 @@ use crate::{ slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider, - InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role, + InlineAssist, InlineAssistant, }; use anyhow::{anyhow, Result}; use assets::Assets; @@ -19,6 +19,7 @@ use gpui::{ }; use heed::{types::SerdeBincode, Database, RoTxn}; use language::{language_settings::SoftWrap, Buffer, LanguageRegistry}; +use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use parking_lot::RwLock; use picker::{Picker, PickerDelegate}; use rope::Rope; diff --git a/crates/assistant/src/terminal_inline_assistant.rs b/crates/assistant/src/terminal_inline_assistant.rs index 8f2cd63bac2798d9bf318a968af9378e6503853d..192db0cf5e13488c8124e35863dff57a414ee8bd 100644 --- a/crates/assistant/src/terminal_inline_assistant.rs +++ b/crates/assistant/src/terminal_inline_assistant.rs @@ -1,7 +1,7 @@ use crate::{ assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent, - CompletionProvider, LanguageModelRequest, LanguageModelRequestMessage, Role, + CompletionProvider, }; use anyhow::{Context as _, Result}; use client::telemetry::Telemetry; @@ -17,6 +17,7 @@ use gpui::{ Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace, }; use language::Buffer; +use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role}; use settings::{update_settings_file, Settings}; use std::{ cmp, @@ -558,8 +559,7 @@ impl Render for PromptEditor { PopoverMenu::new("model-switcher") .menu(move |cx| { ContextMenu::build(cx, |mut menu, cx| { - for model in CompletionProvider::global(cx).available_models(cx) - { + for model in CompletionProvider::global(cx).available_models() { menu = menu.custom_entry( { let model = model.clone(); diff --git a/crates/collab/Cargo.toml b/crates/collab/Cargo.toml index a413a464895e79ad9ef229c6bb47c98dd53458dc..cf99e7c90cff1b0ae9cf2affc5ae8d493c11d6f7 100644 --- a/crates/collab/Cargo.toml +++ b/crates/collab/Cargo.toml @@ -30,6 +30,7 @@ chrono.workspace = true clock.workspace = true clickhouse.workspace = true collections.workspace = true +completion.workspace = true dashmap = "5.4" envy = "0.4.2" futures.workspace = true @@ -79,6 +80,7 @@ channel.workspace = true client = { workspace = true, features = ["test-support"] } collab_ui = { workspace = true, features = ["test-support"] } collections = { workspace = true, features = ["test-support"] } +completion = { workspace = true, features = ["test-support"] } ctor.workspace = true editor = { workspace = true, features = ["test-support"] } env_logger.workspace = true diff --git a/crates/collab/src/tests/test_server.rs b/crates/collab/src/tests/test_server.rs index a3eafd0f9467992ed6e51a04e5218f70079c35c7..61c0a8239de0a45443ebcdcf283e2be4e4c0fff5 100644 --- a/crates/collab/src/tests/test_server.rs +++ b/crates/collab/src/tests/test_server.rs @@ -295,7 +295,7 @@ impl TestServer { menu::init(); dev_server_projects::init(client.clone(), cx); settings::KeymapFile::load_asset(os_keymap, cx).unwrap(); - assistant::FakeCompletionProvider::setup_test(cx); + completion::FakeCompletionProvider::setup_test(cx); assistant::context_store::init(&client); }); diff --git a/crates/completion/Cargo.toml b/crates/completion/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..18181e7bb5d3df89891223baa25c9fe997601641 --- /dev/null +++ b/crates/completion/Cargo.toml @@ -0,0 +1,56 @@ +[package] +name = "completion" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/completion.rs" +doctest = false + +[features] +test-support = [ + "editor/test-support", + "language/test-support", + "project/test-support", + "text/test-support", +] + +[dependencies] +anthropic = { workspace = true, features = ["schemars"] } +anyhow.workspace = true +client.workspace = true +collections.workspace = true +editor.workspace = true +futures.workspace = true +gpui.workspace = true +http.workspace = true +language_model.workspace = true +log.workspace = true +menu.workspace = true +ollama = { workspace = true, features = ["schemars"] } +open_ai = { workspace = true, features = ["schemars"] } +parking_lot.workspace = true +serde.workspace = true +serde_json.workspace = true +settings.workspace = true +smol.workspace = true +strum.workspace = true +theme.workspace = true +tiktoken-rs.workspace = true +ui.workspace = true +util.workspace = true + +[dev-dependencies] +ctor.workspace = true +editor = { workspace = true, features = ["test-support"] } +env_logger.workspace = true +language = { workspace = true, features = ["test-support"] } +project = { workspace = true, features = ["test-support"] } +rand.workspace = true +text = { workspace = true, features = ["test-support"] } +unindent.workspace = true diff --git a/crates/completion/LICENSE-GPL b/crates/completion/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/completion/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/assistant/src/completion_provider/anthropic.rs b/crates/completion/src/anthropic.rs similarity index 86% rename from crates/assistant/src/completion_provider/anthropic.rs rename to crates/completion/src/anthropic.rs index 48d2020cbee15978aa74859bf42dd6020a29cbf1..dc71ebd8ca47fb7b3033ae69df3d59764bbefaca 100644 --- a/crates/assistant/src/completion_provider/anthropic.rs +++ b/crates/completion/src/anthropic.rs @@ -1,14 +1,12 @@ -use crate::{ - assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest, - Role, -}; -use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage}; -use anthropic::{stream_completion, Request, RequestMessage}; +use crate::{count_open_ai_tokens, LanguageModelCompletionProvider}; +use crate::{CompletionProvider, LanguageModel, LanguageModelRequest}; +use anthropic::{stream_completion, Model as AnthropicModel, Request, RequestMessage}; use anyhow::{anyhow, Result}; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace}; use http::HttpClient; +use language_model::Role; use settings::Settings; use std::time::Duration; use std::{env, sync::Arc}; @@ -27,7 +25,7 @@ pub struct AnthropicCompletionProvider { } impl LanguageModelCompletionProvider for AnthropicCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { AnthropicModel::iter() .map(LanguageModel::Anthropic) .collect() @@ -176,7 +174,7 @@ impl AnthropicCompletionProvider { } fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request { - preprocess_anthropic_request(&mut request); + request.preprocess_anthropic(); let model = match request.model { LanguageModel::Anthropic(model) => model, @@ -213,49 +211,6 @@ impl AnthropicCompletionProvider { } } -pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) { - let mut new_messages: Vec = Vec::new(); - let mut system_message = String::new(); - - for message in request.messages.drain(..) { - if message.content.is_empty() { - continue; - } - - match message.role { - Role::User | Role::Assistant => { - if let Some(last_message) = new_messages.last_mut() { - if last_message.role == message.role { - last_message.content.push_str("\n\n"); - last_message.content.push_str(&message.content); - continue; - } - } - - new_messages.push(message); - } - Role::System => { - if !system_message.is_empty() { - system_message.push_str("\n\n"); - } - system_message.push_str(&message.content); - } - } - } - - if !system_message.is_empty() { - new_messages.insert( - 0, - LanguageModelRequestMessage { - role: Role::System, - content: system_message, - }, - ); - } - - request.messages = new_messages; -} - struct AuthenticationPrompt { api_key: View, api_url: String, diff --git a/crates/assistant/src/completion_provider/cloud.rs b/crates/completion/src/cloud.rs similarity index 96% rename from crates/assistant/src/completion_provider/cloud.rs rename to crates/completion/src/cloud.rs index 32b8587116c17f31b249fc7cd56c90baac150585..f84576aeca10168d4c9a24624cbc591d000d2057 100644 --- a/crates/assistant/src/completion_provider/cloud.rs +++ b/crates/completion/src/cloud.rs @@ -1,11 +1,12 @@ use crate::{ - assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel, - LanguageModelCompletionProvider, LanguageModelRequest, + count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider, + LanguageModelRequest, }; use anyhow::{anyhow, Result}; use client::{proto, Client}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt}; use gpui::{AnyView, AppContext, Task}; +use language_model::CloudModel; use std::{future, sync::Arc}; use strum::IntoEnumIterator; use ui::prelude::*; @@ -52,7 +53,7 @@ impl CloudCompletionProvider { } impl LanguageModelCompletionProvider for CloudCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() { Some(custom_model) } else { diff --git a/crates/assistant/src/completion_provider.rs b/crates/completion/src/completion.rs similarity index 57% rename from crates/assistant/src/completion_provider.rs rename to crates/completion/src/completion.rs index 13f91f70e32b19a2f94466d722b7ef4d37c36075..a219e90b51c75c1235213431a46d30fd29fee8ed 100644 --- a/crates/assistant/src/completion_provider.rs +++ b/crates/completion/src/completion.rs @@ -6,52 +6,19 @@ mod ollama; mod open_ai; pub use anthropic::*; +use anyhow::Result; +use client::Client; pub use cloud::*; #[cfg(any(test, feature = "test-support"))] pub use fake::*; +use futures::{future::BoxFuture, stream::BoxStream, StreamExt}; +use gpui::{AnyView, AppContext, Task, WindowContext}; +use language_model::{LanguageModel, LanguageModelRequest}; pub use ollama::*; pub use open_ai::*; use parking_lot::RwLock; use smol::lock::{Semaphore, SemaphoreGuardArc}; - -use crate::{ - assistant_settings::{AssistantProvider, AssistantSettings}, - LanguageModel, LanguageModelRequest, -}; -use anyhow::Result; -use client::Client; -use futures::{future::BoxFuture, stream::BoxStream, StreamExt}; -use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext}; -use settings::{Settings, SettingsStore}; -use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration}; - -/// Choose which model to use for openai provider. -/// If the model is not available, try to use the first available model, or fallback to the original model. -fn choose_openai_model( - model: &::open_ai::Model, - available_models: &[::open_ai::Model], -) -> ::open_ai::Model { - available_models - .iter() - .find(|&m| m == model) - .or_else(|| available_models.first()) - .unwrap_or_else(|| model) - .clone() -} - -pub fn init(client: Arc, cx: &mut AppContext) { - let provider = create_provider_from_settings(client.clone(), 0, cx); - cx.set_global(CompletionProvider::new(provider, Some(client))); - - let mut settings_version = 0; - cx.observe_global::(move |cx| { - settings_version += 1; - cx.update_global::(|provider, cx| { - provider.update_settings(settings_version, cx); - }) - }) - .detach(); -} +use std::{any::Any, pin::Pin, sync::Arc, task::Poll}; pub struct CompletionResponse { inner: BoxStream<'static, Result>, @@ -70,7 +37,7 @@ impl futures::Stream for CompletionResponse { } pub trait LanguageModelCompletionProvider: Send + Sync { - fn available_models(&self, cx: &AppContext) -> Vec; + fn available_models(&self) -> Vec; fn settings_version(&self) -> usize; fn is_authenticated(&self) -> bool; fn authenticate(&self, cx: &AppContext) -> Task>; @@ -110,8 +77,8 @@ impl CompletionProvider { } } - pub fn available_models(&self, cx: &AppContext) -> Vec { - self.provider.read().available_models(cx) + pub fn available_models(&self) -> Vec { + self.provider.read().available_models() } pub fn settings_version(&self) -> usize { @@ -176,6 +143,17 @@ impl CompletionProvider { Ok(completion) }) } + + pub fn update_provider( + &mut self, + get_provider: impl FnOnce(Arc) -> Arc>, + ) { + if let Some(client) = &self.client { + self.provider = get_provider(Arc::clone(client)); + } else { + log::warn!("completion provider cannot be updated because its client was not set"); + } + } } impl gpui::Global for CompletionProvider {} @@ -196,109 +174,6 @@ impl CompletionProvider { None } } - - pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) { - let updated = match &AssistantSettings::get_global(cx).provider { - AssistantProvider::ZedDotDev { model } => self - .update_current_as::<_, CloudCompletionProvider>(|provider| { - provider.update(model.clone(), version); - }), - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - } => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| { - provider.update( - choose_openai_model(&model, &available_models), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - ); - }), - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - } => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| { - provider.update( - model.clone(), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - ); - }), - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - } => self.update_current_as::<_, OllamaCompletionProvider>(|provider| { - provider.update( - model.clone(), - api_url.clone(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - version, - cx, - ); - }), - }; - - // Previously configured provider was changed to another one - if updated.is_none() { - if let Some(client) = self.client.clone() { - self.provider = create_provider_from_settings(client, version, cx); - } else { - log::warn!("completion provider cannot be created because client is not set"); - } - } - } -} - -fn create_provider_from_settings( - client: Arc, - settings_version: usize, - cx: &mut AppContext, -) -> Arc> { - match &AssistantSettings::get_global(cx).provider { - AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new( - CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx), - )), - AssistantProvider::OpenAi { - model, - api_url, - low_speed_timeout_in_seconds, - available_models, - } => Arc::new(RwLock::new(OpenAiCompletionProvider::new( - choose_openai_model(&model, &available_models), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - ))), - AssistantProvider::Anthropic { - model, - api_url, - low_speed_timeout_in_seconds, - } => Arc::new(RwLock::new(AnthropicCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - ))), - AssistantProvider::Ollama { - model, - api_url, - low_speed_timeout_in_seconds, - } => Arc::new(RwLock::new(OllamaCompletionProvider::new( - model.clone(), - api_url.clone(), - client.http_client(), - low_speed_timeout_in_seconds.map(Duration::from_secs), - settings_version, - cx, - ))), - } } #[cfg(test)] @@ -311,8 +186,8 @@ mod tests { use smol::stream::StreamExt; use crate::{ - completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider, - FakeCompletionProvider, LanguageModelRequest, + CompletionProvider, FakeCompletionProvider, LanguageModelRequest, + MAX_CONCURRENT_COMPLETION_REQUESTS, }; #[gpui::test] diff --git a/crates/assistant/src/completion_provider/fake.rs b/crates/completion/src/fake.rs similarity index 97% rename from crates/assistant/src/completion_provider/fake.rs rename to crates/completion/src/fake.rs index e9ad8d9a0faa5c69c4c0e8c733dce8e02c7c6240..9eee0f736fa20f1c4b0c1c639d43408d4221c164 100644 --- a/crates/assistant/src/completion_provider/fake.rs +++ b/crates/completion/src/fake.rs @@ -62,7 +62,7 @@ impl FakeCompletionProvider { } impl LanguageModelCompletionProvider for FakeCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { vec![LanguageModel::default()] } diff --git a/crates/assistant/src/completion_provider/ollama.rs b/crates/completion/src/ollama.rs similarity index 96% rename from crates/assistant/src/completion_provider/ollama.rs rename to crates/completion/src/ollama.rs index 59d79e3ae7d5114533047a3f237a086a02bcaa5e..30d797c76b452c7bc0bba6b43dab3b161acc9471 100644 --- a/crates/assistant/src/completion_provider/ollama.rs +++ b/crates/completion/src/ollama.rs @@ -1,15 +1,14 @@ use crate::LanguageModelCompletionProvider; -use crate::{ - assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, -}; +use crate::{CompletionProvider, LanguageModel, LanguageModelRequest}; use anyhow::Result; use futures::StreamExt as _; use futures::{future::BoxFuture, stream::BoxStream, FutureExt}; use gpui::{AnyView, AppContext, Task}; use http::HttpClient; +use language_model::Role; +use ollama::Model as OllamaModel; use ollama::{ get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest, - Role as OllamaRole, }; use std::sync::Arc; use std::time::Duration; @@ -28,7 +27,7 @@ pub struct OllamaCompletionProvider { } impl LanguageModelCompletionProvider for OllamaCompletionProvider { - fn available_models(&self, _cx: &AppContext) -> Vec { + fn available_models(&self) -> Vec { self.available_models .iter() .map(|m| LanguageModel::Ollama(m.clone())) @@ -262,16 +261,6 @@ impl OllamaCompletionProvider { } } -impl From for ollama::Role { - fn from(val: Role) -> Self { - match val { - Role::User => OllamaRole::User, - Role::Assistant => OllamaRole::Assistant, - Role::System => OllamaRole::System, - } - } -} - struct DownloadOllamaMessage { retry_connection: Box Task>>, } diff --git a/crates/assistant/src/completion_provider/open_ai.rs b/crates/completion/src/open_ai.rs similarity index 89% rename from crates/assistant/src/completion_provider/open_ai.rs rename to crates/completion/src/open_ai.rs index fd65d1afe513438f8501e75d9df933e747544333..0a0f6d5b4a3d512535c960e7d49a8784b1317018 100644 --- a/crates/assistant/src/completion_provider/open_ai.rs +++ b/crates/completion/src/open_ai.rs @@ -1,15 +1,13 @@ -use crate::assistant_settings::CloudModel; -use crate::assistant_settings::{AssistantProvider, AssistantSettings}; +use crate::CompletionProvider; use crate::LanguageModelCompletionProvider; -use crate::{ - assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role, -}; use anyhow::{anyhow, Result}; use editor::{Editor, EditorElement, EditorStyle}; use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt}; use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace}; use http::HttpClient; -use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole}; +use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role}; +use open_ai::Model as OpenAiModel; +use open_ai::{stream_completion, Request, RequestMessage}; use settings::Settings; use std::time::Duration; use std::{env, sync::Arc}; @@ -25,6 +23,7 @@ pub struct OpenAiCompletionProvider { http_client: Arc, low_speed_timeout: Option, settings_version: usize, + available_models_from_settings: Vec, } impl OpenAiCompletionProvider { @@ -34,6 +33,7 @@ impl OpenAiCompletionProvider { http_client: Arc, low_speed_timeout: Option, settings_version: usize, + available_models_from_settings: Vec, ) -> Self { Self { api_key: None, @@ -42,6 +42,7 @@ impl OpenAiCompletionProvider { http_client, low_speed_timeout, settings_version, + available_models_from_settings, } } @@ -92,30 +93,26 @@ impl OpenAiCompletionProvider { } impl LanguageModelCompletionProvider for OpenAiCompletionProvider { - fn available_models(&self, cx: &AppContext) -> Vec { - if let AssistantProvider::OpenAi { - available_models, .. - } = &AssistantSettings::get_global(cx).provider - { - if !available_models.is_empty() { - return available_models - .iter() - .cloned() - .map(LanguageModel::OpenAi) - .collect(); - } - } - let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) { - vec![self.model.clone()] + fn available_models(&self) -> Vec { + if self.available_models_from_settings.is_empty() { + let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) { + vec![self.model.clone()] + } else { + OpenAiModel::iter() + .filter(|model| !matches!(model, OpenAiModel::Custom { .. })) + .collect() + }; + available_models + .into_iter() + .map(LanguageModel::OpenAi) + .collect() } else { - OpenAiModel::iter() - .filter(|model| !matches!(model, OpenAiModel::Custom { .. })) + self.available_models_from_settings + .iter() + .cloned() + .map(LanguageModel::OpenAi) .collect() - }; - available_models - .into_iter() - .map(LanguageModel::OpenAi) - .collect() + } } fn settings_version(&self) -> usize { @@ -255,16 +252,6 @@ pub fn count_open_ai_tokens( .boxed() } -impl From for open_ai::Role { - fn from(val: Role) -> Self { - match val { - Role::User => OpenAiRole::User, - Role::Assistant => OpenAiRole::Assistant, - Role::System => OpenAiRole::System, - } - } -} - struct AuthenticationPrompt { api_key: View, api_url: String, diff --git a/crates/language_model/Cargo.toml b/crates/language_model/Cargo.toml new file mode 100644 index 0000000000000000000000000000000000000000..bdc3ad63d55a7be56f5762f22ea7b4d26cbcf0ec --- /dev/null +++ b/crates/language_model/Cargo.toml @@ -0,0 +1,41 @@ +[package] +name = "language_model" +version = "0.1.0" +edition = "2021" +publish = false +license = "GPL-3.0-or-later" + +[lints] +workspace = true + +[lib] +path = "src/language_model.rs" +doctest = false + +[features] +test-support = [ + "editor/test-support", + "language/test-support", + "project/test-support", + "text/test-support", +] + +[dependencies] +anthropic = { workspace = true, features = ["schemars"] } +ollama = { workspace = true, features = ["schemars"] } +open_ai = { workspace = true, features = ["schemars"] } +schemars.workspace = true +serde.workspace = true +strum.workspace = true +proto = { workspace = true, features = ["test-support"] } + +[dev-dependencies] +ctor.workspace = true +editor = { workspace = true, features = ["test-support"] } +env_logger.workspace = true +language = { workspace = true, features = ["test-support"] } +log.workspace = true +project = { workspace = true, features = ["test-support"] } +rand.workspace = true +text = { workspace = true, features = ["test-support"] } +unindent.workspace = true diff --git a/crates/language_model/LICENSE-GPL b/crates/language_model/LICENSE-GPL new file mode 120000 index 0000000000000000000000000000000000000000..89e542f750cd3860a0598eff0dc34b56d7336dc4 --- /dev/null +++ b/crates/language_model/LICENSE-GPL @@ -0,0 +1 @@ +../../LICENSE-GPL \ No newline at end of file diff --git a/crates/language_model/src/language_model.rs b/crates/language_model/src/language_model.rs new file mode 100644 index 0000000000000000000000000000000000000000..09de409ff47ecdf9e84fbccfc70153a7aa29fcd6 --- /dev/null +++ b/crates/language_model/src/language_model.rs @@ -0,0 +1,7 @@ +mod model; +mod request; +mod role; + +pub use model::*; +pub use request::*; +pub use role::*; diff --git a/crates/language_model/src/model/cloud_model.rs b/crates/language_model/src/model/cloud_model.rs new file mode 100644 index 0000000000000000000000000000000000000000..20b2bf7d4f90e59dba89bc5e05e56453cb8fa98d --- /dev/null +++ b/crates/language_model/src/model/cloud_model.rs @@ -0,0 +1,160 @@ +use crate::LanguageModelRequest; +pub use anthropic::Model as AnthropicModel; +pub use ollama::Model as OllamaModel; +pub use open_ai::Model as OpenAiModel; +use schemars::{ + schema::{InstanceType, Metadata, Schema, SchemaObject}, + JsonSchema, +}; +use serde::{ + de::{self, Visitor}, + Deserialize, Deserializer, Serialize, Serializer, +}; +use std::fmt; +use strum::{EnumIter, IntoEnumIterator}; + +#[derive(Clone, Debug, Default, PartialEq, EnumIter)] +pub enum CloudModel { + Gpt3Point5Turbo, + Gpt4, + Gpt4Turbo, + #[default] + Gpt4Omni, + Gpt4OmniMini, + Claude3_5Sonnet, + Claude3Opus, + Claude3Sonnet, + Claude3Haiku, + Gemini15Pro, + Gemini15Flash, + Custom(String), +} + +impl Serialize for CloudModel { + fn serialize(&self, serializer: S) -> Result + where + S: Serializer, + { + serializer.serialize_str(self.id()) + } +} + +impl<'de> Deserialize<'de> for CloudModel { + fn deserialize(deserializer: D) -> Result + where + D: Deserializer<'de>, + { + struct ZedDotDevModelVisitor; + + impl<'de> Visitor<'de> for ZedDotDevModelVisitor { + type Value = CloudModel; + + fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result { + formatter.write_str("a string for a ZedDotDevModel variant or a custom model") + } + + fn visit_str(self, value: &str) -> Result + where + E: de::Error, + { + let model = CloudModel::iter() + .find(|model| model.id() == value) + .unwrap_or_else(|| CloudModel::Custom(value.to_string())); + Ok(model) + } + } + + deserializer.deserialize_str(ZedDotDevModelVisitor) + } +} + +impl JsonSchema for CloudModel { + fn schema_name() -> String { + "ZedDotDevModel".to_owned() + } + + fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema { + let variants = CloudModel::iter() + .filter_map(|model| { + let id = model.id(); + if id.is_empty() { + None + } else { + Some(id.to_string()) + } + }) + .collect::>(); + Schema::Object(SchemaObject { + instance_type: Some(InstanceType::String.into()), + enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()), + metadata: Some(Box::new(Metadata { + title: Some("ZedDotDevModel".to_owned()), + default: Some(CloudModel::default().id().into()), + examples: variants.into_iter().map(Into::into).collect(), + ..Default::default() + })), + ..Default::default() + }) + } +} + +impl CloudModel { + pub fn id(&self) -> &str { + match self { + Self::Gpt3Point5Turbo => "gpt-3.5-turbo", + Self::Gpt4 => "gpt-4", + Self::Gpt4Turbo => "gpt-4-turbo-preview", + Self::Gpt4Omni => "gpt-4o", + Self::Gpt4OmniMini => "gpt-4o-mini", + Self::Claude3_5Sonnet => "claude-3-5-sonnet", + Self::Claude3Opus => "claude-3-opus", + Self::Claude3Sonnet => "claude-3-sonnet", + Self::Claude3Haiku => "claude-3-haiku", + Self::Gemini15Pro => "gemini-1.5-pro", + Self::Gemini15Flash => "gemini-1.5-flash", + Self::Custom(id) => id, + } + } + + pub fn display_name(&self) -> &str { + match self { + Self::Gpt3Point5Turbo => "GPT 3.5 Turbo", + Self::Gpt4 => "GPT 4", + Self::Gpt4Turbo => "GPT 4 Turbo", + Self::Gpt4Omni => "GPT 4 Omni", + Self::Gpt4OmniMini => "GPT 4 Omni Mini", + Self::Claude3_5Sonnet => "Claude 3.5 Sonnet", + Self::Claude3Opus => "Claude 3 Opus", + Self::Claude3Sonnet => "Claude 3 Sonnet", + Self::Claude3Haiku => "Claude 3 Haiku", + Self::Gemini15Pro => "Gemini 1.5 Pro", + Self::Gemini15Flash => "Gemini 1.5 Flash", + Self::Custom(id) => id.as_str(), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + Self::Gpt3Point5Turbo => 2048, + Self::Gpt4 => 4096, + Self::Gpt4Turbo | Self::Gpt4Omni => 128000, + Self::Gpt4OmniMini => 128000, + Self::Claude3_5Sonnet + | Self::Claude3Opus + | Self::Claude3Sonnet + | Self::Claude3Haiku => 200000, + Self::Gemini15Pro => 128000, + Self::Gemini15Flash => 32000, + Self::Custom(_) => 4096, // TODO: Make this configurable + } + } + + pub fn preprocess_request(&self, request: &mut LanguageModelRequest) { + match self { + Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => { + request.preprocess_anthropic() + } + _ => {} + } + } +} diff --git a/crates/language_model/src/model/mod.rs b/crates/language_model/src/model/mod.rs new file mode 100644 index 0000000000000000000000000000000000000000..b61766308f44b68f9978c1b1fc83d7b25cb77ffa --- /dev/null +++ b/crates/language_model/src/model/mod.rs @@ -0,0 +1,60 @@ +pub mod cloud_model; + +pub use anthropic::Model as AnthropicModel; +pub use cloud_model::*; +pub use ollama::Model as OllamaModel; +pub use open_ai::Model as OpenAiModel; + +use serde::{Deserialize, Serialize}; + +#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)] +pub enum LanguageModel { + Cloud(CloudModel), + OpenAi(OpenAiModel), + Anthropic(AnthropicModel), + Ollama(OllamaModel), +} + +impl Default for LanguageModel { + fn default() -> Self { + LanguageModel::Cloud(CloudModel::default()) + } +} + +impl LanguageModel { + pub fn telemetry_id(&self) -> String { + match self { + LanguageModel::OpenAi(model) => format!("openai/{}", model.id()), + LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()), + LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()), + LanguageModel::Ollama(model) => format!("ollama/{}", model.id()), + } + } + + pub fn display_name(&self) -> String { + match self { + LanguageModel::OpenAi(model) => model.display_name().into(), + LanguageModel::Anthropic(model) => model.display_name().into(), + LanguageModel::Cloud(model) => model.display_name().into(), + LanguageModel::Ollama(model) => model.display_name().into(), + } + } + + pub fn max_token_count(&self) -> usize { + match self { + LanguageModel::OpenAi(model) => model.max_token_count(), + LanguageModel::Anthropic(model) => model.max_token_count(), + LanguageModel::Cloud(model) => model.max_token_count(), + LanguageModel::Ollama(model) => model.max_token_count(), + } + } + + pub fn id(&self) -> &str { + match self { + LanguageModel::OpenAi(model) => model.id(), + LanguageModel::Anthropic(model) => model.id(), + LanguageModel::Cloud(model) => model.id(), + LanguageModel::Ollama(model) => model.id(), + } + } +} diff --git a/crates/language_model/src/request.rs b/crates/language_model/src/request.rs new file mode 100644 index 0000000000000000000000000000000000000000..f9c4322cdff4f613923851c41a5c67e097c3b603 --- /dev/null +++ b/crates/language_model/src/request.rs @@ -0,0 +1,110 @@ +use crate::{ + model::{CloudModel, LanguageModel}, + role::Role, +}; +use serde::{Deserialize, Serialize}; + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LanguageModelRequestMessage { + pub role: Role, + pub content: String, +} + +impl LanguageModelRequestMessage { + pub fn to_proto(&self) -> proto::LanguageModelRequestMessage { + proto::LanguageModelRequestMessage { + role: self.role.to_proto() as i32, + content: self.content.clone(), + tool_calls: Vec::new(), + tool_call_id: None, + } + } +} + +#[derive(Debug, Default, Serialize, Deserialize)] +pub struct LanguageModelRequest { + pub model: LanguageModel, + pub messages: Vec, + pub stop: Vec, + pub temperature: f32, +} + +impl LanguageModelRequest { + pub fn to_proto(&self) -> proto::CompleteWithLanguageModel { + proto::CompleteWithLanguageModel { + model: self.model.id().to_string(), + messages: self.messages.iter().map(|m| m.to_proto()).collect(), + stop: self.stop.clone(), + temperature: self.temperature, + tool_choice: None, + tools: Vec::new(), + } + } + + /// Before we send the request to the server, we can perform fixups on it appropriate to the model. + pub fn preprocess(&mut self) { + match &self.model { + LanguageModel::OpenAi(_) => {} + LanguageModel::Anthropic(_) => {} + LanguageModel::Ollama(_) => {} + LanguageModel::Cloud(model) => match model { + CloudModel::Claude3Opus + | CloudModel::Claude3Sonnet + | CloudModel::Claude3Haiku + | CloudModel::Claude3_5Sonnet => { + self.preprocess_anthropic(); + } + _ => {} + }, + } + } + + pub fn preprocess_anthropic(&mut self) { + let mut new_messages: Vec = Vec::new(); + let mut system_message = String::new(); + + for message in self.messages.drain(..) { + if message.content.is_empty() { + continue; + } + + match message.role { + Role::User | Role::Assistant => { + if let Some(last_message) = new_messages.last_mut() { + if last_message.role == message.role { + last_message.content.push_str("\n\n"); + last_message.content.push_str(&message.content); + continue; + } + } + + new_messages.push(message); + } + Role::System => { + if !system_message.is_empty() { + system_message.push_str("\n\n"); + } + system_message.push_str(&message.content); + } + } + } + + if !system_message.is_empty() { + new_messages.insert( + 0, + LanguageModelRequestMessage { + role: Role::System, + content: system_message, + }, + ); + } + + self.messages = new_messages; + } +} + +#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)] +pub struct LanguageModelResponseMessage { + pub role: Option, + pub content: Option, +} diff --git a/crates/language_model/src/role.rs b/crates/language_model/src/role.rs new file mode 100644 index 0000000000000000000000000000000000000000..f6276a4823651c200b207c44eb612e16283ac913 --- /dev/null +++ b/crates/language_model/src/role.rs @@ -0,0 +1,68 @@ +use serde::{Deserialize, Serialize}; +use std::fmt::{self, Display}; + +#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)] +#[serde(rename_all = "lowercase")] +pub enum Role { + User, + Assistant, + System, +} + +impl Role { + pub fn from_proto(role: i32) -> Role { + match proto::LanguageModelRole::from_i32(role) { + Some(proto::LanguageModelRole::LanguageModelUser) => Role::User, + Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant, + Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System, + Some(proto::LanguageModelRole::LanguageModelTool) => Role::System, + None => Role::User, + } + } + + pub fn to_proto(&self) -> proto::LanguageModelRole { + match self { + Role::User => proto::LanguageModelRole::LanguageModelUser, + Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant, + Role::System => proto::LanguageModelRole::LanguageModelSystem, + } + } + + pub fn cycle(self) -> Role { + match self { + Role::User => Role::Assistant, + Role::Assistant => Role::System, + Role::System => Role::User, + } + } +} + +impl Display for Role { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Role::User => write!(f, "user"), + Role::Assistant => write!(f, "assistant"), + Role::System => write!(f, "system"), + } + } +} + +impl From for ollama::Role { + fn from(val: Role) -> Self { + match val { + Role::User => ollama::Role::User, + Role::Assistant => ollama::Role::Assistant, + Role::System => ollama::Role::System, + } + } +} + +impl From for open_ai::Role { + fn from(val: Role) -> Self { + match val { + Role::User => open_ai::Role::User, + Role::Assistant => open_ai::Role::Assistant, + Role::System => open_ai::Role::System, + } + } +} diff --git a/crates/semantic_index/Cargo.toml b/crates/semantic_index/Cargo.toml index 3f49490941f767157d235778fff12b98694168f2..19cb0c96fee452bdc7cf9f62e3b22beea3350fd0 100644 --- a/crates/semantic_index/Cargo.toml +++ b/crates/semantic_index/Cargo.toml @@ -22,6 +22,7 @@ anyhow.workspace = true client.workspace = true clock.workspace = true collections.workspace = true +completion.workspace = true fs.workspace = true futures.workspace = true futures-batch.workspace = true diff --git a/crates/semantic_index/src/semantic_index.rs b/crates/semantic_index/src/semantic_index.rs index 7a29f3be2526663f2139d1c2510f5d31f58fbc20..4c43fc1e468ed7b3a7f5fb876714edf1a8803115 100644 --- a/crates/semantic_index/src/semantic_index.rs +++ b/crates/semantic_index/src/semantic_index.rs @@ -1261,3 +1261,6 @@ mod tests { ); } } + +// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed. +type _TODO = completion::CompletionProvider;