/*
 * Decompiled with CFR 0.152.
 */
package org.apache.flink.table.planner.functions.sql.ml;

import java.util.Collections;
import java.util.List;
import org.apache.calcite.rel.type.RelDataType;
import org.apache.calcite.rel.type.RelDataTypeFactory;
import org.apache.calcite.sql.SqlCallBinding;
import org.apache.calcite.sql.SqlOperandCountRange;
import org.apache.calcite.sql.SqlOperator;
import org.apache.calcite.sql.SqlOperatorBinding;
import org.apache.calcite.sql.type.SqlOperandCountRanges;
import org.apache.calcite.sql.type.SqlOperandMetadata;
import org.apache.calcite.sql.type.SqlTypeName;
import org.apache.flink.table.planner.functions.sql.ml.SqlMLTableFunction;
import org.apache.flink.table.planner.functions.utils.SqlValidatorUtils;

public class SqlMLPredictTableFunction
extends SqlMLTableFunction {
    public SqlMLPredictTableFunction() {
        super("ML_PREDICT", new PredictOperandMetadata());
    }

    @Override
    public boolean argumentMustBeScalar(int ordinal) {
        return ordinal != 0;
    }

    @Override
    protected RelDataType inferRowType(SqlOperatorBinding opBinding) {
        RelDataTypeFactory typeFactory = opBinding.getTypeFactory();
        RelDataType inputRowType = opBinding.getOperandType(0);
        RelDataType modelOutputRowType = opBinding.getOperandType(1);
        return ((RelDataTypeFactory.FieldInfoBuilder)typeFactory.builder().kind(inputRowType.getStructKind()).addAll(inputRowType.getFieldList())).addAll(SqlValidatorUtils.makeOutputUnique(inputRowType.getFieldList(), modelOutputRowType.getFieldList())).build();
    }

    private static class PredictOperandMetadata
    implements SqlOperandMetadata {
        private static final List<String> PARAM_NAMES = List.of("INPUT", "MODEL", "ARGS", "CONFIG");
        private static final List<String> MANDATORY_PARAM_NAMES = List.of("INPUT", "MODEL", "ARGS");

        PredictOperandMetadata() {
        }

        @Override
        public List<RelDataType> paramTypes(RelDataTypeFactory typeFactory) {
            return Collections.nCopies(PARAM_NAMES.size(), typeFactory.createSqlType(SqlTypeName.ANY));
        }

        @Override
        public List<String> paramNames() {
            return PARAM_NAMES;
        }

        @Override
        public boolean checkOperandTypes(SqlCallBinding callBinding, boolean throwOnFailure) {
            if (!SqlValidatorUtils.checkTableAndDescriptorOperands(callBinding, 2)) {
                return SqlValidatorUtils.throwValidationSignatureErrorOrReturnFalse(callBinding, throwOnFailure);
            }
            if (!SqlValidatorUtils.throwExceptionOrReturnFalse(SqlMLTableFunction.checkModelSignature(callBinding, 2), throwOnFailure)) {
                return false;
            }
            if (callBinding.getOperandCount() < PARAM_NAMES.size()) {
                return true;
            }
            return SqlValidatorUtils.throwExceptionOrReturnFalse(SqlMLTableFunction.checkConfig(callBinding, callBinding.operand(3)), throwOnFailure);
        }

        @Override
        public SqlOperandCountRange getOperandCountRange() {
            return SqlOperandCountRanges.between(MANDATORY_PARAM_NAMES.size(), PARAM_NAMES.size());
        }

        @Override
        public boolean isOptional(int i) {
            return i >= this.getOperandCountRange().getMin() && i < this.getOperandCountRange().getMax();
        }

        @Override
        public String getAllowedSignatures(SqlOperator op, String opName) {
            return opName + "(TABLE table_name, MODEL model_name, DESCRIPTOR(input_columns), [MAP[]])";
        }
    }
}

