거듭제곱 최적화

C++의 pow 함수가 느려서 대안을 찾아 보았고, 그 결과를 정리해 봅니다.

정확도가 떨어지더라도 float의 거듭제곱을 빠르게 얻고 싶다면
http://hackage.haskell.org/package/approximate-0.2.1.1/src/cbits/fast.c의 powf_fast 함수를 사용하면 좋은데, 그 함수는 아래와 같습니다.

float powf_fast(float a, float b) {
    union { float d; int x; } u = { a };
    u.x = (int)(b * (u.x - 1064866805) + 1064866805);
    return u.d;
}

정확도가 떨어지더라도 double의 거듭제곱을 빠르게 얻고 싶다면 Optimized Approximative pow() in C / C++ | Martin Ankerl의 거듭제곱 함수를 사용하면 좋은데, 그 함수는 아래와 같습니다.

inline double fastPow(double a, double b) {
  union {
    double d;
    int x[2];
  } u = { a };
  u.x[1] = (int)(b * (u.x[1] - 1072632447) + 1072632447);
  u.x[0] = 0;
  return u.d;
}

위의 함수보다 더 느리지만 더 정확한 함수는 아래와 같습니다.

// should be much more precise with large b
inline double fastPrecisePow(double a, double b) {
  // calculate approximation with fraction of the exponent
  int e = (int) b;
  union {
    double d;
    int x[2];
  } u = { a };
  u.x[1] = (int)((b - e) * (u.x[1] - 1072632447) + 1072632447);
  u.x[0] = 0;

  // exponentiation by squaring with the exponent's integer part
  // double r = u.d makes everything much slower, not sure why
  double r = 1.0;
  while (e) {
    if (e & 1) {
      r *= a;
    }
    a *= a;
    e >>= 1;
  }

  return r * u.d;
}

저는 위의 세 함수보다 좀 더 느리더라도 더 정확한 거듭제곱이 필요해서, 그걸 제 나름대로 구현해 보았습니다. Single-precision floating-point format – Wikipedia, the free encyclopedia를 보면 알 수 있듯이 부동소수점이 메모리에 저장되는 형식은 (-1의 부호승) * (1 + 소수부 * (2의 -23승)) * (2의 (지수부 – 127)승)이라는 2의 거듭제곱 형태입니다. 그걸 이용해 아래와 같은 과정을 거치면 거듭제곱 최적화가 가능합니다. 참고로, 아래에서 A와 B는 0 이상이며 log의 밑은 전부 2입니다.

A의 B승
= 2의 (logA * B)승
= 2의 (log((1 + (A의 소수부) / (2의 23승)) * (2의 ((A의 지수부) – 127)승)) * B)승
= 2의 ((log(1 + (A의 소수부) / (2의 23승)) + log(2의 ((A의 지수부) – 127)승)) * B)승
= 2의 ((log(1 + (A의 소수부) / (2의 23승)) + (A의 지수부) – 127) * B)승
이 됩니다.

위의 log 값은 log1부터 log2 사이이므로, 미리 계산해 배열에 기억해 뒀다가 참조하면 빠릅니다. 참고로, C 라이브러리엔 밑이 2인 log 함수가 없지만, 밑 변환 log 공식을 이용하면 밑이 2인 log 값을 간접적으로 구할 수 있습니다.

((log(1 + (A의 소수부) / (2의 23승)) + (A의 지수부) – 127) * B)를 x라 할 때, 2의 x승을 빠르게 구하려면 지수 법칙을 이용하면 됩니다. 즉, a의 (b + c)승은 (a의 b승) * (a의 c승)이므로, 2의 (x의 정수부 + x의 소수부)승은 (2의 (x의 정수부)승) * (2의 (x의 소수부)승)과 같습니다. x의 정수부는 -127부터 128까지의 정수이고 소수부는 -1부터 1까지의 실수이므로, 그 값들을 미리 계산해 기억해 뒀다가 참조하면 빠릅니다.

아래는 위의 내용을 실제로 구현한 소스 코드입니다. 성능 측정 결과로는 pow 함수보다 약 4배 빠릅니다.

