ollama.rs

  1use anyhow::{Result, anyhow};
  2use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
  3use futures::{Stream, TryFutureExt, stream};
  4use gpui::{AnyView, App, AsyncApp, Context, Subscription, Task};
  5use http_client::HttpClient;
  6use language_model::{
  7    AuthenticateError, LanguageModel, LanguageModelCompletionError, LanguageModelCompletionEvent,
  8    LanguageModelId, LanguageModelName, LanguageModelProvider, LanguageModelProviderId,
  9    LanguageModelProviderName, LanguageModelProviderState, LanguageModelRequest,
 10    LanguageModelRequestTool, LanguageModelToolChoice, LanguageModelToolUse,
 11    LanguageModelToolUseId, MessageContent, RateLimiter, Role, StopReason, TokenUsage,
 12};
 13use ollama::{
 14    ChatMessage, ChatOptions, ChatRequest, ChatResponseDelta, OllamaFunctionCall,
 15    OllamaFunctionTool, OllamaToolCall, get_models, show_model, stream_chat_completion,
 16};
 17pub use settings::OllamaAvailableModel as AvailableModel;
 18use settings::{Settings, SettingsStore};
 19use std::pin::Pin;
 20use std::sync::atomic::{AtomicU64, Ordering};
 21use std::{collections::HashMap, sync::Arc};
 22use ui::{ButtonLike, Indicator, List, prelude::*};
 23use util::ResultExt;
 24
 25use crate::AllLanguageModelSettings;
 26use crate::ui::InstructionListItem;
 27
 28const OLLAMA_DOWNLOAD_URL: &str = "https://ollama.com/download";
 29const OLLAMA_LIBRARY_URL: &str = "https://ollama.com/library";
 30const OLLAMA_SITE: &str = "https://ollama.com/";
 31
 32const PROVIDER_ID: LanguageModelProviderId = LanguageModelProviderId::new("ollama");
 33const PROVIDER_NAME: LanguageModelProviderName = LanguageModelProviderName::new("Ollama");
 34
 35#[derive(Default, Debug, Clone, PartialEq)]
 36pub struct OllamaSettings {
 37    pub api_url: String,
 38    pub available_models: Vec<AvailableModel>,
 39}
 40
 41pub struct OllamaLanguageModelProvider {
 42    http_client: Arc<dyn HttpClient>,
 43    state: gpui::Entity<State>,
 44}
 45
 46pub struct State {
 47    http_client: Arc<dyn HttpClient>,
 48    available_models: Vec<ollama::Model>,
 49    fetch_model_task: Option<Task<Result<()>>>,
 50    _subscription: Subscription,
 51}
 52
 53impl State {
 54    fn is_authenticated(&self) -> bool {
 55        !self.available_models.is_empty()
 56    }
 57
 58    fn fetch_models(&mut self, cx: &mut Context<Self>) -> Task<Result<()>> {
 59        let settings = &AllLanguageModelSettings::get_global(cx).ollama;
 60        let http_client = Arc::clone(&self.http_client);
 61        let api_url = settings.api_url.clone();
 62
 63        // As a proxy for the server being "authenticated", we'll check if its up by fetching the models
 64        cx.spawn(async move |this, cx| {
 65            let models = get_models(http_client.as_ref(), &api_url, None).await?;
 66
 67            let tasks = models
 68                .into_iter()
 69                // Since there is no metadata from the Ollama API
 70                // indicating which models are embedding models,
 71                // simply filter out models with "-embed" in their name
 72                .filter(|model| !model.name.contains("-embed"))
 73                .map(|model| {
 74                    let http_client = Arc::clone(&http_client);
 75                    let api_url = api_url.clone();
 76                    async move {
 77                        let name = model.name.as_str();
 78                        let capabilities = show_model(http_client.as_ref(), &api_url, name).await?;
 79                        let ollama_model = ollama::Model::new(
 80                            name,
 81                            None,
 82                            None,
 83                            Some(capabilities.supports_tools()),
 84                            Some(capabilities.supports_vision()),
 85                            Some(capabilities.supports_thinking()),
 86                        );
 87                        Ok(ollama_model)
 88                    }
 89                });
 90
 91            // Rate-limit capability fetches
 92            // since there is an arbitrary number of models available
 93            let mut ollama_models: Vec<_> = futures::stream::iter(tasks)
 94                .buffer_unordered(5)
 95                .collect::<Vec<Result<_>>>()
 96                .await
 97                .into_iter()
 98                .collect::<Result<Vec<_>>>()?;
 99
100            ollama_models.sort_by(|a, b| a.name.cmp(&b.name));
101
102            this.update(cx, |this, cx| {
103                this.available_models = ollama_models;
104                cx.notify();
105            })
106        })
107    }
108
109    fn restart_fetch_models_task(&mut self, cx: &mut Context<Self>) {
110        let task = self.fetch_models(cx);
111        self.fetch_model_task.replace(task);
112    }
113
114    fn authenticate(&mut self, cx: &mut Context<Self>) -> Task<Result<(), AuthenticateError>> {
115        if self.is_authenticated() {
116            return Task::ready(Ok(()));
117        }
118
119        let fetch_models_task = self.fetch_models(cx);
120        cx.spawn(async move |_this, _cx| Ok(fetch_models_task.await?))
121    }
122}
123
124impl OllamaLanguageModelProvider {
125    pub fn new(http_client: Arc<dyn HttpClient>, cx: &mut App) -> Self {
126        let this = Self {
127            http_client: http_client.clone(),
128            state: cx.new(|cx| {
129                let subscription = cx.observe_global::<SettingsStore>({
130                    let mut settings = AllLanguageModelSettings::get_global(cx).ollama.clone();
131                    move |this: &mut State, cx| {
132                        let new_settings = &AllLanguageModelSettings::get_global(cx).ollama;
133                        if &settings != new_settings {
134                            settings = new_settings.clone();
135                            this.restart_fetch_models_task(cx);
136                            cx.notify();
137                        }
138                    }
139                });
140
141                State {
142                    http_client,
143                    available_models: Default::default(),
144                    fetch_model_task: None,
145                    _subscription: subscription,
146                }
147            }),
148        };
149        this.state
150            .update(cx, |state, cx| state.restart_fetch_models_task(cx));
151        this
152    }
153}
154
155impl LanguageModelProviderState for OllamaLanguageModelProvider {
156    type ObservableEntity = State;
157
158    fn observable_entity(&self) -> Option<gpui::Entity<Self::ObservableEntity>> {
159        Some(self.state.clone())
160    }
161}
162
163impl LanguageModelProvider for OllamaLanguageModelProvider {
164    fn id(&self) -> LanguageModelProviderId {
165        PROVIDER_ID
166    }
167
168    fn name(&self) -> LanguageModelProviderName {
169        PROVIDER_NAME
170    }
171
172    fn icon(&self) -> IconName {
173        IconName::AiOllama
174    }
175
176    fn default_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
177        // We shouldn't try to select default model, because it might lead to a load call for an unloaded model.
178        // In a constrained environment where user might not have enough resources it'll be a bad UX to select something
179        // to load by default.
180        None
181    }
182
183    fn default_fast_model(&self, _: &App) -> Option<Arc<dyn LanguageModel>> {
184        // See explanation for default_model.
185        None
186    }
187
188    fn provided_models(&self, cx: &App) -> Vec<Arc<dyn LanguageModel>> {
189        let mut models: HashMap<String, ollama::Model> = HashMap::new();
190
191        // Add models from the Ollama API
192        for model in self.state.read(cx).available_models.iter() {
193            models.insert(model.name.clone(), model.clone());
194        }
195
196        // Override with available models from settings
197        for model in AllLanguageModelSettings::get_global(cx)
198            .ollama
199            .available_models
200            .iter()
201        {
202            models.insert(
203                model.name.clone(),
204                ollama::Model {
205                    name: model.name.clone(),
206                    display_name: model.display_name.clone(),
207                    max_tokens: model.max_tokens,
208                    keep_alive: model.keep_alive.clone(),
209                    supports_tools: model.supports_tools,
210                    supports_vision: model.supports_images,
211                    supports_thinking: model.supports_thinking,
212                },
213            );
214        }
215
216        let mut models = models
217            .into_values()
218            .map(|model| {
219                Arc::new(OllamaLanguageModel {
220                    id: LanguageModelId::from(model.name.clone()),
221                    model,
222                    http_client: self.http_client.clone(),
223                    request_limiter: RateLimiter::new(4),
224                }) as Arc<dyn LanguageModel>
225            })
226            .collect::<Vec<_>>();
227        models.sort_by_key(|model| model.name());
228        models
229    }
230
231    fn is_authenticated(&self, cx: &App) -> bool {
232        self.state.read(cx).is_authenticated()
233    }
234
235    fn authenticate(&self, cx: &mut App) -> Task<Result<(), AuthenticateError>> {
236        self.state.update(cx, |state, cx| state.authenticate(cx))
237    }
238
239    fn configuration_view(
240        &self,
241        _target_agent: language_model::ConfigurationViewTargetAgent,
242        window: &mut Window,
243        cx: &mut App,
244    ) -> AnyView {
245        let state = self.state.clone();
246        cx.new(|cx| ConfigurationView::new(state, window, cx))
247            .into()
248    }
249
250    fn reset_credentials(&self, cx: &mut App) -> Task<Result<()>> {
251        self.state.update(cx, |state, cx| state.fetch_models(cx))
252    }
253}
254
255pub struct OllamaLanguageModel {
256    id: LanguageModelId,
257    model: ollama::Model,
258    http_client: Arc<dyn HttpClient>,
259    request_limiter: RateLimiter,
260}
261
262impl OllamaLanguageModel {
263    fn to_ollama_request(&self, request: LanguageModelRequest) -> ChatRequest {
264        let supports_vision = self.model.supports_vision.unwrap_or(false);
265
266        let mut messages = Vec::with_capacity(request.messages.len());
267
268        for mut msg in request.messages.into_iter() {
269            let images = if supports_vision {
270                msg.content
271                    .iter()
272                    .filter_map(|content| match content {
273                        MessageContent::Image(image) => Some(image.source.to_string()),
274                        _ => None,
275                    })
276                    .collect::<Vec<String>>()
277            } else {
278                vec![]
279            };
280
281            match msg.role {
282                Role::User => {
283                    for tool_result in msg
284                        .content
285                        .extract_if(.., |x| matches!(x, MessageContent::ToolResult(..)))
286                    {
287                        match tool_result {
288                            MessageContent::ToolResult(tool_result) => {
289                                messages.push(ChatMessage::Tool {
290                                    tool_name: tool_result.tool_name.to_string(),
291                                    content: tool_result.content.to_str().unwrap_or("").to_string(),
292                                })
293                            }
294                            _ => unreachable!("Only tool result should be extracted"),
295                        }
296                    }
297                    if !msg.content.is_empty() {
298                        messages.push(ChatMessage::User {
299                            content: msg.string_contents(),
300                            images: if images.is_empty() {
301                                None
302                            } else {
303                                Some(images)
304                            },
305                        })
306                    }
307                }
308                Role::Assistant => {
309                    let content = msg.string_contents();
310                    let mut thinking = None;
311                    let mut tool_calls = Vec::new();
312                    for content in msg.content.into_iter() {
313                        match content {
314                            MessageContent::Thinking { text, .. } if !text.is_empty() => {
315                                thinking = Some(text)
316                            }
317                            MessageContent::ToolUse(tool_use) => {
318                                tool_calls.push(OllamaToolCall::Function(OllamaFunctionCall {
319                                    name: tool_use.name.to_string(),
320                                    arguments: tool_use.input,
321                                }));
322                            }
323                            _ => (),
324                        }
325                    }
326                    messages.push(ChatMessage::Assistant {
327                        content,
328                        tool_calls: Some(tool_calls),
329                        images: if images.is_empty() {
330                            None
331                        } else {
332                            Some(images)
333                        },
334                        thinking,
335                    })
336                }
337                Role::System => messages.push(ChatMessage::System {
338                    content: msg.string_contents(),
339                }),
340            }
341        }
342        ChatRequest {
343            model: self.model.name.clone(),
344            messages,
345            keep_alive: self.model.keep_alive.clone().unwrap_or_default(),
346            stream: true,
347            options: Some(ChatOptions {
348                num_ctx: Some(self.model.max_tokens),
349                stop: Some(request.stop),
350                temperature: request.temperature.or(Some(1.0)),
351                ..Default::default()
352            }),
353            think: self
354                .model
355                .supports_thinking
356                .map(|supports_thinking| supports_thinking && request.thinking_allowed),
357            tools: if self.model.supports_tools.unwrap_or(false) {
358                request.tools.into_iter().map(tool_into_ollama).collect()
359            } else {
360                vec![]
361            },
362        }
363    }
364}
365
366impl LanguageModel for OllamaLanguageModel {
367    fn id(&self) -> LanguageModelId {
368        self.id.clone()
369    }
370
371    fn name(&self) -> LanguageModelName {
372        LanguageModelName::from(self.model.display_name().to_string())
373    }
374
375    fn provider_id(&self) -> LanguageModelProviderId {
376        PROVIDER_ID
377    }
378
379    fn provider_name(&self) -> LanguageModelProviderName {
380        PROVIDER_NAME
381    }
382
383    fn supports_tools(&self) -> bool {
384        self.model.supports_tools.unwrap_or(false)
385    }
386
387    fn supports_images(&self) -> bool {
388        self.model.supports_vision.unwrap_or(false)
389    }
390
391    fn supports_tool_choice(&self, choice: LanguageModelToolChoice) -> bool {
392        match choice {
393            LanguageModelToolChoice::Auto => false,
394            LanguageModelToolChoice::Any => false,
395            LanguageModelToolChoice::None => false,
396        }
397    }
398
399    fn telemetry_id(&self) -> String {
400        format!("ollama/{}", self.model.id())
401    }
402
403    fn max_token_count(&self) -> u64 {
404        self.model.max_token_count()
405    }
406
407    fn count_tokens(
408        &self,
409        request: LanguageModelRequest,
410        _cx: &App,
411    ) -> BoxFuture<'static, Result<u64>> {
412        // There is no endpoint for this _yet_ in Ollama
413        // see: https://github.com/ollama/ollama/issues/1716 and https://github.com/ollama/ollama/issues/3582
414        let token_count = request
415            .messages
416            .iter()
417            .map(|msg| msg.string_contents().chars().count())
418            .sum::<usize>()
419            / 4;
420
421        async move { Ok(token_count as u64) }.boxed()
422    }
423
424    fn stream_completion(
425        &self,
426        request: LanguageModelRequest,
427        cx: &AsyncApp,
428    ) -> BoxFuture<
429        'static,
430        Result<
431            BoxStream<'static, Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
432            LanguageModelCompletionError,
433        >,
434    > {
435        let request = self.to_ollama_request(request);
436
437        let http_client = self.http_client.clone();
438        let Ok(api_url) = cx.update(|cx| {
439            let settings = &AllLanguageModelSettings::get_global(cx).ollama;
440            settings.api_url.clone()
441        }) else {
442            return futures::future::ready(Err(anyhow!("App state dropped").into())).boxed();
443        };
444
445        let future = self.request_limiter.stream(async move {
446            let stream = stream_chat_completion(http_client.as_ref(), &api_url, request).await?;
447            let stream = map_to_language_model_completion_events(stream);
448            Ok(stream)
449        });
450
451        future.map_ok(|f| f.boxed()).boxed()
452    }
453}
454
455fn map_to_language_model_completion_events(
456    stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
457) -> impl Stream<Item = Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
458    // Used for creating unique tool use ids
459    static TOOL_CALL_COUNTER: AtomicU64 = AtomicU64::new(0);
460
461    struct State {
462        stream: Pin<Box<dyn Stream<Item = anyhow::Result<ChatResponseDelta>> + Send>>,
463        used_tools: bool,
464    }
465
466    // We need to create a ToolUse and Stop event from a single
467    // response from the original stream
468    let stream = stream::unfold(
469        State {
470            stream,
471            used_tools: false,
472        },
473        async move |mut state| {
474            let response = state.stream.next().await?;
475
476            let delta = match response {
477                Ok(delta) => delta,
478                Err(e) => {
479                    let event = Err(LanguageModelCompletionError::from(anyhow!(e)));
480                    return Some((vec![event], state));
481                }
482            };
483
484            let mut events = Vec::new();
485
486            match delta.message {
487                ChatMessage::User { content, images: _ } => {
488                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
489                }
490                ChatMessage::System { content } => {
491                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
492                }
493                ChatMessage::Tool { content, .. } => {
494                    events.push(Ok(LanguageModelCompletionEvent::Text(content)));
495                }
496                ChatMessage::Assistant {
497                    content,
498                    tool_calls,
499                    images: _,
500                    thinking,
501                } => {
502                    if let Some(text) = thinking {
503                        events.push(Ok(LanguageModelCompletionEvent::Thinking {
504                            text,
505                            signature: None,
506                        }));
507                    }
508
509                    if let Some(tool_call) = tool_calls.and_then(|v| v.into_iter().next()) {
510                        match tool_call {
511                            OllamaToolCall::Function(function) => {
512                                let tool_id = format!(
513                                    "{}-{}",
514                                    &function.name,
515                                    TOOL_CALL_COUNTER.fetch_add(1, Ordering::Relaxed)
516                                );
517                                let event =
518                                    LanguageModelCompletionEvent::ToolUse(LanguageModelToolUse {
519                                        id: LanguageModelToolUseId::from(tool_id),
520                                        name: Arc::from(function.name),
521                                        raw_input: function.arguments.to_string(),
522                                        input: function.arguments,
523                                        is_input_complete: true,
524                                    });
525                                events.push(Ok(event));
526                                state.used_tools = true;
527                            }
528                        }
529                    } else if !content.is_empty() {
530                        events.push(Ok(LanguageModelCompletionEvent::Text(content)));
531                    }
532                }
533            };
534
535            if delta.done {
536                events.push(Ok(LanguageModelCompletionEvent::UsageUpdate(TokenUsage {
537                    input_tokens: delta.prompt_eval_count.unwrap_or(0),
538                    output_tokens: delta.eval_count.unwrap_or(0),
539                    cache_creation_input_tokens: 0,
540                    cache_read_input_tokens: 0,
541                })));
542                if state.used_tools {
543                    state.used_tools = false;
544                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::ToolUse)));
545                } else {
546                    events.push(Ok(LanguageModelCompletionEvent::Stop(StopReason::EndTurn)));
547                }
548            }
549
550            Some((events, state))
551        },
552    );
553
554    stream.flat_map(futures::stream::iter)
555}
556
557struct ConfigurationView {
558    state: gpui::Entity<State>,
559    loading_models_task: Option<Task<()>>,
560}
561
562impl ConfigurationView {
563    pub fn new(state: gpui::Entity<State>, window: &mut Window, cx: &mut Context<Self>) -> Self {
564        let loading_models_task = Some(cx.spawn_in(window, {
565            let state = state.clone();
566            async move |this, cx| {
567                if let Some(task) = state
568                    .update(cx, |state, cx| state.authenticate(cx))
569                    .log_err()
570                {
571                    task.await.log_err();
572                }
573                this.update(cx, |this, cx| {
574                    this.loading_models_task = None;
575                    cx.notify();
576                })
577                .log_err();
578            }
579        }));
580
581        Self {
582            state,
583            loading_models_task,
584        }
585    }
586
587    fn retry_connection(&self, cx: &mut App) {
588        self.state
589            .update(cx, |state, cx| state.fetch_models(cx))
590            .detach_and_log_err(cx);
591    }
592}
593
594impl Render for ConfigurationView {
595    fn render(&mut self, _: &mut Window, cx: &mut Context<Self>) -> impl IntoElement {
596        let is_authenticated = self.state.read(cx).is_authenticated();
597
598        let ollama_intro =
599            "Get up & running with Llama 3.3, Mistral, Gemma 2, and other LLMs with Ollama.";
600
601        if self.loading_models_task.is_some() {
602            div().child(Label::new("Loading models...")).into_any()
603        } else {
604            v_flex()
605                .gap_2()
606                .child(
607                    v_flex().gap_1().child(Label::new(ollama_intro)).child(
608                        List::new()
609                            .child(InstructionListItem::text_only("Ollama must be running with at least one model installed to use it in the assistant."))
610                            .child(InstructionListItem::text_only(
611                                "Once installed, try `ollama run llama3.2`",
612                            )),
613                    ),
614                )
615                .child(
616                    h_flex()
617                        .w_full()
618                        .justify_between()
619                        .gap_2()
620                        .child(
621                            h_flex()
622                                .w_full()
623                                .gap_2()
624                                .map(|this| {
625                                    if is_authenticated {
626                                        this.child(
627                                            Button::new("ollama-site", "Ollama")
628                                                .style(ButtonStyle::Subtle)
629                                                .icon(IconName::ArrowUpRight)
630                                                .icon_size(IconSize::Small)
631                                                .icon_color(Color::Muted)
632                                                .on_click(move |_, _, cx| cx.open_url(OLLAMA_SITE))
633                                                .into_any_element(),
634                                        )
635                                    } else {
636                                        this.child(
637                                            Button::new(
638                                                "download_ollama_button",
639                                                "Download Ollama",
640                                            )
641                                            .style(ButtonStyle::Subtle)
642                                            .icon(IconName::ArrowUpRight)
643                                            .icon_size(IconSize::Small)
644                                            .icon_color(Color::Muted)
645                                            .on_click(move |_, _, cx| {
646                                                cx.open_url(OLLAMA_DOWNLOAD_URL)
647                                            })
648                                            .into_any_element(),
649                                        )
650                                    }
651                                })
652                                .child(
653                                    Button::new("view-models", "View All Models")
654                                        .style(ButtonStyle::Subtle)
655                                        .icon(IconName::ArrowUpRight)
656                                        .icon_size(IconSize::Small)
657                                        .icon_color(Color::Muted)
658                                        .on_click(move |_, _, cx| cx.open_url(OLLAMA_LIBRARY_URL)),
659                                ),
660                        )
661                        .map(|this| {
662                            if is_authenticated {
663                                this.child(
664                                    ButtonLike::new("connected")
665                                        .disabled(true)
666                                        .cursor_style(gpui::CursorStyle::Arrow)
667                                        .child(
668                                            h_flex()
669                                                .gap_2()
670                                                .child(Indicator::dot().color(Color::Success))
671                                                .child(Label::new("Connected"))
672                                                .into_any_element(),
673                                        ),
674                                )
675                            } else {
676                                this.child(
677                                    Button::new("retry_ollama_models", "Connect")
678                                        .icon_position(IconPosition::Start)
679                                        .icon_size(IconSize::XSmall)
680                                        .icon(IconName::PlayFilled)
681                                        .on_click(cx.listener(move |this, _, _, cx| {
682                                            this.retry_connection(cx)
683                                        })),
684                                )
685                            }
686                        })
687                )
688                .into_any()
689        }
690    }
691}
692
693fn tool_into_ollama(tool: LanguageModelRequestTool) -> ollama::OllamaTool {
694    ollama::OllamaTool::Function {
695        function: OllamaFunctionTool {
696            name: tool.name,
697            description: Some(tool.description),
698            parameters: Some(tool.input_schema),
699        },
700    }
701}