@@ -7,10 +7,12 @@ use gpui::{DismissEvent, Entity, EventEmitter, FocusHandle, Focusable, Render, T
use language_model::LanguageModelRegistry;
use language_models::{
AllLanguageModelSettings, OpenAiCompatibleSettingsContent,
- provider::open_ai_compatible::AvailableModel,
+ provider::open_ai_compatible::{AvailableModel, ModelCapabilities},
};
use settings::update_settings_file;
-use ui::{Banner, KeyBinding, Modal, ModalFooter, ModalHeader, Section, prelude::*};
+use ui::{
+ Banner, Checkbox, KeyBinding, Modal, ModalFooter, ModalHeader, Section, ToggleState, prelude::*,
+};
use ui_input::SingleLineInput;
use workspace::{ModalView, Workspace};
@@ -69,11 +71,19 @@ impl AddLlmProviderInput {
}
}
+struct ModelCapabilityToggles {
+ pub supports_tools: ToggleState,
+ pub supports_images: ToggleState,
+ pub supports_parallel_tool_calls: ToggleState,
+ pub supports_prompt_cache_key: ToggleState,
+}
+
struct ModelInput {
name: Entity<SingleLineInput>,
max_completion_tokens: Entity<SingleLineInput>,
max_output_tokens: Entity<SingleLineInput>,
max_tokens: Entity<SingleLineInput>,
+ capabilities: ModelCapabilityToggles,
}
impl ModelInput {
@@ -100,11 +110,23 @@ impl ModelInput {
cx,
);
let max_tokens = single_line_input("Max Tokens", "Max Tokens", Some("200000"), window, cx);
+ let ModelCapabilities {
+ tools,
+ images,
+ parallel_tool_calls,
+ prompt_cache_key,
+ } = ModelCapabilities::default();
Self {
name: model_name,
max_completion_tokens,
max_output_tokens,
max_tokens,
+ capabilities: ModelCapabilityToggles {
+ supports_tools: tools.into(),
+ supports_images: images.into(),
+ supports_parallel_tool_calls: parallel_tool_calls.into(),
+ supports_prompt_cache_key: prompt_cache_key.into(),
+ },
}
}
@@ -136,6 +158,12 @@ impl ModelInput {
.text(cx)
.parse::<u64>()
.map_err(|_| SharedString::from("Max Tokens must be a number"))?,
+ capabilities: ModelCapabilities {
+ tools: self.capabilities.supports_tools.selected(),
+ images: self.capabilities.supports_images.selected(),
+ parallel_tool_calls: self.capabilities.supports_parallel_tool_calls.selected(),
+ prompt_cache_key: self.capabilities.supports_prompt_cache_key.selected(),
+ },
})
}
}
@@ -322,6 +350,55 @@ impl AddLlmProviderModal {
.child(model.max_output_tokens.clone()),
)
.child(model.max_tokens.clone())
+ .child(
+ v_flex()
+ .gap_1()
+ .child(
+ Checkbox::new(("supports-tools", ix), model.capabilities.supports_tools)
+ .label("Supports tools")
+ .on_click(cx.listener(move |this, checked, _window, cx| {
+ this.input.models[ix].capabilities.supports_tools = *checked;
+ cx.notify();
+ })),
+ )
+ .child(
+ Checkbox::new(("supports-images", ix), model.capabilities.supports_images)
+ .label("Supports images")
+ .on_click(cx.listener(move |this, checked, _window, cx| {
+ this.input.models[ix].capabilities.supports_images = *checked;
+ cx.notify();
+ })),
+ )
+ .child(
+ Checkbox::new(
+ ("supports-parallel-tool-calls", ix),
+ model.capabilities.supports_parallel_tool_calls,
+ )
+ .label("Supports parallel_tool_calls")
+ .on_click(cx.listener(
+ move |this, checked, _window, cx| {
+ this.input.models[ix]
+ .capabilities
+ .supports_parallel_tool_calls = *checked;
+ cx.notify();
+ },
+ )),
+ )
+ .child(
+ Checkbox::new(
+ ("supports-prompt-cache-key", ix),
+ model.capabilities.supports_prompt_cache_key,
+ )
+ .label("Supports prompt_cache_key")
+ .on_click(cx.listener(
+ move |this, checked, _window, cx| {
+ this.input.models[ix].capabilities.supports_prompt_cache_key =
+ *checked;
+ cx.notify();
+ },
+ )),
+ ),
+ )
.when(has_more_than_one_model, |this| {
this.child(
Button::new(("remove-model", ix), "Remove Model")
@@ -562,6 +639,93 @@ mod tests {
);
}
+ #[gpui::test]
+ async fn test_model_input_default_capabilities(cx: &mut TestAppContext) {
+ let cx = setup_test(cx).await;
+
+ cx.update(|window, cx| {
+ let model_input = ModelInput::new(window, cx);
+ model_input.name.update(cx, |input, cx| {
+ input.editor().update(cx, |editor, cx| {
+ editor.set_text("somemodel", window, cx);
+ });
+ });
+ assert_eq!(
+ model_input.capabilities.supports_tools,
+ ToggleState::Selected
+ );
+ assert_eq!(
+ model_input.capabilities.supports_images,
+ ToggleState::Unselected
+ );
+ assert_eq!(
+ model_input.capabilities.supports_parallel_tool_calls,
+ ToggleState::Unselected
+ );
+ assert_eq!(
+ model_input.capabilities.supports_prompt_cache_key,
+ ToggleState::Unselected
+ );
+
+ let parsed_model = model_input.parse(cx).unwrap();
+ assert_eq!(parsed_model.capabilities.tools, true);
+ assert_eq!(parsed_model.capabilities.images, false);
+ assert_eq!(parsed_model.capabilities.parallel_tool_calls, false);
+ assert_eq!(parsed_model.capabilities.prompt_cache_key, false);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_model_input_deselected_capabilities(cx: &mut TestAppContext) {
+ let cx = setup_test(cx).await;
+
+ cx.update(|window, cx| {
+ let mut model_input = ModelInput::new(window, cx);
+ model_input.name.update(cx, |input, cx| {
+ input.editor().update(cx, |editor, cx| {
+ editor.set_text("somemodel", window, cx);
+ });
+ });
+
+ model_input.capabilities.supports_tools = ToggleState::Unselected;
+ model_input.capabilities.supports_images = ToggleState::Unselected;
+ model_input.capabilities.supports_parallel_tool_calls = ToggleState::Unselected;
+ model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
+
+ let parsed_model = model_input.parse(cx).unwrap();
+ assert_eq!(parsed_model.capabilities.tools, false);
+ assert_eq!(parsed_model.capabilities.images, false);
+ assert_eq!(parsed_model.capabilities.parallel_tool_calls, false);
+ assert_eq!(parsed_model.capabilities.prompt_cache_key, false);
+ });
+ }
+
+ #[gpui::test]
+ async fn test_model_input_with_name_and_capabilities(cx: &mut TestAppContext) {
+ let cx = setup_test(cx).await;
+
+ cx.update(|window, cx| {
+ let mut model_input = ModelInput::new(window, cx);
+ model_input.name.update(cx, |input, cx| {
+ input.editor().update(cx, |editor, cx| {
+ editor.set_text("somemodel", window, cx);
+ });
+ });
+
+ model_input.capabilities.supports_tools = ToggleState::Selected;
+ model_input.capabilities.supports_images = ToggleState::Unselected;
+ model_input.capabilities.supports_parallel_tool_calls = ToggleState::Selected;
+ model_input.capabilities.supports_prompt_cache_key = ToggleState::Unselected;
+
+ let parsed_model = model_input.parse(cx).unwrap();
+ assert_eq!(parsed_model.name, "somemodel");
+ assert_eq!(parsed_model.capabilities.tools, true);
+ assert_eq!(parsed_model.capabilities.images, false);
+ assert_eq!(parsed_model.capabilities.parallel_tool_calls, true);
+ assert_eq!(parsed_model.capabilities.prompt_cache_key, false);
+ });
+ }
+
async fn setup_test(cx: &mut TestAppContext) -> &mut VisualTestContext {
cx.update(|cx| {
let store = SettingsStore::test(cx);
@@ -38,6 +38,27 @@ pub struct AvailableModel {
pub max_tokens: u64,
pub max_output_tokens: Option<u64>,
pub max_completion_tokens: Option<u64>,
+ #[serde(default)]
+ pub capabilities: ModelCapabilities,
+}
+
+#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
+pub struct ModelCapabilities {
+ pub tools: bool,
+ pub images: bool,
+ pub parallel_tool_calls: bool,
+ pub prompt_cache_key: bool,
+}
+
+impl Default for ModelCapabilities {
+ fn default() -> Self {
+ Self {
+ tools: true,
+ images: false,
+ parallel_tool_calls: false,
+ prompt_cache_key: false,
+ }
+ }
}
pub struct OpenAiCompatibleLanguageModelProvider {
@@ -293,17 +314,17 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
}
fn supports_tools(&self) -> bool {
- true
+ self.model.capabilities.tools
}
fn supports_images(&self) -> bool {
- false
+ self.model.capabilities.images
}
fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
match choice {
- LanguageModelToolChoice::Auto => true,
- LanguageModelToolChoice::Any => true,
+ LanguageModelToolChoice::Auto => self.model.capabilities.tools,
+ LanguageModelToolChoice::Any => self.model.capabilities.tools,
LanguageModelToolChoice::None => true,
}
}
@@ -355,13 +376,11 @@ impl LanguageModel for OpenAiCompatibleLanguageModel {
LanguageModelCompletionError,
>,
> {
- let supports_parallel_tool_call = true;
- let supports_prompt_cache_key = false;
let request = into_open_ai(
request,
&self.model.name,
- supports_parallel_tool_call,
- supports_prompt_cache_key,
+ self.model.capabilities.parallel_tool_calls,
+ self.model.capabilities.prompt_cache_key,
self.max_output_tokens(),
None,
);
@@ -427,7 +427,7 @@ Custom models will be listed in the model dropdown in the Agent Panel.
Zed supports using [OpenAI compatible APIs](https://platform.openai.com/docs/api-reference/chat) by specifying a custom `api_url` and `available_models` for the OpenAI provider.
This is useful for connecting to other hosted services (like Together AI, Anyscale, etc.) or local models.
-You can add a custom, OpenAI-compatible model via either via the UI or by editing your `settings.json`.
+You can add a custom, OpenAI-compatible model either via the UI or by editing your `settings.json`.
To do it via the UI, go to the Agent Panel settings (`agent: open settings`) and look for the "Add Provider" button to the right of the "LLM Providers" section title.
Then, fill up the input fields available in the modal.
@@ -443,7 +443,13 @@ To do it via your `settings.json`, add the following snippet under `language_mod
{
"name": "mistralai/Mixtral-8x7B-Instruct-v0.1",
"display_name": "Together Mixtral 8x7B",
- "max_tokens": 32768
+ "max_tokens": 32768,
+ "capabilities": {
+ "tools": true,
+ "images": false,
+ "parallel_tool_calls": false,
+ "prompt_cache_key": false
+ }
}
]
}
@@ -451,6 +457,13 @@ To do it via your `settings.json`, add the following snippet under `language_mod
}
```
+By default, OpenAI-compatible models inherit the following capabilities:
+
+- `tools`: true (supports tool/function calling)
+- `images`: false (does not support image inputs)
+- `parallel_tool_calls`: false (does not support `parallel_tool_calls` parameter)
+- `prompt_cache_key`: false (does not support `prompt_cache_key` parameter)
+
Note that LLM API keys aren't stored in your settings file.
So, ensure you have it set in your environment variables (`OPENAI_API_KEY=<your api key>`) so your settings can pick it up.