aのn乗をmで割った余り

タイトルのやつを3種類書きました
以下ソースです

#include<stdio.h>
#include<time.h>
int exp_mod(int a,int n,int m)
{
  long int res = 1;
  for(;n>0;n--)
    res = (res * a) % m;
  return (int)res % m;
}

int exp_mod_fixed(int a,int n, int m)
{
  long int res = 1;
  if(n==0)
    return 1;
  if(n%2){
    res = exp_mod(a,n/2,m) % m;
    res = res*res % m;
    res = res*a % m;
  }else{
    res = exp_mod(a,n/2,m) % m;
    res = res*res % m;
  }
  return (int)res;
}

int exp_mod_binaly(int a,int n,int m)
{
  long int res=1;
  long int exp = a;
  while(n > 0){
    if(n & 1)
      res = (exp * res) % m;
    exp = (exp*exp) % m;
    n >>= 1;
  }
  return (int)res;
}
int main()
{
  int a,n,m;
  clock_t start,end;
  printf("    base : \t");
  scanf("%d",&a);
  printf("exponent : \t");
  scanf("%d",&n);
  printf("     mod : \t");
  scanf("%d",&m);

  start = clock();
  printf("  result : \t %d\n",exp_mod(a,n,m));
  end = clock();
  printf("calc time : %.2f sec by exp_mod\n",(double)(end-start)/CLOCKS_PER_SEC);

  start = clock();
  printf("  result : \t %d\n",exp_mod_fixed(a,n,m));
  end = clock();
  printf("calc time : %.2f sec by exp_mod_fixed\n",(double)(end-start)/CLOCKS_PER_SEC);

  start = clock();
  printf("  result : \t %d\n",exp_mod_binaly(a,n,m));
  end = clock();
  printf("calc time : %.2f sec by exp_binaly\n",(double)(end-start)/CLOCKS_PER_SEC);

  return 0;
}


実行例

    base :      2526
exponent :      353
     mod :      777
  result :       174
calc time : 0.00 sec by exp_mod
  result :       174
calc time : 0.00 sec by exp_mod_fixed
  result :       174
calc time : 0.00 sec by exp_binaly

    base :      2012123499
exponent :      353255243
     mod :      5352552
  result :       795627
calc time : 4.67 sec by exp_mod
  result :       795627
calc time : 2.34 sec by exp_mod_fixed
  result :       795627
calc time : 0.00 sec by exp_binaly

    base :      2000000000
exponent :      2123942242
     mod :      4255
  result :       1020
calc time : 28.72 sec by exp_mod
  result :       1020
calc time : 14.23 sec by exp_mod_fixed
  result :       1020
calc time : 0.00 sec by exp_binaly

3種類の関数を動作させ、計測時間を印字させています。
なかなか面白い結果になりました。

上のexp_modは、「aのn乗をmで割った余りを求めたい!」と思ったときに誰もが真っ先に思いつくコードだと思います。計算の後から剰余をとっても先に剰余をとっても答えが変わらない性質を使っています。オーバーフローを防ぐために計算過程はlong intに格納しています。たかだかint型の掛け算、余り算なので、long intで収まります。しかし、上の実行例のように2000000000が入力されたときは単純に2000000000回ループが回り、計算時間が見ての通り大変なことになっています。

真ん中のexp_mod_fixedは、べき乗の演算がもつ特性を使っています。exp値が2で割れるとき、a/2は整数となり、a^exp = (a^exp/2)*(a^exp/2)と分解でき、2で割れないとき、a^exp = (a^exp/2)*(a^exp/2)*a(小数点切り捨て)とかけるのを利用して、再帰的にもとめています。exp値が指数関数的に減少するので計算効率はいいと思ってたんですけど、なんだかそうでもない感じですね...。個人的には好きなコードです。

下のexp_mod_binalyが今回の目玉です。あるexpを2進数表記にし、下のビットから見ていったときにnビット目が1だったらa^nをresにかける、というだけのシンプルなものです。aのn乗をmで割った余りが知りたいので、a^nをresにかけたものをmで割ったものを新しくresとしています。指数法則でべき乗同士の掛け算はexp値の足し算となることを利用しています。aが倍々に増えていくので、掛け算の回数も少なくてすみます。これはint型なので、たかだか32回の計算で終わります。

追記 (2017/07/08)
exp_mod_fixed関数に誤りがありました

int exp_mod_fixed(int a,int n, int m)
{
  long int res = 1;
  if(n==0)
    return 1;
  if(n%2){
    res = exp_mod_fixed(a,n/2,m) % m;
    res = res*res % m;
    res = res*a % m;
  }else{
    res = exp_mod_fixed(a,n/2,m) % m;
    res = res*res % m;
  }
  return (int)res;
}

元のコードだとexp_mod_fixed内でexp_modのほうを呼び出してしまっていたので実行が遅かったみたいです
上のコードで、しっかりexp_mod_fixed内で自分自身を呼び出す再帰構造になおしました。これで試してみると、

実行結果

    base :      2000000000
exponent :      2109424214
     mod :      4344113
  result :       1763220
calc time : 27.95 sec by exp_mod
  result :       1763220
calc time : 0.00 sec by exp_mod_fixed
  result :       1763220
calc time : 0.00 sec by exp_binaly


やっぱ速いジャン↑↑