from predictors import base_predictors, ev_predictors, hmm_predictors, onehot_predictors from predictors import unirep_predictors, esm_predictors, vae_predictors from predictors.base_predictors import BoostingPredictor, JointPredictor BASIC_PREDICTOR_MAP = { 'mutation': base_predictors.MutationRadiusPredictor, 'ev': ev_predictors.EVPredictor, 'onehot': onehot_predictors.OnehotRidgePredictor, 'georgiev': onehot_predictors.GeorgievRidgePredictor, 'eunirep_reg': unirep_predictors.EUniRepRegressionPredictor, 'gunirep_reg': unirep_predictors.GUniRepRegressionPredictor, 'eunirep_ll': unirep_predictors.EUniRepLLPredictor, 'gunirep_ll': unirep_predictors.GUniRepLLPredictor, 'hmm': hmm_predictors.HMMPredictor, 'blosum': base_predictors.BLOSUM62Predictor, 'gesm': esm_predictors.GlobalESMPredictor, 'gesm_reg': esm_predictors.GlobalESMRegressionPredictor, 'vae': vae_predictors.VaePredictor, } CORE_PREDICTORS = [ 'eunirep_reg', 'ev+onehot', 'gesm+onehot', 'eunirep_ll+onehot', 'vae+onehot', ] BASELINE_PREDICTORS = [ 'georgiev', 'onehot', 'hmm+onehot', 'blosum+onehot', 'mutation+onehot', ] ADDITIONAL_PREDICTORS = [ 'gunirep_ll+onehot', 'gesm_reg', ] UNSUPERVISED_PREDICTORS = [ 'ev', 'vae', 'hmm', 'blosum', 'mutation', 'eunirep_ll', 'gunirep_ll', 'gesm', ] def get_predictor_cls(predictor_name): names = predictor_name.split('+') return [BASIC_PREDICTOR_MAP[n] for n in names] def get_predictor_names(key): if key == 'core': return CORE_PREDICTORS elif key == 'baselines': return BASELINE_PREDICTORS elif key == 'additional': return ADDITIONAL_PREDICTORS elif key == 'unsupervised': return UNSUPERVISED_PREDICTORS elif key == 'all': return CORE_PREDICTORS + BASELINE_PREDICTORS + ADDITIONAL_PREDICTORS + UNSUPERVISED_PREDICTORS else: return [key]