import { RandomizationAttributes } from '@/components/Studies/Randomization/Report/Report.model';
import Chip from '@/components/UI/Chip/Chip';
import GroupLabel from '@/components/UI/GroupLabel';
import type { Column } from '@/components/UI/Table/TableComponent.model';
import {
  calculateAge,
  formatNumber,
  hasOwnProperty,
  roundToTwo,
  standardDeviation,
  standardErrorOfMean,
} from '@/helpers';
import { _flatten, _get, _isEmpty, _isNumber, _notNil } from '@/littledash';
import type { Animal } from '@/model/Animal.model';
import { ID } from '@/model/Common.model';
import { PresetCalculation } from '@/model/PresetCalculation.model';
import { DateUtils } from '@/utils/Date.utils';
import _set from 'lodash/set';
import kmeans, { ClusteringOutput } from 'node-kmeans';
import { TbInfoCircle } from 'react-icons/tb';
import { mean, median, medianAbsoluteDeviation, shuffle } from 'simple-statistics';
import { RandomizeAnimal } from './Randomize';
import type {
  RandomizationResult,
  RandomizationResultFlattened,
  RandomizationTableData,
  RandomizeMetric,
  RandomizeOptions,
  RandomizeState,
  ResultsTable,
  SubjectAttribute,
} from './Randomize.model';
import { calculatePValueAndAnova } from './Statistics';
import type { AnovaMetric as AnovaMetricProps, OneWayAnovaResults } from './Statistics.model';

interface Metric {
  accessor: string;
}

export const subjectAttrs: Array<SubjectAttribute> = [
  {
    id: 'sex',
    Header: 'Sex',
    accessor: 'sex',
  },
  {
    id: 'donor_id',
    Header: 'Donor ID',
    accessor: 'alt_ids.donor',
  },
  {
    id: 'dob',
    Header: 'Date of birth',
    accessor: 'dob',
  },
];

export const mappedAttrs = subjectAttrs.reduce<Record<string, SubjectAttribute>>((acc, att) => {
  acc[att.id] = att;
  return acc;
}, {});

export const constructCohorts = (attr: Column<RandomizationTableData>, subjects: Animal[]) => {
  if (!_isEmpty(attr)) {
    const result: Animal[][] = [];
    const uniqueAttrs = [...new Set(subjects.map((s) => _get(s, attr.accessor)))];
    uniqueAttrs.map((a) => {
      const cohort = subjects.filter((s) => _get(s, attr.accessor) === a);
      return result.push(cohort);
    });

    return result;
  } else {
    return [subjects];
  }
};

export const constructVectors = (subjects: Animal[], options: Metric[]): number[][] => {
  return subjects.reduce<number[][]>((acc, v) => {
    const vectors: number[] = [];
    options.forEach((o) => {
      const attr = _get(v, o.accessor);
      if (attr) {
        vectors.push(attr);
      }
    });
    acc.push(vectors);

    return acc;
  }, []);
};

export const clusterData = (vectors: number[][], k = 4): Promise<ClusteringOutput[]> => {
  return new Promise((resolve, reject) => {
    kmeans.clusterize(vectors, { k: k > vectors.length ? vectors.length : k }, (err, res) => {
      if (res) {
        resolve(res);
      } else {
        reject(new Error('Could not cluster data', { cause: err }));
      }
    });
  });
};

export const groupsWithAttributeSizes = (groups: RandomizationResult[], attrs: number = 1): RandomizationResult[] => {
  return groups.map((g) => {
    const decimal = g.size / attrs;
    return {
      ...g,
      initialAttrCapacity: Math.floor(decimal),
      cohort_subject_ids: Array(attrs)
        .fill(undefined)
        .map(() => []),
    };
  });
};

