ollama.rs

  1use anyhow::{anyhow, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
  4use http_client::HttpClient;
  5use ollama::{
  6    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
  7};
  8use settings::{Settings, SettingsStore};
  9use std::{future, sync::Arc, time::Duration};
 10use ui::{prelude::*, ButtonLike, ElevationIndex};
 11
 12use crate::{
 13    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 14    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 15    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 16};
 17
 18const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 19const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 20
 21const PROVIDER_ID: &str = "ollama";
 22const PROVIDER_NAME: &str = "Ollama";
 23
 24#[derive(Default, Debug, Clone, PartialEq)]
 25pub struct OllamaSettings {
 26    pub api_url: String,
 27    pub low_speed_timeout: Option<Duration>,
 28}
 29
 30pub struct OllamaLanguageModelProvider {
 31    http_client: Arc<dyn HttpClient>,
 32    state: gpui::Model<State>,
 33}
 34
 35pub struct State {
 36    http_client: Arc<dyn HttpClient>,
 37    available_models: Vec<ollama::Model>,
 38    _subscription: Subscription,
 39}
 40
 41impl State {
 42    fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 43        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 44        let http_client = self.http_client.clone();
 45        let api_url = settings.api_url.clone();
 46
 47        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 48        cx.spawn(|this, mut cx| async move {
 49            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 50
 51            let mut models: Vec<ollama::Model> = models
 52                .into_iter()
 53                // Since there is no metadata from the Ollama API
 54                // indicating which models are embedding models,
 55                // simply filter out models with "-embed" in their name
 56                .filter(|model| !model.name.contains("-embed"))
 57                .map(|model| ollama::Model::new(&model.name))
 58                .collect();
 59
 60            models.sort_by(|a, b| a.name.cmp(&b.name));
 61
 62            this.update(&mut cx, |this, cx| {
 63                this.available_models = models;
 64                cx.notify();
 65            })
 66        })
 67    }
 68}
 69
 70impl OllamaLanguageModelProvider {
 71    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
 72        let this = Self {
 73            http_client: http_client.clone(),
 74            state: cx.new_model(|cx| State {
 75                http_client,
 76                available_models: Default::default(),
 77                _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 78                    this.fetch_models(cx).detach();
 79                    cx.notify();
 80                }),
 81            }),
 82        };
 83        this.state
 84            .update(cx, |state, cx| state.fetch_models(cx).detach());
 85        this
 86    }
 87}
 88
 89impl LanguageModelProviderState for OllamaLanguageModelProvider {
 90    type ObservableEntity = State;
 91
 92    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
 93        Some(self.state.clone())
 94    }
 95}
 96
 97impl LanguageModelProvider for OllamaLanguageModelProvider {
 98    fn id(&self) -> LanguageModelProviderId {
 99        LanguageModelProviderId(PROVIDER_ID.into())
100    }
101
102    fn name(&self) -> LanguageModelProviderName {
103        LanguageModelProviderName(PROVIDER_NAME.into())
104    }
105
106    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
107        self.state
108            .read(cx)
109            .available_models
110            .iter()
111            .map(|model| {
112                Arc::new(OllamaLanguageModel {
113                    id: LanguageModelId::from(model.name.clone()),
114                    model: model.clone(),
115                    http_client: self.http_client.clone(),
116                    request_limiter: RateLimiter::new(4),
117                }) as Arc<dyn LanguageModel>
118            })
119            .collect()
120    }
121
122    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
123        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
124        let http_client = self.http_client.clone();
125        let api_url = settings.api_url.clone();
126        let id = model.id().0.to_string();
127        cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
128            .detach_and_log_err(cx);
129    }
130
131    fn is_authenticated(&self, cx: &AppContext) -> bool {
132        !self.state.read(cx).available_models.is_empty()
133    }
134
135    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
136        if self.is_authenticated(cx) {
137            Task::ready(Ok(()))
138        } else {
139            self.state.update(cx, |state, cx| state.fetch_models(cx))
140        }
141    }
142
143    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
144        let state = self.state.clone();
145        let fetch_models = Box::new(move |cx: &mut WindowContext| {
146            state.update(cx, |this, cx| this.fetch_models(cx))
147        });
148
149        cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
150            .into()
151    }
152
153    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
154        self.state.update(cx, |state, cx| state.fetch_models(cx))
155    }
156}
157
158pub struct OllamaLanguageModel {
159    id: LanguageModelId,
160    model: ollama::Model,
161    http_client: Arc<dyn HttpClient>,
162    request_limiter: RateLimiter,
163}
164
165impl OllamaLanguageModel {
166    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
167        ChatRequest {
168            model: self.model.name.clone(),
169            messages: request
170                .messages
171                .into_iter()
172                .map(|msg| match msg.role {
173                    Role::User => ChatMessage::User {
174                        content: msg.content,
175                    },
176                    Role::Assistant => ChatMessage::Assistant {
177                        content: msg.content,
178                    },
179                    Role::System => ChatMessage::System {
180                        content: msg.content,
181                    },
182                })
183                .collect(),
184            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
185            stream: true,
186            options: Some(ChatOptions {
187                num_ctx: Some(self.model.max_tokens),
188                stop: Some(request.stop),
189                temperature: Some(request.temperature),
190                ..Default::default()
191            }),
192        }
193    }
194}
195
196impl LanguageModel for OllamaLanguageModel {
197    fn id(&self) -> LanguageModelId {
198        self.id.clone()
199    }
200
201    fn name(&self) -> LanguageModelName {
202        LanguageModelName::from(self.model.display_name().to_string())
203    }
204
205    fn provider_id(&self) -> LanguageModelProviderId {
206        LanguageModelProviderId(PROVIDER_ID.into())
207    }
208
209    fn provider_name(&self) -> LanguageModelProviderName {
210        LanguageModelProviderName(PROVIDER_NAME.into())
211    }
212
213    fn telemetry_id(&self) -> String {
214        format!("ollama/{}", self.model.id())
215    }
216
217    fn max_token_count(&self) -> usize {
218        self.model.max_token_count()
219    }
220
221    fn count_tokens(
222        &self,
223        request: LanguageModelRequest,
224        _cx: &AppContext,
225    ) -> BoxFuture<'static, Result<usize>> {
226        // There is no endpoint for this _yet_ in Ollama
227        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
228        let token_count = request
229            .messages
230            .iter()
231            .map(|msg| msg.content.chars().count())
232            .sum::<usize>()
233            / 4;
234
235        async move { Ok(token_count) }.boxed()
236    }
237
238    fn stream_completion(
239        &self,
240        request: LanguageModelRequest,
241        cx: &AsyncAppContext,
242    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
243        let request = self.to_ollama_request(request);
244
245        let http_client = self.http_client.clone();
246        let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
247            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
248            (settings.api_url.clone(), settings.low_speed_timeout)
249        }) else {
250            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
251        };
252
253        let future = self.request_limiter.stream(async move {
254            let response =
255                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
256                    .await?;
257            let stream = response
258                .filter_map(|response| async move {
259                    match response {
260                        Ok(delta) => {
261                            let content = match delta.message {
262                                ChatMessage::User { content } => content,
263                                ChatMessage::Assistant { content } => content,
264                                ChatMessage::System { content } => content,
265                            };
266                            Some(Ok(content))
267                        }
268                        Err(error) => Some(Err(error)),
269                    }
270                })
271                .boxed();
272            Ok(stream)
273        });
274
275        async move { Ok(future.await?.boxed()) }.boxed()
276    }
277
278    fn use_any_tool(
279        &self,
280        _request: LanguageModelRequest,
281        _name: String,
282        _description: String,
283        _schema: serde_json::Value,
284        _cx: &AsyncAppContext,
285    ) -> BoxFuture<'static, Result<serde_json::Value>> {
286        future::ready(Err(anyhow!("not implemented"))).boxed()
287    }
288}
289
290struct DownloadOllamaMessage {
291    retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
292}
293
294impl DownloadOllamaMessage {
295    pub fn new(
296        retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
297        _cx: &mut ViewContext<Self>,
298    ) -> Self {
299        Self { retry_connection }
300    }
301
302    fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
303        ButtonLike::new("download_ollama_button")
304            .style(ButtonStyle::Filled)
305            .size(ButtonSize::Large)
306            .layer(ElevationIndex::ModalSurface)
307            .child(Label::new("Get Ollama"))
308            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
309    }
310
311    fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
312        ButtonLike::new("retry_ollama_models")
313            .style(ButtonStyle::Filled)
314            .size(ButtonSize::Large)
315            .layer(ElevationIndex::ModalSurface)
316            .child(Label::new("Retry"))
317            .on_click(cx.listener(move |this, _, cx| {
318                let connected = (this.retry_connection)(cx);
319
320                cx.spawn(|_this, _cx| async move {
321                    connected.await?;
322                    anyhow::Ok(())
323                })
324                .detach_and_log_err(cx)
325            }))
326    }
327
328    fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
329        v_flex()
330            .p_4()
331            .size_full()
332            .gap_2()
333            .child(
334                Label::new("Once Ollama is on your machine, make sure to download a model or two.")
335                    .size(LabelSize::Large),
336            )
337            .child(
338                h_flex().w_full().p_4().justify_center().gap_2().child(
339                    ButtonLike::new("view-models")
340                        .style(ButtonStyle::Filled)
341                        .size(ButtonSize::Large)
342                        .layer(ElevationIndex::ModalSurface)
343                        .child(Label::new("View Available Models"))
344                        .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
345                ),
346            )
347    }
348}
349
350impl Render for DownloadOllamaMessage {
351    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
352        v_flex()
353            .p_4()
354            .size_full()
355            .gap_2()
356            .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
357            .child(
358                h_flex()
359                    .w_full()
360                    .p_4()
361                    .justify_center()
362                    .gap_2()
363                    .child(
364                        self.render_download_button(cx)
365                    )
366                    .child(
367                        self.render_retry_button(cx)
368                    )
369            )
370            .child(self.render_next_steps(cx))
371            .into_any()
372    }
373}