Sum-Check
I'm writing this post as an introduction to the Sum-Check protocol.
Motivation
Although there are many references on Sum-Check (see original paper, or link or many others), I decided to write this in a hopefully simpler way for people getting started in Zero-Knowledge to understand.
I also highlight the excellent manuscript from Justin Thaler that I used as reference during my learning, as well as the very helpful discussions with Diego Kingston .
Math background
Consider the polynomial , defined over a finite field (consider as the Boolean field, containing elements 0 and 1). The idea of the sum-check protocol is for the prover $P$ to convince the verifier $V$ that the sum of the polynomial $g$ over the field is equal to $H$. In mathematical terms:
We can also represent the equation above as follows:
As discussed in the manuscript, there are several problems which can be represented as a sum just like equation 1, for example matrix multiplication (see Thaler13 or this nice video by Modular Labs on how they use sumcheck for proving matrix multiplication in the realm of Machine Learning verification).
A great question at this point is: why should the verifier use sumcheck if can simply evaluate the polynomial at all points of the finite field? The reason is speed. Compare the alternatives below:
- Without sumcheck - evaluate at points
- With sumcheck (as depicted later) - evaluate at 2*v points (O(v) runtime complexity)
Hence it's reasonable to use sumcheck.
Explanation of protocol
I find very useful whenever I see an example of how the protocol can be used with a polynomial I can understand. So let's start with a simple example from the manuscript:
First, let's calculate . We evaluate at all points in the finite field, i.e. we use all possible combinations of 0 and 1 for variables , and .
Now, we describe how the protocol takes place, so the prover convinces the verifier that . The key to understand sumcheck is to observe that it's a recursive protocol, i.e. we have multiple rounds (1 per variable of the polynomial), and on each round we will "transform" the polynomial from variables (in our example, ) to 1 variable, and evaluate this transformed polynomial in our entire finite field (i.e. at points 0 and 1).
Start of protocol
- sends - the claimed value of , in our case 12.
First round -
The prover calculates a "transformed" version of the polynomial (given by ), calculated by:
- Prover sends the polynomial to the verifier
- Verifier checks that = H
- Verifier draws a random variable from and sends it to P (let's assume )
Second round -
Again, prover calculates "transformed" version of polynomial , this time using the random value given by the verifier. Note that again the polynomial is univariate.
- Prover sends polynomial to the verifier
-
Verifier checks that =
Here, observe that the verifier does not compute . Instead, the verifier gets those values from the prover and simply verifies that the left side and right side match.
-
Verifier draws a random variable from and sends it to P (let's assume )
Third round -
Again, prover calculates "transformed" version of polynomial , this time using the random and values given by the verifier. Note that again the polynomial is univariate.
- Prover sends polynomial to the verifier
- Verifier checks that =
- Verifier draws a random variable from and sends it to P (let's assume )
Final round
- Verifier has to check that
This is the power of the sumcheck protocol - the verifier is able to evaluate at via a single oracle query, much more efficient than calculating sums on every round for , and so forth.
Code implementation
I did a simple implementation in Python of the sumcheck protocol leveraging the sympy library for polynomial implementation. The code is reproduced below.
# Initialization
x1, x2, x3 = symbols("x1 x2 x3")
poly = Poly(2*x1**3 + x1*x3 + x2*x3)
idx_to_vars = {0: x3, 1: x2, 2: x1}
p = Prover(poly, idx_to_vars)
v = Verifier()
num_of_rounds = len(poly.free_symbols)
### Round 1
# Prover (P) sends value of sum to verifier (V)
total_sum = p.calculate_sum(3)
random_values = {}
vars_to_iterate = num_of_rounds-1
prev_s = total_sum
random_values_store = [2,3,6]
random_value = 0
for round_idx in range(3):
s = p.calculate_sum(vars_to_iterate, random_values)
# Verifier checks that s1(0) + s1(1) = 12
v.verify_univariate_poly(s, prev_s, random_value)
random_value = random_values_store.pop(0)
# We fetch the random_value idx in desc order
x_variable = idx_to_vars[num_of_rounds-round_idx-1]
random_values[x_variable] = random_value
prev_s = s
vars_to_iterate -= 1
# Final round, verifies executes oracle query
result_from_oracle = Oracle().evaluate_polynomial(poly, random_values)
# Verifier checks that s3(6) = g(r1,r2,r3)
v.assert_poly_matches_external_query(s, random_value, result_from_oracle)
Conclusion
In this post, I tried to explain the sumcheck protocol with an example from Thaler's manuscript, so that one can easier grasp the main concepts.
I also recommend this article from Lambdaclass which discussed Sumcheck in greater depth.