import React, { useState, useEffect } from 'react'
import { styled } from '@mui/material/styles'
import { Button } from '@mui/material'

//Redux state selectors
import { MAX_RESPONSE_LENGTH } from '../../../redux/gptjSlicePg'
import {
  descRepetitionPenalty,
  descResponseLength,
  descTemperature,
  descTopK,
  descTopP,
} from '../config/inferenceExample.config'

//component
import { ExamplesModal } from './ExamplesModal'
import SliderItem from '../../../components/SliderItem'
import StyledLoader from './StyledLoader'

//utils
import { showNotification } from '../../../utils/toast'
import modelApiService from '../../../services/modelApi.service'

const PaneContainerMain = styled('div')`
  width: 100%;
  height: 100%;
  display: flex;
  flex-direction: row;
  justify-content: space-between;
  column-gap: 16px;
`
const PaneContainerLeft = styled('div')`
  width: 70%;
  height: 100%;
  display: flex;
  flex-direction: column;
  padding: 16px;
  border-radius: 10px;
  background-color: #fff;
  box-shadow: 0 11px 44px 0 rgb(18 18 19 / 10%);
`
const PaneContainerRight = styled('div')`
  width: 30%;
  height: 100%;
  display: flex;
  flex-direction: column;
  border-radius: 10px;
  background-color: #fff;
  padding: 16px;
  box-shadow: 0 11px 44px 0 rgb(18 18 19 / 10%);
`
const PlayGround = styled('div')`
  position: relative;
  width: 100%;
  height: 100%;
  display: flex;
  justify-content: center;
  flex-direction: column;
  row-gap: 16px;
`
const StyledTextArea = styled('textarea')`
  resize: vertical;
  width: 100%;
  min-height: 100%;
  background-color: #fff;
  font-size: 16px;
  overflow: auto;
  padding: 16px 24px;
  outline: none !important;
  border: none;
  transition: box-shadow 300ms ease, color 300ms ease, border-color 300ms ease;
  & :hover {
    //border-color: #121213;
    box-shadow: 0 2px 12px 0 rgb(18 18 19 / 14%);
  }
  &::-webkit-scrollbar {
    width: 8px;
  }
  &::-webkit-scrollbar-track {
    height: 80%;
    background-color: transparent;
    border-radius: 16px;
  }
  &::-webkit-scrollbar-thumb {
    background: #888;
    border-radius: 16px;
    height: 60px !important;
    background-color: rgba(0, 0, 0, 0.1);
  }
`
const ActionButtonContainer = styled('div')`
  display: flex;
  justify-content: space-between;
  padding: 16px;
  margin-top: auto;
`
const StyledButton = styled(Button)`
  padding: 8px 16px;
  font-size: 16px;
  font-family: Inter;
  font-weight: 500;
  line-height: 18px;
  border: 0;
  cursor: pointer;
  border-radius: 24px;
  color: #fff;
  background-color: #121213;
  text-align: center;
  transform-style: preserve-3d;
  text-transform: none;
  align-self: center;
  transition: background-color 300ms ease, transform 300ms ease,
    color 300ms ease, -webkit-transform 300ms ease;
  :hover {
    background-color: #121213;
    transform: translate3d(0px, -3px, 0.01px);
  }
  & .MuiButton-root {
    text-transform: none;
  }
`
const Gptj = () => {
  //component state
  const [exampleModalOpen, setExampleModalOpen] = useState(false)
  const [isLoading, setIsLoading] = useState(false)
  const [inputPrompt, setInputPrompt] = useState('')

  //Queue related states
  const [responseUrl, setResponseUrl] = useState(undefined)
  const [initialQueuePosition, setInitialQueuePosition] = useState(undefined)
  const [queuePosition, setQueuePosition] = useState(undefined)

  //Inference control parameters
  const [maxResponseLength, setMaxResponseLength] = useState(64)
  const [minResponseLength, setMinResponseLength] = useState(32)
  const [temperature, setTemperature] = useState(0.5)
  const [repetitionPenalty, setRepetitionPenalty] = useState(1.3)
  const [topK, setTopK] = useState(40)
  const [topP, setTopP] = useState(1.0)

  useEffect(() => {
    let pollingTimerId
    let count = 0

    const getTextInference = async () => {
      try {
        const response1 = await modelApiService.getInference(responseUrl)
        console.log('COUNT', count++)
        setQueuePosition((prev) => {
          if (prev && prev > 0) {
            return prev - 1
          }
          return prev
        })
        if (
          response1.data.success &&
          Array.isArray(response1.data.data.completion)
        ) {
          const source = response1.data.data.completion[0]
          setInputPrompt((prev) => prev + ' ' + source)
          clearInterval(pollingTimerId)
          setQueuePosition(undefined)
          showNotification('Request complete', 'success')
          setIsLoading(false)
        }
      } catch (error) {
        console.log('ERROR While polling:', error.message)
        clearInterval(pollingTimerId)
        setQueuePosition(undefined)
        showNotification('Something went wrong please try again', 'error')
        setIsLoading(false)
      }
    }

    if (queuePosition !== undefined) {
      pollingTimerId = setInterval(() => {
        getTextInference()
      }, 15 * 1000)
    }

    return () => {
      clearInterval(pollingTimerId)
    }
  }, [queuePosition])

  const setExample = (preset) => {
    setInputPrompt(preset.text)
    setMaxResponseLength(preset.responseLength)
    setMinResponseLength(32)
    setTemperature(preset.temperature)
    setRepetitionPenalty(preset.repetitionPenalty)
    setTopK(preset.topK)
    setTopP(preset.topP)
  }

  const handleSubmit = async () => {
    if (inputPrompt === '') {
      showNotification('Please enter some prompt')
      return
    }
    const reqBody = {
      prompt: inputPrompt,
      params: {
        max_length: maxResponseLength,
        min_length: minResponseLength,
        repetition_penalty: repetitionPenalty,
        temperature: temperature,
        top_k: topK,
        top_p: topP,
        // Default parameters
        num_beams: 1,
      },
    }

    try {
      setIsLoading(true)
      const response = await modelApiService.getInferenceData(`gpt-j`, reqBody)

      if (response.status === 200) {
        //console.log('response', response.data.data)
        setResponseUrl(response.data.data.responseUrl)
        setInitialQueuePosition(Number(response.data.data.queuePosition) + 1)
        setQueuePosition(Number(response.data.data.queuePosition) + 1)
      } else if (response.status === 429) {
        setIsLoading(false)
        showNotification('Quota exceeded, please try again after 5 min')
      } else {
        setIsLoading(false)
        showNotification('Something went wrong please try again.')
      }
    } catch (error) {
      console.log('ERROR:', error.message)
    }
  }

  return (
    <PaneContainerMain>
      <PaneContainerLeft>
        <PlayGround>
          {isLoading && (
            <StyledLoader
              queuePosition={queuePosition}
              initialQueuePosition={initialQueuePosition}
            />
          )}
          {!isLoading && (
            <StyledTextArea
              autoFocus={true}
              id='myTextArea'
              placeholder='Enter your prompt here. Eg: A step by step recipe to make bolognese pasta...'
              spellCheck={false}
              disabled={isLoading}
              rows={8}
              value={inputPrompt}
              onChange={(e) => setInputPrompt(e.target.value)}
            />
          )}
        </PlayGround>
        <ExamplesModal
          isOpen={exampleModalOpen}
          closeHandler={() => setExampleModalOpen(false)}
          setExample={setExample}
        />
      </PaneContainerLeft>
      <PaneContainerRight>
        <SliderItem
          label='Max response length'
          defaultValue={maxResponseLength}
          step={1}
          min={20}
          max={MAX_RESPONSE_LENGTH}
          description={descResponseLength}
          onChange={(newValue) => setMaxResponseLength(newValue)}
        />
        <SliderItem
          label='Min response length'
          defaultValue={minResponseLength}
          step={1}
          min={20}
          max={MAX_RESPONSE_LENGTH}
          description={descResponseLength}
          onChange={(newValue) => setMinResponseLength(newValue)}
        />
        <SliderItem
          label='Repetition penalty'
          defaultValue={repetitionPenalty}
          step={0.1}
          min={1.0}
          max={10.0}
          description={descRepetitionPenalty}
          onChange={(newValue) => setRepetitionPenalty(newValue)}
        />
        <SliderItem
          label='Temperature'
          defaultValue={temperature}
          step={0.01}
          min={0.01}
          max={1.0}
          description={descTemperature}
          onChange={(newValue) => setTemperature(newValue)}
        />
        <SliderItem
          label='Top k'
          defaultValue={topK}
          step={1}
          min={1}
          max={1000}
          description={descTopK}
          onChange={(newValue) => setTopK(newValue)}
        />
        <SliderItem
          label='Top p'
          defaultValue={topP}
          step={0.1}
          min={0.0}
          max={1.0}
          description={descTopP}
          onChange={(newValue) => setTopP(newValue)}
        />
        <ActionButtonContainer>
          <StyledButton
            variant='contained'
            disableFocusRipple
            disabled={isLoading}
            onClick={() => {
              setExampleModalOpen(true)
            }}
          >
            Examples
          </StyledButton>
          <StyledButton
            variant='contained'
            disableFocusRipple
            {...(isLoading ? { disabled: true } : {})}
            onClick={handleSubmit}
          >
            Submit
          </StyledButton>
        </ActionButtonContainer>
      </PaneContainerRight>
    </PaneContainerMain>
  )
}

export default Gptj
