matchbench.model package¶
Subpackages¶
- matchbench.model.column_type_annotation package
- matchbench.model.entity_alignment package
- Submodules
- matchbench.model.entity_alignment.Dual_AMN module
- matchbench.model.entity_alignment.LargeEA module
- matchbench.model.entity_alignment.bertint module
Basic_Bert_Unit_modelBertIntMLPPairwiseDatasetRelationalDataloaderall_entity_pairs_gene()attributeValue_emb_gene()attributeView_interaction_F_gene()batch_dual_aggregation_feature_gene()candidate_generate()clean_attribute_data()desornameView_interaction_F_gene()dump_other_data()ent2attributeValues_gene()get_attributeValue_embedding()get_attributeView_interaction_feature()get_attribute_value_type()get_entity_embedding()get_neighView_and_desView_interaction_feature()get_tokens_of_value()kernel_mus()kernel_sigmas()neigh_ent_dict_gene()neighborView_interaction_F_gene()padding_to_longest()read_att_data()read_attribute_datas()remove_one_to_N_att_data_by_threshold()sort_a()test_read_emb()
- matchbench.model.entity_alignment.kecg module
- matchbench.model.entity_alignment.lightea module
- matchbench.model.entity_alignment.rrea module
- matchbench.model.entity_alignment.sdea module
Basic_Bert_Unit_modelBertDataLoaderDBPpreprocess()DatasetGRUAttnNetHighwayKBStoreKBStore.add_fact()KBStore.add_item()KBStore.add_to_blocks()KBStore.add_tuple()KBStore.add_word_level_blocks()KBStore.calculate_func()KBStore.get_or_add_item()KBStore.get_property_table_line()KBStore.load()KBStore.load_entities()KBStore.load_facts()KBStore.load_kb()KBStore.load_kb_from_saved()KBStore.load_literals()KBStore.load_properties()KBStore.load_relations()KBStore.save_base_info()KBStore.save_datas()KBStore.save_facts()KBStore.save_property_table()KBStore.save_seq_form()
OEAFileTypePairwiseDatasetRelationDatasetRelationModelRelationModel.rnnRelationModel.combinerRelationModel.ent_embedding1RelationModel.ent_embedding2RelationModel.case_study()RelationModel.forward()RelationModel.get_emb()RelationModel.get_ent_embedding()RelationModel.get_neighbors_batch()RelationModel.get_rel_embeds()RelationModel.pos_neg_count()RelationModel.training
RelationValidDatasetSDEASDEA.if_neg_sample_2SDEA.all_embed1sSDEA.all_embed2sSDEA.entity_modeSDEA.a()SDEA.calculate_loss()SDEA.class_name_str()SDEA.encode()SDEA.forward()SDEA.get_tensor_data()SDEA.load_links()SDEA.load_links_sep()SDEA.load_source_target()SDEA.negative_sample()SDEA.oea_truth_line()SDEA.prepare_dataloader()SDEA.reduce_tokens()SDEA.reduce_tokens_with_freq()SDEA.run_step()SDEA.training
compress_uri()load_list()load_list_p()oea_attr_line()oea_rel_line()save_dict_p()save_list()save_list_p()stripSquareBrackets()strip_square_brackets()text_to_word_sequence()ttl_no_compress_line()
- matchbench.model.entity_alignment.seu module
- Module contents
- matchbench.model.entity_matching package
- Submodules
- matchbench.model.entity_matching.deepmatcher module
- matchbench.model.entity_matching.ditto module
- matchbench.model.entity_matching.jointbert module
- matchbench.model.entity_matching.robem module
ASLSingleLabelAugModeBasicAugEmDatasetHighwayRobEMRobEM.context_forward()RobEM.context_similarity_layers()RobEM.forward()RobEM.get_lm()RobEM.get_lm_class()RobEM.get_lm_dim()RobEM.get_tokenizers()RobEM.has_type_token()RobEM.load_source_target()RobEM.predict()RobEM.prepare_dataloader()RobEM.reset_weights()RobEM.resize_embedding()RobEM.run_step()RobEM.training
RobertaClassificationHeadRobustAugmenterSimpleClassifierSummarizercosine_similarity()set_to_device()
- matchbench.model.entity_matching.rotom module
- Module contents
- matchbench.model.schema_matching package
- Submodules
- matchbench.model.schema_matching.embdi module
EdgeEdgeListEdgeList.convert_cell_value()EdgeList.convert_to_dict()EdgeList.convert_to_numeric()EdgeList.evaluate_frequencies()EdgeList.f_no_smoothing()EdgeList.find_intersection_flatten()EdgeList.get_edgelist()EdgeList.get_prefixes()EdgeList.inverse_freq()EdgeList.inverse_smooth()EdgeList.log_freq()EdgeList.prepare_split()EdgeList.smooth_exp()EdgeList.smooth_freq()
EmbdiGraphNodeRandomWalk
- Module contents
Submodules¶
matchbench.model.base_model module¶
- class matchbench.model.base_model.CTAModel(multi_label, loss_function)¶
Bases:
GeneralModel- calculate_loss(logits, label)¶
Args: logits, label Returns: loss
- compute_metric(prediction, label)¶
- load_source_target(dataset_src, dataset_tgt, **kwargs)¶
- class matchbench.model.base_model.EAModel¶
Bases:
GeneralModelBase class for entity alignment models.
- compute_metric(prediction, stage=1)¶
Calculate the hits. :param prediction: the top indexes of each testing entity. :type prediction: torch.tensor
- Returns:
hits@1
- Return type:
float
- compute_metric_stage_2(prediction)¶
Calculate the hits. :param prediction: the top indexes of each testing entity. :type prediction: torch.tensor
- Returns:
hits@1
- Return type:
float
- get_emb(loader, device=device(type='cuda')) Tensor¶
Convert a list of entities token ids to a list of embedding after encoder. :param loader: :type loader: torch.data.utils.DataLoader :param device: cuda or cpu. :type device: torch.device, optional, defaults to “cuda”
- Returns:
the output embeddings of the encoder.
- Return type:
torch.tensor
- get_emb_r(loader: DataLoader, model, rel_embedding, all_embed, mode, device=device(type='cuda')) Tensor¶
- load_source_target(dataset_src, dataset_tgt)¶
Prepare source and target data. :param dataset_src: Source dataset. :type dataset_src: datasets.arrow_dataset.Dataset :param dataset_tgt: Target dataset. :type dataset_tgt: datasets.arrow_dataset.Dataset
- predict(dataloader, device=device(type='cuda'), stage=1, train_dataloader=None)¶
Predict the closest entities of each entity in test set. :param dataloader: :type dataloader: torch.data.utils.DataLoader
- Returns:
top indexes.
- Return type:
torch.tensor
- predict_stage_2(dataloader, device=device(type='cuda'))¶
- class matchbench.model.base_model.EMModel¶
Bases:
GeneralModelThe basic model for entity matching.
- calculate_thresold(all_y, all_probs)¶
Calculate the threshold on 1-class, if probability greater than threshole, the predicted label will be 1. :param all_y: The label for dataset. :type all_y: List of int :param all_probs: The probability for dataset. :type all_probs: List of double
- Returns:
The threshold on 1-class.
- Return type:
double
- compute_metric(label, pred)¶
Compute metric for entity matching. :param label: The true label. :type label: List of int :param pred: The result predicted by EM model. :type pred: List of int
- Returns:
The f1 score.
- Return type:
float
- get_sentence(table)¶
Get sentence from table. :param table: Table that need to be processed. :type table: datasets.arrow_dataset.Dataset
- Returns:
The List of sentences.
- Return type:
List of str
- load_source_target(data_src, data_tgt)¶
Prepare source and target data. :param data_src: Source dataset. :type data_src: datasets.dataset_dict.DatasetDict :param data_tgt: Target dataset. :type data_tgt: datasets.dataset_dict.DatasetDict
- class matchbench.model.base_model.GeneralModel(**kwargs)¶
Bases:
Module,PyTorchModelHubMixin- calculate_loss()¶
Args: logits, label Returns: loss
- encode()¶
for PLMs, serialization -> CLS embedding (trainable) for pretrained word embedding, pair -> embedding (untrainable)
- forward()¶
# self.encoder: PLM serialization -> feature vector # self.matcher: feature vector -> logits for PLMs, encode + match for PWEs, match
- match()¶
embedding -> logits
- predict()¶
Args: batch data Returns: prediction
- prepare_dataloader(dataset, split='train', batch_size=32)¶
for PLMs, serialize for PWEs, encode
- run_step()¶
Args: batch data Returns: loss
- serialize()¶
for PLMs, pair -> PLM serialization for pretrained word embedding, no need.
- class matchbench.model.base_model.SMModel¶
Bases:
GeneralModel- compute_metric(match_results, ground_truth)¶
- matchbench.model.base_model.f1_score_multilabel(true_list, pred_list)¶