ollama.rs

  1use anyhow::{Result, anyhow, bail};
  2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  3use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::{AuthenticateError, LanguageModelCompletionEvent};
  6use language_model::{
  7    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  8    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  9    LanguageModelRequest, RateLimiter, Role,
 10};
 11use ollama::{
 12    ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaToolCall,
 13    get_models, preload_model, stream_chat_completion,
 14};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::{collections::BTreeMap, sync::Arc};
 19use ui::{ButtonLike, Indicator, prelude::*};
 20use util::ResultExt;
 21
 22use crate::AllLanguageModelSettings;
 23
 24const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 25const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 26const OLLAMA_SITE: &str = "https://ollama.com/";
 27
 28const PROVIDER_ID: &str = "ollama";
 29const PROVIDER_NAME: &str = "Ollama";
 30
 31#[derive(Default, Debug, Clone, PartialEq)]
 32pub struct OllamaSettings {
 33    pub api_url: String,
 34    pub available_models: Vec<AvailableModel>,
 35}
 36
 37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 38pub struct AvailableModel {
 39    /// The model name in the Ollama API (e.g. "llama3.2:latest")
 40    pub name: String,
 41    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 42    pub display_name: Option<String>,
 43    /// The Context Length parameter to the model (aka num_ctx or n_ctx)
 44    pub max_tokens: usize,
 45    /// The number of seconds to keep the connection open after the last request
 46    pub keep_alive: Option<KeepAlive>,
 47}
 48
 49pub struct OllamaLanguageModelProvider {
 50    http_client: Arc<dyn HttpClient>,
 51    state: gpui::Entity<State>,
 52}
 53
 54pub struct State {
 55    http_client: Arc<dyn HttpClient>,
 56    available_models: Vec<ollama::Model>,
 57    fetch_model_task: Option<Task<Result<()>>>,
 58    _subscription: Subscription,
 59}
 60
 61impl State {
 62    fn is_authenticated(&self) -> bool {
 63        !self.available_models.is_empty()
 64    }
 65
 66    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 67        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 68        let http_client = self.http_client.clone();
 69        let api_url = settings.api_url.clone();
 70
 71        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 72        cx.spawn(async move |this, cx| {
 73            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 74
 75            let mut models: Vec<ollama::Model> = models
 76                .into_iter()
 77                // Since there is no metadata from the Ollama API
 78                // indicating which models are embedding models,
 79                // simply filter out models with "-embed" in their name
 80                .filter(|model| !model.name.contains("-embed"))
 81                .map(|model| ollama::Model::new(&model.name, None, None))
 82                .collect();
 83
 84            models.sort_by(|a, b| a.name.cmp(&b.name));
 85
 86            this.update(cx, |this, cx| {
 87                this.available_models = models;
 88                cx.notify();
 89            })
 90        })
 91    }
 92
 93    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
 94        let task = self.fetch_models(cx);
 95        self.fetch_model_task.replace(task);
 96    }
 97
 98    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
 99        if self.is_authenticated() {
100            return Task::ready(Ok(()));
101        }
102
103        let fetch_models_task = self.fetch_models(cx);
104        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
105    }
106}
107
108impl OllamaLanguageModelProvider {
109    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
110        let this = Self {
111            http_client: http_client.clone(),
112            state: cx.new(|cx| {
113                let subscription = cx.observe_global::<SettingsStore>({
114                    let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
115                    move |this: &mut State, cx| {
116                        let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
117                        if &settings != new_settings {
118                            settings = new_settings.clone();
119                            this.restart_fetch_models_task(cx);
120                            cx.notify();
121                        }
122                    }
123                });
124
125                State {
126                    http_client,
127                    available_models: Default::default(),
128                    fetch_model_task: None,
129                    _subscription: subscription,
130                }
131            }),
132        };
133        this.state
134            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
135        this
136    }
137}
138
139impl LanguageModelProviderState for OllamaLanguageModelProvider {
140    type ObservableEntity = State;
141
142    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
143        Some(self.state.clone())
144    }
145}
146
147impl LanguageModelProvider for OllamaLanguageModelProvider {
148    fn id(&self) -> LanguageModelProviderId {
149        LanguageModelProviderId(PROVIDER_ID.into())
150    }
151
152    fn name(&self) -> LanguageModelProviderName {
153        LanguageModelProviderName(PROVIDER_NAME.into())
154    }
155
156    fn icon(&self) -> IconName {
157        IconName::AiOllama
158    }
159
160    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
161        self.provided_models(cx).into_iter().next()
162    }
163
164    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
165        let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
166
167        // Add models from the Ollama API
168        for model in self.state.read(cx).available_models.iter() {
169            models.insert(model.name.clone(), model.clone());
170        }
171
172        // Override with available models from settings
173        for model in AllLanguageModelSettings::get_global(cx)
174            .ollama
175            .available_models
176            .iter()
177        {
178            models.insert(
179                model.name.clone(),
180                ollama::Model {
181                    name: model.name.clone(),
182                    display_name: model.display_name.clone(),
183                    max_tokens: model.max_tokens,
184                    keep_alive: model.keep_alive.clone(),
185                },
186            );
187        }
188
189        models
190            .into_values()
191            .map(|model| {
192                Arc::new(OllamaLanguageModel {
193                    id: LanguageModelId::from(model.name.clone()),
194                    model: model.clone(),
195                    http_client: self.http_client.clone(),
196                    request_limiter: RateLimiter::new(4),
197                }) as Arc<dyn LanguageModel>
198            })
199            .collect()
200    }
201
202    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
203        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
204        let http_client = self.http_client.clone();
205        let api_url = settings.api_url.clone();
206        let id = model.id().0.to_string();
207        cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
208            .detach_and_log_err(cx);
209    }
210
211    fn is_authenticated(&self, cx: &App) -> bool {
212        self.state.read(cx).is_authenticated()
213    }
214
215    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
216        self.state.update(cx, |state, cx| state.authenticate(cx))
217    }
218
219    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
220        let state = self.state.clone();
221        cx.new(|cx| ConfigurationView::new(state, window, cx))
222            .into()
223    }
224
225    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
226        self.state.update(cx, |state, cx| state.fetch_models(cx))
227    }
228}
229
230pub struct OllamaLanguageModel {
231    id: LanguageModelId,
232    model: ollama::Model,
233    http_client: Arc<dyn HttpClient>,
234    request_limiter: RateLimiter,
235}
236
237impl OllamaLanguageModel {
238    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
239        ChatRequest {
240            model: self.model.name.clone(),
241            messages: request
242                .messages
243                .into_iter()
244                .map(|msg| match msg.role {
245                    Role::User => ChatMessage::User {
246                        content: msg.string_contents(),
247                    },
248                    Role::Assistant => ChatMessage::Assistant {
249                        content: msg.string_contents(),
250                        tool_calls: None,
251                    },
252                    Role::System => ChatMessage::System {
253                        content: msg.string_contents(),
254                    },
255                })
256                .collect(),
257            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
258            stream: true,
259            options: Some(ChatOptions {
260                num_ctx: Some(self.model.max_tokens),
261                stop: Some(request.stop),
262                temperature: request.temperature.or(Some(1.0)),
263                ..Default::default()
264            }),
265            tools: vec![],
266        }
267    }
268    fn request_completion(
269        &self,
270        request: ChatRequest,
271        cx: &AsyncApp,
272    ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
273        let http_client = self.http_client.clone();
274
275        let Ok(api_url) = cx.update(|cx| {
276            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
277            settings.api_url.clone()
278        }) else {
279            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
280        };
281
282        async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
283    }
284}
285
286impl LanguageModel for OllamaLanguageModel {
287    fn id(&self) -> LanguageModelId {
288        self.id.clone()
289    }
290
291    fn name(&self) -> LanguageModelName {
292        LanguageModelName::from(self.model.display_name().to_string())
293    }
294
295    fn provider_id(&self) -> LanguageModelProviderId {
296        LanguageModelProviderId(PROVIDER_ID.into())
297    }
298
299    fn provider_name(&self) -> LanguageModelProviderName {
300        LanguageModelProviderName(PROVIDER_NAME.into())
301    }
302
303    fn supports_tools(&self) -> bool {
304        false
305    }
306
307    fn telemetry_id(&self) -> String {
308        format!("ollama/{}", self.model.id())
309    }
310
311    fn max_token_count(&self) -> usize {
312        self.model.max_token_count()
313    }
314
315    fn count_tokens(
316        &self,
317        request: LanguageModelRequest,
318        _cx: &App,
319    ) -> BoxFuture<'static, Result<usize>> {
320        // There is no endpoint for this _yet_ in Ollama
321        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
322        let token_count = request
323            .messages
324            .iter()
325            .map(|msg| msg.string_contents().chars().count())
326            .sum::<usize>()
327            / 4;
328
329        async move { Ok(token_count) }.boxed()
330    }
331
332    fn stream_completion(
333        &self,
334        request: LanguageModelRequest,
335        cx: &AsyncApp,
336    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
337        let request = self.to_ollama_request(request);
338
339        let http_client = self.http_client.clone();
340        let Ok(api_url) = cx.update(|cx| {
341            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
342            settings.api_url.clone()
343        }) else {
344            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
345        };
346
347        let future = self.request_limiter.stream(async move {
348            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
349            let stream = response
350                .filter_map(|response| async move {
351                    match response {
352                        Ok(delta) => {
353                            let content = match delta.message {
354                                ChatMessage::User { content } => content,
355                                ChatMessage::Assistant { content, .. } => content,
356                                ChatMessage::System { content } => content,
357                            };
358                            Some(Ok(content))
359                        }
360                        Err(error) => Some(Err(error)),
361                    }
362                })
363                .boxed();
364            Ok(stream)
365        });
366
367        async move {
368            Ok(future
369                .await?
370                .map(|result| result.map(LanguageModelCompletionEvent::Text))
371                .boxed())
372        }
373        .boxed()
374    }
375
376    fn use_any_tool(
377        &self,
378        request: LanguageModelRequest,
379        tool_name: String,
380        tool_description: String,
381        schema: serde_json::Value,
382        cx: &AsyncApp,
383    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
384        use ollama::{OllamaFunctionTool, OllamaTool};
385        let function = OllamaFunctionTool {
386            name: tool_name.clone(),
387            description: Some(tool_description),
388            parameters: Some(schema),
389        };
390        let tools = vec![OllamaTool::Function { function }];
391        let request = self.to_ollama_request(request).with_tools(tools);
392        let response = self.request_completion(request, cx);
393        self.request_limiter
394            .run(async move {
395                let response = response.await?;
396                let ChatMessage::Assistant { tool_calls, .. } = response.message else {
397                    bail!("message does not have an assistant role");
398                };
399                if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
400                    for call in tool_calls {
401                        let OllamaToolCall::Function(function) = call;
402                        if function.name == tool_name {
403                            return Ok(futures::stream::once(async move {
404                                Ok(function.arguments.to_string())
405                            })
406                            .boxed());
407                        }
408                    }
409                } else {
410                    bail!("assistant message does not have any tool calls");
411                };
412
413                bail!("tool not used")
414            })
415            .boxed()
416    }
417}
418
419struct ConfigurationView {
420    state: gpui::Entity<State>,
421    loading_models_task: Option<Task<()>>,
422}
423
424impl ConfigurationView {
425    pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
426        let loading_models_task = Some(cx.spawn_in(window, {
427            let state = state.clone();
428            async move |this, cx| {
429                if let Some(task) = state
430                    .update(cx, |state, cx| state.authenticate(cx))
431                    .log_err()
432                {
433                    task.await.log_err();
434                }
435                this.update(cx, |this, cx| {
436                    this.loading_models_task = None;
437                    cx.notify();
438                })
439                .log_err();
440            }
441        }));
442
443        Self {
444            state,
445            loading_models_task,
446        }
447    }
448
449    fn retry_connection(&self, cx: &mut App) {
450        self.state
451            .update(cx, |state, cx| state.fetch_models(cx))
452            .detach_and_log_err(cx);
453    }
454}
455
456impl Render for ConfigurationView {
457    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
458        let is_authenticated = self.state.read(cx).is_authenticated();
459
460        let ollama_intro = "Get up and running with Llama 3.3, Mistral, Gemma 2, and other large language models with Ollama.";
461        let ollama_reqs =
462            "Ollama must be running with at least one model installed to use it in the assistant.";
463
464        let inline_code_bg = cx.theme().colors().editor_foreground.opacity(0.05);
465
466        if self.loading_models_task.is_some() {
467            div().child(Label::new("Loading models...")).into_any()
468        } else {
469            v_flex()
470                .size_full()
471                .gap_3()
472                .child(
473                    v_flex()
474                        .size_full()
475                        .gap_2()
476                        .p_1()
477                        .child(Label::new(ollama_intro))
478                        .child(Label::new(ollama_reqs))
479                        .child(
480                            h_flex()
481                                .gap_0p5()
482                                .child(Label::new("Once installed, try "))
483                                .child(
484                                    div()
485                                        .bg(inline_code_bg)
486                                        .px_1p5()
487                                        .rounded_sm()
488                                        .child(Label::new("ollama run llama3.2")),
489                                ),
490                        ),
491                )
492                .child(
493                    h_flex()
494                        .w_full()
495                        .pt_2()
496                        .justify_between()
497                        .gap_2()
498                        .child(
499                            h_flex()
500                                .w_full()
501                                .gap_2()
502                                .map(|this| {
503                                    if is_authenticated {
504                                        this.child(
505                                            Button::new("ollama-site", "Ollama")
506                                                .style(ButtonStyle::Subtle)
507                                                .icon(IconName::ArrowUpRight)
508                                                .icon_size(IconSize::XSmall)
509                                                .icon_color(Color::Muted)
510                                                .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
511                                                .into_any_element(),
512                                        )
513                                    } else {
514                                        this.child(
515                                            Button::new(
516                                                "download_ollama_button",
517                                                "Download Ollama",
518                                            )
519                                            .style(ButtonStyle::Subtle)
520                                            .icon(IconName::ArrowUpRight)
521                                            .icon_size(IconSize::XSmall)
522                                            .icon_color(Color::Muted)
523                                            .on_click(move |_, _, cx| {
524                                                cx.open_url(OLLAMA_DOWNLOAD_URL)
525                                            })
526                                            .into_any_element(),
527                                        )
528                                    }
529                                })
530                                .child(
531                                    Button::new("view-models", "All Models")
532                                        .style(ButtonStyle::Subtle)
533                                        .icon(IconName::ArrowUpRight)
534                                        .icon_size(IconSize::XSmall)
535                                        .icon_color(Color::Muted)
536                                        .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
537                                ),
538                        )
539                        .child(if is_authenticated {
540                            // This is only a button to ensure the spacing is correct
541                            // it should stay disabled
542                            ButtonLike::new("connected")
543                                .disabled(true)
544                                // Since this won't ever be clickable, we can use the arrow cursor
545                                .cursor_style(gpui::CursorStyle::Arrow)
546                                .child(
547                                    h_flex()
548                                        .gap_2()
549                                        .child(Indicator::dot().color(Color::Success))
550                                        .child(Label::new("Connected"))
551                                        .into_any_element(),
552                                )
553                                .into_any_element()
554                        } else {
555                            Button::new("retry_ollama_models", "Connect")
556                                .icon_position(IconPosition::Start)
557                                .icon(IconName::ArrowCircle)
558                                .on_click(
559                                    cx.listener(move |this, _, _, cx| this.retry_connection(cx)),
560                                )
561                                .into_any_element()
562                        }),
563                )
564                .into_any()
565        }
566    }
567}