export const clusterCohort = async (cohort: Animal[], cohortLength: number, metrics: Metric[], kSize: number) => {
  const vectors = constructVectors(cohort, metrics);
  const kmeansResults = await clusterData(vectors, cohortLength < 4 ? 2 : kSize);
  const clusters = kmeansResults.reduce<number[][]>((acc, v, i) => {
    acc[i] = v.clusterInd;
    return acc;
  }, []);

  return clusters.map((c) => shuffle(c));
};

/**
 * @randomiseSubjects
 * ———
 * Attributes: Sex, donor ID, strain
 * Metrics: Tumour volume, weight, glucose (This is what we cluster by)
 * Cohort: Subjects grouped by their unique selected attribute (defaults to 1 if NONE is selected)
 *
 * Outcome:
 * ———
 * The goal of this function is to have fair distributions of attributes and a variety of weighted metric(s) in each group
 *
 * Steps:
 * ———
 * 1. Loop through each cohort and k-means Cluster the subjects values in that cohort,
 * 2. Loop through each group picking a value from each cluster until the `initialAttrCapacity` is reached for each cohort
 * 3. Pick a group with space at random and loop through remainders sequentially until the store is depleted
 *
 * Improvements:
 * ———
 * 1. Break down into smaller testable functions
 * 2. Have cohort_subject_ids come in as a set of empty arrays to populate
 * 3. How can we Cluster 10+ metrics?
 * 4. If there is an easier way of not looping so much and a simpler way to get it in one swoop instead of 2x with the distrubuteRemainders fn
 **/

export const randomiseSubjects = async (
  subjects: Animal[],
  state: RandomizeOptions
): Promise<RandomizationResultFlattened[]> => {
  let { groups } = state;
  const { metrics, attr, kSize } = state;
  const cohorts = constructCohorts(attr, subjects);
  groups = groupsWithAttributeSizes(groups, cohorts.length);
  const cohortStore = cohorts.map((c) => c.map((s) => s.id));

  let groupCount = Math.floor(Math.random() * (groups.length - 1));

  if (metrics.length > 0) {
    for (let index = 0; index < cohorts.length; index++) {
      const store = cohortStore[index];
      const clusters = await clusterCohort(cohorts[index], store?.length ?? 0, metrics, kSize);
      const flattenedClusters = _flatten(clusters) ?? [];

      while (store.length > 0) {
        const group = groups[groupCount];
        const groupIndex = groups.findIndex((g) => g.id === group.id);
        const subject = cohorts[index][flattenedClusters[0]];
        if ((group.initialAttrCapacity ?? 0) > group.cohort_subject_ids[index].length) {
          groups[groupIndex].cohort_subject_ids[index].push(subject.id as number);
          flattenedClusters.shift();
          store.splice(
            store.findIndex((s) => s === subject.id),
            1
          );
        }

        if (groupCount + 1 !== groups.length) {
          groupCount++;
        } else {
          groupCount = 0;
        }

        if (groups.every((g) => g.cohort_subject_ids[index].length === g.initialAttrCapacity)) {
          break;
        }
      }
    }

    if ((_flatten(cohortStore) ?? []).length > 0) {
      groups = distrubuteRemainders(groups, cohortStore);
    }
  } else {
    return distributeByAttributeOnly(groups, cohortStore as number[][]);
  }

  return groups.map((g) => ({
    ...g,
    cohort_subject_ids: _flatten(g.cohort_subject_ids) as number[],
  }));
};

