/*
 * Rijndael Reference ANSI C code
 * authors: Paulo Barreto
 *          Vincent Rijmen
 *
 * adapted from rijndael-alg-ref.c   v2.0   August '99 by Charles Bouillaguet on February 2011
 *
 * This implements a tweaked version of the AES-128 where the final MixColumn IS PRESENT.
 *
 * This simplifies a bit the fault attack (!), while not making the cipher less secure.
 *
 * This file is in the public domain.
 */

#include <stdio.h>
#include <stdlib.h>
#include <assert.h>

#include "../rijndael.h"

#define SC	((BC - 4))

#include "../boxes-ref.h"

int n_sols[256][256];
word8 * sols_x[256][256];
word8 * sols_y[256][256];


static void KeyAddition(word8 a[4][MAXBC], word8 rk[4][MAXBC], word8 BC) {
	/* Exor corresponding text input and round key input bytes
	 */
	int i, j;
	
	for(i = 0; i < 4; i++)
   		for(j = 0; j < BC; j++) a[i][j] ^= rk[i][j];
}

static void ShiftRow(word8 a[4][MAXBC], word8 d, word8 BC) {
	/* Row 0 remains unchanged
	 * The other three rows are shifted a variable amount
	 */
	word8 tmp[MAXBC];
	int i, j;
	
	for(i = 1; i < 4; i++) {
	  //	  printf("Offset : %d\n", shifts[SC][i][d]);
		for(j = 0; j < BC; j++) tmp[j] = a[i][(j + shifts[SC][i][d]) % BC];
		for(j = 0; j < BC; j++) a[i][j] = tmp[j];
	}
}

static void Substitution(word8 a[4][MAXBC], const word8 box[256], word8 BC) {
	/* Replace every byte of the input by the byte at that place
	 * in the nonlinear S-box
	 */
	int i, j;
	
	for(i = 0; i < 4; i++)
		for(j = 0; j < BC; j++) a[i][j] = box[a[i][j]] ;
}
   
static void MixColumn(word8 a[4][MAXBC], word8 BC) {
        /* Mix the four bytes of every column in a linear way
	 */
	word8 b[4][MAXBC];
	int i, j;
		
	for(j = 0; j < BC; j++)
		for(i = 0; i < 4; i++)
			b[i][j] = mul(2,a[i][j])
				^ mul(3,a[(i + 1) % 4][j])
				^ a[(i + 2) % 4][j]
				^ a[(i + 3) % 4][j];
	for(i = 0; i < 4; i++)
		for(j = 0; j < BC; j++) a[i][j] = b[i][j];
}

static void InvMixColumn(word8 a[4][MAXBC], word8 BC) {
        /* Mix the four bytes of every column in a linear way
	 * This is the opposite operation of Mixcolumn
	 */
	word8 b[4][MAXBC];
	int i, j;
	
	for(j = 0; j < BC; j++)
	for(i = 0; i < 4; i++)             
		b[i][j] = mul(0xe,a[i][j])
			^ mul(0xb,a[(i + 1) % 4][j])                 
			^ mul(0xd,a[(i + 2) % 4][j])
			^ mul(0x9,a[(i + 3) % 4][j]);                        
	for(i = 0; i < 4; i++)
		for(j = 0; j < BC; j++) a[i][j] = b[i][j];
}

