This documentation is automatically generated by competitive-verifier/competitive-verifier
// @brief Matrix multiplication (AVX2)
#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"
#pragma GCC optimize("Ofast,unroll-loops")
#include <chrono>
#include <cstring>
#include <immintrin.h>
#include <iostream>
#include <type_traits>
#include <vector>
using namespace std;
namespace fastio {
static constexpr size_t buf_size = 1 << 18;
static constexpr size_t integer_size = 20;
static constexpr size_t block_size = 10000;
static char inbuf[buf_size + 1] = {};
static char outbuf[buf_size + 1] = {};
static char block_str[block_size * 4 + 1] = {};
static constexpr uint64_t power10[] = {1,
10,
100,
1000,
10000,
100000,
1000000,
10000000,
100000000,
1000000000,
10000000000,
100000000000,
1000000000000,
10000000000000,
100000000000000,
1000000000000000,
10000000000000000,
100000000000000000,
1000000000000000000,
10000000000000000000u};
struct Scanner {
private:
size_t pos, end;
void load() {
end = fread(inbuf, 1, buf_size, stdin);
inbuf[end] = '\0';
}
void reload() {
size_t len = end - pos;
memmove(inbuf, inbuf + pos, len);
end = len + fread(inbuf + len, 1, buf_size - len, stdin);
inbuf[end] = '\0';
pos = 0;
}
void skip_space() {
while (inbuf[pos] <= ' ') {
if (__builtin_expect(++pos == end, 0))
reload();
}
}
char get_next() { return inbuf[pos++]; }
char get_next_nonspace() {
skip_space();
return inbuf[pos++];
}
public:
Scanner() { load(); }
void scan(char &c) { c = get_next_nonspace(); }
void scan(std::string &s) {
skip_space();
s = "";
do {
size_t start = pos;
while (inbuf[pos] > ' ')
pos++;
s += std::string(inbuf + start, inbuf + pos);
if (inbuf[pos] != '\0')
break;
reload();
} while (true);
}
template <class T>
typename std::enable_if<std::is_integral<T>::value, void>::type scan(T &x) {
char c = get_next_nonspace();
if (__builtin_expect(pos + integer_size >= end, 0))
reload();
bool neg = false;
if (c == '-')
neg = true, x = 0;
else
x = c & 15;
while ((c = get_next()) >= '0')
x = x * 10 + (c & 15);
if (neg)
x = -x;
}
template <class Head, class... Others>
void scan(Head &head, Others &...others) {
scan(head);
scan(others...);
}
template <class T> Scanner &operator>>(T &x) {
scan(x);
return *this;
}
};
struct Printer {
private:
size_t pos = 0;
void flush() {
fwrite(outbuf, 1, pos, stdout);
pos = 0;
}
void pre_calc() {
for (size_t i = 0; i < block_size; i++) {
size_t j = 4, k = i;
while (j--) {
block_str[i * 4 + j] = k % 10 + '0';
k /= 10;
}
}
}
static constexpr size_t get_integer_size(uint64_t n) {
if (n >= power10[10]) {
if (n >= power10[19])
return 20;
if (n >= power10[18])
return 19;
if (n >= power10[17])
return 18;
if (n >= power10[16])
return 17;
if (n >= power10[15])
return 16;
if (n >= power10[14])
return 15;
if (n >= power10[13])
return 14;
if (n >= power10[12])
return 13;
if (n >= power10[11])
return 12;
return 11;
} else {
if (n >= power10[9])
return 10;
if (n >= power10[8])
return 9;
if (n >= power10[7])
return 8;
if (n >= power10[6])
return 7;
if (n >= power10[5])
return 6;
if (n >= power10[4])
return 5;
if (n >= power10[3])
return 4;
if (n >= power10[2])
return 3;
if (n >= power10[1])
return 2;
return 1;
}
}
public:
Printer() { pre_calc(); }
~Printer() { flush(); }
void print(char c) {
outbuf[pos++] = c;
if (__builtin_expect(pos == buf_size, 0))
flush();
}
void print(const char *s) {
while (*s != 0) {
outbuf[pos++] = *s++;
// if (pos == buf_size) flush();
if (__builtin_expect(pos == buf_size, 0))
flush();
}
}
void print(const std::string &s) {
for (auto c : s) {
outbuf[pos++] = c;
// if (pos == buf_size) flush();
if (__builtin_expect(pos == buf_size, 0))
flush();
}
}
template <class T>
typename std::enable_if<std::is_integral<T>::value, void>::type print(T x) {
if (__builtin_expect(pos + integer_size >= buf_size, 0))
flush();
if (x < 0)
print('-'), x = -x;
size_t digit = get_integer_size(x);
size_t len = digit;
while (len >= 4) {
len -= 4;
memcpy(outbuf + pos + len, block_str + (x % block_size) * 4, 4);
x /= block_size;
}
memcpy(outbuf + pos, block_str + x * 4 + (4 - len), len);
pos += digit;
}
template <class Head, class... Others>
void print(const Head &head, const Others &...others) {
print(head);
print(' ');
print(others...);
}
template <class... Args> void println(const Args &...args) {
print(args...);
print('\n');
}
template <class T> Printer &operator<<(const T &x) {
print(x);
return *this;
}
};
}; // namespace fastio
fastio::Scanner fin;
fastio::Printer fout;
#include <cassert>
using namespace std;
const int SZ = 1024;
const int s1 = 64, s3 = 32, NAIVE = 128;
chrono::high_resolution_clock Clock;
int64_t modmod8 = 8ULL * 998244353 * 998244353;
alignas(32) unsigned A[SZ * SZ], B[SZ * SZ];
alignas(32) uint64_t CC[SZ * SZ];
alignas(32) uint64_t bb[NAIVE * NAIVE];
alignas(32) __m256i tmp[NAIVE * NAIVE / 4], *Bv = (__m256i *)bb;
const __m256i MM8 = {modmod8, modmod8, modmod8, modmod8};
static inline __attribute__((always_inline)) __m256i shrink(__m256i a) {
__m256i mask = _mm256_cmpgt_epi64(a - MM8, __m256i{});
return a - (MM8 & mask);
}
#define LOAD(x) \
t##x = _mm256_load_si256( \
(__m256i *)(c + ((I_TILE * NAIVE + i) * l + K_TILE * NAIVE + i3) + \
x * 4))
#define MADD(x) \
t##x = _mm256_add_epi64( \
t##x, _mm256_mul_epu32(aik, tmp[(i3 / s3) * 8 * NAIVE + 8 * k + x]))
#define STOR(x) \
_mm256_store_si256( \
(__m256i *)(c + ((I_TILE * NAIVE + i) * l + K_TILE * NAIVE + i3) + \
x * 4), \
t##x)
#define TMP(x) \
tmp[(i3 / s3) * 8 * NAIVE + 8 * k + x] = Bv[k * NAIVE / 4 + i3 / 4 + x]
void write_b_to_tmp(const unsigned *b, int J_TILE, int K_TILE, int k) {
for (int i = 0; i < NAIVE; i++)
for (int j = 0; j < NAIVE; j++)
bb[i * NAIVE + j] = b[(J_TILE * NAIVE + i) * k + K_TILE * NAIVE + j];
for (int i3 = 0; i3 < NAIVE; i3 += s3)
for (int k = 0; k < NAIVE; k++)
TMP(0), TMP(1), TMP(2), TMP(3), TMP(4), TMP(5), TMP(6), TMP(7);
}
void avx2_emm_tile(const unsigned *a, uint64_t *c, int I_TILE, int J_TILE,
int K_TILE, int n, int l) {
for (int i3 = 0; i3 < NAIVE; i3 += s3)
for (int i1 = 0; i1 < NAIVE; i1 += s1) {
for (int i = 0; i < NAIVE; i++) {
__m256i t0, t1, t2, t3, t4, t5, t6, t7;
LOAD(0), LOAD(1), LOAD(2), LOAD(3);
LOAD(4), LOAD(5), LOAD(6), LOAD(7);
for (int k = i1; k < i1 + s1; k++) {
if (k % 8 == 0) {
t0 = shrink(t0), t1 = shrink(t1), t2 = shrink(t2), t3 = shrink(t3),
t4 = shrink(t4), t5 = shrink(t5), t6 = shrink(t6), t7 = shrink(t7);
}
__m256i aik = _mm256_set1_epi32(
a[(I_TILE * NAIVE + i) * n + J_TILE * NAIVE + k]);
MADD(0), MADD(1), MADD(2), MADD(3);
MADD(4), MADD(5), MADD(6), MADD(7);
}
STOR(0), STOR(1), STOR(2), STOR(3);
STOR(4), STOR(5), STOR(6), STOR(7);
}
}
}
void emm(const unsigned *a, const unsigned *b, uint64_t *c, int m, int n,
int k) {
for (int J_TILE = 0; J_TILE * NAIVE < n; J_TILE++)
for (int K_TILE = 0; K_TILE * NAIVE < k; K_TILE++) {
write_b_to_tmp(b, J_TILE, K_TILE, k);
for (int I_TILE = 0; I_TILE * NAIVE < m; I_TILE++)
avx2_emm_tile(a, c, I_TILE, J_TILE, K_TILE, n, k);
}
}
signed main() {
int n, m, p;
fin >> n >> m >> p;
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
fin >> A[i * SZ + j];
for (int i = 0; i < m; i++)
for (int j = 0; j < p; j++)
fin >> B[i * SZ + j];
emm(A, B, CC, SZ, SZ, SZ);
for (int i = 0; i < n; i++)
for (int j = 0; j < p; j++) {
fout << CC[i * SZ + j] % 998244353;
fout << (j == p - 1 ? '\n' : ' ');
}
}
#line 1 "yoshi_likes_e4/matmul.test.cpp"
// @brief Matrix multiplication (AVX2)
#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"
#pragma GCC optimize("Ofast,unroll-loops")
#include <chrono>
#include <cstring>
#include <immintrin.h>
#include <iostream>
#include <type_traits>
#include <vector>
using namespace std;
namespace fastio {
static constexpr size_t buf_size = 1 << 18;
static constexpr size_t integer_size = 20;
static constexpr size_t block_size = 10000;
static char inbuf[buf_size + 1] = {};
static char outbuf[buf_size + 1] = {};
static char block_str[block_size * 4 + 1] = {};
static constexpr uint64_t power10[] = {1,
10,
100,
1000,
10000,
100000,
1000000,
10000000,
100000000,
1000000000,
10000000000,
100000000000,
1000000000000,
10000000000000,
100000000000000,
1000000000000000,
10000000000000000,
100000000000000000,
1000000000000000000,
10000000000000000000u};
struct Scanner {
private:
size_t pos, end;
void load() {
end = fread(inbuf, 1, buf_size, stdin);
inbuf[end] = '\0';
}
void reload() {
size_t len = end - pos;
memmove(inbuf, inbuf + pos, len);
end = len + fread(inbuf + len, 1, buf_size - len, stdin);
inbuf[end] = '\0';
pos = 0;
}
void skip_space() {
while (inbuf[pos] <= ' ') {
if (__builtin_expect(++pos == end, 0))
reload();
}
}
char get_next() { return inbuf[pos++]; }
char get_next_nonspace() {
skip_space();
return inbuf[pos++];
}
public:
Scanner() { load(); }
void scan(char &c) { c = get_next_nonspace(); }
void scan(std::string &s) {
skip_space();
s = "";
do {
size_t start = pos;
while (inbuf[pos] > ' ')
pos++;
s += std::string(inbuf + start, inbuf + pos);
if (inbuf[pos] != '\0')
break;
reload();
} while (true);
}
template <class T>
typename std::enable_if<std::is_integral<T>::value, void>::type scan(T &x) {
char c = get_next_nonspace();
if (__builtin_expect(pos + integer_size >= end, 0))
reload();
bool neg = false;
if (c == '-')
neg = true, x = 0;
else
x = c & 15;
while ((c = get_next()) >= '0')
x = x * 10 + (c & 15);
if (neg)
x = -x;
}
template <class Head, class... Others>
void scan(Head &head, Others &...others) {
scan(head);
scan(others...);
}
template <class T> Scanner &operator>>(T &x) {
scan(x);
return *this;
}
};
struct Printer {
private:
size_t pos = 0;
void flush() {
fwrite(outbuf, 1, pos, stdout);
pos = 0;
}
void pre_calc() {
for (size_t i = 0; i < block_size; i++) {
size_t j = 4, k = i;
while (j--) {
block_str[i * 4 + j] = k % 10 + '0';
k /= 10;
}
}
}
static constexpr size_t get_integer_size(uint64_t n) {
if (n >= power10[10]) {
if (n >= power10[19])
return 20;
if (n >= power10[18])
return 19;
if (n >= power10[17])
return 18;
if (n >= power10[16])
return 17;
if (n >= power10[15])
return 16;
if (n >= power10[14])
return 15;
if (n >= power10[13])
return 14;
if (n >= power10[12])
return 13;
if (n >= power10[11])
return 12;
return 11;
} else {
if (n >= power10[9])
return 10;
if (n >= power10[8])
return 9;
if (n >= power10[7])
return 8;
if (n >= power10[6])
return 7;
if (n >= power10[5])
return 6;
if (n >= power10[4])
return 5;
if (n >= power10[3])
return 4;
if (n >= power10[2])
return 3;
if (n >= power10[1])
return 2;
return 1;
}
}
public:
Printer() { pre_calc(); }
~Printer() { flush(); }
void print(char c) {
outbuf[pos++] = c;
if (__builtin_expect(pos == buf_size, 0))
flush();
}
void print(const char *s) {
while (*s != 0) {
outbuf[pos++] = *s++;
// if (pos == buf_size) flush();
if (__builtin_expect(pos == buf_size, 0))
flush();
}
}
void print(const std::string &s) {
for (auto c : s) {
outbuf[pos++] = c;
// if (pos == buf_size) flush();
if (__builtin_expect(pos == buf_size, 0))
flush();
}
}
template <class T>
typename std::enable_if<std::is_integral<T>::value, void>::type print(T x) {
if (__builtin_expect(pos + integer_size >= buf_size, 0))
flush();
if (x < 0)
print('-'), x = -x;
size_t digit = get_integer_size(x);
size_t len = digit;
while (len >= 4) {
len -= 4;
memcpy(outbuf + pos + len, block_str + (x % block_size) * 4, 4);
x /= block_size;
}
memcpy(outbuf + pos, block_str + x * 4 + (4 - len), len);
pos += digit;
}
template <class Head, class... Others>
void print(const Head &head, const Others &...others) {
print(head);
print(' ');
print(others...);
}
template <class... Args> void println(const Args &...args) {
print(args...);
print('\n');
}
template <class T> Printer &operator<<(const T &x) {
print(x);
return *this;
}
};
}; // namespace fastio
fastio::Scanner fin;
fastio::Printer fout;
#include <cassert>
using namespace std;
const int SZ = 1024;
const int s1 = 64, s3 = 32, NAIVE = 128;
chrono::high_resolution_clock Clock;
int64_t modmod8 = 8ULL * 998244353 * 998244353;
alignas(32) unsigned A[SZ * SZ], B[SZ * SZ];
alignas(32) uint64_t CC[SZ * SZ];
alignas(32) uint64_t bb[NAIVE * NAIVE];
alignas(32) __m256i tmp[NAIVE * NAIVE / 4], *Bv = (__m256i *)bb;
const __m256i MM8 = {modmod8, modmod8, modmod8, modmod8};
static inline __attribute__((always_inline)) __m256i shrink(__m256i a) {
__m256i mask = _mm256_cmpgt_epi64(a - MM8, __m256i{});
return a - (MM8 & mask);
}
#define LOAD(x) \
t##x = _mm256_load_si256( \
(__m256i *)(c + ((I_TILE * NAIVE + i) * l + K_TILE * NAIVE + i3) + \
x * 4))
#define MADD(x) \
t##x = _mm256_add_epi64( \
t##x, _mm256_mul_epu32(aik, tmp[(i3 / s3) * 8 * NAIVE + 8 * k + x]))
#define STOR(x) \
_mm256_store_si256( \
(__m256i *)(c + ((I_TILE * NAIVE + i) * l + K_TILE * NAIVE + i3) + \
x * 4), \
t##x)
#define TMP(x) \
tmp[(i3 / s3) * 8 * NAIVE + 8 * k + x] = Bv[k * NAIVE / 4 + i3 / 4 + x]
void write_b_to_tmp(const unsigned *b, int J_TILE, int K_TILE, int k) {
for (int i = 0; i < NAIVE; i++)
for (int j = 0; j < NAIVE; j++)
bb[i * NAIVE + j] = b[(J_TILE * NAIVE + i) * k + K_TILE * NAIVE + j];
for (int i3 = 0; i3 < NAIVE; i3 += s3)
for (int k = 0; k < NAIVE; k++)
TMP(0), TMP(1), TMP(2), TMP(3), TMP(4), TMP(5), TMP(6), TMP(7);
}
void avx2_emm_tile(const unsigned *a, uint64_t *c, int I_TILE, int J_TILE,
int K_TILE, int n, int l) {
for (int i3 = 0; i3 < NAIVE; i3 += s3)
for (int i1 = 0; i1 < NAIVE; i1 += s1) {
for (int i = 0; i < NAIVE; i++) {
__m256i t0, t1, t2, t3, t4, t5, t6, t7;
LOAD(0), LOAD(1), LOAD(2), LOAD(3);
LOAD(4), LOAD(5), LOAD(6), LOAD(7);
for (int k = i1; k < i1 + s1; k++) {
if (k % 8 == 0) {
t0 = shrink(t0), t1 = shrink(t1), t2 = shrink(t2), t3 = shrink(t3),
t4 = shrink(t4), t5 = shrink(t5), t6 = shrink(t6), t7 = shrink(t7);
}
__m256i aik = _mm256_set1_epi32(
a[(I_TILE * NAIVE + i) * n + J_TILE * NAIVE + k]);
MADD(0), MADD(1), MADD(2), MADD(3);
MADD(4), MADD(5), MADD(6), MADD(7);
}
STOR(0), STOR(1), STOR(2), STOR(3);
STOR(4), STOR(5), STOR(6), STOR(7);
}
}
}
void emm(const unsigned *a, const unsigned *b, uint64_t *c, int m, int n,
int k) {
for (int J_TILE = 0; J_TILE * NAIVE < n; J_TILE++)
for (int K_TILE = 0; K_TILE * NAIVE < k; K_TILE++) {
write_b_to_tmp(b, J_TILE, K_TILE, k);
for (int I_TILE = 0; I_TILE * NAIVE < m; I_TILE++)
avx2_emm_tile(a, c, I_TILE, J_TILE, K_TILE, n, k);
}
}
signed main() {
int n, m, p;
fin >> n >> m >> p;
for (int i = 0; i < n; i++)
for (int j = 0; j < m; j++)
fin >> A[i * SZ + j];
for (int i = 0; i < m; i++)
for (int j = 0; j < p; j++)
fin >> B[i * SZ + j];
emm(A, B, CC, SZ, SZ, SZ);
for (int i = 0; i < n; i++)
for (int j = 0; j < p; j++) {
fout << CC[i * SZ + j] % 998244353;
fout << (j == p - 1 ? '\n' : ' ');
}
}
Env | Name | Status | Elapsed | Memory |
---|---|---|---|---|
g++ | example_00 |
![]() |
82 ms | 15 MB |
g++ | example_01 |
![]() |
77 ms | 16 MB |
g++ | example_02 |
![]() |
77 ms | 16 MB |
g++ | max_random_00 |
![]() |
122 ms | 21 MB |
g++ | max_random_01 |
![]() |
121 ms | 21 MB |
g++ | max_random_02 |
![]() |
121 ms | 21 MB |
g++ | random_00 |
![]() |
105 ms | 21 MB |
g++ | random_01 |
![]() |
100 ms | 20 MB |
g++ | random_02 |
![]() |
95 ms | 21 MB |
g++ | signed_overflow_00 |
![]() |
83 ms | 16 MB |
g++ | small_00 |
![]() |
79 ms | 15 MB |
g++ | small_01 |
![]() |
79 ms | 15 MB |
g++ | small_02 |
![]() |
79 ms | 16 MB |
g++ | small_03 |
![]() |
84 ms | 15 MB |
g++ | small_04 |
![]() |
79 ms | 15 MB |
g++ | small_05 |
![]() |
78 ms | 16 MB |
g++ | small_06 |
![]() |
80 ms | 19 MB |
g++ | small_07 |
![]() |
79 ms | 16 MB |
g++ | small_08 |
![]() |
79 ms | 18 MB |
g++ | small_09 |
![]() |
79 ms | 21 MB |
g++ | unsigned_overflow_00 |
![]() |
80 ms | 16 MB |