provider_test.go

  1package ollama
  2
  3import (
  4	"context"
  5	"testing"
  6	"time"
  7)
  8
  9func TestGetProvider(t *testing.T) {
 10	if !IsInstalled() {
 11		t.Skip("Ollama is not installed, skipping GetProvider test")
 12	}
 13
 14	ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
 15	defer cancel()
 16
 17	provider, err := GetProvider(ctx)
 18	if err != nil {
 19		t.Fatalf("Failed to get provider: %v", err)
 20	}
 21
 22	if provider.Name != "Ollama" {
 23		t.Errorf("Expected provider name to be 'Ollama', got '%s'", provider.Name)
 24	}
 25
 26	if provider.ID != "ollama" {
 27		t.Errorf("Expected provider ID to be 'ollama', got '%s'", provider.ID)
 28	}
 29
 30	t.Logf("✓ Provider: %s (ID: %s) with %d models",
 31		provider.Name, provider.ID, len(provider.Models))
 32
 33	// Test model details
 34	for _, model := range provider.Models {
 35		t.Logf("  - %s (context: %d, max_tokens: %d, images: %v)",
 36			model.ID, model.ContextWindow, model.DefaultMaxTokens, model.SupportsImages)
 37	}
 38
 39	// Cleanup
 40	defer func() {
 41		if processManager.crushStartedOllama {
 42			cleanup()
 43		}
 44	}()
 45}
 46
 47func TestExtractModelFamily(t *testing.T) {
 48	tests := []struct {
 49		modelName string
 50		expected  string
 51	}{
 52		{"llama3.2:3b", "llama"},
 53		{"mistral:7b", "mistral"},
 54		{"gemma:2b", "gemma"},
 55		{"qwen2.5:14b", "qwen"},
 56		{"phi3:3.8b", "phi"},
 57		{"codellama:13b", "codellama"},
 58		{"llava:13b", "llava"},
 59		{"llama-vision:7b", "llama-vision"},
 60		{"unknown-model:1b", "unknown"},
 61	}
 62
 63	for _, tt := range tests {
 64		t.Run(tt.modelName, func(t *testing.T) {
 65			result := extractModelFamily(tt.modelName)
 66			if result != tt.expected {
 67				t.Errorf("extractModelFamily(%s) = %s, expected %s",
 68					tt.modelName, result, tt.expected)
 69			}
 70		})
 71	}
 72}
 73
 74func TestGetContextWindow(t *testing.T) {
 75	tests := []struct {
 76		family   string
 77		expected int64
 78	}{
 79		{"llama", 131072},
 80		{"mistral", 32768},
 81		{"gemma", 8192},
 82		{"qwen", 131072},
 83		{"phi", 131072},
 84		{"codellama", 16384},
 85		{"unknown", 8192},
 86	}
 87
 88	for _, tt := range tests {
 89		t.Run(tt.family, func(t *testing.T) {
 90			result := getContextWindow(tt.family)
 91			if result != tt.expected {
 92				t.Errorf("getContextWindow(%s) = %d, expected %d",
 93					tt.family, result, tt.expected)
 94			}
 95		})
 96	}
 97}
 98
 99func TestSupportsImages(t *testing.T) {
100	tests := []struct {
101		family   string
102		expected bool
103	}{
104		{"llama-vision", true},
105		{"llava", true},
106		{"llama", false},
107		{"mistral", false},
108		{"unknown", false},
109	}
110
111	for _, tt := range tests {
112		t.Run(tt.family, func(t *testing.T) {
113			result := supportsImages(tt.family)
114			if result != tt.expected {
115				t.Errorf("supportsImages(%s) = %v, expected %v",
116					tt.family, result, tt.expected)
117			}
118		})
119	}
120}