ollama.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::{
  6    AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
  7};
  8use language_model::{
  9    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
 10    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 11    LanguageModelRequest, RateLimiter, Role,
 12};
 13use ollama::{
 14    ChatMessage, ChatOptions, ChatRequest, KeepAlive, get_models, preload_model,
 15    stream_chat_completion,
 16};
 17use schemars::JsonSchema;
 18use serde::{Deserialize, Serialize};
 19use settings::{Settings, SettingsStore};
 20use std::{collections::BTreeMap, sync::Arc};
 21use ui::{ButtonLike, Indicator, List, prelude::*};
 22use util::ResultExt;
 23
 24use crate::AllLanguageModelSettings;
 25use crate::ui::InstructionListItem;
 26
 27const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 28const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 29const OLLAMA_SITE: &str = "https://ollama.com/";
 30
 31const PROVIDER_ID: &str = "ollama";
 32const PROVIDER_NAME: &str = "Ollama";
 33
 34#[derive(Default, Debug, Clone, PartialEq)]
 35pub struct OllamaSettings {
 36    pub api_url: String,
 37    pub available_models: Vec<AvailableModel>,
 38}
 39
 40#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 41pub struct AvailableModel {
 42    /// The model name in the Ollama API (e.g. "llama3.2:latest")
 43    pub name: String,
 44    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 45    pub display_name: Option<String>,
 46    /// The Context Length parameter to the model (aka num_ctx or n_ctx)
 47    pub max_tokens: usize,
 48    /// The number of seconds to keep the connection open after the last request
 49    pub keep_alive: Option<KeepAlive>,
 50}
 51
 52pub struct OllamaLanguageModelProvider {
 53    http_client: Arc<dyn HttpClient>,
 54    state: gpui::Entity<State>,
 55}
 56
 57pub struct State {
 58    http_client: Arc<dyn HttpClient>,
 59    available_models: Vec<ollama::Model>,
 60    fetch_model_task: Option<Task<Result<()>>>,
 61    _subscription: Subscription,
 62}
 63
 64impl State {
 65    fn is_authenticated(&self) -> bool {
 66        !self.available_models.is_empty()
 67    }
 68
 69    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 70        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 71        let http_client = self.http_client.clone();
 72        let api_url = settings.api_url.clone();
 73
 74        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 75        cx.spawn(async move |this, cx| {
 76            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 77
 78            let mut models: Vec<ollama::Model> = models
 79                .into_iter()
 80                // Since there is no metadata from the Ollama API
 81                // indicating which models are embedding models,
 82                // simply filter out models with "-embed" in their name
 83                .filter(|model| !model.name.contains("-embed"))
 84                .map(|model| ollama::Model::new(&model.name, None, None))
 85                .collect();
 86
 87            models.sort_by(|a, b| a.name.cmp(&b.name));
 88
 89            this.update(cx, |this, cx| {
 90                this.available_models = models;
 91                cx.notify();
 92            })
 93        })
 94    }
 95
 96    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 97        let task = self.fetch_models(cx);
 98        self.fetch_model_task.replace(task);
 99    }