int rijndaelKeySched (word8 k[4][MAXKC], int keyBits, int blockBits, word8 W[MAXROUNDS+1][4][MAXBC]) {
	/* Calculate the necessary round keys
	 * The number of calculations depends on keyBits and blockBits
	 */
	int KC, BC, ROUNDS;
	int i, j, t, rconpointer = 0;
	word8 tk[4][MAXKC];   

	BC = blockBits/32;
	KC = keyBits/32;
	ROUNDS = 6 + (BC > KC ? BC : KC);

	
	for(j = 0; j < KC; j++)
		for(i = 0; i < 4; i++)
			tk[i][j] = k[i][j];
	t = 0;
	/* copy values into round key array */
	for(j = 0; (j < KC) && (t < (ROUNDS+1)*BC); j++, t++)
		for(i = 0; i < 4; i++) W[t / BC][i][t % BC] = tk[i][j];
		
	while (t < (ROUNDS+1)*BC) { /* while not enough round key material calculated */
		/* calculate new values */
		for(i = 0; i < 4; i++)
		  tk[i][0] ^= S[tk[(i+1)%4][KC-1]];
		tk[0][0] ^= rcon[rconpointer++];

		if (KC <= 6)
		  for(j = 1; j < KC; j++)
		    for(i = 0; i < 4; i++) tk[i][j] ^= tk[i][j-1];
		else {
		  for(j = 1; j < 4; j++)
		    for(i = 0; i < 4; i++) tk[i][j] ^= tk[i][j-1];
		  for(i = 0; i < 4; i++) tk[i][4] ^= S[tk[i][3]];
		  for(j = 5; j < KC; j++)
		    for(i = 0; i < 4; i++) tk[i][j] ^= tk[i][j-1];
		}
	/* copy values into round key array */
	for(j = 0; (j < KC) && (t < (ROUNDS+1)*BC); j++, t++)
		for(i = 0; i < 4; i++) W[t / BC][i][t % BC] = tk[i][j];
	}		

	return 0;
}
      
int rijndaelEncrypt (word8 a[4][MAXBC], int keyBits, int blockBits, word8 rk[MAXROUNDS+1][4][MAXBC])
{
	/* Encryption of one block. 
	 */
	int r, BC, ROUNDS;

	BC = blockBits/32;
	ROUNDS = 6 + (keyBits >= blockBits ? keyBits : blockBits)/32;

	/* begin with a key addition
	 */
	KeyAddition(a,rk[0],BC); 

        /* ROUNDS ordinary rounds (i.e. final MixColumn is present) */
	for(r = 1; r <= ROUNDS; r++) {
		Substitution(a,S,BC);
		ShiftRow(a,0,BC);
		MixColumn(a,BC);
		KeyAddition(a,rk[r],BC);
	}
	
	return 0;
}   


void print_state(word8 a[4][MAXBC], int BC) {

  int i,j;
  for(i=0; i<4; i++) {
    for(j=0; j<BC; j++)
      printf("%02x ", a[i][j]);
    printf("\n");
  }
}

void init_diff_tables() {
  int i,j;

  // count solutions
  for(i=0; i<0x100; i++) 
    for(j=0; j<0x100; j++) 
      n_sols[i][j] = 0;

  for(i=0; i<0x100; i++) 
    for(j=0; j<0x100; j++) 
      n_sols[i^j][ S[i] ^ S[j] ]++;

  // allocate space for solutions    
  for(i=0; i<0x100; i++) 
    for(j=0; j<0x100; j++) {
      sols_x[i][j] = malloc(n_sols[i][j]);
      sols_y[i][j] = malloc(n_sols[i][j]);
      n_sols[i][j] = 0;
    }

  // fill solution tables    
  for(i=0; i<0x100; i++)
    for(j=0; j<0x100; j++) {
      int delta_i = i^j;
      int delta_o = S[i] ^ S[j];

      sols_x[ delta_i ][ delta_o ][ n_sols[ delta_i ][ delta_o ] ] = i;
      sols_y[ delta_i ][ delta_o ][ n_sols[ delta_i ][ delta_o ] ] = j;
      n_sols[ delta_i ][ delta_o ]++;
    }


}


