import React, { FC, useCallback, useState } from "react"
import { useSelector } from "react-redux"
import * as THREE from "three"

import { AnimatedGltfUrlModel } from "src/components/Canvas/Viewer/SceneItems/GltfUrlModel"
import { DisplayMode, viewOptionsSelectors } from "src/store/ui/viewOptions"
import { useMachineCoordsStore } from "src/store/zustandMachine"
import { vector3Attribute } from "./visUtils"

interface OvertravelSceneProps {
  url: string
}

export const OvertravelScene: FC<OvertravelSceneProps> = ({ url }) => {
  const showWcs = useSelector(viewOptionsSelectors.showWcs)
  const overtravelDisplayMode = useSelector(viewOptionsSelectors.overtravelLimitsDisplayMode)
  const visible = showWcs && overtravelDisplayMode !== DisplayMode.Hidden
  const { toolLength } = useMachineCoordsStore.getState()
  const position = new THREE.Vector3(0, 0, -toolLength)
  const [mesh, setMesh] = useState<THREE.Mesh>()

  const sceneCallback = useCallback(
    (group: THREE.Group) => {
      const child = group.children[0]?.children?.[0]
      if (child instanceof THREE.Mesh) {
        const material = child.material

        const ramp: number[] = []

        const position = child.geometry.getAttribute("position").array
        const min = [1e6, 1e6, 1e6]
        const max = [-1e6, -1e6, -1e6]
        for (let i = 0; i < position.length; i += 3) {
          min[0] = Math.min(min[0], position[i])
          min[1] = Math.min(min[1], position[i + 1])
          min[2] = Math.min(min[2], position[i + 2])
          max[0] = Math.max(max[0], position[i])
          max[1] = Math.max(max[1], position[i + 1])
          max[2] = Math.max(max[2], position[i + 2])
        }
        const center = [(min[0] + max[0]) * 0.5, (min[1] + max[1]) * 0.5, (min[2] + max[2]) * 0.5]
        for (let i = 0; i < position.length; i += 3) {
          ramp.push(position[i] < center[0] ? 0.0 : 1.0)
          ramp.push(position[i + 1] < center[1] ? 0.0 : 1.0)
          ramp.push(position[i + 2] < center[2] ? 0.0 : 1.0)
        }
        child.geometry.setAttribute("ramp", vector3Attribute(ramp))

        material.customProgramCacheKey = () => {
          return "travel_limit_shader"
        }

        material.onBeforeCompile = (shader: THREE.Shader) => {
          shader.vertexShader = shader.vertexShader.replace(
            "#include <common>",
            `
              attribute vec3 ramp;

              varying vec3 v_ramp;
              varying vec3 v_position;
              varying vec3 v_normal;
              #include <common>
          `
          )

          shader.vertexShader = shader.vertexShader.replace(
            "#include <fog_vertex>",
            `
              v_ramp = ramp;
              v_position = position;
              v_normal = normal;
              #include <fog_vertex>
          `
          )

          shader.fragmentShader = shader.fragmentShader.replace(
            "#include <common>",
            `
              varying vec3 v_ramp;
              varying vec3 v_position;
              varying vec3 v_normal;

              const vec3 PLANE_NORMAL = normalize(vec3(1.0, 1.0, 1.0));

              #include <common>
          `
          )

          shader.fragmentShader = shader.fragmentShader.replace(
            "#include <dithering_fragment>",
            `
                #include <dithering_fragment>
                
                float STRIPE_SCALE = 18.0;
                vec3 BASE_COLOR = vec3(0.98, 0.88, 0.31);
                float STRIPE_RAMP = 0.025;

                vec3 n_a = 1.0 - abs(v_normal);
                
                vec3 coord_2d = n_a * abs((v_ramp * n_a - 0.5) * 2.0);
                float max_coord_2d = max(max(coord_2d.x, coord_2d.y), coord_2d.z);
                float base = 0.15;
                float alpha_factor = smoothstep(0.955, 0.96, max_coord_2d) * (1.0 - base) + base;
                gl_FragColor.a *= alpha_factor;

                float dist = dot(v_position, PLANE_NORMAL);
                float dist_mod = abs(mod(dist, STRIPE_SCALE) - STRIPE_SCALE * 0.5) / (STRIPE_SCALE * 0.5);
                float dist_mod_smooth = smoothstep(0.5 - STRIPE_RAMP, 0.5 + STRIPE_RAMP, dist_mod);
                gl_FragColor.rgb = vec3(dist_mod_smooth) * BASE_COLOR;
        `
          )
        }
        setMesh(child)
      } else {
        setMesh(undefined)
      }
    },
    [setMesh]
  )

  if (mesh) {
    const material = mesh.material
    if (material instanceof THREE.MeshStandardMaterial) {
      if (overtravelDisplayMode === DisplayMode.Transparent) {
        material.opacity = 0.3
      } else {
        material.opacity = 1.0
      }
    }
  }

  return (
    <group visible={visible} position={position}>
      <AnimatedGltfUrlModel
        url={url}
        flatShading={true}
        material={{
          flatShading: false,
          metalness: 1.0,
          roughness: 0.6,
          color: new THREE.Color("#ffffff"),
          wireColor: new THREE.Color("#000000"),
          transparent: true,
          side: THREE.DoubleSide,
        }}
        sceneCallback={sceneCallback}
      />
    </group>
  )
}