100
101    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
102        if self.is_authenticated() {
103            return Task::ready(Ok(()));
104        }
105
106        let fetch_models_task = self.fetch_models(cx);
107        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
108    }
109}
110
111impl OllamaLanguageModelProvider {
112    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
113        let this = Self {
114            http_client: http_client.clone(),
115            state: cx.new(|cx| {
116                let subscription = cx.observe_global::<SettingsStore>({
117                    let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
118                    move |this: &mut State, cx| {
119                        let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
120                        if &settings != new_settings {
121                            settings = new_settings.clone();
122                            this.restart_fetch_models_task(cx);
123                            cx.notify();
124                        }
125                    }
126                });
127
128                State {
129                    http_client,
130                    available_models: Default::default(),
131                    fetch_model_task: None,
132                    _subscription: subscription,
133                }
134            }),
135        };
136        this.state
137            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
138        this
139    }
140}
141
142impl LanguageModelProviderState for OllamaLanguageModelProvider {
143    type ObservableEntity = State;
144
145    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
146        Some(self.state.clone())
147    }
148}
149
150impl LanguageModelProvider for OllamaLanguageModelProvider {
151    fn id(&self) -> LanguageModelProviderId {
152        LanguageModelProviderId(PROVIDER_ID.into())
153    }
154
155    fn name(&self) -> LanguageModelProviderName {
156        LanguageModelProviderName(PROVIDER_NAME.into())
157    }
158
159    fn icon(&self) -> IconName {
160        IconName::AiOllama
161    }
162
163    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
164        self.provided_models(cx).into_iter().next()
165    }
166
167    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
168        self.default_model(cx)
169    }
170
171    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
172        let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
173
174        // Add models from the Ollama API
175        for model in self.state.read(cx).available_models.iter() {
176            models.insert(model.name.clone(), model.clone());
177        }
178
179        // Override with available models from settings
180        for model in AllLanguageModelSettings::get_global(cx)
181            .ollama
182            .available_models
183            .iter()
184        {
185            models.insert(
186                model.name.clone(),
187                ollama::Model {
188                    name: model.name.clone(),
189                    display_name: model.display_name.clone(),
190                    max_tokens: model.max_tokens,
191                    keep_alive: model.keep_alive.clone(),
192                },
193            );
194        }
195
196        models
197            .into_values()
198            .map(|model| {
199                Arc::new(OllamaLanguageModel {
200                    id: LanguageModelId::from(model.name.clone()),
201                    model: model.clone(),
202                    http_client: self.http_client.clone(),
203                    request_limiter: RateLimiter::new(4),
204                }) as Arc<dyn LanguageModel>
205            })
206            .collect()
207    }
208
209    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
210        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
211        let http_client = self.http_client.clone();
212        let api_url = settings.api_url.clone();
213        let id = model.id().0.to_string();
214        cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
215            .detach_and_log_err(cx);
216    }
217
218    fn is_authenticated(&self, cx: &App) -> bool {
219        self.state.read(cx).is_authenticated()
220    }
221
222    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
223        self.state.update(cx, |state, cx| state.authenticate(cx))
224    }
225
226    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
227        let state = self.state.clone();
228        cx.new(|cx| ConfigurationView::new(state, window, cx))
229            .into()
230    }
231
232    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
233        self.state.update(cx, |state, cx| state.fetch_models(cx))
234    }
235}
236
237pub struct OllamaLanguageModel {
238    id: LanguageModelId,
239    model: ollama::Model,
240    http_client: Arc<dyn HttpClient>,
241    request_limiter: RateLimiter,
242}
243
244impl OllamaLanguageModel {
245    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
246        ChatRequest {
247            model: self.model.name.clone(),
248            messages: request
249                .messages
250                .into_iter()
251                .map(|msg| match msg.role {
252                    Role::User => ChatMessage::User {
253                        content: msg.string_contents(),
254                    },
255                    Role::Assistant => ChatMessage::Assistant {
256                        content: msg.string_contents(),
257                        tool_calls: None,
258                    },
259                    Role::System => ChatMessage::System {
260                        content: msg.string_contents(),
261                    },
262                })
263                .collect(),
264            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
265            stream: true,
266            options: Some(ChatOptions {
267                num_ctx: Some(self.model.max_tokens),
268                stop: Some(request.stop),
269                temperature: request.temperature.or(Some(1.0)),
270                ..Default::default()
271            }),
272            tools: vec![],
273        }
274    }
275}
276
277impl LanguageModel for OllamaLanguageModel {
278    fn id(&self) -> LanguageModelId {
279        self.id.clone()
280    }
281
282    fn name(&self) -> LanguageModelName {
283        LanguageModelName::from(self.model.display_name().to_string())
284    }
285
286    fn provider_id(&self) -> LanguageModelProviderId {
287        LanguageModelProviderId(PROVIDER_ID.into())
288    }
289
290    fn provider_name(&self) -> LanguageModelProviderName {
291        LanguageModelProviderName(PROVIDER_NAME.into())
292    }
293
294    fn supports_tools(&self) -> bool {
295        false
296    }
297
298    fn telemetry_id(&self) -> String {
299        format!("ollama/{}", self.model.id())
300    }
301
302    fn max_token_count(&self) -> usize {
303        self.model.max_token_count()
304    }
305
306    fn count_tokens(
307        &self,
308        request: LanguageModelRequest,
309        _cx: &App,
310    ) -> BoxFuture<'static, Result<usize>> {
311        // There is no endpoint for this _yet_ in Ollama
312        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
313        let token_count = request
314            .messages
315            .iter()
316            .map(|msg| msg.string_contents().chars().count())
317            .sum::<usize>()
318            / 4;
319
320        async move { Ok(token_count) }.boxed()
321    }
322
323    fn stream_completion(
324        &self,
325        request: LanguageModelRequest,
326        cx: &AsyncApp,
327    ) -> BoxFuture<
328        'static,
329        Result<
330            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
331        >,
332    > {
333        let request = self.to_ollama_request(request);
334
335        let http_client = self.http_client.clone();
336        let Ok(api_url) = cx.update(|cx| {
337            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
338            settings.api_url.clone()
339        }) else {
340            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
341        };
342
343        let future = self.request_limiter.stream(async move {
344            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
345            let stream = response
346                .filter_map(|response| async move {
347                    match response {
348                        Ok(delta) => {
349                            let content = match delta.message {
350                                ChatMessage::User { content } => content,
351                                ChatMessage::Assistant { content, .. } => content,
352                                ChatMessage::System { content } => content,
353                            };
354                            Some(Ok(content))
355                        }
356                        Err(error) => Some(Err(error)),
357                    }
358                })
359                .boxed();
360            Ok(stream)
361        });
362
363        async move {
364            Ok(future
365                .await?
366                .map(|result| {
367                    result
368                        .map(LanguageModelCompletionEvent::Text)
369                        .map_err(LanguageModelCompletionError::Other)
370                })
371                .boxed())
372        }
373        .boxed()
374    }
375}
376
377struct ConfigurationView {
378    state: gpui::Entity<State>,
379    loading_models_task: Option<Task<()>>,
380}
381
382impl ConfigurationView {
383    pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
384        let loading_models_task = Some(cx.spawn_in(window, {
385            let state = state.clone();
386            async move |this, cx| {
387                if let Some(task) = state
388                    .update(cx, |state, cx| state.authenticate(cx))
389                    .log_err()
390                {
391                    task.await.log_err();
392                }
393                this.update(cx, |this, cx| {
394                    this.loading_models_task = None;
395                    cx.notify();
396                })
397                .log_err();
398            }
399        }));
400
401        Self {
402            state,
403            loading_models_task,
404        }
405    }
406
407    fn retry_connection(&self, cx: &mut App) {
408        self.state
409            .update(cx, |state, cx| state.fetch_models(cx))
410            .detach_and_log_err(cx);
411    }
412}
413
414impl Render for ConfigurationView {
415    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
416        let is_authenticated = self.state.read(cx).is_authenticated();
417
418        let ollama_intro =
419            "Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama.";
420
421        if self.loading_models_task.is_some() {
422            div().child(Label::new("Loading models...")).into_any()
423        } else {
424            v_flex()
425                .gap_2()
426                .child(
427                    v_flex().gap_1().child(Label::new(ollama_intro)).child(
428                        List::new()
429                            .child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant."))
430                            .child(InstructionListItem::text_only(
431                                "Once installed, try `ollama run llama3.2`",
432                            )),
433                    ),
434                )
435                .child(
436                    h_flex()
437                        .w_full()
438                        .justify_between()
439                        .gap_2()
440                        .child(
441                            h_flex()
442                                .w_full()
443                                .gap_2()
444                                .map(|this| {
445                                    if is_authenticated {
446                                        this.child(
447                                            Button::new("ollama-site", "Ollama")
448                                                .style(ButtonStyle::Subtle)
449                                                .icon(IconName::ArrowUpRight)
450                                                .icon_size(IconSize::XSmall)
451                                                .icon_color(Color::Muted)
452                                                .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
453                                                .into_any_element(),
454                                        )
455                                    } else {
456                                        this.child(
457                                            Button::new(
458                                                "download_ollama_button",
459                                                "Download Ollama",
460                                            )
461                                            .style(ButtonStyle::Subtle)
462                                            .icon(IconName::ArrowUpRight)
463                                            .icon_size(IconSize::XSmall)
464                                            .icon_color(Color::Muted)
465                                            .on_click(move |_, _, cx| {
466                                                cx.open_url(OLLAMA_DOWNLOAD_URL)
467                                            })
468                                            .into_any_element(),
469                                        )
470                                    }
471                                })
472                                .child(
473                                    Button::new("view-models", "All Models")
474                                        .style(ButtonStyle::Subtle)
475                                        .icon(IconName::ArrowUpRight)
476                                        .icon_size(IconSize::XSmall)
477                                        .icon_color(Color::Muted)
478                                        .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
479                                ),
480                        )
481                        .map(|this| {
482                            if is_authenticated {
483                                this.child(
484                                    ButtonLike::new("connected")
485                                        .disabled(true)
486                                        .cursor_style(gpui::CursorStyle::Arrow)
487                                        .child(
488                                            h_flex()
489                                                .gap_2()
490                                                .child(Indicator::dot().color(Color::Success))
491                                                .child(Label::new("Connected"))
492                                                .into_any_element(),
493                                        ),
494                                )
495                            } else {
496                                this.child(
497                                    Button::new("retry_ollama_models", "Connect")
498                                        .icon_position(IconPosition::Start)
499                                        .icon_size(IconSize::XSmall)
500                                        .icon(IconName::Play)
501                                        .on_click(cx.listener(move |this, _, _, cx| {
502                                            this.retry_connection(cx)
503                                        })),
504                                )
505                            }
506                        })
507                )
508                .into_any()
509        }
510    }
511}