Commit da099f7a authored by Carlos de Lannoy's avatar Carlos de Lannoy
Browse files

add script for tool comparison

parent 8075de02
import os, sys, argparse, re
from io import StringIO
import seaborn as sns
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
from os.path import basename
import numpy as np
import pandas as pd
from scipy.linalg import logm
from itertools import permutations, chain
__location__ = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
sys.path.append(__location__)
from helper_functions import parse_input_path, parse_output_dir
colors = ['#1b9e77', '#d95f02', '#7570b3', '#e7298a', '#66a61e']
def plot_tool_tr_bars(df, ax):
manual_df = df.query('tool == "manual"').copy()
plot_df = df.query('tool != "manual"').copy()
tool_list = plot_df.tool.unique()
nb_tools = len(tool_list)
tridx_dict = {tr: it for it, tr in enumerate(plot_df.transition.unique())}
plot_dists = np.linspace(-1 * (nb_tools // 2), nb_tools // 2, nb_tools) * 0.1
tn_dict = {tn: plot_dists[it] for it, tn in enumerate(tool_list)}
plot_df.loc[:, 'x'] = plot_df.apply(lambda x: tridx_dict[x.transition] + tn_dict[x.tool], axis=1)
color_dict = {tn: colors[it] for it, tn in enumerate(tool_list)}
plot_df.loc[:, 'color'] = plot_df.tool.apply(lambda x: color_dict[x])
manual_df.loc[:, 'lb'] = manual_df.transition.apply(lambda x: plot_df.query(f'transition == "{x}"').loc[:, 'x'].min()) - 0.05
manual_df.loc[:, 'rb'] = manual_df.transition.apply(lambda x: plot_df.query(f'transition == "{x}"').loc[:, 'x'].max()) + 0.05
# fig, ax = plt.subplots(nrows=1)
ax.errorbar(x=plot_df.x, y=plot_df.tr,
yerr=plot_df.loc[:, ('ci_low', 'ci_high')].to_numpy().T,
color='white',
ecolor=plot_df.color, fmt='.')
ax.scatter(x=plot_df.x, y=plot_df.tr, color=plot_df.color)
ax.hlines(y='tr', xmin='lb', xmax='rb', linestyles='dashed', colors='black', data=manual_df)
ax.set_ylabel('transition rate ($s^{-1}$)')
ax.set_xlabel('transition')
ax.set_xticks(ticks=list(range(len(tridx_dict))))
ax.set_xticklabels(labels=[tr.replace('_', '') for tr in tridx_dict])
def plot_efret_dists(df, ax, style='box'):
"""
Plot E_FRET distributions in histograms or boxplots. If Histograms, [ax] should be a list of
axes with len(ax) == nb states.
"""
df = df.loc[np.invert(np.logical_or(np.abs(df.efret) == np.inf, df.efret.isna())), :].copy()
manual_df = df.query('tool == "manual"').copy()
plot_df = df.query('tool != "manual"').copy()
if style == 'box':
box_colors = colors[:len(df.tool.unique())-1] + ['#ffffff']
# box_colors = [(0,0,0,1)] * len(df.tool.unique())
sns.boxplot(x='state', y='efret', hue='tool', data=df, palette=sns.color_palette(box_colors),
width=0.3, showfliers=False, ax=ax)
ax.set_ylim((-1, 2))
ax.set_xlabel('')
ax.set_ylabel('$E_{PR}$')
elif style == 'hist':
hist_colors = colors[:len(df.tool.unique()) - 1]
states = df.state.unique()
for ax_id, st in enumerate(states):
sns.histplot(x='efret', color='grey', data=manual_df.query(f'state == {st}'),
stat='count', bins=np.arange(0,1,0.01),
element='step', fill=True, ax=ax[ax_id]
)
sns.histplot(x='efret', hue='tool', data=plot_df.query(f'state == {st}'),
stat='count', bins=np.arange(0,1,0.01), palette=sns.color_palette(hist_colors),
element='step', fill=False, ax=ax[ax_id]
)
ax[ax_id].set_xlim((0,1))
ax[ax_id].set_xlabel('$E_{PR}$')
def parse_fretboard_results(in_dir):
# --- parse transition rates ---
ci_limits = np.load(f'{in_dir}/summary_stats/transitions.tsv.npy')[0, :, :].reshape(-1)
trdf = pd.read_csv(f'{in_dir}/summary_stats/transitions.tsv', sep='\t', header=0, names=['FRETboard', 'manual'])
trdf = trdf.reset_index().rename({'index': 'transition'}, axis=1).melt(value_vars=['FRETboard', 'manual'],
id_vars=['transition'], var_name=['tool'],
value_name='tr')
trdf.loc[trdf.tool == 'FRETboard', 'ci_low'] = ci_limits[:len(ci_limits) // 2]
trdf.loc[trdf.tool == 'FRETboard', 'ci_high'] = ci_limits[len(ci_limits) // 2:]
trdf.loc[trdf.tool == 'manual', ('ci_low', 'ci_high')] = 0.0
# --- parse efret values ---
efdf_list = []
efdf_manual_list = []
fn_list = parse_input_path(f'{in_dir}/trace_csvs', pattern='*.csv')
for fn in fn_list:
trace_df = pd.read_csv(fn, sep='\t')
efdf_list.append(trace_df.loc[:, ('E_FRET', 'predicted')].copy())
efdf_manual_list.append(trace_df.loc[:, ('E_FRET', 'manual')].copy())
efdf = pd.concat(efdf_list)
efdf.rename({'E_FRET': 'efret', 'predicted': 'state'}, axis=1, inplace=True)
efdf.loc[:, 'tool'] = 'FRETboard'
efdf_manual = pd.concat(efdf_manual_list)
efdf_manual.rename({'E_FRET': 'efret', 'manual': 'state'}, axis=1, inplace=True)
efdf_manual.loc[:, 'tool'] = 'manual'
efdf = pd.concat((efdf, efdf_manual))
# todo Quick fix: remove trash states
states_of_interest = set(chain.from_iterable([tr.split('_') for tr in trdf.transition.unique()]))
states_of_interest = [int(s) for s in states_of_interest]
efdf = efdf.query(f'state in {str(states_of_interest)}')
return efdf, trdf
def parse_ebfret_results(in_dir, framerate):
# --- parse original summary ---
with open(f'{in_dir}/ebFRET_analyzed_summary.csv') as fh:
eb_txt = fh.read()
eb_params_txt = re.search('Parameters[\sa-zA-Z.,0-9_+-]+', eb_txt).group(0)
nb_states = int(re.search('(?<=Num_States,)[0-9]+', eb_params_txt).group(0))
tr_names = [f'{tr[0]}_{tr[1]}' for tr in permutations(np.arange(1, nb_states + 1), 2)]
eb_means_txt = re.search('(?<=Center)[\sa-zA-Z.,0-9_+-]+(?=Precision)', eb_params_txt).group(0).strip().replace(' ', '')
efret_df = pd.read_csv(StringIO(eb_means_txt), names=list(range(1,nb_states + 1))).T
efret_dict[exp]['ebFRET'] = {ri: {'mean': r.loc['Mean'], 'sd': r.loc['Std']} for ri, r in efret_df.iterrows()}
eb_trans_txt = re.search('(?<=Transition_Matrix)[\sa-zA-Z.,0-9_+-]+', eb_params_txt).group(0).strip().replace(' ', '')
eb_trans_means_txt = re.search('(?<=Mean,)[\sa-zA-Z.,0-9_+-]+(?=Std)', eb_trans_txt).group(0).replace('\n,', '\n')
transition_probs = np.genfromtxt(StringIO(eb_trans_means_txt), delimiter=',')
transition_rates = (np.eye(nb_states) + framerate * logm(transition_probs))[np.invert(np.eye(nb_states, dtype=bool))]
# transition_rates = np.genfromtxt(StringIO(eb_trans_means_txt), delimiter=',')[np.invert(np.eye(nb_states, dtype=bool))]
trdf = pd.DataFrame({'transition': tr_names, 'tool': 'ebFRET', 'tr': transition_rates}).set_index(['transition'])
# --- parse bootstrap summary ---
with open(f'{in_dir}/ebFRET_analyzed_summary_bootstrapped.csv') as fh:
eb_bs_txt = fh.read()
tr_bs_txt = re.search('(?<=bootstrap_tr\n\s{4}tr)[\sa-zA-Z.,0-9_+-]+', eb_bs_txt).group(0)
bs_df = pd.read_csv(StringIO(tr_bs_txt), names=tr_names)
for cn in bs_df.columns:
cur_mean, cur_sd = bs_df.loc[:, cn].mean(), bs_df.loc[:, cn].std()
ci_low, ci_high = cur_mean - 2 * cur_sd, cur_mean + 2 * cur_sd
trdf.loc[cn, 'ci_low'] = trdf.loc[cn, 'tr'] - ci_low
trdf.loc[cn, 'ci_high'] = ci_high - trdf.loc[cn, 'tr']
trdf.reset_index(inplace=True)
edf = pd.read_csv(f'{in_dir}/ebFRET_analyzed.dat', sep='\s+', names=['trace_nb', 'd', 'a', 'state'])
edf.state = edf.state.astype(int)
edf.loc[:, 'efret'] = edf.a / (edf.d + edf.a)
edf.drop(['d', 'a', 'trace_nb'], inplace=True, axis=1)
edf.loc[:, 'tool'] = 'ebFRET'
return edf, trdf
def parse_mashfret_results(in_dir):
cp=1
# parse transition rates
fit_fn_list = parse_input_path(f'{in_dir}/kinetics', pattern='*.fit')
fit_dict = {tuple(int(a) for a in re.search('[0-9]+to[0-9]+', fit_fn).group(0).split('to')): fit_fn for fit_fn in
fit_fn_list}
unique_fret_values = np.unique(np.array(list(fit_dict)).reshape(-1))
unique_fret_values.sort()
efret2num_dict = {ufv: i + 1 for i, ufv in enumerate(unique_fret_values)} # fret states sorted low to high
dt_df = pd.DataFrame(0.0, index=pd.MultiIndex.from_tuples(list(permutations(unique_fret_values, 2))),
columns=['mean', 'sd'])
for fit_tup in fit_dict:
if fit_tup[0] == fit_tup[1]: continue
with open(fit_dict[fit_tup], 'r') as fh:
block_reached = False
for line in fh.readlines():
if 'fitting results (bootstrap)' in line: block_reached = True
if block_reached and '(s):' in line:
dt_df.loc[fit_tup] = [float(a.strip()) for a in line.split('\t')[-2:]]
break
trdf = pd.DataFrame(index=dt_df.index, columns=['tr', 'ci_low', 'ci_high'])
trdf.loc[:, 'tr'] = 1 / dt_df.loc[:, 'mean']
trdf.loc[:, 'ci_high'] = 1 / (dt_df.loc[:, 'mean'] - dt_df.loc[:, 'sd']) - trdf.tr
trdf.loc[:, 'ci_low'] = trdf.tr - 1 / (dt_df.loc[:, 'mean'] + dt_df.loc[:, 'sd'])
trdf.loc[:, 'transition'] = trdf.apply(lambda x: f'{efret2num_dict[x.name[0]]}_{efret2num_dict[x.name[1]]}', axis=1).to_list()
trdf.loc[:, 'tool'] = 'MASH-FRET'
trdf.reset_index(drop=True, inplace=True)
# parse efret from backsimulation results
trace_list = []
for fn in parse_input_path(f'{in_dir}/traces_ASCII', pattern='*.txt'):
trace_list.append(pd.read_csv(fn, skiprows=1, sep='\t', usecols=['FRET', 'state sequence']))
efdf = pd.concat(trace_list)
efdf.rename({'FRET': 'efret', 'state sequence': 'state'}, inplace=True, axis=1)
efdf.loc[:, 'tool'] = 'MASH-FRET'
efdf = efdf.query('state != -1').copy()
# reorder states in order of mean, to match transition rate state numbering
st_mean_dict = {}
for st, sdf in efdf.groupby('state'):
st_mean_dict[st] = sdf.efret.mean()
st_order = list(st_mean_dict)
st_order.sort(key=lambda x: st_mean_dict[x])
st_order_dict = {so+1: sn for so, sn in enumerate(st_order)}
efdf.state = efdf.state.apply(lambda x: st_order_dict[x])
return efdf, trdf
parser = argparse.ArgumentParser(description='Plot performance of different tools for side-by-side comparison')
parser.add_argument('--eval-dir', type=str, required=True)
parser.add_argument('--experiments', type=str, nargs='+', required=True,
help='String to identify directories for the same dataset analyzed by different tools')
parser.add_argument('--framerate', type=float, default=10.0,
help='recording frame raterequired to translate transition probs to transition rates')
parser.add_argument('--data-dir',type=str, required=True)
parser.add_argument('--out-dir', type=str, required=True)
args = parser.parse_args()
out_dir = parse_output_dir(args.out_dir)
eval_dirs = [ed[0] for ed in os.walk(args.eval_dir)]
trdf_dict = {}
efret_df_dict = {}
efret_dict = {}
exp_dict = {}
for en in args.experiments:
exp_dict[en] = [ed for ed in eval_dirs if en in basename(ed)]
trdf_dict[en] = []
efret_df_dict[en] = []
efret_dict[en] = {}
# Collect E_FRET mean,sd, transition rates
for exp in exp_dict:
for ed in exp_dict[exp]:
if ed.endswith('_eval'): # the fretboard dir
efdf, trdf = parse_fretboard_results(ed)
elif ed.endswith('_mash'):
efdf, trdf = parse_mashfret_results(ed)
elif ed.endswith('_ebFRET'):
efdf, trdf = parse_ebfret_results(ed, args.framerate)
# elif ed.endswith('_DeepFRET'):
# pass
else:
continue
trdf_dict[exp].append(trdf)
efret_df_dict[exp].append(efdf)
# --- plotting ---
nb_exp = len(exp_dict)
fig = plt.figure(constrained_layout=False, figsize=(48/2.54, 40/2.54))
gs = gridspec.GridSpec(2, nb_exp, figure=fig, wspace=0.2, hspace=0.30)
plot_types = ['efret', 'transition']
for cidx, exp in enumerate(exp_dict):
if cidx == 0:
ax_dict = {plot_name: fig.add_subplot(gs[idx, cidx]) for idx, plot_name in enumerate(plot_types)}
first_ax_dict = ax_dict
else:
ax_dict = {plot_name: fig.add_subplot(gs[idx, cidx], sharey=first_ax_dict[plot_name]) for
idx, plot_name in enumerate(plot_types)}
# move fb and manual results to back of list
idx_fb = np.argwhere([edf.tool.iloc[0] == 'FRETboard' for edf in efret_df_dict[exp]])[0, 0]
efret_df_dict[exp].append(efret_df_dict[exp].pop(idx_fb))
trdf_dict[exp].append(trdf_dict[exp].pop(idx_fb))
# Plot histograms of FRET values
nb_states = len(efret_df_dict[exp][0].state.unique())
plt.figure()
fig_hists = plt.figure(constrained_layout=False, figsize=(48/2.54, 40/2.54))
gs_hists = gridspec.GridSpec(nb_states, 1, figure=fig_hists,
hspace=0.30)
ax_list = [fig_hists.add_subplot(gs_hists[ii, 0]) for ii in range(nb_states)]
plot_efret_dists(pd.concat(efret_df_dict[exp]), ax_list, 'hist')
fig_hists.savefig(f'{out_dir}efret_hists_{exp}.svg')
# add plots to composition figures
plot_efret_dists(pd.concat(efret_df_dict[exp]), ax_dict['efret'])
plot_tool_tr_bars(pd.concat(trdf_dict[exp]), ax_dict['transition'])
fig.savefig(f'{out_dir}tool_comparison_composed.svg')
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment