ollama.rs

  1use anyhow::{anyhow, bail, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
  4use http_client::HttpClient;
  5use ollama::{
  6    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
  7    ChatResponseDelta, OllamaToolCall,
  8};
  9use settings::{Settings, SettingsStore};
 10use std::{sync::Arc, time::Duration};
 11use ui::{prelude::*, ButtonLike, Indicator};
 12use util::ResultExt;
 13
 14use crate::{
 15    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 16    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 17    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 18};
 19
 20const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 21const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 22const OLLAMA_SITE: &str = "https://ollama.com/";
 23
 24const PROVIDER_ID: &str = "ollama";
 25const PROVIDER_NAME: &str = "Ollama";
 26
 27#[derive(Default, Debug, Clone, PartialEq)]
 28pub struct OllamaSettings {
 29    pub api_url: String,
 30    pub low_speed_timeout: Option<Duration>,
 31}
 32
 33pub struct OllamaLanguageModelProvider {
 34    http_client: Arc<dyn HttpClient>,
 35    state: gpui::Model<State>,
 36}
 37
 38pub struct State {
 39    http_client: Arc<dyn HttpClient>,
 40    available_models: Vec<ollama::Model>,
 41    _subscription: Subscription,
 42}
 43
 44impl State {
 45    fn is_authenticated(&self) -> bool {
 46        !self.available_models.is_empty()
 47    }
 48
 49    fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 50        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 51        let http_client = self.http_client.clone();
 52        let api_url = settings.api_url.clone();
 53
 54        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 55        cx.spawn(|this, mut cx| async move {
 56            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 57
 58            let mut models: Vec<ollama::Model> = models
 59                .into_iter()
 60                // Since there is no metadata from the Ollama API
 61                // indicating which models are embedding models,
 62                // simply filter out models with "-embed" in their name
 63                .filter(|model| !model.name.contains("-embed"))
 64                .map(|model| ollama::Model::new(&model.name))
 65                .collect();
 66
 67            models.sort_by(|a, b| a.name.cmp(&b.name));
 68
 69            this.update(&mut cx, |this, cx| {
 70                this.available_models = models;
 71                cx.notify();
 72            })
 73        })
 74    }
 75
 76    fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 77        if self.is_authenticated() {
 78            Task::ready(Ok(()))
 79        } else {
 80            self.fetch_models(cx)
 81        }
 82    }
 83}
 84
 85impl OllamaLanguageModelProvider {
 86    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
 87        let this = Self {
 88            http_client: http_client.clone(),
 89            state: cx.new_model(|cx| State {
 90                http_client,
 91                available_models: Default::default(),
 92                _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 93                    this.fetch_models(cx).detach();
 94                    cx.notify();
 95                }),
 96            }),
 97        };
 98        this.state
 99            .update(cx, |state, cx| state.fetch_models(cx).detach());
