import { Table, Thead, Tbody, Tr, Th, Td } from '@chakra-ui/react'
import styles from './pipelineAnalytics.module.css'
import { useMemo, useState } from 'react'
import {
  CategoryScale,
  LinearScale,
  Chart as ChartJS,
  Legend,
  BarElement,
  Title,
  Tooltip,
  PointElement,
  LineElement,
} from 'chart.js'
import { getColorByIndex } from '../../lib/colors'
import { Line } from 'react-chartjs-2'
import { cohortNameFromDate, renderCohortName } from './Utilities'
import { OpportunityHistory } from './Types'

const MONTHS_TO_ANALYZE = 12
const COHORT_BUCKET = 'month' as const;

type CohortRates = {
  winRate: number
  lossRate: number
}

const cohortToDate = (cohort: string) => new Date(`${cohort}-01`)
const isMonthInFuture = (baseDate: Date, monthsToAdd: number) => {
  const futureDate = new Date(baseDate)
  futureDate.setMonth(futureDate.getMonth() + monthsToAdd)
  return futureDate > new Date()
}

const calculateRates = (
  wins: number,
  losses: number,
  total: number,
): CohortRates => {
  if (!total) return { winRate: 0, lossRate: 0 }
  return {
    winRate: (wins / total) * 100,
    lossRate: (losses / total) * 100,
  }
}

const calculateCohortRates = (
  cohortsFound: string[],
  winCounts: Record<string, Record<number, number>>,
  lossCounts: Record<string, Record<number, number>>,
  cohortSizes: Record<string, number>,
) => {
  return cohortsFound.reduce((acc, cohort) => {
    const cohortDate = cohortToDate(cohort)
    const cohortSize = cohortSizes[cohort]
    const wins = winCounts[cohort] || {}
    const losses = lossCounts[cohort] || {}

    acc[cohort] = [...Array(MONTHS_TO_ANALYZE)].map((_, i) => {
      if (isMonthInFuture(cohortDate, i)) return null
      return calculateRates(wins[i] || 0, losses[i] || 0, cohortSize)
    })

    return acc
  }, {} as Record<string, (CohortRates | null)[]>)
}

const getChartData = (
  cohortsFound: string[],
  cohortRates: Record<string, (CohortRates | null)[]>,
) => {
  const labels = [...Array(MONTHS_TO_ANALYZE)].map((_, i) => `Month ${i}`)

  const datasets = cohortsFound.map((cohort, index) => {
    const rates = cohortRates[cohort]
    const data = rates
      .map(rate => rate?.winRate ?? null)
      .filter((_, index, array) => {
        const lastValidIndex = array.findLastIndex(v => v !== null)
        return index <= lastValidIndex
      })

    return {
      label: renderCohortName(cohort, COHORT_BUCKET),
      data,
      borderColor: getColorByIndex(index),
      backgroundColor: getColorByIndex(index),
      tension: 0.3,
      fill: false,
    }
  })

  return { labels, datasets }
}

