File size: 6,128 Bytes
7718235
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
library(ggplot2)
result.plot <- readRDS('figs/fig.5.prepare.RDS')
result.plot <- result.plot[result.plot$task.type=='Gene',]
result.plot$use.lw <- F
# remove itan tasks
result.plot <- result.plot[!grepl('.itan.split', result.plot$task.id),]
pick.cond <- 'auc'
# get unique models
uniq.models <- unique(gsub('.lw', '', result.plot$model))
# only keep the original models
uniq.models <- uniq.models[grepl('/$', uniq.models)]
uniq.genes <- unique(result.plot$task.id)
# for each gene and each fold, decide weather to use large window
for (g in uniq.genes) {
  for (m in uniq.models) {
    for (f in 0:4) {
      lw.loss <- result.plot$val.loss[result.plot$model == paste0(m, '.lw') & result.plot$task.id == g & result.plot$fold==f]
      loss <- result.plot$val.loss[result.plot$model == m & result.plot$task.id == g & result.plot$fold==f]
      lw.tr.auc <- result.plot$tr.auc[result.plot$model == paste0(m, '.lw') & result.plot$task.id == g & result.plot$fold==f]
      tr.auc <- result.plot$tr.auc[result.plot$model == m & result.plot$task.id == g & result.plot$fold==f]
      if (pick.cond == 'auc') {
        cond <- !is.na(mean(lw.tr.auc)) & lw.tr.auc > tr.auc
      } else if (pick.cond == 'loss') {
        cond <- !is.na(mean(lw.loss)) & loss > lw.loss
      } else if (pick.cond == 'auc+loss') {
        cond <- !is.na(lw.loss) & !is.na(lw.tr.auc) & (tr.auc/loss > lw.tr.auc/lw.loss)
      } else {
        cond <- F
      }
      if (cond) {
        # use lw
        to.remove <- which(result.plot$model == m & result.plot$task.id == g & result.plot$fold==f)
        to.anno <- which(result.plot$model == paste0(m, '.lw') & result.plot$task.id == g & result.plot$fold==f)
        result.plot$model[to.anno] <- m
        result.plot$use.lw[to.anno] <- T
        result.plot <- result.plot[-to.remove,]
      } else {
        to.remove <- which(result.plot$model == paste0(m, '.lw') & result.plot$task.id == g & result.plot$fold==f)
        result.plot <- result.plot[-to.remove,]
      }
    }
  }
}

result.plot$task.name[result.plot$task.id == "Q14524.clean"] <- "Gene: SCN5A"
result.plot <- result.plot[result.plot$model %in% c("PreMode/",
                                                    "PreMode.noESM/",
                                                    "PreMode.noMSA/",
                                                    "PreMode.noStructure/",
                                                    "PreMode.ptm/",
                                                    "ESM.SLP/",
                                                    "PreMode.noPretrain/",
                                                    "random.forest"
                                                    ),]
model.dic <- c("PreMode/"="1: PreMode",
               "PreMode.noPretrain/"="2: PreMode: no Pretrain",
               "random.forest"="3: Random Forest",
               "ESM.SLP/"="4: ESM + SLP",
               "PreMode.noESM/"="5: PreMode: no ESM",
               "PreMode.noMSA/"="6: PreMode: no MSA",
               "PreMode.noStructure/"="7: PreMode: no Structure",
               "PreMode.ptm/"="8: PreMode: add ptm")
result.plot$model <- model.dic[result.plot$model]

# plot the task weighted averages as well as task size weighted error bars
uniq.result.plot <- result.plot[result.plot$fold==0,]
for (i in 1:dim(uniq.result.plot)[1]) {
  aucs <- result.plot$auc[result.plot$model==uniq.result.plot$model[i] & 
                            result.plot$task.name==uniq.result.plot$task.name[i]]
  uniq.result.plot$auc[i] = mean(aucs, na.rm=T)
  uniq.result.plot$auc.se[i] = sd(aucs, na.rm=T) / sqrt(length(aucs))
}
# aggregate across models
uniq.model.result.plot <- uniq.result.plot[!duplicated(uniq.result.plot$model),]
for (i in 1:dim(uniq.model.result.plot)[1]) {
  task.sizes.lof <- uniq.result.plot$task.size.lof[uniq.result.plot$model==uniq.model.result.plot$model[i]] 
  task.sizes.gof <- uniq.result.plot$task.size.gof[uniq.result.plot$model==uniq.model.result.plot$model[i]]
  # change to harmonic average of task size
  task.sizes <- task.sizes.lof * task.sizes.gof / (task.sizes.lof + task.sizes.gof)
  aucs <- uniq.result.plot$auc[uniq.result.plot$model==uniq.model.result.plot$model[i]]
  auc.ses <- uniq.result.plot$auc.se[uniq.result.plot$model==uniq.model.result.plot$model[i]]
  # remove NA values
  task.sizes <- task.sizes[!is.na(aucs)]
  aucs <- aucs[!is.na(aucs)]
  auc.ses <- auc.ses[!is.na(auc.ses)]
  uniq.model.result.plot$auc[i] <- sum(aucs * task.sizes / sum(task.sizes), na.rm=T)
  uniq.model.result.plot$auc.se[i] <- sum(auc.ses * task.sizes / sum(task.sizes), na.rm=T)
}

uniq.model.result.plot$model.type <- 'PreMode: Ablation'
uniq.model.result.plot$model.type[uniq.model.result.plot$model == "8: PreMode: add ptm"] <- 'PreMode: add ptm'
uniq.model.result.plot$model.type[uniq.model.result.plot$model == "4: ESM + SLP"] <- 'Baselines'
uniq.model.result.plot$model.type[uniq.model.result.plot$model == "3: Random Forest"] <- 'Baselines'
uniq.model.result.plot$model.type[uniq.model.result.plot$model == "1: PreMode"] <- 'PreMode'
uniq.model.result.plot$model.type <- factor(uniq.model.result.plot$model.type, 
                                            levels = c('PreMode', 'PreMode: add ptm', 'PreMode: Ablation', 'Baselines'))
# chose other color scale
p <- ggplot(uniq.model.result.plot, aes(x=model, y=auc, col=model.type)) +
  geom_point() + scale_color_manual(values = c("#F8766D", "#CD9600", "#999999", "#619CFF")) + 
  geom_errorbar(aes(ymin=auc-auc.se, ymax=auc+auc.se), width=.2) +
  coord_flip() + guides(col=guide_legend(ncol=2)) +
  labs(x = "models", y = "auc", fill = "model") +
  theme_bw() + ylim(0.5, 0.9) + ggtitle('PreMode ablation analysis') +
  ggeasy::easy_center_title() +
  theme(axis.text.x = element_text(angle=60, vjust = 1, hjust = 1), 
        text = element_text(size = 16),
        plot.title = element_text(size=15),
        legend.text = element_text(size=10),
        legend.title = element_blank(),
        legend.position="bottom", 
        legend.direction="horizontal") + 
  ggeasy::easy_center_title()
ggsave('figs/fig.5b.pdf', p, height=5, width=6)