rsa.c 4.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166
  1. #include <stdio.h>
  2. #include <stdint.h>
  3. #include <stdlib.h>
  4. #include <math.h>
  5. #include <time.h>
  6. #include <string.h>
  7. char buffer[1024];
  8. const int MAX_DIGITS = 50;
  9. int i, j = 0;
  10. struct public_key_class {
  11. long long modulus;
  12. long long exponent;
  13. };
  14. struct private_key_class {
  15. long long modulus;
  16. long long exponent;
  17. };
  18. // This should totally be in the math library.
  19. long long gcd(long long a, long long b) {
  20. long long c;
  21. while (a != 0) {
  22. c = a;
  23. a = b % a;
  24. b = c;
  25. }
  26. return b;
  27. }
  28. long long ExtEuclid(long long a, long long b) {
  29. long long x = 0, y = 1, u = 1, v = 0, gcd = b, m, n, q, r;
  30. while (a != 0) {
  31. q = gcd / a;
  32. r = gcd % a;
  33. m = x - u * q;
  34. n = y - v * q;
  35. gcd = a;
  36. a = r;
  37. x = u;
  38. y = v;
  39. u = m;
  40. v = n;
  41. }
  42. return y;
  43. }
  44. static inline long long modmult(long long a, long long b, long long mod) {
  45. // this is necessary since we will be dividing by a
  46. if (a == 0) {
  47. return 0;
  48. }
  49. register long long product = a * b;
  50. //if multiplication does not overflow, we can use it
  51. if (product / a == b) {
  52. return product % mod;
  53. }
  54. // if a % 2 == 1 i. e. a >> 1 is not a / 2
  55. if (a & 1) {
  56. product = modmult((a >> 1), b, mod);
  57. if ((product << 1) > product) {
  58. return (((product << 1) % mod) + b) % mod;
  59. }
  60. }
  61. //implicit else
  62. product = modmult((a >> 1), b, mod);
  63. if ((product << 1) > product) {
  64. return (product << 1) % mod;
  65. }
  66. //implicit else: this is about 10x slower than the code above, but it will not overflow
  67. long long sum;
  68. sum = 0;
  69. while (b > 0) {
  70. if (b & 1)
  71. sum = (sum + a) % mod;
  72. a = (2 * a) % mod;
  73. b >>= 1;
  74. }
  75. return sum;
  76. }
  77. long long rsa_modExp(long long b, long long e, long long m) {
  78. long long product;
  79. product = 1;
  80. if (b < 0 || e < 0 || m <= 0) {
  81. return -1;
  82. }
  83. b = b % m;
  84. while (e > 0) {
  85. if (e & 1) {
  86. product = modmult(product, b, m);
  87. }
  88. b = modmult(b, b, m);
  89. e >>= 1;
  90. }
  91. return product;
  92. }
  93. // Calling this function will generate a public and private key and store them in the pointers
  94. // it is given.
  95. void rsa_gen_keys(struct public_key_class *pub, struct private_key_class *priv) {
  96. pub->modulus = 1580420911;
  97. pub->exponent = 131073;
  98. priv->modulus = 1580420911;
  99. priv->exponent = 874267137;
  100. }
  101. unsigned char *rsa_encrypt(const unsigned char *message, const unsigned long message_size,
  102. const struct public_key_class *pub) {
  103. long long *encrypted = malloc(sizeof(long long) * message_size);
  104. if (encrypted == NULL) {
  105. fprintf(stderr,
  106. "Error: Heap allocation failed.\n");
  107. return NULL;
  108. }
  109. long long i = 0;
  110. for (i = 0; i < message_size; i++) {
  111. if ((encrypted[i] = rsa_modExp(message[i], pub->exponent, pub->modulus)) == -1)
  112. return NULL;
  113. }
  114. return (unsigned char*)encrypted;
  115. }
  116. unsigned char *rsa_decrypt(const long long *message,
  117. const unsigned long message_size,
  118. const struct private_key_class *priv) {
  119. if (message_size % sizeof(long long) != 0) {
  120. fprintf(stderr,
  121. "Error: message_size is not divisible by %d, so cannot be output of rsa_encrypt\n",
  122. (int) sizeof(long long));
  123. return NULL;
  124. }
  125. // We allocate space to do the decryption (temp) and space for the output as a char array
  126. // (decrypted)
  127. unsigned char *decrypted = malloc(message_size / sizeof(long long));
  128. unsigned char *temp = malloc(message_size);
  129. if ((decrypted == NULL) || (temp == NULL)) {
  130. fprintf(stderr,
  131. "Error: Heap allocation failed.\n");
  132. return NULL;
  133. }
  134. // Now we go through each 8-byte chunk and decrypt it.
  135. long long i = 0;
  136. for (i = 0; i < message_size / 8; i++) {
  137. if ((temp[i] = rsa_modExp(message[i], priv->exponent, priv->modulus)) == -1) {
  138. free(temp);
  139. return NULL;
  140. }
  141. }
  142. // The result should be a number in the char range, which gives back the original byte.
  143. // We put that into decrypted, then return.
  144. for (i = 0; i < message_size / 8; i++) {
  145. decrypted[i] = temp[i];
  146. }
  147. free(temp);
  148. return decrypted;
  149. }