clear all
close all
load('preprocessed_MCT')

concs = [0.35,1.08,3.25,10;
    0.825,3.25,12.5,50;
    0.825,3.25,12.5,50];

drugs = {'Erlotinib','Lapatinib','PD0325901','PLX_{4720}','CHIR_{265}','Selumetinib','Vandetanib'};

plotConcFun(data_fit,sim_fit,1:8,drugs,'corr','MCT - single treatment');

load('preprocessed_MCT_combi')

combis = {[2,1],[2,3],[1,3]};

figure
icombi = 1;
for combi = combis;
    [X,Y] = meshgrid(1:4,1:4);
    XX = X(:);
    YY = Y(:);
    for iz = 1:length(XX)
        for icv = 1:5
            x = nanmedian(sim_fit(:,combi{1}(1),combi{1}(2),XX(iz),YY(iz),icv,:),7);
            y = nanmedian(data_fit(:,combi{1}(1),combi{1}(2),XX(iz),YY(iz),icv,:),7);
            ZZ(iz,icv) = getcorrcoeff(x,y);
        end
    end
    subplot(1,3,icombi)
    %     surf(X,Y,reshape(mean(ZZ,2),4,4));
    imagesc(reshape(mean(ZZ,2),4,4));
    caxis([-1,1])
    view(270,90)
    axis square
    if(icombi==3)
        h = colorbar;
        ylabel(h,'correlation')
    end
    tmp = gca;
    tmp.XTick = 1:4;
    tmp.YTick = 1:4;
    tmp.XTickLabel = concs(combi{1}(1),:);
    tmp.YTickLabel = concs(combi{1}(2),:);
    xlabel([drugs{combi{1}(1)} ' [\muM]'])
    ylabel([drugs{combi{1}(2)} ' [\muM]'])
    title([drugs{combi{1}(1)} ' - '  drugs{combi{1}(2)} ': ' num2str(mean(ZZ(:))) ' avg. corr.'])
    icombi = icombi+1;
end

%%
load('preprocessed')
data_ccle = data_fit;
load('preprocessed_MCT_combi')
sim_fit_combi = sim_fit;
data_fit_combi = data_fit;
concs_combi = concs_fit;

load('preprocessed_MCT')

%%

%%
interp_fit = NaN(size(data_fit));

Nrows = 12;
Ncols = 10;

% for idrug = 1:3    
%     for icl = 1:length(celllines)
%         for icv = 1:5
%         subtightplot(Nrows,Ncols,icl);
%         set(gca,'XScale','log');
%         ylim([0,1.5]);
%         xlim([1e-1,1e4])
%         hold on
%         
%         conc = concs_fit(idrug,:);
%         
%         par = getFunPars(conc,viability);
%         [conc_gdsc,filtered_gdsc] = plotData(drugs{idrug},celllines{icl},data_gdsc,'b');
%         [~,filtered_ccle] = plotData(drugs{idrug},celllines{icl},data_ccle,'r',conc_gdsc);
%         end
%     end
% end

%%



bliss_fit_combi = NaN(size(data_fit_combi));
hsa_fit_combi = NaN(size(data_fit_combi));
combis = {[2,1],[2,3],[1,3]};
for icl = 1:120
    for idrug = 1:3
        for icv = 1:5
            viability_i_data = squeeze(nanmedian(data_fit(icl,idrug,:,icv,:),5));
            viability_i_sim = squeeze(nanmedian(sim_fit(icl,idrug,:,icv,:),5));
            if(any(~isnan(viability_i_data)))
                par_i_data = getFunPars(concs_fit(idrug,:),viability_i_data);
                par_i_sim = getFunPars(concs_fit(idrug,:),viability_i_sim);
                effect_i_data = 1-hillCurve(concs_combi(combis{1}(1),:,1),par_i_data);
                effect_i_sim = 1-hillCurve(concs_combi(combis{1}(1),:,1),par_i_sim);
                for jdrug = 1:3
                    viability_j_data = squeeze(nanmedian(data_fit(icl,jdrug,:,icv,:),5));
                    viability_j_sim = squeeze(nanmedian(sim_fit(icl,jdrug,:,icv,:),5));
                    if(any(~isnan(viability_j_data)))
                        par_j_data = getFunPars(concs_fit(jdrug,:),viability_j_data);
                        par_j_sim = getFunPars(concs_fit(jdrug,:),viability_j_sim);
                        effect_j_data = 1-hillCurve(concs_combi(combis{1}(2),:,2),par_j_data);
                        effect_j_sim = 1-hillCurve(concs_combi(combis{1}(2),:,2),par_j_sim);
                        for iconc = 1:4
                            for jconc = 1:4
                                
                                % data
