import React from 'react';
import { WidgetProps } from 'App/WidgetPanel/Widget';
import { AxisBottom, AxisLeft } from '@visx/axis';
import { scaleOrdinal } from '@visx/scale';
import ParentSize from '@visx/responsive/lib/components/ParentSize';
import { Group } from '@visx/group';
import { LinePath } from '@visx/shape';
import { curveBasis } from '@visx/curve';
import { LegendItem, LegendLabel, LegendOrdinal } from '@visx/legend';
import { InnerVisContainer, LegendContainer, OuterVisContainer } from 'App/WidgetPanel/Widgets/StyledContainers';
import { schemeLineStyles } from 'tools/d3-scheme-line-styles';
import { getXScale, getYScale } from 'App/WidgetPanel/Widgets/multi-time-scale';
import { produce } from 'immer';
import AxisLabels from 'App/WidgetPanel/Widgets/AxisLabels';
import _ from 'lodash';
import { getNumberFormatter } from 'tools/helpers';
import useLocalStorage from 'tools/hooks/useLocalStorage';

const DEFAULT_MARGIN = { top: 10, right: 25, bottom: 35, left: 45 };

const SingleEntityMultiTimeWidget = ({ widgetDefinition, margin = DEFAULT_MARGIN }: WidgetProps) => {
    const [hoveredAttribute, setHoveredAttribute] = React.useState<string>();

    const { dataEntities } = widgetDefinition;
    const attributes = Object.keys(dataEntities[0].data[0]).filter((a) => a !== 'step');

    // Put this into local storage, so it keeps constant over unmount-mount-cycles
    const [selectedAttributes, setSelectedAttributes] = useLocalStorage(
        `widget:${widgetDefinition.widgetId}/selectedAttributes`,
        attributes,
    );

    const linestyleScale = scaleOrdinal({
        domain: attributes,
        range: [...schemeLineStyles],
    });

    const getStrokeOpacity = (attr: string) => {
        if (hoveredAttribute) {
            return hoveredAttribute === attr ? 1 : 0.5;
        }

        return 1;
    };

    const onMouseOutHandler = () => {
        setHoveredAttribute(undefined);
    };

    const unit = dataEntities[0].unit;

    return (
        <>
            <OuterVisContainer>
                <InnerVisContainer>
                    <ParentSize debounceTime={10}>
                        {({ width: visWidth, height: visHeight }) => {
                            const xMax = visWidth - margin.left - margin.right;
                            const yMax = visHeight - margin.top - margin.bottom;
                            const scaleX = getXScale(dataEntities, xMax, widgetDefinition.entityRanges);
                            const scaleY = getYScale(
                                dataEntities,
                                // If widget has its range defined from outside, use all attributes for scaling
                                widgetDefinition.entityRanges ? attributes : selectedAttributes,
                                widgetDefinition.scaleCreatorFn,
                                yMax,
                                widgetDefinition.entityRanges,
                            );

                            return (
                                <>
                                    <svg width={visWidth} height={visHeight} style={{ background: '#fff' }}>
                                        <Group left={margin.left} top={margin.top}>
                                            <AxisLeft tickFormat={getNumberFormatter(3)} scale={scaleY} numTicks={5} />
                                            <AxisBottom
                                                tickFormat={getNumberFormatter(3)}
                                                top={yMax}
                                                scale={scaleX}
                                                numTicks={5}
                                            />
                                            {dataEntities.map((e) => {
                                                const { data, color } = e;

                                                return attributes
                                                    .filter((a) => selectedAttributes.includes(a))
                                                    .map((attr) => {
                                                        const onMouseMoveHandler = () => {
                                                            setHoveredAttribute(attr);
                                                        };

                                                        return (
                                                            <LinePath
                                                                onMouseMove={onMouseMoveHandler}
                                                                onMouseOut={onMouseOutHandler}
                                                                key={attr}
                                                                data={data}
                                                                curve={curveBasis}
                                                                x={(d) => scaleX(d['step'] as number) ?? 0}
                                                                y={(d) => scaleY(d[attr] as number) ?? 0}
                                                                stroke={color}
                                                                strokeWidth={2}
                                                                strokeOpacity={getStrokeOpacity(attr)}
                                                                style={{ ...linestyleScale(attr), fill: 'none' }}
                                                            />
                                                        );
                                                    });
                                            })}

                                            <AxisLabels
                                                labelX={'Training Step'}
                                                labelY={
                                                    attributes.length === 1
                                                        ? _.startCase(attributes[0]) + (unit ? ` (${unit})` : '')
                                                        : undefined
                                                }
                                                xMax={xMax}
                                                yMax={yMax}
                                                markIncompatibleScales={widgetDefinition.entityRanges === undefined}
                                            />
                                        </Group>
                                    </svg>
                                </>
                            );
                        }}
                    </ParentSize>
                </InnerVisContainer>
            </OuterVisContainer>
            {attributes.length > 1 && (
                <LegendContainer>
                    <LegendOrdinal scale={linestyleScale} labelFormat={(label) => `${label}`}>
                        {(labels) => (
                            <div style={{ display: 'flex', flexDirection: 'row', flexWrap: 'wrap' }}>
                                {labels.map((label, i) => {
                                    const onMouseMoveHandler = () => {
                                        setHoveredAttribute(label.text);
                                    };
                                    const onClickHandler = () => {
                                        setSelectedAttributes((prevState) =>
                                            produce(prevState, (draftState) => {
                                                const selectedAttributeIndex = draftState.indexOf(label.text);

                                                if (selectedAttributeIndex >= 0) {
                                                    // If all attributes are selected, hide all others and only select the clicked one
                                                    if (draftState.length === attributes.length) {
                                                        draftState = [label.text];
                                                    } else {
                                                        // Otherwise, hide current one
                                                        draftState.splice(selectedAttributeIndex, 1);

                                                        // If there are no visible attributes left, unhide all of them
                                                        if (draftState.length === 0) {
                                                            draftState.push(...attributes);
                                                        }
                                                    }
                                                } else {
                                                    // If the attribute was hidden, show it
                                                    draftState.push(label.text);
                                                }

                                                return draftState;
                                            }),
                                        );
                                    };

                                    const lineColor = !selectedAttributes.includes(label.text)
                                        ? 'var(--gray)'
                                        : dataEntities.length > 1
                                        ? 'var(--gray-dark)'
                                        : dataEntities[0].color;
                                    const textColor = !selectedAttributes.includes(label.text)
                                        ? 'var(--gray)'
                                        : 'var(--dark)';

                                    return (
                                        <LegendItem
                                            key={`legend-quantile-${i}`}
                                            margin="0 5px"
                                            onMouseMove={onMouseMoveHandler}
                                            onMouseOut={onMouseOutHandler}
                                            onClick={onClickHandler}
                                        >
                                            <svg width={20} height={10}>
                                                <path
                                                    style={{
                                                        ...label.value,
                                                        stroke: lineColor,
                                                        strokeWidth: 2,
                                                        strokeOpacity: getStrokeOpacity(label.text),
                                                    }}
                                                    d={'m 0,0 20,10'}
                                                />
                                            </svg>
                                            <LegendLabel
                                                style={{
                                                    margin: '0 6px 0 6px',
                                                    color: textColor,
                                                }}
                                            >
                                                {_.startCase(label.text)}
                                            </LegendLabel>
                                        </LegendItem>
                                    );
                                })}
                            </div>
                        )}
                    </LegendOrdinal>
                </LegendContainer>
            )}
        </>
    );
};

export default SingleEntityMultiTimeWidget;
