03 June 2024, Monday
GPT-2 parallelisation (part 3): layer normalisation with SIMD
In this article we parallelise layer normalisation found in a Transformer (such as GPT-2) using SIMD instructions.
This is the third part of the GPT-2 parallelisation series.
Layer normalisation ensures that the input (token embeddings before QKV phase) and subsequent intermediate attention layers, are standardised.
This prevents the learned weights and biases from being affected by sample specific artefacts in the input and intermediate layers. In other words, this allows the Transformer to learn and generalise across the features in the data, rather than focusing on the features within samples.
Layer normalisation is different from batch normalisation. In layer normalisation, we normalise across the channels, so that the channels of a token is first normalised (calculate the \(z\)-score), followed by scaling (weighting) and shifting (add bias). On the other hand, batch normalisation happens along the channels for the entire batch of tokens. In other words, each channel is normalised, scaled and shifted along all tokens in the batch, and we do this separately for each channel.
Layer normalisation can be summarised with the following diagram:
Here, an array of channels is supplied as input. If the layer normalisation is done after the token embedding and positional encoding phase (see GPT-2 parallelisation series), the inputs will be the embedding for a line of tokens, where each line contains \(T\) tokens, and each token is embedded using \(C\) channels. When the input is an intermediate attention layer, these will be the transformed vectors in vector space with \(C\) dimensions.
For each group of channels, we first calculate the sample mean and the biased sample standard deviation (not the unbiased standard deviation, see Why divide by (n-1) when calculating sample variance?). This is the reduction parallel pattern, which we will see frequently in subsequent articles. During reduction, we use multiple threads to reduce an array of values into a single value. For instance, summation, which is required for calculation of the mean, or finding the maximum value, etc. are reduction operations.
Once we have the sample means (\(\bar{x}\)) and standard deviations (\(s\)), we use these to normalise the corresponding group of channels by calculating the \(z\)-scores of the channel values. Hence, for each array of channel values \(x\), we calculate the \(z\)-scores as:
\[z = \frac{x - \bar{x}}{s}.\]
Once we have these, then we scale the channels using the weights and add a bias to shift the channels.
We use different weights and biases depending on the layer we are
normalising. For instance, the layer normalisation after the token
embedding and positional encoding phase is the \(0\)-th layer, such that it uses the first
of the normalisation weights and biases (ln1w
and
ln1b
, see Memory
layout of GPT-2 parameters).
There are \(L\) weights and biases,
each with \(C\) channels, where \(L\) is the number of layers. Note that
there is a second phase of layer normalisation, which uses a different
set of weights and biases, i.e., ln2w
and
ln2b
, and a final full-connected layer normalisation, which
uses lnfw
and lnfb
. All these have the same
number of channels.
The layer normalisation function is implemented as follows:
void layernorm_tokens(
float * restrict output,
float * restrict means,
float * restrict rstds,
const float * restrict input,
const float * restrict weights,
const float * restrict biases,
uint32_t num_channels,
uint32_t num_tokens
)
{
#if defined(__AVX__)
(
layernorm_tokens_avx, means, rstds, input, weights,
output, num_channels, num_tokens
biases);
#elif defined(__SSE__) && defined(__SSE3__)
(
layernorm_tokens_sse, means, rstds, input, weights,
output, num_channels, num_tokens
biases);
#else
(
layernorm_tokens_seq, means, rstds, input, weights,
output, num_channels, num_tokens
biases);
#endif
}
The following is a sequential implementation of layer normalisation (albeit, with the option to use OpenMP automatic parallelisation when OpenMP multi-threading is available).
void layernorm_tokens_seq(
float * restrict output,
float * restrict means,
float * restrict rstds,
const float * restrict input,
const float * restrict weights,
const float * restrict biases,
uint32_t num_channels,
uint32_t num_tokens
)
{
#if defined(DEBUG) && !defined(__SSE__)
("Cannot print correct floating point values without SSE\n");
printf#endif
#pragma omp parallel for \
firstprivate(output, means, rstds, input, \
weights, biases, num_channels)
for (uint32_t i = 0; i < num_tokens; ++i) {
uint32_t idx = i * num_channels;
const float *input_ptr = input + idx;
const float *i_ptr = input_ptr;
float sum = 0.0f;
for (uint32_t j = 0; j < num_channels; ++j)
+= *i_ptr++;
sum float mean = sum / num_channels;
#if defined(DEBUG) && defined(__SSE__)
if (i == 0)
("sum: %.*e\nmean: %.*e\n",
printf, sum, DECIMAL_DIG, mean);
DECIMAL_DIG#endif
= input_ptr;
i_ptr = 0.0f;
sum for (uint32_t j = 0; j < num_channels; ++j) {
float deviation = *i_ptr++ - mean;
+= deviation * deviation;
sum }
float biased_variance = sum / num_channels;
float rstd = 0.0f;
(&rstd, biased_variance);
Q_rsqrt
#if defined(DEBUG) && defined(__SSE__)
if (i == 0)
("biased_variance: %.*e\nrstd: %.*e\n",
printf, biased_variance, DECIMAL_DIG, rstd);
DECIMAL_DIG#endif
= input_ptr;
i_ptr float *o_ptr = output + idx;
const float *w_ptr = weights;
const float *b_ptr = biases;
for (uint32_t j = 0; j < num_channels; ++j)
*o_ptr++ = (*i_ptr++ - mean) * rstd *
*w_ptr++ + *b_ptr++;
[i] = mean;
means[i] = rstd;
rstds}
}
Since we are experimenting with SIMD instructions, where we enable
and disable SIMD instructions, it was difficult to use the existing C
system libraries to do our experiments, since the installed
sqrtf()
function was compiled with SIMD instructions
enabled. Hence, for these experiments we use the fast inverse square
root implementation from the Quake
3D Engine.
/* Quake inverse square root function
* To disable SSE, we cannot use sqrtf() library function. */
void Q_rsqrt(float *output, float input)
{
union {
float f;
uint32_t i;
} conv = { .f = input };
.i = 0x5f3759df - (conv.i >> 1);
conv.f *= 1.5F - (input * 0.5F * conv.f * conv.f);
conv*output = conv.f;
}
Similarly, note that we had to disable printing of floating-point values when SSE is disabled, since the C system libraries that I have installed uses SIMD instructions for printing floating-point values.
To vectorise the implementation, we first do all we can with AVX intrinsics, which processes 8 channels at a time, followed by SSE intrinsics that processes 4 channels at a time, finally processing the remaining channels using scalar instructions.
#if defined(__AVX__)
void layernorm_tokens_avx(
float * restrict output,
float * restrict means,
float * restrict rstds,
const float * restrict input,
const float * restrict weights,
const float * restrict biases,
uint32_t num_channels,
uint32_t num_tokens
)
{
uint32_t n = num_channels;
const uint32_t num_avx_blocks = n >> 3;
const uint32_t num_avx_items = num_avx_blocks << 3;
-= num_avx_items;
n
const uint32_t num_sse_blocks = n >> 2;
const uint32_t num_sse_items = num_sse_blocks << 2;
#pragma omp parallel for \
firstprivate(output, means, rstds, input, \
weights, biases, num_channels, \
num_avx_items, num_sse_items)
for (uint32_t i = 0; i < num_tokens; ++i) {
const uint32_t idx = i * num_channels;
const float *input_ptr = input + idx;
const float *i_ptr = input_ptr;
const float *i_end = i_ptr + num_channels;
const float *i_avx_end = i_ptr + num_avx_items;
const float *i_sse_end = i_avx_end + num_sse_items;
= _mm256_loadu_ps(i_ptr);
__m256 sum_avx += 8;
i_ptr while (i_ptr != i_avx_end) {
= _mm256_add_ps(
sum_avx , _mm256_loadu_ps(i_ptr)
sum_avx);
+= 8;
i_ptr }
= _mm256_extractf128_ps(sum_avx, 1);
__m128 sum_avx_high = _mm256_castps256_ps128(sum_avx);
__m128 sum_avx_low = _mm_add_ps(sum_avx_high, sum_avx_low);
__m128 sum_sse
while (i_ptr != i_sse_end) {
= _mm_add_ps(sum_sse, _mm_loadu_ps(i_ptr));
sum_sse += 4;
i_ptr }
= _mm_hadd_ps(sum_sse, sum_sse);
sum_sse = _mm_hadd_ps(sum_sse, sum_sse);
sum_sse
float sum = 0.0f;
(&sum, sum_sse);
_mm_store_ss
while (i_ptr != i_end)
+= *i_ptr++;
sum
float mean = sum / num_channels;
#if defined(DEBUG)
if (i == 0)
("sum: %.*e\nmean: %.*e\n",
printf, sum, DECIMAL_DIG, mean);
DECIMAL_DIG#endif
= input_ptr;
i_ptr = _mm256_set1_ps(mean);
__m256 mean_avx = _mm256_sub_ps(
__m256 temp_avx (i_ptr),
_mm256_loadu_ps
mean_avx);
= _mm256_mul_ps(temp_avx, temp_avx);
sum_avx += 8;
i_ptr while (i_ptr != i_avx_end) {
= _mm256_sub_ps(
temp_avx (i_ptr),
_mm256_loadu_ps
mean_avx);
= _mm256_add_ps(
sum_avx ,
sum_avx(temp_avx, temp_avx)
_mm256_mul_ps);
+= 8;
i_ptr }
= _mm256_extractf128_ps(sum_avx, 1);
sum_avx_high = _mm256_castps256_ps128(sum_avx);
sum_avx_low = _mm_add_ps(sum_avx_high, sum_avx_low);
sum_sse
= _mm_set1_ps(mean);
__m128 mean_sse ;
__m128 temp_ssewhile (i_ptr != i_sse_end) {
= _mm_sub_ps(
temp_sse (i_ptr),
_mm_loadu_ps
mean_sse);
= _mm_add_ps(
sum_sse ,
sum_sse(temp_sse, temp_sse)
_mm_mul_ps);
+= 4;
i_ptr }
= _mm_hadd_ps(sum_sse, sum_sse);
sum_sse = _mm_hadd_ps(sum_sse, sum_sse);
sum_sse (&sum, sum_sse);
_mm_store_ss
while (i_ptr != i_end) {
const float deviation = *i_ptr++ - mean;
+= deviation * deviation;
sum }
const float biased_variance = sum / num_channels;
float rstd = 0.0f;
(&rstd, biased_variance);
Q_rsqrt
#if defined(DEBUG)
if (i == 0)
("biased_variance: %.*e\nrstd: %.*e\n",
printf, biased_variance, DECIMAL_DIG, rstd);
DECIMAL_DIG#endif
= input_ptr;
i_ptr float *o_ptr = output + idx;
const float *w_ptr = weights;
const float *b_ptr = biases;
= _mm256_set1_ps(rstd);
__m256 rstd_avx while (i_ptr != i_avx_end) {
= _mm256_sub_ps(
temp_avx (i_ptr),
_mm256_loadu_ps
mean_avx);
= _mm256_mul_ps(temp_avx, rstd_avx);
temp_avx = _mm256_mul_ps(
temp_avx ,
temp_avx(w_ptr)
_mm256_loadu_ps);
(
_mm256_storeu_ps,
o_ptr(
_mm256_add_ps,
temp_avx(b_ptr)
_mm256_loadu_ps)
);
+= 8;
i_ptr += 8;
o_ptr += 8;
w_ptr += 8;
b_ptr }
= _mm_set1_ps(rstd);
__m128 rstd_sse while (i_ptr != i_sse_end) {
= _mm_sub_ps(_mm_loadu_ps(i_ptr), mean_sse);
temp_sse = _mm_mul_ps(temp_sse, rstd_sse);
temp_sse = _mm_mul_ps(temp_sse, _mm_loadu_ps(w_ptr));
temp_sse (
_mm_storeu_ps,
o_ptr(temp_sse, _mm_loadu_ps(b_ptr))
_mm_add_ps);
+= 4;
i_ptr += 4;
o_ptr += 4;
w_ptr += 4;
b_ptr }
while (i_ptr != i_end)
*o_ptr++ = (*i_ptr++ - mean) * rstd *
*w_ptr++ + *b_ptr++;
[i] = mean;
means[i] = rstd;
rstds}
}
#endif
The SSE only implementation (when AVX in not available) is as follows:
#if defined(__SSE__) && defined(__SSE3__)
void layernorm_tokens_sse(
float * restrict output,
float * restrict means,
float * restrict rstds,
const float * restrict input,
const float * restrict weights,
const float * restrict biases,
uint32_t num_channels,
uint32_t num_tokens
)
{
const uint32_t num_sse_blocks = num_channels >> 2;
const uint32_t num_sse_items = num_sse_blocks << 2;
#pragma omp parallel for \
firstprivate(output, means, rstds, input, \
weights, biases, num_channels, num_sse_items)
for (uint32_t i = 0; i < num_tokens; ++i) {
const uint32_t idx = i * num_channels;
const float *input_ptr = input + idx;
const float *i_ptr = input_ptr;
const float *i_end = i_ptr + num_channels;
const float *i_sse_end = i_ptr + num_sse_items;
= _mm_loadu_ps(i_ptr);
__m128 sum_sse += 4;
i_ptr while (i_ptr != i_sse_end) {
= _mm_add_ps(sum_sse, _mm_loadu_ps(i_ptr));
sum_sse += 4;
i_ptr }
= _mm_hadd_ps(sum_sse, sum_sse);
sum_sse = _mm_hadd_ps(sum_sse, sum_sse);
sum_sse
float sum = 0.0f;
(&sum, sum_sse);
_mm_store_ss
while (i_ptr != i_end)
+= *i_ptr++;
sum
float mean = sum / num_channels;
#if defined(DEBUG)
if (i == 0)
("sum: %.*e\nmean: %.*e\n",
printf, sum, DECIMAL_DIG, mean);
DECIMAL_DIG#endif
= input_ptr;
i_ptr = _mm_set1_ps(mean);
__m128 mean_sse = _mm_sub_ps(_mm_loadu_ps(i_ptr), mean_sse);
__m128 temp_sse = _mm_mul_ps(temp_sse, temp_sse);
sum_sse += 4;
i_ptr while (i_ptr != i_sse_end) {
= _mm_sub_ps(_mm_loadu_ps(i_ptr), mean_sse);
temp_sse = _mm_add_ps(
sum_sse ,
sum_sse(temp_sse, temp_sse)
_mm_mul_ps);
+= 4;
i_ptr }
= _mm_hadd_ps(sum_sse, sum_sse);
sum_sse = _mm_hadd_ps(sum_sse, sum_sse);
sum_sse (&sum, sum_sse);
_mm_store_ss
while (i_ptr != i_end) {
const float deviation = *i_ptr++ - mean;
+= deviation * deviation;
sum }
const float biased_variance = sum / num_channels;
float rstd = 0.0f;
(&rstd, biased_variance);
Q_rsqrt
#if defined(DEBUG)
if (i == 0)
("biased_variance: %.*e\nrstd: %.*e\n",
printf, biased_variance, DECIMAL_DIG, rstd);
DECIMAL_DIG#endif
= input_ptr;
i_ptr float *o_ptr = output + idx;
const float *w_ptr = weights;
const float *b_ptr = biases;
= _mm_set1_ps(rstd);
__m128 rstd_sse while (i_ptr != i_sse_end) {
= _mm_sub_ps(_mm_loadu_ps(i_ptr), mean_sse);
temp_sse = _mm_mul_ps(temp_sse, rstd_sse);
temp_sse = _mm_mul_ps(temp_sse, _mm_loadu_ps(w_ptr));
temp_sse (
_mm_storeu_ps,
o_ptr(temp_sse, _mm_loadu_ps(b_ptr))
_mm_add_ps);
+= 4;
i_ptr += 4;
o_ptr += 4;
w_ptr += 4;
b_ptr }
while (i_ptr != i_end)
*o_ptr++ = (*i_ptr++ - mean) * rstd *
*w_ptr++ + *b_ptr++;
[i] = mean;
means[i] = rstd;
rstds}
}
#endif
Performance evaluation and CUDA implementation to be written.
llm.c