export const distributeByAttributeOnly = (
  groups: RandomizationResult[],
  cohorts: number[][]
): RandomizationResultFlattened[] => {
  const totalAnimals = cohorts.reduce((sum, cohort) => sum + cohort.length, 0);

  const remainingCohorts: number[][] = [];

  const shuffledGroups = shuffle(groups);
  cohorts.forEach((cohort) => {
    const remainingAnimals = [...cohort];
    shuffledGroups.forEach((group) => {
      // Get the group - cohort for this group
      const targetCount = Math.floor((cohort.length * group.size) / totalAnimals);

      // If the group can fit the ratio, add those animals
      if (targetCount + group.cohort_subject_ids.length <= group.size) {
        const assignedAnimals = remainingAnimals.splice(0, targetCount);
        group.cohort_subject_ids.push(assignedAnimals);
      }
    });

    // Push any remaining animals to the backlog
    if (remainingAnimals.length > 0) {
      remainingCohorts.push(remainingAnimals);
    }
  });

  const remainingCohortsSorted = remainingCohorts.sort((a, b) => b.length - a.length);

  // Remove any filled groups
  const newGroups = groups
    .map((g) => ({
      ...g,
      cohort_subject_ids: _flatten(g.cohort_subject_ids) as number[],
    }))
    .filter((x) => x.cohort_subject_ids.length < x.size);

  let currentGroupIndex = 0;

  remainingCohortsSorted.forEach((cohort) => {
    const remainingAnimals = [...cohort];

    // Assign each remaining animal to a group in order
    while (remainingAnimals.length > 0) {
      const currentGroup = newGroups[currentGroupIndex];
      if (currentGroup.cohort_subject_ids.length < currentGroup.size) {
        currentGroup.cohort_subject_ids.push(...remainingAnimals.splice(0, 1));
      }
      if (currentGroupIndex + 1 < newGroups.length) {
        currentGroupIndex++;
      } else {
        currentGroupIndex = 0;
      }
    }
  });

  return newGroups;
};

export const distrubuteRemainders = (groups: RandomizationResult[], cohortStore: ID[][]) => {
  let cohortIndex = 0;

  shuffle(groups).map((g) => {
    while ((_flatten(g.cohort_subject_ids) ?? []).length !== g.size) {
      const store = cohortStore[cohortIndex];

      if (!_isEmpty(store)) {
        g.cohort_subject_ids?.[cohortIndex].push(store[0] as number);
        store.splice(0, 1);
      }

      if (cohortIndex + 1 !== cohortStore.length) {
        cohortIndex++;
      } else {
        cohortIndex = 0;
      }

      if (!(_flatten(cohortStore) ?? []).length) {
        break;
      }
    }
  });

  return groups;
};

export const allAttrVariantsPresent = (attr: Partial<SubjectAttribute>, subjects: Animal[]) => {
  const subjectsWithNoVariantPresent = subjects.filter((s) => !_get(s, attr.accessor));
  return _isEmpty(subjectsWithNoVariantPresent);
};

const resultTableColumns = (
  selectedAttr: Column<RandomizationTableData>,
  selectedMetrics: Array<RandomizeMetric> = [],
  pValues: Record<string, OneWayAnovaResults> = {},
  openModal: (modal: any, props: any) => void,
  closeModal: () => void
) => {
  let columns: Column<RandomizationTableData>[] = [
    {
      id: 'group',
      Header: 'Group',
      accessor: 'name',
      Cell: ({ row: { original } }) => {
        if (hasOwnProperty(original, 'group')) {
          return <GroupLabel group={original.group} />;
        }

        return original.name;
      },
    },
    {
      id: 'population',
      Header: 'Population',
      accessor: 'size',
      align: 'right',
    },
  ];

  if (!_isEmpty(selectedMetrics)) {
    columns = [
      ...columns,
      ...selectedMetrics.reduce<Column<RandomizationTableData>[]>((acc, metric) => {
        if (!metric.isExclusion) {
          const pValue = pValues[metric.accessor.split('.')[1]].pvalue;
          acc.push({
            ...metric,
            Header: (
              <div>
                <div>{metric.Header}</div>
                <div className="dark-gray normal f6 flex flex-row">
                  {_notNil(pValue) && _isNumber(pValue) ? `P = ${pValue?.toFixed(4)}` : 'No P Value'}
                  {
                    <a
                      className="flex items-center ml1"
                      onClick={() => {
                        openModal('ONEWAYANOVA_TABLE', {
                          oneWayAnova: pValues[metric.accessor.split('.')[1]],
                          closeModal: closeModal,
                        });
                      }}
                    >
                      <TbInfoCircle />
                    </a>
                  }
                </div>
              </div>
            ),
          });
        }
        return acc;
      }, []),
    ];
  }

  if (!_isEmpty(selectedAttr)) {
    columns.push({
      ...selectedAttr,
      width: 350,
      Cell: ({ row: { original } }) => {
        const attributes = _get(original, selectedAttr.accessor);
        if (typeof attributes === 'object') {
          return (
            <>
              {Object.keys(attributes)
                .sort()
                .map((k, i) => {
                  const title =
                    selectedAttr?.id === RandomizationAttributes.dob
                      ? `${k} (${calculateAge(k, DateUtils.dateNow(), true)})`
                      : k;
                  return <Chip key={k} className={`${i !== 0 ? 'ml2' : ''}`} title={title} value={attributes[k]} />;
                })}
            </>
          );
        }
        return attributes || '';
      },
    });
  }

  return columns;
};

