/*
 * Decompiled with CFR 0.152.
 */
package org.janelia.thickness.inference;

import java.util.Arrays;
import mpicbg.models.AffineModel1D;
import mpicbg.models.IllDefinedDataPointsException;
import mpicbg.models.Model;
import mpicbg.models.NotEnoughDataPointsException;
import net.imglib2.Cursor;
import net.imglib2.RandomAccess;
import net.imglib2.RandomAccessible;
import net.imglib2.RandomAccessibleInterval;
import net.imglib2.img.array.ArrayImg;
import net.imglib2.img.array.ArrayImgFactory;
import net.imglib2.interpolation.randomaccess.NLinearInterpolatorFactory;
import net.imglib2.interpolation.randomaccess.NearestNeighborInterpolatorFactory;
import net.imglib2.transform.Transform;
import net.imglib2.type.NativeType;
import net.imglib2.type.Type;
import net.imglib2.type.numeric.RealType;
import net.imglib2.type.numeric.real.DoubleType;
import net.imglib2.util.ConstantUtils;
import net.imglib2.util.Util;
import net.imglib2.view.IntervalView;
import net.imglib2.view.TransformView;
import net.imglib2.view.Views;
import org.janelia.thickness.EstimateScalingFactors;
import org.janelia.thickness.ShiftCoordinates;
import org.janelia.thickness.inference.Options;
import org.janelia.thickness.inference.fits.AbstractCorrelationFit;
import org.janelia.thickness.inference.visitor.LazyVisitor;
import org.janelia.thickness.inference.visitor.Visitor;
import org.janelia.thickness.lut.LUTRealTransform;
import org.janelia.thickness.lut.PermutationTransform;
import org.janelia.utility.MatrixStripConversion;
import org.janelia.utility.arrays.ArraySortedIndices;
import org.janelia.utility.arrays.ReplaceNaNs;

public class InferFromMatrix {
    private final AbstractCorrelationFit correlationFit;

    public InferFromMatrix(AbstractCorrelationFit correlationFit) {
        this.correlationFit = correlationFit;
    }

    public <T extends RealType<T> & NativeType<T>> double[] estimateZCoordinates(RandomAccessibleInterval<T> matrix, double[] startingCoordinates, Options options) throws Exception {
        return this.estimateZCoordinates(matrix, startingCoordinates, new LazyVisitor(), options);
    }

    public <T extends RealType<T> & NativeType<T>> double[] estimateZCoordinates(RandomAccessibleInterval<T> inputMatrix, double[] startingCoordinates, Visitor visitor, Options options) throws Exception {
        return this.estimateZCoordinates(inputMatrix, startingCoordinates, new double[0], Arrays.stream(new double[startingCoordinates.length]).map(d -> 1.0).toArray(), ConstantUtils.constantRandomAccessibleInterval((Object)new DoubleType(1.0), (int)inputMatrix.numDimensions(), inputMatrix), Arrays.stream(new double[startingCoordinates.length]).map(d -> 1.0).toArray(), visitor, options);
    }

