/*
 * Decompiled with CFR 0.152.
 */
package org.apache.spark.sql.catalyst.expressions;

import org.apache.spark.sql.catalyst.util.ArrayData;
import org.apache.spark.sql.errors.QueryExecutionErrors;
import org.apache.spark.unsafe.types.UTF8String;

public class VectorFunctionImplUtils {
    public static Float vectorCosineSimilarity(ArrayData left, ArrayData right, UTF8String funcName) {
        int i;
        int rightLen;
        int leftLen = left.numElements();
        if (leftLen != (rightLen = right.numElements())) {
            throw QueryExecutionErrors.vectorDimensionMismatchError(funcName.toString(), leftLen, rightLen);
        }
        if (leftLen == 0) {
            return null;
        }
        float dotProduct = 0.0f;
        float norm1Sq = 0.0f;
        float norm2Sq = 0.0f;
        int simdLimit = leftLen / 8 * 8;
        for (i = 0; i < simdLimit; i += 8) {
            if (left.isNullAt(i) || left.isNullAt(i + 1) || left.isNullAt(i + 2) || left.isNullAt(i + 3) || left.isNullAt(i + 4) || left.isNullAt(i + 5) || left.isNullAt(i + 6) || left.isNullAt(i + 7) || right.isNullAt(i) || right.isNullAt(i + 1) || right.isNullAt(i + 2) || right.isNullAt(i + 3) || right.isNullAt(i + 4) || right.isNullAt(i + 5) || right.isNullAt(i + 6) || right.isNullAt(i + 7)) {
                return null;
            }
            float a0 = left.getFloat(i);
            float a1 = left.getFloat(i + 1);
            float a2 = left.getFloat(i + 2);
            float a3 = left.getFloat(i + 3);
            float a4 = left.getFloat(i + 4);
            float a5 = left.getFloat(i + 5);
            float a6 = left.getFloat(i + 6);
            float a7 = left.getFloat(i + 7);
            float b0 = right.getFloat(i);
            float b1 = right.getFloat(i + 1);
            float b2 = right.getFloat(i + 2);
            float b3 = right.getFloat(i + 3);
            float b4 = right.getFloat(i + 4);
            float b5 = right.getFloat(i + 5);
            float b6 = right.getFloat(i + 6);
            float b7 = right.getFloat(i + 7);
            dotProduct += a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 + a4 * b4 + a5 * b5 + a6 * b6 + a7 * b7;
            norm1Sq += a0 * a0 + a1 * a1 + a2 * a2 + a3 * a3 + a4 * a4 + a5 * a5 + a6 * a6 + a7 * a7;
            norm2Sq += b0 * b0 + b1 * b1 + b2 * b2 + b3 * b3 + b4 * b4 + b5 * b5 + b6 * b6 + b7 * b7;
        }
        while (i < leftLen) {
            if (left.isNullAt(i) || right.isNullAt(i)) {
                return null;
            }
            float a = left.getFloat(i);
            float b = right.getFloat(i);
            dotProduct += a * b;
            norm1Sq += a * a;
            norm2Sq += b * b;
            ++i;
        }
        float normProduct = (float)Math.sqrt(norm1Sq * norm2Sq);
        if (normProduct < Float.MIN_NORMAL) {
            return null;
        }
        return Float.valueOf(dotProduct / normProduct);
    }

    public static Float vectorInnerProduct(ArrayData left, ArrayData right, UTF8String funcName) {
        int i;
        int rightLen;
        int leftLen = left.numElements();
        if (leftLen != (rightLen = right.numElements())) {
            throw QueryExecutionErrors.vectorDimensionMismatchError(funcName.toString(), leftLen, rightLen);
        }
        if (leftLen == 0) {
            return Float.valueOf(0.0f);
        }
        float dotProduct = 0.0f;
        int simdLimit = leftLen / 8 * 8;
        for (i = 0; i < simdLimit; i += 8) {
            if (left.isNullAt(i) || left.isNullAt(i + 1) || left.isNullAt(i + 2) || left.isNullAt(i + 3) || left.isNullAt(i + 4) || left.isNullAt(i + 5) || left.isNullAt(i + 6) || left.isNullAt(i + 7) || right.isNullAt(i) || right.isNullAt(i + 1) || right.isNullAt(i + 2) || right.isNullAt(i + 3) || right.isNullAt(i + 4) || right.isNullAt(i + 5) || right.isNullAt(i + 6) || right.isNullAt(i + 7)) {
                return null;
            }
            float a0 = left.getFloat(i);
            float a1 = left.getFloat(i + 1);
            float a2 = left.getFloat(i + 2);
            float a3 = left.getFloat(i + 3);
            float a4 = left.getFloat(i + 4);
            float a5 = left.getFloat(i + 5);
            float a6 = left.getFloat(i + 6);
            float a7 = left.getFloat(i + 7);
            float b0 = right.getFloat(i);
            float b1 = right.getFloat(i + 1);
            float b2 = right.getFloat(i + 2);
            float b3 = right.getFloat(i + 3);
            float b4 = right.getFloat(i + 4);
            float b5 = right.getFloat(i + 5);
            float b6 = right.getFloat(i + 6);
            float b7 = right.getFloat(i + 7);
            dotProduct += a0 * b0 + a1 * b1 + a2 * b2 + a3 * b3 + a4 * b4 + a5 * b5 + a6 * b6 + a7 * b7;
        }
        while (i < leftLen) {
            if (left.isNullAt(i) || right.isNullAt(i)) {
                return null;
            }
            float a = left.getFloat(i);
            float b = right.getFloat(i);
            dotProduct += a * b;
            ++i;
        }
        return Float.valueOf(dotProduct);
    }

