/*
 * Decompiled with CFR 0.152.
 */
package de.rub.nds.protocol.crypto.ec;

import de.rub.nds.protocol.crypto.ec.EllipticCurve;
import de.rub.nds.protocol.crypto.ec.FieldElement;
import de.rub.nds.protocol.crypto.ec.FieldElementFp;
import de.rub.nds.protocol.crypto.ec.Point;
import java.math.BigInteger;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

public class EllipticCurveOverFp
extends EllipticCurve {
    private static final Logger LOGGER = LogManager.getLogger();
    private final FieldElementFp fieldA;
    private final FieldElementFp fieldB;

    public EllipticCurveOverFp(BigInteger a, BigInteger b, BigInteger p) {
        super(p);
        this.fieldA = new FieldElementFp(a, this.getModulus());
        this.fieldB = new FieldElementFp(b, this.getModulus());
    }

    public EllipticCurveOverFp(BigInteger a, BigInteger b, BigInteger p, BigInteger x, BigInteger y, BigInteger q) {
        super(p, x, y, q);
        this.fieldA = new FieldElementFp(a, this.getModulus());
        this.fieldB = new FieldElementFp(b, this.getModulus());
    }

    @Override
    public Point getPoint(BigInteger x, BigInteger y) {
        FieldElementFp elemX = new FieldElementFp(x, this.getModulus());
        FieldElementFp elemY = new FieldElementFp(y, this.getModulus());
        return new Point(elemX, elemY);
    }

    @Override
    public boolean isOnCurve(Point p) {
        if (p.isAtInfinity()) {
            return true;
        }
        if (p.getFieldX().getClass() != FieldElementFp.class || p.getFieldY().getClass() != FieldElementFp.class) {
            return false;
        }
        FieldElementFp x = (FieldElementFp)p.getFieldX();
        FieldElementFp y = (FieldElementFp)p.getFieldY();
        if (x.getModulus() != this.getModulus() || y.getModulus() != this.getModulus()) {
            return false;
        }
        FieldElementFp leftPart = (FieldElementFp)y.mult(y);
        FieldElementFp rightPart = (FieldElementFp)x.mult(x.mult(x)).add(x.mult(this.getFieldA())).add(this.getFieldB());
        return leftPart.equals(rightPart);
    }

    @Override
    protected Point inverseAffine(Point p) {
        if (!(p.getFieldX() instanceof FieldElementFp) || !(p.getFieldY() instanceof FieldElementFp)) {
            LOGGER.warn("Trying to invert non Fp point with Fp curve. Returning point at (0,0)");
            return this.getPoint(BigInteger.ZERO, BigInteger.ZERO);
        }
        FieldElementFp x = (FieldElementFp)p.getFieldX();
        FieldElementFp invY = (FieldElementFp)p.getFieldY().addInv();
        return new Point(x, invY);
    }

    @Override
    protected Point additionFormular(Point p, Point q) {
        if (!(p.getFieldX() instanceof FieldElementFp && p.getFieldY() instanceof FieldElementFp && q.getFieldX() instanceof FieldElementFp && q.getFieldY() instanceof FieldElementFp)) {
            LOGGER.warn("Trying to add non Fp points with Fp curve. Returning point at (0,0)");
            return this.getPoint(BigInteger.ZERO, BigInteger.ZERO);
        }
        try {
            FieldElementFp lambda;
            FieldElementFp x1 = (FieldElementFp)p.getFieldX();
            FieldElementFp y1 = (FieldElementFp)p.getFieldY();
            FieldElementFp x2 = (FieldElementFp)q.getFieldX();
            FieldElementFp y2 = (FieldElementFp)q.getFieldY();
            if (p.equals(q)) {
                FieldElementFp two = new FieldElementFp(new BigInteger("2"), this.getModulus());
                FieldElementFp three = new FieldElementFp(new BigInteger("3"), this.getModulus());
                lambda = (FieldElementFp)x1.mult(x1).mult(three).add(this.getFieldA()).divide(y1.mult(two));
            } else {
                lambda = (FieldElementFp)y2.subtract(y1).divide(x2.subtract(x1));
            }
            FieldElementFp lambdaSq = (FieldElementFp)lambda.mult(lambda);
            FieldElementFp x3 = (FieldElementFp)lambdaSq.subtract(x1).subtract(x2);
            FieldElementFp y3 = (FieldElementFp)lambda.mult(x1.subtract(x3)).subtract(y1);
            return new Point(x3, y3);
        }
        catch (ArithmeticException e) {
            LOGGER.warn("Encountered an arithmetic exception during addition. Returning point at 0,0");
            return this.getPoint(BigInteger.ZERO, BigInteger.ZERO);
        }
    }

    @Override
    public FieldElement createFieldElement(BigInteger value) {
        return new FieldElementFp(value, this.getModulus());
    }

    @Override
    public Point createAPointOnCurve(BigInteger x, boolean returnBasepointUponError) {
        BigInteger y = x.pow(3).add(x.multiply(this.getFieldA().getData())).add(this.getFieldB().getData()).mod(this.getModulus());
        if ((y = this.modSqrt(y, this.getModulus())) == null) {
            if (returnBasepointUponError) {
                LOGGER.warn("Was unable to create point on curve - using basepoint instead");
                return this.getBasePoint();
            }
            return null;
        }
        Point created = this.getPoint(x, y);
        if (!y.testBit(0)) {
            created = this.inverse(created);
        }
        return created;
    }

    public FieldElementFp getFieldA() {
        return this.fieldA;
    }

    public FieldElementFp getFieldB() {
        return this.fieldB;
    }

    private int legendreSymbol(BigInteger a, BigInteger p) {
        BigInteger ls = a.modPow(p.subtract(BigInteger.ONE).divide(new BigInteger("2")), p);
        if (ls.compareTo(p.subtract(BigInteger.ONE)) == 0) {
            return -1;
        }
        return ls.intValue();
    }

    public BigInteger modSqrt(BigInteger a, BigInteger p) {
        BigInteger z;
        if (this.legendreSymbol(a, p) != 1 || a.compareTo(BigInteger.ZERO) == 0 || a.compareTo(new BigInteger("2")) == 0) {
            return null;
        }
        if (p.mod(new BigInteger("4")).compareTo(new BigInteger("3")) == 0) {
            return a.modPow(p.add(BigInteger.ONE).divide(new BigInteger("4")), p);
        }
        BigInteger r = p.subtract(BigInteger.ONE);
        BigInteger e = BigInteger.ZERO;
        while (r.mod(new BigInteger("2")).compareTo(BigInteger.ZERO) == 0) {
            r = r.divide(new BigInteger("2"));
            e = e.add(BigInteger.ONE);
        }
        BigInteger n = new BigInteger("2");
        while (this.legendreSymbol(n, p) != -1) {
            n = n.add(BigInteger.ONE);
        }
        BigInteger y = z = n.modPow(r, p);
        BigInteger s = e;
        BigInteger x = a.modPow(r.subtract(BigInteger.ONE).divide(new BigInteger("2")), p);
        BigInteger b = a.multiply(x.pow(2)).mod(p);
        x = a.multiply(x).mod(p);
        while (b.mod(p).compareTo(BigInteger.ONE) != 0) {
            BigInteger m = BigInteger.ONE;
            while (b.modPow(new BigInteger("2").pow(m.intValue()), p).compareTo(BigInteger.ONE) != 0) {
                m = m.add(BigInteger.ONE);
            }
            BigInteger t = y.modPow(new BigInteger("2").pow(s.intValue() - m.intValue() - 1), p);
            y = t.pow(2).mod(p);
            s = m;
            x = t.multiply(x).mod(p);
            b = y.multiply(b).mod(p);
        }
        return x;
    }
}

