ollama.rs

  1use anyhow::{anyhow, bail, Result};
  2use futures::{future::BoxFuture, stream::BoxStream, FutureExt, StreamExt};
  3use gpui::{AnyView, AppContext, AsyncAppContext, ModelContext, Subscription, Task};
  4use http_client::HttpClient;
  5use ollama::{
  6    get_models, preload_model, stream_chat_completion, ChatMessage, ChatOptions, ChatRequest,
  7    ChatResponseDelta, KeepAlive, OllamaToolCall,
  8};
  9use schemars::JsonSchema;
 10use serde::{Deserialize, Serialize};
 11use settings::{Settings, SettingsStore};
 12use std::{collections::BTreeMap, sync::Arc};
 13use ui::{prelude::*, ButtonLike, Indicator};
 14use util::ResultExt;
 15
 16use crate::LanguageModelCompletionEvent;
 17use crate::{
 18    settings::AllLanguageModelSettings, LanguageModel, LanguageModelId, LanguageModelName,
 19    LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
 20    LanguageModelProviderState, LanguageModelRequest, RateLimiter, Role,
 21};
 22
 23const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 24const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 25const OLLAMA_SITE: &str = "https://ollama.com/";
 26
 27const PROVIDER_ID: &str = "ollama";
 28const PROVIDER_NAME: &str = "Ollama";
 29
 30#[derive(Default, Debug, Clone, PartialEq)]
 31pub struct OllamaSettings {
 32    pub api_url: String,
 33    pub available_models: Vec<AvailableModel>,
 34}
 35
 36#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 37pub struct AvailableModel {
 38    /// The model name in the Ollama API (e.g. "llama3.2:latest")
 39    pub name: String,
 40    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 41    pub display_name: Option<String>,
 42    /// The Context Length parameter to the model (aka num_ctx or n_ctx)
 43    pub max_tokens: usize,
 44    /// The number of seconds to keep the connection open after the last request
 45    pub keep_alive: Option<KeepAlive>,
 46}
 47
 48pub struct OllamaLanguageModelProvider {
 49    http_client: Arc<dyn HttpClient>,
 50    state: gpui::Model<State>,
 51}
 52
 53pub struct State {
 54    http_client: Arc<dyn HttpClient>,
 55    available_models: Vec<ollama::Model>,
 56    fetch_model_task: Option<Task<Result<()>>>,
 57    _subscription: Subscription,
 58}
 59
 60impl State {
 61    fn is_authenticated(&self) -> bool {
 62        !self.available_models.is_empty()
 63    }
 64
 65    fn fetch_models(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 66        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 67        let http_client = self.http_client.clone();
 68        let api_url = settings.api_url.clone();
 69
 70        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 71        cx.spawn(|this, mut cx| async move {
 72            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 73
 74            let mut models: Vec<ollama::Model> = models
 75                .into_iter()
 76                // Since there is no metadata from the Ollama API
 77                // indicating which models are embedding models,
 78                // simply filter out models with "-embed" in their name
 79                .filter(|model| !model.name.contains("-embed"))
 80                .map(|model| ollama::Model::new(&model.name, None, None))
 81                .collect();
 82
 83            models.sort_by(|a, b| a.name.cmp(&b.name));
 84
 85            this.update(&mut cx, |this, cx| {
 86                this.available_models = models;
 87                cx.notify();
 88            })
 89        })
 90    }
 91
 92    fn restart_fetch_models_task(&mut self, cx: &mut ModelContext<Self>) {
 93        let task = self.fetch_models(cx);
 94        self.fetch_model_task.replace(task);
 95    }
 96
 97    fn authenticate(&mut self, cx: &mut ModelContext<Self>) -> Task<Result<()>> {
 98        if self.is_authenticated() {
 99            Task::ready(Ok(()))
100        } else {
101            self.fetch_models(cx)
102        }
103    }
104}
105
106impl OllamaLanguageModelProvider {
107    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut AppContext) -> Self {
108        let this = Self {
109            http_client: http_client.clone(),
110            state: cx.new_model(|cx| {
111                let subscription = cx.observe_global::<SettingsStore>({
112                    let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
113                    move |this: &mut State, cx| {
114                        let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
115                        if &settings != new_settings {
116                            settings = new_settings.clone();
117                            this.restart_fetch_models_task(cx);
118                            cx.notify();
119                        }
120                    }
121                });
122
123                State {
124                    http_client,
125                    available_models: Default::default(),
126                    fetch_model_task: None,
127                    _subscription: subscription,
128                }
129            }),
130        };
131        this.state
132            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
133        this
134    }
135}
136
137impl LanguageModelProviderState for OllamaLanguageModelProvider {
138    type ObservableEntity = State;
139
140    fn observable_entity(&self) -> Option<gpui::Model<Self::ObservableEntity>> {
141        Some(self.state.clone())
142    }
143}
144
145impl LanguageModelProvider for OllamaLanguageModelProvider {
146    fn id(&self) -> LanguageModelProviderId {
147        LanguageModelProviderId(PROVIDER_ID.into())
148    }
149
150    fn name(&self) -> LanguageModelProviderName {
151        LanguageModelProviderName(PROVIDER_NAME.into())
152    }
153
154    fn icon(&self) -> IconName {
155        IconName::AiOllama
156    }
157
158    fn provided_models(&self, cx: &AppContext) -> Vec<Arc<dyn LanguageModel>> {
159        let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
160
161        // Add models from the Ollama API
162        for model in self.state.read(cx).available_models.iter() {
163            models.insert(model.name.clone(), model.clone());
164        }
165
166        // Override with available models from settings
167        for model in AllLanguageModelSettings::get_global(cx)
168            .ollama
169            .available_models
170            .iter()
171        {
172            models.insert(
173                model.name.clone(),
174                ollama::Model {
175                    name: model.name.clone(),
176                    display_name: model.display_name.clone(),
177                    max_tokens: model.max_tokens,
178                    keep_alive: model.keep_alive.clone(),
179                },
180            );
181        }
182
183        models
184            .into_values()
185            .map(|model| {
186                Arc::new(OllamaLanguageModel {
187                    id: LanguageModelId::from(model.name.clone()),
188                    model: model.clone(),
189                    http_client: self.http_client.clone(),
190                    request_limiter: RateLimiter::new(4),
191                }) as Arc<dyn LanguageModel>
192            })
193            .collect()
194    }
195
196    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &AppContext) {
197        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
198        let http_client = self.http_client.clone();
199        let api_url = settings.api_url.clone();
200        let id = model.id().0.to_string();
201        cx.spawn(|_| async move { preload_model(http_client, &api_url, &id).await })
202            .detach_and_log_err(cx);
203    }
204
205    fn is_authenticated(&self, cx: &AppContext) -> bool {
206        self.state.read(cx).is_authenticated()
207    }
208
209    fn authenticate(&self, cx: &mut AppContext) -> Task<Result<()>> {
210        self.state.update(cx, |state, cx| state.authenticate(cx))
211    }
212
213    fn configuration_view(&self, cx: &mut WindowContext) -> AnyView {
214        let state = self.state.clone();
215        cx.new_view(|cx| ConfigurationView::new(state, cx)).into()
216    }
217
218    fn reset_credentials(&self, cx: &mut AppContext) -> Task<Result<()>> {
219        self.state.update(cx, |state, cx| state.fetch_models(cx))
220    }
221}
222
223pub struct OllamaLanguageModel {
224    id: LanguageModelId,
225    model: ollama::Model,
226    http_client: Arc<dyn HttpClient>,
227    request_limiter: RateLimiter,
228}
229
230impl OllamaLanguageModel {
231    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
232        ChatRequest {
233            model: self.model.name.clone(),
234            messages: request
235                .messages
236                .into_iter()
237                .map(|msg| match msg.role {
238                    Role::User => ChatMessage::User {
239                        content: msg.string_contents(),
240                    },
241                    Role::Assistant => ChatMessage::Assistant {
242                        content: msg.string_contents(),
243                        tool_calls: None,
244                    },
245                    Role::System => ChatMessage::System {
246                        content: msg.string_contents(),
247                    },
248                })
249                .collect(),
250            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
251            stream: true,
252            options: Some(ChatOptions {
253                num_ctx: Some(self.model.max_tokens),
254                stop: Some(request.stop),
255                temperature: request.temperature.or(Some(1.0)),
256                ..Default::default()
257            }),
258            tools: vec![],
259        }
260    }
261    fn request_completion(
262        &self,
263        request: ChatRequest,
264        cx: &AsyncAppContext,
265    ) -> BoxFuture<'static, Result<ChatResponseDelta>> {
266        let http_client = self.http_client.clone();
267
268        let Ok(api_url) = cx.update(|cx| {
269            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
270            settings.api_url.clone()
271        }) else {
272            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
273        };
274
275        async move { ollama::complete(http_client.as_ref(), &api_url, request).await }.boxed()
276    }
277}
278
279impl LanguageModel for OllamaLanguageModel {
280    fn id(&self) -> LanguageModelId {
281        self.id.clone()
282    }
283
284    fn name(&self) -> LanguageModelName {
285        LanguageModelName::from(self.model.display_name().to_string())
286    }
287
288    fn provider_id(&self) -> LanguageModelProviderId {
289        LanguageModelProviderId(PROVIDER_ID.into())
290    }
291
292    fn provider_name(&self) -> LanguageModelProviderName {
293        LanguageModelProviderName(PROVIDER_NAME.into())
294    }
295
296    fn telemetry_id(&self) -> String {
297        format!("ollama/{}", self.model.id())
298    }
299
300    fn max_token_count(&self) -> usize {
301        self.model.max_token_count()
302    }
303
304    fn count_tokens(
305        &self,
306        request: LanguageModelRequest,
307        _cx: &AppContext,
308    ) -> BoxFuture<'static, Result<usize>> {
309        // There is no endpoint for this _yet_ in Ollama
310        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
311        let token_count = request
312            .messages
313            .iter()
314            .map(|msg| msg.string_contents().chars().count())
315            .sum::<usize>()
316            / 4;
317
318        async move { Ok(token_count) }.boxed()
319    }
320
321    fn stream_completion(
322        &self,
323        request: LanguageModelRequest,
324        cx: &AsyncAppContext,
325    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<LanguageModelCompletionEvent>>>> {
326        let request = self.to_ollama_request(request);
327
328        let http_client = self.http_client.clone();
329        let Ok(api_url) = cx.update(|cx| {
330            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
331            settings.api_url.clone()
332        }) else {
333            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
334        };
335
336        let future = self.request_limiter.stream(async move {
337            let response = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
338            let stream = response
339                .filter_map(|response| async move {
340                    match response {
341                        Ok(delta) => {
342                            let content = match delta.message {
343                                ChatMessage::User { content } => content,
344                                ChatMessage::Assistant { content, .. } => content,
345                                ChatMessage::System { content } => content,
346                            };
347                            Some(Ok(content))
348                        }
349                        Err(error) => Some(Err(error)),
350                    }
351                })
352                .boxed();
353            Ok(stream)
354        });
355
356        async move {
357            Ok(future
358                .await?
359                .map(|result| result.map(LanguageModelCompletionEvent::Text))
360                .boxed())
361        }
362        .boxed()
363    }
364
365    fn use_any_tool(
366        &self,
367        request: LanguageModelRequest,
368        tool_name: String,
369        tool_description: String,
370        schema: serde_json::Value,
371        cx: &AsyncAppContext,
372    ) -> BoxFuture<'static, Result<BoxStream<'static, Result<String>>>> {
373        use ollama::{OllamaFunctionTool, OllamaTool};
374        let function = OllamaFunctionTool {
375            name: tool_name.clone(),
376            description: Some(tool_description),
377            parameters: Some(schema),
378        };
379        let tools = vec![OllamaTool::Function { function }];
380        let request = self.to_ollama_request(request).with_tools(tools);
381        let response = self.request_completion(request, cx);
382        self.request_limiter
383            .run(async move {
384                let response = response.await?;
385                let ChatMessage::Assistant { tool_calls, .. } = response.message else {
386                    bail!("message does not have an assistant role");
387                };
388                if let Some(tool_calls) = tool_calls.filter(|calls| !calls.is_empty()) {
389                    for call in tool_calls {
390                        let OllamaToolCall::Function(function) = call;
391                        if function.name == tool_name {
392                            return Ok(futures::stream::once(async move {
393                                Ok(function.arguments.to_string())
394                            })
395                            .boxed());
396                        }
397                    }
398                } else {
399                    bail!("assistant message does not have any tool calls");
400                };
401
402                bail!("tool not used")
403            })
404            .boxed()
405    }
406}
407
408struct ConfigurationView {
409    state: gpui::Model<State>,
410    loading_models_task: Option<Task<()>>,
411}
412
413impl ConfigurationView {
414    pub fn new(state: gpui::Model<State>, cx: &mut ViewContext<Self>) -> Self {
415        let loading_models_task = Some(cx.spawn({
416            let state = state.clone();
417            |this, mut cx| async move {
418                if let Some(task) = state
419                    .update(&mut cx, |state, cx| state.authenticate(cx))
420                    .log_err()
421                {
422                    task.await.log_err();
423                }
424                this.update(&mut cx, |this, cx| {
425                    this.loading_models_task = None;
426                    cx.notify();
427                })
428                .log_err();
429            }
430        }));
431
432        Self {
433            state,
434            loading_models_task,
435        }
436    }
437
438    fn retry_connection(&self, cx: &mut WindowContext) {
439        self.state
440            .update(cx, |state, cx| state.fetch_models(cx))
441            .detach_and_log_err(cx);
442    }
443}
444
445impl Render for ConfigurationView {
446    fn render(&mut self, cx: &mut ViewContext<Self>) -> impl IntoElement {
447        let is_authenticated = self.state.read(cx).is_authenticated();
448
449        let ollama_intro = "Get up and running with Llama 3.2, Mistral, Gemma 2, and other large language models with Ollama.";
450        let ollama_reqs =
451            "Ollama must be running with at least one model installed to use it in the assistant.";
452
453        let mut inline_code_bg = cx.theme().colors().editor_background;
454        inline_code_bg.fade_out(0.5);
455
456        if self.loading_models_task.is_some() {
457            div().child(Label::new("Loading models...")).into_any()
458        } else {
459            v_flex()
460                .size_full()
461                .gap_3()
462                .child(
463                    v_flex()
464                        .size_full()
465                        .gap_2()
466                        .p_1()
467                        .child(Label::new(ollama_intro))
468                        .child(Label::new(ollama_reqs))
469                        .child(
470                            h_flex()
471                                .gap_0p5()
472                                .child(Label::new("Once installed, try "))
473                                .child(
474                                    div()
475                                        .bg(inline_code_bg)
476                                        .px_1p5()
477                                        .rounded_md()
478                                        .child(Label::new("ollama run llama3.2")),
479                                ),
480                        ),
481                )
482                .child(
483                    h_flex()
484                        .w_full()
485                        .pt_2()
486                        .justify_between()
487                        .gap_2()
488                        .child(
489                            h_flex()
490                                .w_full()
491                                .gap_2()
492                                .map(|this| {
493                                    if is_authenticated {
494                                        this.child(
495                                            Button::new("ollama-site", "Ollama")
496                                                .style(ButtonStyle::Subtle)
497                                                .icon(IconName::ExternalLink)
498                                                .icon_size(IconSize::XSmall)
499                                                .icon_color(Color::Muted)
500                                                .on_click(move |_, cx| cx.open_url(OLLAMA_SITE))
501                                                .into_any_element(),
502                                        )
503                                    } else {
504                                        this.child(
505                                            Button::new(
506                                                "download_ollama_button",
507                                                "Download Ollama",
508                                            )
509                                            .style(ButtonStyle::Subtle)
510                                            .icon(IconName::ExternalLink)
511                                            .icon_size(IconSize::XSmall)
512                                            .icon_color(Color::Muted)
513                                            .on_click(move |_, cx| cx.open_url(OLLAMA_DOWNLOAD_URL))
514                                            .into_any_element(),
515                                        )
516                                    }
517                                })
518                                .child(
519                                    Button::new("view-models", "All Models")
520                                        .style(ButtonStyle::Subtle)
521                                        .icon(IconName::ExternalLink)
522                                        .icon_size(IconSize::XSmall)
523                                        .icon_color(Color::Muted)
524                                        .on_click(move |_, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
525                                ),
526                        )
527                        .child(if is_authenticated {
528                            // This is only a button to ensure the spacing is correct
529                            // it should stay disabled
530                            ButtonLike::new("connected")
531                                .disabled(true)
532                                // Since this won't ever be clickable, we can use the arrow cursor
533                                .cursor_style(gpui::CursorStyle::Arrow)
534                                .child(
535                                    h_flex()
536                                        .gap_2()
537                                        .child(Indicator::dot().color(Color::Success))
538                                        .child(Label::new("Connected"))
539                                        .into_any_element(),
540                                )
541                                .into_any_element()
542                        } else {
543                            Button::new("retry_ollama_models", "Connect")
544                                .icon_position(IconPosition::Start)
545                                .icon(IconName::ArrowCircle)
546                                .on_click(cx.listener(move |this, _, cx| this.retry_connection(cx)))
547                                .into_any_element()
548                        }),
549                )
550                .into_any()
551        }
552    }
553}