import {IDiffExpResults, VisData} from '../../../utils/database';
import * as Plotly from 'plotly.js';
import {Layout, Shape} from 'plotly.js';
import {VolcanoColorMap, VolcanoMark, VolcanoProps} from './index';

export function groupByCompare(data: IDiffExpResults[]): {
    [key: string]: IDiffExpResults[];
} {
    const comps: {[k: string]: IDiffExpResults[]} = {};
    data.forEach(d => {
        if (!comps[d.comparison]) {
            comps[d.comparison] = [];
        }
        d['-log10(pvalue)'] = -Math.log10(d['pvalue']);
        d['-log10(padj)'] = -Math.log10(d['padj']);

        comps[d.comparison].push(d);
    });
    return comps;
}
export function handleAxes(max: number, compare: string[]) {
    // TODO:build these as sep funcs
    const basex = {
        linecolor: 'black',
        linewidth: 1,
        mirror: true,
        showgrid: false
    };
    const basey = {
        range: [0, max],
        linecolor: 'black',
        linewidth: 1,
        mirror: true,
        showgrid: false
    };
    const ax: {[key: string]: any} = {};

    compare.forEach((_, i) => {
        if (i === 0) {
            ax[`xaxis`] = {...basex};
            ax[`yaxis`] = {...basey};
        }
        ax[`xaxis${i + 1}`] = {...basex};
        ax[`yaxis${i + 1}`] = {...basey, showticklabels: false};
    });
    return ax;
}

export const grouSigs = (data: IDiffExpResults[], FCcut: number, pchoice: 'pvalue' | 'padj', pcut: number) => {
    const LFCcut = Math.log2(FCcut);
    const sigGroup: {[key: string]: any} = {
        Upregulated: {
            extra: [],
            log2foldchangeList: [],
            log10pvalueList: [],
            log10padjList: [],
            basemeanList: []
        },
        Downregulated: {
            extra: [],
            log2foldchangeList: [],
            log10pvalueList: [],
            log10padjList: [],
            basemeanList: []
        },
        Normal: {
            extra: [],
            log2foldchangeList: [],
            log10pvalueList: [],
            log10padjList: [],
            basemeanList: []
        },
        'Not Significant': {
            extra: [],
            log2foldchangeList: [],
            log10pvalueList: [],
            log10padjList: [],
            basemeanList: []
        }
    };
    data.forEach(d => {
        switch (true) {
            case d['log2foldchange'] >= LFCcut && d[pchoice] <= pcut:
                sigGroup['Upregulated'].extra.push(d);
                sigGroup['Upregulated'].log2foldchangeList.push(d.log2foldchange);
                sigGroup['Upregulated'].log10pvalueList.push(d['-log10(pvalue)']);
                sigGroup['Upregulated'].log10padjList.push(d['-log10(padj)']);
                sigGroup['Upregulated'].basemeanList.push(d['basemean']);

                break;
            case d['log2foldchange'] >= -LFCcut && d[pchoice] <= pcut:
                sigGroup['Downregulated'].extra.push(d);
                sigGroup['Downregulated'].log2foldchangeList.push(d.log2foldchange);
                sigGroup['Downregulated'].log10pvalueList.push(d['-log10(pvalue)']);
                sigGroup['Downregulated'].log10padjList.push(d['-log10(padj)']);
                sigGroup['Downregulated'].basemeanList.push(d['basemean']);

                break;
            case (d['log2foldchange'] > -LFCcut || d['log2foldchange'] < LFCcut) && d[pchoice] <= pcut:
                sigGroup['Normal'].extra.push(d);
                sigGroup['Normal'].log2foldchangeList.push(d.log2foldchange);
                sigGroup['Normal'].log10pvalueList.push(d['-log10(pvalue)']);
                sigGroup['Normal'].log10padjList.push(d['-log10(padj)']);
                sigGroup['Normal'].basemeanList.push(d['basemean']);

                break;
            case d[pchoice] > pcut:
                sigGroup['Not Significant'].extra.push(d);
                sigGroup['Not Significant'].log2foldchangeList.push(d.log2foldchange);
                sigGroup['Not Significant'].log10pvalueList.push(d['-log10(pvalue)']);
                sigGroup['Not Significant'].log10padjList.push(d['-log10(padj)']);
                sigGroup['Not Significant'].basemeanList.push(d['basemean']);

                break;
        }
    });
    return sigGroup;
};