void PiretQuisquater (word8 P[4][MAXBC],   word8 X_3[4][MAXBC],   word8 Xp_3[4][MAXBC])
{
  int i,j,k,r, BC, ROUNDS;
  
  BC = 4;

  word8 X_1[4][MAXBC];
  word8 X_2[4][MAXBC];

  word8 Xp_1[4][MAXBC];
  word8 Xp_2[4][MAXBC];

  // first compute the difference just after the last SubBytes
  word8 delta_op[4][MAXBC]; 
  for(i=0; i<4; i++) 
    for(j=0; j<4; j++)
      delta_op[i][j] = X_3[i][j] ^ Xp_3[i][j];
  
  InvMixColumn(delta_op, 4);
  ShiftRow(delta_op,1,4);                
  
  for(i=0; i<4; i++) 
    for(j=0; j<4; j++) 
      if (delta_op[i][j] == 0) {
	printf("Aouch ! zero difference in an output byte. You need an ad hoc attack in this case. I can't work....\n");
	exit(1);
      }
  
    int ctr0, ctr1, ctr2, ctr3, ctr4;

    word8 Delta_X_1[4];
    word8 Delta_X_2[4][MAXBC];
    int n_candidates = 0;


    // STEP 1 : guess difference in Y_0[0,0]
    for(ctr0 = 0; ctr0<0x100; ctr0++) {
      word8 delta = ctr0;

      // STEP 1 bis : compute difference in X_1[*,0]
      Delta_X_1[0] = mul(2, delta);
      Delta_X_1[1] = delta;
      Delta_X_1[2] = delta;
      Delta_X_1[3] = mul(3, delta);

      printf("\r%.1f %%", delta/256.*100);
      fflush(stdout);

      // STEP 2 : guess actual value of X_1[0,0] and X_1[1,0]
      for(ctr1 = 0; ctr1<0x100; ctr1++)
        for(ctr2 = 0; ctr2<0x100; ctr2++) {

      X_1[0][0] = ctr1;
      X_1[1][0] = ctr2;

      Xp_1[0][0] = X_1[0][0] ^ Delta_X_1[0];  // Get X'_1[0,0]
      Xp_1[1][0] = X_1[1][0] ^ Delta_X_1[1];  // Get X'_1[1,0]

      // this is the difference after SubBytes in X_1[0,0]
      word8 foo = S[ X_1[0][0] ] ^ S[ Xp_1[0][0] ];
      word8 bar = S[ X_1[1][0] ] ^ S[ Xp_1[1][0] ];
      
      // STEP 3 : compute differences and actual values in X_2[*, 0]
      Delta_X_2[0][0] = mul(2, foo);
      Delta_X_2[1][0] = foo;
      Delta_X_2[2][0] = foo;
      Delta_X_2[3][0] = mul(3, foo);
      
      // difference in X_2[*, 3]
      Delta_X_2[0][3] = mul(3, bar);
      Delta_X_2[1][3] = mul(2, bar);
      Delta_X_2[2][3] = bar;
      Delta_X_2[3][3] = bar;
      
      // now recovers the actual values in X_2[*,0] and X_2[*,3]
	int a,b,c,d,e,f,g,h;
	for(a=0; a<n_sols[ Delta_X_2[0][0] ][ delta_op[0][0] ]; a++)
	for(b=0; b<n_sols[ Delta_X_2[1][0] ][ delta_op[1][0] ]; b++)
	for(c=0; c<n_sols[ Delta_X_2[2][0] ][ delta_op[2][0] ]; c++)
        for(d=0; d<n_sols[ Delta_X_2[3][0] ][ delta_op[3][0] ]; d++) 

	for(e=0; e<n_sols[ Delta_X_2[0][3] ][ delta_op[0][3] ]; e++)
	for(f=0; f<n_sols[ Delta_X_2[1][3] ][ delta_op[1][3] ]; f++)
	for(g=0; g<n_sols[ Delta_X_2[2][3] ][ delta_op[2][3] ]; g++)
        for(h=0; h<n_sols[ Delta_X_2[3][3] ][ delta_op[3][3] ]; h++) {

	  // read actual values from diff. table
	  X_2[0][0] = sols_x[ Delta_X_2[0][0] ][ delta_op[0][0] ][a];
	  X_2[1][0] = sols_x[ Delta_X_2[1][0] ][ delta_op[1][0] ][b];
	  X_2[2][0] = sols_x[ Delta_X_2[2][0] ][ delta_op[2][0] ][c];
	  X_2[3][0] = sols_x[ Delta_X_2[3][0] ][ delta_op[3][0] ][d];

	  X_2[0][3] = sols_x[ Delta_X_2[0][3] ][ delta_op[0][3] ][e];
	  X_2[1][3] = sols_x[ Delta_X_2[1][3] ][ delta_op[1][3] ][f];
	  X_2[2][3] = sols_x[ Delta_X_2[2][3] ][ delta_op[2][3] ][g];
	  X_2[3][3] = sols_x[ Delta_X_2[3][3] ][ delta_op[3][3] ][h];

	  // STEP 4 : filter guesses thanks to lemma X, point i)

	  // first, compute stuff = X_3[*,2] + X_3[*,3] + X_2[*,3]
	  word8 stuff[4];
	  stuff[0] = X_3[0][2] ^ X_3[0][3] ^ X_2[0][3];
	  stuff[1] = X_3[1][2] ^ X_3[1][3] ^ X_2[1][3];
	  stuff[2] = X_3[2][2] ^ X_3[2][3] ^ X_2[2][3];
	  stuff[3] = X_3[3][2] ^ X_3[3][3] ^ X_2[3][3];

	  // now computes InverseMixColumn(stuff)[1]
	  word8 inv_mc_stuff_1 = mul(0x09, stuff[0]) ^ mul(0x0e, stuff[1]) ^ mul(0x0b, stuff[2]) ^ mul(0x0d, stuff[3]);

	  // recomputes X_1[1][0]....
	  word8 x_1_1_check = Si[ inv_mc_stuff_1 ^ S[ X_2[1][3] ] ^ S[ X_2[1][0] ] ];

	  // and check it against the guess
	  if (X_1[1][0] != x_1_1_check) continue;   // if not equal, try next possibility

	  // STEP 5 : guess actual value of X_1[2,0]
	  for(ctr3 = 0; ctr3<0x100; ctr3++) {
	    X_1[2][0] = ctr3;

	    Xp_1[2][0] = X_1[2][0] ^ Delta_X_1[2];  // Get X'_1[2,0]

	    // this is the difference after SubBytes in X_1[2,0]
	    word8 foobar = S[ X_1[2][0] ] ^ S[ Xp_1[2][0] ];
	    
	    // STEP 6 : compute differences and actual values in X_2[*, 2]
	    Delta_X_2[0][2] = foobar;
	    Delta_X_2[1][2] = mul(3, foobar);
	    Delta_X_2[2][2] = mul(2, foobar);
	    Delta_X_2[3][2] = foobar;


	    // now recovers the actual values in X_2[*,2]
	    int i,j,k,l;
	    for(i=0; i<n_sols[ Delta_X_2[0][2] ][ delta_op[0][2] ]; i++)
	    for(j=0; j<n_sols[ Delta_X_2[1][2] ][ delta_op[1][2] ]; j++)
	    for(k=0; k<n_sols[ Delta_X_2[2][2] ][ delta_op[2][2] ]; k++)
            for(l=0; l<n_sols[ Delta_X_2[3][2] ][ delta_op[3][2] ]; l++) {

	      // read actual values from diff. table
	      X_2[0][2] = sols_x[ Delta_X_2[0][2] ][ delta_op[0][2] ][i];
	      X_2[1][2] = sols_x[ Delta_X_2[1][2] ][ delta_op[1][2] ][j];
	      X_2[2][2] = sols_x[ Delta_X_2[2][2] ][ delta_op[2][2] ][k];
	      X_2[3][2] = sols_x[ Delta_X_2[3][2] ][ delta_op[3][2] ][l];

	      // STEP 7 : filter guesses thanks to lemma X, point ii)

	      // first, compute stuff = X_3[*,1] + X_3[*,2] + X_2[*,2]
	      stuff[0] = X_3[0][1] ^ X_3[0][2] ^ X_2[0][2];
	      stuff[1] = X_3[1][1] ^ X_3[1][2] ^ X_2[1][2];
	      stuff[2] = X_3[2][1] ^ X_3[2][2] ^ X_2[2][2];
	      stuff[3] = X_3[3][1] ^ X_3[3][2] ^ X_2[3][2];

	      // now computes InverseMixColumn(stuff2)[2]
	      word8 inv_mc_stuff2_2 = mul(0x0d, stuff[0]) ^ mul(0x09, stuff[1]) ^ mul(0x0e, stuff[2]) ^ mul(0x0b, stuff[3]);

	      // recomputes X_1[2][0]....
	      word8 x_1_2_check = Si[ inv_mc_stuff2_2 ^ S[ X_2[2][3] ] ^ S[ X_2[2][0] ] ];

	      
	      // and check it against the guess
	      if (X_1[2][0] != x_1_2_check) continue;   // if not equal, try next possibility

	      // STEP 8 : guess actual value of X_1[3,0]
	      for(ctr4 = 0; ctr4<0x100; ctr4++) {
		X_1[3][0] = ctr4;

		Xp_1[3][0] = X_1[3][0] ^ Delta_X_1[3];  // Get X'_1[3,0]

		// this is the difference after SubBytes in X_1[3,0]
		word8 barfoo = S[ X_1[3][0] ] ^ S[ Xp_1[3][0] ];
	
		// STEP 9 : compute differences and actual values in X_2[*, 1]
		Delta_X_2[0][1] = barfoo;
		Delta_X_2[1][1] = barfoo;
		Delta_X_2[2][1] = mul(3, barfoo);
		Delta_X_2[3][1] = mul(2, barfoo);


		// now recovers the actual values in X_2[*,1]
		int m,n,o,p;
		for(m=0; m<n_sols[ Delta_X_2[0][1] ][ delta_op[0][1] ]; m++)
   	        for(n=0; n<n_sols[ Delta_X_2[1][1] ][ delta_op[1][1] ]; n++)
		for(o=0; o<n_sols[ Delta_X_2[2][1] ][ delta_op[2][1] ]; o++)
		for(p=0; p<n_sols[ Delta_X_2[3][1] ][ delta_op[3][1] ]; p++) {

		  // read actual values from diff. tables
		  X_2[0][1] = sols_x[ Delta_X_2[0][1] ][ delta_op[0][1] ][m];
		  X_2[1][1] = sols_x[ Delta_X_2[1][1] ][ delta_op[1][1] ][n];
		  X_2[2][1] = sols_x[ Delta_X_2[2][1] ][ delta_op[2][1] ][o];
		  X_2[3][1] = sols_x[ Delta_X_2[3][1] ][ delta_op[3][1] ][p];

		  // STEP 10 : filter guesses thanks to lemma X, point iii)

		  // first, compute stuff = X_3[*,0] + X_3[*,1] + X_2[*,1]
		  word8 stuff3[4];
		  stuff3[0] = X_3[0][0] ^ X_3[0][1] ^ X_2[0][1];
		  stuff3[1] = X_3[1][0] ^ X_3[1][1] ^ X_2[1][1];
		  stuff3[2] = X_3[2][0] ^ X_3[2][1] ^ X_2[2][1];
		  stuff3[3] = X_3[3][0] ^ X_3[3][1] ^ X_2[3][1];

		  // now computes InverseMixColumn(stuff3)[3]
		  word8 inv_mc_stuff3_3 = mul(0x0b, stuff3[0]) ^ mul(0x0d, stuff3[1]) ^ mul(0x09, stuff3[2]) ^ mul(0x0e, stuff3[3]);
	      
		  // recomputes X_1[3][0]....
		  word8 x_1_3_check = Si[ inv_mc_stuff3_3 ^ S[ X_2[3][3] ] ^ S[ X_2[3][0] ] ];

		  // and check it against the guess
		  if (X_1[3][0] != x_1_3_check) continue;   // if not equal, try next possibility
		  
		  // STEP 11 : at this stage, X_2 is entirely known.

		  // computes W_2
		  word8 W_2[4][4];
		  int q,r;
		  for(q=0; q<4; q++)
		    for(r=0; r<4; r++)
		      W_2[q][r] = X_2[q][r];
		  
		  Substitution(W_2,S,BC);
		  ShiftRow(W_2,0,BC);
		  MixColumn(W_2,BC);
		  
		    // and K_3
		  word8 K[4][4];
		  for(q=0; q<4; q++)
		    for(r=0; r<4; r++)
		      K[q][r] = W_2[q][r] ^ X_3[q][r];
		  
		  // invert the key_schedule for 10 rounds.
		  word8 K_[4][4];
		  int s;
		  
		  for(s=9; s>=0; s--) {
		    
		    for(q=0; q<4; q++)
		      for(r=1; r<4; r++)
			K_[q][r] = K[q][r] ^ K[q][r-1];
		      
		    K_[0][0] = K[0][0] ^ S[ K_[1][3] ] ^ rcon[s];
		    K_[1][0] = K[1][0] ^ S[ K_[2][3] ];
		    K_[2][0] = K[2][0] ^ S[ K_[3][3] ];
		    K_[3][0] = K[3][0] ^ S[ K_[0][3] ];
		    
		    for(q=0; q<4; q++)
		      for(r=0; r<4; r++)
			K[q][r] = K_[q][r];
		  }
		  
		  // Check the potential master key
		    
		  // generate the subkeys (OK, this is kind of stupid, considering what we have done before....)
		  word8 RK[20][4][MAXBC];
		  rijndaelKeySched (K, 128, 128, RK); 
		  word8 Test[4][MAXBC];
		  for(q=0; q<4; q++)
		    for(r=0; r<4; r++)
		      Test[q][r] = P[q][r];
		  
		  rijndaelEncrypt (Test, 128, 128, RK);
		  n_candidates++;
		  

		  int good = 1;
		  for(q=0; q<4; q++)
		    for(r=0; r<4; r++)
		      good &= (Test[q][r] == X_3[q][r]);
		  
		  if (good) {
		    printf("\nMaster Key found by the attack : \n");
		    print_state(K, 4);		    
		    
		    printf("\n\n%d candidate keys have been tested (expected number = 2^15)\n", n_candidates);
		    
		    exit(0);
		    }
		}
	      }
	    }
	  }	   
	}
      }
    }
}