// TODO: Current assumption is that opp stages only progress forward, address this edge case
function OpportunityWinRate({
  data,
  fiscalYearStartMonth,
}: {
  data: any
  fiscalYearStartMonth: number
}) {
  const { cohortsFound, winCounts, lossCounts, openCounts, cohortSizes } =
    useMemo(() => {
      if (!data || !data.opportunities || !data.opportunity_history)
        return {
          cohortsFound: [],
          winCounts: {},
          lossCounts: {},
          openCounts: {},
          cohortSizes: {},
        }

      const cohortCounts: Record<string, number> = {}
      const cohortOpps: Record<string, Set<any>> = {}
      const openCounts: Record<string, number> = {}

      data.opportunities.forEach((opp: any) => {
        const cohort = cohortNameFromDate(
          opp.CreatedDate,
          COHORT_BUCKET,
          fiscalYearStartMonth,
        )

        cohortCounts[cohort] = (cohortCounts[cohort] || 0) + 1

        if (opp.StageName !== 'Closed Won' && opp.StageName !== 'Closed Lost') {
          openCounts[cohort] = (openCounts[cohort] || 0) + 1
        }

        if (!cohortOpps[cohort]) {
          cohortOpps[cohort] = new Set()
        }
        cohortOpps[cohort].add(opp)
      })

    const sortedCohorts = Object.keys(cohortCounts).sort()

      const winCounts: Record<string, Record<number, number>> = {}
      const lossCounts: Record<string, Record<number, number>> = {}

      sortedCohorts.forEach(cohort => {
        winCounts[cohort] = {}
        lossCounts[cohort] = {}

        cohortOpps[cohort].forEach(opp => {
          const oppHistory = data.opportunity_history
            .filter((hist: OpportunityHistory) => hist.OpportunityId === opp.Id)
            .sort(
              (a: OpportunityHistory, b: OpportunityHistory) =>
                new Date(b.CreatedDate).getTime() -
                new Date(a.CreatedDate).getTime(),
            )

          const oppCreatedDate = new Date(opp.CreatedDate)
          // Find first transition to closed status within analysis window
          const closedRecord = oppHistory.find((hist: OpportunityHistory) => {
            if (
              hist.StageName !== 'Closed Won' &&
              hist.StageName !== 'Closed Lost'
            ) {
              return false
            }

            const historyDate = new Date(hist.CreatedDate)
            const monthsDiff =
              (historyDate.getFullYear() - oppCreatedDate.getFullYear()) * 12 +
              (historyDate.getMonth() - oppCreatedDate.getMonth())

            return monthsDiff >= 0 && monthsDiff < MONTHS_TO_ANALYZE
          })

          if (closedRecord) {
            const closeDate = new Date(closedRecord.CreatedDate)
            const closureMonth =
              (closeDate.getFullYear() - oppCreatedDate.getFullYear()) * 12 +
              (closeDate.getMonth() - oppCreatedDate.getMonth())

            for (let month = closureMonth; month < MONTHS_TO_ANALYZE; month++) {
              if (closedRecord.StageName === 'Closed Won') {
                winCounts[cohort][month] = (winCounts[cohort][month] || 0) + 1
              } else {
                lossCounts[cohort][month] = (lossCounts[cohort][month] || 0) + 1
              }
            }
          }
        })
      })

      return {
        cohortsFound: sortedCohorts,
        winCounts,
        lossCounts,
        openCounts,
        cohortSizes: cohortCounts,
      }
    }, [data])

  ChartJS.register(
    CategoryScale,
    LinearScale,
    BarElement,
    Title,
    Tooltip,
    Legend,
    PointElement,
    LineElement,
  )

  const cohortRates = useMemo(
    () =>
      calculateCohortRates(cohortsFound, winCounts, lossCounts, cohortSizes),
    [cohortsFound, winCounts, lossCounts, cohortSizes],
  )

  const renderRateCell = (rates: CohortRates | null, key?: string) => (
    <Td key={key}>
      {rates ? (
        <div style={{ display: 'flex' }}>
          <span
            style={{
              color: '#16a34a',
              width: '50px',
              textAlign: 'right',
              marginRight: '16px',
            }}
          >
            {rates.winRate.toFixed(1)}%
          </span>
          <span style={{ color: '#dc2626', width: '50px', textAlign: 'right' }}>
            {rates.lossRate.toFixed(1)}%
          </span>
        </div>
      ) : (
        '-'
      )}
    </Td>
  )

  const verticalHoverLine = {
    id: 'verticalHoverLine',
    beforeDatasetsDraw(chart: any, args: any, plugins: any) {
      const {
        ctx,
        chartArea: { top, bottom, height },
      } = chart

      ctx.save()

      chart.getDatasetMeta(0).data.forEach((point: any, index: any) => {
        if (point.active) {
          ctx.beginPath()
          ctx.strokeStyle = 'grey'
          ctx.moveTo(point.x, top)
          ctx.lineTo(point.x, bottom)
          ctx.stroke()
        }
      })
      ctx.restore()
    },
  }

  return (
    <>
      <div
        style={{
          fontSize: '18px',
          fontWeight: '600',
          marginLeft: '10px',
          marginBottom: '24px',
        }}
      >
        Opportunity Win Rate
      </div>

      <div style={{ height: '50vh' }}>
        <Line
          data={getChartData(cohortsFound, cohortRates)}
          plugins={[verticalHoverLine]}
          options={{
            responsive: true,
            maintainAspectRatio: false,
            interaction: {
              mode: 'index',
              intersect: false,
            },
            scales: {
              y: {
                beginAtZero: true,
                max: 100,
                ticks: {
                  callback: value => `${value}%`,
                },
              },
              x: {
                grid: {
                  display: false,
                },
              },
            },
            plugins: {
              datalabels: {
                display: false,
              },
              legend: {
                position: 'bottom' as const,
              },
              tooltip: {
                callbacks: {
                  label: context => {
                    const value = context.parsed.y
                    return `${value.toFixed(1)}%`
                  },
                },
              },
            },
          }}
        />
      </div>
      <div style={{ width: '100%', overflowX: 'scroll', marginTop: '48px' }}>
        <Table className={styles.fixedFirstColumn} variant='simple' size='sm'>
          <Thead>
            <Tr>
              <Th>Cohort</Th>
              <Th style={{ whiteSpace: 'nowrap' }}>Cohort Size</Th>
              {[...Array(MONTHS_TO_ANALYZE)].map((_, i) => (
                <Th key={i}>Month {i}</Th>
              ))}
            </Tr>
          </Thead>
          <Tbody>
            {cohortsFound.map((cohort, index) => {
              const cohortSize = cohortSizes[cohort]
              const cohortColor = getColorByIndex(index)
              const rates = cohortRates[cohort]

              return (
                <Tr key={cohort}>
                  <Td key={`${cohort}-label`}>
                    <span
                      style={{
                        display: 'inline-block',
                        width: '10px',
                        height: '10px',
                        backgroundColor: cohortColor,
                        marginRight: '8px',
                      }}
                    ></span>
                    {renderCohortName(cohort, COHORT_BUCKET)}
                  </Td>
                  <Td key={`${cohort}-size`} style={{ whiteSpace: 'nowrap' }}>
                    {cohortSize}{' '}
                    <span style={{ color: '#71717a', fontSize: '0.9em' }}>
                      ({openCounts[cohort] || 0} open)
                    </span>
                  </Td>
                  {[...Array(MONTHS_TO_ANALYZE)].map((_, i) => {
                    return renderRateCell(rates[i], `${cohort}-month-${i}`)
                  })}
                </Tr>
              )
            })}
            <Tr key='average'>
              <Td key='average-label'>Average Win Rate</Td>
              <Td key='average-size'>-</Td>
              {[...Array(MONTHS_TO_ANALYZE)].map((_, i) => {
                const validCohorts = cohortsFound.filter(cohort => {
                  const monthDate = cohortToDate(cohort)
                  monthDate.setMonth(monthDate.getMonth() + i)
                  return monthDate <= new Date()
                })

                if (validCohorts.length === 0)
                  return <Td key={`average-month-${i}`}>-</Td>

                const totalWins = validCohorts.reduce(
                  (sum, cohort) => sum + (winCounts[cohort]?.[i] || 0),
                  0,
                )
                const totalLosses = validCohorts.reduce(
                  (sum, cohort) => sum + (lossCounts[cohort]?.[i] || 0),
                  0,
                )
                const totalOpps = validCohorts.reduce(
                  (sum, cohort) => sum + (cohortSizes[cohort] || 0),
                  0,
                )

                const rates = calculateRates(totalWins, totalLosses, totalOpps)
                return renderRateCell(rates, `average-month-${i}`)
              })}
            </Tr>
          </Tbody>
        </Table>
      </div>
    </>
  )
}

export default OpportunityWinRate
