import { Table, Tbody, Td, Th, Thead, Tr } from '@chakra-ui/react'
import {
  BarElement,
  CategoryScale,
  Chart as ChartJS,
  Legend,
  LinearScale,
  LineElement,
  PointElement,
  Title,
  Tooltip,
} from 'chart.js'
import { useEffect, useMemo, useRef } from 'react'
import { Line } from 'react-chartjs-2'
import { getColorByIndex } from '../../../lib/colors'
import { CohortCell } from '../CohortCell'
import CohortDrillDown, { useCohortDrillDown } from '../CohortDrillDown'
import MetricHeader from '../MetricHeader'
import styles from '../pipelineAnalytics.module.css'
import { BaseMetricProps, BucketingProps, BucketType } from '../Types'
import {
  cohortNameFromDate,
  HeaderContainer,
  renderCohortName,
} from '../Utilities'

const MONTHS_TO_ANALYZE = 12

// Explicitly define the props type for this metric
type Props = BaseMetricProps & BucketingProps

type CohortRates = {
  winRate: number
  lossRate: number
}

const cohortToDate = (
  cohort: string,
  bucketingScheme: BucketType,
  fiscalYearStartMonth: number = 1,
) => {
  switch (bucketingScheme) {
    case 'quarter': {
      const [year, quarter] = cohort.split('-Q')
      const fiscalQuarter = parseInt(quarter)
      const baseMonth = (fiscalQuarter - 1) * 3
      const calendarMonth = (baseMonth + fiscalYearStartMonth - 1) % 12
      const calendarYear =
        parseInt(year) + Math.floor((baseMonth + fiscalYearStartMonth - 1) / 12)
      return new Date(calendarYear, calendarMonth, 1)
    }
    case 'month':
      return new Date(`${cohort}-01`)
    default:
      throw new Error(
        `Unsupported bucket type: ${bucketingScheme}. Expected 'month' or 'quarter'.`,
      )
  }
}

const isMonthInFuture = (baseDate: Date, monthsToAdd: number) => {
  const futureDate = new Date(baseDate)
  futureDate.setMonth(futureDate.getMonth() + monthsToAdd)
  return futureDate > new Date()
}

const calculateRates = (
  wins: number,
  losses: number,
  total: number,
): CohortRates => {
  if (!total) return { winRate: 0, lossRate: 0 }
  return {
    winRate: (wins / total) * 100,
    lossRate: (losses / total) * 100,
  }
}

const calculateCohortRates = (
  cohortsFound: string[],
  winCounts: Record<string, Record<number, number>>,
  lossCounts: Record<string, Record<number, number>>,
  cohortSizes: Record<string, number>,
  bucketingScheme: BucketType,
  fiscalYearStartMonth: number,
) => {
  return cohortsFound.reduce((acc, cohort) => {
    const cohortDate = cohortToDate(
      cohort,
      bucketingScheme,
      fiscalYearStartMonth,
    )
    const cohortSize = cohortSizes[cohort]
    const wins = winCounts[cohort] || {}
    const losses = lossCounts[cohort] || {}

    acc[cohort] = [...Array(MONTHS_TO_ANALYZE)].map((_, i) => {
      if (isMonthInFuture(cohortDate, i)) return null
      return calculateRates(wins[i] || 0, losses[i] || 0, cohortSize)
    })

    return acc
  }, {} as Record<string, (CohortRates | null)[]>)
}

const getChartData = (
  cohortsFound: string[],
  cohortRates: Record<string, (CohortRates | null)[]>,
  bucketingScheme: BucketType,
) => {
  const labels = [...Array(MONTHS_TO_ANALYZE)].map((_, i) => `Month ${i}`)

  const datasets = cohortsFound.map((cohort, index) => {
    const rates = cohortRates[cohort]
    const data = rates
      .map(rate => rate?.winRate ?? null)
      .filter((_, index, array) => {
        const lastValidIndex = array.findLastIndex(v => v !== null)
        return index <= lastValidIndex
      })

    return {
      label: renderCohortName(cohort, bucketingScheme),
      data,
      borderColor: getColorByIndex(index, cohortsFound.length),
      backgroundColor: getColorByIndex(index, cohortsFound.length),
      tension: 0,
      fill: false,
    }
  })

  return { labels, datasets }
}

