custom-gcn / configuration_gcn.py
Huhujingjing's picture
Upload model
40fcca6
raw
history blame
920 Bytes
from transformers import PretrainedConfig
class GCNConfig(PretrainedConfig):
model_type = "gcn"
def __init__(
self,
input_feature: int=64,
emb_input: int=20,
hidden_size: int=64,
n_layers: int=6,
num_classes: int=1,
**kwargs,
):
self.input_feature = input_feature # the dimension of input feature
self.emb_input = emb_input # the embedding dimension of input feature
self.hidden_size = hidden_size # the hidden size of GCN
self.n_layers = n_layers # the number of GCN layers
self.num_classes = num_classes # the number of output classes
super().__init__(**kwargs)
if __name__ == "__main__":
gcn_config = GCNConfig(input_feature=64, emb_input=20, hidden_size=64, n_layers=6, num_classes=1)
gcn_config.save_pretrained("custom-gcn")