ollama.rs

  1use anyhow::{anyhow, bail, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
  4use http_client::HttpClient;
  5use ollama::{
  6    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
  7    ChatResponseDelta, OllamaToolCall,
  8};
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use settings::{Settings, SettingsStore};
 12use std::{collections::BTreeMap, sync::Arc, time::Duration};
 13use ui::{prelude::*, ButtonLike, Indicator};
 14use util::ResultExt;
 15
 16use crate::{
 17    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 18    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 19    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 20};
 21
 22const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 23const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 24const OLLAMA_SITE: &str = "https://ollama.com/";
 25
 26const PROVIDER_ID: &str = "ollama";
 27const PROVIDER_NAME: &str = "Ollama";
 28
 29#[derive(Default, Debug, Clone, PartialEq)]
 30pub struct OllamaSettings {
 31    pub api_url: String,
 32    pub low_speed_timeout: Option<Duration>,
 33    pub available_models: Vec<AvailableModel>,
 34}
 35
 36#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 37pub struct AvailableModel {
 38    /// The model name in the Ollama API (e.g. "llama3.1:latest")
 39    pub name: String,
 40    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 41    pub display_name: Option<String>,
 42    /// The Context Length parameter to the model (aka num_ctx or n_ctx)
 43    pub max_tokens: usize,
 44}
 45
 46pub struct OllamaLanguageModelProvider {
 47    http_client: Arc<dyn HttpClient>,
 48    state: gpui::Model<State>,
 49}
 50
 51pub struct State {
 52    http_client: Arc<dyn HttpClient>,
 53    available_models: Vec<ollama::Model>,
 54    _subscription: Subscription,
 55}
 56
 57impl State {
 58    fn is_authenticated(&self) -> bool {
 59        !self.available_models.is_empty()
 60    }
 61
 62    fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 63        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 64        let http_client = self.http_client.clone();
 65        let api_url = settings.api_url.clone();
 66
 67        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 68        cx.spawn(|this, mut cx| async move {
 69            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 70
 71            let mut models: Vec<ollama::Model> = models
 72                .into_iter()
 73                // Since there is no metadata from the Ollama API
 74                // indicating which models are embedding models,
 75                // simply filter out models with "-embed" in their name
 76                .filter(|model| !model.name.contains("-embed"))
 77                .map(|model| ollama::Model::new(&model.name, None, None))
 78                .collect();
 79
 80            models.sort_by(|a, b| a.name.cmp(&b.name));
 81
 82            this.update(&mut cx, |this, cx| {
 83                this.available_models = models;
 84                cx.notify();
 85            })
 86        })
 87    }
 88
 89    fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 90        if self.is_authenticated() {
 91            Task::ready(Ok(()))
 92        } else {
 93            self.fetch_models(cx)
 94        }
 95    }
 96}
 97
 98impl OllamaLanguageModelProvider {
 99    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
100        let this = Self {
101            http_client: http_client.clone(),
102            state: cx.new_model(|cx| State {
103                http_client,
104                available_models: Default::default(),
105                _subscription: cx.observe_global::<SettingsStore>(|this: &mut State, cx| {
106                    this.fetch_models(cx).detach();
107                    cx.notify();
108                }),
109            }),
110        };
111        this.state
112            .update(cx, |state, cx| state.fetch_models(cx).detach());
113        this
114    }
115}
116
117impl LanguageModelProviderState for OllamaLanguageModelProvider {
118    type ObservableEntity = State;
119
120    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
121        Some(self.state.clone())
122    }
123}
124
125impl LanguageModelProvider for OllamaLanguageModelProvider {
126    fn id(&self) -> LanguageModelProviderId {
127        LanguageModelProviderId(PROVIDER_ID.into())
128    }
129
130    fn name(&self) -> LanguageModelProviderName {
131        LanguageModelProviderName(PROVIDER_NAME.into())
132    }
133
134    fn icon(&self) -> IconName {
135        IconName::AiOllama
136    }
137
138    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
139        let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
140
141        // Add models from the Ollama API
142        for model in self.state.read(cx).available_models.iter() {
143            models.insert(model.name.clone(), model.clone());
144        }
145
146        // Override with available models from settings
147        for model in AllLanguageModelSettings::get_global(cx)
148            .ollama
149            .available_models
150            .iter()
151        {
152            models.insert(
153                model.name.clone(),
154                ollama::Model {
155                    name: model.name.clone(),
156                    display_name: model.display_name.clone(),
157                    max_tokens: model.max_tokens,
158                    keep_alive: None,
159                },
160            );
161        }
162
163        models
164            .into_values()
165            .map(|model| {
166                Arc::new(OllamaLanguageModel {
167                    id: LanguageModelId::from(model.name.clone()),
168                    model: model.clone(),
169                    http_client: self.http_client.clone(),
170                    request_limiter: RateLimiter::new(4),
171                }) as Arc<dyn LanguageModel>
172            })
173            .collect()
174    }
175
176    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
177        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
178        let http_client = self.http_client.clone();
179        let api_url = settings.api_url.clone();
180        let id = model.id().0.to_string();
181        cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
182            .detach_and_log_err(cx);
183    }
184
185    fn is_authenticated(&self, cx: &AppContext) -> bool {
186        self.state.read(cx).is_authenticated()
187    }
188
189    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
190        self.state.update(cx, |state, cx| state.authenticate(cx))
191    }
192
193    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
194        let state = self.state.clone();
195        cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
196    }
197
198    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
199        self.state.update(cx, |state, cx| state.fetch_models(cx))
200    }
201}
202
203pub struct OllamaLanguageModel {
204    id: LanguageModelId,
205    model: ollama::Model,
206    http_client: Arc<dyn HttpClient>,
207    request_limiter: RateLimiter,
208}
209
210impl OllamaLanguageModel {
211    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
212        ChatRequest {
213            model: self.model.name.clone(),
214            messages: request
215                .messages
216                .into_iter()
217                .map(|msg| match msg.role {
218                    Role::User => ChatMessage::User {
219                        content: msg.string_contents(),
220                    },
221                    Role::Assistant => ChatMessage::Assistant {
222                        content: msg.string_contents(),
223                        tool_calls: None,
224                    },
225                    Role::System => ChatMessage::System {
226                        content: msg.string_contents(),
227                    },
228                })
229                .collect(),
230            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
231            stream: true,
232            options: Some(ChatOptions {
233                num_ctx: Some(self.model.max_tokens),
234                stop: Some(request.stop),
235                temperature: Some(request.temperature),
236                ..Default::default()
237            }),
238            tools: vec![],
239        }
240    }
241    fn request_completion(
242        &self,
243        request: ChatRequest,
244        cx: &AsyncAppContext,
245    ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
246        let http_client = self.http_client.clone();
247
248        let Ok(api_url) = cx.update(|cx| {
249            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
250            settings.api_url.clone()
251        }) else {
252            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
253        };
254
255        async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
256    }
257}
258
259impl LanguageModel for OllamaLanguageModel {
260    fn id(&self) -> LanguageModelId {
261        self.id.clone()
262    }
263
264    fn name(&self) -> LanguageModelName {
265        LanguageModelName::from(self.model.display_name().to_string())
266    }
267
268    fn provider_id(&self) -> LanguageModelProviderId {
269        LanguageModelProviderId(PROVIDER_ID.into())
270    }
271
272    fn provider_name(&self) -> LanguageModelProviderName {
273        LanguageModelProviderName(PROVIDER_NAME.into())
274    }
275
276    fn telemetry_id(&self) -> String {
277        format!("ollama/{}", self.model.id())
278    }
279
280    fn max_token_count(&self) -> usize {
281        self.model.max_token_count()
282    }
283
284    fn count_tokens(
285        &self,
286        request: LanguageModelRequest,
287        _cx: &AppContext,
288    ) -> BoxFuture<'static, Result<usize>> {
289        // There is no endpoint for this _yet_ in Ollama
290        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
291        let token_count = request
292            .messages
293            .iter()
294            .map(|msg| msg.string_contents().chars().count())
295            .sum::<usize>()
296            / 4;
297
298        async move { Ok(token_count) }.boxed()
299    }
300
301    fn stream_completion(
302        &self,
303        request: LanguageModelRequest,
304        cx: &AsyncAppContext,
305    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
306        let request = self.to_ollama_request(request);
307
308        let http_client = self.http_client.clone();
309        let Ok((api_url, low_speed_timeout)) = cx.update(|cx| {
310            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
311            (settings.api_url.clone(), settings.low_speed_timeout)
312        }) else {
313            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
314        };
315
316        let future = self.request_limiter.stream(async move {
317            let response =
318                stream_chat_completion(http_client.as_ref(), &api_url, request, low_speed_timeout)
319                    .await?;
320            let stream = response
321                .filter_map(|response| async move {
322                    match response {
323                        Ok(delta) => {
324                            let content = match delta.message {
325                                ChatMessage::User { content } => content,
326                                ChatMessage::Assistant { content, .. } => content,
327                                ChatMessage::System { content } => content,
328                            };
329                            Some(Ok(content))
330                        }
331                        Err(error) => Some(Err(error)),
332                    }
333                })
334                .boxed();
335            Ok(stream)
336        });
337
338        async move { Ok(future.await?.boxed()) }.boxed()
339    }
340
341    fn use_any_tool(
342        &self,
343        request: LanguageModelRequest,
344        tool_name: String,
345        tool_description: String,
346        schema: serde_json::Value,
347        cx: &AsyncAppContext,
348    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
349        use ollama::{OllamaFunctionTool, OllamaTool};
350        let function = OllamaFunctionTool {
351            name: tool_name.clone(),
352            description: Some(tool_description),
353            parameters: Some(schema),
354        };
355        let tools = vec![OllamaTool::Function { function }];
356        let request = self.to_ollama_request(request).with_tools(tools);
357        let response = self.request_completion(request, cx);
358        self.request_limiter
359            .run(async move {
360                let response = response.await?;
361                let ChatMessage::Assistant { tool_calls, .. } = response.message else {
362                    bail!("message does not have an assistant role");
363                };
364                if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
365                    for call in tool_calls {
366                        let OllamaToolCall::Function(function) = call;
367                        if function.name == tool_name {
368                            return Ok(futures::stream::once(async move {
369                                Ok(function.arguments.to_string())
370                            })
371                            .boxed());
372                        }
373                    }
374                } else {
375                    bail!("assistant message does not have any tool calls");
376                };
377
378                bail!("tool not used")
379            })
380            .boxed()
381    }
382}
383
384struct ConfigurationView {
385    state: gpui::Model<State>,
386    loading_models_task: Option<Task<()>>,
387}
388
389impl ConfigurationView {
390    pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
391        let loading_models_task = Some(cx.spawn({
392            let state = state.clone();
393            |this, mut cx| async move {
394                if let Some(task) = state
395                    .update(&mut cx, |state, cx| state.authenticate(cx))
396                    .log_err()
397                {
398                    task.await.log_err();
399                }
400                this.update(&mut cx, |this, cx| {
401                    this.loading_models_task = None;
402                    cx.notify();
403                })
404                .log_err();
405            }
406        }));
407
408        Self {
409            state,
410            loading_models_task,
411        }
412    }
413
414    fn retry_connection(&self, cx: &mut WindowContext) {
415        self.state
416            .update(cx, |state, cx| state.fetch_models(cx))
417            .detach_and_log_err(cx);
418    }
419}
420
421impl Render for ConfigurationView {
422    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
423        let is_authenticated = self.state.read(cx).is_authenticated();
424
425        let ollama_intro = "Get up and running with Llama 3.1, Mistral, Gemma 2, and other large language models with Ollama.";
426        let ollama_reqs =
427            "Ollama must be running with at least one model installed to use it in the assistant.";
428
429        let mut inline_code_bg = cx.theme().colors().editor_background;
430        inline_code_bg.fade_out(0.5);
431
432        if self.loading_models_task.is_some() {
433            div().child(Label::new("Loading models...")).into_any()
434        } else {
435            v_flex()
436                .size_full()
437                .gap_3()
438                .child(
439                    v_flex()
440                        .size_full()
441                        .gap_2()
442                        .p_1()
443                        .child(Label::new(ollama_intro))
444                        .child(Label::new(ollama_reqs))
445                        .child(
446                            h_flex()
447                                .gap_0p5()
448                                .child(Label::new("Once installed, try "))
449                                .child(
450                                    div()
451                                        .bg(inline_code_bg)
452                                        .px_1p5()
453                                        .rounded_md()
454                                        .child(Label::new("ollama run llama3.1")),
455                                ),
456                        ),
457                )
458                .child(
459                    h_flex()
460                        .w_full()
461                        .pt_2()
462                        .justify_between()
463                        .gap_2()
464                        .child(
465                            h_flex()
466                                .w_full()
467                                .gap_2()
468                                .map(|this| {
469                                    if is_authenticated {
470                                        this.child(
471                                            Button::new("ollama-site", "Ollama")
472                                                .style(ButtonStyle::Subtle)
473                                                .icon(IconName::ExternalLink)
474                                                .icon_size(IconSize::XSmall)
475                                                .icon_color(Color::Muted)
476                                                .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
477                                                .into_any_element(),
478                                        )
479                                    } else {
480                                        this.child(
481                                            Button::new(
482                                                "download_ollama_button",
483                                                "Download Ollama",
484                                            )
485                                            .style(ButtonStyle::Subtle)
486                                            .icon(IconName::ExternalLink)
487                                            .icon_size(IconSize::XSmall)
488                                            .icon_color(Color::Muted)
489                                            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
490                                            .into_any_element(),
491                                        )
492                                    }
493                                })
494                                .child(
495                                    Button::new("view-models", "All Models")
496                                        .style(ButtonStyle::Subtle)
497                                        .icon(IconName::ExternalLink)
498                                        .icon_size(IconSize::XSmall)
499                                        .icon_color(Color::Muted)
500                                        .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
501                                ),
502                        )
503                        .child(if is_authenticated {
504                            // This is only a button to ensure the spacing is correct
505                            // it should stay disabled
506                            ButtonLike::new("connected")
507                                .disabled(true)
508                                // Since this won't ever be clickable, we can use the arrow cursor
509                                .cursor_style(gpui::CursorStyle::Arrow)
510                                .child(
511                                    h_flex()
512                                        .gap_2()
513                                        .child(Indicator::dot().color(Color::Success))
514                                        .child(Label::new("Connected"))
515                                        .into_any_element(),
516                                )
517                                .into_any_element()
518                        } else {
519                            Button::new("retry_ollama_models", "Connect")
520                                .icon_position(IconPosition::Start)
521                                .icon(IconName::ArrowCircle)
522                                .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
523                                .into_any_element()
524                        }),
525                )
526                .into_any()
527        }
528    }
529}