%                                 combi = 1-nanmedian(data_fit_combi(icl,idrug,jdrug,iconc,jconc,icv,:),7);
%                                 excess_over_bliss_data(icl,idrug,jdrug,iconc,jconc,icv) = combi - effect_i_data(iconc)*effect_j_data(jconc);
%                                 excess_over_HSA_data(icl,idrug,jdrug,iconc,jconc,icv) = combi - max(effect_i_data(iconc),effect_j_data(jconc));

                                effect_combi = 1-nanmedian(data_fit_combi(icl,idrug,jdrug,iconc,jconc,icv,:),7);
                                bliss_fit_combi(icl,idrug,jdrug,iconc,jconc,icv,:) = 1-(effect_i_data(iconc) + effect_j_data(jconc) - effect_i_data(iconc)*effect_j_data(jconc));
                                hsa_fit_combi(icl,idrug,jdrug,iconc,jconc,icv,:) = 1-max(effect_i_data(iconc),effect_j_data(jconc));
                                
                                % sim
%                                 combi = 1-nanmedian(sim_fit_combi(icl,idrug,jdrug,iconc,jconc,icv,:),7);
%                                 excess_over_bliss_sim(icl,idrug,jdrug,iconc,jconc,icv) = combi - effect_i_sim(iconc)*effect_j_sim(jconc);
%                                 excess_over_HSA_sim(icl,idrug,jdrug,iconc,jconc,icv) = combi - max(effect_i_sim(iconc),effect_j_sim(jconc));
                            end
                        end
                    end
                end
            end
        end
    end
end

%%
index = 1;
modes = {};
types = {};

sets = {'mech1','mech2','Bliss','HSA'};

for mode = {'rmse','corr'}
    for set = sets
        for icv = 1:5
            switch(set{1})
                case 'mech1'
                    sim = nanmedian(sim_fit(:,:,:,icv,:),5);
                    data = nanmedian(data_fit(:,:,:,icv,:),5);
                case 'mech2'
                    sim = nanmedian(sim_fit_combi(:,:,:,:,:,icv,:),7);
                    data = nanmedian(data_fit_combi(:,:,:,:,:,icv,:),7);
                case 'Bliss'
                    sim = nanmedian(bliss_fit_combi(:,:,:,:,:,icv,:),7);
                    data = nanmedian(data_fit_combi(:,:,:,:,:,icv,:),7);
                case 'HSA'
                    sim = nanmedian(hsa_fit_combi(:,:,:,:,:,icv,:),7);
                    data = nanmedian(data_fit_combi(:,:,:,:,:,icv,:),7);
            end
            switch(mode{1})
                case 'corr'
                    values(index) = getcorrcoeff(sim,data);
                case 'rmse'
                    values(index) = getrmse(sim,data);
            end
            types{index} = set{1};
            modes{index} = mode{1};
            index = index+1;
        end
    end
end
for set = sets
    disp(strrep(set{1},'GDSC-sigmoid','randomforest'))
    for mode_nested = {'corr','rmse'}
        ref = values(and(strcmp(types,sets{2}),strcmp(modes,mode_nested{1})));
        vals = values(and(strcmp(types,set{1}),strcmp(modes,mode_nested{1})));
        [~,p] = ttest(ref,vals);
        disp([mode_nested{1} ': ' num2str(nanmean(vals)) '+-' num2str(nanstd(vals)) '  p=' num2str(p*(length(sets)-1))])
    end
