/* big.c -- the "small" big number routines
    by Spencer Putt
    
They're very specific to this signer, but they might make a good reference */


#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include "big.h"

big big_create(void) {
    big op;
    op = (big) malloc(sizeof(bignum));
    op->numlength = -1;
    op->sign = 0;
    memset(op->number,0,kMaxLength);
    return op;
}

void big_clear(big a) {
    free(a);
}

void big_read(big result, const unsigned char* buffer, int length) {
    int i;
    memset(result->number,0,kMaxLength);
    for (i = 0; i < length; i++) {
        result->number[i] = buffer[i];
    }
}

void big_add_ui(big result, big a, unsigned int b) {
    big temp;
    temp = big_create();
    
    big_read(temp, (const unsigned char*) &b, sizeof(int));
    big_add(result, a, temp);
    big_clear(temp);
}

void big_add(big result, big a, big b) {
    unsigned int ta, tb, tr, current = 0;
    unsigned int carry = 0;
    big temp;
    
    if (a->sign && !b->sign) {
        temp = big_create();
        big_set(temp, a);
        temp->sign = 0;
        big_sub(result,b,temp);
        big_clear(temp);
        return;
    }
    if (b->sign && !a->sign) {
        temp = big_create();
        big_set(temp, b);
        temp->sign = 0;
        big_sub(result,a,temp);
        big_clear(temp);
        return;
    }

    while (current < kMaxLength/4) {
        ta = ((unsigned int*)a->number)[current];
        tb = ((unsigned int*)b->number)[current];

        tr = ta + tb + carry;
        carry = (tr < ta) ? 1 : 0;
        
        ((unsigned int*)result->number)[current] = tr;
        current++;
    }
}

void big_sub_ui(big result, big a, unsigned int b) {
    big temp;
    temp = big_create();
    big_read(temp, (const unsigned char*) &b, sizeof(int));
    big_sub(result, a, temp);
    big_clear(temp);
}

void big_sub(big result, big a, big b) {
    unsigned int ta, tb, tr, current = 0,i;
    int lengtha;
    unsigned int carry = 0;
    int compare = big_cmp(a,b);
    
    if (compare >= 0) {
        lengtha = -1;
        for (i = 0; i < kMaxLength/4; i++) {
            if (((unsigned int*)a->number)[i]) lengtha = i;
        }
        if (lengtha == -1) {
            big_set_ui(result,0);
            return;
        }
        while (current <= lengtha) {
            ta = ((unsigned int*)a->number)[current];
            tb = ((unsigned int*)b->number)[current];
            
            tr = ta - tb - carry;
            if (tr > ta) carry = 1;
            else carry = 0;
            ((unsigned int*)result->number)[current] = tr;
            current++;
        }
        
        result->sign = 0;
        return;
    } else {
        // flip operands.
        //big_print(a);
        //big_print(b);
        big_sub(result,b,a);
        result->sign = 1;
    }    
}

int big_cmp_ui(big a, unsigned int b) {
    big temp;
    int result;
    
    temp = big_create();
    big_set_ui(temp, b);
    result = big_cmp(a, temp);
    big_clear(temp);
    return result;
}

int big_cmp(big a, big b) {
    int i;
    register unsigned int tempa,tempb;
    if (!a->sign && b->sign) return 1;      //a positive, b negative
    if (!b->sign && a->sign) {
        //printf("went negative\n");
        return -1;     //a negative, b positive
    }
    
    if (a->sign && b->sign) return -big_cmp(a,b);   //both negative
    
    i = (kMaxLength/4)-1;
    while (i >= 0) {
        tempa = ((unsigned int*)a->number)[i];
        tempb = ((unsigned int*)b->number)[i];
        if (tempa > tempb) return 1;
        if (tempb > tempa) return -1;
        i--;
    }
    return 0;
}

void big_print(big a) {
    int i,last;
    last = -1;
    for (i = 0; i < kMaxLength; i++) {
        if (a->number[i]) last = i;
    }
    if (last == -1) {
        putchar('0');
        putchar('\n');
        return;
    }
    if (a->sign) putchar('-');
    else putchar(' ');
    for (; last>=0; last--) {
        printf("%02X",a->number[last]);
    }
    printf("\n");
}

