model-resolver.ts

  1import { getModel, type Model } from "@mariozechner/pi-ai";
  2import {
  3  type CustomModelConfig,
  4  type RumiloConfig,
  5} from "../config/schema.js";
  6import { ConfigError } from "../util/errors.js";
  7import { resolveHeaders } from "../util/env.js";
  8
  9export function resolveModel(
 10  modelString: string,
 11  config: RumiloConfig,
 12): Model<any> {
 13  const colonIndex = modelString.indexOf(":");
 14  if (colonIndex === -1) {
 15    throw new ConfigError("Model must be in provider:model format");
 16  }
 17
 18  const provider = modelString.slice(0, colonIndex);
 19  const modelName = modelString.slice(colonIndex + 1);
 20
 21  if (!provider || !modelName) {
 22    throw new ConfigError("Model must be in provider:model format");
 23  }
 24
 25  // Handle custom models
 26  if (provider === "custom") {
 27    return resolveCustomModel(modelName, config);
 28  }
 29
 30  // Handle built-in providers
 31  return getModel(provider as any, modelName);
 32}
 33
 34function resolveCustomModel(modelName: string, config: RumiloConfig): Model<any> {
 35  if (!config.custom_models) {
 36    throw new ConfigError(
 37      `No custom models configured. Use 'custom:' prefix only with custom model definitions in config.`,
 38    );
 39  }
 40
 41  const customConfig = config.custom_models[modelName];
 42  if (!customConfig) {
 43    const available = Object.keys(config.custom_models).join(", ");
 44    throw new ConfigError(
 45      `Custom model '${modelName}' not found. Available custom models: ${available}`,
 46    );
 47  }
 48
 49  return buildCustomModel(customConfig);
 50}
 51
 52function buildCustomModel(config: CustomModelConfig): Model<any> {
 53  const api = config.api as any;
 54
 55  const cost: any = {
 56    input: config.cost.input,
 57    output: config.cost.output,
 58  };
 59
 60  if (config.cost.cache_read !== undefined) {
 61    cost.cacheRead = config.cost.cache_read;
 62  }
 63
 64  if (config.cost.cache_write !== undefined) {
 65    cost.cacheWrite = config.cost.cache_write;
 66  }
 67
 68  const model: any = {
 69    id: config.id,
 70    name: config.name,
 71    api,
 72    provider: config.provider as any,
 73    baseUrl: config.base_url,
 74    reasoning: config.reasoning,
 75    input: config.input,
 76    cost,
 77    contextWindow: config.context_window,
 78    maxTokens: config.max_tokens,
 79  };
 80
 81  const resolvedHeaders = resolveHeaders(config.headers);
 82  if (resolvedHeaders) {
 83    model.headers = resolvedHeaders;
 84  }
 85
 86  if (config.compat) {
 87    model.compat = convertCompatConfig(config.compat);
 88  }
 89
 90  return model;
 91}
 92
 93function convertCompatConfig(
 94  compat: CustomModelConfig["compat"],
 95): any {
 96  if (!compat) {
 97    throw new Error("Compat config is expected to be defined");
 98  }
 99
100  return {
101    supportsStore: compat.supports_store,
102    supportsDeveloperRole: compat.supports_developer_role,
103    supportsReasoningEffort: compat.supports_reasoning_effort,
104    supportsUsageInStreaming: compat.supports_usage_in_streaming,
105    maxTokensField: compat.max_tokens_field,
106    requiresToolResultName: compat.requires_tool_result_name,
107    requiresAssistantAfterToolResult: compat.requires_assistant_after_tool_result,
108    requiresThinkingAsText: compat.requires_thinking_as_text,
109    requiresMistralToolIds: compat.requires_mistral_tool_ids,
110    thinkingFormat: compat.thinking_format,
111  };
112}