int main() {

  int i,j,r;

  init_diff_tables();
  
  printf("\nDemonstrates the improved state recovery in the fault attack against the AES.\n");
  printf("(well, a tweaked version of the AES where the final MixColumn IS PRESENT (!).\n");
  printf("The attack is nevertheless adaptable to the normal case, but more cumbersome).\n\n");

  printf("C. Bouillaguet, P.Derbez, P.-A. Fouque, February 2011.\n");
  printf("This program is in the public domain.\n\n");

  srand(time(0));

  // generate the key, the plaintext, the clean and faulty ciphertexts.

  word8 P[4][MAXBC];
  
  word8 C_clean[4][MAXBC];
  word8 C_faulty[4][MAXBC];

  word8 K[4][MAXBC];
  word8 RK[20][4][MAXBC];

  // temporary states
  word8 state1[4][MAXBC];
  word8 state2[4][MAXBC];

  // pick a random "master key"
  for(i=0; i<4; i++) 
    for(j=0; j<4; j++) 
      K[i][j] = rand() & 0xff;

  printf("Chosen Master key (the secret to recover) : \n");
  print_state(K,4);

  // build the subkeys
  rijndaelKeySched (K, 128, 128, RK); 

  // pick a plaintext
  for(i=0; i<4; i++)
    for(j=0; j<4; j++) 
      state1[i][j] = P[i][j] = rand() & 0xff;

  printf("Known (=randomly generated) plaintext (what we encipher with the key) : \n");
  print_state(P,4);

  // initial key addition
  KeyAddition(state1,RK[0],4); 

  // encrypt 7 rounds
  for(r = 1; r < 8; r++) {
    Substitution(state1,S,4);
    ShiftRow(state1,0,4);
    MixColumn(state1,4);
    KeyAddition(state1,RK[r],4);
  }

  // initialize the second state
  for(i=0; i<4; i++)
    for(j=0; j<4; j++) 
      state2[i][j] = state1[i][j];

  // introduce a one-byte difference in the state
  state2[0][0] = rand() & 0xff;
  word8 delta_i = S[state1[0][0]] ^ S[state2[0][0]];

  if (delta_i == 0) {
    printf("Zero fault introduced ! re-run the attack.\n");
    exit(1);
  }

  // encrypt the 3 last rounds
  for(r = 8; r <= 10; r++) {
    Substitution(state1,S,4);          Substitution(state2,S,4);
    ShiftRow(state1,0,4);              ShiftRow(state2,0,4);
    MixColumn(state1,4);               MixColumn(state2,4);
    KeyAddition(state1,RK[r],4);       KeyAddition(state2,RK[r],4);
  }

  for(i=0; i<4; i++) 
    for(j=0; j<4; j++) {
      C_clean[i][j] = state1[i][j];    C_faulty[i][j] = state2[i][j];
    }

  printf("clean ciphertext : \n");
  print_state(C_clean,4);

  printf("faulty ciphertext : \n");
  print_state(C_faulty,4);
  printf("\n\n");

  // only give the required information to the attack
  PiretQuisquater(P, C_clean, C_faulty);
}
