ollama.rs

   1use anyhow::{Result, anyhow};
   2use credentials_provider::CredentialsProvider;
   3use fs::Fs;
   4use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
   5use futures::{Stream, TryFutureExt, stream};
   6use gpui::{AnyView, App, AsyncApp, Context, CursorStyle, Entity, Task};
   7use http_client::HttpClient;
   8use language_model::{
   9    ApiKeyState, AuthenticateError, EnvVar, IconOrSvg, LanguageModel, LanguageModelCompletionError,
  10    LanguageModelCompletionEvent, LanguageModelId, LanguageModelName, LanguageModelProvider,
  11    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  12    LanguageModelRequest, LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
  13    LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage, env_var,
  14};
  15use menu;
  16use ollama::{
  17    ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, OLLAMA_API_URL, OllamaFunctionCall,
  18    OllamaFunctionTool, OllamaToolCall, get_models, show_model, stream_chat_completion,
  19};
  20pub use settings::OllamaAvailableModel as AvailableModel;
  21use settings::{Settings, SettingsStore, update_settings_file};
  22use std::pin::Pin;
  23use std::sync::LazyLock;
  24use std::{collections::HashMap, sync::Arc};
  25use ui::{
  26    ButtonLike, ButtonLink, ConfiguredApiCard, ElevationIndex, List, ListBulletItem, Tooltip,
  27    prelude::*,
  28};
  29use ui_input::InputField;
  30
  31use crate::AllLanguageModelSettings;
  32
  33const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
  34const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
  35const OLLAMA_SITE: &str = "https://ollama.com/";
  36
  37const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
  38const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
  39
  40const API_KEY_ENV_VAR_NAME: &str = "OLLAMA_API_KEY";
  41static API_KEY_ENV_VAR: LazyLock<EnvVar> = env_var!(API_KEY_ENV_VAR_NAME);
  42
  43#[derive(Default, Debug, Clone, PartialEq)]
  44pub struct OllamaSettings {
  45    pub api_url: String,
  46    pub auto_discover: bool,
  47    pub available_models: Vec<AvailableModel>,
  48    pub context_window: Option<u64>,
  49}
  50
  51pub struct OllamaLanguageModelProvider {
  52    http_client: Arc<dyn HttpClient>,
  53    state: Entity<State>,
  54}
  55
  56pub struct State {
  57    api_key_state: ApiKeyState,
  58    credentials_provider: Arc<dyn CredentialsProvider>,
  59    http_client: Arc<dyn HttpClient>,
  60    fetched_models: Vec<ollama::Model>,
  61    fetch_model_task: Option<Task<Result<()>>>,
  62}
  63
  64impl State {
  65    fn is_authenticated(&self) -> bool {
  66        !self.fetched_models.is_empty()
  67    }
  68
  69    fn set_api_key(&mut self, api_key: Option<String>, cx: &mut Context<Self>) -> Task<Result<()>> {
  70        let credentials_provider = self.credentials_provider.clone();
  71        let api_url = OllamaLanguageModelProvider::api_url(cx);
  72        let task = self.api_key_state.store(
  73            api_url,
  74            api_key,
  75            |this| &mut this.api_key_state,
  76            credentials_provider,
  77            cx,
  78        );
  79
  80        self.fetched_models.clear();
  81        cx.spawn(async move |this, cx| {
  82            let result = task.await;
  83            this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
  84                .ok();
  85            result
  86        })
  87    }
  88
  89    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
  90        let credentials_provider = self.credentials_provider.clone();
  91        let api_url = OllamaLanguageModelProvider::api_url(cx);
  92        let task = self.api_key_state.load_if_needed(
  93            api_url,
  94            |this| &mut this.api_key_state,
  95            credentials_provider,
  96            cx,
  97        );
  98
  99        // Always try to fetch models - if no API key is needed (local Ollama), it will work
 100        // If API key is needed and provided, it will work
 101        // If API key is needed and not provided, it will fail gracefully
 102        cx.spawn(async move |this, cx| {
 103            let result = task.await;
 104            this.update(cx, |this, cx| this.restart_fetch_models_task(cx))
 105                .ok();
 106            result
 107        })
 108    }
 109
 110    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 111        let http_client = Arc::clone(&self.http_client);
 112        let api_url = OllamaLanguageModelProvider::api_url(cx);
 113        let api_key = self.api_key_state.key(&api_url);
 114
 115        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 116        cx.spawn(async move |this, cx| {
 117            let models = get_models(http_client.as_ref(), &api_url, api_key.as_deref()).await?;
 118
 119            let tasks = models
 120                .into_iter()
 121                // Since there is no metadata from the Ollama API
 122                // indicating which models are embedding models,
 123                // simply filter out models with "-embed" in their name
 124                .filter(|model| !model.name.contains("-embed"))
 125                .map(|model| {
 126                    let http_client = Arc::clone(&http_client);
 127                    let api_url = api_url.clone();
 128                    let api_key = api_key.clone();
 129                    async move {
 130                        let name = model.name.as_str();
 131                        let model =
 132                            show_model(http_client.as_ref(), &api_url, api_key.as_deref(), name)
 133                                .await?;
 134                        let ollama_model = ollama::Model::new(
 135                            name,
 136                            None,
 137                            model.context_length,
 138                            Some(model.supports_tools()),
 139                            Some(model.supports_vision()),
 140                            Some(model.supports_thinking()),
 141                        );
 142                        Ok(ollama_model)
 143                    }
 144                });
 145
 146            // Rate-limit capability fetches
 147            // since there is an arbitrary number of models available
 148            let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
 149                .buffer_unordered(5)
 150                .collect::<Vec<Result<_>>>()
 151                .await
 152                .into_iter()
 153                .collect::<Result<Vec<_>>>()?;
 154
 155            ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
 156
 157            this.update(cx, |this, cx| {
 158                this.fetched_models = ollama_models;
 159                cx.notify();
 160            })
 161        })
 162    }
 163
 164    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 165        let task = self.fetch_models(cx);
 166        self.fetch_model_task.replace(task);
 167    }
 168}
 169
 170impl OllamaLanguageModelProvider {
 171    pub fn new(
 172        http_client: Arc<dyn HttpClient>,
 173        credentials_provider: Arc<dyn CredentialsProvider>,
 174        cx: &mut App,
 175    ) -> Self {
 176        let this = Self {
 177            http_client: http_client.clone(),
 178            state: cx.new(|cx| {
 179                cx.observe_global::<SettingsStore>({
 180                    let mut last_settings = OllamaLanguageModelProvider::settings(cx).clone();
 181                    move |this: &mut State, cx| {
 182                        let current_settings = OllamaLanguageModelProvider::settings(cx);
 183                        let settings_changed = current_settings != &last_settings;
 184                        if settings_changed {
 185                            let url_changed = last_settings.api_url != current_settings.api_url;
 186                            last_settings = current_settings.clone();
 187                            if url_changed {
 188                                let credentials_provider = this.credentials_provider.clone();
 189                                let api_url = Self::api_url(cx);
 190                                this.api_key_state.handle_url_change(
 191                                    api_url,
 192                                    |this| &mut this.api_key_state,
 193                                    credentials_provider,
 194                                    cx,
 195                                );
 196                                this.fetched_models.clear();
 197                                this.authenticate(cx).detach();
 198                            }
 199                            cx.notify();
 200                        }
 201                    }
 202                })
 203                .detach();
 204
 205                State {
 206                    http_client,
 207                    fetched_models: Default::default(),
 208                    fetch_model_task: None,
 209                    api_key_state: ApiKeyState::new(Self::api_url(cx), (*API_KEY_ENV_VAR).clone()),
 210                    credentials_provider,
 211                }
 212            }),
 213        };
 214        this
 215    }
 216
 217    fn settings(cx: &App) -> &OllamaSettings {
 218        &AllLanguageModelSettings::get_global(cx).ollama
 219    }
 220
 221    fn api_url(cx: &App) -> SharedString {
 222        let api_url = &Self::settings(cx).api_url;
 223        if api_url.is_empty() {
 224            OLLAMA_API_URL.into()
 225        } else {
 226            SharedString::new(api_url.as_str())
 227        }
 228    }
 229
 230    fn has_custom_url(cx: &App) -> bool {
 231        Self::settings(cx).api_url != OLLAMA_API_URL
 232    }
 233}
 234
 235impl LanguageModelProviderState for OllamaLanguageModelProvider {
 236    type ObservableEntity = State;
 237
 238    fn observable_entity(&self) -> Option<Entity<Self::ObservableEntity>> {
 239        Some(self.state.clone())
 240    }
 241}
 242
 243impl LanguageModelProvider for OllamaLanguageModelProvider {
 244    fn id(&self) -> LanguageModelProviderId {
 245        PROVIDER_ID
 246    }
 247
 248    fn name(&self) -> LanguageModelProviderName {
 249        PROVIDER_NAME
 250    }
 251
 252    fn icon(&self) -> IconOrSvg {
 253        IconOrSvg::Icon(IconName::AiOllama)
 254    }
 255
 256    fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
 257        // We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
 258        // In a constrained environment where user might not have enough resources it'll be a bad UX to select something
 259        // to load by default.
 260        None
 261    }
 262
 263    fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
 264        // See explanation for default_model.
 265        None
 266    }
 267
 268    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
 269        let mut models: HashMap<String, ollama::Model> = HashMap::new();
 270        let settings = OllamaLanguageModelProvider::settings(cx);
 271
 272        // Add models from the Ollama API
 273        for model in self.state.read(cx).fetched_models.iter() {
 274            let mut model = model.clone();
 275            if let Some(context_window) = settings.context_window {
 276                model.max_tokens = context_window;
 277            }
 278            models.insert(model.name.clone(), model);
 279        }
 280
 281        // Override with available models from settings
 282        merge_settings_into_models(
 283            &mut models,
 284            &settings.available_models,
 285            settings.context_window,
 286        );
 287
 288        let mut models = models
 289            .into_values()
 290            .map(|model| {
 291                Arc::new(OllamaLanguageModel {
 292                    id: LanguageModelId::from(model.name.clone()),
 293                    model,
 294                    http_client: self.http_client.clone(),
 295                    request_limiter: RateLimiter::new(4),
 296                    state: self.state.clone(),
 297                }) as Arc<dyn LanguageModel>
 298            })
 299            .collect::<Vec<_>>();
 300        models.sort_by_key(|model| model.name());
 301        models
 302    }
 303
 304    fn is_authenticated(&self, cx: &App) -> bool {
 305        self.state.read(cx).is_authenticated()
 306    }
 307
 308    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
 309        self.state.update(cx, |state, cx| state.authenticate(cx))
 310    }
 311
 312    fn configuration_view(
 313        &self,
 314        _target_agent: language_model::ConfigurationViewTargetAgent,
 315        window: &mut Window,
 316        cx: &mut App,
 317    ) -> AnyView {
 318        let state = self.state.clone();
 319        cx.new(|cx| ConfigurationView::new(state, window, cx))
 320            .into()
 321    }
 322
 323    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
 324        self.state
 325            .update(cx, |state, cx| state.set_api_key(None, cx))
 326    }
 327}
 328
 329pub struct OllamaLanguageModel {
 330    id: LanguageModelId,
 331    model: ollama::Model,
 332    http_client: Arc<dyn HttpClient>,
 333    request_limiter: RateLimiter,
 334    state: Entity<State>,
 335}
 336
 337impl OllamaLanguageModel {
 338    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
 339        let supports_vision = self.model.supports_vision.unwrap_or(false);
 340
 341        let mut messages = Vec::with_capacity(request.messages.len());
 342
 343        for mut msg in request.messages.into_iter() {
 344            let images = if supports_vision {
 345                msg.content
 346                    .iter()
 347                    .filter_map(|content| match content {
 348                        MessageContent::Image(image) => Some(image.source.to_string()),
 349                        _ => None,
 350                    })
 351                    .collect::<Vec<String>>()
 352            } else {
 353                vec![]
 354            };
 355
 356            match msg.role {
 357                Role::User => {
 358                    for tool_result in msg
 359                        .content
 360                        .extract_if(.., |x| matches!(x, MessageContent::ToolResult(..)))
 361                    {
 362                        match tool_result {
 363                            MessageContent::ToolResult(tool_result) => {
 364                                messages.push(ChatMessage::Tool {
 365                                    tool_name: tool_result.tool_name.to_string(),
 366                                    content: tool_result.content.to_str().unwrap_or("").to_string(),
 367                                })
 368                            }
 369                            _ => unreachable!("Only tool result should be extracted"),
 370                        }
 371                    }
 372                    if !msg.content.is_empty() {
 373                        messages.push(ChatMessage::User {
 374                            content: msg.string_contents(),
 375                            images: if images.is_empty() {
 376                                None
 377                            } else {
 378                                Some(images)
 379                            },
 380                        })
 381                    }
 382                }
 383                Role::Assistant => {
 384                    let content = msg.string_contents();
 385                    let mut thinking = None;
 386                    let mut tool_calls = Vec::new();
 387                    for content in msg.content.into_iter() {
 388                        match content {
 389                            MessageContent::Thinking { text, .. } if !text.is_empty() => {
 390                                thinking = Some(text)
 391                            }
 392                            MessageContent::ToolUse(tool_use) => {
 393                                tool_calls.push(OllamaToolCall {
 394                                    id: tool_use.id.to_string(),
 395                                    function: OllamaFunctionCall {
 396                                        name: tool_use.name.to_string(),
 397                                        arguments: tool_use.input,
 398                                    },
 399                                });
 400                            }
 401                            _ => (),
 402                        }
 403                    }
 404                    messages.push(ChatMessage::Assistant {
 405                        content,
 406                        tool_calls: Some(tool_calls),
 407                        images: if images.is_empty() {
 408                            None
 409                        } else {
 410                            Some(images)
 411                        },
 412                        thinking,
 413                    })
 414                }
 415                Role::System => messages.push(ChatMessage::System {
 416                    content: msg.string_contents(),
 417                }),
 418            }
 419        }
 420        ChatRequest {
 421            model: self.model.name.clone(),
 422            messages,
 423            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
 424            stream: true,
 425            options: Some(ChatOptions {
 426                num_ctx: Some(self.model.max_tokens),
 427                // Only send stop tokens if explicitly provided. When empty/None,
 428                // Ollama will use the model's default stop tokens from its Modelfile.
 429                // Sending an empty array would override and disable the defaults.
 430                stop: if request.stop.is_empty() {
 431                    None
 432                } else {
 433                    Some(request.stop)
 434                },
 435                temperature: request.temperature.or(Some(1.0)),
 436                ..Default::default()
 437            }),
 438            think: self
 439                .model
 440                .supports_thinking
 441                .map(|supports_thinking| supports_thinking && request.thinking_allowed),
 442            tools: if self.model.supports_tools.unwrap_or(false) {
 443                request.tools.into_iter().map(tool_into_ollama).collect()
 444            } else {
 445                vec![]
 446            },
 447        }
 448    }
 449}
 450
 451impl LanguageModel for OllamaLanguageModel {
 452    fn id(&self) -> LanguageModelId {
 453        self.id.clone()
 454    }
 455
 456    fn name(&self) -> LanguageModelName {
 457        LanguageModelName::from(self.model.display_name().to_string())
 458    }
 459
 460    fn provider_id(&self) -> LanguageModelProviderId {
 461        PROVIDER_ID
 462    }
 463
 464    fn provider_name(&self) -> LanguageModelProviderName {
 465        PROVIDER_NAME
 466    }
 467
 468    fn supports_tools(&self) -> bool {
 469        self.model.supports_tools.unwrap_or(false)
 470    }
 471
 472    fn supports_images(&self) -> bool {
 473        self.model.supports_vision.unwrap_or(false)
 474    }
 475
 476    fn supports_thinking(&self) -> bool {
 477        self.model.supports_thinking.unwrap_or(false)
 478    }
 479
 480    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
 481        match choice {
 482            LanguageModelToolChoice::Auto => false,
 483            LanguageModelToolChoice::Any => false,
 484            LanguageModelToolChoice::None => false,
 485        }
 486    }
 487
 488    fn telemetry_id(&self) -> String {
 489        format!("ollama/{}", self.model.id())
 490    }
 491
 492    fn max_token_count(&self) -> u64 {
 493        self.model.max_token_count()
 494    }
 495
 496    fn count_tokens(
 497        &self,
 498        request: LanguageModelRequest,
 499        _cx: &App,
 500    ) -> BoxFuture<'static, Result<u64>> {
 501        // There is no endpoint for this _yet_ in Ollama
 502        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
 503        let token_count = request
 504            .messages
 505            .iter()
 506            .map(|msg| msg.string_contents().chars().count())
 507            .sum::<usize>()
 508            / 4;
 509
 510        async move { Ok(token_count as u64) }.boxed()
 511    }
 512
 513    fn stream_completion(
 514        &self,
 515        request: LanguageModelRequest,
 516        cx: &AsyncApp,
 517    ) -> BoxFuture<
 518        'static,
 519        Result<
 520            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
 521            LanguageModelCompletionError,
 522        >,
 523    > {
 524        let request = self.to_ollama_request(request);
 525
 526        let http_client = self.http_client.clone();
 527        let (api_key, api_url) = self.state.read_with(cx, |state, cx| {
 528            let api_url = OllamaLanguageModelProvider::api_url(cx);
 529            (state.api_key_state.key(&api_url), api_url)
 530        });
 531
 532        let future = self.request_limiter.stream(async move {
 533            let stream =
 534                stream_chat_completion(http_client.as_ref(), &api_url, api_key.as_deref(), request)
 535                    .await?;
 536            let stream = map_to_language_model_completion_events(stream);
 537            Ok(stream)
 538        });
 539
 540        future.map_ok(|f| f.boxed()).boxed()
 541    }
 542}
 543
 544fn map_to_language_model_completion_events(
 545    stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
 546) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
 547    struct State {
 548        stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
 549        used_tools: bool,
 550    }
 551
 552    // We need to create a ToolUse and Stop event from a single
 553    // response from the original stream
 554    let stream = stream::unfold(
 555        State {
 556            stream,
 557            used_tools: false,
 558        },
 559        async move |mut state| {
 560            let response = state.stream.next().await?;
 561
 562            let delta = match response {
 563                Ok(delta) => delta,
 564                Err(e) => {
 565                    let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
 566                    return Some((vec![event], state));
 567                }
 568            };
 569
 570            let mut events = Vec::new();
 571
 572            match delta.message {
 573                ChatMessage::User { content, images: _ } => {
 574                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
 575                }
 576                ChatMessage::System { content } => {
 577                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
 578                }
 579                ChatMessage::Tool { content, .. } => {
 580                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
 581                }
 582                ChatMessage::Assistant {
 583                    content,
 584                    tool_calls,
 585                    images: _,
 586                    thinking,
 587                } => {
 588                    if let Some(text) = thinking {
 589                        events.push(Ok(LanguageModelCompletionEvent::Thinking {
 590                            text,
 591                            signature: None,
 592                        }));
 593                    }
 594
 595                    if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) {
 596                        let OllamaToolCall { id, function } = tool_call;
 597                        let event = LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
 598                            id: LanguageModelToolUseId::from(id),
 599                            name: Arc::from(function.name),
 600                            raw_input: function.arguments.to_string(),
 601                            input: function.arguments,
 602                            is_input_complete: true,
 603                            thought_signature: None,
 604                        });
 605                        events.push(Ok(event));
 606                        state.used_tools = true;
 607                    } else if !content.is_empty() {
 608                        events.push(Ok(LanguageModelCompletionEvent::Text(content)));
 609                    }
 610                }
 611            };
 612
 613            if delta.done {
 614                events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
 615                    input_tokens: delta.prompt_eval_count.unwrap_or(0),
 616                    output_tokens: delta.eval_count.unwrap_or(0),
 617                    cache_creation_input_tokens: 0,
 618                    cache_read_input_tokens: 0,
 619                })));
 620                if state.used_tools {
 621                    state.used_tools = false;
 622                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
 623                } else {
 624                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
 625                }
 626            }
 627
 628            Some((events, state))
 629        },
 630    );
 631
 632    stream.flat_map(futures::stream::iter)
 633}
 634
 635struct ConfigurationView {
 636    api_key_editor: Entity<InputField>,
 637    api_url_editor: Entity<InputField>,
 638    context_window_editor: Entity<InputField>,
 639    state: Entity<State>,
 640}
 641
 642impl ConfigurationView {
 643    pub fn new(state: Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
 644        let api_key_editor = cx.new(|cx| InputField::new(window, cx, "63e02e...").label("API key"));
 645
 646        let api_url_editor = cx.new(|cx| {
 647            let input = InputField::new(window, cx, OLLAMA_API_URL).label("API URL");
 648            input.set_text(&OllamaLanguageModelProvider::api_url(cx), window, cx);
 649            input
 650        });
 651
 652        let context_window_editor = cx.new(|cx| {
 653            let input = InputField::new(window, cx, "8192").label("Context Window");
 654            if let Some(context_window) = OllamaLanguageModelProvider::settings(cx).context_window {
 655                input.set_text(&context_window.to_string(), window, cx);
 656            }
 657            input
 658        });
 659
 660        cx.observe(&state, |_, _, cx| {
 661            cx.notify();
 662        })
 663        .detach();
 664
 665        Self {
 666            api_key_editor,
 667            api_url_editor,
 668            context_window_editor,
 669            state,
 670        }
 671    }
 672
 673    fn retry_connection(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 674        let has_api_url = OllamaLanguageModelProvider::has_custom_url(cx);
 675        let has_api_key = self
 676            .state
 677            .read_with(cx, |state, _| state.api_key_state.has_key());
 678        if !has_api_url {
 679            self.save_api_url(cx);
 680        }
 681        if !has_api_key {
 682            self.save_api_key(&Default::default(), window, cx);
 683        }
 684
 685        self.state.update(cx, |state, cx| {
 686            state.restart_fetch_models_task(cx);
 687        });
 688    }
 689
 690    fn save_api_key(&mut self, _: &menu::Confirm, window: &mut Window, cx: &mut Context<Self>) {
 691        let api_key = self.api_key_editor.read(cx).text(cx).trim().to_string();
 692        if api_key.is_empty() {
 693            return;
 694        }
 695
 696        // url changes can cause the editor to be displayed again
 697        self.api_key_editor
 698            .update(cx, |input, cx| input.set_text("", window, cx));
 699
 700        let state = self.state.clone();
 701        cx.spawn_in(window, async move |_, cx| {
 702            state
 703                .update(cx, |state, cx| state.set_api_key(Some(api_key), cx))
 704                .await
 705        })
 706        .detach_and_log_err(cx);
 707    }
 708
 709    fn reset_api_key(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 710        self.api_key_editor
 711            .update(cx, |input, cx| input.set_text("", window, cx));
 712
 713        let state = self.state.clone();
 714        cx.spawn_in(window, async move |_, cx| {
 715            state
 716                .update(cx, |state, cx| state.set_api_key(None, cx))
 717                .await
 718        })
 719        .detach_and_log_err(cx);
 720
 721        cx.notify();
 722    }
 723
 724    fn save_api_url(&self, cx: &mut Context<Self>) {
 725        let api_url = self.api_url_editor.read(cx).text(cx).trim().to_string();
 726        let current_url = OllamaLanguageModelProvider::api_url(cx);
 727        if !api_url.is_empty() && &api_url != &current_url {
 728            let fs = <dyn Fs>::global(cx);
 729            update_settings_file(fs, cx, move |settings, _| {
 730                settings
 731                    .language_models
 732                    .get_or_insert_default()
 733                    .ollama
 734                    .get_or_insert_default()
 735                    .api_url = Some(api_url);
 736            });
 737        }
 738    }
 739
 740    fn reset_api_url(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 741        self.api_url_editor
 742            .update(cx, |input, cx| input.set_text("", window, cx));
 743        let fs = <dyn Fs>::global(cx);
 744        update_settings_file(fs, cx, |settings, _cx| {
 745            if let Some(settings) = settings
 746                .language_models
 747                .as_mut()
 748                .and_then(|models| models.ollama.as_mut())
 749            {
 750                settings.api_url = Some(OLLAMA_API_URL.into());
 751            }
 752        });
 753        cx.notify();
 754    }
 755
 756    fn save_context_window(&mut self, cx: &mut Context<Self>) {
 757        let context_window_str = self
 758            .context_window_editor
 759            .read(cx)
 760            .text(cx)
 761            .trim()
 762            .to_string();
 763        let current_context_window = OllamaLanguageModelProvider::settings(cx).context_window;
 764
 765        if let Ok(context_window) = context_window_str.parse::<u64>() {
 766            if Some(context_window) != current_context_window {
 767                let fs = <dyn Fs>::global(cx);
 768                update_settings_file(fs, cx, move |settings, _| {
 769                    settings
 770                        .language_models
 771                        .get_or_insert_default()
 772                        .ollama
 773                        .get_or_insert_default()
 774                        .context_window = Some(context_window);
 775                });
 776            }
 777        } else if context_window_str.is_empty() && current_context_window.is_some() {
 778            let fs = <dyn Fs>::global(cx);
 779            update_settings_file(fs, cx, move |settings, _| {
 780                settings
 781                    .language_models
 782                    .get_or_insert_default()
 783                    .ollama
 784                    .get_or_insert_default()
 785                    .context_window = None;
 786            });
 787        }
 788    }
 789
 790    fn reset_context_window(&mut self, window: &mut Window, cx: &mut Context<Self>) {
 791        self.context_window_editor
 792            .update(cx, |input, cx| input.set_text("", window, cx));
 793        let fs = <dyn Fs>::global(cx);
 794        update_settings_file(fs, cx, |settings, _cx| {
 795            if let Some(settings) = settings
 796                .language_models
 797                .as_mut()
 798                .and_then(|models| models.ollama.as_mut())
 799            {
 800                settings.context_window = None;
 801            }
 802        });
 803        cx.notify();
 804    }
 805
 806    fn render_instructions(cx: &App) -> Div {
 807        v_flex()
 808            .gap_2()
 809            .child(Label::new(
 810                "Run LLMs locally on your machine with Ollama, or connect to an Ollama server. \
 811                Can provide access to Llama, Mistral, Gemma, and hundreds of other models.",
 812            ))
 813            .child(Label::new("To use local Ollama:"))
 814            .child(
 815                List::new()
 816                    .child(
 817                        ListBulletItem::new("")
 818                            .child(Label::new("Download and install Ollama from"))
 819                            .child(ButtonLink::new("ollama.com", "https://ollama.com/download")),
 820                    )
 821                    .child(
 822                        ListBulletItem::new("")
 823                            .child(Label::new("Start Ollama and download a model:"))
 824                            .child(Label::new("ollama run gpt-oss:20b").inline_code(cx)),
 825                    )
 826                    .child(ListBulletItem::new(
 827                        "Click 'Connect' below to start using Ollama in Zed",
 828                    )),
 829            )
 830            .child(Label::new(
 831                "Alternatively, you can connect to an Ollama server by specifying its \
 832                URL and API key (may not be required):",
 833            ))
 834    }
 835
 836    fn render_api_key_editor(&self, cx: &Context<Self>) -> impl IntoElement {
 837        let state = self.state.read(cx);
 838        let env_var_set = state.api_key_state.is_from_env_var();
 839        let configured_card_label = if env_var_set {
 840            format!("API key set in {API_KEY_ENV_VAR_NAME} environment variable.")
 841        } else {
 842            "API key configured".to_string()
 843        };
 844
 845        if !state.api_key_state.has_key() {
 846            v_flex()
 847              .on_action(cx.listener(Self::save_api_key))
 848              .child(self.api_key_editor.clone())
 849              .child(
 850                  Label::new(
 851                      format!("You can also set the {API_KEY_ENV_VAR_NAME} environment variable and restart Zed.")
 852                  )
 853                  .size(LabelSize::Small)
 854                  .color(Color::Muted),
 855              )
 856              .into_any_element()
 857        } else {
 858            ConfiguredApiCard::new(configured_card_label)
 859                .disabled(env_var_set)
 860                .on_click(cx.listener(|this, _, window, cx| this.reset_api_key(window, cx)))
 861                .when(env_var_set, |this| {
 862                    this.tooltip_label(format!("To reset your API key, unset the {API_KEY_ENV_VAR_NAME} environment variable."))
 863                })
 864                .into_any_element()
 865        }
 866    }
 867
 868    fn render_context_window_editor(&self, cx: &Context<Self>) -> Div {
 869        let settings = OllamaLanguageModelProvider::settings(cx);
 870        let custom_context_window_set = settings.context_window.is_some();
 871
 872        if custom_context_window_set {
 873            h_flex()
 874                .p_3()
 875                .justify_between()
 876                .rounded_md()
 877                .border_1()
 878                .border_color(cx.theme().colors().border)
 879                .bg(cx.theme().colors().elevated_surface_background)
 880                .child(
 881                    h_flex()
 882                        .gap_2()
 883                        .child(Icon::new(IconName::Check).color(Color::Success))
 884                        .child(v_flex().gap_1().child(Label::new(format!(
 885                            "Context Window: {}",
 886                            settings.context_window.unwrap()
 887                        )))),
 888                )
 889                .child(
 890                    Button::new("reset-context-window", "Reset")
 891                        .label_size(LabelSize::Small)
 892                        .start_icon(Icon::new(IconName::Undo).size(IconSize::Small))
 893                        .layer(ElevationIndex::ModalSurface)
 894                        .on_click(
 895                            cx.listener(|this, _, window, cx| {
 896                                this.reset_context_window(window, cx)
 897                            }),
 898                        ),
 899                )
 900        } else {
 901            v_flex()
 902                .on_action(
 903                    cx.listener(|this, _: &menu::Confirm, _window, cx| {
 904                        this.save_context_window(cx)
 905                    }),
 906                )
 907                .child(self.context_window_editor.clone())
 908                .child(
 909                    Label::new("Default: Model specific")
 910                        .size(LabelSize::Small)
 911                        .color(Color::Muted),
 912                )
 913        }
 914    }
 915
 916    fn render_api_url_editor(&self, cx: &Context<Self>) -> Div {
 917        let api_url = OllamaLanguageModelProvider::api_url(cx);
 918        let custom_api_url_set = api_url != OLLAMA_API_URL;
 919
 920        if custom_api_url_set {
 921            h_flex()
 922                .p_3()
 923                .justify_between()
 924                .rounded_md()
 925                .border_1()
 926                .border_color(cx.theme().colors().border)
 927                .bg(cx.theme().colors().elevated_surface_background)
 928                .child(
 929                    h_flex()
 930                        .gap_2()
 931                        .child(Icon::new(IconName::Check).color(Color::Success))
 932                        .child(v_flex().gap_1().child(Label::new(api_url))),
 933                )
 934                .child(
 935                    Button::new("reset-api-url", "Reset API URL")
 936                        .label_size(LabelSize::Small)
 937                        .start_icon(Icon::new(IconName::Undo).size(IconSize::Small))
 938                        .layer(ElevationIndex::ModalSurface)
 939                        .on_click(
 940                            cx.listener(|this, _, window, cx| this.reset_api_url(window, cx)),
 941                        ),
 942                )
 943        } else {
 944            v_flex()
 945                .on_action(cx.listener(|this, _: &menu::Confirm, _window, cx| {
 946                    this.save_api_url(cx);
 947                    cx.notify();
 948                }))
 949                .gap_2()
 950                .child(self.api_url_editor.clone())
 951        }
 952    }
 953}
 954
 955impl Render for ConfigurationView {
 956    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
 957        let is_authenticated = self.state.read(cx).is_authenticated();
 958
 959        v_flex()
 960            .gap_2()
 961            .child(Self::render_instructions(cx))
 962            .child(self.render_api_url_editor(cx))
 963            .child(self.render_context_window_editor(cx))
 964            .child(self.render_api_key_editor(cx))
 965            .child(
 966                h_flex()
 967                    .w_full()
 968                    .justify_between()
 969                    .gap_2()
 970                    .child(
 971                        h_flex()
 972                            .w_full()
 973                            .gap_2()
 974                            .map(|this| {
 975                                if is_authenticated {
 976                                    this.child(
 977                                        Button::new("ollama-site", "Ollama")
 978                                            .style(ButtonStyle::Subtle)
 979                                            .end_icon(
 980                                                Icon::new(IconName::ArrowUpRight)
 981                                                    .size(IconSize::XSmall)
 982                                                    .color(Color::Muted),
 983                                            )
 984                                            .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
 985                                            .into_any_element(),
 986                                    )
 987                                } else {
 988                                    this.child(
 989                                        Button::new("download_ollama_button", "Download Ollama")
 990                                            .style(ButtonStyle::Subtle)
 991                                            .end_icon(
 992                                                Icon::new(IconName::ArrowUpRight)
 993                                                    .size(IconSize::XSmall)
 994                                                    .color(Color::Muted),
 995                                            )
 996                                            .on_click(move |_, _, cx| {
 997                                                cx.open_url(OLLAMA_DOWNLOAD_URL)
 998                                            })
 999                                            .into_any_element(),
1000                                    )
1001                                }
1002                            })
1003                            .child(
1004                                Button::new("view-models", "View All Models")
1005                                    .style(ButtonStyle::Subtle)
1006                                    .end_icon(
1007                                        Icon::new(IconName::ArrowUpRight)
1008                                            .size(IconSize::XSmall)
1009                                            .color(Color::Muted),
1010                                    )
1011                                    .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
1012                            ),
1013                    )
1014                    .map(|this| {
1015                        if is_authenticated {
1016                            this.child(
1017                                ButtonLike::new("connected")
1018                                    .disabled(true)
1019                                    .cursor_style(CursorStyle::Arrow)
1020                                    .child(
1021                                        h_flex()
1022                                            .gap_2()
1023                                            .child(Icon::new(IconName::Check).color(Color::Success))
1024                                            .child(Label::new("Connected"))
1025                                            .into_any_element(),
1026                                    )
1027                                    .child(
1028                                        IconButton::new("refresh-models", IconName::RotateCcw)
1029                                            .tooltip(Tooltip::text("Refresh Models"))
1030                                            .on_click(cx.listener(|this, _, window, cx| {
1031                                                this.state.update(cx, |state, _| {
1032                                                    state.fetched_models.clear();
1033                                                });
1034                                                this.retry_connection(window, cx);
1035                                            })),
1036                                    ),
1037                            )
1038                        } else {
1039                            this.child(
1040                                Button::new("retry_ollama_models", "Connect")
1041                                    .start_icon(
1042                                        Icon::new(IconName::PlayOutlined).size(IconSize::XSmall),
1043                                    )
1044                                    .on_click(cx.listener(move |this, _, window, cx| {
1045                                        this.retry_connection(window, cx)
1046                                    })),
1047                            )
1048                        }
1049                    }),
1050            )
1051    }
1052}
1053
1054fn merge_settings_into_models(
1055    models: &mut HashMap<String, ollama::Model>,
1056    available_models: &[AvailableModel],
1057    context_window: Option<u64>,
1058) {
1059    for setting_model in available_models {
1060        if let Some(model) = models.get_mut(&setting_model.name) {
1061            if context_window.is_none() {
1062                model.max_tokens = setting_model.max_tokens;
1063            }
1064            model.display_name = setting_model.display_name.clone();
1065            model.keep_alive = setting_model.keep_alive.clone();
1066            model.supports_tools = setting_model.supports_tools;
1067            model.supports_vision = setting_model.supports_images;
1068            model.supports_thinking = setting_model.supports_thinking;
1069        } else {
1070            models.insert(
1071                setting_model.name.clone(),
1072                ollama::Model {
1073                    name: setting_model.name.clone(),
1074                    display_name: setting_model.display_name.clone(),
1075                    max_tokens: context_window.unwrap_or(setting_model.max_tokens),
1076                    keep_alive: setting_model.keep_alive.clone(),
1077                    supports_tools: setting_model.supports_tools,
1078                    supports_vision: setting_model.supports_images,
1079                    supports_thinking: setting_model.supports_thinking,
1080                },
1081            );
1082        }
1083    }
1084}
1085
1086fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
1087    ollama::OllamaTool::Function {
1088        function: OllamaFunctionTool {
1089            name: tool.name,
1090            description: Some(tool.description),
1091            parameters: Some(tool.input_schema),
1092        },
1093    }
1094}
1095
1096#[cfg(test)]
1097mod tests {
1098    use super::*;
1099
1100    #[test]
1101    fn test_merge_settings_preserves_display_names_for_similar_models() {
1102        // Regression test for https://github.com/zed-industries/zed/issues/43646
1103        // When multiple models share the same base name (e.g., qwen2.5-coder:1.5b and qwen2.5-coder:3b),
1104        // each model should get its own display_name from settings, not a random one.
1105
1106        let mut models: HashMap<String, ollama::Model> = HashMap::new();
1107        models.insert(
1108            "qwen2.5-coder:1.5b".to_string(),
1109            ollama::Model {
1110                name: "qwen2.5-coder:1.5b".to_string(),
1111                display_name: None,
1112                max_tokens: 4096,
1113                keep_alive: None,
1114                supports_tools: None,
1115                supports_vision: None,
1116                supports_thinking: None,
1117            },
1118        );
1119        models.insert(
1120            "qwen2.5-coder:3b".to_string(),
1121            ollama::Model {
1122                name: "qwen2.5-coder:3b".to_string(),
1123                display_name: None,
1124                max_tokens: 4096,
1125                keep_alive: None,
1126                supports_tools: None,
1127                supports_vision: None,
1128                supports_thinking: None,
1129            },
1130        );
1131
1132        let available_models = vec![
1133            AvailableModel {
1134                name: "qwen2.5-coder:1.5b".to_string(),
1135                display_name: Some("QWEN2.5 Coder 1.5B".to_string()),
1136                max_tokens: 5000,
1137                keep_alive: None,
1138                supports_tools: Some(true),
1139                supports_images: None,
1140                supports_thinking: None,
1141            },
1142            AvailableModel {
1143                name: "qwen2.5-coder:3b".to_string(),
1144                display_name: Some("QWEN2.5 Coder 3B".to_string()),
1145                max_tokens: 6000,
1146                keep_alive: None,
1147                supports_tools: Some(true),
1148                supports_images: None,
1149                supports_thinking: None,
1150            },
1151        ];
1152
1153        merge_settings_into_models(&mut models, &available_models, None);
1154
1155        let model_1_5b = models
1156            .get("qwen2.5-coder:1.5b")
1157            .expect("1.5b model missing");
1158        let model_3b = models.get("qwen2.5-coder:3b").expect("3b model missing");
1159
1160        assert_eq!(
1161            model_1_5b.display_name,
1162            Some("QWEN2.5 Coder 1.5B".to_string()),
1163            "1.5b model should have its own display_name"
1164        );
1165        assert_eq!(model_1_5b.max_tokens, 5000);
1166
1167        assert_eq!(
1168            model_3b.display_name,
1169            Some("QWEN2.5 Coder 3B".to_string()),
1170            "3b model should have its own display_name"
1171        );
1172        assert_eq!(model_3b.max_tokens, 6000);
1173    }
1174}