ModelLens / inference_lib.py
luisrui's picture
Deploy MLPMetricFull v2 (47k models, with ID emb)
f86c505 verified
Raw
History Blame Contribute Delete
22 kB
"""Self-contained inference module for the recommendation web app.
Contains trimmed copies of ``MLPMetric`` and ``MLPMetricFull`` (and their
dependencies) so HF Spaces deployments do not need to ship the full
``module/`` package. The class layout and parameter names match the trained
checkpoint exactly, so the original ``state_dict`` loads with
``strict=False`` and a clean diff.
"""
from __future__ import annotations
import hashlib
import math
import re
from types import SimpleNamespace
from typing import Optional
import torch
import torch.nn as nn
class ModelNameAvgEncoder(nn.Module):
"""Hashed-token average over a model name. Optionally adds an ID embedding."""
def __init__(self, args, hash_buckets: int = 10000):
super().__init__()
self.hash_buckets = hash_buckets
self.tok_emb = nn.Embedding(self.hash_buckets, args.token_dim)
self.use_id_emb = bool(getattr(args, "use_id_emb", False))
if self.use_id_emb:
self.id_emb = nn.Embedding(args.num_models + 1, args.model_dim)
self.unk_model_id = args.num_models
@staticmethod
def _split(name: str):
n = (name or "").strip().lower()
if not n:
return []
toks = [n]
if "/" in n:
toks.append(n.split("/")[-1])
toks.extend([t for t in re.split(r"[\/_\-\s]+", n) if t])
out, seen = [], set()
for t in toks:
if t in seen:
continue
out.append(t)
seen.add(t)
return out
def _hash(self, tok: str):
return int(hashlib.md5(tok.encode()).hexdigest(), 16) % self.hash_buckets
def forward(self, model_ids: torch.LongTensor, model_names: list[str]):
device = self.tok_emb.weight.device
vecs = []
for n in model_names:
toks = self._split(n)
if not toks:
vecs.append(torch.zeros(self.tok_emb.embedding_dim, device=device))
continue
idxs = torch.tensor([self._hash(t) for t in toks], device=device, dtype=torch.long)
vecs.append(self.tok_emb(idxs).mean(dim=0))
h_name = torch.stack(vecs, dim=0)
feats = [h_name]
if self.use_id_emb:
feats.append(self.id_emb(model_ids.to(device)))
return torch.cat(feats, dim=-1)
class MLPMetric(nn.Module):
"""MLP recommender that takes raw dataset description embeddings, plus
task / metric / size / family side features, and ranks model candidates.
Mirrors the checkpoint at
``checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id``.
"""
def __init__(self, args):
super().__init__()
self.use_id_emb = bool(getattr(args, "use_id_emb", False))
if self.use_id_emb:
self.model_embedding = nn.Embedding(args.num_models, args.model_dim)
else:
self.model_embedding = None
self.task_embedding = nn.Embedding(args.num_tasks, args.task_dim)
self.model_info_encoder = ModelNameAvgEncoder(args)
self.size_embedding = nn.Embedding(args.num_size_buckets, args.size_dim)
self.num_size_buckets = int(args.num_size_buckets)
self.use_size_prior = bool(getattr(args, "use_size_prior", True))
self.use_family_prior = bool(getattr(args, "use_family_prior", False))
if self.use_family_prior:
family_dim = int(getattr(args, "family_dim", args.size_dim))
self.family_embedding = nn.Embedding(args.num_families, family_dim)
self.family_dim = family_dim
else:
self.family_dim = 0
# Disable Model-Spider fusion path entirely (not used by this checkpoint).
self.use_ms_spider_repr = False
self.ms_fusion_dim = 0
model_info_dim = args.token_dim + (args.model_dim if self.use_id_emb else 0)
dataset_info_dim = args.dataset_desp_dim + args.task_dim
backbone_in_dim = (
model_info_dim + dataset_info_dim + args.size_dim + self.family_dim + self.ms_fusion_dim
)
# Backbone is rebuilt by the metric branch below; the base layers are kept here
# to match the parameter naming of the saved state dict.
self.backbone = nn.Sequential(
nn.Linear(backbone_in_dim, args.hidden_dim),
nn.ReLU(),
nn.Dropout(args.dropout_rate),
nn.Linear(args.hidden_dim, args.hidden_dim),
nn.ReLU(),
nn.Dropout(args.dropout_rate),
)
self.pairwise_head = nn.Linear(args.hidden_dim, 1)
self.pointwise_head = nn.Linear(args.hidden_dim, 1)
prior_in_dim = args.size_dim + self.family_dim
self.prior_head = nn.Sequential(
nn.Linear(prior_in_dim, args.hidden_dim // 2),
nn.ReLU(),
nn.Linear(args.hidden_dim // 2, 1),
)
self.temperature = nn.Parameter(torch.tensor(1.0))
# ---- metric extension (matches the MLPMetric subclass) ----
self.use_metric_embedding = bool(getattr(args, "use_metric_feature", True))
self.num_metrics = int(getattr(args, "num_metrics", 1))
self.metric_dim = int(getattr(args, "metric_dim", args.task_dim))
self.unknown_metric_id = int(getattr(args, "unknown_metric_id", 0))
if self.use_metric_embedding:
self.metric_embedding = nn.Embedding(max(self.num_metrics, 1), self.metric_dim)
in_features = self.backbone[0].in_features + self.metric_dim
hidden = self.backbone[0].out_features
dropout = self.backbone[2].p
self.backbone = nn.Sequential(
nn.Linear(in_features, hidden),
nn.ReLU(),
nn.Dropout(dropout),
nn.Linear(hidden, hidden),
nn.ReLU(),
nn.Dropout(dropout),
)
else:
self.metric_embedding = None
def encode_model(self, model_ids: torch.LongTensor, model_names: list[str]) -> torch.Tensor:
return self.model_info_encoder(model_ids, model_names)
@torch.no_grad()
def build_model_cache(
self,
all_model_names: list[str],
all_model_size_ids: torch.LongTensor,
all_model_family_ids: Optional[torch.LongTensor] = None,
device=None,
):
if device is None:
device = next(self.parameters()).device
size_ids = all_model_size_ids.to(device=device, dtype=torch.long)
M = len(all_model_names)
assert size_ids.shape[0] == M
model_ids = torch.arange(M, device=device, dtype=torch.long)
h_model = self.encode_model(model_ids, all_model_names)
h_size = self.size_embedding(size_ids)
cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids}
if self.use_family_prior and all_model_family_ids is not None:
family_ids = all_model_family_ids.to(device=device, dtype=torch.long)
cache["h_family"] = self.family_embedding(family_ids)
cache["family_ids"] = family_ids
else:
cache["h_family"] = None
cache["family_ids"] = None
return cache
def _metric_embed(
self, metric_ids: Optional[torch.LongTensor], batch_size: int, device
) -> Optional[torch.Tensor]:
if not self.use_metric_embedding or self.metric_embedding is None:
return None
if metric_ids is None:
metric_ids = torch.full(
(batch_size,), int(self.unknown_metric_id), dtype=torch.long, device=device
)
return self.metric_embedding(metric_ids)
@torch.no_grad()
def score_matrix(
self,
task_ids: torch.LongTensor,
dataset_desp_batch: torch.Tensor,
model_cache: dict,
metric_ids: Optional[torch.LongTensor] = None,
chunk_size: int = 8192,
) -> torch.Tensor:
device = dataset_desp_batch.device
B = dataset_desp_batch.size(0)
h_task = self.task_embedding(task_ids)
h_data = dataset_desp_batch
h_metric = self._metric_embed(metric_ids, B, device)
h_model_all = model_cache["h_model"]
h_size_all = model_cache["h_size"]
h_family_all = model_cache.get("h_family")
M = h_model_all.size(0)
if self.use_size_prior or self.use_family_prior:
if h_family_all is not None:
prior_inp_all = torch.cat([h_size_all, h_family_all], dim=-1)
else:
prior_inp_all = h_size_all
prior_all = self.prior_head(prior_inp_all).squeeze(-1)
else:
prior_all = torch.zeros(M, device=device)
out = torch.empty(B, M, device=device)
T = torch.clamp(self.temperature, min=1e-3)
start = 0
while start < M:
end = min(start + chunk_size, M)
m = end - start
h_model = h_model_all[start:end]
h_size = h_size_all[start:end]
h_model_exp = h_model.unsqueeze(0).expand(B, m, -1)
h_size_exp = h_size.unsqueeze(0).expand(B, m, -1)
h_data_exp = h_data.unsqueeze(1).expand(B, m, -1)
h_task_exp = h_task.unsqueeze(1).expand(B, m, -1)
parts = [h_model_exp, h_data_exp, h_size_exp]
if h_family_all is not None:
h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1)
parts.append(h_family_exp)
parts.append(h_task_exp)
if h_metric is not None:
parts.append(h_metric.unsqueeze(1).expand(B, m, -1))
residual_inp = torch.cat(parts, dim=-1)
h = self.backbone(residual_inp.reshape(B * m, -1))
s_chunk = self.pairwise_head(h).reshape(B, m)
prior_chunk = prior_all[start:end].unsqueeze(0)
out[:, start:end] = (s_chunk + prior_chunk) / T
start = end
return out
class MLPMetricFull(MLPMetric):
"""Full-feature recommender. Uses model-id emb, model-name emb, model-desc
emb, dataset-id emb, and dataset-desc emb.
For inference on a *new user dataset* (no global dataset_id), we:
- feed UNK as dataset_id (so dataset_id_embedding still contributes a
learned [UNK] prior),
- feed the user's OpenAI embedding directly as dataset_desc_emb,
bypassing the training-time ``dataset_desc_matrix`` lookup.
Parameter layout matches the training-time class so the state_dict loads
via ``load_state_dict(strict=False)`` after stripping the buffers that
are only useful for the train-set IDs.
"""
def __init__(self, args):
# ---- dim bookkeeping ----
self.dataset_id_emb_dim = int(getattr(args, "dataset_id_emb_dim", 256))
self.dataset_desp_emb_dim = int(getattr(args, "dataset_desp_emb_dim", 1536))
self.model_desp_emb_dim = int(getattr(args, "model_desp_emb_dim", 1536))
# Information-source flags (kept for parity; defaults match training)
self.use_model_id_emb = bool(getattr(args, "use_model_id_emb", True))
self.use_model_name_emb = bool(getattr(args, "use_model_name_emb", True))
self.use_model_desc_emb = bool(getattr(args, "use_model_desc_emb", True))
self.use_dataset_id_emb = bool(getattr(args, "use_dataset_id_emb", True))
self.use_dataset_desc_emb = bool(getattr(args, "use_dataset_desc_emb", True))
self.use_size_feature = bool(getattr(args, "use_size_feature", True))
# The parent's __init__ builds task/size/family/metric embeddings,
# prior_head, temperature, plus a placeholder backbone (which we rebuild).
# Set dataset_desp_dim so parent sizes its placeholder correctly; we
# don't actually use the parent's backbone — we rebuild it below.
orig_desp_dim = args.dataset_desp_dim
args.dataset_desp_dim = self.dataset_id_emb_dim + self.dataset_desp_emb_dim
super().__init__(args)
args.dataset_desp_dim = orig_desp_dim
# ==== Model-side components (own name encoder + own id emb) ====
if self.use_model_name_emb:
args_name_only = SimpleNamespace(**vars(args))
args_name_only.use_id_emb = False
self._name_encoder = ModelNameAvgEncoder(args_name_only)
else:
self._name_encoder = None
if self.use_model_id_emb:
self._id_emb = nn.Embedding(args.num_models + 1, args.model_dim)
self.unk_model_id = args.num_models
else:
self._id_emb = None
self.unk_model_id = 0
# Model-description buffer: one row per known model.
if self.use_model_desc_emb:
self.register_buffer(
"model_desc_matrix",
torch.zeros(args.num_models, self.model_desp_emb_dim),
)
else:
self.register_buffer(
"model_desc_matrix",
torch.zeros(0, self.model_desp_emb_dim),
)
# ==== Dataset-side components ====
num_datasets = int(getattr(args, "num_datasets", 100000))
if self.use_dataset_id_emb:
# +2: one for [UNK], one for the upper slack (matches training)
self.dataset_id_embedding = nn.Embedding(num_datasets + 2, self.dataset_id_emb_dim)
self.unk_dataset_id = num_datasets + 1
else:
self.dataset_id_embedding = None
self.unk_dataset_id = 0
# ``dataset_desc_matrix`` is NOT registered at inference time — we use
# the user's OpenAI embedding directly. The stripped checkpoint also
# omits this buffer.
# ==== Recompute backbone input dim and rebuild ====
model_info_dim = (
(args.token_dim if self.use_model_name_emb else 0)
+ (args.model_dim if self.use_model_id_emb else 0)
+ (self.model_desp_emb_dim if self.use_model_desc_emb else 0)
)
self.model_info_dim = model_info_dim
dataset_emb_dim = (
(self.dataset_id_emb_dim if self.use_dataset_id_emb else 0)
+ (self.dataset_desp_emb_dim if self.use_dataset_desc_emb else 0)
)
self.dataset_emb_dim = dataset_emb_dim
dataset_info_dim = dataset_emb_dim + args.task_dim
metric_dim = self.metric_dim if self.use_metric_embedding else 0
size_emb_dim_eff = args.size_dim if self.use_size_feature else 0
backbone_in = (
model_info_dim
+ dataset_info_dim
+ size_emb_dim_eff
+ self.family_dim
+ metric_dim
)
self.backbone = nn.Sequential(
nn.Linear(backbone_in, args.hidden_dim),
nn.ReLU(),
nn.Dropout(args.dropout_rate),
nn.Linear(args.hidden_dim, args.hidden_dim),
nn.ReLU(),
nn.Dropout(args.dropout_rate),
)
prior_in_actual = 0
if self.use_size_prior and self.use_size_feature:
prior_in_actual += args.size_dim
if self.use_family_prior:
prior_in_actual += self.family_dim
if prior_in_actual > 0:
self.prior_head = nn.Sequential(
nn.Linear(prior_in_actual, args.hidden_dim // 2),
nn.ReLU(),
nn.Linear(args.hidden_dim // 2, 1),
)
# ------------------------------------------------------------------
# Model-side encoding (used by build_model_cache)
# ------------------------------------------------------------------
def encode_model(
self, model_ids: torch.LongTensor, model_names: list[str],
) -> torch.Tensor:
B = model_ids.shape[0]
device = model_ids.device
parts = []
if self.use_model_name_emb:
parts.append(self._name_encoder(model_ids, model_names))
if self.use_model_id_emb:
parts.append(self._id_emb(model_ids))
if self.use_model_desc_emb:
if self.model_desc_matrix.shape[0] > 0:
safe_ids = model_ids.clamp(0, self.model_desc_matrix.shape[0] - 1)
parts.append(self.model_desc_matrix[safe_ids])
else:
parts.append(torch.zeros(B, self.model_desp_emb_dim, device=device))
if not parts:
return torch.zeros(B, 0, device=device)
if len(parts) == 1:
return parts[0]
return torch.cat(parts, dim=-1)
@torch.no_grad()
def build_model_cache(
self,
all_model_names: list[str],
all_model_size_ids: torch.LongTensor,
all_model_family_ids: Optional[torch.LongTensor] = None,
device=None,
):
if device is None:
device = next(self.parameters()).device
size_ids = all_model_size_ids.to(device=device, dtype=torch.long)
M = len(all_model_names)
assert size_ids.shape[0] == M
model_ids = torch.arange(M, device=device, dtype=torch.long)
h_model = self.encode_model(model_ids, all_model_names)
h_size = self.size_embedding(size_ids) if self.use_size_feature else None
cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids}
if self.use_family_prior and all_model_family_ids is not None:
family_ids = all_model_family_ids.to(device=device, dtype=torch.long)
cache["h_family"] = self.family_embedding(family_ids)
cache["family_ids"] = family_ids
else:
cache["h_family"] = None
cache["family_ids"] = None
return cache
# ------------------------------------------------------------------
# Dataset-side encoding for inference: user's OpenAI emb + UNK id
# ------------------------------------------------------------------
def _encode_dataset_at_inference(
self, dataset_desp_emb: torch.Tensor,
) -> torch.Tensor:
"""``dataset_desp_emb`` is the user's OpenAI vector of shape
``[B, dataset_desp_emb_dim]``. We pair it with a learned [UNK]
dataset-id embedding, matching the training-time concatenation order
(id_emb || desc_emb).
"""
B = dataset_desp_emb.shape[0]
device = dataset_desp_emb.device
parts = []
if self.use_dataset_id_emb and self.dataset_id_embedding is not None:
unk = torch.full((B,), int(self.unk_dataset_id), dtype=torch.long, device=device)
parts.append(self.dataset_id_embedding(unk))
if self.use_dataset_desc_emb:
parts.append(dataset_desp_emb)
if not parts:
return torch.zeros(B, 0, device=device)
if len(parts) == 1:
return parts[0]
return torch.cat(parts, dim=-1)
# ------------------------------------------------------------------
# score_matrix at inference time
# ------------------------------------------------------------------
@torch.no_grad()
def score_matrix(
self,
task_ids: torch.LongTensor,
dataset_desp_batch: torch.Tensor,
model_cache: dict,
metric_ids: Optional[torch.LongTensor] = None,
chunk_size: int = 8192,
) -> torch.Tensor:
"""``dataset_desp_batch`` here is the OpenAI embedding ``[B, 1536]``."""
device = dataset_desp_batch.device
B = dataset_desp_batch.size(0)
h_task = self.task_embedding(task_ids)
h_data = self._encode_dataset_at_inference(dataset_desp_batch)
h_metric = self._metric_embed(metric_ids, B, device)
h_model_all = model_cache["h_model"]
h_size_all = model_cache["h_size"] if self.use_size_feature else None
h_family_all = model_cache.get("h_family")
M = h_model_all.size(0)
prior_parts_all = []
if self.use_size_prior and h_size_all is not None:
prior_parts_all.append(h_size_all)
if self.use_family_prior and h_family_all is not None:
prior_parts_all.append(h_family_all)
if prior_parts_all:
prior_inp_all = (
torch.cat(prior_parts_all, dim=-1) if len(prior_parts_all) > 1 else prior_parts_all[0]
)
prior_all = self.prior_head(prior_inp_all).squeeze(-1)
else:
prior_all = torch.zeros(M, device=device)
out = torch.empty(B, M, device=device)
T = torch.clamp(self.temperature, min=1e-3)
start = 0
while start < M:
end = min(start + chunk_size, M)
m = end - start
h_model = h_model_all[start:end]
h_model_exp = h_model.unsqueeze(0).expand(B, m, -1) if h_model.shape[1] > 0 else None
h_data_exp = h_data.unsqueeze(1).expand(B, m, -1) if h_data.shape[1] > 0 else None
h_task_exp = h_task.unsqueeze(1).expand(B, m, -1)
h_size_exp = (
h_size_all[start:end].unsqueeze(0).expand(B, m, -1)
if h_size_all is not None else None
)
h_metric_exp = (
h_metric.unsqueeze(1).expand(B, m, -1) if h_metric is not None else None
)
parts = []
if h_model_exp is not None:
parts.append(h_model_exp)
if h_data_exp is not None:
parts.append(h_data_exp)
if h_size_exp is not None:
parts.append(h_size_exp)
if h_family_all is not None:
h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1)
parts.append(h_family_exp)
parts.append(h_task_exp)
if h_metric_exp is not None:
parts.append(h_metric_exp)
residual_inp = torch.cat(parts, dim=-1)
h = self.backbone(residual_inp.reshape(B * m, -1))
s_chunk = self.pairwise_head(h).reshape(B, m)
prior_chunk = prior_all[start:end].unsqueeze(0)
out[:, start:end] = (s_chunk + prior_chunk) / T
start = end
return out