    public <T extends RealType<T> & NativeType<T>, W extends RealType<W>> double[] estimateZCoordinates(RandomAccessibleInterval<T> inputMatrix, double[] startingCoordinates, double[] functionEstimate, double[] scalingFactors, RandomAccessibleInterval<W> estimateWeights, double[] shiftWeights, Visitor visitor, Options options) throws Exception {
        Regularizer regularizer;
        double[] lut = (double[])startingCoordinates.clone();
        int n = (int)inputMatrix.dimension(0);
        int[] permutationLut = new int[n];
        int[] inverse = (int[])permutationLut.clone();
        int nMatrixDim = inputMatrix.numDimensions();
        RandomAccessibleInterval[] correlationFitsStore = new RandomAccessibleInterval[]{null};
        double[] permutedLut = (double[])lut.clone();
        double[] scalingFactorsPrevious = (double[])scalingFactors.clone();
        ArraySortedIndices.sort(permutedLut, permutationLut, inverse);
        RealType nanExtension = (RealType)((RealType)Util.getTypeFromInterval(inputMatrix)).createVariable();
        nanExtension.setReal(Double.NaN);
        ArrayImg inputScaledStrip = new ArrayImgFactory((NativeType)nanExtension.createVariable()).create(new long[]{2 * options.comparisonRange + 1, n});
        RandomAccessibleInterval<Type> inputScaledMatrix = MatrixStripConversion.stripToMatrix(inputScaledStrip, nanExtension.copy());
        Cursor source = Views.flatIterable(MatrixStripConversion.matrixToStrip(inputMatrix, options.comparisonRange, nanExtension.copy())).cursor();
        Cursor target = Views.flatIterable((RandomAccessibleInterval)inputScaledStrip).cursor();
        while (source.hasNext()) {
            ((RealType)target.next()).set((Type)source.next());
        }
        switch (options.regularizationType) {
            case BORDER: {
                regularizer = new BorderRegularization((Model<?>)new AffineModel1D(), n);
                break;
            }
            case IDENTITY: {
                regularizer = new IdentityRegularization((Model<?>)new AffineModel1D(), n);
                break;
            }
            case NONE: {
                regularizer = new NoRegularization();
                break;
            }
            default: {
                regularizer = new NoRegularization();
            }
        }
        double[] shiftsArray = new double[n];
        double[] weightSums = new double[n];
        for (int iteration = 0; iteration < options.nIterations; ++iteration) {
            long t0 = System.nanoTime();
            PermutationTransform permutation = new PermutationTransform(inverse, nMatrixDim, nMatrixDim);
            IntervalView matrix = Views.interval((RandomAccessible)new TransformView(inputMatrix, (Transform)permutation), inputMatrix);
            IntervalView scaledMatrix = Views.interval((RandomAccessible)new TransformView(inputScaledMatrix, (Transform)permutation), inputScaledMatrix);
            if (iteration == 0) {
                visitor.act(iteration, matrix, scaledMatrix, lut, permutationLut, inverse, scalingFactors, (RandomAccessibleInterval<double[]>)correlationFitsStore[0]);
            }
            Arrays.fill(shiftsArray, 0.0);
            Arrays.fill(weightSums, 0.0);
            double[] shifts = this.getMediatedShifts((RandomAccessibleInterval<T>)matrix, (RandomAccessibleInterval<T>)scaledMatrix, permutedLut, scalingFactors, iteration, (RandomAccessibleInterval<double[]>[])correlationFitsStore, shiftsArray, weightSums, estimateWeights, shiftWeights, options);
            this.applyShifts(permutedLut, shifts, startingCoordinates, permutation.copyToDimension(1, 1), options);
            ReplaceNaNs.replace(permutedLut);
            if (!options.withReorder.booleanValue()) {
                this.preventReorder(permutedLut, options);
            }
            regularizer.regularize(permutedLut, options);
            this.updateArray(permutedLut, lut, inverse);
            this.updateArray(scalingFactors, scalingFactorsPrevious, inverse);
            permutedLut = (double[])lut.clone();
            ArraySortedIndices.sort(permutedLut, permutationLut, inverse);
            this.updateArray(scalingFactorsPrevious, scalingFactors, permutationLut);
            long t1 = System.nanoTime();
            visitor.act(iteration + 1, matrix, scaledMatrix, lut, permutationLut, inverse, scalingFactors, (RandomAccessibleInterval<double[]>)correlationFitsStore[0]);
        }
        return lut;
    }

    public <T extends RealType<T>, W extends RealType<W>> double[] getMediatedShifts(RandomAccessibleInterval<T> matrix, RandomAccessibleInterval<T> scaledMatrix, double[] lut, double[] scalingFactors, int iteration, RandomAccessibleInterval<double[]>[] correlationFitsStore, double[] shiftsArray, double[] weightSums, RandomAccessibleInterval<W> estimateWeightMatrix, double[] shiftWeights, Options options) throws NotEnoughDataPointsException, IllDefinedDataPointsException {
        int nMatrixDimensions = scaledMatrix.numDimensions();
        LUTRealTransform transform = new LUTRealTransform(lut, nMatrixDimensions, nMatrixDimensions);
        boolean isIdentity = InferFromMatrix.isIdentity(lut);
        RandomAccessibleInterval<double[]> fits = this.correlationFit.estimateFromMatrix(scaledMatrix, lut, transform, estimateWeightMatrix, options, isIdentity ? new NearestNeighborInterpolatorFactory() : new NLinearInterpolatorFactory());
        correlationFitsStore[0] = fits;
        EstimateScalingFactors.estimateQuadraticFromMatrix(matrix, scalingFactors, lut, fits, options.scalingFactorRegularizerWeight, options.comparisonRange, options.scalingFactorEstimationIterations, estimateWeightMatrix);
        RandomAccess matrixRA = matrix.randomAccess();
        RandomAccess scaledMatrixRA = scaledMatrix.randomAccess();
        for (int z = 0; z < lut.length; ++z) {
            matrixRA.setPosition(z, 0);
            scaledMatrixRA.setPosition(z, 0);
            int max = Math.min(lut.length, z + options.comparisonRange + 1);
            for (int k = Math.max(0, z - options.comparisonRange); k < max; ++k) {
                matrixRA.setPosition(k, 1);
                scaledMatrixRA.setPosition(k, 1);
                ((RealType)scaledMatrixRA.get()).set((Type)matrixRA.get());
                if (k == z) continue;
                ((RealType)scaledMatrixRA.get()).mul(scalingFactors[z] * scalingFactors[k]);
            }
        }
        ShiftCoordinates.collectShiftsFromMatrix(lut, scaledMatrix, scalingFactors, fits, shiftsArray, weightSums, shiftWeights, options);
        double[] mediatedShifts = new double[lut.length];
        InferFromMatrix.mediateShifts(shiftsArray, weightSums, mediatedShifts);
        return mediatedShifts;
    }