void big_neg(big a) {
    a->sign = !a->sign;
}

void big_set(big result, big a) {
    memcpy(result->number,a->number,kMaxLength);
    result->sign = a->sign;
}

void big_set_ui(big result, unsigned int a) {
    big_read(result, (const unsigned char*) &a, sizeof(int));
}


void big_mul_ui(big result, big a, unsigned int b) {
    big temp;
    
    temp = big_create();
    big_set_ui(temp, b);
    big_mul(result, a, temp);
    big_clear(temp);
}


//sum gets stored in a
int add_streams(unsigned short *res, unsigned short *a, unsigned short *b, int la, int lb) {
    int temp;
    unsigned short op_a,op_b,sum,carry;
    unsigned short* ptr;
    
    
    //printf("Beginning add stream at %d,%d\n",la,lb);
    // swap offending elements ( we want longest number first)
    if (lb > la || ((lb == la) && b[lb-1] > a[la-1])) {
        temp = lb;
        lb = la;
        la = temp;
        
        ptr = b;
        b = a;
        a = ptr;
    }
    //if (!lb) {lb++; *b = 0;}
    
    carry = temp = 0;
    for (; temp < lb; temp++,a++,b++) {
        op_a = *a; 
        op_b = *b;
        
        sum = op_a + op_b + carry;
        
        if (sum < op_a || sum < op_b) carry = 1;
        else carry = 0;
        
        res[temp] = sum;
    }
    //printf("Starting add complete after %d\n",temp);
    for (; temp < la; temp++,a++) {
        op_a = *a;
        sum = op_a + carry;
        if (sum < op_a) carry = 1;
        else carry = 0;
        res[temp] = sum;
    }
    if (carry == 0) return temp;
    //puts("overflow");
    res[temp] = carry;
    return temp+1;
}


void big_mul_long(unsigned int* result, unsigned int a, unsigned int b) {
    register unsigned int carry;
    unsigned int buffer[2] = {a, 0};
    
    while (b) {
        if (b & 1) {
            result[0]+=buffer[0];
            if (result[0] < buffer[0]) result[1]++;
            result[1]+=buffer[1];
        }
        carry = buffer[0] & 0x80000000;
        buffer[0]<<=1;
        buffer[1]<<=1;
        buffer[1] |= (carry >> 31);
        b>>=1;
    }
}

void big_mul_code(big result, unsigned short* a, unsigned short* b, int la, int lb) {
    int i,length,lx,ly;
    unsigned short *x1,*y1;
    //big debug,debug2,dresult;
    bignum bsA,bsB,bsC;
    int lra,lrb;                            //length of result
    unsigned short resa[65],resb[65];       //allocate 65 for each addition result.


    // if la (left number) is null, the result is zero.
    // this works since la occasionally can shrink to nothing.
    // b does no such reduction
    if (!la) return;

    //check for posible special cases.
    if (lb <= 2) {
        //((unsigned int*)result->number)[0] = (unsigned int) (*a) * (*b);
        //return;
        if (lb == 1) return big_mul_long((unsigned int*) result->number, (unsigned int) *a, (unsigned int) *b);
        if (la == 1) return big_mul_long((unsigned int*) result->number, (unsigned int) *a, *((unsigned int*) b));
        return big_mul_long((unsigned int*) result->number, *((unsigned int*) a), *((unsigned int*) b));    
    }
    
    // lb holds "ideal" length.
    if (lb & 1) length = lb+1;  //if this is odd, add one to it.
    else length = lb;
    // length of lb should be ideal length, b is always ideal.
    // if la is less than lb, that means a has been truncated.
    // a.la will be la - a.lb
    
    // i has ideal length
    i = length>>1;
    
    // we have a piece which is i + i in length.
    // lb in this case is twice i's length.
    
    // la can be reduced in size.
    if (la > i) lx = (la - i);
    else lx = 0;
    ly = (lb - i);
    x1 = a + i; 
    y1 = b + i;
    
    // seems like y1 is never truncated since it comes from b
    memset(bsA.number,0,kMaxLength); bsA.sign = 0;
    if (ly > lx) big_mul_code(&bsA, x1, y1, lx, ly);
    else big_mul_code(&bsA, y1, x1, ly, lx);
    
    // niether a nor b will be truncated, since they are both bs.
    memset(bsB.number,0,kMaxLength); 
    bsB.sign = 0;
    big_mul_code(&bsB, a, b, i, i);

    lra = lrb = 0;
    lra = add_streams(resa, a, x1, i, lx);
    lrb = add_streams(resb, b, y1, i, ly);
    /*
    debug = big_create();
    debug2 = big_create();
    dresult = big_create();
    puts("x1+a = resa");
    big_read(debug,(unsigned char*)x1,lx*2);
    big_print(debug);
    big_read(debug2,(unsigned char*)a,i*2);
    big_print(debug2);
    big_add(dresult,debug,debug2);
    big_read(debug,(unsigned char*)resa,lra*2);
    big_print(debug);
    big_print(dresult);
    if (big_cmp(dresult,debug)) system("PAUSE");

    big_clear(debug);
    big_clear(debug2);
    big_clear(dresult);*/

    memset(bsC.number,0,kMaxLength); 
    bsC.sign = 0;
    if (lra > lrb) big_mul_code(&bsC, resb, resa, lrb, lra);
    else big_mul_code(&bsC, resa, resb, lra, lrb);
  
    big_sub(&bsC, &bsC, &bsA);
    big_sub(&bsC, &bsC, &bsB);

    memcpy(result->number + length*2, bsA.number, length*2);
    memcpy(result->number, bsB.number, length*2);
    
    memset(bsA.number,0,length);
    memcpy(bsA.number + length, bsC.number, kMaxLength-length);
    big_add(result, result, &bsA);
}    