const int exponent_bias = 127;

const int log2s_bit_size = 16;
float log2s[1 << log2s_bit_size];
const int log2s_size = sizeof log2s / sizeof log2s[0];
bool initialize_log2s()
{
	for (int i = 0; i < log2s_size; ++i)
		log2s[i] = log(1.f + static_cast<float>(i) / log2s_size) / log(2.f);
	return true;
}
bool log2s_is_initialized = initialize_log2s();

float integer_powers[256];
float fraction_powers[65536];
const int fraction_powers_size = sizeof fraction_powers / sizeof fraction_powers[0];
const int half_fraction_powers_size = fraction_powers_size / 2;
bool initialize_powers()
{
	for (int i = 0; i < sizeof integer_powers / sizeof integer_powers[0]; ++i)
		integer_powers[i] = pow(2.f, i - exponent_bias);
	for (int i = 0; i < fraction_powers_size; ++i)
		fraction_powers[i] = pow(2.f, static_cast<float>(i) / half_fraction_powers_size - 1.f);
	return true;
}
bool powers_are_initialized = initialize_powers();

float calculate_power(float base, float exponent)
{
	assert(base >= 0.f);
	assert(exponent >= 0.f);

	static const int significand_bit_size = 23;

	union ieee_754_byte32
	{
		float float_value;
		struct
		{
			unsigned significand : significand_bit_size;
			unsigned exponent : 8;
			unsigned sign : 1;
		} parts;
	};

	ieee_754_byte32 union_for_base { base };

	// 원래는 float base_significand = log2s[static_cast<int>((static_cast<float>(union_for_base.parts.significand) / (1 << significand_bit_count)) * log2s_size)];인데, 아래처럼 축약하면 더 빠릅니다.
	float base_significand = log2s[union_for_base.parts.significand >> (significand_bit_size - log2s_bit_size)];

	int base_exponent = union_for_base.parts.exponent - exponent_bias;
	
	// 이 방식에선 2의 -127승 미만은 제대로 처리하지 못합니다.
	float power = __max(-exponent_bias, (base_significand + base_exponent) * exponent);

	// 이 경우엔 제대로 처리하지 못하므로 예외 처리합니다. 코드 맨 앞에 놓으면 코드 실행이 느려지므로, 이렇게 중간에 뒀습니다.
	if (base == 0.f && exponent > 0.f)
		return 0.f;

	int integer_part_of_power = static_cast<int>(power);

	// 다음처럼 부동소수점 형식을 이용해 정수의 거듭제곱을 구할 수도 있는데, 성능 측정 결과로는 배열을 참조하는 게 더 빨랐습니다.
	// ieee_754_byte32 union_for_integer_part_of_power {};
	// union_for_integer_part_of_power.parts.exponent = integer_part_of_power + exponent_bias;
	// float _2_raised_to_the_power_of_integer_part_of_power = union_for_integer_part_of_power.float_value;
	float _2_raised_to_the_power_of_integer_part_of_power = integer_powers[integer_part_of_power + exponent_bias];
	assert(_2_raised_to_the_power_of_integer_part_of_power >= 0.f);

	float frational_part_of_power = power - integer_part_of_power;
	float _2_raised_to_the_power_of_frational_part_of_power = fraction_powers[static_cast<int>(half_fraction_powers_size * (1.f + frational_part_of_power))];
	assert(_2_raised_to_the_power_of_frational_part_of_power >= 0.f);

	return _2_raised_to_the_power_of_integer_part_of_power * _2_raised_to_the_power_of_frational_part_of_power; 
}
Advertisements

답글 남기기

아래 항목을 채우거나 오른쪽 아이콘 중 하나를 클릭하여 로그 인 하세요:

WordPress.com 로고

WordPress.com의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Twitter 사진

Twitter의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Facebook 사진

Facebook의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

Google+ photo

Google+의 계정을 사용하여 댓글을 남깁니다. 로그아웃 / 변경 )

%s에 연결하는 중