ollama.rs

  1use anyhow::{anyhow, bail, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
  4use http_client::HttpClient;
  5use ollama::{
  6    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
  7    ChatResponseDelta, KeepAlive, OllamaToolCall,
  8};
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use settings::{Settings, SettingsStore};
 12use std::{collections::BTreeMap, sync::Arc, time::Duration};
 13use ui::{prelude::*, ButtonLike, Indicator};
 14use util::ResultExt;
 15
 16use crate::LanguageModelCompletionEvent;
 17use crate::{
 18    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 19    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 20    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 21};
 22
 23const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 24const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 25const OLLAMA_SITE: &str = "https://ollama.com/";
 26
 27const PROVIDER_ID: &str = "ollama";
 28const PROVIDER_NAME: &str = "Ollama";
 29
 30#[derive(Default, Debug, Clone, PartialEq)]
 31pub struct OllamaSettings {
 32    pub api_url: String,
 33    pub low_speed_timeout: Option<Duration>,
 34    pub available_models: Vec<AvailableModel>,
 35}
 36
 37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 38pub struct AvailableModel {
 39    /// The model name in the Ollama API (e.g. "llama3.1:latest")
 40    pub name: String,
 41    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 42    pub display_name: Option<String>,
 43    /// The Context Length parameter to the model (aka num_ctx or n_ctx)
 44    pub max_tokens: usize,
 45    /// The number of seconds to keep the connection open after the last request
 46    pub keep_alive: Option<KeepAlive>,
 47}
 48
 49pub struct OllamaLanguageModelProvider {
 50    http_client: Arc<dyn HttpClient>,
 51    state: gpui::Model<State>,
 52}
 53
 54pub struct State {
 55    http_client: Arc<dyn HttpClient>,
 56    available_models: Vec<ollama::Model>,
 57    _subscription: Subscription,
 58}
 59
 60impl State {
 61    fn is_authenticated(&self) -> bool {
 62        !self.available_models.is_empty()
 63    }
 64
 65    fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 66        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 67        let http_client = self.http_client.clone();
 68        let api_url = settings.api_url.clone();
 69
 70        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 71        cx.spawn(|this, mut cx| async move {
 72            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 73
 74            let mut models: Vec<ollama::Model> = models
 75                .into_iter()
 76                // Since there is no metadata from the Ollama API
 77                // indicating which models are embedding models,
 78                // simply filter out models with "-embed" in their name
 79                .filter(|model| !model.name.contains("-embed"))
 80                .map(|model| ollama::Model::new(&model.name, None, None))
 81                .collect();
 82
 83            models.sort_by(|a, b| a.name.cmp(&b.name));
 84
 85            this.update(&mut cx, |this, cx| {
 86                this.available_models = models;
 87                cx.notify();
 88            })
 89        })
 90    }
 91
 92    fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 93        if self.is_authenticated() {
 94            Task::ready(Ok(()))
 95        } else {
 96            self.fetch_models(cx)
 97        }
 98    }
 99}
