ollama.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  3use futures::{Stream, TryFutureExt, stream};
  4use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  5use http_client::HttpClient;
  6use language_model::{
  7    AuthenticateError, LanguageModelCompletionError, LanguageModelCompletionEvent,
  8    LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
  9    LanguageModelToolUseId, MessageContent, StopReason,
 10};
 11use language_model::{
 12    LanguageModel, LanguageModelId, LanguageModelName, LanguageModelProvider,
 13    LanguageModelProviderId, LanguageModelProviderName, LanguageModelProviderState,
 14    LanguageModelRequest, RateLimiter, Role,
 15};
 16use ollama::{
 17    ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, KeepAlive, OllamaFunctionTool,
 18    OllamaToolCall, get_models, preload_model, show_model, stream_chat_completion,
 19};
 20use schemars::JsonSchema;
 21use serde::{Deserialize, Serialize};
 22use settings::{Settings, SettingsStore};
 23use std::pin::Pin;
 24use std::sync::atomic::{AtomicU64, Ordering};
 25use std::{collections::BTreeMap, sync::Arc};
 26use ui::{ButtonLike, Indicator, List, prelude::*};
 27use util::ResultExt;
 28
 29use crate::AllLanguageModelSettings;
 30use crate::ui::InstructionListItem;
 31
 32const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 33const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 34const OLLAMA_SITE: &str = "https://ollama.com/";
 35
 36const PROVIDER_ID: &str = "ollama";
 37const PROVIDER_NAME: &str = "Ollama";
 38
 39#[derive(Default, Debug, Clone, PartialEq)]
 40pub struct OllamaSettings {
 41    pub api_url: String,
 42    pub available_models: Vec<AvailableModel>,
 43}
 44
 45#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
 46pub struct AvailableModel {
 47    /// The model name in the Ollama API (e.g. "llama3.2:latest")
 48    pub name: String,
 49    /// The model's name in Zed's UI, such as in the model selector dropdown menu in the assistant panel.
 50    pub display_name: Option<String>,
 51    /// The Context Length parameter to the model (aka num_ctx or n_ctx)
 52    pub max_tokens: usize,
 53    /// The number of seconds to keep the connection open after the last request
 54    pub keep_alive: Option<KeepAlive>,
 55    /// Whether the model supports tools
 56    pub supports_tools: Option<bool>,
 57    /// Whether to enable think mode
 58    pub supports_thinking: Option<bool>,
 59}
 60
 61pub struct OllamaLanguageModelProvider {
 62    http_client: Arc<dyn HttpClient>,
 63    state: gpui::Entity<State>,
 64}
 65
 66pub struct State {
 67    http_client: Arc<dyn HttpClient>,
 68    available_models: Vec<ollama::Model>,
 69    fetch_model_task: Option<Task<Result<()>>>,
 70    _subscription: Subscription,
 71}
 72
 73impl State {
 74    fn is_authenticated(&self) -> bool {
 75        !self.available_models.is_empty()
 76    }
 77
 78    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 79        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 80        let http_client = Arc::clone(&self.http_client);
 81        let api_url = settings.api_url.clone();
 82
 83        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 84        cx.spawn(async move |this, cx| {
 85            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 86
 87            let tasks = models
 88                .into_iter()
 89                // Since there is no metadata from the Ollama API
 90                // indicating which models are embedding models,
 91                // simply filter out models with "-embed" in their name
 92                .filter(|model| !model.name.contains("-embed"))
 93                .map(|model| {
 94                    let http_client = Arc::clone(&http_client);
 95                    let api_url = api_url.clone();
 96                    async move {
 97                        let name = model.name.as_str();
 98                        let capabilities = show_model(http_client.as_ref(), &api_url, name).await?;
 99                        let ollama_model = ollama::Model::new(
100                            name,
101                            None,
102                            None,
103                            Some(capabilities.supports_tools()),
104                            Some(capabilities.supports_thinking()),
105                        );
106                        Ok(ollama_model)
107                    }
108                });
109
110            // Rate-limit capability fetches
111            // since there is an arbitrary number of models available
112            let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
113                .buffer_unordered(5)
114                .collect::<Vec<Result<_>>>()
115                .await
116                .into_iter()
117                .collect::<Result<Vec<_>>>()?;
118
119            ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
120
121            this.update(cx, |this, cx| {
122                this.available_models = ollama_models;
123                cx.notify();
124            })
125        })
126    }
127
128    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
129        let task = self.fetch_models(cx);
130        self.fetch_model_task.replace(task);
131    }
132
133    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
134        if self.is_authenticated() {
135            return Task::ready(Ok(()));
136        }
137
138        let fetch_models_task = self.fetch_models(cx);
139        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
140    }
141}
142
143impl OllamaLanguageModelProvider {
144    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
145        let this = Self {
146            http_client: http_client.clone(),
147            state: cx.new(|cx| {
148                let subscription = cx.observe_global::<SettingsStore>({
149                    let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
150                    move |this: &mut State, cx| {
151                        let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
152                        if &settings != new_settings {
153                            settings = new_settings.clone();
154                            this.restart_fetch_models_task(cx);
155                            cx.notify();
156                        }
157                    }
158                });
159
160                State {
161                    http_client,
162                    available_models: Default::default(),
163                    fetch_model_task: None,
164                    _subscription: subscription,
165                }
166            }),
167        };
168        this.state
169            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
170        this
171    }
172}
173
174impl LanguageModelProviderState for OllamaLanguageModelProvider {
175    type ObservableEntity = State;
176
177    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
178        Some(self.state.clone())
179    }
180}
181
182impl LanguageModelProvider for OllamaLanguageModelProvider {
183    fn id(&self) -> LanguageModelProviderId {
184        LanguageModelProviderId(PROVIDER_ID.into())
185    }
186
187    fn name(&self) -> LanguageModelProviderName {
188        LanguageModelProviderName(PROVIDER_NAME.into())
189    }
190
191    fn icon(&self) -> IconName {
192        IconName::AiOllama
193    }
194
195    fn default_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
196        self.provided_models(cx).into_iter().next()
197    }
198
199    fn default_fast_model(&self, cx: &App) -> Option<Arc<dyn LanguageModel>> {
200        self.default_model(cx)
201    }
202
203    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
204        let mut models: BTreeMap<String, ollama::Model> = BTreeMap::default();
205
206        // Add models from the Ollama API
207        for model in self.state.read(cx).available_models.iter() {
208            models.insert(model.name.clone(), model.clone());
209        }
210
211        // Override with available models from settings
212        for model in AllLanguageModelSettings::get_global(cx)
213            .ollama
214            .available_models
215            .iter()
216        {
217            models.insert(
218                model.name.clone(),
219                ollama::Model {
220                    name: model.name.clone(),
221                    display_name: model.display_name.clone(),
222                    max_tokens: model.max_tokens,
223                    keep_alive: model.keep_alive.clone(),
224                    supports_tools: model.supports_tools,
225                    supports_thinking: model.supports_thinking,
226                },
227            );
228        }
229
230        models
231            .into_values()
232            .map(|model| {
233                Arc::new(OllamaLanguageModel {
234                    id: LanguageModelId::from(model.name.clone()),
235                    model: model.clone(),
236                    http_client: self.http_client.clone(),
237                    request_limiter: RateLimiter::new(4),
238                }) as Arc<dyn LanguageModel>
239            })
240            .collect()
241    }
242
243    fn load_model(&self, model: Arc<dyn LanguageModel>, cx: &App) {
244        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
245        let http_client = self.http_client.clone();
246        let api_url = settings.api_url.clone();
247        let id = model.id().0.to_string();
248        cx.spawn(async move |_| preload_model(http_client, &api_url, &id).await)
249            .detach_and_log_err(cx);
250    }
251
252    fn is_authenticated(&self, cx: &App) -> bool {
253        self.state.read(cx).is_authenticated()
254    }
255
256    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
257        self.state.update(cx, |state, cx| state.authenticate(cx))
258    }
259
260    fn configuration_view(&self, window: &mut Window, cx: &mut App) -> AnyView {
261        let state = self.state.clone();
262        cx.new(|cx| ConfigurationView::new(state, window, cx))
263            .into()
264    }
265
266    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
267        self.state.update(cx, |state, cx| state.fetch_models(cx))
268    }
269}
270
271pub struct OllamaLanguageModel {
272    id: LanguageModelId,
273    model: ollama::Model,
274    http_client: Arc<dyn HttpClient>,
275    request_limiter: RateLimiter,
276}
277
278impl OllamaLanguageModel {
279    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
280        ChatRequest {
281            model: self.model.name.clone(),
282            messages: request
283                .messages
284                .into_iter()
285                .map(|msg| match msg.role {
286                    Role::User => ChatMessage::User {
287                        content: msg.string_contents(),
288                    },
289                    Role::Assistant => {
290                        let content = msg.string_contents();
291                        let thinking = msg.content.into_iter().find_map(|content| match content {
292                            MessageContent::Thinking { text, .. } if !text.is_empty() => Some(text),
293                            _ => None,
294                        });
295                        ChatMessage::Assistant {
296                            content,
297                            tool_calls: None,
298                            thinking,
299                        }
300                    }
301                    Role::System => ChatMessage::System {
302                        content: msg.string_contents(),
303                    },
304                })
305                .collect(),
306            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
307            stream: true,
308            options: Some(ChatOptions {
309                num_ctx: Some(self.model.max_tokens),
310                stop: Some(request.stop),
311                temperature: request.temperature.or(Some(1.0)),
312                ..Default::default()
313            }),
314            think: self.model.supports_thinking,
315            tools: request.tools.into_iter().map(tool_into_ollama).collect(),
316        }
317    }
318}
319
320impl LanguageModel for OllamaLanguageModel {
321    fn id(&self) -> LanguageModelId {
322        self.id.clone()
323    }
324
325    fn name(&self) -> LanguageModelName {
326        LanguageModelName::from(self.model.display_name().to_string())
327    }
328
329    fn provider_id(&self) -> LanguageModelProviderId {
330        LanguageModelProviderId(PROVIDER_ID.into())
331    }
332
333    fn provider_name(&self) -> LanguageModelProviderName {
334        LanguageModelProviderName(PROVIDER_NAME.into())
335    }
336
337    fn supports_tools(&self) -> bool {
338        self.model.supports_tools.unwrap_or(false)
339    }
340
341    fn supports_images(&self) -> bool {
342        false
343    }
344
345    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
346        match choice {
347            LanguageModelToolChoice::Auto => false,
348            LanguageModelToolChoice::Any => false,
349            LanguageModelToolChoice::None => false,
350        }
351    }
352
353    fn telemetry_id(&self) -> String {
354        format!("ollama/{}", self.model.id())
355    }
356
357    fn max_token_count(&self) -> usize {
358        self.model.max_token_count()
359    }
360
361    fn count_tokens(
362        &self,
363        request: LanguageModelRequest,
364        _cx: &App,
365    ) -> BoxFuture<'static, Result<usize>> {
366        // There is no endpoint for this _yet_ in Ollama
367        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
368        let token_count = request
369            .messages
370            .iter()
371            .map(|msg| msg.string_contents().chars().count())
372            .sum::<usize>()
373            / 4;
374
375        async move { Ok(token_count) }.boxed()
376    }
377
378    fn stream_completion(
379        &self,
380        request: LanguageModelRequest,
381        cx: &AsyncApp,
382    ) -> BoxFuture<
383        'static,
384        Result<
385            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
386        >,
387    > {
388        let request = self.to_ollama_request(request);
389
390        let http_client = self.http_client.clone();
391        let Ok(api_url) = cx.update(|cx| {
392            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
393            settings.api_url.clone()
394        }) else {
395            return futures::future::ready(Err(anyhow!("App state dropped"))).boxed();
396        };
397
398        let future = self.request_limiter.stream(async move {
399            let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
400            let stream = map_to_language_model_completion_events(stream);
401            Ok(stream)
402        });
403
404        future.map_ok(|f| f.boxed()).boxed()
405    }
406}
407
408fn map_to_language_model_completion_events(
409    stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
410) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
411    // Used for creating unique tool use ids
412    static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
413
414    struct State {
415        stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
416        used_tools: bool,
417    }
418
419    // We need to create a ToolUse and Stop event from a single
420    // response from the original stream
421    let stream = stream::unfold(
422        State {
423            stream,
424            used_tools: false,
425        },
426        async move |mut state| {
427            let response = state.stream.next().await?;
428
429            let delta = match response {
430                Ok(delta) => delta,
431                Err(e) => {
432                    let event = Err(LanguageModelCompletionError::Other(anyhow!(e)));
433                    return Some((vec![event], state));
434                }
435            };
436
437            let mut events = Vec::new();
438
439            match delta.message {
440                ChatMessage::User { content } => {
441                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
442                }
443                ChatMessage::System { content } => {
444                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
445                }
446                ChatMessage::Assistant {
447                    content,
448                    tool_calls,
449                    thinking,
450                } => {
451                    if let Some(text) = thinking {
452                        events.push(Ok(LanguageModelCompletionEvent::Thinking {
453                            text,
454                            signature: None,
455                        }));
456                    }
457
458                    if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) {
459                        match tool_call {
460                            OllamaToolCall::Function(function) => {
461                                let tool_id = format!(
462                                    "{}-{}",
463                                    &function.name,
464                                    TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed)
465                                );
466                                let event =
467                                    LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
468                                        id: LanguageModelToolUseId::from(tool_id),
469                                        name: Arc::from(function.name),
470                                        raw_input: function.arguments.to_string(),
471                                        input: function.arguments,
472                                        is_input_complete: true,
473                                    });
474                                events.push(Ok(event));
475                                state.used_tools = true;
476                            }
477                        }
478                    } else if !content.is_empty() {
479                        events.push(Ok(LanguageModelCompletionEvent::Text(content)));
480                    }
481                }
482            };
483
484            if delta.done {
485                if state.used_tools {
486                    state.used_tools = false;
487                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
488                } else {
489                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
490                }
491            }
492
493            Some((events, state))
494        },
495    );
496
497    stream.flat_map(futures::stream::iter)
498}
499
500struct ConfigurationView {
501    state: gpui::Entity<State>,
502    loading_models_task: Option<Task<()>>,
503}
504
505impl ConfigurationView {
506    pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
507        let loading_models_task = Some(cx.spawn_in(window, {
508            let state = state.clone();
509            async move |this, cx| {
510                if let Some(task) = state
511                    .update(cx, |state, cx| state.authenticate(cx))
512                    .log_err()
513                {
514                    task.await.log_err();
515                }
516                this.update(cx, |this, cx| {
517                    this.loading_models_task = None;
518                    cx.notify();
519                })
520                .log_err();
521            }
522        }));
523
524        Self {
525            state,
526            loading_models_task,
527        }
528    }
529
530    fn retry_connection(&self, cx: &mut App) {
531        self.state
532            .update(cx, |state, cx| state.fetch_models(cx))
533            .detach_and_log_err(cx);
534    }
535}
536
537impl Render for ConfigurationView {
538    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
539        let is_authenticated = self.state.read(cx).is_authenticated();
540
541        let ollama_intro =
542            "Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama.";
543
544        if self.loading_models_task.is_some() {
545            div().child(Label::new("Loading models...")).into_any()
546        } else {
547            v_flex()
548                .gap_2()
549                .child(
550                    v_flex().gap_1().child(Label::new(ollama_intro)).child(
551                        List::new()
552                            .child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant."))
553                            .child(InstructionListItem::text_only(
554                                "Once installed, try `ollama run llama3.2`",
555                            )),
556                    ),
557                )
558                .child(
559                    h_flex()
560                        .w_full()
561                        .justify_between()
562                        .gap_2()
563                        .child(
564                            h_flex()
565                                .w_full()
566                                .gap_2()
567                                .map(|this| {
568                                    if is_authenticated {
569                                        this.child(
570                                            Button::new("ollama-site", "Ollama")
571                                                .style(ButtonStyle::Subtle)
572                                                .icon(IconName::ArrowUpRight)
573                                                .icon_size(IconSize::XSmall)
574                                                .icon_color(Color::Muted)
575                                                .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
576                                                .into_any_element(),
577                                        )
578                                    } else {
579                                        this.child(
580                                            Button::new(
581                                                "download_ollama_button",
582                                                "Download Ollama",
583                                            )
584                                            .style(ButtonStyle::Subtle)
585                                            .icon(IconName::ArrowUpRight)
586                                            .icon_size(IconSize::XSmall)
587                                            .icon_color(Color::Muted)
588                                            .on_click(move |_, _, cx| {
589                                                cx.open_url(OLLAMA_DOWNLOAD_URL)
590                                            })
591                                            .into_any_element(),
592                                        )
593                                    }
594                                })
595                                .child(
596                                    Button::new("view-models", "All Models")
597                                        .style(ButtonStyle::Subtle)
598                                        .icon(IconName::ArrowUpRight)
599                                        .icon_size(IconSize::XSmall)
600                                        .icon_color(Color::Muted)
601                                        .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
602                                ),
603                        )
604                        .map(|this| {
605                            if is_authenticated {
606                                this.child(
607                                    ButtonLike::new("connected")
608                                        .disabled(true)
609                                        .cursor_style(gpui::CursorStyle::Arrow)
610                                        .child(
611                                            h_flex()
612                                                .gap_2()
613                                                .child(Indicator::dot().color(Color::Success))
614                                                .child(Label::new("Connected"))
615                                                .into_any_element(),
616                                        ),
617                                )
618                            } else {
619                                this.child(
620                                    Button::new("retry_ollama_models", "Connect")
621                                        .icon_position(IconPosition::Start)
622                                        .icon_size(IconSize::XSmall)
623                                        .icon(IconName::Play)
624                                        .on_click(cx.listener(move |this, _, _, cx| {
625                                            this.retry_connection(cx)
626                                        })),
627                                )
628                            }
629                        })
630                )
631                .into_any()
632        }
633    }
634}
635
636fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
637    ollama::OllamaTool::Function {
638        function: OllamaFunctionTool {
639            name: tool.name,
640            description: Some(tool.description),
641            parameters: Some(tool.input_schema),
642        },
643    }
644}