    public static Float vectorL2Distance(ArrayData left, ArrayData right, UTF8String funcName) {
        int i;
        int rightLen;
        int leftLen = left.numElements();
        if (leftLen != (rightLen = right.numElements())) {
            throw QueryExecutionErrors.vectorDimensionMismatchError(funcName.toString(), leftLen, rightLen);
        }
        if (leftLen == 0) {
            return Float.valueOf(0.0f);
        }
        float sumSq = 0.0f;
        int simdLimit = leftLen / 8 * 8;
        for (i = 0; i < simdLimit; i += 8) {
            if (left.isNullAt(i) || left.isNullAt(i + 1) || left.isNullAt(i + 2) || left.isNullAt(i + 3) || left.isNullAt(i + 4) || left.isNullAt(i + 5) || left.isNullAt(i + 6) || left.isNullAt(i + 7) || right.isNullAt(i) || right.isNullAt(i + 1) || right.isNullAt(i + 2) || right.isNullAt(i + 3) || right.isNullAt(i + 4) || right.isNullAt(i + 5) || right.isNullAt(i + 6) || right.isNullAt(i + 7)) {
                return null;
            }
            float a0 = left.getFloat(i);
            float a1 = left.getFloat(i + 1);
            float a2 = left.getFloat(i + 2);
            float a3 = left.getFloat(i + 3);
            float a4 = left.getFloat(i + 4);
            float a5 = left.getFloat(i + 5);
            float a6 = left.getFloat(i + 6);
            float a7 = left.getFloat(i + 7);
            float b0 = right.getFloat(i);
            float b1 = right.getFloat(i + 1);
            float b2 = right.getFloat(i + 2);
            float b3 = right.getFloat(i + 3);
            float b4 = right.getFloat(i + 4);
            float b5 = right.getFloat(i + 5);
            float b6 = right.getFloat(i + 6);
            float b7 = right.getFloat(i + 7);
            float d0 = a0 - b0;
            float d1 = a1 - b1;
            float d2 = a2 - b2;
            float d3 = a3 - b3;
            float d4 = a4 - b4;
            float d5 = a5 - b5;
            float d6 = a6 - b6;
            float d7 = a7 - b7;
            sumSq += d0 * d0 + d1 * d1 + d2 * d2 + d3 * d3 + d4 * d4 + d5 * d5 + d6 * d6 + d7 * d7;
        }
        while (i < leftLen) {
            if (left.isNullAt(i) || right.isNullAt(i)) {
                return null;
            }
            float a = left.getFloat(i);
            float b = right.getFloat(i);
            float diff = a - b;
            sumSq += diff * diff;
            ++i;
        }
        return Float.valueOf((float)Math.sqrt(sumSq));
    }

    public static Float vectorL1Norm(ArrayData vec) {
        int i;
        int len = vec.numElements();
        if (len == 0) {
            return Float.valueOf(0.0f);
        }
        float sum = 0.0f;
        int simdLimit = len / 8 * 8;
        for (i = 0; i < simdLimit; i += 8) {
            if (vec.isNullAt(i) || vec.isNullAt(i + 1) || vec.isNullAt(i + 2) || vec.isNullAt(i + 3) || vec.isNullAt(i + 4) || vec.isNullAt(i + 5) || vec.isNullAt(i + 6) || vec.isNullAt(i + 7)) {
                return null;
            }
            float a0 = vec.getFloat(i);
            float a1 = vec.getFloat(i + 1);
            float a2 = vec.getFloat(i + 2);
            float a3 = vec.getFloat(i + 3);
            float a4 = vec.getFloat(i + 4);
            float a5 = vec.getFloat(i + 5);
            float a6 = vec.getFloat(i + 6);
            float a7 = vec.getFloat(i + 7);
            sum += Math.abs(a0) + Math.abs(a1) + Math.abs(a2) + Math.abs(a3) + Math.abs(a4) + Math.abs(a5) + Math.abs(a6) + Math.abs(a7);
        }
        while (i < len) {
            if (vec.isNullAt(i)) {
                return null;
            }
            float a = vec.getFloat(i);
            sum += Math.abs(a);
            ++i;
        }
        return Float.valueOf(sum);
    }