    public void applyShifts(double[] coordinates, double[] shifts, double[] regularizerCoordinates, PermutationTransform permutation, Options options) {
        double inverseCoordinateUpdateRegularizerWeight = 1.0 - options.coordinateUpdateRegularizerWeight;
        for (int i = 0; i < coordinates.length; ++i) {
            double val = coordinates[i];
            double shift = shifts[i];
            if (!Double.isFinite(shift)) continue;
            val += options.shiftProportion * shifts[i];
            coordinates[i] = val = options.coordinateUpdateRegularizerWeight * regularizerCoordinates[permutation.applyInverse(i)] + inverseCoordinateUpdateRegularizerWeight * val;
        }
    }

    public void preventReorder(double[] coordinates, Options options) {
        for (int i = 1; i < coordinates.length; ++i) {
            double previous = coordinates[i - 1];
            if (!(previous > coordinates[i])) continue;
            coordinates[i] = previous + options.minimumSectionThickness;
        }
    }

    public void updateArray(double[] source, double[] target, int[] permutation) {
        for (int i = 0; i < target.length; ++i) {
            target[permutation[i]] = source[i];
        }
    }

    public static void mediateShifts(double[] shifts, double[] weightSums, double[] mediatedShifts) {
        for (int i = 0; i < mediatedShifts.length; ++i) {
            mediatedShifts[i] = shifts[i] / weightSums[i];
        }
    }

    public static <T extends RealType<T>> void fillWeightStrip(RandomAccessibleInterval<T> strip, double[] weights, T t) {
        int n = weights.length;
        int range = (int)(strip.dimension(0) / 2L);
        for (RealType s : Views.flatIterable(strip)) {
            s.setReal(Double.NaN);
        }
        RandomAccess access = MatrixStripConversion.stripToMatrix(strip, t).randomAccess();
        block1: for (int z1 = 0; z1 < n; ++z1) {
            access.setPosition(z1, 1);
            double w1 = weights[z1];
            for (int z2 = z1 - range; z2 <= z1 + range; ++z2) {
                if (z2 < 0) continue;
                if (z2 >= n) continue block1;
                access.setPosition(z2, 0);
                ((RealType)access.get()).setReal(w1 * weights[z2]);
            }
        }
    }

    public static boolean isIdentity(double[] mapping) {
        for (int i = 0; i < mapping.length; ++i) {
            double v = mapping[i];
            if ((double)Math.round(v) == v && (double)i == v) continue;
            return false;
        }
        return true;
    }

    public static class IdentityRegularization
    extends ModelRegularization {
        public IdentityRegularization(Model<?> m, int length) {
            super(m, IdentityRegularization.range(0, length, 1), IdentityRegularization.constVals(length, 1.0));
        }

        @Override
        protected double[] extractRelevantCoordinates(double[] coordinates) {
            return coordinates;
        }

        public static double[] range(int start, int stop, int step) {
            double[] result = new double[(stop - start) / step];
            int i = 0;
            while (i < result.length) {
                result[i] = start;
                ++i;
                start += step;
            }
            return result;
        }

        public static double[] constVals(int length, double val) {
            double[] result = new double[length];
            for (int i = 0; i < result.length; ++i) {
                result[i] = val;
            }
            return result;
        }
    }

    public static class BorderRegularization
    extends ModelRegularization {
        private final double[] relevantCoordinates = new double[2];

        public BorderRegularization(Model<?> m, int length) {
            super(m, new double[]{0.0, length - 1}, new double[]{1.0, 1.0});
        }

        @Override
        protected double[] extractRelevantCoordinates(double[] coordinates) {
            this.relevantCoordinates[0] = coordinates[0];
            this.relevantCoordinates[1] = coordinates[coordinates.length - 1];
            return this.relevantCoordinates;
        }
    }

    public static abstract class ModelRegularization
    implements Regularizer {
        private final Model<?> m;
        private final double[] regularizationValues;
        private final double[] weights;
        private final double[] dummy;

        protected ModelRegularization(Model<?> m, double[] regularizationValues, double[] weights) {
            this.m = m;
            this.regularizationValues = regularizationValues;
            this.weights = weights;
            this.dummy = new double[1];
        }

        protected abstract double[] extractRelevantCoordinates(double[] var1);

        @Override
        public void regularize(double[] coordinates, Options options) throws NotEnoughDataPointsException, IllDefinedDataPointsException {
            double[] relevantCoordinates = this.extractRelevantCoordinates(coordinates);
            this.m.fit((double[][])new double[][]{relevantCoordinates}, (double[][])new double[][]{this.regularizationValues}, this.weights);
            for (int i = 0; i < coordinates.length; ++i) {
                this.dummy[0] = coordinates[i];
                this.m.applyInPlace(this.dummy);
                coordinates[i] = this.dummy[0];
            }
        }
    }

    public static class NoRegularization
    implements Regularizer {
        @Override
        public void regularize(double[] coordinates, Options options) {
        }
    }

    public static interface Regularizer {
        public void regularize(double[] var1, Options var2) throws Exception;
    }

    public static enum RegularizationType {
        NONE,
        IDENTITY,
        BORDER;

    }
}