// TODO: Current assumption is that opp stages only progress forward, address this edge case
function OpportunityWinRate({
  data,
  fiscalYearStartMonth,
  salesforceInstanceUrl,
  accountId,
  name,
  icon,
  backgroundColor,
  description,
  bucketingScheme,
  cohortField,
}: Props) {
  const {
    cohortsFound,
    winCounts,
    lossCounts,
    openCounts,
    cohortSizes,
    cohortOpportunities,
  } = useMemo(() => {
    if (!data || !data.opportunities || !data.opportunity_history)
      return {
        cohortsFound: [],
        winCounts: {},
        lossCounts: {},
        openCounts: {},
        cohortSizes: {},
        cohortOpportunities: {},
      }

    const cohortCounts: Record<string, number> = {}
    const cohortOpps: Record<string, Set<any>> = {}
    const openCounts: Record<string, number> = {}
    const cohortOpportunities: Record<string, any[]> = {}

    data.opportunities.forEach((opp: any) => {
      const cohort = cohortNameFromDate(
        opp[cohortField!.name] as string,
        bucketingScheme,
        fiscalYearStartMonth,
        accountId,
      )

      cohortCounts[cohort] = (cohortCounts[cohort] || 0) + 1

      if (!opp.IsClosed) {
        openCounts[cohort] = (openCounts[cohort] || 0) + 1
      }

      if (!cohortOpps[cohort]) {
        cohortOpps[cohort] = new Set()
      }
      cohortOpps[cohort].add(opp)

      if (!cohortOpportunities[cohort]) {
        cohortOpportunities[cohort] = []
      }
      cohortOpportunities[cohort].push(opp)
    })

    const sortedCohorts = Object.keys(cohortCounts).sort()

    const winCounts: Record<string, Record<number, number>> = {}
    const lossCounts: Record<string, Record<number, number>> = {}

    sortedCohorts.forEach(cohort => {
      winCounts[cohort] = {}
      lossCounts[cohort] = {}

      cohortOpps[cohort].forEach(opp => {
        // Skip if not closed
        if (!opp.IsClosed) {
          return
        }

        const oppCreatedDate = new Date(opp[cohortField!.name])
        const closeDate = new Date(opp.CloseDate)

        const monthsDiff =
          (closeDate.getFullYear() - oppCreatedDate.getFullYear()) * 12 +
          (closeDate.getMonth() - oppCreatedDate.getMonth())

        // Only process if closure happened within analysis window
        if (monthsDiff >= 0 && monthsDiff < MONTHS_TO_ANALYZE) {
          for (let month = monthsDiff; month < MONTHS_TO_ANALYZE; month++) {
            if (opp.IsWon) {
              winCounts[cohort][month] = (winCounts[cohort][month] || 0) + 1
            } else {
              lossCounts[cohort][month] = (lossCounts[cohort][month] || 0) + 1
            }
          }
        }
      })
    })

    return {
      cohortsFound: sortedCohorts,
      winCounts,
      lossCounts,
      openCounts,
      cohortSizes: cohortCounts,
      cohortOpportunities,
    }
  }, [data, fiscalYearStartMonth, accountId, bucketingScheme, cohortField])

  const { selectedCohort, setSelectedCohort } = useCohortDrillDown(
    cohortField!,
    bucketingScheme,
  )

  ChartJS.register(
    CategoryScale,
    LinearScale,
    BarElement,
    Title,
    Tooltip,
    Legend,
    PointElement,
    LineElement,
  )

  const cohortRates = useMemo(
    () =>
      calculateCohortRates(
        cohortsFound,
        winCounts,
        lossCounts,
        cohortSizes,
        bucketingScheme,
        fiscalYearStartMonth,
      ),
    [
      cohortsFound,
      winCounts,
      lossCounts,
      cohortSizes,
      bucketingScheme,
      fiscalYearStartMonth,
    ],
  )

  const renderRateCell = (rates: CohortRates | null, key?: string) => (
    <Td key={key}>
      {rates ? (
        <div style={{ display: 'flex' }}>
          <span
            style={{
              color: '#16a34a',
              width: '50px',
              textAlign: 'right',
              marginRight: '16px',
            }}
          >
            {rates.winRate.toFixed(1)}%
          </span>
          <span style={{ color: '#dc2626', width: '50px', textAlign: 'right' }}>
            {rates.lossRate.toFixed(1)}%
          </span>
        </div>
      ) : (
        '-'
      )}
    </Td>
  )

  const verticalHoverLine = {
    id: 'verticalHoverLine',
    beforeDatasetsDraw(chart: any, args: any, plugins: any) {
      const {
        ctx,
        chartArea: { top, bottom, height },
      } = chart

      ctx.save()

      chart.getDatasetMeta(0).data.forEach((point: any, index: any) => {
        if (point.active) {
          ctx.beginPath()
          ctx.strokeStyle = 'grey'
          ctx.moveTo(point.x, top)
          ctx.lineTo(point.x, bottom)
          ctx.stroke()
        }
      })
      ctx.restore()
    },
  }

  // Sticky second column
  const tableRef = useRef<HTMLTableElement>(null)

  // Function to update the sticky column positions
  const updateStickyColumnPositions = () => {
    if (!tableRef.current) return

    const firstCol = tableRef.current.querySelector(
      'th:nth-child(1)',
    ) as HTMLElement
    const secondColCells = tableRef.current.querySelectorAll(
      'th:nth-child(2), td:nth-child(2)',
    )

    if (firstCol) {
      const firstColWidth = firstCol.offsetWidth // Get the width of the first column
      secondColCells.forEach((cell: any) => {
        cell.style.left = `${firstColWidth - 1}px` // Set left dynamically
      })
    }
  }

  // Run on mount and on window resize
  useEffect(() => {
    updateStickyColumnPositions() // Initial setup

    window.addEventListener('resize', updateStickyColumnPositions)
    return () =>
      window.removeEventListener('resize', updateStickyColumnPositions)
  }, [])

  return (
    <>
      <HeaderContainer>
        <MetricHeader
          icon={icon}
          backgroundColor={backgroundColor}
          name={name}
          description={description}
        />
      </HeaderContainer>

      <div style={{ height: '50vh' }}>
        <Line
          data={getChartData(cohortsFound, cohortRates, bucketingScheme)}
          plugins={[verticalHoverLine]}
          options={{
            responsive: true,
            maintainAspectRatio: false,
            interaction: {
              mode: 'index',
              intersect: false,
            },
            scales: {
              y: {
                beginAtZero: true,
                ticks: {
                  callback: value => `${value}%`,
                },
              },
              x: {
                grid: {
                  display: false,
                },
              },
            },
            plugins: {
              datalabels: {
                display: false,
              },
              legend: {
                position: 'bottom' as const,
              },
              tooltip: {
                callbacks: {
                  label: context => {
                    const value = context.parsed.y
                    return `${value.toFixed(1)}%`
                  },
                },
              },
            },
          }}
        />
      </div>
      <div style={{ width: '100%', overflowX: 'scroll', marginTop: '48px' }}>
        <Table
          className={`${styles.fixedFirstColumn} ${styles.fixedSecondColumn}`}
          variant='simple'
          size='sm'
          ref={tableRef}
        >
          <Thead>
            <Tr>
              <Th>{cohortField!.label}</Th>
              <Th style={{ whiteSpace: 'nowrap' }}>Cohort Size</Th>
              {[...Array(MONTHS_TO_ANALYZE)].map((_, i) => (
                <Th key={i}>Month {i}</Th>
              ))}
            </Tr>
          </Thead>
          <Tbody>
            {cohortsFound.map((cohort, index) => {
              const cohortSize = cohortSizes[cohort]
              const cohortColor = getColorByIndex(index, cohortsFound.length)
              const rates = cohortRates[cohort]

              return (
                <Tr key={cohort}>
                  <Td>
                    <CohortCell
                      onClick={() => setSelectedCohort(cohort)}
                      color={cohortColor}
                    >
                      {renderCohortName(cohort, bucketingScheme)}
                    </CohortCell>
                  </Td>
                  <Td key={`${cohort}-size`} style={{ whiteSpace: 'nowrap' }}>
                    {cohortSize}{' '}
                    <span style={{ color: '#71717a', fontSize: '0.9em' }}>
                      ({openCounts[cohort] || 0} open)
                    </span>
                  </Td>
                  {[...Array(MONTHS_TO_ANALYZE)].map((_, i) => {
                    return renderRateCell(rates[i], `${cohort}-month-${i}`)
                  })}
                </Tr>
              )
            })}
            <Tr key='average'>
              <Td key='average-label'>Average Win Rate</Td>
              <Td key='average-size'>-</Td>
              {[...Array(MONTHS_TO_ANALYZE)].map((_, i) => {
                const validCohorts = cohortsFound.filter(cohort => {
                  const monthDate = cohortToDate(
                    cohort,
                    bucketingScheme,
                    fiscalYearStartMonth,
                  )
                  monthDate.setMonth(monthDate.getMonth() + i)
                  return monthDate <= new Date()
                })

                if (validCohorts.length === 0)
                  return <Td key={`average-month-${i}`}>-</Td>

                const totalWins = validCohorts.reduce(
                  (sum, cohort) => sum + (winCounts[cohort]?.[i] || 0),
                  0,
                )
                const totalLosses = validCohorts.reduce(
                  (sum, cohort) => sum + (lossCounts[cohort]?.[i] || 0),
                  0,
                )
                const totalOpps = validCohorts.reduce(
                  (sum, cohort) => sum + (cohortSizes[cohort] || 0),
                  0,
                )

                const rates = calculateRates(totalWins, totalLosses, totalOpps)
                return renderRateCell(rates, `average-month-${i}`)
              })}
            </Tr>
          </Tbody>
        </Table>
      </div>

      <CohortDrillDown
        selectedCohort={selectedCohort}
        opportunitiesByCohort={cohortOpportunities}
        cohortField={cohortField!.name}
        salesforceInstanceUrl={salesforceInstanceUrl}
        renderCohortName={cohort => renderCohortName(cohort, bucketingScheme)}
      />
    </>
  )
}

export default OpportunityWinRate
