ollama.rs

  1use anyhow::{anyhow, bail, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
  4use http_client::HttpClient;
  5use ollama::{
  6    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
  7    ChatResponseDelta, OllamaToolCall,
  8};
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use settings::{Settings, SettingsStore};
 12use std::{collections::BTreeMap, sync::Arc, time::Duration};
 13use ui::{prelude::*, ButtonLike, Indicator};
 14use util::ResultExt;
 15
 16use crate::LanguageModelCompletionEvent;
 17use crate::{
 18    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 19    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 20    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 21};
 22
 23const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 24const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 25const OLLAMA_SITE: &str = "https://ollama.com/";
 26
 27const PROVIDER_ID: &str = "ollama";
 28const PROVIDER_NAME: &str = "Ollama";
 29
 30#[derive(Default, Debug, Clone, PartialEq)]
 31pub struct OllamaSettings {
 32    pub api_url: String,
 33    pub low_speed_timeout: Option<Duration>,
 34    pub available_models: Vec<AvailableModel>,
 35}
 36
 37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 38pub struct AvailableModel {
 39    /// The model name in the Ollama API (e.g. "llama3.1:latest")
 40    pub name: String,
 41    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 42    pub display_name: Option<String>,
 43    /// The Context Length parameter to the model (aka num_ctx or n_ctx)
 44    pub max_tokens: usize,
 45}
 46
 47pub struct OllamaLanguageModelProvider {
 48    http_client: Arc<dyn HttpClient>,
 49    state: gpui::Model<State>,
 50}
 51
 52pub struct State {
 53    http_client: Arc<dyn HttpClient>,
 54    available_models: Vec<ollama::Model>,
 55    _subscription: Subscription,
 56}
 57
 58impl State {
 59    fn is_authenticated(&self) -> bool {
 60        !self.available_models.is_empty()
 61    }
 62
 63    fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 64        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 65        let http_client = self.http_client.clone();
 66        let api_url = settings.api_url.clone();
 67
 68        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 69        cx.spawn(|this, mut cx| async move {
 70            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 71
 72            let mut models: Vec<ollama::Model> = models
 73                .into_iter()
 74                // Since there is no metadata from the Ollama API
 75                // indicating which models are embedding models,
 76                // simply filter out models with "-embed" in their name
 77                .filter(|model| !model.name.contains("-embed"))
 78                .map(|model| ollama::Model::new(&model.name, None, None))
 79                .collect();
 80
 81            models.sort_by(|a, b| a.name.cmp(&b.name));
 82
 83            this.update(&mut cx, |this, cx| {
 84                this.available_models = models;
 85                cx.notify();
 86            })
 87        })
 88    }
 89
 90    fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 91        if self.is_authenticated() {
 92            Task::ready(Ok(()))
 93        } else {
 94            self.fetch_models(cx)
 95        }
 96    }
 97}
 98
 99impl OllamaLanguageModelProvider {
100    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
101        let this = Self {
102            http_client: http_client.clone(),
103            state: cx.new_model(|cx| State {
104                http_client,
105                available_models: Default::default(),
106                _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
107                    this.fetch_models(cx).detach();
108                    cx.notify();
109                }),
110            }),
111        };
112        this.state
113            .update(cx, |state, cx| state.fetch_models(cx).detach());
114        this
115    }
116}
117
118impl LanguageModelProviderState for OllamaLanguageModelProvider {
119    type ObservableEntity = State;
120
121    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
122        Some(self.state.clone())
123    }
124}
125
126impl LanguageModelProvider for OllamaLanguageModelProvider {
127    fn id(&self) -> LanguageModelProviderId {
128        LanguageModelProviderId(PROVIDER_ID.into())
129    }
130
131    fn name(&self) -> LanguageModelProviderName {
132        LanguageModelProviderName(PROVIDER_NAME.into())
133    }
134
135    fn icon(&self) -> IconName {
136        IconName::AiOllama
137    }
138
139    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
140        let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
141
142        // Add models from the Ollama API
143        for model in self.state.read(cx).available_models.iter() {
144            models.insert(model.name.clone(), model.clone());
145        }
146
147        // Override with available models from settings
148        for model in AllLanguageModelSettings::get_global(cx)
149            .ollama
150            .available_models
151            .iter()
152        {
153            models.insert(
154                model.name.clone(),
155                ollama::Model {
156                    name: model.name.clone(),
157                    display_name: model.display_name.clone(),
158                    max_tokens: model.max_tokens,
159                    keep_alive: None,
160                },
161            );
162        }
163
164        models
165            .into_values()
166            .map(|model| {
167                Arc::new(OllamaLanguageModel {
168                    id: LanguageModelId::from(model.name.clone()),
169                    model: model.clone(),
170                    http_client: self.http_client.clone(),
171                    request_limiter: RateLimiter::new(4),
172                }) as Arc<dyn LanguageModel>
173            })
174            .collect()
175    }
176
177    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
178        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
179        let http_client = self.http_client.clone();
180        let api_url = settings.api_url.clone();
181        let id = model.id().0.to_string();
182        cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
183            .detach_and_log_err(cx);
184    }
185
186    fn is_authenticated(&self, cx: &AppContext) -> bool {
187        self.state.read(cx).is_authenticated()
188    }
189
190    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
191        self.state.update(cx, |state, cx| state.authenticate(cx))
192    }
193
194    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
195        let state = self.state.clone();
196        cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
197    }
198
199    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
200        self.state.update(cx, |state, cx| state.fetch_models(cx))
201    }
202}
203
204pub struct OllamaLanguageModel {
205    id: LanguageModelId,
206    model: ollama::Model,
207    http_client: Arc<dyn HttpClient>,
208    request_limiter: RateLimiter,
209}
210
211impl OllamaLanguageModel {
212    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
213        ChatRequest {
214            model: self.model.name.clone(),
215            messages: request
216                .messages
217                .into_iter()
218                .map(|msg| match msg.role {
219                    Role::User => ChatMessage::User {
220                        content: msg.string_contents(),
221                    },
222                    Role::Assistant => ChatMessage::Assistant {
223                        content: msg.string_contents(),
224                        tool_calls: None,
225                    },
226                    Role::System => ChatMessage::System {
227                        content: msg.string_contents(),
228                    },
229                })
230                .collect(),
231            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
232            stream: true,
233            options: Some(ChatOptions {
234                num_ctx: Some(self.model.max_tokens),
235                stop: Some(request.stop),
236                temperature: Some(request.temperature),
237                ..Default::default()
238            }),
239            tools: vec![],
240        }
241    }
242    fn request_completion(
243        &self,
244        request: ChatRequest,
245        cx: &AsyncAppContext,
246    ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
247        let http_client = self.http_client.clone();
248
249        let Ok(api_url) = cx.update(|cx| {
250            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
251            settings.api_url.clone()
252        }) else {
253            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
254        };
255
256        async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
257    }
258}
259
260impl LanguageModel for OllamaLanguageModel {
261    fn id(&self) -> LanguageModelId {
262        self.id.clone()
263    }
264
265    fn name(&self) -> LanguageModelName {
266        LanguageModelName::from(self.model.display_name().to_string())
267    }
268
269    fn provider_id(&self) -> LanguageModelProviderId {
270        LanguageModelProviderId(PROVIDER_ID.into())
271    }
272
273    fn provider_name(&self) -> LanguageModelProviderName {
274        LanguageModelProviderName(PROVIDER_NAME.into())
275    }
276
277    fn telemetry_id(&self) -> String {
278        format!("ollama/{}", self.model.id())
279    }
280
281    fn max_token_count(&self) -> usize {
282        self.model.max_token_count()
283    }
284
285    fn count_tokens(
286        &self,
287        request: LanguageModelRequest,
288        _cx: &AppContext,
289    ) -> BoxFuture<'static, Result<usize>> {
290        // There is no endpoint for this _yet_ in Ollama
291        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
292        let token_count = request
293            .messages
294            .iter()
295            .map(|msg| msg.string_contents().chars().count())
296            .sum::<usize>()
297            / 4;
298
299        async move { Ok(token_count) }.boxed()
300    }
301
302    fn stream_completion(
303        &self,
304        request: LanguageModelRequest,
305        cx: &AsyncAppContext,
306    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
307        let request = self.to_ollama_request(request);
308
309        let http_client = self.http_client.clone();
310        let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
311            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
312            (settings.api_url.clone(), settings.low_speed_timeout)
313        }) else {
314            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
315        };
316
317        let future = self.request_limiter.stream(async move {
318            let response =
319                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
320                    .await?;
321            let stream = response
322                .filter_map(|response| async move {
323                    match response {
324                        Ok(delta) => {
325                            let content = match delta.message {
326                                ChatMessage::User { content } => content,
327                                ChatMessage::Assistant { content, .. } => content,
328                                ChatMessage::System { content } => content,
329                            };
330                            Some(Ok(content))
331                        }
332                        Err(error) => Some(Err(error)),
333                    }
334                })
335                .boxed();
336            Ok(stream)
337        });
338
339        async move {
340            Ok(future
341                .await?
342                .map(|result| result.map(LanguageModelCompletionEvent::Text))
343                .boxed())
344        }
345        .boxed()
346    }
347
348    fn use_any_tool(
349        &self,
350        request: LanguageModelRequest,
351        tool_name: String,
352        tool_description: String,
353        schema: serde_json::Value,
354        cx: &AsyncAppContext,
355    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
356        use ollama::{OllamaFunctionTool, OllamaTool};
357        let function = OllamaFunctionTool {
358            name: tool_name.clone(),
359            description: Some(tool_description),
360            parameters: Some(schema),
361        };
362        let tools = vec![OllamaTool::Function { function }];
363        let request = self.to_ollama_request(request).with_tools(tools);
364        let response = self.request_completion(request, cx);
365        self.request_limiter
366            .run(async move {
367                let response = response.await?;
368                let ChatMessage::Assistant { tool_calls, .. } = response.message else {
369                    bail!("message does not have an assistant role");
370                };
371                if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
372                    for call in tool_calls {
373                        let OllamaToolCall::Function(function) = call;
374                        if function.name == tool_name {
375                            return Ok(futures::stream::once(async move {
376                                Ok(function.arguments.to_string())
377                            })
378                            .boxed());
379                        }
380                    }
381                } else {
382                    bail!("assistant message does not have any tool calls");
383                };
384
385                bail!("tool not used")
386            })
387            .boxed()
388    }
389}
390
391struct ConfigurationView {
392    state: gpui::Model<State>,
393    loading_models_task: Option<Task<()>>,
394}
395
396impl ConfigurationView {
397    pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
398        let loading_models_task = Some(cx.spawn({
399            let state = state.clone();
400            |this, mut cx| async move {
401                if let Some(task) = state
402                    .update(&mut cx, |state, cx| state.authenticate(cx))
403                    .log_err()
404                {
405                    task.await.log_err();
406                }
407                this.update(&mut cx, |this, cx| {
408                    this.loading_models_task = None;
409                    cx.notify();
410                })
411                .log_err();
412            }
413        }));
414
415        Self {
416            state,
417            loading_models_task,
418        }
419    }
420
421    fn retry_connection(&self, cx: &mut WindowContext) {
422        self.state
423            .update(cx, |state, cx| state.fetch_models(cx))
424            .detach_and_log_err(cx);
425    }
426}
427
428impl Render for ConfigurationView {
429    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
430        let is_authenticated = self.state.read(cx).is_authenticated();
431
432        let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
433        let ollama_reqs =
434            "Ollama must be running with at least one model installed to use it in the assistant.";
435
436        let mut inline_code_bg = cx.theme().colors().editor_background;
437        inline_code_bg.fade_out(0.5);
438
439        if self.loading_models_task.is_some() {
440            div().child(Label::new("Loading models...")).into_any()
441        } else {
442            v_flex()
443                .size_full()
444                .gap_3()
445                .child(
446                    v_flex()
447                        .size_full()
448                        .gap_2()
449                        .p_1()
450                        .child(Label::new(ollama_intro))
451                        .child(Label::new(ollama_reqs))
452                        .child(
453                            h_flex()
454                                .gap_0p5()
455                                .child(Label::new("Once installed, try "))
456                                .child(
457                                    div()
458                                        .bg(inline_code_bg)
459                                        .px_1p5()
460                                        .rounded_md()
461                                        .child(Label::new("ollama run llama3.1")),
462                                ),
463                        ),
464                )
465                .child(
466                    h_flex()
467                        .w_full()
468                        .pt_2()
469                        .justify_between()
470                        .gap_2()
471                        .child(
472                            h_flex()
473                                .w_full()
474                                .gap_2()
475                                .map(|this| {
476                                    if is_authenticated {
477                                        this.child(
478                                            Button::new("ollama-site", "Ollama")
479                                                .style(ButtonStyle::Subtle)
480                                                .icon(IconName::ExternalLink)
481                                                .icon_size(IconSize::XSmall)
482                                                .icon_color(Color::Muted)
483                                                .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
484                                                .into_any_element(),
485                                        )
486                                    } else {
487                                        this.child(
488                                            Button::new(
489                                                "download_ollama_button",
490                                                "Download Ollama",
491                                            )
492                                            .style(ButtonStyle::Subtle)
493                                            .icon(IconName::ExternalLink)
494                                            .icon_size(IconSize::XSmall)
495                                            .icon_color(Color::Muted)
496                                            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
497                                            .into_any_element(),
498                                        )
499                                    }
500                                })
501                                .child(
502                                    Button::new("view-models", "All Models")
503                                        .style(ButtonStyle::Subtle)
504                                        .icon(IconName::ExternalLink)
505                                        .icon_size(IconSize::XSmall)
506                                        .icon_color(Color::Muted)
507                                        .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
508                                ),
509                        )
510                        .child(if is_authenticated {
511                            // This is only a button to ensure the spacing is correct
512                            // it should stay disabled
513                            ButtonLike::new("connected")
514                                .disabled(true)
515                                // Since this won't ever be clickable, we can use the arrow cursor
516                                .cursor_style(gpui::CursorStyle::Arrow)
517                                .child(
518                                    h_flex()
519                                        .gap_2()
520                                        .child(Indicator::dot().color(Color::Success))
521                                        .child(Label::new("Connected"))
522                                        .into_any_element(),
523                                )
524                                .into_any_element()
525                        } else {
526                            Button::new("retry_ollama_models", "Connect")
527                                .icon_position(IconPosition::Start)
528                                .icon(IconName::ArrowCircle)
529                                .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
530                                .into_any_element()
531                        }),
532                )
533                .into_any()
534        }
535    }
536}