ollama.rs

  1use anyhow::{anyhow, bail, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
  4use http_client::HttpClient;
  5use ollama::{
  6    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
  7    ChatResponseDelta, OllamaToolCall,
  8};
  9use serde_json::Value;
 10use settings::{Settings, SettingsStore};
 11use std::{sync::Arc, time::Duration};
 12use ui::{prelude::*, ButtonLike, Indicator};
 13use util::ResultExt;
 14
 15use crate::{
 16    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 17    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 18    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 19};
 20
 21const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 22const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 23const OLLAMA_SITE: &str = "https://ollama.com/";
 24
 25const PROVIDER_ID: &str = "ollama";
 26const PROVIDER_NAME: &str = "Ollama";
 27
 28#[derive(Default, Debug, Clone, PartialEq)]
 29pub struct OllamaSettings {
 30    pub api_url: String,
 31    pub low_speed_timeout: Option<Duration>,
 32}
 33
 34pub struct OllamaLanguageModelProvider {
 35    http_client: Arc<dyn HttpClient>,
 36    state: gpui::Model<State>,
 37}
 38
 39pub struct State {
 40    http_client: Arc<dyn HttpClient>,
 41    available_models: Vec<ollama::Model>,
 42    _subscription: Subscription,
 43}
 44
 45impl State {
 46    fn is_authenticated(&self) -> bool {
 47        !self.available_models.is_empty()
 48    }
 49
 50    fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 51        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 52        let http_client = self.http_client.clone();
 53        let api_url = settings.api_url.clone();
 54
 55        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 56        cx.spawn(|this, mut cx| async move {
 57            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 58
 59            let mut models: Vec<ollama::Model> = models
 60                .into_iter()
 61                // Since there is no metadata from the Ollama API
 62                // indicating which models are embedding models,
 63                // simply filter out models with "-embed" in their name
 64                .filter(|model| !model.name.contains("-embed"))
 65                .map(|model| ollama::Model::new(&model.name))
 66                .collect();
 67
 68            models.sort_by(|a, b| a.name.cmp(&b.name));
 69
 70            this.update(&mut cx, |this, cx| {
 71                this.available_models = models;
 72                cx.notify();
 73            })
 74        })
 75    }
 76
 77    fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 78        if self.is_authenticated() {
 79            Task::ready(Ok(()))
 80        } else {
 81            self.fetch_models(cx)
 82        }
 83    }
 84}
 85
 86impl OllamaLanguageModelProvider {
 87    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
 88        let this = Self {
 89            http_client: http_client.clone(),
 90            state: cx.new_model(|cx| State {
 91                http_client,
 92                available_models: Default::default(),
 93                _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
 94                    this.fetch_models(cx).detach();
 95                    cx.notify();
 96                }),
 97            }),
 98        };
 99        this.state
