/*
 * Decompiled with CFR 0.152.
 */
package ai.onnxruntime.platform;

import java.lang.invoke.MethodHandle;
import java.lang.invoke.MethodHandles;
import java.lang.invoke.MethodType;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.ShortBuffer;
import java.util.logging.Level;
import java.util.logging.Logger;

public final class Fp16Conversions {
    private static final Logger logger = Logger.getLogger(Fp16Conversions.class.getName());
    private static final MethodHandle fp16ToFp32;
    private static final MethodHandle fp32ToFp16;

    public static ShortBuffer convertFloatBufferToFp16Buffer(FloatBuffer buf) {
        int pos = buf.position();
        int remaining = buf.remaining();
        ShortBuffer output = ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer();
        for (int i = 0; i < remaining; ++i) {
            output.put(i, Fp16Conversions.floatToFp16(buf.get(i + pos)));
        }
        return output;
    }

    public static FloatBuffer convertFp16BufferToFloatBuffer(ShortBuffer buf) {
        int pos = buf.position();
        int remaining = buf.remaining();
        FloatBuffer output = ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer();
        for (int i = 0; i < remaining; ++i) {
            output.put(i, Fp16Conversions.fp16ToFloat(buf.get(i + pos)));
        }
        return output;
    }

    public static ShortBuffer convertFloatBufferToBf16Buffer(FloatBuffer buf) {
        int pos = buf.position();
        int remaining = buf.remaining();
        ShortBuffer output = ByteBuffer.allocateDirect(remaining * 2).order(ByteOrder.nativeOrder()).asShortBuffer();
        for (int i = 0; i < remaining; ++i) {
            output.put(i, Fp16Conversions.floatToBf16(buf.get(i + pos)));
        }
        return output;
    }

    public static FloatBuffer convertBf16BufferToFloatBuffer(ShortBuffer buf) {
        int pos = buf.position();
        int remaining = buf.remaining();
        FloatBuffer output = ByteBuffer.allocateDirect(remaining * 4).order(ByteOrder.nativeOrder()).asFloatBuffer();
        for (int i = 0; i < remaining; ++i) {
            output.put(i, Fp16Conversions.bf16ToFloat(buf.get(i + pos)));
        }
        return output;
    }

    public static float fp16ToFloat(short input) {
        try {
            float ret = fp16ToFp32.invokeExact(input);
            return ret;
        }
        catch (Throwable e) {
            throw new AssertionError("Should not reach here", e);
        }
    }

    public static short floatToFp16(float input) {
        try {
            short ret = fp32ToFp16.invokeExact(input);
            return ret;
        }
        catch (Throwable e) {
            throw new AssertionError("Should not reach here", e);
        }
    }

    public static float mlasFp16ToFloat(short input) {
        int MAGIC = 0x38800000;
        int SHIFTED_EXP = 0xF800000;
        int bits = (input & Short.MAX_VALUE) << 13;
        int exp = 0xF800000 & bits;
        bits += 0x38000000;
        if (exp == 0xF800000) {
            bits += 0x38000000;
        } else if (exp == 0) {
            float tmp = Float.intBitsToFloat(bits += 0x800000) - Float.intBitsToFloat(0x38800000);
            bits = Float.floatToIntBits(tmp);
        }
        return Float.intBitsToFloat(bits |= (input & 0x8000) << 16);
    }

    public static short mlasFloatToFp16(float input) {
        short output;
        int sign;
        int bits = Float.floatToIntBits(input);
        int F32_INFINITY = Float.floatToIntBits(Float.POSITIVE_INFINITY);
        int F16_MAX = 1199570944;
        int DENORM_MAGIC = 0x3F000000;
        int SIGN_MASK = Integer.MIN_VALUE;
        int ROUNDING_CONST = -939520001;
        if ((bits ^= (sign = bits & Integer.MIN_VALUE)) >= 1199570944) {
            output = bits > F32_INFINITY ? (short)32256 : 31744;
        } else if (bits < 0x38800000) {
            float tmp = Float.intBitsToFloat(bits) + Float.intBitsToFloat(0x3F000000);
            output = (short)(Float.floatToIntBits(tmp) - 0x3F000000);
        } else {
            int mant_odd = bits >> 13 & 1;
            bits -= 939520001;
            output = (short)((bits += mant_odd) >> 13);
        }
        output = (short)(output | (short)(sign >> 16));
        return output;
    }

    public static float bf16ToFloat(short input) {
        int bits = input << 16;
        return Float.intBitsToFloat(bits);
    }

    public static short floatToBf16(float input) {
        int bits = Float.floatToIntBits(input);
        int lsb = bits >> 16 & 1;
        int roundingBias = Short.MAX_VALUE + lsb;
        return (short)((bits += roundingBias) >> 16);
    }

    static {
        MethodHandle tmp16 = null;
        MethodHandle tmp32 = null;
        MethodHandles.Lookup lookup = MethodHandles.lookup();
        try {
            tmp16 = lookup.findStatic(Float.class, "float16ToFloat", MethodType.methodType(Float.TYPE, Short.TYPE));
            tmp32 = lookup.findStatic(Float.class, "floatToFloat16", MethodType.methodType(Short.TYPE, Float.TYPE));
        }
        catch (IllegalAccessException | NoSuchMethodException e) {
            try {
                tmp16 = lookup.findStatic(Fp16Conversions.class, "mlasFp16ToFloat", MethodType.methodType(Float.TYPE, Short.TYPE));
                tmp32 = lookup.findStatic(Fp16Conversions.class, "mlasFloatToFp16", MethodType.methodType(Short.TYPE, Float.TYPE));
            }
            catch (IllegalAccessException | NoSuchMethodException ex) {
                logger.log(Level.SEVERE, "Failed to find fp16 conversion methods on OnnxTensor", e);
            }
        }
        fp16ToFp32 = tmp16;
        fp32ToFp16 = tmp32;
    }
}