end
%%
g = gramm('x',types,'y',values);
g.facet_grid(modes,[],'scale','free_y')
g.stat_summary('geom','bar');
g.stat_summary('geom','errorbar');
g.set_order_options('x',sets);
g.axe_property('XTickLabelRotation',90,'YGrid','on');
figure
g.draw
g.facet_axes_handles(1).YLim(2) = 1;
g.facet_axes_handles(1).YTick = 0:0.2:1;
g.facet_axes_handles(2).YLim(2) = 0.3;
g.facet_axes_handles(2).YTick = 0:0.05:0.3;

makeFigure(0,0,2,80,138,'figure_prediction_mct',true)

%%
index = 1;
modes = {};
types = {};
xx = {};
yy = {};

sets = {'Model Combi Drug','Bliss Combi Drug','HSA Combi Drug'};


for set = sets
    for icv = 1:5
        switch(set{1})
            case 'Model Combi Drug'
                sim = nanmedian(sim_fit_combi(:,:,:,:,:,icv,:),7);
                data = nanmedian(data_fit_combi(:,:,:,:,:,icv,:),7);
            case 'Bliss Combi Drug'
                sim = nanmedian(bliss_fit_combi(:,:,:,:,:,icv,:),7);
                data = nanmedian(data_fit_combi(:,:,:,:,:,icv,:),7);
            case 'HSA Combi Drug'
                sim = nanmedian(hsa_fit_combi(:,:,:,:,:,icv,:),7);
                data = nanmedian(data_fit_combi(:,:,:,:,:,icv,:),7);
        end
        [y,x] = ksdensity(sim(:)-data(:));
        xx{index} = x;
        yy{index} = y;
        types{index} = set{1};
        index = index+1;
    end
end


%%
g = gramm('x',xx,'y',yy,'color',types);
g.stat_summary('geom','lines');
% g.no_legend()
g.set_names('x','prediction-data','y','probability density')
figure
g.draw

makeFigure(0,0,2,71,60,'figure_residuals_synergy',true)

%%
combi = [2,1];
icl = 12;
icv = 2;

squeeze(data_fit_combi(icl,combi(1),combi(2),:,:,icv,1))

figure;
subplot(1,5,1)
plotDrugSurface(squeeze(data_fit_combi(icl,combi(1),combi(2),:,:,icv,1)),concs(combi(1),:),concs(combi(2),:))
title('data')
ylabel([drugs{combi(2)} ' [\muM]'])
xlabel([drugs{combi(1)} ' [\muM]'])

subplot(1,5,2)
plotDrugSurface(squeeze(median(sim_fit_combi(icl,combi(1),combi(2),:,:,icv,:),7)),concs(combi(1),:),concs(combi(2),:))
title('mechanistic model')
tmp = gca;
tmp.YTick = [];
xlabel([drugs{combi(1)} ' [\muM]'])

subplot(1,5,3)
plotDrugSurface(squeeze(bliss_fit_combi(icl,combi(1),combi(2),:,:,icv,1)),concs(combi(1),:),concs(combi(2),:))
title('bliss independence')
tmp = gca;
tmp.YTick = [];
xlabel([drugs{combi(1)} ' [\muM]'])


subplot(1,5,4)
plotDrugSurface(squeeze(hsa_fit_combi(icl,combi(1),combi(2),:,:,icv,1)),concs(combi(1),:),concs(combi(2),:))
title('highest single agent')
tmp = gca;
tmp.YTick = [];
xlabel([drugs{combi(1)} ' [\muM]'])


subplot(1,5,5)
colorbar
caxis([0,1])

makeFigure(0,0,2,400,70,'figure_combination_mct',true)

%%
figure
subplot(1,3,1)

%% sim
figure
x = nanmedian(data_fit_combi(:,:,:,:,:,:,:),7);

y1 = nanmedian(sim_fit_combi(:,:,:,:,:,:,:),7);
column1 = ones(size(x));

y2 = nanmedian(bliss_fit_combi(:,:,:,:,:,:,:),7);
column2 = 2*ones(size(x));

y3 = nanmedian(hsa_fit_combi(:,:,:,:,:,:,:),7);
column3 = 3*ones(size(x));

color = bsxfun(@times,permute(1:5,[1,3,4,5,6,2]),column1);

xx = [x(:);x(:);x(:)];
yy = [y1(:);y2(:);y3(:)];
col = [column1(:);column2(:);column3(:)];
cc = [color(:);color(:);color(:)];

