ollama.rs

  1use crate::LanguageModelCompletionProvider;
  2use crate::{CompletionProvider, LanguageModel, LanguageModelRequest};
  3use anyhow::Result;
  4use futures::StreamExt as _;
  5use futures::{future::BoxFuture, stream::BoxStream, FutureExt};
  6use gpui::{AnyView, AppContext, Task};
  7use http::HttpClient;
  8use language_model::Role;
  9use ollama::Model as OllamaModel;
 10use ollama::{
 11    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
 12};
 13use std::sync::Arc;
 14use std::time::Duration;
 15use ui::{prelude::*, ButtonLike, ElevationIndex};
 16
 17const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 18const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 19
 20pub struct OllamaCompletionProvider {
 21    api_url: String,
 22    model: OllamaModel,
 23    http_client: Arc<dyn HttpClient>,
 24    low_speed_timeout: Option<Duration>,
 25    settings_version: usize,
 26    available_models: Vec<OllamaModel>,
 27}
 28
 29impl LanguageModelCompletionProvider for OllamaCompletionProvider {
 30    fn available_models(&self) -> Vec<LanguageModel> {
 31        self.available_models
 32            .iter()
 33            .map(|m| LanguageModel::Ollama(m.clone()))
 34            .collect()
 35    }
 36
 37    fn settings_version(&self) -> usize {
 38        self.settings_version
 39    }
 40
 41    fn is_authenticated(&self) -> bool {
 42        !self.available_models.is_empty()
 43    }
 44
 45    fn authenticate(&self, cx: &AppContext) -> Task<Result<()>> {
 46        if self.is_authenticated() {
 47            Task::ready(Ok(()))
 48        } else {
 49            self.fetch_models(cx)
 50        }
 51    }
 52
 53    fn authentication_prompt(&self, cx: &mut WindowContext) -> AnyView {
 54        let fetch_models = Box::new(move |cx: &mut WindowContext| {
 55            cx.update_global::<CompletionProvider, _>(|provider, cx| {
 56                provider
 57                    .update_current_as::<_, OllamaCompletionProvider>(|provider| {
 58                        provider.fetch_models(cx)
 59                    })
 60                    .unwrap_or_else(|| Task::ready(Ok(())))
 61            })
 62        });
 63
 64        cx.new_view(|cx| DownloadOllamaMessage::new(fetch_models, cx))
 65            .into()
 66    }
 67
 68    fn reset_credentials(&self, cx: &AppContext) -> Task<Result<()>> {
 69        self.fetch_models(cx)
 70    }
 71
 72    fn model(&self) -> LanguageModel {
 73        LanguageModel::Ollama(self.model.clone())
 74    }
 75
 76    fn count_tokens(
 77        &self,
 78        request: LanguageModelRequest,
 79        _cx: &AppContext,
 80    ) -> BoxFuture<'static, Result<usize>> {
 81        // There is no endpoint for this _yet_ in Ollama
 82        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
 83        let token_count = request
 84            .messages
 85            .iter()
 86            .map(|msg| msg.content.chars().count())
 87            .sum::<usize>()
 88            / 4;
 89
 90        async move { Ok(token_count) }.boxed()
 91    }
 92
 93    fn stream_completion(
 94        &self,
 95        request: LanguageModelRequest,
 96    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
 97        let request = self.to_ollama_request(request);
 98
 99        let http_client = self.http_client.clone();
100        let api_url = self.api_url.clone();
101        let low_speed_timeout = self.low_speed_timeout;
102        async move {
103            let request =
104                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout);
105            let response = request.await?;
106            let stream = response
107                .filter_map(|response| async move {
108                    match response {
109                        Ok(delta) => {
110                            let content = match delta.message {
111                                ChatMessage::User { content } => content,
112                                ChatMessage::Assistant { content } => content,
113                                ChatMessage::System { content } => content,
114                            };
115                            Some(Ok(content))
116                        }
117                        Err(error) => Some(Err(error)),
118                    }
119                })
120                .boxed();
121            Ok(stream)
122        }
123        .boxed()
124    }
125
126    fn as_any_mut(&mut self) -> &mut dyn std::any::Any {
127        self
128    }
129}
130
131impl OllamaCompletionProvider {
132    pub fn new(
133        model: OllamaModel,
134        api_url: String,
135        http_client: Arc<dyn HttpClient>,
136        low_speed_timeout: Option<Duration>,
137        settings_version: usize,
138        cx: &AppContext,
139    ) -> Self {
140        cx.spawn({
141            let api_url = api_url.clone();
142            let client = http_client.clone();
143            let model = model.name.clone();
144
145            |_| async move {
146                if model.is_empty() {
147                    return Ok(());
148                }
149                preload_model(client.as_ref(), &api_url, &model).await
150            }
151        })
152        .detach_and_log_err(cx);
153
154        Self {
155            api_url,
156            model,
157            http_client,
158            low_speed_timeout,
159            settings_version,
160            available_models: Default::default(),
161        }
162    }
163
164    pub fn update(
165        &mut self,
166        model: OllamaModel,
167        api_url: String,
168        low_speed_timeout: Option<Duration>,
169        settings_version: usize,
170        cx: &AppContext,
171    ) {
172        cx.spawn({
173            let api_url = api_url.clone();
174            let client = self.http_client.clone();
175            let model = model.name.clone();
176
177            |_| async move { preload_model(client.as_ref(), &api_url, &model).await }
178        })
179        .detach_and_log_err(cx);
180
181        if model.name.is_empty() {
182            self.select_first_available_model()
183        } else {
184            self.model = model;
185        }
186
187        self.api_url = api_url;
188        self.low_speed_timeout = low_speed_timeout;
189        self.settings_version = settings_version;
190    }
191
192    pub fn select_first_available_model(&mut self) {
193        if let Some(model) = self.available_models.first() {
194            self.model = model.clone();
195        }
196    }
197
198    pub fn fetch_models(&self, cx: &AppContext) -> Task<Result<()>> {
199        let http_client = self.http_client.clone();
200        let api_url = self.api_url.clone();
201
202        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
203        cx.spawn(|mut cx| async move {
204            let models = get_models(http_client.as_ref(), &api_url, None).await?;
205
206            let mut models: Vec<OllamaModel> = models
207                .into_iter()
208                // Since there is no metadata from the Ollama API
209                // indicating which models are embedding models,
210                // simply filter out models with "-embed" in their name
211                .filter(|model| !model.name.contains("-embed"))
212                .map(|model| OllamaModel::new(&model.name))
213                .collect();
214
215            models.sort_by(|a, b| a.name.cmp(&b.name));
216
217            cx.update_global::<CompletionProvider, _>(|provider, _cx| {
218                provider.update_current_as::<_, OllamaCompletionProvider>(|provider| {
219                    provider.available_models = models;
220
221                    if !provider.available_models.is_empty() && provider.model.name.is_empty() {
222                        provider.select_first_available_model()
223                    }
224                });
225            })
226        })
227    }
228
229    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
230        let model = match request.model {
231            LanguageModel::Ollama(model) => model,
232            _ => self.model.clone(),
233        };
234
235        ChatRequest {
236            model: model.name,
237            messages: request
238                .messages
239                .into_iter()
240                .map(|msg| match msg.role {
241                    Role::User => ChatMessage::User {
242                        content: msg.content,
243                    },
244                    Role::Assistant => ChatMessage::Assistant {
245                        content: msg.content,
246                    },
247                    Role::System => ChatMessage::System {
248                        content: msg.content,
249                    },
250                })
251                .collect(),
252            keep_alive: model.keep_alive.unwrap_or_default(),
253            stream: true,
254            options: Some(ChatOptions {
255                num_ctx: Some(model.max_tokens),
256                stop: Some(request.stop),
257                temperature: Some(request.temperature),
258                ..Default::default()
259            }),
260        }
261    }
262}
263
264struct DownloadOllamaMessage {
265    retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
266}
267
268impl DownloadOllamaMessage {
269    pub fn new(
270        retry_connection: Box<dyn Fn(&mut WindowContext) -> Task<Result<()>>>,
271        _cx: &mut ViewContext<Self>,
272    ) -> Self {
273        Self { retry_connection }
274    }
275
276    fn render_download_button(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
277        ButtonLike::new("download_ollama_button")
278            .style(ButtonStyle::Filled)
279            .size(ButtonSize::Large)
280            .layer(ElevationIndex::ModalSurface)
281            .child(Label::new("Get Ollama"))
282            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
283    }
284
285    fn render_retry_button(&self, cx: &mut ViewContext<Self>) -> impl IntoElement {
286        ButtonLike::new("retry_ollama_models")
287            .style(ButtonStyle::Filled)
288            .size(ButtonSize::Large)
289            .layer(ElevationIndex::ModalSurface)
290            .child(Label::new("Retry"))
291            .on_click(cx.listener(move |this, _, cx| {
292                let connected = (this.retry_connection)(cx);
293
294                cx.spawn(|_this, _cx| async move {
295                    connected.await?;
296                    anyhow::Ok(())
297                })
298                .detach_and_log_err(cx)
299            }))
300    }
301
302    fn render_next_steps(&self, _cx: &mut ViewContext<Self>) -> impl IntoElement {
303        v_flex()
304            .p_4()
305            .size_full()
306            .gap_2()
307            .child(
308                Label::new("Once Ollama is on your machine, make sure to download a model or two.")
309                    .size(LabelSize::Large),
310            )
311            .child(
312                h_flex().w_full().p_4().justify_center().gap_2().child(
313                    ButtonLike::new("view-models")
314                        .style(ButtonStyle::Filled)
315                        .size(ButtonSize::Large)
316                        .layer(ElevationIndex::ModalSurface)
317                        .child(Label::new("View Available Models"))
318                        .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
319                ),
320            )
321    }
322}
323
324impl Render for DownloadOllamaMessage {
325    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
326        v_flex()
327            .p_4()
328            .size_full()
329            .gap_2()
330            .child(Label::new("To use Ollama models via the assistant, Ollama must be running on your machine with at least one model downloaded.").size(LabelSize::Large))
331            .child(
332                h_flex()
333                    .w_full()
334                    .p_4()
335                    .justify_center()
336                    .gap_2()
337                    .child(
338                        self.render_download_button(cx)
339                    )
340                    .child(
341                        self.render_retry_button(cx)
342                    )
343            )
344            .child(self.render_next_steps(cx))
345            .into_any()
346    }
347}