ollama.rs

   1use anyhow::{Result, anyhow};
   2use fs::Fs;
   3use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
   4use futures::{Stream, TryFutureExt, stream};
   5use gpui::{AnyView, App, AsyncApp, Context, CursorStyle, Entity, Task};
   6use http_client::HttpClient;
   7use language_model::{
   8    ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
   9    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  10    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  11    LanguageModelRequest, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
  12    LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
  13};
  14use menu;
  15use ollama::{
  16    ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, OLLAMA_API_URL, OllamaFunctionCall,
  17    OllamaFunctionTool, OllamaToolCall, get_models, show_model, stream_chat_completion,
  18};
  19pub use settings::OllamaAvailableModel as AvailableModel;
  20use settings::{Settings, SettingsStore, update_settings_file};
  21use std::pin::Pin;
  22use std::sync::LazyLock;
  23use std::{collections::HashMap, sync::Arc};
  24use ui::{
  25    ButtonLike, ButtonLink, ConfiguredApiCard, ElevationIndex, List, ListBulletItem, Tooltip,
  26    prelude::*,
  27};
  28use ui_input::InputField;
  29
  30use crate::AllLanguageModelSettings;
  31
  32const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
  33const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
  34const OLLAMA_SITE: &str = "https://ollama.com/";
  35
  36const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
  37const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
  38
  39const API_KEY_ENV_VAR_NAME: &str = "OLLAMA_API_KEY";
  40static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
  41
  42#[derive(Default, Debug, Clone, PartialEq)]
  43pub struct OllamaSettings {
  44    pub api_url: String,
  45    pub auto_discover: bool,
  46    pub available_models: Vec<AvailableModel>,
  47    pub context_window: Option<u64>,
  48}
  49
  50pub struct OllamaLanguageModelProvider {
  51    http_client: Arc<dyn HttpClient>,
  52    state: Entity<State>,
  53}
  54
  55pub struct State {
  56    api_key_state: ApiKeyState,
  57    http_client: Arc<dyn HttpClient>,
  58    fetched_models: Vec<ollama::Model>,
  59    fetch_model_task: Option<Task<Result<()>>>,
  60}
  61
  62impl State {
  63    fn is_authenticated(&self) -> bool {
  64        !self.fetched_models.is_empty()
  65    }
  66
  67    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  68        let api_url = OllamaLanguageModelProvider::api_url(cx);
  69        let task = self
  70            .api_key_state
  71            .store(api_url, api_key, |this| &mut this.api_key_state, cx);
  72
  73        self.fetched_models.clear();
  74        cx.spawn(async move |this, cx| {
  75            let result = task.await;
  76            this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
  77                .ok();
  78            result
  79        })
  80    }
  81
  82    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  83        let api_url = OllamaLanguageModelProvider::api_url(cx);
  84        let task = self
  85            .api_key_state
  86            .load_if_needed(api_url, |this| &mut this.api_key_state, cx);
  87
  88        // Always try to fetch models - if no API key is needed (local Ollama), it will work
  89        // If API key is needed and provided, it will work
  90        // If API key is needed and not provided, it will fail gracefully
  91        cx.spawn(async move |this, cx| {
  92            let result = task.await;
  93            this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
  94                .ok();
  95            result
  96        })
  97    }
  98
  99    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 100        let http_client = Arc::clone(&self.http_client);
 101        let api_url = OllamaLanguageModelProvider::api_url(cx);
 102        let api_key = self.api_key_state.key(&api_url);
 103
 104        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 105        cx.spawn(async move |this, cx| {
 106            let models = get_models(http_client.as_ref(), &api_url, api_key.as_deref()).await?;
 107
 108            let tasks = models
 109                .into_iter()
 110                // Since there is no metadata from the Ollama API
 111                // indicating which models are embedding models,
 112                // simply filter out models with "-embed" in their name
 113                .filter(|model| !model.name.contains("-embed"))
 114                .map(|model| {
 115                    let http_client = Arc::clone(&http_client);
 116                    let api_url = api_url.clone();
 117                    let api_key = api_key.clone();
 118                    async move {
 119                        let name = model.name.as_str();
 120                        let model =
 121                            show_model(http_client.as_ref(), &api_url, api_key.as_deref(), name)
 122                                .await?;
 123                        let ollama_model = ollama::Model::new(
 124                            name,
 125                            None,
 126                            model.context_length,
 127                            Some(model.supports_tools()),
 128                            Some(model.supports_vision()),
 129                            Some(model.supports_thinking()),
 130                        );
 131                        Ok(ollama_model)
 132                    }
 133                });
 134
 135            // Rate-limit capability fetches
 136            // since there is an arbitrary number of models available
 137            let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
 138                .buffer_unordered(5)
 139                .collect::<Vec<Result<_>>>()
 140                .await
 141                .into_iter()
 142                .collect::<Result<Vec<_>>>()?;
 143
 144            ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
 145
 146            this.update(cx, |this, cx| {
 147                this.fetched_models = ollama_models;
 148                cx.notify();
 149            })
 150        })
 151    }
 152
 153    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 154        let task = self.fetch_models(cx);
 155        self.fetch_model_task.replace(task);
 156    }
 157}
 158
 159impl OllamaLanguageModelProvider {
 160    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
 161        let this = Self {
 162            http_client: http_client.clone(),
 163            state: cx.new(|cx| {
 164                cx.observe_global::<SettingsStore>({
 165                    let mut last_settings = OllamaLanguageModelProvider::settings(cx).clone();
 166                    move |this: &mut State, cx| {
 167                        let current_settings = OllamaLanguageModelProvider::settings(cx);
 168                        let settings_changed = current_settings != &last_settings;
 169                        if settings_changed {
 170                            let url_changed = last_settings.api_url != current_settings.api_url;
 171                            last_settings = current_settings.clone();
 172                            if url_changed {
 173                                this.fetched_models.clear();
 174                                this.authenticate(cx).detach();
 175                            }
 176                            cx.notify();
 177                        }
 178                    }
 179                })
 180                .detach();
 181
 182                State {
 183                    http_client,
 184                    fetched_models: Default::default(),
 185                    fetch_model_task: None,
 186                    api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
 187                }
 188            }),
 189        };
 190        this
 191    }
 192
 193    fn settings(cx: &App) -> &OllamaSettings {
 194        &AllLanguageModelSettings::get_global(cx).ollama
 195    }
 196
 197    fn api_url(cx: &App) -> SharedString {
 198        let api_url = &Self::settings(cx).api_url;
 199        if api_url.is_empty() {
 200            OLLAMA_API_URL.into()
 201        } else {
 202            SharedString::new(api_url.as_str())
 203        }
 204    }
 205
 206    fn has_custom_url(cx: &App) -> bool {
 207        Self::settings(cx).api_url != OLLAMA_API_URL
 208    }
 209}
 210
 211impl LanguageModelProviderState for OllamaLanguageModelProvider {
 212    type ObservableEntity = State;
 213
 214    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 215        Some(self.state.clone())
 216    }
 217}
 218
 219impl LanguageModelProvider for OllamaLanguageModelProvider {
 220    fn id(&self) -> LanguageModelProviderId {
 221        PROVIDER_ID
 222    }
 223
 224    fn name(&self) -> LanguageModelProviderName {
 225        PROVIDER_NAME
 226    }
 227
 228    fn icon(&self) -> IconOrSvg {
 229        IconOrSvg::Icon(IconName::AiOllama)
 230    }
 231
 232    fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
 233        // We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
 234        // In a constrained environment where user might not have enough resources it'll be a bad UX to select something
 235        // to load by default.
 236        None
 237    }
 238
 239    fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
 240        // See explanation for default_model.
 241        None
 242    }
 243
 244    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 245        let mut models: HashMap<String, ollama::Model> = HashMap::new();
 246        let settings = OllamaLanguageModelProvider::settings(cx);
 247
 248        // Add models from the Ollama API
 249        for model in self.state.read(cx).fetched_models.iter() {
 250            let mut model = model.clone();
 251            if let Some(context_window) = settings.context_window {
 252                model.max_tokens = context_window;
 253            }
 254            models.insert(model.name.clone(), model);
 255        }
 256
 257        // Override with available models from settings
 258        merge_settings_into_models(
 259            &mut models,
 260            &settings.available_models,
 261            settings.context_window,
 262        );
 263
 264        let mut models = models
 265            .into_values()
 266            .map(|model| {
 267                Arc::new(OllamaLanguageModel {
 268                    id: LanguageModelId::from(model.name.clone()),
 269                    model,
 270                    http_client: self.http_client.clone(),
 271                    request_limiter: RateLimiter::new(4),
 272                    state: self.state.clone(),
 273                }) as Arc<dyn LanguageModel>
 274            })
 275            .collect::<Vec<_>>();
 276        models.sort_by_key(|model| model.name());
 277        models
 278    }
 279
 280    fn is_authenticated(&self, cx: &App) -> bool {
 281        self.state.read(cx).is_authenticated()
 282    }
 283
 284    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 285        self.state.update(cx, |state, cx| state.authenticate(cx))
 286    }
 287
 288    fn configuration_view(
 289        &self,
 290        _target_agent: language_model::ConfigurationViewTargetAgent,
 291        window: &mut Window,
 292        cx: &mut App,
 293    ) -> AnyView {
 294        let state = self.state.clone();
 295        cx.new(|cx| ConfigurationView::new(state, window, cx))
 296            .into()
 297    }
 298
 299    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 300        self.state
 301            .update(cx, |state, cx| state.set_api_key(None, cx))
 302    }
 303}
 304
 305pub struct OllamaLanguageModel {
 306    id: LanguageModelId,
 307    model: ollama::Model,
 308    http_client: Arc<dyn HttpClient>,
 309    request_limiter: RateLimiter,
 310    state: Entity<State>,
 311}
 312
 313impl OllamaLanguageModel {
 314    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
 315        let supports_vision = self.model.supports_vision.unwrap_or(false);
 316
 317        let mut messages = Vec::with_capacity(request.messages.len());
 318
 319        for mut msg in request.messages.into_iter() {
 320            let images = if supports_vision {
 321                msg.content
 322                    .iter()
 323                    .filter_map(|content| match content {
 324                        MessageContent::Image(image) => Some(image.source.to_string()),
 325                        _ => None,
 326                    })
 327                    .collect::<Vec<String>>()
 328            } else {
 329                vec![]
 330            };
 331
 332            match msg.role {
 333                Role::User => {
 334                    for tool_result in msg
 335                        .content
 336                        .extract_if(.., |x| matches!(x, MessageContent::ToolResult(..)))
 337                    {
 338                        match tool_result {
 339                            MessageContent::ToolResult(tool_result) => {
 340                                messages.push(ChatMessage::Tool {
 341                                    tool_name: tool_result.tool_name.to_string(),
 342                                    content: tool_result.content.to_str().unwrap_or("").to_string(),
 343                                })
 344                            }
 345                            _ => unreachable!("Only tool result should be extracted"),
 346                        }
 347                    }
 348                    if !msg.content.is_empty() {
 349                        messages.push(ChatMessage::User {
 350                            content: msg.string_contents(),
 351                            images: if images.is_empty() {
 352                                None
 353                            } else {
 354                                Some(images)
 355                            },
 356                        })
 357                    }
 358                }
 359                Role::Assistant => {
 360                    let content = msg.string_contents();
 361                    let mut thinking = None;
 362                    let mut tool_calls = Vec::new();
 363                    for content in msg.content.into_iter() {
 364                        match content {
 365                            MessageContent::Thinking { text, .. } if !text.is_empty() => {
 366                                thinking = Some(text)
 367                            }
 368                            MessageContent::ToolUse(tool_use) => {
 369                                tool_calls.push(OllamaToolCall {
 370                                    id: tool_use.id.to_string(),
 371                                    function: OllamaFunctionCall {
 372                                        name: tool_use.name.to_string(),
 373                                        arguments: tool_use.input,
 374                                    },
 375                                });
 376                            }
 377                            _ => (),
 378                        }
 379                    }
 380                    messages.push(ChatMessage::Assistant {
 381                        content,
 382                        tool_calls: Some(tool_calls),
 383                        images: if images.is_empty() {
 384                            None
 385                        } else {
 386                            Some(images)
 387                        },
 388                        thinking,
 389                    })
 390                }
 391                Role::System => messages.push(ChatMessage::System {
 392                    content: msg.string_contents(),
 393                }),
 394            }
 395        }
 396        ChatRequest {
 397            model: self.model.name.clone(),
 398            messages,
 399            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
 400            stream: true,
 401            options: Some(ChatOptions {
 402                num_ctx: Some(self.model.max_tokens),
 403                // Only send stop tokens if explicitly provided. When empty/None,
 404                // Ollama will use the model's default stop tokens from its Modelfile.
 405                // Sending an empty array would override and disable the defaults.
 406                stop: if request.stop.is_empty() {
 407                    None
 408                } else {
 409                    Some(request.stop)
 410                },
 411                temperature: request.temperature.or(Some(1.0)),
 412                ..Default::default()
 413            }),
 414            think: self
 415                .model
 416                .supports_thinking
 417                .map(|supports_thinking| supports_thinking && request.thinking_allowed),
 418            tools: if self.model.supports_tools.unwrap_or(false) {
 419                request.tools.into_iter().map(tool_into_ollama).collect()
 420            } else {
 421                vec![]
 422            },
 423        }
 424    }
 425}
 426
 427impl LanguageModel for OllamaLanguageModel {
 428    fn id(&self) -> LanguageModelId {
 429        self.id.clone()
 430    }
 431
 432    fn name(&self) -> LanguageModelName {
 433        LanguageModelName::from(self.model.display_name().to_string())
 434    }
 435
 436    fn provider_id(&self) -> LanguageModelProviderId {
 437        PROVIDER_ID
 438    }
 439
 440    fn provider_name(&self) -> LanguageModelProviderName {
 441        PROVIDER_NAME
 442    }
 443
 444    fn supports_tools(&self) -> bool {
 445        self.model.supports_tools.unwrap_or(false)
 446    }
 447
 448    fn supports_images(&self) -> bool {
 449        self.model.supports_vision.unwrap_or(false)
 450    }
 451
 452    fn supports_thinking(&self) -> bool {
 453        self.model.supports_thinking.unwrap_or(false)
 454    }
 455
 456    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 457        match choice {
 458            LanguageModelToolChoice::Auto => false,
 459            LanguageModelToolChoice::Any => false,
 460            LanguageModelToolChoice::None => false,
 461        }
 462    }
 463
 464    fn telemetry_id(&self) -> String {
 465        format!("ollama/{}", self.model.id())
 466    }
 467
 468    fn max_token_count(&self) -> u64 {
 469        self.model.max_token_count()
 470    }
 471
 472    fn count_tokens(
 473        &self,
 474        request: LanguageModelRequest,
 475        _cx: &App,
 476    ) -> BoxFuture<'static, Result<u64>> {
 477        // There is no endpoint for this _yet_ in Ollama
 478        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
 479        let token_count = request
 480            .messages
 481            .iter()
 482            .map(|msg| msg.string_contents().chars().count())
 483            .sum::<usize>()
 484            / 4;
 485
 486        async move { Ok(token_count as u64) }.boxed()
 487    }
 488
 489    fn stream_completion(
 490        &self,
 491        request: LanguageModelRequest,
 492        cx: &AsyncApp,
 493    ) -> BoxFuture<
 494        'static,
 495        Result<
 496            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 497            LanguageModelCompletionError,
 498        >,
 499    > {
 500        let request = self.to_ollama_request(request);
 501
 502        let http_client = self.http_client.clone();
 503        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
 504            let api_url = OllamaLanguageModelProvider::api_url(cx);
 505            (state.api_key_state.key(&api_url), api_url)
 506        });
 507
 508        let future = self.request_limiter.stream(async move {
 509            let stream =
 510                stream_chat_completion(http_client.as_ref(), &api_url, api_key.as_deref(), request)
 511                    .await?;
 512            let stream = map_to_language_model_completion_events(stream);
 513            Ok(stream)
 514        });
 515
 516        future.map_ok(|f| f.boxed()).boxed()
 517    }
 518}
 519
 520fn map_to_language_model_completion_events(
 521    stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
 522) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 523    struct State {
 524        stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
 525        used_tools: bool,
 526    }
 527
 528    // We need to create a ToolUse and Stop event from a single
 529    // response from the original stream
 530    let stream = stream::unfold(
 531        State {
 532            stream,
 533            used_tools: false,
 534        },
 535        async move |mut state| {
 536            let response = state.stream.next().await?;
 537
 538            let delta = match response {
 539                Ok(delta) => delta,
 540                Err(e) => {
 541                    let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
 542                    return Some((vec![event], state));
 543                }
 544            };
 545
 546            let mut events = Vec::new();
 547
 548            match delta.message {
 549                ChatMessage::User { content, images: _ } => {
 550                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
 551                }
 552                ChatMessage::System { content } => {
 553                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
 554                }
 555                ChatMessage::Tool { content, .. } => {
 556                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
 557                }
 558                ChatMessage::Assistant {
 559                    content,
 560                    tool_calls,
 561                    images: _,
 562                    thinking,
 563                } => {
 564                    if let Some(text) = thinking {
 565                        events.push(Ok(LanguageModelCompletionEvent::Thinking {
 566                            text,
 567                            signature: None,
 568                        }));
 569                    }
 570
 571                    if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) {
 572                        let OllamaToolCall { id, function } = tool_call;
 573                        let event = LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
 574                            id: LanguageModelToolUseId::from(id),
 575                            name: Arc::from(function.name),
 576                            raw_input: function.arguments.to_string(),
 577                            input: function.arguments,
 578                            is_input_complete: true,
 579                            thought_signature: None,
 580                        });
 581                        events.push(Ok(event));
 582                        state.used_tools = true;
 583                    } else if !content.is_empty() {
 584                        events.push(Ok(LanguageModelCompletionEvent::Text(content)));
 585                    }
 586                }
 587            };
 588
 589            if delta.done {
 590                events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 591                    input_tokens: delta.prompt_eval_count.unwrap_or(0),
 592                    output_tokens: delta.eval_count.unwrap_or(0),
 593                    cache_creation_input_tokens: 0,
 594                    cache_read_input_tokens: 0,
 595                })));
 596                if state.used_tools {
 597                    state.used_tools = false;
 598                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 599                } else {
 600                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 601                }
 602            }
 603
 604            Some((events, state))
 605        },
 606    );
 607
 608    stream.flat_map(futures::stream::iter)
 609}
 610
 611struct ConfigurationView {
 612    api_key_editor: Entity<InputField>,
 613    api_url_editor: Entity<InputField>,
 614    context_window_editor: Entity<InputField>,
 615    state: Entity<State>,
 616}
 617
 618impl ConfigurationView {
 619    pub fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 620        let api_key_editor = cx.new(|cx| InputField::new(window, cx, "63e02e...").label("API key"));
 621
 622        let api_url_editor = cx.new(|cx| {
 623            let input = InputField::new(window, cx, OLLAMA_API_URL).label("API URL");
 624            input.set_text(&OllamaLanguageModelProvider::api_url(cx), window, cx);
 625            input
 626        });
 627
 628        let context_window_editor = cx.new(|cx| {
 629            let input = InputField::new(window, cx, "8192").label("Context Window");
 630            if let Some(context_window) = OllamaLanguageModelProvider::settings(cx).context_window {
 631                input.set_text(&context_window.to_string(), window, cx);
 632            }
 633            input
 634        });
 635
 636        cx.observe(&state, |_, _, cx| {
 637            cx.notify();
 638        })
 639        .detach();
 640
 641        Self {
 642            api_key_editor,
 643            api_url_editor,
 644            context_window_editor,
 645            state,
 646        }
 647    }
 648
 649    fn retry_connection(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 650        let has_api_url = OllamaLanguageModelProvider::has_custom_url(cx);
 651        let has_api_key = self
 652            .state
 653            .read_with(cx, |state, _| state.api_key_state.has_key());
 654        if !has_api_url {
 655            self.save_api_url(cx);
 656        }
 657        if !has_api_key {
 658            self.save_api_key(&Default::default(), window, cx);
 659        }
 660
 661        self.state.update(cx, |state, cx| {
 662            state.restart_fetch_models_task(cx);
 663        });
 664    }
 665
 666    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 667        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 668        if api_key.is_empty() {
 669            return;
 670        }
 671
 672        // url changes can cause the editor to be displayed again
 673        self.api_key_editor
 674            .update(cx, |input, cx| input.set_text("", window, cx));
 675
 676        let state = self.state.clone();
 677        cx.spawn_in(window, async move |_, cx| {
 678            state
 679                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
 680                .await
 681        })
 682        .detach_and_log_err(cx);
 683    }
 684
 685    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 686        self.api_key_editor
 687            .update(cx, |input, cx| input.set_text("", window, cx));
 688
 689        let state = self.state.clone();
 690        cx.spawn_in(window, async move |_, cx| {
 691            state
 692                .update(cx, |state, cx| state.set_api_key(None, cx))
 693                .await
 694        })
 695        .detach_and_log_err(cx);
 696
 697        cx.notify();
 698    }
 699
 700    fn save_api_url(&self, cx: &mut Context<Self>) {
 701        let api_url = self.api_url_editor.read(cx).text(cx).trim().to_string();
 702        let current_url = OllamaLanguageModelProvider::api_url(cx);
 703        if !api_url.is_empty() && &api_url != &current_url {
 704            let fs = <dyn Fs>::global(cx);
 705            update_settings_file(fs, cx, move |settings, _| {
 706                settings
 707                    .language_models
 708                    .get_or_insert_default()
 709                    .ollama
 710                    .get_or_insert_default()
 711                    .api_url = Some(api_url);
 712            });
 713        }
 714    }
 715
 716    fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 717        self.api_url_editor
 718            .update(cx, |input, cx| input.set_text("", window, cx));
 719        let fs = <dyn Fs>::global(cx);
 720        update_settings_file(fs, cx, |settings, _cx| {
 721            if let Some(settings) = settings
 722                .language_models
 723                .as_mut()
 724                .and_then(|models| models.ollama.as_mut())
 725            {
 726                settings.api_url = Some(OLLAMA_API_URL.into());
 727            }
 728        });
 729        cx.notify();
 730    }
 731
 732    fn save_context_window(&mut self, cx: &mut Context<Self>) {
 733        let context_window_str = self
 734            .context_window_editor
 735            .read(cx)
 736            .text(cx)
 737            .trim()
 738            .to_string();
 739        let current_context_window = OllamaLanguageModelProvider::settings(cx).context_window;
 740
 741        if let Ok(context_window) = context_window_str.parse::<u64>() {
 742            if Some(context_window) != current_context_window {
 743                let fs = <dyn Fs>::global(cx);
 744                update_settings_file(fs, cx, move |settings, _| {
 745                    settings
 746                        .language_models
 747                        .get_or_insert_default()
 748                        .ollama
 749                        .get_or_insert_default()
 750                        .context_window = Some(context_window);
 751                });
 752            }
 753        } else if context_window_str.is_empty() && current_context_window.is_some() {
 754            let fs = <dyn Fs>::global(cx);
 755            update_settings_file(fs, cx, move |settings, _| {
 756                settings
 757                    .language_models
 758                    .get_or_insert_default()
 759                    .ollama
 760                    .get_or_insert_default()
 761                    .context_window = None;
 762            });
 763        }
 764    }
 765
 766    fn reset_context_window(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 767        self.context_window_editor
 768            .update(cx, |input, cx| input.set_text("", window, cx));
 769        let fs = <dyn Fs>::global(cx);
 770        update_settings_file(fs, cx, |settings, _cx| {
 771            if let Some(settings) = settings
 772                .language_models
 773                .as_mut()
 774                .and_then(|models| models.ollama.as_mut())
 775            {
 776                settings.context_window = None;
 777            }
 778        });
 779        cx.notify();
 780    }
 781
 782    fn render_instructions(cx: &App) -> Div {
 783        v_flex()
 784            .gap_2()
 785            .child(Label::new(
 786                "Run LLMs locally on your machine with Ollama, or connect to an Ollama server. \
 787                Can provide access to Llama, Mistral, Gemma, and hundreds of other models.",
 788            ))
 789            .child(Label::new("To use local Ollama:"))
 790            .child(
 791                List::new()
 792                    .child(
 793                        ListBulletItem::new("")
 794                            .child(Label::new("Download and install Ollama from"))
 795                            .child(ButtonLink::new("ollama.com", "https://ollama.com/download")),
 796                    )
 797                    .child(
 798                        ListBulletItem::new("")
 799                            .child(Label::new("Start Ollama and download a model:"))
 800                            .child(Label::new("ollama run gpt-oss:20b").inline_code(cx)),
 801                    )
 802                    .child(ListBulletItem::new(
 803                        "Click 'Connect' below to start using Ollama in Zed",
 804                    )),
 805            )
 806            .child(Label::new(
 807                "Alternatively, you can connect to an Ollama server by specifying its \
 808                URL and API key (may not be required):",
 809            ))
 810    }
 811
 812    fn render_api_key_editor(&self, cx: &Context<Self>) -> impl IntoElement {
 813        let state = self.state.read(cx);
 814        let env_var_set = state.api_key_state.is_from_env_var();
 815        let configured_card_label = if env_var_set {
 816            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
 817        } else {
 818            "API key configured".to_string()
 819        };
 820
 821        if !state.api_key_state.has_key() {
 822            v_flex()
 823              .on_action(cx.listener(Self::save_api_key))
 824              .child(self.api_key_editor.clone())
 825              .child(
 826                  Label::new(
 827                      format!("You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed.")
 828                  )
 829                  .size(LabelSize::Small)
 830                  .color(Color::Muted),
 831              )
 832              .into_any_element()
 833        } else {
 834            ConfiguredApiCard::new(configured_card_label)
 835                .disabled(env_var_set)
 836                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 837                .when(env_var_set, |this| {
 838                    this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
 839                })
 840                .into_any_element()
 841        }
 842    }
 843
 844    fn render_context_window_editor(&self, cx: &Context<Self>) -> Div {
 845        let settings = OllamaLanguageModelProvider::settings(cx);
 846        let custom_context_window_set = settings.context_window.is_some();
 847
 848        if custom_context_window_set {
 849            h_flex()
 850                .p_3()
 851                .justify_between()
 852                .rounded_md()
 853                .border_1()
 854                .border_color(cx.theme().colors().border)
 855                .bg(cx.theme().colors().elevated_surface_background)
 856                .child(
 857                    h_flex()
 858                        .gap_2()
 859                        .child(Icon::new(IconName::Check).color(Color::Success))
 860                        .child(v_flex().gap_1().child(Label::new(format!(
 861                            "Context Window: {}",
 862                            settings.context_window.unwrap()
 863                        )))),
 864                )
 865                .child(
 866                    Button::new("reset-context-window", "Reset")
 867                        .label_size(LabelSize::Small)
 868                        .start_icon(Icon::new(IconName::Undo).size(IconSize::Small))
 869                        .layer(ElevationIndex::ModalSurface)
 870                        .on_click(
 871                            cx.listener(|this, _, window, cx| {
 872                                this.reset_context_window(window, cx)
 873                            }),
 874                        ),
 875                )
 876        } else {
 877            v_flex()
 878                .on_action(
 879                    cx.listener(|this, _: &menu::Confirm, _window, cx| {
 880                        this.save_context_window(cx)
 881                    }),
 882                )
 883                .child(self.context_window_editor.clone())
 884                .child(
 885                    Label::new("Default: Model specific")
 886                        .size(LabelSize::Small)
 887                        .color(Color::Muted),
 888                )
 889        }
 890    }
 891
 892    fn render_api_url_editor(&self, cx: &Context<Self>) -> Div {
 893        let api_url = OllamaLanguageModelProvider::api_url(cx);
 894        let custom_api_url_set = api_url != OLLAMA_API_URL;
 895
 896        if custom_api_url_set {
 897            h_flex()
 898                .p_3()
 899                .justify_between()
 900                .rounded_md()
 901                .border_1()
 902                .border_color(cx.theme().colors().border)
 903                .bg(cx.theme().colors().elevated_surface_background)
 904                .child(
 905                    h_flex()
 906                        .gap_2()
 907                        .child(Icon::new(IconName::Check).color(Color::Success))
 908                        .child(v_flex().gap_1().child(Label::new(api_url))),
 909                )
 910                .child(
 911                    Button::new("reset-api-url", "Reset API URL")
 912                        .label_size(LabelSize::Small)
 913                        .start_icon(Icon::new(IconName::Undo).size(IconSize::Small))
 914                        .layer(ElevationIndex::ModalSurface)
 915                        .on_click(
 916                            cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
 917                        ),
 918                )
 919        } else {
 920            v_flex()
 921                .on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
 922                    this.save_api_url(cx);
 923                    cx.notify();
 924                }))
 925                .gap_2()
 926                .child(self.api_url_editor.clone())
 927        }
 928    }
 929}
 930
 931impl Render for ConfigurationView {
 932    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 933        let is_authenticated = self.state.read(cx).is_authenticated();
 934
 935        v_flex()
 936            .gap_2()
 937            .child(Self::render_instructions(cx))
 938            .child(self.render_api_url_editor(cx))
 939            .child(self.render_context_window_editor(cx))
 940            .child(self.render_api_key_editor(cx))
 941            .child(
 942                h_flex()
 943                    .w_full()
 944                    .justify_between()
 945                    .gap_2()
 946                    .child(
 947                        h_flex()
 948                            .w_full()
 949                            .gap_2()
 950                            .map(|this| {
 951                                if is_authenticated {
 952                                    this.child(
 953                                        Button::new("ollama-site", "Ollama")
 954                                            .style(ButtonStyle::Subtle)
 955                                            .end_icon(
 956                                                Icon::new(IconName::ArrowUpRight)
 957                                                    .size(IconSize::XSmall)
 958                                                    .color(Color::Muted),
 959                                            )
 960                                            .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
 961                                            .into_any_element(),
 962                                    )
 963                                } else {
 964                                    this.child(
 965                                        Button::new("download_ollama_button", "Download Ollama")
 966                                            .style(ButtonStyle::Subtle)
 967                                            .end_icon(
 968                                                Icon::new(IconName::ArrowUpRight)
 969                                                    .size(IconSize::XSmall)
 970                                                    .color(Color::Muted),
 971                                            )
 972                                            .on_click(move |_, _, cx| {
 973                                                cx.open_url(OLLAMA_DOWNLOAD_URL)
 974                                            })
 975                                            .into_any_element(),
 976                                    )
 977                                }
 978                            })
 979                            .child(
 980                                Button::new("view-models", "View All Models")
 981                                    .style(ButtonStyle::Subtle)
 982                                    .end_icon(
 983                                        Icon::new(IconName::ArrowUpRight)
 984                                            .size(IconSize::XSmall)
 985                                            .color(Color::Muted),
 986                                    )
 987                                    .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
 988                            ),
 989                    )
 990                    .map(|this| {
 991                        if is_authenticated {
 992                            this.child(
 993                                ButtonLike::new("connected")
 994                                    .disabled(true)
 995                                    .cursor_style(CursorStyle::Arrow)
 996                                    .child(
 997                                        h_flex()
 998                                            .gap_2()
 999                                            .child(Icon::new(IconName::Check).color(Color::Success))
1000                                            .child(Label::new("Connected"))
1001                                            .into_any_element(),
1002                                    )
1003                                    .child(
1004                                        IconButton::new("refresh-models", IconName::RotateCcw)
1005                                            .tooltip(Tooltip::text("Refresh Models"))
1006                                            .on_click(cx.listener(|this, _, window, cx| {
1007                                                this.state.update(cx, |state, _| {
1008                                                    state.fetched_models.clear();
1009                                                });
1010                                                this.retry_connection(window, cx);
1011                                            })),
1012                                    ),
1013                            )
1014                        } else {
1015                            this.child(
1016                                Button::new("retry_ollama_models", "Connect")
1017                                    .start_icon(
1018                                        Icon::new(IconName::PlayOutlined).size(IconSize::XSmall),
1019                                    )
1020                                    .on_click(cx.listener(move |this, _, window, cx| {
1021                                        this.retry_connection(window, cx)
1022                                    })),
1023                            )
1024                        }
1025                    }),
1026            )
1027    }
1028}
1029
1030fn merge_settings_into_models(
1031    models: &mut HashMap<String, ollama::Model>,
1032    available_models: &[AvailableModel],
1033    context_window: Option<u64>,
1034) {
1035    for setting_model in available_models {
1036        if let Some(model) = models.get_mut(&setting_model.name) {
1037            if context_window.is_none() {
1038                model.max_tokens = setting_model.max_tokens;
1039            }
1040            model.display_name = setting_model.display_name.clone();
1041            model.keep_alive = setting_model.keep_alive.clone();
1042            model.supports_tools = setting_model.supports_tools;
1043            model.supports_vision = setting_model.supports_images;
1044            model.supports_thinking = setting_model.supports_thinking;
1045        } else {
1046            models.insert(
1047                setting_model.name.clone(),
1048                ollama::Model {
1049                    name: setting_model.name.clone(),
1050                    display_name: setting_model.display_name.clone(),
1051                    max_tokens: context_window.unwrap_or(setting_model.max_tokens),
1052                    keep_alive: setting_model.keep_alive.clone(),
1053                    supports_tools: setting_model.supports_tools,
1054                    supports_vision: setting_model.supports_images,
1055                    supports_thinking: setting_model.supports_thinking,
1056                },
1057            );
1058        }
1059    }
1060}
1061
1062fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
1063    ollama::OllamaTool::Function {
1064        function: OllamaFunctionTool {
1065            name: tool.name,
1066            description: Some(tool.description),
1067            parameters: Some(tool.input_schema),
1068        },
1069    }
1070}
1071
1072#[cfg(test)]
1073mod tests {
1074    use super::*;
1075
1076    #[test]
1077    fn test_merge_settings_preserves_display_names_for_similar_models() {
1078        // Regression test for https://github.com/zed-industries/zed/issues/43646
1079        // When multiple models share the same base name (e.g., qwen2.5-coder:1.5b and qwen2.5-coder:3b),
1080        // each model should get its own display_name from settings, not a random one.
1081
1082        let mut models: HashMap<String, ollama::Model> = HashMap::new();
1083        models.insert(
1084            "qwen2.5-coder:1.5b".to_string(),
1085            ollama::Model {
1086                name: "qwen2.5-coder:1.5b".to_string(),
1087                display_name: None,
1088                max_tokens: 4096,
1089                keep_alive: None,
1090                supports_tools: None,
1091                supports_vision: None,
1092                supports_thinking: None,
1093            },
1094        );
1095        models.insert(
1096            "qwen2.5-coder:3b".to_string(),
1097            ollama::Model {
1098                name: "qwen2.5-coder:3b".to_string(),
1099                display_name: None,
1100                max_tokens: 4096,
1101                keep_alive: None,
1102                supports_tools: None,
1103                supports_vision: None,
1104                supports_thinking: None,
1105            },
1106        );
1107
1108        let available_models = vec![
1109            AvailableModel {
1110                name: "qwen2.5-coder:1.5b".to_string(),
1111                display_name: Some("QWEN2.5 Coder 1.5B".to_string()),
1112                max_tokens: 5000,
1113                keep_alive: None,
1114                supports_tools: Some(true),
1115                supports_images: None,
1116                supports_thinking: None,
1117            },
1118            AvailableModel {
1119                name: "qwen2.5-coder:3b".to_string(),
1120                display_name: Some("QWEN2.5 Coder 3B".to_string()),
1121                max_tokens: 6000,
1122                keep_alive: None,
1123                supports_tools: Some(true),
1124                supports_images: None,
1125                supports_thinking: None,
1126            },
1127        ];
1128
1129        merge_settings_into_models(&mut models, &available_models, None);
1130
1131        let model_1_5b = models
1132            .get("qwen2.5-coder:1.5b")
1133            .expect("1.5b model missing");
1134        let model_3b = models.get("qwen2.5-coder:3b").expect("3b model missing");
1135
1136        assert_eq!(
1137            model_1_5b.display_name,
1138            Some("QWEN2.5 Coder 1.5B".to_string()),
1139            "1.5b model should have its own display_name"
1140        );
1141        assert_eq!(model_1_5b.max_tokens, 5000);
1142
1143        assert_eq!(
1144            model_3b.display_name,
1145            Some("QWEN2.5 Coder 3B".to_string()),
1146            "3b model should have its own display_name"
1147        );
1148        assert_eq!(model_3b.max_tokens, 6000);
1149    }
1150}