ollama.rs

  1use anyhow::{anyhow, bail, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
  4use http_client::HttpClient;
  5use language_model::LanguageModelCompletionEvent;
  6use language_model::{
  7    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
  8    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
  9    LanguageModelRequest, RateLimiter, Role,
 10};
 11use ollama::{
 12    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
 13    ChatResponseDelta, KeepAlive, OllamaToolCall,
 14};
 15use schemars::JsonSchema;
 16use serde::{Deserialize, Serialize};
 17use settings::{Settings, SettingsStore};
 18use std::{collections::BTreeMap, sync::Arc};
 19use ui::{prelude::*, ButtonLike, Indicator};
 20use util::ResultExt;
 21
 22use crate::AllLanguageModelSettings;
 23
 24const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 25const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 26const OLLAMA_SITE: &str = "https://ollama.com/";
 27
 28const PROVIDER_ID: &str = "ollama";
 29const PROVIDER_NAME: &str = "Ollama";
 30
 31#[derive(Default, Debug, Clone, PartialEq)]
 32pub struct OllamaSettings {
 33    pub api_url: String,
 34    pub available_models: Vec<AvailableModel>,
 35}
 36
 37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 38pub struct AvailableModel {
 39    /// The model name in the Ollama API (e.g. "llama3.2:latest")
 40    pub name: String,
 41    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 42    pub display_name: Option<String>,
 43    /// The Context Length parameter to the model (aka num_ctx or n_ctx)
 44    pub max_tokens: usize,
 45    /// The number of seconds to keep the connection open after the last request
 46    pub keep_alive: Option<KeepAlive>,
 47}
 48
 49pub struct OllamaLanguageModelProvider {
 50    http_client: Arc<dyn HttpClient>,
 51    state: gpui::Model<State>,
 52}
 53
 54pub struct State {
 55    http_client: Arc<dyn HttpClient>,
 56    available_models: Vec<ollama::Model>,
 57    fetch_model_task: Option<Task<Result<()>>>,
 58    _subscription: Subscription,
 59}
 60
 61impl State {
 62    fn is_authenticated(&self) -> bool {
 63        !self.available_models.is_empty()
 64    }
 65
 66    fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 67        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 68        let http_client = self.http_client.clone();
 69        let api_url = settings.api_url.clone();
 70
 71        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 72        cx.spawn(|this, mut cx| async move {
 73            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 74
 75            let mut models: Vec<ollama::Model> = models
 76                .into_iter()
 77                // Since there is no metadata from the Ollama API
 78                // indicating which models are embedding models,
 79                // simply filter out models with "-embed" in their name
 80                .filter(|model| !model.name.contains("-embed"))
 81                .map(|model| ollama::Model::new(&model.name, None, None))
 82                .collect();
 83
 84            models.sort_by(|a, b| a.name.cmp(&b.name));
 85
 86            this.update(&mut cx, |this, cx| {
 87                this.available_models = models;
 88                cx.notify();
 89            })
 90        })
 91    }
 92
 93    fn restart_fetch_models_task(&mut self, cx: &mut ModelContext<Self>) {
 94        let task = self.fetch_models(cx);
 95        self.fetch_model_task.replace(task);
 96    }
 97
 98    fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 99        if self.is_authenticated() {
100            Task::ready(Ok(()))
101        } else {
102            self.fetch_models(cx)
103        }
104    }
105}
106
107impl OllamaLanguageModelProvider {
108    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
109        let this = Self {
110            http_client: http_client.clone(),
111            state: cx.new_model(|cx| {
112                let subscription = cx.observe_global::<SettingsStore>({
113                    let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
114                    move |this: &mut State, cx| {
115                        let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
116                        if &settings != new_settings {
117                            settings = new_settings.clone();
118                            this.restart_fetch_models_task(cx);
119                            cx.notify();
120                        }
121                    }
122                });
123
124                State {
125                    http_client,
126                    available_models: Default::default(),
127                    fetch_model_task: None,
128                    _subscription: subscription,
129                }
130            }),
131        };
132        this.state
133            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
134        this
135    }
136}
137
138impl LanguageModelProviderState for OllamaLanguageModelProvider {
139    type ObservableEntity = State;
140
141    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
142        Some(self.state.clone())
143    }
144}
145
146impl LanguageModelProvider for OllamaLanguageModelProvider {
147    fn id(&self) -> LanguageModelProviderId {
148        LanguageModelProviderId(PROVIDER_ID.into())
149    }
150
151    fn name(&self) -> LanguageModelProviderName {
152        LanguageModelProviderName(PROVIDER_NAME.into())
153    }
154
155    fn icon(&self) -> IconName {
156        IconName::AiOllama
157    }
158
159    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
160        let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
161
162        // Add models from the Ollama API
163        for model in self.state.read(cx).available_models.iter() {
164            models.insert(model.name.clone(), model.clone());
165        }
166
167        // Override with available models from settings
168        for model in AllLanguageModelSettings::get_global(cx)
169            .ollama
170            .available_models
171            .iter()
172        {
173            models.insert(
174                model.name.clone(),
175                ollama::Model {
176                    name: model.name.clone(),
177                    display_name: model.display_name.clone(),
178                    max_tokens: model.max_tokens,
179                    keep_alive: model.keep_alive.clone(),
180                },
181            );
182        }
183
184        models
185            .into_values()
186            .map(|model| {
187                Arc::new(OllamaLanguageModel {
188                    id: LanguageModelId::from(model.name.clone()),
189                    model: model.clone(),
190                    http_client: self.http_client.clone(),
191                    request_limiter: RateLimiter::new(4),
192                }) as Arc<dyn LanguageModel>
193            })
194            .collect()
195    }
196
197    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
198        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
199        let http_client = self.http_client.clone();
200        let api_url = settings.api_url.clone();
201        let id = model.id().0.to_string();
202        cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
203            .detach_and_log_err(cx);
204    }
205
206    fn is_authenticated(&self, cx: &AppContext) -> bool {
207        self.state.read(cx).is_authenticated()
208    }
209
210    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
211        self.state.update(cx, |state, cx| state.authenticate(cx))
212    }
213
214    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
215        let state = self.state.clone();
216        cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
217    }
218
219    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
220        self.state.update(cx, |state, cx| state.fetch_models(cx))
221    }
222}
223
224pub struct OllamaLanguageModel {
225    id: LanguageModelId,
226    model: ollama::Model,
227    http_client: Arc<dyn HttpClient>,
228    request_limiter: RateLimiter,
229}
230
231impl OllamaLanguageModel {
232    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
233        ChatRequest {
234            model: self.model.name.clone(),
235            messages: request
236                .messages
237                .into_iter()
238                .map(|msg| match msg.role {
239                    Role::User => ChatMessage::User {
240                        content: msg.string_contents(),
241                    },
242                    Role::Assistant => ChatMessage::Assistant {
243                        content: msg.string_contents(),
244                        tool_calls: None,
245                    },
246                    Role::System => ChatMessage::System {
247                        content: msg.string_contents(),
248                    },
249                })
250                .collect(),
251            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
252            stream: true,
253            options: Some(ChatOptions {
254                num_ctx: Some(self.model.max_tokens),
255                stop: Some(request.stop),
256                temperature: request.temperature.or(Some(1.0)),
257                ..Default::default()
258            }),
259            tools: vec![],
260        }
261    }
262    fn request_completion(
263        &self,
264        request: ChatRequest,
265        cx: &AsyncAppContext,
266    ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
267        let http_client = self.http_client.clone();
268
269        let Ok(api_url) = cx.update(|cx| {
270            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
271            settings.api_url.clone()
272        }) else {
273            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
274        };
275
276        async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
277    }
278}
279
280impl LanguageModel for OllamaLanguageModel {
281    fn id(&self) -> LanguageModelId {
282        self.id.clone()
283    }
284
285    fn name(&self) -> LanguageModelName {
286        LanguageModelName::from(self.model.display_name().to_string())
287    }
288
289    fn provider_id(&self) -> LanguageModelProviderId {
290        LanguageModelProviderId(PROVIDER_ID.into())
291    }
292
293    fn provider_name(&self) -> LanguageModelProviderName {
294        LanguageModelProviderName(PROVIDER_NAME.into())
295    }
296
297    fn telemetry_id(&self) -> String {
298        format!("ollama/{}", self.model.id())
299    }
300
301    fn max_token_count(&self) -> usize {
302        self.model.max_token_count()
303    }
304
305    fn count_tokens(
306        &self,
307        request: LanguageModelRequest,
308        _cx: &AppContext,
309    ) -> BoxFuture<'static, Result<usize>> {
310        // There is no endpoint for this _yet_ in Ollama
311        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
312        let token_count = request
313            .messages
314            .iter()
315            .map(|msg| msg.string_contents().chars().count())
316            .sum::<usize>()
317            / 4;
318
319        async move { Ok(token_count) }.boxed()
320    }
321
322    fn stream_completion(
323        &self,
324        request: LanguageModelRequest,
325        cx: &AsyncAppContext,
326    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
327        let request = self.to_ollama_request(request);
328
329        let http_client = self.http_client.clone();
330        let Ok(api_url) = cx.update(|cx| {
331            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
332            settings.api_url.clone()
333        }) else {
334            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
335        };
336
337        let future = self.request_limiter.stream(async move {
338            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
339            let stream = response
340                .filter_map(|response| async move {
341                    match response {
342                        Ok(delta) => {
343                            let content = match delta.message {
344                                ChatMessage::User { content } => content,
345                                ChatMessage::Assistant { content, .. } => content,
346                                ChatMessage::System { content } => content,
347                            };
348                            Some(Ok(content))
349                        }
350                        Err(error) => Some(Err(error)),
351                    }
352                })
353                .boxed();
354            Ok(stream)
355        });
356
357        async move {
358            Ok(future
359                .await?
360                .map(|result| result.map(LanguageModelCompletionEvent::Text))
361                .boxed())
362        }
363        .boxed()
364    }
365
366    fn use_any_tool(
367        &self,
368        request: LanguageModelRequest,
369        tool_name: String,
370        tool_description: String,
371        schema: serde_json::Value,
372        cx: &AsyncAppContext,
373    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
374        use ollama::{OllamaFunctionTool, OllamaTool};
375        let function = OllamaFunctionTool {
376            name: tool_name.clone(),
377            description: Some(tool_description),
378            parameters: Some(schema),
379        };
380        let tools = vec![OllamaTool::Function { function }];
381        let request = self.to_ollama_request(request).with_tools(tools);
382        let response = self.request_completion(request, cx);
383        self.request_limiter
384            .run(async move {
385                let response = response.await?;
386                let ChatMessage::Assistant { tool_calls, .. } = response.message else {
387                    bail!("message does not have an assistant role");
388                };
389                if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
390                    for call in tool_calls {
391                        let OllamaToolCall::Function(function) = call;
392                        if function.name == tool_name {
393                            return Ok(futures::stream::once(async move {
394                                Ok(function.arguments.to_string())
395                            })
396                            .boxed());
397                        }
398                    }
399                } else {
400                    bail!("assistant message does not have any tool calls");
401                };
402
403                bail!("tool not used")
404            })
405            .boxed()
406    }
407}
408
409struct ConfigurationView {
410    state: gpui::Model<State>,
411    loading_models_task: Option<Task<()>>,
412}
413
414impl ConfigurationView {
415    pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
416        let loading_models_task = Some(cx.spawn({
417            let state = state.clone();
418            |this, mut cx| async move {
419                if let Some(task) = state
420                    .update(&mut cx, |state, cx| state.authenticate(cx))
421                    .log_err()
422                {
423                    task.await.log_err();
424                }
425                this.update(&mut cx, |this, cx| {
426                    this.loading_models_task = None;
427                    cx.notify();
428                })
429                .log_err();
430            }
431        }));
432
433        Self {
434            state,
435            loading_models_task,
436        }
437    }
438
439    fn retry_connection(&self, cx: &mut WindowContext) {
440        self.state
441            .update(cx, |state, cx| state.fetch_models(cx))
442            .detach_and_log_err(cx);
443    }
444}
445
446impl Render for ConfigurationView {
447    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
448        let is_authenticated = self.state.read(cx).is_authenticated();
449
450        let ollama_intro = "Get up and running with Llama 3.3, Mistral, Gemma 2, and other large language models with Ollama.";
451        let ollama_reqs =
452            "Ollama must be running with at least one model installed to use it in the assistant.";
453
454        let mut inline_code_bg = cx.theme().colors().editor_background;
455        inline_code_bg.fade_out(0.5);
456
457        if self.loading_models_task.is_some() {
458            div().child(Label::new("Loading models...")).into_any()
459        } else {
460            v_flex()
461                .size_full()
462                .gap_3()
463                .child(
464                    v_flex()
465                        .size_full()
466                        .gap_2()
467                        .p_1()
468                        .child(Label::new(ollama_intro))
469                        .child(Label::new(ollama_reqs))
470                        .child(
471                            h_flex()
472                                .gap_0p5()
473                                .child(Label::new("Once installed, try "))
474                                .child(
475                                    div()
476                                        .bg(inline_code_bg)
477                                        .px_1p5()
478                                        .rounded_md()
479                                        .child(Label::new("ollama run llama3.2")),
480                                ),
481                        ),
482                )
483                .child(
484                    h_flex()
485                        .w_full()
486                        .pt_2()
487                        .justify_between()
488                        .gap_2()
489                        .child(
490                            h_flex()
491                                .w_full()
492                                .gap_2()
493                                .map(|this| {
494                                    if is_authenticated {
495                                        this.child(
496                                            Button::new("ollama-site", "Ollama")
497                                                .style(ButtonStyle::Subtle)
498                                                .icon(IconName::ExternalLink)
499                                                .icon_size(IconSize::XSmall)
500                                                .icon_color(Color::Muted)
501                                                .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
502                                                .into_any_element(),
503                                        )
504                                    } else {
505                                        this.child(
506                                            Button::new(
507                                                "download_ollama_button",
508                                                "Download Ollama",
509                                            )
510                                            .style(ButtonStyle::Subtle)
511                                            .icon(IconName::ExternalLink)
512                                            .icon_size(IconSize::XSmall)
513                                            .icon_color(Color::Muted)
514                                            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
515                                            .into_any_element(),
516                                        )
517                                    }
518                                })
519                                .child(
520                                    Button::new("view-models", "All Models")
521                                        .style(ButtonStyle::Subtle)
522                                        .icon(IconName::ExternalLink)
523                                        .icon_size(IconSize::XSmall)
524                                        .icon_color(Color::Muted)
525                                        .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
526                                ),
527                        )
528                        .child(if is_authenticated {
529                            // This is only a button to ensure the spacing is correct
530                            // it should stay disabled
531                            ButtonLike::new("connected")
532                                .disabled(true)
533                                // Since this won't ever be clickable, we can use the arrow cursor
534                                .cursor_style(gpui::CursorStyle::Arrow)
535                                .child(
536                                    h_flex()
537                                        .gap_2()
538                                        .child(Indicator::dot().color(Color::Success))
539                                        .child(Label::new("Connected"))
540                                        .into_any_element(),
541                                )
542                                .into_any_element()
543                        } else {
544                            Button::new("retry_ollama_models", "Connect")
545                                .icon_position(IconPosition::Start)
546                                .icon(IconName::ArrowCircle)
547                                .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
548                                .into_any_element()
549                        }),
550                )
551                .into_any()
552        }
553    }
554}