void big_mul(big result, big a, big b) {
    int la,lb,i;
    register unsigned int temp;
    big tempresult;
    
    tempresult = big_create();
    
    for (i = (kMaxLength/4)-1; i>=0; i--) {
        temp = ((unsigned int*) a->number)[i];
        if (temp) {
            if (temp > 0x0000FFFF) {la = 2*(i + 1); break; }
            la = 2*i + 1; break;
        }
    }
    for (i = (kMaxLength/4)-1; i>=0; i--) {
        temp = ((unsigned int*) b->number)[i];
        if (temp) {
            if (temp > 0x0000FFFF) {lb = (i + 1)<<1; break; }
            lb = (i<<1) + 1; break;
        }
    }

    if (lb > la) big_mul_code(tempresult, (unsigned short*) a->number,(unsigned short*) b->number, la, lb);
    else big_mul_code(tempresult, (unsigned short*) b->number,(unsigned short*) a->number, lb, la);

    if (a->sign && !b->sign) tempresult->sign = 1;
    else if (b->sign && !a->sign) tempresult->sign = 1;

    big_set(result,tempresult);
    big_clear(tempresult);
}
        
void big_srl(big a) {
    int i,last,carry = 0,next_carry;
    last = -1;
    for (i = 0; i < kMaxLength; i++) {
        if (a->number[i]) last = i;
    }
    if (last == -1) return;
    while (last >= 0) {
         next_carry = a->number[last] & 1;
         a->number[last]>>=1;
         if (carry)
            a->number[last]|= 0x80;
        carry = next_carry;
        last--;
    }
    
}

void big_sll(big a) {
    register unsigned int cint;
    int i,last;
    unsigned int carry = 0,next_carry;
    
    last = -1;
    for (i = 0; i < kMaxLength/4; i++) {
        if (((unsigned int*)a->number)[i]) last = i;
    }
    if (last == -1) return;
    i = 0;
    while (i <= last) {
        cint = ((unsigned int*)a->number)[i];
        next_carry = cint & 0x80000000;
        cint<<=1;
        cint |= carry;
        ((unsigned int*)a->number)[i] = cint;
        carry = next_carry>>31;
        i++;
    }
    if (!carry) return;
    ((unsigned int*)a->number)[i] = 1;
}
    
