import React, {
  forwardRef,
  PropsWithChildren,
  ReactElement,
  Ref,
  useEffect,
  useImperativeHandle,
  useRef,
  useState,
} from 'react';
import {
  GridChildComponentProps,
  GridOnScrollProps,
  VariableSizeGrid as Grid,
} from 'react-window';
import { Table } from 'antd';
import classNames from 'classnames';
import ResizeObserver from 'rc-resize-observer';
import { ScrollToInitRef } from '@models/types';
import {
  ResetAfterIndicesOptions,
  ResizeObserverState,
  TableBody,
  VirtualTableColumn,
  VirtualTableProps,
} from './models';

import './styles.scss';

export const VirtualTable = forwardRef(
  (
    {
      columns,
      dataSource = [],
      maxViewRow,
      scroll,
      rowHeight = 54,
      hasLongCellData = false,
      verticalScrollCorrection = 0,
    }: PropsWithChildren<VirtualTableProps>,
    ref: Ref<ScrollToInitRef>,
  ) => {
    const gridRef = useRef<Grid>(null);

    const [tableWidth, setTableWidth] = useState(0);
    const [tableHeight, setTableHeight] = useState(0);

    useImperativeHandle(ref, () => ({
      scrollToInit: (): void => {
        gridRef?.current?.scrollTo({ scrollLeft: 0, scrollTop: 0 });
      },
    }));

    useEffect(() => {
      if (gridRef.current) {
        gridRef.current.resetAfterIndices({
          columnIndex: 0,
          shouldForceUpdate: true,
        } as ResetAfterIndicesOptions);
      }
    }, [tableWidth, hasLongCellData, columns]);

    useEffect(() => {
      const dataCount = dataSource.length;
      const preResult =
        dataCount > maxViewRow ? maxViewRow * rowHeight : dataCount * rowHeight;
      const scrollBarWidth = 16;

      setTableHeight(hasLongCellData ? preResult + scrollBarWidth : preResult);
    }, [dataSource, rowHeight, maxViewRow, hasLongCellData]);

    const columnsWithoutWidth = columns.filter(
      ({ width }: VirtualTableColumn) => !width,
    ).length;

    const sumWidth = columns.reduce(
      (acc: number, v: VirtualTableColumn) => acc + (v.width || 0),
      0,
    );

    const currentTableWidth = hasLongCellData
      ? tableWidth
      : tableWidth - sumWidth;

    const widthColumnsWithoutWidth = Math.floor(
      currentTableWidth / (columnsWithoutWidth || 1),
    );

    const tableWidthMoreColumnsWidth =
      currentTableWidth >
      widthColumnsWithoutWidth * columnsWithoutWidth + sumWidth;

    const additionalWidth = hasLongCellData
      ? Math.max(
          Math.floor(
            (currentTableWidth -
              (sumWidth + widthColumnsWithoutWidth * columnsWithoutWidth)) /
              columns.length,
          ),
          0,
        )
      : 0;

    const currentColumns = columns.map((column: VirtualTableColumn) =>
      column.width
        ? {
            ...column,
            width: column.width + additionalWidth,
          }
        : {
            ...column,
            width:
              // can be 0, division by 0 gives infinity
              Math.floor(currentTableWidth / (columnsWithoutWidth || 1)) +
              additionalWidth,
          },
    );

    const xScrollAuto =
      hasLongCellData && !dataSource?.length && !tableWidthMoreColumnsWidth;

    const renderVirtualList = (
      rawData: ReadonlyArray<any>,
      { scrollbarSize, onScroll }: TableBody,
    ): React.ReactNode => {
      const totalHeight: number = rawData.length * rowHeight;
      const lastColumnHasWidth: boolean = !!columns[columns.length - 1].width;
      const deductedValue: number = scrollbarSize + 1;
      const viewSmallerTotalHeight: boolean = totalHeight > tableHeight;

      return (
        <Grid
          className="prov-virtual-table__body"
          ref={gridRef}
          columnCount={currentColumns.length}
          columnWidth={(index: number): number => {
            const { width } = currentColumns[index];
            const isLastColumn = index === currentColumns.length - 1;

            switch (true) {
              case viewSmallerTotalHeight &&
                isLastColumn &&
                !lastColumnHasWidth:
                return width! - deductedValue;
              case viewSmallerTotalHeight &&
                !isLastColumn &&
                lastColumnHasWidth:
                return (
                  width! -
                  (columnsWithoutWidth
                    ? Math.floor(deductedValue / columnsWithoutWidth)
                    : Math.round(deductedValue / currentColumns.length / 10) *
                      10)
                );
              default:
                return width!;
            }
          }}
          height={tableHeight}
          rowCount={rawData.length}
          rowHeight={(): number => rowHeight}
          width={tableWidth}
          onScroll={({ scrollLeft }: GridOnScrollProps): void => {
            onScroll({
              scrollLeft,
            });
          }}
        >
          {({
            columnIndex,
            rowIndex,
            style,
          }: GridChildComponentProps): ReactElement => {
            const { render, dataIndex } = currentColumns[columnIndex];

            return (
              <div
                className={classNames(
                  'vt-cell',
                  rowIndex % 2 === 0 ? 'vt-cell--odd-row' : 'vt-cell--even-row',
                )}
                style={style}
              >
                {render
                  ? render(
                      dataIndex
                        ? rawData[rowIndex][dataIndex]
                        : rawData[rowIndex],
                      rawData[rowIndex],
                      rowIndex,
                    )
                  : rawData[rowIndex][dataIndex || '']}
              </div>
            );
          }}
        </Grid>
      );
    };

    const widthCorrection =
      dataSource?.length > maxViewRow ? verticalScrollCorrection : 0;

    return (
      <ResizeObserver
        onResize={({ width }: ResizeObserverState): void => {
          setTableWidth(width);
        }}
      >
        <div className="prov-virtual-table">
          <Table
            columns={currentColumns.map(
              (column: VirtualTableColumn, index: number) => ({
                ...column,
                width:
                  index === currentColumns.length - 1
                    ? (column?.width || 0) + widthCorrection
                    : column?.width,
              }),
            )}
            dataSource={dataSource}
            pagination={false}
            components={{
              body: renderVirtualList,
            }}
            scroll={
              scroll || {
                y: maxViewRow * rowHeight,
                ...(xScrollAuto ? { x: 'auto' } : {}),
              }
            }
          />
        </div>
      </ResizeObserver>
    );
  },
);