columns = {'Mechanistic Model','Bliss Undependence','Highest Single Agent'}

g = gramm('x',xx,'y',yy,'color',cc);
g.geom_point()
g.set_names('x','experimental data','y','model prediction','Color','CrossValidation','Column','')
g.set_order_options('column',columns)
g.axe_property('DataAspectRatio',[1,1,1])
g.facet_grid([],columns(col));
g.draw()

%for line = g.results.geom_point_handle
%    line.MarkerSize = 3;
%end

makeFigure(0,0,2,250,120,'figure_combination_corr_mct',true)





% %%
% figure
% subplot(1,2,1)
% plot(excess_over_bliss_sim(:),excess_over_bliss_data(:),'rx');
% title('excess over bliss')
% xlim([-1,1])
% ylim([-1,1])
% xlabel('simulation')
% ylabel('data')
% subplot(1,2,2)
% plot(excess_over_HSA_sim(:),excess_over_HSA_data(:),'rx');
% title('excess over HSA')
% xlim([-1,1])
% ylim([-1,1])
% xlabel('simulation')
% ylabel('data')
%%

function z = getcorrcoeff(x,y)
    xx = x(~isnan(y));
    yy = y(~isnan(y));
    if(isempty(xx))
        z = NaN;
    else
        z = corr(xx,yy,'type','Pearson');
    end
end

function z = getrmse(x,y)
    z = sqrt(nanmean((x(:)-y(:)).^2));
end


function [] = plotDrugSurface(data,conc1,conc2)
    imagesc(data);
    caxis([0,1]);
    axis square
    set(gca,'XTick',1:4,'YTick',1:4,'XTickLabel',conc1,'YTickLabel',conc2);
end

function [xx,yy] = plotConcFun(sim,data,conc,drugs,mode,title)
    
    switch(mode)
        case 'corr'
            fun = @(x,y) getcorrcoeff(x,y);
            ylim = [-1,1];
    end
    
    drugs{end+1} = 'all';
    for ic = 1:size(data,3)
        for icv = 1:5
            for id = 1:size(data,2)
                x = nanmedian(sim(:,id,ic,icv,:),5);
                y = nanmedian(data(:,id,ic,icv,:),5);
                z = fun(x(~isnan(y)),y(~isnan(y)));
                
                xx(id,ic,icv) = ic;
                yy(id,ic,icv) = z;
                cc(id,ic,icv) = id;
            end
            
            x = nanmedian(sim(:,:,ic,icv,:),5);
            y = nanmedian(data(:,:,ic,icv,:),5);
            z = fun(x(~isnan(y)),y(~isnan(y)));
            
            id=id+1;
            
            xx(id,ic,icv) = ic;
            yy(id,ic,icv) = z;
            cc(id,ic,icv) = id;
        end
    end
    figure
    g = gramm('x',conc(xx(:)),'y',yy(:),'color',drugs(cc(:)));
    if(numel(unique(xx(:)))>1)
        g.stat_summary('type','ci','geom','area');
        g.axe_property('XScale','log','YLim',ylim);
    else
        g.stat_boxplot()
        g.set_names('x','drug concentration','y','correlation coefficient','color','drug');
    end
    
    g.set_title(title);
    g.draw;
end

function par = getFunPars(conc,viability)
    
    Rmin = min(viability);
    Rmax = max(viability);
    
    ub = [Rmax, 5, 5];
    lb = [Rmin,-5,-5];
    
    options = optimoptions('fmincon');
    options.Display = 'none';
    par = fmincon(@(par) sse(conc,viability(:)',par),[0,-1,0],[],[],[],[],lb,ub,[],options);
    
    Amax = 1;
    Amin = par(1);
    EC50 = exp(par(2));
    hill = exp(par(3));
    
    x_bend = (Amin-Amax)/(1+4.6805)+Amax;
    num_bend = sum(conc>x_bend);
end

function obj = sse(conc,viability,par)
    obj = sum((viability-hillCurve(conc,par)).^2);
end

function value = hillCurve(conc,par)
    Amax = 1;
    Amin = par(1);
    EC50 = exp(par(2));
    hill = exp(par(3));
    
    value = Amax + (Amin-Amax)./(1+(EC50./conc).^(hill));
end
