from math import sqrt, floor, ceil, gcd
import fractions
import random


def xgcd(a, b):
    """
    Given two integers (a, b), return (g, u, v) where g is the GCD of a and b,
    and (u,v) are the coefficients of the Bezout relation a*u + b*v == g.
    """
    (x0, x1, y0, y1) = (1, 0, 0, 1)
    while b != 0:
        (q, a, b) = (a // b, b, a % b)
        (x0, x1) = (x1, x0 - q * x1)
        (y0, y1) = (y1, y0 - q * y1)
    return (b, x0, y0)

def keygen(l, k):
    """
    Generate an instance of the problem.
    Returns (public, private)
    """
    a = 2
    n = 2
    while gcd(a, n) != 1:
        n = random.getrandbits(l)
        a = random.randrange(n)
    x = 2
    y = 2
    while gcd(x, y) != 1:
        x = random.getrandbits(k)
        y = random.getrandbits(k)
    _, t, z = xgcd(x, y)
    z = -z
    assert x*t - z*y == 1
    aprime = a*x + n*y
    nprime = a*z + n*t
    return (aprime, nprime), (a, n, x, y, z, t)

################################################################

def dot(u, v):
    """
    Return the dot product of the two vectors u and v
    """
    r = 0
    for (x, y) in zip(u, v):
        r += x * y
    return r

def sqnorm(u):
    """
    Return the square of the norm of the vector u
    """
    return dot(u, u)

def aupbv(a, u, b, v):
    """
    Given two vectors (u, v) and two scalar (a, b), return the vector a*u + b*v
    """
    z = []
    for (x, y) in zip(u, v):
        z.append(a * x + b * y)
    return z

def norm(u):
    return sqrt(dot(u, u))

def scale(s, v):
    """
    Return 1/s * v, where s is a (potentially large)
    scalar and v is a vector.  This yields a vector 
    of rationals
    """
    f = fractions.Fraction(1, s)
    return [f * x for x in v]

def lagrange_reduction(u, v):
    """
    Given a basis (u, v) of a two-dimensional lattice,
    return the two shortest vectors.  This algorithm operates
    only on integers.
    """
    if sqnorm(u) < sqnorm(v):
        tmp = u
        u = v
        v = tmp
    while True:
        f = fractions.Fraction(dot(u, v), dot(v, v))
        q = round(f)
        r = aupbv(1, u, -q, v)
        u = v
        v = r
        if sqnorm(u) <= sqnorm(v):
            return (u, v)

def enumerate(r, s, B, N):
    """Given a lattice L spanned by the two vectors 
    (r, s), assumed to be the shortest, return the
    list of all vectors of L with norm less than B.
    It is assumed that r and s have norm about N, 
    and that B is close to N.
    """
    short_vectors = []
    # Compute volume
    Vol2 = sqnorm(r) * sqnorm(s) - dot(r, s)**2
    sVol = sqrt(fractions.Fraction(Vol2, N**4))
    # Gram-Schmidt orthogonalization
    mu = fractions.Fraction(dot(r, s), dot(r, r))
    snorm_rstar = norm(scale(N, r))
    snorm_sstar = sVol / snorm_rstar
    # enumeration
    sB = fractions.Fraction(B, N)
    x2_max = floor(sB / snorm_sstar)
    for x2 in range(-x2_max, x2_max + 1):
        x1_max = floor(sB / snorm_rstar - mu*x2)
        x1_min = ceil(-sB / snorm_rstar - mu*x2)
        for x1 in range(x1_min, x1_max + 1):
            w = aupbv(x1, r, x2, s)
            if sqnorm(w) <= B * B:
                short_vectors.append(w)
    return short_vectors

def linear_program(aprime, nprime, l, k):
    """
    Return the set of all (gamma, delta) such that
    |gamma * aprime + delta * nprime| <= 2**l
    |gamma| <= 2**k
    |delta| <= 2**k
    """
    u = [aprime, 2**(l - k), 0]
    v = [nprime, 0, 2**(l - k)]
    r, s = lagrange_reduction(u, v)
    B = 2**(l+1)
    solutions = []
    short_vectors = enumerate(r, s, B, 2**l)
    for (a, g, d) in short_vectors:
        g = g // 2**(l - k)
        d = d // 2**(l - k)
        if abs(a) <= 2**l and abs(g) <= 2**k and \
           abs(d) <= 2**k:
            solutions.append([g, d])
    return solutions

def break_protocol(aprime, nprime, l, k):
    """
    Given the server input (a', n'), produce the
    secrets of the client (a, n).
    """
    Omega = linear_program(aprime, nprime, l, k)
    secrets = []
    for (t, mz) in Omega:
        for (my, x) in Omega:
            if abs(t * x - my * mz) != 1:
                continue
            a = t * aprime + mz * nprime
            n = my * aprime + x * nprime
            if 0 < a < n:
                secrets.append([a, n])
    return secrets

######################################################""

l = 2048
k = 128

pk, sk = keygen(l, k)
aprime, nprime = pk
candidates = break_protocol(aprime, nprime, l, k)
print("{} candidates found".format(len(candidates)))

a, n, *_ = sk
assert [a, n] in candidates
print("correct solution found")