export function getData(
    data: any,
    compare: string[],
    pchoice: 'pvalue' | 'padj',
    pcut: number,
    FCcut: number,
    markSize: number,
    colorMap: VolcanoColorMap
) {
    const range: number[] = [0];
    const volc: Plotly.Data[] = [];
    if (compare.length > 0) {
        compare.forEach((c, i) => {
            const cData = data[c];
            if (!cData) return;
            const groups = grouSigs(cData, FCcut, pchoice, pcut);
            const ax: {[key: string]: any} = {};
            if (i > 0) {
                ax[`xaxis`] = `x${i + 1}`;
                ax[`yaxis`] = `y${i + 1}`;
                ax['showlegend'] = false;
            }
            Object.keys(groups).forEach(g => {
                const plot = groups[g];

                //a reduce would be more efficient
                const sort = plot.log2foldchangeList.map((i: number) => parseInt(String(i))).sort((a: number, b: number) => a - b);
                range.push(sort[0], sort[sort.length - 1]);
                const volcano: Plotly.Data = {
                    x: plot.log2foldchangeList,
                    y: plot[`log10${pchoice}List`],
                    ...ax,
                    customdata: plot.extra,
                    name: g,
                    mode: 'markers',
                    type: 'scattergl',
                    legendgroup: g,
                    marker: {
                        color: colorMap[g as VolcanoMark],
                        size: markSize
                    }
                };
                volc.push(volcano);
            });
        });
    }
    const sort = range.sort((a, b) => a - b);

    return {
        figure: volc,
        range: [sort.shift() as number, sort.pop() as number]
    };
}
function buildLines(FCcut: number, pcut: number, compare: string[], range: number[]) {
    const LFCcut = Math.log2(FCcut);
    return compare.map((_, i) => [
        {
            type: 'line',
            yref: `y${i + 1}`,
            xref: `x${i + 1}`,
            x0: range[0] - 2,
            x1: range[1] + 2,
            y0: -Math.log10(pcut),
            y1: -Math.log10(pcut),
            line: {
                color: 'rgb(0, 0, 0)',
                width: 1,
                dash: 'dash'
            }
        },
        {
            type: 'line',
            yref: 'paper',
            xref: `x${i + 1}`,

            x0: -LFCcut,
            x1: -LFCcut,
            // x1: 1,
            y0: 0,
            y1: 1,
            line: {
                color: 'rgb(255, 0, 0)',
                width: 1,
                dash: 'dash'
            }
        },
        {
            type: 'line',
            yref: 'paper',
            xref: `x${i + 1}`,
            x0: LFCcut,
            x1: LFCcut,
            y0: 0,
            y1: 1,
            line: {
                color: 'rgb(255, 0, 0)',
                width: 1,
                dash: 'dash'
            }
        }
    ]);
}

export function buildLayout(FCcut: number, pcut: number, max: number, compare: string[], range: number[]): Partial<Layout> {
    return {
        grid: {
            columns: compare.length,
            rows: 1,
            pattern: 'independent'
        },
        ...handleAxes(max, compare),
        shapes: buildLines(FCcut, pcut, compare, range).reduce((a, i) => a.concat(i), []) as Partial<Shape>[]
    };
}

export function getVolcanoFigure({data, compare, pcut, FCcut, max, markSize, pchoice, colorMap}: VolcanoProps & {data: VisData}) {
    const {figure, range} = getData(data.diffExpResults.data, compare, pchoice, pcut, FCcut, markSize, colorMap);
    const figureLayout = buildLayout(FCcut, pcut, max, compare, range);
    return {
        layout: figureLayout,
        traces: figure
    };
}
