|
|
from transformers import PretrainedConfig |
|
|
|
|
|
class GCNConfig(PretrainedConfig): |
|
|
model_type = "gcn" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_feature: int=64, |
|
|
emb_input: int=20, |
|
|
hidden_size: int=64, |
|
|
n_layers: int=6, |
|
|
num_classes: int=1, |
|
|
**kwargs, |
|
|
): |
|
|
|
|
|
self.input_feature = input_feature |
|
|
self.emb_input = emb_input |
|
|
self.hidden_size = hidden_size |
|
|
self.n_layers = n_layers |
|
|
self.num_classes = num_classes |
|
|
|
|
|
super().__init__(**kwargs) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
gcn_config = GCNConfig(input_feature=64, emb_input=20, hidden_size=64, n_layers=6, num_classes=1) |
|
|
gcn_config.save_pretrained("custom-gcn") |