    public static Float vectorL2Norm(ArrayData vec) {
        int i;
        int len = vec.numElements();
        if (len == 0) {
            return Float.valueOf(0.0f);
        }
        float sumSq = 0.0f;
        int simdLimit = len / 8 * 8;
        for (i = 0; i < simdLimit; i += 8) {
            if (vec.isNullAt(i) || vec.isNullAt(i + 1) || vec.isNullAt(i + 2) || vec.isNullAt(i + 3) || vec.isNullAt(i + 4) || vec.isNullAt(i + 5) || vec.isNullAt(i + 6) || vec.isNullAt(i + 7)) {
                return null;
            }
            float a0 = vec.getFloat(i);
            float a1 = vec.getFloat(i + 1);
            float a2 = vec.getFloat(i + 2);
            float a3 = vec.getFloat(i + 3);
            float a4 = vec.getFloat(i + 4);
            float a5 = vec.getFloat(i + 5);
            float a6 = vec.getFloat(i + 6);
            float a7 = vec.getFloat(i + 7);
            sumSq += a0 * a0 + a1 * a1 + a2 * a2 + a3 * a3 + a4 * a4 + a5 * a5 + a6 * a6 + a7 * a7;
        }
        while (i < len) {
            if (vec.isNullAt(i)) {
                return null;
            }
            float a = vec.getFloat(i);
            sumSq += a * a;
            ++i;
        }
        return Float.valueOf((float)Math.sqrt(sumSq));
    }

    public static Float vectorInfNorm(ArrayData vec) {
        int len = vec.numElements();
        if (len == 0) {
            return Float.valueOf(0.0f);
        }
        float maxAbs = 0.0f;
        for (int i = 0; i < len; ++i) {
            if (vec.isNullAt(i)) {
                return null;
            }
            float absVal = Math.abs(vec.getFloat(i));
            if (!(absVal > maxAbs)) continue;
            maxAbs = absVal;
        }
        return Float.valueOf(maxAbs);
    }

    public static ArrayData vectorNormalizeWithNorm(ArrayData vec, float norm) {
        int i;
        int len = vec.numElements();
        if (len == 0) {
            return vec;
        }
        if (norm < Float.MIN_NORMAL) {
            return null;
        }
        float[] result = new float[len];
        int simdLimit = len / 8 * 8;
        for (i = 0; i < simdLimit; i += 8) {
            if (vec.isNullAt(i) || vec.isNullAt(i + 1) || vec.isNullAt(i + 2) || vec.isNullAt(i + 3) || vec.isNullAt(i + 4) || vec.isNullAt(i + 5) || vec.isNullAt(i + 6) || vec.isNullAt(i + 7)) {
                return null;
            }
            result[i] = vec.getFloat(i) / norm;
            result[i + 1] = vec.getFloat(i + 1) / norm;
            result[i + 2] = vec.getFloat(i + 2) / norm;
            result[i + 3] = vec.getFloat(i + 3) / norm;
            result[i + 4] = vec.getFloat(i + 4) / norm;
            result[i + 5] = vec.getFloat(i + 5) / norm;
            result[i + 6] = vec.getFloat(i + 6) / norm;
            result[i + 7] = vec.getFloat(i + 7) / norm;
        }
        while (i < len) {
            if (vec.isNullAt(i)) {
                return null;
            }
            result[i] = vec.getFloat(i) / norm;
            ++i;
        }
        return ArrayData.toArrayData(result);
    }

    public static Float vectorNorm(ArrayData vec, float degree, UTF8String funcName) {
        if (degree == 1.0f) {
            return VectorFunctionImplUtils.vectorL1Norm(vec);
        }
        if (degree == 2.0f) {
            return VectorFunctionImplUtils.vectorL2Norm(vec);
        }
        if (degree == Float.POSITIVE_INFINITY) {
            return VectorFunctionImplUtils.vectorInfNorm(vec);
        }
        throw QueryExecutionErrors.invalidVectorNormDegreeError(funcName.toString(), degree);
    }

    public static ArrayData vectorNormalize(ArrayData vec, float degree, UTF8String funcName) {
        Float norm = VectorFunctionImplUtils.vectorNorm(vec, degree, funcName);
        if (norm == null) {
            return null;
        }
        return VectorFunctionImplUtils.vectorNormalizeWithNorm(vec, norm.floatValue());
    }
}