const accessor = (m: string, average?: string) => {
  return `significantMeasurement.${m}.${average ?? 'value'}`;
};

const valueAccessor = (m: string) => {
  return `significantMeasurement.${m}.value`;
};

export const subjectMetrics = (calculations: PresetCalculation[], averageType = 'mean'): RandomizeMetric[] => {
  return calculations.map((m) => ({
    Header: m.name,
    id: m.id,
    accessor: accessor(m.id, averageType) as keyof Animal,
    align: 'right',
    Cell: ({ row: { original } }: { row: { original: RandomizationTableData } }) => {
      const valuePath = accessor(m.id, averageType);
      // This is used for the animal subrow, which uses the value property from PHP
      const value = formatNumber(_get(original, valuePath) ?? _get(original, accessor(m.id, 'value')), true);
      const deviationType = averageType === 'mean' ? 'sd' : 'mad';
      if (Object.prototype.hasOwnProperty.call(original, deviationType)) {
        const getSDValue = _get(original[deviationType], valuePath);
        return (
          <>
            <span className="dib mr2">{value}</span>
            {_notNil(getSDValue) && <span>±{formatNumber(getSDValue, true)}</span>}
          </>
        );
      }

      return value;
    },
  }));
};

export const subjectValueMetrics = (calculations: PresetCalculation[], averageType = 'mean'): RandomizeMetric[] => {
  return calculations.map((m) => ({
    Header: m.name,
    id: m.id,
    accessor: valueAccessor(m.id) as keyof Animal,
    align: 'right',
    Cell: ({ row: { original } }: { row: { original: RandomizationTableData } }) => {
      const valuePath = valueAccessor(m.id);
      // This is used for the animal subrow, which uses the value property from PHP
      const value = formatNumber(_get(original, valuePath) ?? _get(original, valueAccessor(m.id)), true);
      const deviationType = averageType === 'mean' ? 'sd' : 'mad';
      if (Object.prototype.hasOwnProperty.call(original, deviationType)) {
        const getSDValue = _get(original[deviationType], valuePath);
        return (
          <>
            <span className="dib mr2">{value}</span>
            {_notNil(getSDValue) && <span>±{formatNumber(getSDValue, true)}</span>}
          </>
        );
      }

      return value;
    },
  }));
};

export const animalsWithSignificantMeasurement = (
  animals: Animal[],
  measurementsOnDate: RandomizeState['randomizeByDate']['measurementsOnDate'] = null
): RandomizeAnimal[] => {
  return animals.map((animal: Animal) => ({
    ...animal,
    significantMeasurement: _notNil(measurementsOnDate)
      ? (measurementsOnDate?.[animal.api_id ?? ''] ?? null)
      : animal.latestMeasurement,
  }));
};

