Commit d7b39a18 authored by Lannoy, Carlos de's avatar Lannoy, Carlos de
Browse files

add scripts for noise and supervision ramp

parent 8fbfdb2f
......@@ -25,6 +25,14 @@ python scripts/evaluate_traces.py --fb "simulated/eval_supplementary/5_10_state_
--outdir "simulated/eval_supplementary/5_10_state_equidist_eval/"
echo "Done"
echo "evaluating traces for noise_ramp..."
for f in 13.0 18.2 25.5 35.7 50.0 ; do
python scripts/evaluate_traces.py --fb "simulated/eval_supplementary/noise_ramp/${f}_fb" \
--manual "simulated/data/5_noise_ramp/${f}/dats_labeled/" \
--target-states 1 2 3 \
--outdir "simulated/eval_supplementary/noise_ramp/${f}_eval"
done
## Generate fig. 3
python scripts/generate_sim_fig.py --eval-dirs simulated/eval/*_eval --cat-names 1_2_state_don_eval 2_3_state_don 3_3_state_kinetic 4_10_state_equidist --out-svg "simulated/eval/sim_figure.svg"
......
......@@ -209,13 +209,15 @@ loader_fun = dict(matlab=matlab_load_fun, fretboard=fretboard_load_fun)[args.man
nb_classes = len(target_states_w_ground)
acc_list = []
acc_df = pd.DataFrame(0, columns=['correct', 'total'], index=args.categories)
acc_df = pd.DataFrame(0, columns=['correct', 'total', 'supervised_correct', 'supervised_total'], index=args.categories)
total_pts = 0
correct_pts = 0
max_state = 0
plt.rcParams.update({'font.size': 30}) # large text for trace plots
efret_dict = {cat:{st: [] for st in target_states_w_ground} for cat in args.categories}
efret_pred_dict = {cat:{st: [] for st in target_states_w_ground} for cat in args.categories}
efret_pred_dict = {cat: {st: [] for st in target_states_w_ground} for cat in args.categories}
# supervision_acc_dict = {'correct': 0, 'total': 0}
for fb in fb_files:
cat = [cat for cat in args.categories if cat in fb]
if not len(cat): continue
......@@ -224,7 +226,12 @@ for fb in fb_files:
fb_base = basename(fb)
if fb_base not in manual_dict: continue # skip if no ground truth file available
dat_df = pd.read_csv(fb, sep='\t')
if not dat_df.label.isnull().all(): continue # skip if read was used as labeled example
if not dat_df.label.isnull().all():
acc_df.loc[cat, 'supervised_correct'] = acc_df.loc[cat, 'supervised_correct'] + len(dat_df.query('label == predicted'))
acc_df.loc[cat, 'supervised_total'] = acc_df.loc[cat, 'supervised_total'] + len(dat_df)
# supervision_acc_dict['correct'] += len(dat_df.query('label == predicted'))
# supervision_acc_dict['total'] += len(dat_df)
continue # skip if read was used as labeled example
# load manual labels
manual_df = loader_fun(manual_dict[fb_base])
......@@ -337,6 +344,7 @@ for fb in fb_files:
transition_df.index = [f'{str(idx[0])}_{str(idx[1])}' for idx in transition_df.index.to_flat_index()]
plt.rcParams.update({'font.size': 15}) # smaller text for summary plots
acc_df.loc[:, 'accuracy'] = acc_df.correct / acc_df.total
acc_df.loc[:, 'supervised_accuracy'] = acc_df.supervised_correct / acc_df.supervised_total
acc_df.to_csv(f'{summary_dir}/accuracy_per_category.tsv', sep='\t')
# Plot transition density plots
......
import argparse, os, sys, re
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
parser = argparse.ArgumentParser(description='draw histogram of logps of an index table')
parser.add_argument('--in-csv', type=str, required=True)
parser.add_argument('--out-svg', type=str, required=True)
args = parser.parse_args()
df = pd.read_csv(args.in_csv)
df.loc[:, 'nb_states'] = df.apply(lambda x: [2, 3][int(re.search('(?<=trace_)[0-9]+(?=.dat)', x.trace).group(0))> 150], axis=1)
fig, ax = plt.subplots(figsize=(8.25, 2.9375))
sns.histplot(x='logprob', hue='nb_states', data=df, ax=ax, bins=30)
ax.set_xlabel('log(P($X_n$|θ)) / $T_n$')
plt.tight_layout()
plt.savefig(args.out_svg)
import argparse, os, sys, pickle, re
from os.path import basename, splitext
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
efret_gt_dict = {1: 0.2, 2: 0.4, 3: 0.6}
tr_dict = {'1_2': 0.01, '1_3': 0.01,
'2_1': 0.01, '3_1': 0.01,
'2_3': 0.0, '3_2': 0.0}
parser = argparse.ArgumentParser(description='Produce figure comparing several runs for the same numerical parameter')
parser.add_argument('--eval-dirs', type=str, nargs='+', required=True)
parser.add_argument('--param-name', type=str, required=True)
parser.add_argument('--out-svg', type=str, required=True)
args = parser.parse_args()
ed_list = []
for ed in args.eval_dirs:
if ed[-1] != '/': ed_list.append(ed + '/')
else: ed_list.append(ed)
accuracy_df = pd.DataFrame(columns=['accuracy', 'supervised_accuracy'])
yerr_dict = {}
transitions_list = []
efret_list = []
# -- collect data ---
for ed in ed_list:
cat = float(re.search('^[.0-9]+(?=_)', basename(ed[:-1])).group(0))
# Accuracy
acc_df = pd.read_csv(f'{ed}summary_stats/accuracy_per_category.tsv', sep='\t')
accuracy_df.loc[cat, 'accuracy'] = acc_df.loc[0, 'accuracy']
accuracy_df.loc[cat, 'supervised_accuracy'] = acc_df.loc[0, 'supervised_accuracy']
# efret means
with open(f'{ed}summary_stats/efret_pred_dict.pkl', 'rb') as fh: efret_pred_dict = pickle.load(fh)['']
edf = pd.concat([pd.DataFrame({'label': st, 'SNR': cat, 'E_FRET': efret_pred_dict[st]}) for st in efret_pred_dict])
with open(f'{ed}summary_stats/_means_dict.pkl', 'rb') as fh: means_dict = pickle.load(fh)
# edf = pd.read_csv(f'{ed}summary_stats/event_counts_kde_.tsv', sep='\t', header=0)
# edf.loc[:, 'SNR'] = cat
efret_list.append(edf)
# transition rates
tdf = pd.read_csv(f'{ed}summary_stats/transitions.tsv', header=0, names=['transition', 'tr'], sep='\t',
usecols=[0,1], index_col=0)
tr_sd = np.load(f'{ed}summary_stats/transitions.tsv.npy')
tdf.loc[:, 'ci_low'] = tr_sd[0,1,:]
tdf.loc[:, 'ci_high'] = tr_sd[0, 0, :]
tdf.index = pd.MultiIndex.from_tuples([(cat, ct) for ct in tdf.index])
transitions_list.append(tdf)
# --- concatenating, typecasting ---
accuracy_df.index.rename('SNR', inplace=True)
accuracy_df.reset_index(inplace=True)
accuracy_df.SNR = accuracy_df.SNR.astype(float)
accuracy_df.accuracy = accuracy_df.accuracy.astype(float) * 100
accuracy_df.supervised_accuracy = accuracy_df.supervised_accuracy.astype(float) * 100
transitions_df = pd.concat(transitions_list)
transitions_df.index.rename(['SNR', 'transition'], inplace=True)
transitions_df.sort_index(level='SNR', inplace=True)
transitions_df.reset_index(inplace=True)
transitions_df.SNR = transitions_df.SNR.astype(float)
efret_df = pd.concat(efret_list)
# efret_df.index.rename(['SNR', 'state'], inplace=True)
# efret_df.reset_index(inplace=True)
efret_df.SNR = efret_df.SNR.astype(float)
# --- plotting ---
fig, (ax_efret, ax_tr, ax_acc) = plt.subplots(3,1, figsize=(8.25, 2.9375 * 3), sharex=True)
# efret
# sns.lineplot(x='SNR', y='E_FRET', hue='label', data=efret_df, ax=ax_efret, ci='sd',
# err_style='bars', markers=True, estimator='median',
# palette={1: 'black', 2: 'black', 3: 'black'},
# err_kws={'capsize': 3.0})
sns.lineplot(x='SNR', y='E_FRET', hue='label', data=efret_df, ax=ax_efret, ci=95,
err_style='bars', markers=True,
# palette={1: 'black', 2: 'black', 3: 'black'},
err_kws={'capsize': 3.0})
ax_efret.axhline(0.2, ls='--', color='black', zorder=3, alpha=0.5)
ax_efret.axhline(0.4, ls='--', color='black', zorder=3, alpha=0.5)
ax_efret.axhline(0.6, ls='--', color='black', zorder=3, alpha=0.5)
# ax_efret.set_ylim(0, 1.2)
ax_efret.set_ylabel('$E_{PR}$')
# transition rate
cb_dark2 = ['#1b9e77','#d95f02','#7570b3','#e7298a','#66a61e','#e6ab02']
tr_color_dict = {tr: cb_dark2[i] for i, tr in enumerate(transitions_df.transition.unique())}
for trt, cdf in transitions_df.groupby('transition'):
line = ax_tr.errorbar(x=cdf.SNR, y=cdf.tr, yerr=cdf.loc[:,['ci_low', 'ci_high']].to_numpy().T,
color=tr_color_dict[trt],
zorder=1)
line.set_label(trt.replace('_', '→'))
ax_tr.legend()
ax_tr.axhline(0.01, ls='--', color='black', zorder=3, alpha=0.5)
ax_tr.axhline(0.0, ls='--', color='black', zorder=3, alpha=0.5)
ax_tr.set_ylabel('transition rate ($s^{-1}$)')
# Accuracy
sns.lineplot(x='SNR', y='accuracy', color='black', data=accuracy_df, ax=ax_acc)
sns.lineplot(x='SNR', y='supervised_accuracy', color='black', data=accuracy_df, ax=ax_acc, ls='dotted')
ax_acc.set_ylabel('accuracy (%)')
plt.xlabel(args.param_name)
plt.tight_layout()
plt.savefig(args.out_svg)
plt.close(fig)
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment