import { err, JsonDecoder, ok } from 'ts.data.json';
import { ModelInfo } from 'types/nn-types/ModelInfo';
import { FlatShape, IModelGraph, ModelGraphLink, ModelGraphNode } from 'types/nn-types/ModelGraph';
import { CheckpointInfo } from 'types/nn-types/CheckpointInfo';
import { ModelStats } from 'types/nn-types/ModelStats';
import { EntityType } from 'types/inspection-types/EntityType';
import ValueWithUnit from 'types/nn-types/ValueWithUnit';
import { $JsonDecoderErrors } from 'ts.data.json/dist/json-decoder';
import ModelSourceCode from 'types/nn-types/ModelSourceCode';
import * as mathjs from 'mathjs';
import { constructLayerInnardGraph } from 'tools/construct-layer-innard-graph';
import { DataRow, DataValue } from 'types/inspection-types/DataArray';
import { LayerwiseStatisticalDescriptors, StatisticalDescriptors } from 'types/nn-types/StatisticalDescriptors';

export const stringArrayDecoder = JsonDecoder.array<string>(JsonDecoder.string, 'string[]');

/** ***********************
 * Array<string|number>[] *
 ************************ */
const stringOrNumberDecoder = JsonDecoder.oneOf<string | number>(
    [JsonDecoder.string, JsonDecoder.number],
    'string | number',
);
export const stringOrNumberArrayDecoder = JsonDecoder.array<string | number>(
    stringOrNumberDecoder,
    'Array<string|number>[]',
);

/** ****************
 * Dict<undefined> *
 ***************** */

const undefinedDictDecoder = JsonDecoder.dictionary<undefined>(JsonDecoder.succeed, 'Dict<undefined>');

/** *****
 * Date *
 ****** */

const dateDecoder: JsonDecoder.Decoder<Date> = new JsonDecoder.Decoder<Date>((json: unknown) => {
    if (typeof json === 'string') {
        const isValidDate = !isNaN(Date.parse(json));

        if (isValidDate) {
            return ok<Date>(new Date(json));
        }
    }

    return err<Date>($JsonDecoderErrors.primitiveError(json, 'Date'));
});

/** ************************************
 * Dict<{unit: string, value: number}> *
 ************************************* */

export const valueWithUnitDecoder = JsonDecoder.object<ValueWithUnit>(
    {
        unit: JsonDecoder.string,
        value: JsonDecoder.number,
    },
    'ValueWithUnit',
);

/** **********
 * ModelInfo *
 *********** */
export const modelInfoDecoder = JsonDecoder.object<ModelInfo>(
    {
        name: JsonDecoder.string,
        label: JsonDecoder.string,
        children: stringArrayDecoder,
        timestamp: dateDecoder,
        sourceFile: JsonDecoder.string,
        parents: stringArrayDecoder,
        classes: stringOrNumberArrayDecoder,
    },
    'ModelInfo',
);

/** ****************
 * ModelSourceCode *
 *******************/
export const modelSourceCodeDecoder = JsonDecoder.object<ModelSourceCode>(
    {
        modelFunction: JsonDecoder.string,
        metadataFunction: JsonDecoder.string,
        fullSource: JsonDecoder.string,
    },
    'ModelSourceCode',
);

/** ***********
 * ModelStats *
 ************ */
export const modelStatsDecoder = JsonDecoder.object<ModelStats>(
    {
        numTrainableParameters: JsonDecoder.number,
        numLayers: JsonDecoder.number,
        executionTimeTestset: valueWithUnitDecoder,
        memoryUsage: valueWithUnitDecoder,
        modelSaveSize: valueWithUnitDecoder,
    },
    'ModelStats',
);

/** ***********
 * ModelGraph *
 ************ */
const numberOrNullDecoder = JsonDecoder.oneOf<number | null>(
    [JsonDecoder.number, JsonDecoder.isNull(null)],
    'number | null',
);

const flatShapeDecoder = JsonDecoder.array<number | null>(numberOrNullDecoder, '(number | null)[]');

const shapeDecoder = JsonDecoder.array<(number | null) | FlatShape>(
    JsonDecoder.oneOf<(number | null) | FlatShape>(
        [numberOrNullDecoder, flatShapeDecoder],
        '(number | null) | (number | null)[]',
    ),
    'Shape',
);

const entityTypeDecoder = JsonDecoder.enumeration<EntityType>(EntityType, 'EntityType');