export const resultsTableTransformer = (
  results: RandomizationResultFlattened[],
  subjects: Animal[] = [],
  selectedAttr: Column<RandomizationTableData>,
  selectedMetrics: RandomizeMetric[] = [],
  sd = 'sd',
  openModal: (modal: any, props: any) => void,
  closeModal: () => void,
  averageType = 'mean'
): ResultsTable => {
  const anovaMetrics: Record<string, AnovaMetricProps> = {};
  const tableData = results.reduce<{ data: RandomizationTableData[] }>(
    (acc, v, i) => {
      if (_isEmpty(v.cohort_subject_ids ?? [])) {
        return acc;
      }
      const groupedSubjects = subjects.filter((s) => v.cohort_subject_ids.includes(s.id as number));
      const group: ResultsTable['data'][number] = {
        group: v,
        name: v.name,
        size: (v.animals_count ?? 0) + v.cohort_subject_ids.length,
        sd: { significantMeasurement: {} },
        mad: { significantMeasurement: {} },
        sex: v.sex ?? {},
        significantMeasurement: {},
        subRows: groupedSubjects.map((s) => ({
          name: s.name,
          sex: s.sex === 'm' ? 'Male' : 'Female',
          dob: s.dob,
          alt_ids: s.alt_ids,
          significantMeasurement: s.significantMeasurement,
        })),
      };

      if (!_isEmpty(selectedAttr)) {
        const allAttrs = group.subRows.map((s) => _get(s, selectedAttr.accessor));
        const counts = allAttrs.reduce(
          (acc, value) => ({
            ...acc,
            [value]: (acc[value] || 0) + 1,
          }),
          {}
        );
        _set(group, selectedAttr.accessor as string, counts);
      }

      if (!_isEmpty(selectedMetrics)) {
        selectedMetrics.forEach((metric) => {
          const groupMetrics = group.subRows.map((s) => Number(_get(s, accessor(metric.id))));
          if (!metric.isExclusion) {
            if (_notNil(anovaMetrics[metric.accessor])) {
              anovaMetrics[metric.accessor].push(groupMetrics);
            } else {
              anovaMetrics[metric.accessor] = [groupMetrics];
            }
          }

          switch (averageType) {
            case 'mean': {
              const sdCalc = standardDeviation(groupMetrics);
              if (!isNaN(sdCalc)) {
                _set(group.sd, metric.accessor, sdCalc);
                if (sd !== 'sd') {
                  _set(group.sd, metric.accessor, standardErrorOfMean(sdCalc, groupMetrics.length));
                }
              }
              const meanValue = groupMetrics.length >= 2 ? roundToTwo(mean(groupMetrics)) : groupMetrics[0];
              _set(group, metric.accessor, meanValue);
              break;
            }
            case 'median': {
              const madCalc = medianAbsoluteDeviation(groupMetrics);
              if (!isNaN(madCalc)) {
                _set(group.mad, metric.accessor, madCalc);
              }
              const medianValue = groupMetrics.length >= 2 ? roundToTwo(median(groupMetrics)) : groupMetrics[0];
              _set(group, metric.accessor, medianValue);
              break;
            }
          }
        });
      }

      acc.data[i] = group;

      return acc;
    },
    {
      data: [],
    }
  );

  const oneWayAnovaResults = Object.entries(anovaMetrics).reduce<Record<string, OneWayAnovaResults>>(
    (acc, [metric, anova]) => {
      acc[metric.split('.')[1]] = calculatePValueAndAnova(anova);
      return acc;
    },
    {}
  );

  return {
    ...tableData,
    oneWayAnovaResults,
    columns: resultTableColumns(selectedAttr, selectedMetrics, oneWayAnovaResults, openModal, closeModal),
  };
};

export const metricPresentAcrossAllAnimals = (animals: RandomizeAnimal[], metricName: string): boolean => {
  return animals.every((a) => a.significantMeasurement?.[metricName]);
};

/** Returns metrics that are only present across all animals */
export const enabledMetricsForRandomize = (
  animals: RandomizeAnimal[],
  metrics: PresetCalculation[]
): PresetCalculation[] => {
  return metrics.filter((m) => metricPresentAcrossAllAnimals(animals, m.id));
};
