matchbench.model package¶
Subpackages¶
- matchbench.model.column_type_annotation package
- matchbench.model.entity_alignment package
- Submodules
- matchbench.model.entity_alignment.Dual_AMN module
- matchbench.model.entity_alignment.LargeEA module
- matchbench.model.entity_alignment.bertint module
Basic_Bert_Unit_model
BertInt
MLP
PairwiseDataset
RelationalDataloader
all_entity_pairs_gene()
attributeValue_emb_gene()
attributeView_interaction_F_gene()
batch_dual_aggregation_feature_gene()
candidate_generate()
clean_attribute_data()
desornameView_interaction_F_gene()
dump_other_data()
ent2attributeValues_gene()
get_attributeValue_embedding()
get_attributeView_interaction_feature()
get_attribute_value_type()
get_entity_embedding()
get_neighView_and_desView_interaction_feature()
get_tokens_of_value()
kernel_mus()
kernel_sigmas()
neigh_ent_dict_gene()
neighborView_interaction_F_gene()
padding_to_longest()
read_att_data()
read_attribute_datas()
remove_one_to_N_att_data_by_threshold()
sort_a()
test_read_emb()
- matchbench.model.entity_alignment.kecg module
- matchbench.model.entity_alignment.lightea module
- matchbench.model.entity_alignment.rrea module
- matchbench.model.entity_alignment.sdea module
Basic_Bert_Unit_model
BertDataLoader
DBPpreprocess()
Dataset
GRUAttnNet
Highway
KBStore
KBStore.add_fact()
KBStore.add_item()
KBStore.add_to_blocks()
KBStore.add_tuple()
KBStore.add_word_level_blocks()
KBStore.calculate_func()
KBStore.get_or_add_item()
KBStore.get_property_table_line()
KBStore.load()
KBStore.load_entities()
KBStore.load_facts()
KBStore.load_kb()
KBStore.load_kb_from_saved()
KBStore.load_literals()
KBStore.load_properties()
KBStore.load_relations()
KBStore.save_base_info()
KBStore.save_datas()
KBStore.save_facts()
KBStore.save_property_table()
KBStore.save_seq_form()
OEAFileType
PairwiseDataset
RelationDataset
RelationModel
RelationModel.rnn
RelationModel.combiner
RelationModel.ent_embedding1
RelationModel.ent_embedding2
RelationModel.case_study()
RelationModel.forward()
RelationModel.get_emb()
RelationModel.get_ent_embedding()
RelationModel.get_neighbors_batch()
RelationModel.get_rel_embeds()
RelationModel.pos_neg_count()
RelationModel.training
RelationValidDataset
SDEA
SDEA.if_neg_sample_2
SDEA.all_embed1s
SDEA.all_embed2s
SDEA.entity_mode
SDEA.a()
SDEA.calculate_loss()
SDEA.class_name_str()
SDEA.encode()
SDEA.forward()
SDEA.get_tensor_data()
SDEA.load_links()
SDEA.load_links_sep()
SDEA.load_source_target()
SDEA.negative_sample()
SDEA.oea_truth_line()
SDEA.prepare_dataloader()
SDEA.reduce_tokens()
SDEA.reduce_tokens_with_freq()
SDEA.run_step()
SDEA.training
compress_uri()
load_list()
load_list_p()
oea_attr_line()
oea_rel_line()
save_dict_p()
save_list()
save_list_p()
stripSquareBrackets()
strip_square_brackets()
text_to_word_sequence()
ttl_no_compress_line()
- matchbench.model.entity_alignment.seu module
- Module contents
- matchbench.model.entity_matching package
- Submodules
- matchbench.model.entity_matching.deepmatcher module
- matchbench.model.entity_matching.ditto module
- matchbench.model.entity_matching.jointbert module
- matchbench.model.entity_matching.robem module
ASLSingleLabel
AugMode
BasicAug
EmDataset
Highway
RobEM
RobEM.context_forward()
RobEM.context_similarity_layers()
RobEM.forward()
RobEM.get_lm()
RobEM.get_lm_class()
RobEM.get_lm_dim()
RobEM.get_tokenizers()
RobEM.has_type_token()
RobEM.load_source_target()
RobEM.predict()
RobEM.prepare_dataloader()
RobEM.reset_weights()
RobEM.resize_embedding()
RobEM.run_step()
RobEM.training
RobertaClassificationHead
RobustAugmenter
SimpleClassifier
Summarizer
cosine_similarity()
set_to_device()
- matchbench.model.entity_matching.rotom module
- Module contents
- matchbench.model.schema_matching package
- Submodules
- matchbench.model.schema_matching.embdi module
Edge
EdgeList
EdgeList.convert_cell_value()
EdgeList.convert_to_dict()
EdgeList.convert_to_numeric()
EdgeList.evaluate_frequencies()
EdgeList.f_no_smoothing()
EdgeList.find_intersection_flatten()
EdgeList.get_edgelist()
EdgeList.get_prefixes()
EdgeList.inverse_freq()
EdgeList.inverse_smooth()
EdgeList.log_freq()
EdgeList.prepare_split()
EdgeList.smooth_exp()
EdgeList.smooth_freq()
Embdi
Graph
Node
RandomWalk
- Module contents
Submodules¶
matchbench.model.base_model module¶
- class matchbench.model.base_model.CTAModel(multi_label, loss_function)¶
Bases:
GeneralModel
- calculate_loss(logits, label)¶
Args: logits, label Returns: loss
- compute_metric(prediction, label)¶
- load_source_target(dataset_src, dataset_tgt, **kwargs)¶
- class matchbench.model.base_model.EAModel¶
Bases:
GeneralModel
Base class for entity alignment models.
- compute_metric(prediction, stage=1)¶
Calculate the hits. :param prediction: the top indexes of each testing entity. :type prediction: torch.tensor
- Returns:
hits@1
- Return type:
float
- compute_metric_stage_2(prediction)¶
Calculate the hits. :param prediction: the top indexes of each testing entity. :type prediction: torch.tensor
- Returns:
hits@1
- Return type:
float
- get_emb(loader, device=device(type='cuda')) Tensor ¶
Convert a list of entities token ids to a list of embedding after encoder. :param loader: :type loader: torch.data.utils.DataLoader :param device: cuda or cpu. :type device: torch.device, optional, defaults to “cuda”
- Returns:
the output embeddings of the encoder.
- Return type:
torch.tensor
- get_emb_r(loader: DataLoader, model, rel_embedding, all_embed, mode, device=device(type='cuda')) Tensor ¶
- load_source_target(dataset_src, dataset_tgt)¶
Prepare source and target data. :param dataset_src: Source dataset. :type dataset_src: datasets.arrow_dataset.Dataset :param dataset_tgt: Target dataset. :type dataset_tgt: datasets.arrow_dataset.Dataset
- predict(dataloader, device=device(type='cuda'), stage=1, train_dataloader=None)¶
Predict the closest entities of each entity in test set. :param dataloader: :type dataloader: torch.data.utils.DataLoader
- Returns:
top indexes.
- Return type:
torch.tensor
- predict_stage_2(dataloader, device=device(type='cuda'))¶
- class matchbench.model.base_model.EMModel¶
Bases:
GeneralModel
The basic model for entity matching.
- calculate_thresold(all_y, all_probs)¶
Calculate the threshold on 1-class, if probability greater than threshole, the predicted label will be 1. :param all_y: The label for dataset. :type all_y: List of int :param all_probs: The probability for dataset. :type all_probs: List of double
- Returns:
The threshold on 1-class.
- Return type:
double
- compute_metric(label, pred)¶
Compute metric for entity matching. :param label: The true label. :type label: List of int :param pred: The result predicted by EM model. :type pred: List of int
- Returns:
The f1 score.
- Return type:
float
- get_sentence(table)¶
Get sentence from table. :param table: Table that need to be processed. :type table: datasets.arrow_dataset.Dataset
- Returns:
The List of sentences.
- Return type:
List of str
- load_source_target(data_src, data_tgt)¶
Prepare source and target data. :param data_src: Source dataset. :type data_src: datasets.dataset_dict.DatasetDict :param data_tgt: Target dataset. :type data_tgt: datasets.dataset_dict.DatasetDict
- class matchbench.model.base_model.GeneralModel(**kwargs)¶
Bases:
Module
,PyTorchModelHubMixin
- calculate_loss()¶
Args: logits, label Returns: loss
- encode()¶
for PLMs, serialization -> CLS embedding (trainable) for pretrained word embedding, pair -> embedding (untrainable)
- forward()¶
# self.encoder: PLM serialization -> feature vector # self.matcher: feature vector -> logits for PLMs, encode + match for PWEs, match
- match()¶
embedding -> logits
- predict()¶
Args: batch data Returns: prediction
- prepare_dataloader(dataset, split='train', batch_size=32)¶
for PLMs, serialize for PWEs, encode
- run_step()¶
Args: batch data Returns: loss
- serialize()¶
for PLMs, pair -> PLM serialization for pretrained word embedding, no need.
- class matchbench.model.base_model.SMModel¶
Bases:
GeneralModel
- compute_metric(match_results, ground_truth)¶
- matchbench.model.base_model.f1_score_multilabel(true_list, pred_list)¶