clear all
prep_figure
close all

%%
load('gdsc_from_ccle_sigmoid')

for comparison = {'mech','stat_test','stat_indep_test'}
    
    index = 1;
    values = [];
    modes = {};
    types = {};
    
    switch(comparison{1})
        case 'mech'
            randtrain = repmat({'rand train'},[1,10]);
            randval = repmat({'rand val'},[1,10]);
            randtest = repmat({'rand test'},[1,10]);
            sets = {'training set','test set','indep. test set','GDSC_mech','GDSC_sigmoid',randtrain{:},randval{:},randtest{:}};
            irandtrain = 1;
            irandval = 1;
            irandtest = 1;
        case 'stat_test'
            sets = {'test set','lasso','lasso_int1','lasso_int2','lasso_int3','lasso_allint','glmgraph','GDSC-sigmoid'};
            load('preprocessed_stat')
        case 'stat_indep_test'
            sets = {'indep. test set','lasso','lasso_int1','lasso_int2','lasso_int3','lasso_allint','glmgraph','GDSC-sigmoid'};
            load('preprocessed_stat')
    end
    for set = sets
        switch(set{1})
            case 'rand train'
                load(['preprocessed_random_' num2str(irandtrain)])
                sim = sim_fit;
                data = data_fit;
                irandtrain = irandtrain + 1
            case 'rand val'
                load(['preprocessed_random_' num2str(irandval)])
                sim = sim_val;
                data = data_val;
                irandval = irandval + 1
            case 'rand test'
                load(['preprocessed_random_' num2str(irandtest)])
                sim = sim_test;
                data = data_test;
                irandtest = irandtest + 1
            case 'GDSC_mech'
                load('preprocessed_GDSC')
                sim = sim_fit;
                sim(isnan(gdsc_from_ccle_sigmoid)) = NaN;
                data = data_fit;
                data(isnan(gdsc_from_ccle_sigmoid)) = NaN;
            case 'GDSC_sigmoid'
                load('gdsc_from_ccle_sigmoid')
                sim = gdsc_from_ccle_sigmoid;
                data = data_fit;
                data(isnan(gdsc_from_ccle_sigmoid)) = NaN;
            case 'training set'
                load('preprocessed')
                sim = sim_fit;
                data = data_fit;
            case 'test set'
                load('preprocessed')
                sim = sim_val;
                data = data_val;
            case 'indep. test set'
                load('preprocessed')
                sim = sim_test;
                data = data_test;
            otherwise
                switch(comparison{1})
                    case 'stat_test'
                        sim = sim_stat_val(strrep(set{1},'GDSC-sigmoid','randomforest'));
                        data = data_val;
                    case 'stat_indep_test'
                        sim = sim_stat_test(strrep(set{1},'GDSC-sigmoid','randomforest'));
                        data = data_test;   
                end
        end
        for mode = {'rmse','corr'}
            for icv = 1:5
                x=nanmedian(sim(:,:,:,icv,:),5);
                y=nanmedian(data(:,:,:,icv,:),5);
                switch(mode{1})
                    case 'rmse'
                        try
                            values(index) = getrmse(x,y);
                        catch
                            values(index) = NaN;
                        end
                    case 'corr'
                        try
                            values(index) = getcorrcoeff(x,y);
                        catch
                            values(index) = NaN;
                        end
                end
                modes{index} = mode{1};
                types{index} = set{1};
                index = index + 1;
            end
        end
    end
    
    for set = unique(sets)
        disp(strrep(set{1},'GDSC-sigmoid','randomforest'))
        for mode_nested = {'corr','rmse'}
            ref = values(and(strcmp(types,sets{1}),strcmp(modes,mode_nested{1})));
            vals = values(and(strcmp(types,set{1}),strcmp(modes,mode_nested{1})));
            if(length(ref)==length(vals))
                [~,p] = ttest(ref,vals);
            else
                p =NaN;
            end
            disp([mode_nested{1} ': ' num2str(nanmean(vals)) '+-' num2str(nanstd(vals)) '  p=' num2str(p*(length(sets)-1))]) 
        end
    end
    
    figure
    g = gramm('x',types,'y',values);
    g.facet_grid(modes,[],'scale','free_y');
    g.stat_summary('geom','bar');
    g.stat_summary('geom','errorbar');
    g.axe_property('XTickLabelRotation',90,'YGrid','on');
    if (strcmp(comparison{1},'mech'))
        g.set_order_options('x',{'training set','rand train','test set','rand val','indep. test set','rand test','GDSC_mech','GDSC_sigmoid'});
    end
        
    g.draw();
    g.facet_axes_handles(1).YLim(2) = 1;
    g.facet_axes_handles(1).YTick = 0:0.2:1;
    g.facet_axes_handles(2).YLim(2) = 0.3;
    g.facet_axes_handles(2).YTick = 0:0.05:0.3;
    switch(comparison{1})
        case  'mech'
            width = 150;
        otherwise
            width = 210;
    end
    makeFigure(0,0,2,width,145,['figure_generalization_' comparison{1}],true);
end


function z = getcorrcoeff(x,y)
    z = corr(x(~isnan(y)),y(~isnan(y)),'type','Pearson');
end

function z = getrmse(x,y)
    z = sqrt(nanmean((x(:)-y(:)).^2));
end