100            .update(cx, |state, cx| state.fetch_models(cx).detach());
101        this
102    }
103}
104
105impl LanguageModelProviderState for OllamaLanguageModelProvider {
106    type ObservableEntity = State;
107
108    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
109        Some(self.state.clone())
110    }
111}
112
113impl LanguageModelProvider for OllamaLanguageModelProvider {
114    fn id(&self) -> LanguageModelProviderId {
115        LanguageModelProviderId(PROVIDER_ID.into())
116    }
117
118    fn name(&self) -> LanguageModelProviderName {
119        LanguageModelProviderName(PROVIDER_NAME.into())
120    }
121
122    fn icon(&self) -> IconName {
123        IconName::AiOllama
124    }
125
126    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
127        self.state
128            .read(cx)
129            .available_models
130            .iter()
131            .map(|model| {
132                Arc::new(OllamaLanguageModel {
133                    id: LanguageModelId::from(model.name.clone()),
134                    model: model.clone(),
135                    http_client: self.http_client.clone(),
136                    request_limiter: RateLimiter::new(4),
137                }) as Arc<dyn LanguageModel>
138            })
139            .collect()
140    }
141
142    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
143        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
144        let http_client = self.http_client.clone();
145        let api_url = settings.api_url.clone();
146        let id = model.id().0.to_string();
147        cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
148            .detach_and_log_err(cx);
149    }
150
151    fn is_authenticated(&self, cx: &AppContext) -> bool {
152        self.state.read(cx).is_authenticated()
153    }
154
155    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
156        self.state.update(cx, |state, cx| state.authenticate(cx))
157    }
158
159    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
160        let state = self.state.clone();
161        cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
162    }
163
164    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
165        self.state.update(cx, |state, cx| state.fetch_models(cx))
166    }
167}
168
169pub struct OllamaLanguageModel {
170    id: LanguageModelId,
171    model: ollama::Model,
172    http_client: Arc<dyn HttpClient>,
173    request_limiter: RateLimiter,
174}
175
176impl OllamaLanguageModel {
177    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
178        ChatRequest {
179            model: self.model.name.clone(),
180            messages: request
181                .messages
182                .into_iter()
183                .map(|msg| match msg.role {
184                    Role::User => ChatMessage::User {
185                        content: msg.content,
186                    },
187                    Role::Assistant => ChatMessage::Assistant {
188                        content: msg.content,
189                        tool_calls: None,
190                    },
191                    Role::System => ChatMessage::System {
192                        content: msg.content,
193                    },
194                })
195                .collect(),
196            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
197            stream: true,
198            options: Some(ChatOptions {
199                num_ctx: Some(self.model.max_tokens),
200                stop: Some(request.stop),
201                temperature: Some(request.temperature),
202                ..Default::default()
203            }),
204            tools: vec![],
205        }
206    }
207    fn request_completion(
208        &self,
209        request: ChatRequest,
210        cx: &AsyncAppContext,
211    ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
212        let http_client = self.http_client.clone();
213
214        let Ok(api_url) = cx.update(|cx| {
215            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
216            settings.api_url.clone()
217        }) else {
218            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
219        };
220
221        async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
222    }
223}
224
225impl LanguageModel for OllamaLanguageModel {
226    fn id(&self) -> LanguageModelId {
227        self.id.clone()
228    }
229
230    fn name(&self) -> LanguageModelName {
231        LanguageModelName::from(self.model.display_name().to_string())
232    }
233
234    fn provider_id(&self) -> LanguageModelProviderId {
235        LanguageModelProviderId(PROVIDER_ID.into())
236    }
237
238    fn provider_name(&self) -> LanguageModelProviderName {
239        LanguageModelProviderName(PROVIDER_NAME.into())
240    }
241
242    fn telemetry_id(&self) -> String {
243        format!("ollama/{}", self.model.id())
244    }
245
246    fn max_token_count(&self) -> usize {
247        self.model.max_token_count()
248    }
249
250    fn count_tokens(
251        &self,
252        request: LanguageModelRequest,
253        _cx: &AppContext,
254    ) -> BoxFuture<'static, Result<usize>> {
255        // There is no endpoint for this _yet_ in Ollama
256        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
257        let token_count = request
258            .messages
259            .iter()
260            .map(|msg| msg.content.chars().count())
261            .sum::<usize>()
262            / 4;
263
264        async move { Ok(token_count) }.boxed()
265    }
266
267    fn stream_completion(
268        &self,
269        request: LanguageModelRequest,
270        cx: &AsyncAppContext,
271    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
272        let request = self.to_ollama_request(request);
273
274        let http_client = self.http_client.clone();
275        let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
276            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
277            (settings.api_url.clone(), settings.low_speed_timeout)
278        }) else {
279            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
280        };
281
282        let future = self.request_limiter.stream(async move {
283            let response =
284                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
285                    .await?;
286            let stream = response
287                .filter_map(|response| async move {
288                    match response {
289                        Ok(delta) => {
290                            let content = match delta.message {
291                                ChatMessage::User { content } => content,
292                                ChatMessage::Assistant { content, .. } => content,
293                                ChatMessage::System { content } => content,
294                            };
295                            Some(Ok(content))
296                        }
297                        Err(error) => Some(Err(error)),
298                    }
299                })
300                .boxed();
301            Ok(stream)
302        });
303
304        async move { Ok(future.await?.boxed()) }.boxed()
305    }
306
307    fn use_any_tool(
308        &self,
309        request: LanguageModelRequest,
310        tool_name: String,
311        tool_description: String,
312        schema: serde_json::Value,
313        cx: &AsyncAppContext,
314    ) -> BoxFuture<'static, Result<serde_json::Value>> {
315        use ollama::{OllamaFunctionTool, OllamaTool};
316        let function = OllamaFunctionTool {
317            name: tool_name.clone(),
318            description: Some(tool_description),
319            parameters: Some(schema),
320        };
321        let tools = vec![OllamaTool::Function { function }];
322        let request = self.to_ollama_request(request).with_tools(tools);
323        let response = self.request_completion(request, cx);
324        self.request_limiter
325            .run(async move {
326                let response = response.await?;
327                let ChatMessage::Assistant {
328                    tool_calls,
329                    content,
330                } = response.message
331                else {
332                    bail!("message does not have an assistant role");
333                };
334                if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
335                    for call in tool_calls {
336                        let OllamaToolCall::Function(function) = call;
337                        if function.name == tool_name {
338                            return Ok(function.arguments);
339                        }
340                    }
341                } else if let Ok(args) = serde_json::from_str::<Value>(&content) {
342                    // Parse content as arguments.
343                    return Ok(args);
344                } else {
345                    bail!("assistant message does not have any tool calls");
346                };
347
348                bail!("tool not used")
349            })
350            .boxed()
351    }
352}
353
354struct ConfigurationView {
355    state: gpui::Model<State>,
356    loading_models_task: Option<Task<()>>,
357}
358
359impl ConfigurationView {
360    pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
361        let loading_models_task = Some(cx.spawn({
362            let state = state.clone();
363            |this, mut cx| async move {
364                if let Some(task) = state
365                    .update(&mut cx, |state, cx| state.authenticate(cx))
366                    .log_err()
367                {
368                    task.await.log_err();
369                }
370                this.update(&mut cx, |this, cx| {
371                    this.loading_models_task = None;
372                    cx.notify();
373                })
374                .log_err();
375            }
376        }));
377
378        Self {
379            state,
380            loading_models_task,
381        }
382    }
383
384    fn retry_connection(&self, cx: &mut WindowContext) {
385        self.state
386            .update(cx, |state, cx| state.fetch_models(cx))
387            .detach_and_log_err(cx);
388    }
389}
390
391impl Render for ConfigurationView {
392    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
393        let is_authenticated = self.state.read(cx).is_authenticated();
394
395        let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
396        let ollama_reqs =
397            "Ollama must be running with at least one model installed to use it in the assistant.";
398
399        let mut inline_code_bg = cx.theme().colors().editor_background;
400        inline_code_bg.fade_out(0.5);
401
402        if self.loading_models_task.is_some() {
403            div().child(Label::new("Loading models...")).into_any()
404        } else {
405            v_flex()
406                .size_full()
407                .gap_3()
408                .child(
409                    v_flex()
410                        .size_full()
411                        .gap_2()
412                        .p_1()
413                        .child(Label::new(ollama_intro))
414                        .child(Label::new(ollama_reqs))
415                        .child(
416                            h_flex()
417                                .gap_0p5()
418                                .child(Label::new("Once installed, try "))
419                                .child(
420                                    div()
421                                        .bg(inline_code_bg)
422                                        .px_1p5()
423                                        .rounded_md()
424                                        .child(Label::new("ollama run llama3.1")),
425                                ),
426                        ),
427                )
428                .child(
429                    h_flex()
430                        .w_full()
431                        .pt_2()
432                        .justify_between()
433                        .gap_2()
434                        .child(
435                            h_flex()
436                                .w_full()
437                                .gap_2()
438                                .map(|this| {
439                                    if is_authenticated {
440                                        this.child(
441                                            Button::new("ollama-site", "Ollama")
442                                                .style(ButtonStyle::Subtle)
443                                                .icon(IconName::ExternalLink)
444                                                .icon_size(IconSize::XSmall)
445                                                .icon_color(Color::Muted)
446                                                .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
447                                                .into_any_element(),
448                                        )
449                                    } else {
450                                        this.child(
451                                            Button::new(
452                                                "download_ollama_button",
453                                                "Download Ollama",
454                                            )
455                                            .style(ButtonStyle::Subtle)
456                                            .icon(IconName::ExternalLink)
457                                            .icon_size(IconSize::XSmall)
458                                            .icon_color(Color::Muted)
459                                            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
460                                            .into_any_element(),
461                                        )
462                                    }
463                                })
464                                .child(
465                                    Button::new("view-models", "All Models")
466                                        .style(ButtonStyle::Subtle)
467                                        .icon(IconName::ExternalLink)
468                                        .icon_size(IconSize::XSmall)
469                                        .icon_color(Color::Muted)
470                                        .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
471                                ),
472                        )
473                        .child(if is_authenticated {
474                            // This is only a button to ensure the spacing is correct
475                            // it should stay disabled
476                            ButtonLike::new("connected")
477                                .disabled(true)
478                                // Since this won't ever be clickable, we can use the arrow cursor
479                                .cursor_style(gpui::CursorStyle::Arrow)
480                                .child(
481                                    h_flex()
482                                        .gap_2()
483                                        .child(Indicator::dot().color(Color::Success))
484                                        .child(Label::new("Connected"))
485                                        .into_any_element(),
486                                )
487                                .into_any_element()
488                        } else {
489                            Button::new("retry_ollama_models", "Connect")
490                                .icon_position(IconPosition::Start)
491                                .icon(IconName::ArrowCircle)
492                                .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
493                                .into_any_element()
494                        }),
495                )
496                .into_any()
497        }
498    }
499}