Detailed changes
@@ -382,6 +382,7 @@ dependencies = [
"clock",
"collections",
"command_palette_hooks",
+ "completion",
"ctor",
"editor",
"env_logger",
@@ -396,6 +397,7 @@ dependencies = [
"indexed_docs",
"indoc",
"language",
+ "language_model",
"log",
"menu",
"multi_buffer",
@@ -418,13 +420,11 @@ dependencies = [
"settings",
"similar",
"smol",
- "strum",
"telemetry_events",
"terminal",
"terminal_view",
"text",
"theme",
- "tiktoken-rs",
"toml 0.8.10",
"ui",
"unindent",
@@ -2491,6 +2491,7 @@ dependencies = [
"clock",
"collab_ui",
"collections",
+ "completion",
"ctor",
"dashmap",
"dev_server_projects",
@@ -2673,6 +2674,42 @@ dependencies = [
"gpui",
]
+[[package]]
+name = "completion"
+version = "0.1.0"
+dependencies = [
+ "anthropic",
+ "anyhow",
+ "client",
+ "collections",
+ "ctor",
+ "editor",
+ "env_logger",
+ "futures 0.3.28",
+ "gpui",
+ "http 0.1.0",
+ "language",
+ "language_model",
+ "log",
+ "menu",
+ "ollama",
+ "open_ai",
+ "parking_lot",
+ "project",
+ "rand 0.8.5",
+ "serde",
+ "serde_json",
+ "settings",
+ "smol",
+ "strum",
+ "text",
+ "theme",
+ "tiktoken-rs",
+ "ui",
+ "unindent",
+ "util",
+]
+
[[package]]
name = "concurrent-queue"
version = "2.2.0"
@@ -5996,6 +6033,28 @@ dependencies = [
"util",
]
+[[package]]
+name = "language_model"
+version = "0.1.0"
+dependencies = [
+ "anthropic",
+ "ctor",
+ "editor",
+ "env_logger",
+ "language",
+ "log",
+ "ollama",
+ "open_ai",
+ "project",
+ "proto",
+ "rand 0.8.5",
+ "schemars",
+ "serde",
+ "strum",
+ "text",
+ "unindent",
+]
+
[[package]]
name = "language_selector"
version = "0.1.0"
@@ -9510,6 +9569,7 @@ dependencies = [
"client",
"clock",
"collections",
+ "completion",
"env_logger",
"fs",
"futures 0.3.28",
@@ -19,6 +19,7 @@ members = [
"crates/collections",
"crates/command_palette",
"crates/command_palette_hooks",
+ "crates/completion",
"crates/copilot",
"crates/db",
"crates/dev_server_projects",
@@ -50,6 +51,7 @@ members = [
"crates/install_cli",
"crates/journal",
"crates/language",
+ "crates/language_model",
"crates/language_selector",
"crates/language_tools",
"crates/languages",
@@ -176,6 +178,7 @@ collab_ui = { path = "crates/collab_ui" }
collections = { path = "crates/collections" }
command_palette = { path = "crates/command_palette" }
command_palette_hooks = { path = "crates/command_palette_hooks" }
+completion = { path = "crates/completion" }
copilot = { path = "crates/copilot" }
db = { path = "crates/db" }
dev_server_projects = { path = "crates/dev_server_projects" }
@@ -205,6 +208,7 @@ inline_completion_button = { path = "crates/inline_completion_button" }
install_cli = { path = "crates/install_cli" }
journal = { path = "crates/journal" }
language = { path = "crates/language" }
+language_model = { path = "crates/language_model" }
language_selector = { path = "crates/language_selector" }
language_tools = { path = "crates/language_tools" }
languages = { path = "crates/languages" }
@@ -33,6 +33,7 @@ client.workspace = true
clock.workspace = true
collections.workspace = true
command_palette_hooks.workspace = true
+completion.workspace = true
editor.workspace = true
feature_flags.workspace = true
fs.workspace = true
@@ -45,6 +46,7 @@ http.workspace = true
indexed_docs.workspace = true
indoc.workspace = true
language.workspace = true
+language_model.workspace = true
log.workspace = true
menu.workspace = true
multi_buffer.workspace = true
@@ -64,12 +66,10 @@ serde_json.workspace = true
settings.workspace = true
similar.workspace = true
smol.workspace = true
-strum.workspace = true
telemetry_events.workspace = true
terminal.workspace = true
terminal_view.workspace = true
theme.workspace = true
-tiktoken-rs.workspace = true
toml.workspace = true
ui.workspace = true
util.workspace = true
@@ -79,6 +79,7 @@ picker.workspace = true
roxmltree = "0.20.0"
[dev-dependencies]
+completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
@@ -1,6 +1,5 @@
pub mod assistant_panel;
pub mod assistant_settings;
-mod completion_provider;
mod context;
pub mod context_store;
mod inline_assistant;
@@ -12,17 +11,20 @@ mod streaming_diff;
mod terminal_inline_assistant;
pub use assistant_panel::{AssistantPanel, AssistantPanelEvent};
-use assistant_settings::{AnthropicModel, AssistantSettings, CloudModel, OllamaModel, OpenAiModel};
+use assistant_settings::AssistantSettings;
use assistant_slash_command::SlashCommandRegistry;
use client::{proto, Client};
use command_palette_hooks::CommandPaletteFilter;
-pub use completion_provider::*;
+use completion::CompletionProvider;
pub use context::*;
pub use context_store::*;
use fs::Fs;
-use gpui::{actions, impl_actions, AppContext, Global, SharedString, UpdateGlobal};
+use gpui::{
+ actions, impl_actions, AppContext, BorrowAppContext, Global, SharedString, UpdateGlobal,
+};
use indexed_docs::IndexedDocsRegistry;
pub(crate) use inline_assistant::*;
+use language_model::LanguageModelResponseMessage;
pub(crate) use model_selector::*;
use semantic_index::{CloudEmbeddingProvider, SemanticIndex};
use serde::{Deserialize, Serialize};
@@ -32,10 +34,7 @@ use slash_command::{
file_command, now_command, project_command, prompt_command, search_command, symbols_command,
tabs_command, term_command,
};
-use std::{
- fmt::{self, Display},
- sync::Arc,
-};
+use std::sync::Arc;
pub(crate) use streaming_diff::*;
actions!(
@@ -73,166 +72,6 @@ impl MessageId {
}
}
-#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
-#[serde(rename_all = "lowercase")]
-pub enum Role {
- User,
- Assistant,
- System,
-}
-
-impl Role {
- pub fn from_proto(role: i32) -> Role {
- match proto::LanguageModelRole::from_i32(role) {
- Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
- Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
- Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
- Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
- None => Role::User,
- }
- }
-
- pub fn to_proto(&self) -> proto::LanguageModelRole {
- match self {
- Role::User => proto::LanguageModelRole::LanguageModelUser,
- Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
- Role::System => proto::LanguageModelRole::LanguageModelSystem,
- }
- }
-
- pub fn cycle(self) -> Role {
- match self {
- Role::User => Role::Assistant,
- Role::Assistant => Role::System,
- Role::System => Role::User,
- }
- }
-}
-
-impl Display for Role {
- fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
- match self {
- Role::User => write!(f, "user"),
- Role::Assistant => write!(f, "assistant"),
- Role::System => write!(f, "system"),
- }
- }
-}
-
-#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
-pub enum LanguageModel {
- Cloud(CloudModel),
- OpenAi(OpenAiModel),
- Anthropic(AnthropicModel),
- Ollama(OllamaModel),
-}
-
-impl Default for LanguageModel {
- fn default() -> Self {
- LanguageModel::Cloud(CloudModel::default())
- }
-}
-
-impl LanguageModel {
- pub fn telemetry_id(&self) -> String {
- match self {
- LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
- LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
- LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
- LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
- }
- }
-
- pub fn display_name(&self) -> String {
- match self {
- LanguageModel::OpenAi(model) => model.display_name().into(),
- LanguageModel::Anthropic(model) => model.display_name().into(),
- LanguageModel::Cloud(model) => model.display_name().into(),
- LanguageModel::Ollama(model) => model.display_name().into(),
- }
- }
-
- pub fn max_token_count(&self) -> usize {
- match self {
- LanguageModel::OpenAi(model) => model.max_token_count(),
- LanguageModel::Anthropic(model) => model.max_token_count(),
- LanguageModel::Cloud(model) => model.max_token_count(),
- LanguageModel::Ollama(model) => model.max_token_count(),
- }
- }
-
- pub fn id(&self) -> &str {
- match self {
- LanguageModel::OpenAi(model) => model.id(),
- LanguageModel::Anthropic(model) => model.id(),
- LanguageModel::Cloud(model) => model.id(),
- LanguageModel::Ollama(model) => model.id(),
- }
- }
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct LanguageModelRequestMessage {
- pub role: Role,
- pub content: String,
-}
-
-impl LanguageModelRequestMessage {
- pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
- proto::LanguageModelRequestMessage {
- role: self.role.to_proto() as i32,
- content: self.content.clone(),
- tool_calls: Vec::new(),
- tool_call_id: None,
- }
- }
-}
-
-#[derive(Debug, Default, Serialize, Deserialize)]
-pub struct LanguageModelRequest {
- pub model: LanguageModel,
- pub messages: Vec<LanguageModelRequestMessage>,
- pub stop: Vec<String>,
- pub temperature: f32,
-}
-
-impl LanguageModelRequest {
- pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
- proto::CompleteWithLanguageModel {
- model: self.model.id().to_string(),
- messages: self.messages.iter().map(|m| m.to_proto()).collect(),
- stop: self.stop.clone(),
- temperature: self.temperature,
- tool_choice: None,
- tools: Vec::new(),
- }
- }
-
- /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
- pub fn preprocess(&mut self) {
- match &self.model {
- LanguageModel::OpenAi(_) => {}
- LanguageModel::Anthropic(_) => {}
- LanguageModel::Ollama(_) => {}
- LanguageModel::Cloud(model) => match model {
- CloudModel::Claude3Opus
- | CloudModel::Claude3Sonnet
- | CloudModel::Claude3Haiku
- | CloudModel::Claude3_5Sonnet => {
- preprocess_anthropic_request(self);
- }
- _ => {}
- },
- }
- }
-}
-
-#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
-pub struct LanguageModelResponseMessage {
- pub role: Option<Role>,
- pub content: Option<String>,
-}
-
#[derive(Deserialize, Debug)]
pub struct LanguageModelUsage {
pub prompt_tokens: u32,
@@ -343,7 +182,7 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
context_store::init(&client);
prompt_library::init(cx);
- completion_provider::init(client.clone(), cx);
+ init_completion_provider(Arc::clone(&client), cx);
assistant_slash_command::init(cx);
register_slash_commands(cx);
assistant_panel::init(cx);
@@ -368,6 +207,20 @@ pub fn init(fs: Arc<dyn Fs>, client: Arc<Client>, cx: &mut AppContext) {
.detach();
}
+fn init_completion_provider(client: Arc<Client>, cx: &mut AppContext) {
+ let provider = assistant_settings::create_provider_from_settings(client.clone(), 0, cx);
+ cx.set_global(CompletionProvider::new(provider, Some(client)));
+
+ let mut settings_version = 0;
+ cx.observe_global::<SettingsStore>(move |cx| {
+ settings_version += 1;
+ cx.update_global::<CompletionProvider, _>(|provider, cx| {
+ assistant_settings::update_completion_provider_settings(provider, settings_version, cx);
+ })
+ })
+ .detach();
+}
+
fn register_slash_commands(cx: &mut AppContext) {
let slash_command_registry = SlashCommandRegistry::global(cx);
slash_command_registry.register_command(file_command::FileSlashCommand, true);
@@ -8,18 +8,18 @@ use crate::{
SlashCommandCompletionProvider, SlashCommandRegistry,
},
terminal_inline_assistant::TerminalInlineAssistant,
- Assist, CompletionProvider, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore,
- CycleMessageRole, DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep,
- EditStepOperations, EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant,
- InsertIntoEditor, MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus,
- QuoteSelection, RemoteContextMetadata, ResetKey, Role, SavedContextMetadata, Split,
- ToggleFocus, ToggleModelSelector,
+ Assist, ConfirmCommand, Context, ContextEvent, ContextId, ContextStore, CycleMessageRole,
+ DebugEditSteps, DeployHistory, DeployPromptLibrary, EditStep, EditStepOperations,
+ EditSuggestionGroup, InlineAssist, InlineAssistId, InlineAssistant, InsertIntoEditor,
+ MessageStatus, ModelSelector, PendingSlashCommand, PendingSlashCommandStatus, QuoteSelection,
+ RemoteContextMetadata, ResetKey, SavedContextMetadata, Split, ToggleFocus, ToggleModelSelector,
};
use anyhow::{anyhow, Result};
use assistant_slash_command::{SlashCommand, SlashCommandOutputSection};
use breadcrumbs::Breadcrumbs;
use client::proto;
use collections::{BTreeSet, HashMap, HashSet};
+use completion::CompletionProvider;
use editor::{
actions::{FoldAt, MoveToEndOfLine, Newline, ShowCompletions, UnfoldAt},
display_map::{
@@ -43,6 +43,7 @@ use language::{
language_settings::SoftWrap, Buffer, Capability, LanguageRegistry, LspAdapterDelegate, Point,
ToOffset,
};
+use language_model::Role;
use multi_buffer::MultiBufferRow;
use picker::{Picker, PickerDelegate};
use project::{Project, ProjectLspAdapterDelegate};
@@ -1,166 +1,19 @@
-use std::fmt;
-
-use crate::{preprocess_anthropic_request, LanguageModel, LanguageModelRequest};
-pub use anthropic::Model as AnthropicModel;
-use gpui::Pixels;
-pub use ollama::Model as OllamaModel;
-pub use open_ai::Model as OpenAiModel;
-use schemars::{
- schema::{InstanceType, Metadata, Schema, SchemaObject},
- JsonSchema,
-};
-use serde::{
- de::{self, Visitor},
- Deserialize, Deserializer, Serialize, Serializer,
+use std::{sync::Arc, time::Duration};
+
+use anthropic::Model as AnthropicModel;
+use client::Client;
+use completion::{
+ AnthropicCompletionProvider, CloudCompletionProvider, CompletionProvider,
+ LanguageModelCompletionProvider, OllamaCompletionProvider, OpenAiCompletionProvider,
};
+use gpui::{AppContext, Pixels};
+use language_model::{CloudModel, LanguageModel};
+use ollama::Model as OllamaModel;
+use open_ai::Model as OpenAiModel;
+use parking_lot::RwLock;
+use schemars::{schema::Schema, JsonSchema};
+use serde::{Deserialize, Serialize};
use settings::{Settings, SettingsSources};
-use strum::{EnumIter, IntoEnumIterator};
-
-#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
-pub enum CloudModel {
- Gpt3Point5Turbo,
- Gpt4,
- Gpt4Turbo,
- #[default]
- Gpt4Omni,
- Gpt4OmniMini,
- Claude3_5Sonnet,
- Claude3Opus,
- Claude3Sonnet,
- Claude3Haiku,
- Gemini15Pro,
- Gemini15Flash,
- Custom(String),
-}
-
-impl Serialize for CloudModel {
- fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
- where
- S: Serializer,
- {
- serializer.serialize_str(self.id())
- }
-}
-
-impl<'de> Deserialize<'de> for CloudModel {
- fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
- where
- D: Deserializer<'de>,
- {
- struct ZedDotDevModelVisitor;
-
- impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
- type Value = CloudModel;
-
- fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
- formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
- }
-
- fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
- where
- E: de::Error,
- {
- let model = CloudModel::iter()
- .find(|model| model.id() == value)
- .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
- Ok(model)
- }
- }
-
- deserializer.deserialize_str(ZedDotDevModelVisitor)
- }
-}
-
-impl JsonSchema for CloudModel {
- fn schema_name() -> String {
- "ZedDotDevModel".to_owned()
- }
-
- fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
- let variants = CloudModel::iter()
- .filter_map(|model| {
- let id = model.id();
- if id.is_empty() {
- None
- } else {
- Some(id.to_string())
- }
- })
- .collect::<Vec<_>>();
- Schema::Object(SchemaObject {
- instance_type: Some(InstanceType::String.into()),
- enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
- metadata: Some(Box::new(Metadata {
- title: Some("ZedDotDevModel".to_owned()),
- default: Some(CloudModel::default().id().into()),
- examples: variants.into_iter().map(Into::into).collect(),
- ..Default::default()
- })),
- ..Default::default()
- })
- }
-}
-
-impl CloudModel {
- pub fn id(&self) -> &str {
- match self {
- Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
- Self::Gpt4 => "gpt-4",
- Self::Gpt4Turbo => "gpt-4-turbo-preview",
- Self::Gpt4Omni => "gpt-4o",
- Self::Gpt4OmniMini => "gpt-4o-mini",
- Self::Claude3_5Sonnet => "claude-3-5-sonnet",
- Self::Claude3Opus => "claude-3-opus",
- Self::Claude3Sonnet => "claude-3-sonnet",
- Self::Claude3Haiku => "claude-3-haiku",
- Self::Gemini15Pro => "gemini-1.5-pro",
- Self::Gemini15Flash => "gemini-1.5-flash",
- Self::Custom(id) => id,
- }
- }
-
- pub fn display_name(&self) -> &str {
- match self {
- Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
- Self::Gpt4 => "GPT 4",
- Self::Gpt4Turbo => "GPT 4 Turbo",
- Self::Gpt4Omni => "GPT 4 Omni",
- Self::Gpt4OmniMini => "GPT 4 Omni Mini",
- Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
- Self::Claude3Opus => "Claude 3 Opus",
- Self::Claude3Sonnet => "Claude 3 Sonnet",
- Self::Claude3Haiku => "Claude 3 Haiku",
- Self::Gemini15Pro => "Gemini 1.5 Pro",
- Self::Gemini15Flash => "Gemini 1.5 Flash",
- Self::Custom(id) => id.as_str(),
- }
- }
-
- pub fn max_token_count(&self) -> usize {
- match self {
- Self::Gpt3Point5Turbo => 2048,
- Self::Gpt4 => 4096,
- Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
- Self::Gpt4OmniMini => 128000,
- Self::Claude3_5Sonnet
- | Self::Claude3Opus
- | Self::Claude3Sonnet
- | Self::Claude3Haiku => 200000,
- Self::Gemini15Pro => 128000,
- Self::Gemini15Flash => 32000,
- Self::Custom(_) => 4096, // TODO: Make this configurable
- }
- }
-
- pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
- match self {
- Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
- preprocess_anthropic_request(request)
- }
- _ => {}
- }
- }
-}
#[derive(Copy, Clone, Default, Debug, Serialize, Deserialize, JsonSchema)]
#[serde(rename_all = "snake_case")]
@@ -620,6 +473,124 @@ fn merge<T>(target: &mut T, value: Option<T>) {
}
}
+pub fn update_completion_provider_settings(
+ provider: &mut CompletionProvider,
+ version: usize,
+ cx: &mut AppContext,
+) {
+ let updated = match &AssistantSettings::get_global(cx).provider {
+ AssistantProvider::ZedDotDev { model } => provider
+ .update_current_as::<_, CloudCompletionProvider>(|provider| {
+ provider.update(model.clone(), version);
+ }),
+ AssistantProvider::OpenAi {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ available_models,
+ } => provider.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
+ provider.update(
+ choose_openai_model(&model, &available_models),
+ api_url.clone(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ version,
+ );
+ }),
+ AssistantProvider::Anthropic {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => provider.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
+ provider.update(
+ model.clone(),
+ api_url.clone(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ version,
+ );
+ }),
+ AssistantProvider::Ollama {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
+ provider.update(
+ model.clone(),
+ api_url.clone(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ version,
+ cx,
+ );
+ }),
+ };
+
+ // Previously configured provider was changed to another one
+ if updated.is_none() {
+ provider.update_provider(|client| create_provider_from_settings(client, version, cx));
+ }
+}
+
+pub(crate) fn create_provider_from_settings(
+ client: Arc<Client>,
+ settings_version: usize,
+ cx: &mut AppContext,
+) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
+ match &AssistantSettings::get_global(cx).provider {
+ AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
+ CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
+ )),
+ AssistantProvider::OpenAi {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ available_models,
+ } => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
+ choose_openai_model(&model, &available_models),
+ api_url.clone(),
+ client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ available_models.clone(),
+ ))),
+ AssistantProvider::Anthropic {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
+ model.clone(),
+ api_url.clone(),
+ client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ ))),
+ AssistantProvider::Ollama {
+ model,
+ api_url,
+ low_speed_timeout_in_seconds,
+ } => Arc::new(RwLock::new(OllamaCompletionProvider::new(
+ model.clone(),
+ api_url.clone(),
+ client.http_client(),
+ low_speed_timeout_in_seconds.map(Duration::from_secs),
+ settings_version,
+ cx,
+ ))),
+ }
+}
+
+/// Choose which model to use for openai provider.
+/// If the model is not available, try to use the first available model, or fallback to the original model.
+fn choose_openai_model(
+ model: &::open_ai::Model,
+ available_models: &[::open_ai::Model],
+) -> ::open_ai::Model {
+ available_models
+ .iter()
+ .find(|&m| m == model)
+ .or_else(|| available_models.first())
+ .unwrap_or_else(|| model)
+ .clone()
+}
+
#[cfg(test)]
mod tests {
use gpui::{AppContext, UpdateGlobal};
@@ -1,12 +1,12 @@
use crate::{
- prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider,
- LanguageModelRequest, LanguageModelRequestMessage, MessageId, MessageStatus, Role,
+ prompt_library::PromptStore, slash_command::SlashCommandLine, CompletionProvider, MessageId,
+ MessageStatus,
};
use anyhow::{anyhow, Context as _, Result};
use assistant_slash_command::{
SlashCommandOutput, SlashCommandOutputSection, SlashCommandRegistry,
};
-use client::{proto, telemetry::Telemetry};
+use client::{self, proto, telemetry::Telemetry};
use clock::ReplicaId;
use collections::{HashMap, HashSet};
use fs::Fs;
@@ -18,6 +18,8 @@ use gpui::{AppContext, Context as _, EventEmitter, Model, ModelContext, Subscrip
use language::{
AnchorRangeExt, Bias, Buffer, LanguageRegistry, OffsetRangeExt, ParseStatus, Point, ToOffset,
};
+use language_model::LanguageModelRequestMessage;
+use language_model::{LanguageModelRequest, Role};
use open_ai::Model as OpenAiModel;
use paths::contexts_dir;
use project::Project;
@@ -2477,9 +2479,10 @@ mod tests {
use crate::{
assistant_panel, prompt_library,
slash_command::{active_command, file_command},
- FakeCompletionProvider, MessageId,
+ MessageId,
};
use assistant_slash_command::{ArgumentCompletion, SlashCommand};
+ use completion::FakeCompletionProvider;
use fs::FakeFs;
use gpui::{AppContext, TestAppContext, WeakView};
use indoc::indoc;
@@ -1,7 +1,6 @@
use crate::{
assistant_settings::AssistantSettings, humanize_token_count, prompts::generate_content_prompt,
- AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, LanguageModelRequest,
- LanguageModelRequestMessage, Role, StreamingDiff,
+ AssistantPanel, AssistantPanelEvent, CompletionProvider, Hunk, StreamingDiff,
};
use anyhow::{anyhow, Context as _, Result};
use client::telemetry::Telemetry;
@@ -28,6 +27,7 @@ use gpui::{
WhiteSpace, WindowContext,
};
use language::{Buffer, Point, Selection, TransactionId};
+use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use multi_buffer::MultiBufferRow;
use parking_lot::Mutex;
use rope::Rope;
@@ -1432,8 +1432,7 @@ impl Render for PromptEditor {
PopoverMenu::new("model-switcher")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
- for model in CompletionProvider::global(cx).available_models(cx)
- {
+ for model in CompletionProvider::global(cx).available_models() {
menu = menu.custom_entry(
{
let model = model.clone();
@@ -2606,7 +2605,7 @@ fn merge_ranges(ranges: &mut Vec<Range<Anchor>>, buffer: &MultiBufferSnapshot) {
#[cfg(test)]
mod tests {
use super::*;
- use crate::FakeCompletionProvider;
+ use completion::FakeCompletionProvider;
use futures::stream::{self};
use gpui::{Context, TestAppContext};
use indoc::indoc;
@@ -23,7 +23,7 @@ impl RenderOnce for ModelSelector {
.with_handle(self.handle)
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
- for model in CompletionProvider::global(cx).available_models(cx) {
+ for model in CompletionProvider::global(cx).available_models() {
menu = menu.custom_entry(
{
let model = model.clone();
@@ -1,6 +1,6 @@
use crate::{
slash_command::SlashCommandCompletionProvider, AssistantPanel, CompletionProvider,
- InlineAssist, InlineAssistant, LanguageModelRequest, LanguageModelRequestMessage, Role,
+ InlineAssist, InlineAssistant,
};
use anyhow::{anyhow, Result};
use assets::Assets;
@@ -19,6 +19,7 @@ use gpui::{
};
use heed::{types::SerdeBincode, Database, RoTxn};
use language::{language_settings::SoftWrap, Buffer, LanguageRegistry};
+use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use parking_lot::RwLock;
use picker::{Picker, PickerDelegate};
use rope::Rope;
@@ -1,7 +1,7 @@
use crate::{
assistant_settings::AssistantSettings, humanize_token_count,
prompts::generate_terminal_assistant_prompt, AssistantPanel, AssistantPanelEvent,
- CompletionProvider, LanguageModelRequest, LanguageModelRequestMessage, Role,
+ CompletionProvider,
};
use anyhow::{Context as _, Result};
use client::telemetry::Telemetry;
@@ -17,6 +17,7 @@ use gpui::{
Model, ModelContext, Subscription, Task, TextStyle, UpdateGlobal, View, WeakView, WhiteSpace,
};
use language::Buffer;
+use language_model::{LanguageModelRequest, LanguageModelRequestMessage, Role};
use settings::{update_settings_file, Settings};
use std::{
cmp,
@@ -558,8 +559,7 @@ impl Render for PromptEditor {
PopoverMenu::new("model-switcher")
.menu(move |cx| {
ContextMenu::build(cx, |mut menu, cx| {
- for model in CompletionProvider::global(cx).available_models(cx)
- {
+ for model in CompletionProvider::global(cx).available_models() {
menu = menu.custom_entry(
{
let model = model.clone();
@@ -30,6 +30,7 @@ chrono.workspace = true
clock.workspace = true
clickhouse.workspace = true
collections.workspace = true
+completion.workspace = true
dashmap = "5.4"
envy = "0.4.2"
futures.workspace = true
@@ -79,6 +80,7 @@ channel.workspace = true
client = { workspace = true, features = ["test-support"] }
collab_ui = { workspace = true, features = ["test-support"] }
collections = { workspace = true, features = ["test-support"] }
+completion = { workspace = true, features = ["test-support"] }
ctor.workspace = true
editor = { workspace = true, features = ["test-support"] }
env_logger.workspace = true
@@ -295,7 +295,7 @@ impl TestServer {
menu::init();
dev_server_projects::init(client.clone(), cx);
settings::KeymapFile::load_asset(os_keymap, cx).unwrap();
- assistant::FakeCompletionProvider::setup_test(cx);
+ completion::FakeCompletionProvider::setup_test(cx);
assistant::context_store::init(&client);
});
@@ -0,0 +1,56 @@
+[package]
+name = "completion"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/completion.rs"
+doctest = false
+
+[features]
+test-support = [
+ "editor/test-support",
+ "language/test-support",
+ "project/test-support",
+ "text/test-support",
+]
+
+[dependencies]
+anthropic = { workspace = true, features = ["schemars"] }
+anyhow.workspace = true
+client.workspace = true
+collections.workspace = true
+editor.workspace = true
+futures.workspace = true
+gpui.workspace = true
+http.workspace = true
+language_model.workspace = true
+log.workspace = true
+menu.workspace = true
+ollama = { workspace = true, features = ["schemars"] }
+open_ai = { workspace = true, features = ["schemars"] }
+parking_lot.workspace = true
+serde.workspace = true
+serde_json.workspace = true
+settings.workspace = true
+smol.workspace = true
+strum.workspace = true
+theme.workspace = true
+tiktoken-rs.workspace = true
+ui.workspace = true
+util.workspace = true
+
+[dev-dependencies]
+ctor.workspace = true
+editor = { workspace = true, features = ["test-support"] }
+env_logger.workspace = true
+language = { workspace = true, features = ["test-support"] }
+project = { workspace = true, features = ["test-support"] }
+rand.workspace = true
+text = { workspace = true, features = ["test-support"] }
+unindent.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -1,14 +1,12 @@
-use crate::{
- assistant_settings::AnthropicModel, CompletionProvider, LanguageModel, LanguageModelRequest,
- Role,
-};
-use crate::{count_open_ai_tokens, LanguageModelCompletionProvider, LanguageModelRequestMessage};
-use anthropic::{stream_completion, Request, RequestMessage};
+use crate::{count_open_ai_tokens, LanguageModelCompletionProvider};
+use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
+use anthropic::{stream_completion, Model as AnthropicModel, Request, RequestMessage};
use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
use http::HttpClient;
+use language_model::Role;
use settings::Settings;
use std::time::Duration;
use std::{env, sync::Arc};
@@ -27,7 +25,7 @@ pub struct AnthropicCompletionProvider {
}
impl LanguageModelCompletionProvider for AnthropicCompletionProvider {
- fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
+ fn available_models(&self) -> Vec<LanguageModel> {
AnthropicModel::iter()
.map(LanguageModel::Anthropic)
.collect()
@@ -176,7 +174,7 @@ impl AnthropicCompletionProvider {
}
fn to_anthropic_request(&self, mut request: LanguageModelRequest) -> Request {
- preprocess_anthropic_request(&mut request);
+ request.preprocess_anthropic();
let model = match request.model {
LanguageModel::Anthropic(model) => model,
@@ -213,49 +211,6 @@ impl AnthropicCompletionProvider {
}
}
-pub fn preprocess_anthropic_request(request: &mut LanguageModelRequest) {
- let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
- let mut system_message = String::new();
-
- for message in request.messages.drain(..) {
- if message.content.is_empty() {
- continue;
- }
-
- match message.role {
- Role::User | Role::Assistant => {
- if let Some(last_message) = new_messages.last_mut() {
- if last_message.role == message.role {
- last_message.content.push_str("\n\n");
- last_message.content.push_str(&message.content);
- continue;
- }
- }
-
- new_messages.push(message);
- }
- Role::System => {
- if !system_message.is_empty() {
- system_message.push_str("\n\n");
- }
- system_message.push_str(&message.content);
- }
- }
- }
-
- if !system_message.is_empty() {
- new_messages.insert(
- 0,
- LanguageModelRequestMessage {
- role: Role::System,
- content: system_message,
- },
- );
- }
-
- request.messages = new_messages;
-}
-
struct AuthenticationPrompt {
api_key: View<Editor>,
api_url: String,
@@ -1,11 +1,12 @@
use crate::{
- assistant_settings::CloudModel, count_open_ai_tokens, CompletionProvider, LanguageModel,
- LanguageModelCompletionProvider, LanguageModelRequest,
+ count_open_ai_tokens, CompletionProvider, LanguageModel, LanguageModelCompletionProvider,
+ LanguageModelRequest,
};
use anyhow::{anyhow, Result};
use client::{proto, Client};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt, TryFutureExt};
use gpui::{AnyView, AppContext, Task};
+use language_model::CloudModel;
use std::{future, sync::Arc};
use strum::IntoEnumIterator;
use ui::prelude::*;
@@ -52,7 +53,7 @@ impl CloudCompletionProvider {
}
impl LanguageModelCompletionProvider for CloudCompletionProvider {
- fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
+ fn available_models(&self) -> Vec<LanguageModel> {
let mut custom_model = if let CloudModel::Custom(custom_model) = self.model.clone() {
Some(custom_model)
} else {
@@ -6,52 +6,19 @@ mod ollama;
mod open_ai;
pub use anthropic::*;
+use anyhow::Result;
+use client::Client;
pub use cloud::*;
#[cfg(any(test, feature = "test-support"))]
pub use fake::*;
+use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
+use gpui::{AnyView, AppContext, Task, WindowContext};
+use language_model::{LanguageModel, LanguageModelRequest};
pub use ollama::*;
pub use open_ai::*;
use parking_lot::RwLock;
use smol::lock::{Semaphore, SemaphoreGuardArc};
-
-use crate::{
- assistant_settings::{AssistantProvider, AssistantSettings},
- LanguageModel, LanguageModelRequest,
-};
-use anyhow::Result;
-use client::Client;
-use futures::{future::BoxFuture, stream::BoxStream, StreamExt};
-use gpui::{AnyView, AppContext, BorrowAppContext, Task, WindowContext};
-use settings::{Settings, SettingsStore};
-use std::{any::Any, pin::Pin, sync::Arc, task::Poll, time::Duration};
-
-/// Choose which model to use for openai provider.
-/// If the model is not available, try to use the first available model, or fallback to the original model.
-fn choose_openai_model(
- model: &::open_ai::Model,
- available_models: &[::open_ai::Model],
-) -> ::open_ai::Model {
- available_models
- .iter()
- .find(|&m| m == model)
- .or_else(|| available_models.first())
- .unwrap_or_else(|| model)
- .clone()
-}
-
-pub fn init(client: Arc<Client>, cx: &mut AppContext) {
- let provider = create_provider_from_settings(client.clone(), 0, cx);
- cx.set_global(CompletionProvider::new(provider, Some(client)));
-
- let mut settings_version = 0;
- cx.observe_global::<SettingsStore>(move |cx| {
- settings_version += 1;
- cx.update_global::<CompletionProvider, _>(|provider, cx| {
- provider.update_settings(settings_version, cx);
- })
- })
- .detach();
-}
+use std::{any::Any, pin::Pin, sync::Arc, task::Poll};
pub struct CompletionResponse {
inner: BoxStream<'static, Result<String>>,
@@ -70,7 +37,7 @@ impl futures::Stream for CompletionResponse {
}
pub trait LanguageModelCompletionProvider: Send + Sync {
- fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel>;
+ fn available_models(&self) -> Vec<LanguageModel>;
fn settings_version(&self) -> usize;
fn is_authenticated(&self) -> bool;
fn authenticate(&self, cx: &AppContext) -> Task<Result<()>>;
@@ -110,8 +77,8 @@ impl CompletionProvider {
}
}
- pub fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
- self.provider.read().available_models(cx)
+ pub fn available_models(&self) -> Vec<LanguageModel> {
+ self.provider.read().available_models()
}
pub fn settings_version(&self) -> usize {
@@ -176,6 +143,17 @@ impl CompletionProvider {
Ok(completion)
})
}
+
+ pub fn update_provider(
+ &mut self,
+ get_provider: impl FnOnce(Arc<Client>) -> Arc<RwLock<dyn LanguageModelCompletionProvider>>,
+ ) {
+ if let Some(client) = &self.client {
+ self.provider = get_provider(Arc::clone(client));
+ } else {
+ log::warn!("completion provider cannot be updated because its client was not set");
+ }
+ }
}
impl gpui::Global for CompletionProvider {}
@@ -196,109 +174,6 @@ impl CompletionProvider {
None
}
}
-
- pub fn update_settings(&mut self, version: usize, cx: &mut AppContext) {
- let updated = match &AssistantSettings::get_global(cx).provider {
- AssistantProvider::ZedDotDev { model } => self
- .update_current_as::<_, CloudCompletionProvider>(|provider| {
- provider.update(model.clone(), version);
- }),
- AssistantProvider::OpenAi {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- available_models,
- } => self.update_current_as::<_, OpenAiCompletionProvider>(|provider| {
- provider.update(
- choose_openai_model(&model, &available_models),
- api_url.clone(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- version,
- );
- }),
- AssistantProvider::Anthropic {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- } => self.update_current_as::<_, AnthropicCompletionProvider>(|provider| {
- provider.update(
- model.clone(),
- api_url.clone(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- version,
- );
- }),
- AssistantProvider::Ollama {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- } => self.update_current_as::<_, OllamaCompletionProvider>(|provider| {
- provider.update(
- model.clone(),
- api_url.clone(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- version,
- cx,
- );
- }),
- };
-
- // Previously configured provider was changed to another one
- if updated.is_none() {
- if let Some(client) = self.client.clone() {
- self.provider = create_provider_from_settings(client, version, cx);
- } else {
- log::warn!("completion provider cannot be created because client is not set");
- }
- }
- }
-}
-
-fn create_provider_from_settings(
- client: Arc<Client>,
- settings_version: usize,
- cx: &mut AppContext,
-) -> Arc<RwLock<dyn LanguageModelCompletionProvider>> {
- match &AssistantSettings::get_global(cx).provider {
- AssistantProvider::ZedDotDev { model } => Arc::new(RwLock::new(
- CloudCompletionProvider::new(model.clone(), client.clone(), settings_version, cx),
- )),
- AssistantProvider::OpenAi {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- available_models,
- } => Arc::new(RwLock::new(OpenAiCompletionProvider::new(
- choose_openai_model(&model, &available_models),
- api_url.clone(),
- client.http_client(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- ))),
- AssistantProvider::Anthropic {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- } => Arc::new(RwLock::new(AnthropicCompletionProvider::new(
- model.clone(),
- api_url.clone(),
- client.http_client(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- ))),
- AssistantProvider::Ollama {
- model,
- api_url,
- low_speed_timeout_in_seconds,
- } => Arc::new(RwLock::new(OllamaCompletionProvider::new(
- model.clone(),
- api_url.clone(),
- client.http_client(),
- low_speed_timeout_in_seconds.map(Duration::from_secs),
- settings_version,
- cx,
- ))),
- }
}
#[cfg(test)]
@@ -311,8 +186,8 @@ mod tests {
use smol::stream::StreamExt;
use crate::{
- completion_provider::MAX_CONCURRENT_COMPLETION_REQUESTS, CompletionProvider,
- FakeCompletionProvider, LanguageModelRequest,
+ CompletionProvider, FakeCompletionProvider, LanguageModelRequest,
+ MAX_CONCURRENT_COMPLETION_REQUESTS,
};
#[gpui::test]
@@ -62,7 +62,7 @@ impl FakeCompletionProvider {
}
impl LanguageModelCompletionProvider for FakeCompletionProvider {
- fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
+ fn available_models(&self) -> Vec<LanguageModel> {
vec![LanguageModel::default()]
}
@@ -1,15 +1,14 @@
use crate::LanguageModelCompletionProvider;
-use crate::{
- assistant_settings::OllamaModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
-};
+use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
use anyhow::Result;
use futures::StreamExt as _;
use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
use gpui::{AnyView, AppContext, Task};
use http::HttpClient;
+use language_model::Role;
+use ollama::Model as OllamaModel;
use ollama::{
get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
- Role as OllamaRole,
};
use std::sync::Arc;
use std::time::Duration;
@@ -28,7 +27,7 @@ pub struct OllamaCompletionProvider {
}
impl LanguageModelCompletionProvider for OllamaCompletionProvider {
- fn available_models(&self, _cx: &AppContext) -> Vec<LanguageModel> {
+ fn available_models(&self) -> Vec<LanguageModel> {
self.available_models
.iter()
.map(|m| LanguageModel::Ollama(m.clone()))
@@ -262,16 +261,6 @@ impl OllamaCompletionProvider {
}
}
-impl From<Role> for ollama::Role {
- fn from(val: Role) -> Self {
- match val {
- Role::User => OllamaRole::User,
- Role::Assistant => OllamaRole::Assistant,
- Role::System => OllamaRole::System,
- }
- }
-}
-
struct DownloadOllamaMessage {
retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
}
@@ -1,15 +1,13 @@
-use crate::assistant_settings::CloudModel;
-use crate::assistant_settings::{AssistantProvider, AssistantSettings};
+use crate::CompletionProvider;
use crate::LanguageModelCompletionProvider;
-use crate::{
- assistant_settings::OpenAiModel, CompletionProvider, LanguageModel, LanguageModelRequest, Role,
-};
use anyhow::{anyhow, Result};
use editor::{Editor, EditorElement, EditorStyle};
use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
use gpui::{AnyView, AppContext, FontStyle, Task, TextStyle, View, WhiteSpace};
use http::HttpClient;
-use open_ai::{stream_completion, Request, RequestMessage, Role as OpenAiRole};
+use language_model::{CloudModel, LanguageModel, LanguageModelRequest, Role};
+use open_ai::Model as OpenAiModel;
+use open_ai::{stream_completion, Request, RequestMessage};
use settings::Settings;
use std::time::Duration;
use std::{env, sync::Arc};
@@ -25,6 +23,7 @@ pub struct OpenAiCompletionProvider {
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
+ available_models_from_settings: Vec<OpenAiModel>,
}
impl OpenAiCompletionProvider {
@@ -34,6 +33,7 @@ impl OpenAiCompletionProvider {
http_client: Arc<dyn HttpClient>,
low_speed_timeout: Option<Duration>,
settings_version: usize,
+ available_models_from_settings: Vec<OpenAiModel>,
) -> Self {
Self {
api_key: None,
@@ -42,6 +42,7 @@ impl OpenAiCompletionProvider {
http_client,
low_speed_timeout,
settings_version,
+ available_models_from_settings,
}
}
@@ -92,30 +93,26 @@ impl OpenAiCompletionProvider {
}
impl LanguageModelCompletionProvider for OpenAiCompletionProvider {
- fn available_models(&self, cx: &AppContext) -> Vec<LanguageModel> {
- if let AssistantProvider::OpenAi {
- available_models, ..
- } = &AssistantSettings::get_global(cx).provider
- {
- if !available_models.is_empty() {
- return available_models
- .iter()
- .cloned()
- .map(LanguageModel::OpenAi)
- .collect();
- }
- }
- let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
- vec![self.model.clone()]
+ fn available_models(&self) -> Vec<LanguageModel> {
+ if self.available_models_from_settings.is_empty() {
+ let available_models = if matches!(self.model, OpenAiModel::Custom { .. }) {
+ vec![self.model.clone()]
+ } else {
+ OpenAiModel::iter()
+ .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
+ .collect()
+ };
+ available_models
+ .into_iter()
+ .map(LanguageModel::OpenAi)
+ .collect()
} else {
- OpenAiModel::iter()
- .filter(|model| !matches!(model, OpenAiModel::Custom { .. }))
+ self.available_models_from_settings
+ .iter()
+ .cloned()
+ .map(LanguageModel::OpenAi)
.collect()
- };
- available_models
- .into_iter()
- .map(LanguageModel::OpenAi)
- .collect()
+ }
}
fn settings_version(&self) -> usize {
@@ -255,16 +252,6 @@ pub fn count_open_ai_tokens(
.boxed()
}
-impl From<Role> for open_ai::Role {
- fn from(val: Role) -> Self {
- match val {
- Role::User => OpenAiRole::User,
- Role::Assistant => OpenAiRole::Assistant,
- Role::System => OpenAiRole::System,
- }
- }
-}
-
struct AuthenticationPrompt {
api_key: View<Editor>,
api_url: String,
@@ -0,0 +1,41 @@
+[package]
+name = "language_model"
+version = "0.1.0"
+edition = "2021"
+publish = false
+license = "GPL-3.0-or-later"
+
+[lints]
+workspace = true
+
+[lib]
+path = "src/language_model.rs"
+doctest = false
+
+[features]
+test-support = [
+ "editor/test-support",
+ "language/test-support",
+ "project/test-support",
+ "text/test-support",
+]
+
+[dependencies]
+anthropic = { workspace = true, features = ["schemars"] }
+ollama = { workspace = true, features = ["schemars"] }
+open_ai = { workspace = true, features = ["schemars"] }
+schemars.workspace = true
+serde.workspace = true
+strum.workspace = true
+proto = { workspace = true, features = ["test-support"] }
+
+[dev-dependencies]
+ctor.workspace = true
+editor = { workspace = true, features = ["test-support"] }
+env_logger.workspace = true
+language = { workspace = true, features = ["test-support"] }
+log.workspace = true
+project = { workspace = true, features = ["test-support"] }
+rand.workspace = true
+text = { workspace = true, features = ["test-support"] }
+unindent.workspace = true
@@ -0,0 +1 @@
+../../LICENSE-GPL
@@ -0,0 +1,7 @@
+mod model;
+mod request;
+mod role;
+
+pub use model::*;
+pub use request::*;
+pub use role::*;
@@ -0,0 +1,160 @@
+use crate::LanguageModelRequest;
+pub use anthropic::Model as AnthropicModel;
+pub use ollama::Model as OllamaModel;
+pub use open_ai::Model as OpenAiModel;
+use schemars::{
+ schema::{InstanceType, Metadata, Schema, SchemaObject},
+ JsonSchema,
+};
+use serde::{
+ de::{self, Visitor},
+ Deserialize, Deserializer, Serialize, Serializer,
+};
+use std::fmt;
+use strum::{EnumIter, IntoEnumIterator};
+
+#[derive(Clone, Debug, Default, PartialEq, EnumIter)]
+pub enum CloudModel {
+ Gpt3Point5Turbo,
+ Gpt4,
+ Gpt4Turbo,
+ #[default]
+ Gpt4Omni,
+ Gpt4OmniMini,
+ Claude3_5Sonnet,
+ Claude3Opus,
+ Claude3Sonnet,
+ Claude3Haiku,
+ Gemini15Pro,
+ Gemini15Flash,
+ Custom(String),
+}
+
+impl Serialize for CloudModel {
+ fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
+ where
+ S: Serializer,
+ {
+ serializer.serialize_str(self.id())
+ }
+}
+
+impl<'de> Deserialize<'de> for CloudModel {
+ fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
+ where
+ D: Deserializer<'de>,
+ {
+ struct ZedDotDevModelVisitor;
+
+ impl<'de> Visitor<'de> for ZedDotDevModelVisitor {
+ type Value = CloudModel;
+
+ fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
+ formatter.write_str("a string for a ZedDotDevModel variant or a custom model")
+ }
+
+ fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
+ where
+ E: de::Error,
+ {
+ let model = CloudModel::iter()
+ .find(|model| model.id() == value)
+ .unwrap_or_else(|| CloudModel::Custom(value.to_string()));
+ Ok(model)
+ }
+ }
+
+ deserializer.deserialize_str(ZedDotDevModelVisitor)
+ }
+}
+
+impl JsonSchema for CloudModel {
+ fn schema_name() -> String {
+ "ZedDotDevModel".to_owned()
+ }
+
+ fn json_schema(_generator: &mut schemars::gen::SchemaGenerator) -> Schema {
+ let variants = CloudModel::iter()
+ .filter_map(|model| {
+ let id = model.id();
+ if id.is_empty() {
+ None
+ } else {
+ Some(id.to_string())
+ }
+ })
+ .collect::<Vec<_>>();
+ Schema::Object(SchemaObject {
+ instance_type: Some(InstanceType::String.into()),
+ enum_values: Some(variants.iter().map(|s| s.clone().into()).collect()),
+ metadata: Some(Box::new(Metadata {
+ title: Some("ZedDotDevModel".to_owned()),
+ default: Some(CloudModel::default().id().into()),
+ examples: variants.into_iter().map(Into::into).collect(),
+ ..Default::default()
+ })),
+ ..Default::default()
+ })
+ }
+}
+
+impl CloudModel {
+ pub fn id(&self) -> &str {
+ match self {
+ Self::Gpt3Point5Turbo => "gpt-3.5-turbo",
+ Self::Gpt4 => "gpt-4",
+ Self::Gpt4Turbo => "gpt-4-turbo-preview",
+ Self::Gpt4Omni => "gpt-4o",
+ Self::Gpt4OmniMini => "gpt-4o-mini",
+ Self::Claude3_5Sonnet => "claude-3-5-sonnet",
+ Self::Claude3Opus => "claude-3-opus",
+ Self::Claude3Sonnet => "claude-3-sonnet",
+ Self::Claude3Haiku => "claude-3-haiku",
+ Self::Gemini15Pro => "gemini-1.5-pro",
+ Self::Gemini15Flash => "gemini-1.5-flash",
+ Self::Custom(id) => id,
+ }
+ }
+
+ pub fn display_name(&self) -> &str {
+ match self {
+ Self::Gpt3Point5Turbo => "GPT 3.5 Turbo",
+ Self::Gpt4 => "GPT 4",
+ Self::Gpt4Turbo => "GPT 4 Turbo",
+ Self::Gpt4Omni => "GPT 4 Omni",
+ Self::Gpt4OmniMini => "GPT 4 Omni Mini",
+ Self::Claude3_5Sonnet => "Claude 3.5 Sonnet",
+ Self::Claude3Opus => "Claude 3 Opus",
+ Self::Claude3Sonnet => "Claude 3 Sonnet",
+ Self::Claude3Haiku => "Claude 3 Haiku",
+ Self::Gemini15Pro => "Gemini 1.5 Pro",
+ Self::Gemini15Flash => "Gemini 1.5 Flash",
+ Self::Custom(id) => id.as_str(),
+ }
+ }
+
+ pub fn max_token_count(&self) -> usize {
+ match self {
+ Self::Gpt3Point5Turbo => 2048,
+ Self::Gpt4 => 4096,
+ Self::Gpt4Turbo | Self::Gpt4Omni => 128000,
+ Self::Gpt4OmniMini => 128000,
+ Self::Claude3_5Sonnet
+ | Self::Claude3Opus
+ | Self::Claude3Sonnet
+ | Self::Claude3Haiku => 200000,
+ Self::Gemini15Pro => 128000,
+ Self::Gemini15Flash => 32000,
+ Self::Custom(_) => 4096, // TODO: Make this configurable
+ }
+ }
+
+ pub fn preprocess_request(&self, request: &mut LanguageModelRequest) {
+ match self {
+ Self::Claude3Opus | Self::Claude3Sonnet | Self::Claude3Haiku => {
+ request.preprocess_anthropic()
+ }
+ _ => {}
+ }
+ }
+}
@@ -0,0 +1,60 @@
+pub mod cloud_model;
+
+pub use anthropic::Model as AnthropicModel;
+pub use cloud_model::*;
+pub use ollama::Model as OllamaModel;
+pub use open_ai::Model as OpenAiModel;
+
+use serde::{Deserialize, Serialize};
+
+#[derive(Clone, Debug, Serialize, Deserialize, PartialEq)]
+pub enum LanguageModel {
+ Cloud(CloudModel),
+ OpenAi(OpenAiModel),
+ Anthropic(AnthropicModel),
+ Ollama(OllamaModel),
+}
+
+impl Default for LanguageModel {
+ fn default() -> Self {
+ LanguageModel::Cloud(CloudModel::default())
+ }
+}
+
+impl LanguageModel {
+ pub fn telemetry_id(&self) -> String {
+ match self {
+ LanguageModel::OpenAi(model) => format!("openai/{}", model.id()),
+ LanguageModel::Anthropic(model) => format!("anthropic/{}", model.id()),
+ LanguageModel::Cloud(model) => format!("zed.dev/{}", model.id()),
+ LanguageModel::Ollama(model) => format!("ollama/{}", model.id()),
+ }
+ }
+
+ pub fn display_name(&self) -> String {
+ match self {
+ LanguageModel::OpenAi(model) => model.display_name().into(),
+ LanguageModel::Anthropic(model) => model.display_name().into(),
+ LanguageModel::Cloud(model) => model.display_name().into(),
+ LanguageModel::Ollama(model) => model.display_name().into(),
+ }
+ }
+
+ pub fn max_token_count(&self) -> usize {
+ match self {
+ LanguageModel::OpenAi(model) => model.max_token_count(),
+ LanguageModel::Anthropic(model) => model.max_token_count(),
+ LanguageModel::Cloud(model) => model.max_token_count(),
+ LanguageModel::Ollama(model) => model.max_token_count(),
+ }
+ }
+
+ pub fn id(&self) -> &str {
+ match self {
+ LanguageModel::OpenAi(model) => model.id(),
+ LanguageModel::Anthropic(model) => model.id(),
+ LanguageModel::Cloud(model) => model.id(),
+ LanguageModel::Ollama(model) => model.id(),
+ }
+ }
+}
@@ -0,0 +1,110 @@
+use crate::{
+ model::{CloudModel, LanguageModel},
+ role::Role,
+};
+use serde::{Deserialize, Serialize};
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct LanguageModelRequestMessage {
+ pub role: Role,
+ pub content: String,
+}
+
+impl LanguageModelRequestMessage {
+ pub fn to_proto(&self) -> proto::LanguageModelRequestMessage {
+ proto::LanguageModelRequestMessage {
+ role: self.role.to_proto() as i32,
+ content: self.content.clone(),
+ tool_calls: Vec::new(),
+ tool_call_id: None,
+ }
+ }
+}
+
+#[derive(Debug, Default, Serialize, Deserialize)]
+pub struct LanguageModelRequest {
+ pub model: LanguageModel,
+ pub messages: Vec<LanguageModelRequestMessage>,
+ pub stop: Vec<String>,
+ pub temperature: f32,
+}
+
+impl LanguageModelRequest {
+ pub fn to_proto(&self) -> proto::CompleteWithLanguageModel {
+ proto::CompleteWithLanguageModel {
+ model: self.model.id().to_string(),
+ messages: self.messages.iter().map(|m| m.to_proto()).collect(),
+ stop: self.stop.clone(),
+ temperature: self.temperature,
+ tool_choice: None,
+ tools: Vec::new(),
+ }
+ }
+
+ /// Before we send the request to the server, we can perform fixups on it appropriate to the model.
+ pub fn preprocess(&mut self) {
+ match &self.model {
+ LanguageModel::OpenAi(_) => {}
+ LanguageModel::Anthropic(_) => {}
+ LanguageModel::Ollama(_) => {}
+ LanguageModel::Cloud(model) => match model {
+ CloudModel::Claude3Opus
+ | CloudModel::Claude3Sonnet
+ | CloudModel::Claude3Haiku
+ | CloudModel::Claude3_5Sonnet => {
+ self.preprocess_anthropic();
+ }
+ _ => {}
+ },
+ }
+ }
+
+ pub fn preprocess_anthropic(&mut self) {
+ let mut new_messages: Vec<LanguageModelRequestMessage> = Vec::new();
+ let mut system_message = String::new();
+
+ for message in self.messages.drain(..) {
+ if message.content.is_empty() {
+ continue;
+ }
+
+ match message.role {
+ Role::User | Role::Assistant => {
+ if let Some(last_message) = new_messages.last_mut() {
+ if last_message.role == message.role {
+ last_message.content.push_str("\n\n");
+ last_message.content.push_str(&message.content);
+ continue;
+ }
+ }
+
+ new_messages.push(message);
+ }
+ Role::System => {
+ if !system_message.is_empty() {
+ system_message.push_str("\n\n");
+ }
+ system_message.push_str(&message.content);
+ }
+ }
+ }
+
+ if !system_message.is_empty() {
+ new_messages.insert(
+ 0,
+ LanguageModelRequestMessage {
+ role: Role::System,
+ content: system_message,
+ },
+ );
+ }
+
+ self.messages = new_messages;
+ }
+}
+
+#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
+pub struct LanguageModelResponseMessage {
+ pub role: Option<Role>,
+ pub content: Option<String>,
+}
@@ -0,0 +1,68 @@
+use serde::{Deserialize, Serialize};
+use std::fmt::{self, Display};
+
+#[derive(Clone, Copy, Serialize, Deserialize, Debug, Eq, PartialEq)]
+#[serde(rename_all = "lowercase")]
+pub enum Role {
+ User,
+ Assistant,
+ System,
+}
+
+impl Role {
+ pub fn from_proto(role: i32) -> Role {
+ match proto::LanguageModelRole::from_i32(role) {
+ Some(proto::LanguageModelRole::LanguageModelUser) => Role::User,
+ Some(proto::LanguageModelRole::LanguageModelAssistant) => Role::Assistant,
+ Some(proto::LanguageModelRole::LanguageModelSystem) => Role::System,
+ Some(proto::LanguageModelRole::LanguageModelTool) => Role::System,
+ None => Role::User,
+ }
+ }
+
+ pub fn to_proto(&self) -> proto::LanguageModelRole {
+ match self {
+ Role::User => proto::LanguageModelRole::LanguageModelUser,
+ Role::Assistant => proto::LanguageModelRole::LanguageModelAssistant,
+ Role::System => proto::LanguageModelRole::LanguageModelSystem,
+ }
+ }
+
+ pub fn cycle(self) -> Role {
+ match self {
+ Role::User => Role::Assistant,
+ Role::Assistant => Role::System,
+ Role::System => Role::User,
+ }
+ }
+}
+
+impl Display for Role {
+ fn fmt(&self, f: &mut fmt::Formatter<'_>) -> std::fmt::Result {
+ match self {
+ Role::User => write!(f, "user"),
+ Role::Assistant => write!(f, "assistant"),
+ Role::System => write!(f, "system"),
+ }
+ }
+}
+
+impl From<Role> for ollama::Role {
+ fn from(val: Role) -> Self {
+ match val {
+ Role::User => ollama::Role::User,
+ Role::Assistant => ollama::Role::Assistant,
+ Role::System => ollama::Role::System,
+ }
+ }
+}
+
+impl From<Role> for open_ai::Role {
+ fn from(val: Role) -> Self {
+ match val {
+ Role::User => open_ai::Role::User,
+ Role::Assistant => open_ai::Role::Assistant,
+ Role::System => open_ai::Role::System,
+ }
+ }
+}
@@ -22,6 +22,7 @@ anyhow.workspace = true
client.workspace = true
clock.workspace = true
collections.workspace = true
+completion.workspace = true
fs.workspace = true
futures.workspace = true
futures-batch.workspace = true
@@ -1261,3 +1261,6 @@ mod tests {
);
}
}
+
+// See https://github.com/zed-industries/zed/pull/14823#discussion_r1684616398 for why this is here and when it should be removed.
+type _TODO = completion::CompletionProvider;