const modelGraphNodeDecoder = (modelId: string) =>
    JsonDecoder.object<ModelGraphNode>(
        {
            id: JsonDecoder.string,
            parentModelId: JsonDecoder.constant(modelId),
            kerasId: JsonDecoder.string,
            name: JsonDecoder.string,
            clsName: JsonDecoder.string,
            inputShape: shapeDecoder,
            outputShape: shapeDecoder,
            numParameter: JsonDecoder.number,
            config: undefinedDictDecoder,
            innardGraph: JsonDecoder.succeed,
            type: JsonDecoder.failover<EntityType>(EntityType.MISC, entityTypeDecoder),
            // TODO: This might be problematic due to all arrays referencing the same value
            interestingness: JsonDecoder.constant([]),
        },
        'ModelGraphNode',
    ).map((a: ModelGraphNode) => {
        constructLayerInnardGraph(a);
        return a;
    });

const modelGraphLinkDecoder = JsonDecoder.object<ModelGraphLink>(
    {
        source: JsonDecoder.string,
        target: JsonDecoder.string,
    },
    'ModelGraphLink',
);

export const imodelGraphDecoder = (modelId: string) =>
    JsonDecoder.object<IModelGraph>(
        {
            directed: JsonDecoder.boolean,
            nodes: JsonDecoder.array<ModelGraphNode>(modelGraphNodeDecoder(modelId), 'Node[]'),
            links: JsonDecoder.array<ModelGraphLink>(modelGraphLinkDecoder, 'Link[]'),
        },
        'IModelGraph',
    );

/** ***************
 * CheckpointInfo *
 **************** */
const checkpointStatisticsDecoder = JsonDecoder.dictionary(JsonDecoder.number, 'CheckpointStatistics');

const checkpointInfoDecoder = JsonDecoder.object<CheckpointInfo>(
    {
        step: JsonDecoder.number,
        batch: JsonDecoder.number,
        epoch: JsonDecoder.number,
        performanceStatistics: JsonDecoder.optional(checkpointStatisticsDecoder),
        timestamp: dateDecoder,
        filesize: valueWithUnitDecoder,
    },
    'ModelInfo',
);
export const checkpointInfosDecoder = JsonDecoder.array<CheckpointInfo>(checkpointInfoDecoder, 'CheckpointInfo[]');

/** **********
 * DataArray.ts *
 *********** */
const matrixDecoder: JsonDecoder.Decoder<mathjs.Matrix> = new JsonDecoder.Decoder<mathjs.Matrix>((json: unknown) => {
    if (Array.isArray(json)) {
        return ok<mathjs.Matrix>(mathjs.matrix(json));
    }

    return err<mathjs.Matrix>($JsonDecoderErrors.primitiveError(json, 'mathjs.Matrix'));
});

const matrixOrNumberDecoder = JsonDecoder.oneOf<DataValue>([JsonDecoder.number, matrixDecoder], 'DataValue');

export const dataArrayDecoder = JsonDecoder.array<DataRow>(
    JsonDecoder.dictionary(matrixOrNumberDecoder, 'DataRow'),
    'DataArray',
);

/** ****************
 * Interestingness *
 ***************** */

const statisticalDescriptorsDecoder = JsonDecoder.object<StatisticalDescriptors>(
    {
        avg_kl_div_between_steps: JsonDecoder.number,
        kl_div_from_initialization: JsonDecoder.number,
        kl_div_from_model_mean: JsonDecoder.number,
        kl_div_from_prev_step: JsonDecoder.number,
        maximum: JsonDecoder.number,
        median: JsonDecoder.number,
        minimum: JsonDecoder.number,
        variance: JsonDecoder.number,
        skew: JsonDecoder.number,
    },
    'StatisticalDescriptor',
);

const layerwiseStatisticalDescriptorsDecoder = JsonDecoder.object<LayerwiseStatisticalDescriptors>(
    {
        dense_kernel: JsonDecoder.optional(statisticalDescriptorsDecoder),
        dense_bias: JsonDecoder.optional(statisticalDescriptorsDecoder),
        conv2d_kernel: JsonDecoder.optional(statisticalDescriptorsDecoder),
        conv2d_bias: JsonDecoder.optional(statisticalDescriptorsDecoder),
        activation: JsonDecoder.optional(statisticalDescriptorsDecoder),
    },
    'LayerwiseInterestingness',
);

export const statisticalDescriptorsCatalogDecoder = JsonDecoder.dictionary<LayerwiseStatisticalDescriptors>(
    layerwiseStatisticalDescriptorsDecoder,
    'Record<InterestingnessVariableType, StatisticalDescriptor>',
);