// I don't copy args, so don't put result the same as your args
void big_powm(big result, big base, big exp, big mod) {
    int i,length;
    unsigned char cbyte;
    int bitcount = 0,byteloc = 0;
    big modulo;
    //each bit of base to exp power, modulo it, multiply with next bit
    //so, start with 0
    
    for (i = 0; i < kMaxLength; i++) {
        if (exp->number[i]) length = i;
    }  
    // so length is the location of the LAST byte in the exponent.
    modulo = big_create();
    
    big_mod(modulo, base, mod);     //get modulo of base^1.
    //printf("original modulo: \n");
    //big_print(modulo);
    big_set_ui(result, 1);          //set result to mult identity

    for (byteloc = 0; byteloc <= length; byteloc++) {
        cbyte = exp->number[byteloc];
        for (bitcount = 0; bitcount < 8; bitcount++, cbyte>>=1) {
            if (cbyte & 1) {
                big_mul(result, result, modulo);
                big_mod(result, result, mod);
            }
            big_mul(modulo, modulo, modulo);    //square modulo
            big_mod(modulo, modulo, mod);       //take new modulo
        }
    }
    big_clear(modulo);
}
        
    
int big_legendre(big ia, big in) {
    int j = 1,mres;
    big val,a,n;
    
    a   = big_create();
    n   = big_create();
    val = big_create();
    
    big_set(a, ia);
    big_set(n, in);
    
    big_mod(a, a, n);
    while (big_cmp_ui(a, 0)) {
        while ((a->number[0] & 1)==0) {
            big_srl(a);
            mres = n->number[0] & 0x07;
            if (mres == 3 || mres == 5) j = -j;
        }
        big_set(val, a);
        big_set(a, n);
        big_set(n, val);
        
        if (((a->number[0] & 3)==3) && ((n->number[0] & 3)==3))
            j = -j;
            
        big_mod(a, a, n);
    }
    big_clear(val);
    big_clear(a);
    big_clear(n);
    if (!big_cmp_ui(n, 1)) return j;
    return 0;
}
            
void big_sl32(big a) {
    int i;
    for (i = kMaxLength/4-1; i; i--) {
        ((unsigned int*)a->number)[i] = ((unsigned int*)a->number)[i-1];
    }
    ((unsigned int*)a->number)[0] = 0;
}

void big_div_code(big result, big a, big b, int do_mod) {
    big bC;
    unsigned int i,last,lastbyte;
    unsigned int carry = 0,new_carry;
    int shiftcount,shiftindex;
    unsigned int iresult = 0, iA;
    
    bC = big_create();
    
    for (i = 0; i < kMaxLength/4; i++) {
        if (((unsigned int*)a->number)[i]) last = i;
    }
    lastbyte = last;
    last++;
    last<<=5;

    shiftcount = 31; shiftindex = lastbyte;
    iA = ((unsigned int*) a->number)[lastbyte];    

    for (; last; last--,shiftcount--) {
        iresult <<= 1;
        carry = iA & 0x80000000;
        iA <<= 1;
        new_carry = ((unsigned int*)bC->number)[lastbyte] & 0x80000000;
        big_sll(bC);
        ((unsigned int*)bC->number)[0] |= carry>>31;
        if (new_carry || (big_cmp(bC, b)>=0)) {
            big_sub(bC, bC, b);
            iresult |= 1;
        }
        if (shiftcount) continue;
        iA = ((unsigned int*) a->number)[shiftindex-1];
        ((unsigned int*)result->number)[shiftindex] = iresult;
        shiftindex--;
        shiftcount=32;
    }
    if (do_mod == 1) {
        big_set(result, bC);
    }
    big_clear(bC);
}//*/

void big_div(big result, big a, big b) {
    if (big_cmp(a,b) == -1) {
        puts("Small exception evoked.");
        big_set_ui(result, 0);
        return;
    }
    big_div_code(result, a, b, 0);
}

void big_mod(big result, big a, big mod) {
    int cmp = big_cmp(a, mod);
    int neg_mod = 0;

    if (cmp == -1) { //if mod > a
        if (!a->sign) {
            big_set(result, a);
            return;
        }
        neg_mod = 1;
    } else if (!cmp) {
        big_set_ui(result, 0);
        return;
    }
    big_div_code(result, a, mod, 1);
    
    if (neg_mod) big_sub(result, mod, result);
}

