import Amplify, { API, graphqlOperation } from 'aws-amplify'
import { listLabeledSensorDatas } from '../graphql/queries'
import { Matrix } from 'ml-matrix'
import LogisticRegression from 'ml-logistic-regression'

const gestureModel = new LogisticRegression({
  numSteps: 10000,
  learningRate: 5e-2
})

export const predictGesture = sample => {
  const m = new Matrix([sample])
  const result = gestureModel.predict(m)
  console.log(result)
  return result
}

export const trainGestureModel = async () => {
  const allSensorData = await fetchAllSensorData()
  const xs = new Matrix(allSensorData.map(row => row.slice(0, -1)))
  const ys = Matrix.columnVector(allSensorData.map(row => row.slice(-1)))
  gestureModel.train(xs, ys)
  console.log(
    `Trained logistic regression with ${allSensorData.length} samples`
  )
  const trainedPred = Matrix.columnVector(gestureModel.predict(xs))
  console.log('Training predictions', trainedPred)
  const diff = Matrix.sub(trainedPred, ys)
  console.log('Diff', diff)
  const abs = Matrix.abs(diff)
  console.log('Abs', abs)
  const loss = abs.mean()
  console.log(`Training loss: ${loss}`)
}

const fetchAllSensorData = async () => {
  const sensorData = []
  let nextToken = null
  do {
    const batch = await fetchSensorBatch(nextToken)
    nextToken = batch.nextToken
    sensorData.splice(-1, 0, ...batch.data)
  } while (nextToken != null)

  return sensorData
}

const fetchSensorBatch = async nextBatchToken => {
  const { data } = await API.graphql(
    graphqlOperation(listLabeledSensorDatas, {
      limit: 25,
      nextToken: nextBatchToken
    })
  )
  const { items, nextToken } = data?.listLabeledSensorDatas

  const batch = {
    data: [],
    nextToken
  }
  data.listLabeledSensorDatas.items.forEach(sensor => {
    batch.data.push([
      sensor.alpha1,
      sensor.beta1,
      sensor.gamma1,
      sensor.alpha2,
      sensor.beta2,
      sensor.gamma2,
      sensor.alpha3,
      sensor.beta3,
      sensor.gamma3,
      sensor.alpha4,
      sensor.beta4,
      sensor.gamma4,
      sensor.alpha5,
      sensor.beta5,
      sensor.gamma5,
      sensor.alpha6,
      sensor.beta6,
      sensor.gamma6,
      sensor.alpha7,
      sensor.beta7,
      sensor.gamma7,
      sensor.alpha8,
      sensor.beta8,
      sensor.gamma8,
      sensor.alpha9,
      sensor.beta9,
      sensor.gamma9,
      sensor.alpha10,
      sensor.beta10,
      sensor.gamma10,
      //  TODO... don't normalize the labels to 0 (raised), 1 (not raised)...  I think the logit
      //  model should be able to predict classes of 1 and 2 but something wasn't working.
      sensor.label - 1
    ])
  })

  return batch
}