100        this
101    }
102}
103
104impl LanguageModelProviderState for OllamaLanguageModelProvider {
105    type ObservableEntity = State;
106
107    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
108        Some(self.state.clone())
109    }
110}
111
112impl LanguageModelProvider for OllamaLanguageModelProvider {
113    fn id(&self) -> LanguageModelProviderId {
114        LanguageModelProviderId(PROVIDER_ID.into())
115    }
116
117    fn name(&self) -> LanguageModelProviderName {
118        LanguageModelProviderName(PROVIDER_NAME.into())
119    }
120
121    fn icon(&self) -> IconName {
122        IconName::AiOllama
123    }
124
125    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
126        self.state
127            .read(cx)
128            .available_models
129            .iter()
130            .map(|model| {
131                Arc::new(OllamaLanguageModel {
132                    id: LanguageModelId::from(model.name.clone()),
133                    model: model.clone(),
134                    http_client: self.http_client.clone(),
135                    request_limiter: RateLimiter::new(4),
136                }) as Arc<dyn LanguageModel>
137            })
138            .collect()
139    }
140
141    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
142        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
143        let http_client = self.http_client.clone();
144        let api_url = settings.api_url.clone();
145        let id = model.id().0.to_string();
146        cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
147            .detach_and_log_err(cx);
148    }
149
150    fn is_authenticated(&self, cx: &AppContext) -> bool {
151        self.state.read(cx).is_authenticated()
152    }
153
154    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
155        self.state.update(cx, |state, cx| state.authenticate(cx))
156    }
157
158    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
159        let state = self.state.clone();
160        cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
161    }
162
163    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
164        self.state.update(cx, |state, cx| state.fetch_models(cx))
165    }
166}
167
168pub struct OllamaLanguageModel {
169    id: LanguageModelId,
170    model: ollama::Model,
171    http_client: Arc<dyn HttpClient>,
172    request_limiter: RateLimiter,
173}
174
175impl OllamaLanguageModel {
176    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
177        ChatRequest {
178            model: self.model.name.clone(),
179            messages: request
180                .messages
181                .into_iter()
182                .map(|msg| match msg.role {
183                    Role::User => ChatMessage::User {
184                        content: msg.string_contents(),
185                    },
186                    Role::Assistant => ChatMessage::Assistant {
187                        content: msg.string_contents(),
188                        tool_calls: None,
189                    },
190                    Role::System => ChatMessage::System {
191                        content: msg.string_contents(),
192                    },
193                })
194                .collect(),
195            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
196            stream: true,
197            options: Some(ChatOptions {
198                num_ctx: Some(self.model.max_tokens),
199                stop: Some(request.stop),
200                temperature: Some(request.temperature),
201                ..Default::default()
202            }),
203            tools: vec![],
204        }
205    }
206    fn request_completion(
207        &self,
208        request: ChatRequest,
209        cx: &AsyncAppContext,
210    ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
211        let http_client = self.http_client.clone();
212
213        let Ok(api_url) = cx.update(|cx| {
214            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
215            settings.api_url.clone()
216        }) else {
217            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
218        };
219
220        async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
221    }
222}
223
224impl LanguageModel for OllamaLanguageModel {
225    fn id(&self) -> LanguageModelId {
226        self.id.clone()
227    }
228
229    fn name(&self) -> LanguageModelName {
230        LanguageModelName::from(self.model.display_name().to_string())
231    }
232
233    fn provider_id(&self) -> LanguageModelProviderId {
234        LanguageModelProviderId(PROVIDER_ID.into())
235    }
236
237    fn provider_name(&self) -> LanguageModelProviderName {
238        LanguageModelProviderName(PROVIDER_NAME.into())
239    }
240
241    fn telemetry_id(&self) -> String {
242        format!("ollama/{}", self.model.id())
243    }
244
245    fn max_token_count(&self) -> usize {
246        self.model.max_token_count()
247    }
248
249    fn count_tokens(
250        &self,
251        request: LanguageModelRequest,
252        _cx: &AppContext,
253    ) -> BoxFuture<'static, Result<usize>> {
254        // There is no endpoint for this _yet_ in Ollama
255        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
256        let token_count = request
257            .messages
258            .iter()
259            .map(|msg| msg.string_contents().chars().count())
260            .sum::<usize>()
261            / 4;
262
263        async move { Ok(token_count) }.boxed()
264    }
265
266    fn stream_completion(
267        &self,
268        request: LanguageModelRequest,
269        cx: &AsyncAppContext,
270    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
271        let request = self.to_ollama_request(request);
272
273        let http_client = self.http_client.clone();
274        let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
275            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
276            (settings.api_url.clone(), settings.low_speed_timeout)
277        }) else {
278            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
279        };
280
281        let future = self.request_limiter.stream(async move {
282            let response =
283                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
284                    .await?;
285            let stream = response
286                .filter_map(|response| async move {
287                    match response {
288                        Ok(delta) => {
289                            let content = match delta.message {
290                                ChatMessage::User { content } => content,
291                                ChatMessage::Assistant { content, .. } => content,
292                                ChatMessage::System { content } => content,
293                            };
294                            Some(Ok(content))
295                        }
296                        Err(error) => Some(Err(error)),
297                    }
298                })
299                .boxed();
300            Ok(stream)
301        });
302
303        async move { Ok(future.await?.boxed()) }.boxed()
304    }
305
306    fn use_any_tool(
307        &self,
308        request: LanguageModelRequest,
309        tool_name: String,
310        tool_description: String,
311        schema: serde_json::Value,
312        cx: &AsyncAppContext,
313    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
314        use ollama::{OllamaFunctionTool, OllamaTool};
315        let function = OllamaFunctionTool {
316            name: tool_name.clone(),
317            description: Some(tool_description),
318            parameters: Some(schema),
319        };
320        let tools = vec![OllamaTool::Function { function }];
321        let request = self.to_ollama_request(request).with_tools(tools);
322        let response = self.request_completion(request, cx);
323        self.request_limiter
324            .run(async move {
325                let response = response.await?;
326                let ChatMessage::Assistant { tool_calls, .. } = response.message else {
327                    bail!("message does not have an assistant role");
328                };
329                if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
330                    for call in tool_calls {
331                        let OllamaToolCall::Function(function) = call;
332                        if function.name == tool_name {
333                            return Ok(futures::stream::once(async move {
334                                Ok(function.arguments.to_string())
335                            })
336                            .boxed());
337                        }
338                    }
339                } else {
340                    bail!("assistant message does not have any tool calls");
341                };
342
343                bail!("tool not used")
344            })
345            .boxed()
346    }
347}
348
349struct ConfigurationView {
350    state: gpui::Model<State>,
351    loading_models_task: Option<Task<()>>,
352}
353
354impl ConfigurationView {
355    pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
356        let loading_models_task = Some(cx.spawn({
357            let state = state.clone();
358            |this, mut cx| async move {
359                if let Some(task) = state
360                    .update(&mut cx, |state, cx| state.authenticate(cx))
361                    .log_err()
362                {
363                    task.await.log_err();
364                }
365                this.update(&mut cx, |this, cx| {
366                    this.loading_models_task = None;
367                    cx.notify();
368                })
369                .log_err();
370            }
371        }));
372
373        Self {
374            state,
375            loading_models_task,
376        }
377    }
378
379    fn retry_connection(&self, cx: &mut WindowContext) {
380        self.state
381            .update(cx, |state, cx| state.fetch_models(cx))
382            .detach_and_log_err(cx);
383    }
384}
385
386impl Render for ConfigurationView {
387    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
388        let is_authenticated = self.state.read(cx).is_authenticated();
389
390        let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
391        let ollama_reqs =
392            "Ollama must be running with at least one model installed to use it in the assistant.";
393
394        let mut inline_code_bg = cx.theme().colors().editor_background;
395        inline_code_bg.fade_out(0.5);
396
397        if self.loading_models_task.is_some() {
398            div().child(Label::new("Loading models...")).into_any()
399        } else {
400            v_flex()
401                .size_full()
402                .gap_3()
403                .child(
404                    v_flex()
405                        .size_full()
406                        .gap_2()
407                        .p_1()
408                        .child(Label::new(ollama_intro))
409                        .child(Label::new(ollama_reqs))
410                        .child(
411                            h_flex()
412                                .gap_0p5()
413                                .child(Label::new("Once installed, try "))
414                                .child(
415                                    div()
416                                        .bg(inline_code_bg)
417                                        .px_1p5()
418                                        .rounded_md()
419                                        .child(Label::new("ollama run llama3.1")),
420                                ),
421                        ),
422                )
423                .child(
424                    h_flex()
425                        .w_full()
426                        .pt_2()
427                        .justify_between()
428                        .gap_2()
429                        .child(
430                            h_flex()
431                                .w_full()
432                                .gap_2()
433                                .map(|this| {
434                                    if is_authenticated {
435                                        this.child(
436                                            Button::new("ollama-site", "Ollama")
437                                                .style(ButtonStyle::Subtle)
438                                                .icon(IconName::ExternalLink)
439                                                .icon_size(IconSize::XSmall)
440                                                .icon_color(Color::Muted)
441                                                .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
442                                                .into_any_element(),
443                                        )
444                                    } else {
445                                        this.child(
446                                            Button::new(
447                                                "download_ollama_button",
448                                                "Download Ollama",
449                                            )
450                                            .style(ButtonStyle::Subtle)
451                                            .icon(IconName::ExternalLink)
452                                            .icon_size(IconSize::XSmall)
453                                            .icon_color(Color::Muted)
454                                            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
455                                            .into_any_element(),
456                                        )
457                                    }
458                                })
459                                .child(
460                                    Button::new("view-models", "All Models")
461                                        .style(ButtonStyle::Subtle)
462                                        .icon(IconName::ExternalLink)
463                                        .icon_size(IconSize::XSmall)
464                                        .icon_color(Color::Muted)
465                                        .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
466                                ),
467                        )
468                        .child(if is_authenticated {
469                            // This is only a button to ensure the spacing is correct
470                            // it should stay disabled
471                            ButtonLike::new("connected")
472                                .disabled(true)
473                                // Since this won't ever be clickable, we can use the arrow cursor
474                                .cursor_style(gpui::CursorStyle::Arrow)
475                                .child(
476                                    h_flex()
477                                        .gap_2()
478                                        .child(Indicator::dot().color(Color::Success))
479                                        .child(Label::new("Connected"))
480                                        .into_any_element(),
481                                )
482                                .into_any_element()
483                        } else {
484                            Button::new("retry_ollama_models", "Connect")
485                                .icon_position(IconPosition::Start)
486                                .icon(IconName::ArrowCircle)
487                                .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
488                                .into_any_element()
489                        }),
490                )
491                .into_any()
492        }
493    }
494}