matchbench.model package

Subpackages

Submodules

matchbench.model.base_model module

class matchbench.model.base_model.CTAModel(multi_label, loss_function)

Bases: GeneralModel

calculate_loss(logits, label)

Args: logits, label Returns: loss

compute_metric(prediction, label)
load_source_target(dataset_src, dataset_tgt, **kwargs)
training: bool
class matchbench.model.base_model.EAModel

Bases: GeneralModel

Base class for entity alignment models.

compute_metric(prediction, stage=1)

Calculate the hits. :param prediction: the top indexes of each testing entity. :type prediction: torch.tensor

Returns:

hits@1

Return type:

float

compute_metric_stage_2(prediction)

Calculate the hits. :param prediction: the top indexes of each testing entity. :type prediction: torch.tensor

Returns:

hits@1

Return type:

float

get_emb(loader, device=device(type='cuda')) Tensor

Convert a list of entities token ids to a list of embedding after encoder. :param loader: :type loader: torch.data.utils.DataLoader :param device: cuda or cpu. :type device: torch.device, optional, defaults to “cuda”

Returns:

the output embeddings of the encoder.

Return type:

torch.tensor

get_emb_r(loader: DataLoader, model, rel_embedding, all_embed, mode, device=device(type='cuda')) Tensor
load_source_target(dataset_src, dataset_tgt)

Prepare source and target data. :param dataset_src: Source dataset. :type dataset_src: datasets.arrow_dataset.Dataset :param dataset_tgt: Target dataset. :type dataset_tgt: datasets.arrow_dataset.Dataset

predict(dataloader, device=device(type='cuda'), stage=1, train_dataloader=None)

Predict the closest entities of each entity in test set. :param dataloader: :type dataloader: torch.data.utils.DataLoader

Returns:

top indexes.

Return type:

torch.tensor

predict_stage_2(dataloader, device=device(type='cuda'))
training: bool
class matchbench.model.base_model.EMModel

Bases: GeneralModel

The basic model for entity matching.

calculate_thresold(all_y, all_probs)

Calculate the threshold on 1-class, if probability greater than threshole, the predicted label will be 1. :param all_y: The label for dataset. :type all_y: List of int :param all_probs: The probability for dataset. :type all_probs: List of double

Returns:

The threshold on 1-class.

Return type:

double

compute_metric(label, pred)

Compute metric for entity matching. :param label: The true label. :type label: List of int :param pred: The result predicted by EM model. :type pred: List of int

Returns:

The f1 score.

Return type:

float

get_sentence(table)

Get sentence from table. :param table: Table that need to be processed. :type table: datasets.arrow_dataset.Dataset

Returns:

The List of sentences.

Return type:

List of str

load_source_target(data_src, data_tgt)

Prepare source and target data. :param data_src: Source dataset. :type data_src: datasets.dataset_dict.DatasetDict :param data_tgt: Target dataset. :type data_tgt: datasets.dataset_dict.DatasetDict

training: bool
class matchbench.model.base_model.GeneralModel(**kwargs)

Bases: Module, PyTorchModelHubMixin

calculate_loss()

Args: logits, label Returns: loss

encode()

for PLMs, serialization -> CLS embedding (trainable) for pretrained word embedding, pair -> embedding (untrainable)

forward()

# self.encoder: PLM serialization -> feature vector # self.matcher: feature vector -> logits for PLMs, encode + match for PWEs, match

match()

embedding -> logits

predict()

Args: batch data Returns: prediction

prepare_dataloader(dataset, split='train', batch_size=32)

for PLMs, serialize for PWEs, encode

run_step()

Args: batch data Returns: loss

serialize()

for PLMs, pair -> PLM serialization for pretrained word embedding, no need.

training: bool
class matchbench.model.base_model.SMModel

Bases: GeneralModel

compute_metric(match_results, ground_truth)
training: bool
matchbench.model.base_model.f1_score_multilabel(true_list, pred_list)

Module contents