| """Self-contained inference module for the recommendation web app. |
| |
| Contains trimmed copies of ``MLPMetric`` and ``MLPMetricFull`` (and their |
| dependencies) so HF Spaces deployments do not need to ship the full |
| ``module/`` package. The class layout and parameter names match the trained |
| checkpoint exactly, so the original ``state_dict`` loads with |
| ``strict=False`` and a clean diff. |
| """ |
| from __future__ import annotations |
|
|
| import hashlib |
| import math |
| import re |
| from types import SimpleNamespace |
| from typing import Optional |
|
|
| import torch |
| import torch.nn as nn |
|
|
|
|
| class ModelNameAvgEncoder(nn.Module): |
| """Hashed-token average over a model name. Optionally adds an ID embedding.""" |
|
|
| def __init__(self, args, hash_buckets: int = 10000): |
| super().__init__() |
| self.hash_buckets = hash_buckets |
| self.tok_emb = nn.Embedding(self.hash_buckets, args.token_dim) |
| self.use_id_emb = bool(getattr(args, "use_id_emb", False)) |
| if self.use_id_emb: |
| self.id_emb = nn.Embedding(args.num_models + 1, args.model_dim) |
| self.unk_model_id = args.num_models |
|
|
| @staticmethod |
| def _split(name: str): |
| n = (name or "").strip().lower() |
| if not n: |
| return [] |
| toks = [n] |
| if "/" in n: |
| toks.append(n.split("/")[-1]) |
| toks.extend([t for t in re.split(r"[\/_\-\s]+", n) if t]) |
| out, seen = [], set() |
| for t in toks: |
| if t in seen: |
| continue |
| out.append(t) |
| seen.add(t) |
| return out |
|
|
| def _hash(self, tok: str): |
| return int(hashlib.md5(tok.encode()).hexdigest(), 16) % self.hash_buckets |
|
|
| def forward(self, model_ids: torch.LongTensor, model_names: list[str]): |
| device = self.tok_emb.weight.device |
| vecs = [] |
| for n in model_names: |
| toks = self._split(n) |
| if not toks: |
| vecs.append(torch.zeros(self.tok_emb.embedding_dim, device=device)) |
| continue |
| idxs = torch.tensor([self._hash(t) for t in toks], device=device, dtype=torch.long) |
| vecs.append(self.tok_emb(idxs).mean(dim=0)) |
| h_name = torch.stack(vecs, dim=0) |
| feats = [h_name] |
| if self.use_id_emb: |
| feats.append(self.id_emb(model_ids.to(device))) |
| return torch.cat(feats, dim=-1) |
|
|
|
|
| class MLPMetric(nn.Module): |
| """MLP recommender that takes raw dataset description embeddings, plus |
| task / metric / size / family side features, and ranks model candidates. |
| |
| Mirrors the checkpoint at |
| ``checkpoint/mlp/unified_augmented/ablation_no_model_id_no_dataset_id``. |
| """ |
|
|
| def __init__(self, args): |
| super().__init__() |
| self.use_id_emb = bool(getattr(args, "use_id_emb", False)) |
| if self.use_id_emb: |
| self.model_embedding = nn.Embedding(args.num_models, args.model_dim) |
| else: |
| self.model_embedding = None |
|
|
| self.task_embedding = nn.Embedding(args.num_tasks, args.task_dim) |
| self.model_info_encoder = ModelNameAvgEncoder(args) |
| self.size_embedding = nn.Embedding(args.num_size_buckets, args.size_dim) |
| self.num_size_buckets = int(args.num_size_buckets) |
| self.use_size_prior = bool(getattr(args, "use_size_prior", True)) |
|
|
| self.use_family_prior = bool(getattr(args, "use_family_prior", False)) |
| if self.use_family_prior: |
| family_dim = int(getattr(args, "family_dim", args.size_dim)) |
| self.family_embedding = nn.Embedding(args.num_families, family_dim) |
| self.family_dim = family_dim |
| else: |
| self.family_dim = 0 |
|
|
| |
| self.use_ms_spider_repr = False |
| self.ms_fusion_dim = 0 |
|
|
| model_info_dim = args.token_dim + (args.model_dim if self.use_id_emb else 0) |
| dataset_info_dim = args.dataset_desp_dim + args.task_dim |
| backbone_in_dim = ( |
| model_info_dim + dataset_info_dim + args.size_dim + self.family_dim + self.ms_fusion_dim |
| ) |
|
|
| |
| |
| self.backbone = nn.Sequential( |
| nn.Linear(backbone_in_dim, args.hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(args.dropout_rate), |
| nn.Linear(args.hidden_dim, args.hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(args.dropout_rate), |
| ) |
| self.pairwise_head = nn.Linear(args.hidden_dim, 1) |
| self.pointwise_head = nn.Linear(args.hidden_dim, 1) |
|
|
| prior_in_dim = args.size_dim + self.family_dim |
| self.prior_head = nn.Sequential( |
| nn.Linear(prior_in_dim, args.hidden_dim // 2), |
| nn.ReLU(), |
| nn.Linear(args.hidden_dim // 2, 1), |
| ) |
| self.temperature = nn.Parameter(torch.tensor(1.0)) |
|
|
| |
| self.use_metric_embedding = bool(getattr(args, "use_metric_feature", True)) |
| self.num_metrics = int(getattr(args, "num_metrics", 1)) |
| self.metric_dim = int(getattr(args, "metric_dim", args.task_dim)) |
| self.unknown_metric_id = int(getattr(args, "unknown_metric_id", 0)) |
| if self.use_metric_embedding: |
| self.metric_embedding = nn.Embedding(max(self.num_metrics, 1), self.metric_dim) |
| in_features = self.backbone[0].in_features + self.metric_dim |
| hidden = self.backbone[0].out_features |
| dropout = self.backbone[2].p |
| self.backbone = nn.Sequential( |
| nn.Linear(in_features, hidden), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| nn.Linear(hidden, hidden), |
| nn.ReLU(), |
| nn.Dropout(dropout), |
| ) |
| else: |
| self.metric_embedding = None |
|
|
| def encode_model(self, model_ids: torch.LongTensor, model_names: list[str]) -> torch.Tensor: |
| return self.model_info_encoder(model_ids, model_names) |
|
|
| @torch.no_grad() |
| def build_model_cache( |
| self, |
| all_model_names: list[str], |
| all_model_size_ids: torch.LongTensor, |
| all_model_family_ids: Optional[torch.LongTensor] = None, |
| device=None, |
| ): |
| if device is None: |
| device = next(self.parameters()).device |
| size_ids = all_model_size_ids.to(device=device, dtype=torch.long) |
| M = len(all_model_names) |
| assert size_ids.shape[0] == M |
| model_ids = torch.arange(M, device=device, dtype=torch.long) |
|
|
| h_model = self.encode_model(model_ids, all_model_names) |
| h_size = self.size_embedding(size_ids) |
| cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids} |
| if self.use_family_prior and all_model_family_ids is not None: |
| family_ids = all_model_family_ids.to(device=device, dtype=torch.long) |
| cache["h_family"] = self.family_embedding(family_ids) |
| cache["family_ids"] = family_ids |
| else: |
| cache["h_family"] = None |
| cache["family_ids"] = None |
| return cache |
|
|
| def _metric_embed( |
| self, metric_ids: Optional[torch.LongTensor], batch_size: int, device |
| ) -> Optional[torch.Tensor]: |
| if not self.use_metric_embedding or self.metric_embedding is None: |
| return None |
| if metric_ids is None: |
| metric_ids = torch.full( |
| (batch_size,), int(self.unknown_metric_id), dtype=torch.long, device=device |
| ) |
| return self.metric_embedding(metric_ids) |
|
|
| @torch.no_grad() |
| def score_matrix( |
| self, |
| task_ids: torch.LongTensor, |
| dataset_desp_batch: torch.Tensor, |
| model_cache: dict, |
| metric_ids: Optional[torch.LongTensor] = None, |
| chunk_size: int = 8192, |
| ) -> torch.Tensor: |
| device = dataset_desp_batch.device |
| B = dataset_desp_batch.size(0) |
|
|
| h_task = self.task_embedding(task_ids) |
| h_data = dataset_desp_batch |
| h_metric = self._metric_embed(metric_ids, B, device) |
|
|
| h_model_all = model_cache["h_model"] |
| h_size_all = model_cache["h_size"] |
| h_family_all = model_cache.get("h_family") |
| M = h_model_all.size(0) |
|
|
| if self.use_size_prior or self.use_family_prior: |
| if h_family_all is not None: |
| prior_inp_all = torch.cat([h_size_all, h_family_all], dim=-1) |
| else: |
| prior_inp_all = h_size_all |
| prior_all = self.prior_head(prior_inp_all).squeeze(-1) |
| else: |
| prior_all = torch.zeros(M, device=device) |
|
|
| out = torch.empty(B, M, device=device) |
| T = torch.clamp(self.temperature, min=1e-3) |
|
|
| start = 0 |
| while start < M: |
| end = min(start + chunk_size, M) |
| m = end - start |
| h_model = h_model_all[start:end] |
| h_size = h_size_all[start:end] |
|
|
| h_model_exp = h_model.unsqueeze(0).expand(B, m, -1) |
| h_size_exp = h_size.unsqueeze(0).expand(B, m, -1) |
| h_data_exp = h_data.unsqueeze(1).expand(B, m, -1) |
| h_task_exp = h_task.unsqueeze(1).expand(B, m, -1) |
|
|
| parts = [h_model_exp, h_data_exp, h_size_exp] |
| if h_family_all is not None: |
| h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1) |
| parts.append(h_family_exp) |
| parts.append(h_task_exp) |
| if h_metric is not None: |
| parts.append(h_metric.unsqueeze(1).expand(B, m, -1)) |
| residual_inp = torch.cat(parts, dim=-1) |
|
|
| h = self.backbone(residual_inp.reshape(B * m, -1)) |
| s_chunk = self.pairwise_head(h).reshape(B, m) |
| prior_chunk = prior_all[start:end].unsqueeze(0) |
| out[:, start:end] = (s_chunk + prior_chunk) / T |
| start = end |
| return out |
|
|
|
|
| class MLPMetricFull(MLPMetric): |
| """Full-feature recommender. Uses model-id emb, model-name emb, model-desc |
| emb, dataset-id emb, and dataset-desc emb. |
| |
| For inference on a *new user dataset* (no global dataset_id), we: |
| - feed UNK as dataset_id (so dataset_id_embedding still contributes a |
| learned [UNK] prior), |
| - feed the user's OpenAI embedding directly as dataset_desc_emb, |
| bypassing the training-time ``dataset_desc_matrix`` lookup. |
| |
| Parameter layout matches the training-time class so the state_dict loads |
| via ``load_state_dict(strict=False)`` after stripping the buffers that |
| are only useful for the train-set IDs. |
| """ |
|
|
| def __init__(self, args): |
| |
| self.dataset_id_emb_dim = int(getattr(args, "dataset_id_emb_dim", 256)) |
| self.dataset_desp_emb_dim = int(getattr(args, "dataset_desp_emb_dim", 1536)) |
| self.model_desp_emb_dim = int(getattr(args, "model_desp_emb_dim", 1536)) |
|
|
| |
| self.use_model_id_emb = bool(getattr(args, "use_model_id_emb", True)) |
| self.use_model_name_emb = bool(getattr(args, "use_model_name_emb", True)) |
| self.use_model_desc_emb = bool(getattr(args, "use_model_desc_emb", True)) |
| self.use_dataset_id_emb = bool(getattr(args, "use_dataset_id_emb", True)) |
| self.use_dataset_desc_emb = bool(getattr(args, "use_dataset_desc_emb", True)) |
| self.use_size_feature = bool(getattr(args, "use_size_feature", True)) |
|
|
| |
| |
| |
| |
| orig_desp_dim = args.dataset_desp_dim |
| args.dataset_desp_dim = self.dataset_id_emb_dim + self.dataset_desp_emb_dim |
| super().__init__(args) |
| args.dataset_desp_dim = orig_desp_dim |
|
|
| |
| if self.use_model_name_emb: |
| args_name_only = SimpleNamespace(**vars(args)) |
| args_name_only.use_id_emb = False |
| self._name_encoder = ModelNameAvgEncoder(args_name_only) |
| else: |
| self._name_encoder = None |
|
|
| if self.use_model_id_emb: |
| self._id_emb = nn.Embedding(args.num_models + 1, args.model_dim) |
| self.unk_model_id = args.num_models |
| else: |
| self._id_emb = None |
| self.unk_model_id = 0 |
|
|
| |
| if self.use_model_desc_emb: |
| self.register_buffer( |
| "model_desc_matrix", |
| torch.zeros(args.num_models, self.model_desp_emb_dim), |
| ) |
| else: |
| self.register_buffer( |
| "model_desc_matrix", |
| torch.zeros(0, self.model_desp_emb_dim), |
| ) |
|
|
| |
| num_datasets = int(getattr(args, "num_datasets", 100000)) |
| if self.use_dataset_id_emb: |
| |
| self.dataset_id_embedding = nn.Embedding(num_datasets + 2, self.dataset_id_emb_dim) |
| self.unk_dataset_id = num_datasets + 1 |
| else: |
| self.dataset_id_embedding = None |
| self.unk_dataset_id = 0 |
|
|
| |
| |
| |
|
|
| |
| model_info_dim = ( |
| (args.token_dim if self.use_model_name_emb else 0) |
| + (args.model_dim if self.use_model_id_emb else 0) |
| + (self.model_desp_emb_dim if self.use_model_desc_emb else 0) |
| ) |
| self.model_info_dim = model_info_dim |
|
|
| dataset_emb_dim = ( |
| (self.dataset_id_emb_dim if self.use_dataset_id_emb else 0) |
| + (self.dataset_desp_emb_dim if self.use_dataset_desc_emb else 0) |
| ) |
| self.dataset_emb_dim = dataset_emb_dim |
| dataset_info_dim = dataset_emb_dim + args.task_dim |
| metric_dim = self.metric_dim if self.use_metric_embedding else 0 |
| size_emb_dim_eff = args.size_dim if self.use_size_feature else 0 |
| backbone_in = ( |
| model_info_dim |
| + dataset_info_dim |
| + size_emb_dim_eff |
| + self.family_dim |
| + metric_dim |
| ) |
| self.backbone = nn.Sequential( |
| nn.Linear(backbone_in, args.hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(args.dropout_rate), |
| nn.Linear(args.hidden_dim, args.hidden_dim), |
| nn.ReLU(), |
| nn.Dropout(args.dropout_rate), |
| ) |
|
|
| prior_in_actual = 0 |
| if self.use_size_prior and self.use_size_feature: |
| prior_in_actual += args.size_dim |
| if self.use_family_prior: |
| prior_in_actual += self.family_dim |
| if prior_in_actual > 0: |
| self.prior_head = nn.Sequential( |
| nn.Linear(prior_in_actual, args.hidden_dim // 2), |
| nn.ReLU(), |
| nn.Linear(args.hidden_dim // 2, 1), |
| ) |
|
|
| |
| |
| |
| def encode_model( |
| self, model_ids: torch.LongTensor, model_names: list[str], |
| ) -> torch.Tensor: |
| B = model_ids.shape[0] |
| device = model_ids.device |
| parts = [] |
| if self.use_model_name_emb: |
| parts.append(self._name_encoder(model_ids, model_names)) |
| if self.use_model_id_emb: |
| parts.append(self._id_emb(model_ids)) |
| if self.use_model_desc_emb: |
| if self.model_desc_matrix.shape[0] > 0: |
| safe_ids = model_ids.clamp(0, self.model_desc_matrix.shape[0] - 1) |
| parts.append(self.model_desc_matrix[safe_ids]) |
| else: |
| parts.append(torch.zeros(B, self.model_desp_emb_dim, device=device)) |
| if not parts: |
| return torch.zeros(B, 0, device=device) |
| if len(parts) == 1: |
| return parts[0] |
| return torch.cat(parts, dim=-1) |
|
|
| @torch.no_grad() |
| def build_model_cache( |
| self, |
| all_model_names: list[str], |
| all_model_size_ids: torch.LongTensor, |
| all_model_family_ids: Optional[torch.LongTensor] = None, |
| device=None, |
| ): |
| if device is None: |
| device = next(self.parameters()).device |
| size_ids = all_model_size_ids.to(device=device, dtype=torch.long) |
| M = len(all_model_names) |
| assert size_ids.shape[0] == M |
| model_ids = torch.arange(M, device=device, dtype=torch.long) |
|
|
| h_model = self.encode_model(model_ids, all_model_names) |
| h_size = self.size_embedding(size_ids) if self.use_size_feature else None |
| cache = {"h_model": h_model, "h_size": h_size, "size_ids": size_ids} |
| if self.use_family_prior and all_model_family_ids is not None: |
| family_ids = all_model_family_ids.to(device=device, dtype=torch.long) |
| cache["h_family"] = self.family_embedding(family_ids) |
| cache["family_ids"] = family_ids |
| else: |
| cache["h_family"] = None |
| cache["family_ids"] = None |
| return cache |
|
|
| |
| |
| |
| def _encode_dataset_at_inference( |
| self, dataset_desp_emb: torch.Tensor, |
| ) -> torch.Tensor: |
| """``dataset_desp_emb`` is the user's OpenAI vector of shape |
| ``[B, dataset_desp_emb_dim]``. We pair it with a learned [UNK] |
| dataset-id embedding, matching the training-time concatenation order |
| (id_emb || desc_emb). |
| """ |
| B = dataset_desp_emb.shape[0] |
| device = dataset_desp_emb.device |
| parts = [] |
| if self.use_dataset_id_emb and self.dataset_id_embedding is not None: |
| unk = torch.full((B,), int(self.unk_dataset_id), dtype=torch.long, device=device) |
| parts.append(self.dataset_id_embedding(unk)) |
| if self.use_dataset_desc_emb: |
| parts.append(dataset_desp_emb) |
| if not parts: |
| return torch.zeros(B, 0, device=device) |
| if len(parts) == 1: |
| return parts[0] |
| return torch.cat(parts, dim=-1) |
|
|
| |
| |
| |
| @torch.no_grad() |
| def score_matrix( |
| self, |
| task_ids: torch.LongTensor, |
| dataset_desp_batch: torch.Tensor, |
| model_cache: dict, |
| metric_ids: Optional[torch.LongTensor] = None, |
| chunk_size: int = 8192, |
| ) -> torch.Tensor: |
| """``dataset_desp_batch`` here is the OpenAI embedding ``[B, 1536]``.""" |
| device = dataset_desp_batch.device |
| B = dataset_desp_batch.size(0) |
|
|
| h_task = self.task_embedding(task_ids) |
| h_data = self._encode_dataset_at_inference(dataset_desp_batch) |
| h_metric = self._metric_embed(metric_ids, B, device) |
|
|
| h_model_all = model_cache["h_model"] |
| h_size_all = model_cache["h_size"] if self.use_size_feature else None |
| h_family_all = model_cache.get("h_family") |
| M = h_model_all.size(0) |
|
|
| prior_parts_all = [] |
| if self.use_size_prior and h_size_all is not None: |
| prior_parts_all.append(h_size_all) |
| if self.use_family_prior and h_family_all is not None: |
| prior_parts_all.append(h_family_all) |
| if prior_parts_all: |
| prior_inp_all = ( |
| torch.cat(prior_parts_all, dim=-1) if len(prior_parts_all) > 1 else prior_parts_all[0] |
| ) |
| prior_all = self.prior_head(prior_inp_all).squeeze(-1) |
| else: |
| prior_all = torch.zeros(M, device=device) |
|
|
| out = torch.empty(B, M, device=device) |
| T = torch.clamp(self.temperature, min=1e-3) |
|
|
| start = 0 |
| while start < M: |
| end = min(start + chunk_size, M) |
| m = end - start |
|
|
| h_model = h_model_all[start:end] |
| h_model_exp = h_model.unsqueeze(0).expand(B, m, -1) if h_model.shape[1] > 0 else None |
| h_data_exp = h_data.unsqueeze(1).expand(B, m, -1) if h_data.shape[1] > 0 else None |
| h_task_exp = h_task.unsqueeze(1).expand(B, m, -1) |
| h_size_exp = ( |
| h_size_all[start:end].unsqueeze(0).expand(B, m, -1) |
| if h_size_all is not None else None |
| ) |
| h_metric_exp = ( |
| h_metric.unsqueeze(1).expand(B, m, -1) if h_metric is not None else None |
| ) |
|
|
| parts = [] |
| if h_model_exp is not None: |
| parts.append(h_model_exp) |
| if h_data_exp is not None: |
| parts.append(h_data_exp) |
| if h_size_exp is not None: |
| parts.append(h_size_exp) |
| if h_family_all is not None: |
| h_family_exp = h_family_all[start:end].unsqueeze(0).expand(B, m, -1) |
| parts.append(h_family_exp) |
| parts.append(h_task_exp) |
| if h_metric_exp is not None: |
| parts.append(h_metric_exp) |
| residual_inp = torch.cat(parts, dim=-1) |
|
|
| h = self.backbone(residual_inp.reshape(B * m, -1)) |
| s_chunk = self.pairwise_head(h).reshape(B, m) |
| prior_chunk = prior_all[start:end].unsqueeze(0) |
| out[:, start:end] = (s_chunk + prior_chunk) / T |
| start = end |
|
|
| return out |
|
|