100
101impl OllamaLanguageModelProvider {
102    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
103        let this = Self {
104            http_client: http_client.clone(),
105            state: cx.new_model(|cx| State {
106                http_client,
107                available_models: Default::default(),
108                _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
109                    this.fetch_models(cx).detach();
110                    cx.notify();
111                }),
112            }),
113        };
114        this.state
115            .update(cx, |state, cx| state.fetch_models(cx).detach());
116        this
117    }
118}
119
120impl LanguageModelProviderState for OllamaLanguageModelProvider {
121    type ObservableEntity = State;
122
123    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
124        Some(self.state.clone())
125    }
126}
127
128impl LanguageModelProvider for OllamaLanguageModelProvider {
129    fn id(&self) -> LanguageModelProviderId {
130        LanguageModelProviderId(PROVIDER_ID.into())
131    }
132
133    fn name(&self) -> LanguageModelProviderName {
134        LanguageModelProviderName(PROVIDER_NAME.into())
135    }
136
137    fn icon(&self) -> IconName {
138        IconName::AiOllama
139    }
140
141    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
142        let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
143
144        // Add models from the Ollama API
145        for model in self.state.read(cx).available_models.iter() {
146            models.insert(model.name.clone(), model.clone());
147        }
148
149        // Override with available models from settings
150        for model in AllLanguageModelSettings::get_global(cx)
151            .ollama
152            .available_models
153            .iter()
154        {
155            models.insert(
156                model.name.clone(),
157                ollama::Model {
158                    name: model.name.clone(),
159                    display_name: model.display_name.clone(),
160                    max_tokens: model.max_tokens,
161                    keep_alive: model.keep_alive.clone(),
162                },
163            );
164        }
165
166        models
167            .into_values()
168            .map(|model| {
169                Arc::new(OllamaLanguageModel {
170                    id: LanguageModelId::from(model.name.clone()),
171                    model: model.clone(),
172                    http_client: self.http_client.clone(),
173                    request_limiter: RateLimiter::new(4),
174                }) as Arc<dyn LanguageModel>
175            })
176            .collect()
177    }
178
179    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
180        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
181        let http_client = self.http_client.clone();
182        let api_url = settings.api_url.clone();
183        let id = model.id().0.to_string();
184        cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
185            .detach_and_log_err(cx);
186    }
187
188    fn is_authenticated(&self, cx: &AppContext) -> bool {
189        self.state.read(cx).is_authenticated()
190    }
191
192    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
193        self.state.update(cx, |state, cx| state.authenticate(cx))
194    }
195
196    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
197        let state = self.state.clone();
198        cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
199    }
200
201    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
202        self.state.update(cx, |state, cx| state.fetch_models(cx))
203    }
204}
205
206pub struct OllamaLanguageModel {
207    id: LanguageModelId,
208    model: ollama::Model,
209    http_client: Arc<dyn HttpClient>,
210    request_limiter: RateLimiter,
211}
212
213impl OllamaLanguageModel {
214    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
215        ChatRequest {
216            model: self.model.name.clone(),
217            messages: request
218                .messages
219                .into_iter()
220                .map(|msg| match msg.role {
221                    Role::User => ChatMessage::User {
222                        content: msg.string_contents(),
223                    },
224                    Role::Assistant => ChatMessage::Assistant {
225                        content: msg.string_contents(),
226                        tool_calls: None,
227                    },
228                    Role::System => ChatMessage::System {
229                        content: msg.string_contents(),
230                    },
231                })
232                .collect(),
233            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
234            stream: true,
235            options: Some(ChatOptions {
236                num_ctx: Some(self.model.max_tokens),
237                stop: Some(request.stop),
238                temperature: request.temperature.or(Some(1.0)),
239                ..Default::default()
240            }),
241            tools: vec![],
242        }
243    }
244    fn request_completion(
245        &self,
246        request: ChatRequest,
247        cx: &AsyncAppContext,
248    ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
249        let http_client = self.http_client.clone();
250
251        let Ok(api_url) = cx.update(|cx| {
252            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
253            settings.api_url.clone()
254        }) else {
255            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
256        };
257
258        async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
259    }
260}
261
262impl LanguageModel for OllamaLanguageModel {
263    fn id(&self) -> LanguageModelId {
264        self.id.clone()
265    }
266
267    fn name(&self) -> LanguageModelName {
268        LanguageModelName::from(self.model.display_name().to_string())
269    }
270
271    fn provider_id(&self) -> LanguageModelProviderId {
272        LanguageModelProviderId(PROVIDER_ID.into())
273    }
274
275    fn provider_name(&self) -> LanguageModelProviderName {
276        LanguageModelProviderName(PROVIDER_NAME.into())
277    }
278
279    fn telemetry_id(&self) -> String {
280        format!("ollama/{}", self.model.id())
281    }
282
283    fn max_token_count(&self) -> usize {
284        self.model.max_token_count()
285    }
286
287    fn count_tokens(
288        &self,
289        request: LanguageModelRequest,
290        _cx: &AppContext,
291    ) -> BoxFuture<'static, Result<usize>> {
292        // There is no endpoint for this _yet_ in Ollama
293        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
294        let token_count = request
295            .messages
296            .iter()
297            .map(|msg| msg.string_contents().chars().count())
298            .sum::<usize>()
299            / 4;
300
301        async move { Ok(token_count) }.boxed()
302    }
303
304    fn stream_completion(
305        &self,
306        request: LanguageModelRequest,
307        cx: &AsyncAppContext,
308    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
309        let request = self.to_ollama_request(request);
310
311        let http_client = self.http_client.clone();
312        let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
313            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
314            (settings.api_url.clone(), settings.low_speed_timeout)
315        }) else {
316            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
317        };
318
319        let future = self.request_limiter.stream(async move {
320            let response =
321                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
322                    .await?;
323            let stream = response
324                .filter_map(|response| async move {
325                    match response {
326                        Ok(delta) => {
327                            let content = match delta.message {
328                                ChatMessage::User { content } => content,
329                                ChatMessage::Assistant { content, .. } => content,
330                                ChatMessage::System { content } => content,
331                            };
332                            Some(Ok(content))
333                        }
334                        Err(error) => Some(Err(error)),
335                    }
336                })
337                .boxed();
338            Ok(stream)
339        });
340
341        async move {
342            Ok(future
343                .await?
344                .map(|result| result.map(LanguageModelCompletionEvent::Text))
345                .boxed())
346        }
347        .boxed()
348    }
349
350    fn use_any_tool(
351        &self,
352        request: LanguageModelRequest,
353        tool_name: String,
354        tool_description: String,
355        schema: serde_json::Value,
356        cx: &AsyncAppContext,
357    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
358        use ollama::{OllamaFunctionTool, OllamaTool};
359        let function = OllamaFunctionTool {
360            name: tool_name.clone(),
361            description: Some(tool_description),
362            parameters: Some(schema),
363        };
364        let tools = vec![OllamaTool::Function { function }];
365        let request = self.to_ollama_request(request).with_tools(tools);
366        let response = self.request_completion(request, cx);
367        self.request_limiter
368            .run(async move {
369                let response = response.await?;
370                let ChatMessage::Assistant { tool_calls, .. } = response.message else {
371                    bail!("message does not have an assistant role");
372                };
373                if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
374                    for call in tool_calls {
375                        let OllamaToolCall::Function(function) = call;
376                        if function.name == tool_name {
377                            return Ok(futures::stream::once(async move {
378                                Ok(function.arguments.to_string())
379                            })
380                            .boxed());
381                        }
382                    }
383                } else {
384                    bail!("assistant message does not have any tool calls");
385                };
386
387                bail!("tool not used")
388            })
389            .boxed()
390    }
391}
392
393struct ConfigurationView {
394    state: gpui::Model<State>,
395    loading_models_task: Option<Task<()>>,
396}
397
398impl ConfigurationView {
399    pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
400        let loading_models_task = Some(cx.spawn({
401            let state = state.clone();
402            |this, mut cx| async move {
403                if let Some(task) = state
404                    .update(&mut cx, |state, cx| state.authenticate(cx))
405                    .log_err()
406                {
407                    task.await.log_err();
408                }
409                this.update(&mut cx, |this, cx| {
410                    this.loading_models_task = None;
411                    cx.notify();
412                })
413                .log_err();
414            }
415        }));
416
417        Self {
418            state,
419            loading_models_task,
420        }
421    }
422
423    fn retry_connection(&self, cx: &mut WindowContext) {
424        self.state
425            .update(cx, |state, cx| state.fetch_models(cx))
426            .detach_and_log_err(cx);
427    }
428}
429
430impl Render for ConfigurationView {
431    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
432        let is_authenticated = self.state.read(cx).is_authenticated();
433
434        let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
435        let ollama_reqs =
436            "Ollama must be running with at least one model installed to use it in the assistant.";
437
438        let mut inline_code_bg = cx.theme().colors().editor_background;
439        inline_code_bg.fade_out(0.5);
440
441        if self.loading_models_task.is_some() {
442            div().child(Label::new("Loading models...")).into_any()
443        } else {
444            v_flex()
445                .size_full()
446                .gap_3()
447                .child(
448                    v_flex()
449                        .size_full()
450                        .gap_2()
451                        .p_1()
452                        .child(Label::new(ollama_intro))
453                        .child(Label::new(ollama_reqs))
454                        .child(
455                            h_flex()
456                                .gap_0p5()
457                                .child(Label::new("Once installed, try "))
458                                .child(
459                                    div()
460                                        .bg(inline_code_bg)
461                                        .px_1p5()
462                                        .rounded_md()
463                                        .child(Label::new("ollama run llama3.1")),
464                                ),
465                        ),
466                )
467                .child(
468                    h_flex()
469                        .w_full()
470                        .pt_2()
471                        .justify_between()
472                        .gap_2()
473                        .child(
474                            h_flex()
475                                .w_full()
476                                .gap_2()
477                                .map(|this| {
478                                    if is_authenticated {
479                                        this.child(
480                                            Button::new("ollama-site", "Ollama")
481                                                .style(ButtonStyle::Subtle)
482                                                .icon(IconName::ExternalLink)
483                                                .icon_size(IconSize::XSmall)
484                                                .icon_color(Color::Muted)
485                                                .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
486                                                .into_any_element(),
487                                        )
488                                    } else {
489                                        this.child(
490                                            Button::new(
491                                                "download_ollama_button",
492                                                "Download Ollama",
493                                            )
494                                            .style(ButtonStyle::Subtle)
495                                            .icon(IconName::ExternalLink)
496                                            .icon_size(IconSize::XSmall)
497                                            .icon_color(Color::Muted)
498                                            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
499                                            .into_any_element(),
500                                        )
501                                    }
502                                })
503                                .child(
504                                    Button::new("view-models", "All Models")
505                                        .style(ButtonStyle::Subtle)
506                                        .icon(IconName::ExternalLink)
507                                        .icon_size(IconSize::XSmall)
508                                        .icon_color(Color::Muted)
509                                        .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
510                                ),
511                        )
512                        .child(if is_authenticated {
513                            // This is only a button to ensure the spacing is correct
514                            // it should stay disabled
515                            ButtonLike::new("connected")
516                                .disabled(true)
517                                // Since this won't ever be clickable, we can use the arrow cursor
518                                .cursor_style(gpui::CursorStyle::Arrow)
519                                .child(
520                                    h_flex()
521                                        .gap_2()
522                                        .child(Indicator::dot().color(Color::Success))
523                                        .child(Label::new("Connected"))
524                                        .into_any_element(),
525                                )
526                                .into_any_element()
527                        } else {
528                            Button::new("retry_ollama_models", "Connect")
529                                .icon_position(IconPosition::Start)
530                                .icon(IconName::ArrowCircle)
531                                .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
532                                .into_any_element()
533                        }),
534                )
535                .into_any()
536        }
537    }
538}