import React, { FunctionComponent, memo, useContext } from 'react';
import { WidgetProps } from '../Widget';
import _ from 'lodash';
import { InnerVisContainer, OuterVisContainer } from './StyledContainers';
import { ParentSize } from '@visx/responsive';
import { Group } from '@visx/group';
import ClassSelectionContext from 'App/ClassSelectionContext';
import { interpolateYlGnBu } from 'd3-scale-chromatic';
import { scaleBand, scaleLog } from '@visx/scale';
import { getNumberFormatter } from 'tools/helpers';
import { AxisBottom, AxisLeft, AxisRight, AxisTop } from '@visx/axis';
import { Text } from '@visx/text';

const DEFAULT_MARGIN = { top: 25, right: 25, bottom: 25, left: 25 };

const ConfusionMatrixWidget: FunctionComponent<WidgetProps> = ({
    widgetDefinition,
    margin = DEFAULT_MARGIN,
}: WidgetProps) => {
    const { selectedClasses } = useContext(ClassSelectionContext);

    const entity = widgetDefinition.dataEntities[0];

    const data = entity.data.filter(
        (dataRow) =>
            selectedClasses.includes(dataRow['true'] as number | string) &&
            selectedClasses.includes(dataRow['pred'] as number | string),
    );

    const scaleColor = (t: number) => {
        const logScale = scaleLog<number>().range([0, 0.8]).domain([0.01, 1.01]);
        return interpolateYlGnBu(logScale(t + 0.01) ?? 0);
    };
    const numberFormatter = getNumberFormatter(2);
    const tickFormatter = (v: number) => `${selectedClasses[v]}`;

    const idxClassReverseMapping: Record<number | string, number> = {};
    selectedClasses.forEach((c, i) => (idxClassReverseMapping[c] = i));

    return (
        <>
            <OuterVisContainer>
                <InnerVisContainer>
                    <ParentSize debounceTime={10}>
                        {({ width: visWidth, height: visHeight }) => {
                            const xExtend = visWidth - margin.left - margin.right;
                            const yExtend = visHeight - margin.top - margin.bottom;

                            const scaleX = scaleBand<number>({
                                range: [0, xExtend],
                                domain: _.range(selectedClasses.length),
                            });

                            const scaleY = scaleBand<number>({
                                range: [0, yExtend],
                                domain: _.range(selectedClasses.length),
                            });

                            return (
                                <svg width={visWidth} height={visHeight} style={{ background: '#fff' }}>
                                    <Group top={margin.top} left={margin.left}>
                                        {/*<rect x={0} y={0} width={xExtend} height={yExtend} />*/}

                                        {data.map((dataRow) => {
                                            const x = scaleX(idxClassReverseMapping[dataRow['pred'] as number]) ?? 0;
                                            const y = scaleY(idxClassReverseMapping[dataRow['true'] as number]) ?? 0;
                                            const width = scaleX.bandwidth();
                                            const height = scaleY.bandwidth();
                                            const value = dataRow['value'] as number;

                                            return (
                                                <Group key={`${dataRow['pred']}-${dataRow['true']}`}>
                                                    <rect
                                                        x={x}
                                                        y={y}
                                                        width={width}
                                                        height={height}
                                                        fill={scaleColor(value)}
                                                    />
                                                    <Text
                                                        fontSize={12}
                                                        textAnchor={'middle'}
                                                        verticalAnchor={'middle'}
                                                        x={x + scaleX.bandwidth() / 2}
                                                        y={y + scaleY.bandwidth() / 2}
                                                        width={scaleX.bandwidth()}
                                                        scaleToFit={'shrink-only'}
                                                    >
                                                        {numberFormatter(value)}
                                                    </Text>
                                                </Group>
                                            );
                                        })}

                                        <AxisTop
                                            label={'True Class'}
                                            scale={scaleX}
                                            labelOffset={-5}
                                            tickFormat={() => ''}
                                            tickLength={3}
                                        />
                                        <AxisBottom
                                            tickFormat={tickFormatter}
                                            top={yExtend}
                                            scale={scaleX}
                                            tickLength={3}
                                        />
                                        <AxisLeft tickFormat={tickFormatter} scale={scaleY} tickLength={3} />
                                        <AxisRight
                                            tickLength={3}
                                            label={'Predicted  Class'}
                                            labelOffset={8}
                                            tickFormat={() => ''}
                                            scale={scaleY}
                                            left={xExtend}
                                        />
                                    </Group>
                                </svg>
                            );
                        }}
                    </ParentSize>
                </InnerVisContainer>
            </OuterVisContainer>
        </>
    );
};

export default memo(ConfusionMatrixWidget);
