cp-library

This documentation is automatically generated by competitive-verifier/competitive-verifier

View the Project on GitHub pt13762104/cp-library

:heavy_check_mark: Matrix multiplication (AVX2) (yoshi_likes_e4/matmul.test.cpp)

Code

// @brief Matrix multiplication (AVX2)
#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"
#pragma GCC optimize("Ofast,unroll-loops")
#include <chrono>
#include <cstring>
#include <immintrin.h>
#include <iostream>
#include <type_traits>
#include <vector>
using namespace std;
namespace fastio {
static constexpr size_t buf_size = 1 << 18;
static constexpr size_t integer_size = 20;
static constexpr size_t block_size = 10000;

static char inbuf[buf_size + 1] = {};
static char outbuf[buf_size + 1] = {};
static char block_str[block_size * 4 + 1] = {};

static constexpr uint64_t power10[] = {1,
                                       10,
                                       100,
                                       1000,
                                       10000,
                                       100000,
                                       1000000,
                                       10000000,
                                       100000000,
                                       1000000000,
                                       10000000000,
                                       100000000000,
                                       1000000000000,
                                       10000000000000,
                                       100000000000000,
                                       1000000000000000,
                                       10000000000000000,
                                       100000000000000000,
                                       1000000000000000000,
                                       10000000000000000000u};

struct Scanner {
private:
  size_t pos, end;

  void load() {
    end = fread(inbuf, 1, buf_size, stdin);
    inbuf[end] = '\0';
  }
  void reload() {
    size_t len = end - pos;
    memmove(inbuf, inbuf + pos, len);
    end = len + fread(inbuf + len, 1, buf_size - len, stdin);
    inbuf[end] = '\0';
    pos = 0;
  }
  void skip_space() {
    while (inbuf[pos] <= ' ') {
      if (__builtin_expect(++pos == end, 0))
        reload();
    }
  }
  char get_next() { return inbuf[pos++]; }
  char get_next_nonspace() {
    skip_space();
    return inbuf[pos++];
  }

public:
  Scanner() { load(); }

  void scan(char &c) { c = get_next_nonspace(); }
  void scan(std::string &s) {
    skip_space();
    s = "";
    do {
      size_t start = pos;
      while (inbuf[pos] > ' ')
        pos++;
      s += std::string(inbuf + start, inbuf + pos);
      if (inbuf[pos] != '\0')
        break;
      reload();
    } while (true);
  }

  template <class T>
  typename std::enable_if<std::is_integral<T>::value, void>::type scan(T &x) {
    char c = get_next_nonspace();
    if (__builtin_expect(pos + integer_size >= end, 0))
      reload();
    bool neg = false;
    if (c == '-')
      neg = true, x = 0;
    else
      x = c & 15;
    while ((c = get_next()) >= '0')
      x = x * 10 + (c & 15);
    if (neg)
      x = -x;
  }

  template <class Head, class... Others>
  void scan(Head &head, Others &...others) {
    scan(head);
    scan(others...);
  }

  template <class T> Scanner &operator>>(T &x) {
    scan(x);
    return *this;
  }
};

struct Printer {
private:
  size_t pos = 0;

  void flush() {
    fwrite(outbuf, 1, pos, stdout);
    pos = 0;
  }

  void pre_calc() {
    for (size_t i = 0; i < block_size; i++) {
      size_t j = 4, k = i;
      while (j--) {
        block_str[i * 4 + j] = k % 10 + '0';
        k /= 10;
      }
    }
  }

  static constexpr size_t get_integer_size(uint64_t n) {
    if (n >= power10[10]) {
      if (n >= power10[19])
        return 20;
      if (n >= power10[18])
        return 19;
      if (n >= power10[17])
        return 18;
      if (n >= power10[16])
        return 17;
      if (n >= power10[15])
        return 16;
      if (n >= power10[14])
        return 15;
      if (n >= power10[13])
        return 14;
      if (n >= power10[12])
        return 13;
      if (n >= power10[11])
        return 12;
      return 11;
    } else {
      if (n >= power10[9])
        return 10;
      if (n >= power10[8])
        return 9;
      if (n >= power10[7])
        return 8;
      if (n >= power10[6])
        return 7;
      if (n >= power10[5])
        return 6;
      if (n >= power10[4])
        return 5;
      if (n >= power10[3])
        return 4;
      if (n >= power10[2])
        return 3;
      if (n >= power10[1])
        return 2;
      return 1;
    }
  }

public:
  Printer() { pre_calc(); }
  ~Printer() { flush(); }

  void print(char c) {
    outbuf[pos++] = c;
    if (__builtin_expect(pos == buf_size, 0))
      flush();
  }
  void print(const char *s) {
    while (*s != 0) {
      outbuf[pos++] = *s++;
      // if (pos == buf_size) flush();
      if (__builtin_expect(pos == buf_size, 0))
        flush();
    }
  }
  void print(const std::string &s) {
    for (auto c : s) {
      outbuf[pos++] = c;
      // if (pos == buf_size) flush();
      if (__builtin_expect(pos == buf_size, 0))
        flush();
    }
  }

  template <class T>
  typename std::enable_if<std::is_integral<T>::value, void>::type print(T x) {
    if (__builtin_expect(pos + integer_size >= buf_size, 0))
      flush();
    if (x < 0)
      print('-'), x = -x;
    size_t digit = get_integer_size(x);
    size_t len = digit;
    while (len >= 4) {
      len -= 4;
      memcpy(outbuf + pos + len, block_str + (x % block_size) * 4, 4);
      x /= block_size;
    }
    memcpy(outbuf + pos, block_str + x * 4 + (4 - len), len);
    pos += digit;
  }

  template <class Head, class... Others>
  void print(const Head &head, const Others &...others) {
    print(head);
    print(' ');
    print(others...);
  }

  template <class... Args> void println(const Args &...args) {
    print(args...);
    print('\n');
  }

  template <class T> Printer &operator<<(const T &x) {
    print(x);
    return *this;
  }
};
}; // namespace fastio

fastio::Scanner fin;
fastio::Printer fout;
#include <cassert>
using namespace std;
const int SZ = 1024;
const int s1 = 64, s3 = 32, NAIVE = 128;
chrono::high_resolution_clock Clock;
int64_t modmod8 = 8ULL * 998244353 * 998244353;
alignas(32) unsigned A[SZ * SZ], B[SZ * SZ];
alignas(32) uint64_t CC[SZ * SZ];
alignas(32) uint64_t bb[NAIVE * NAIVE];
alignas(32) __m256i tmp[NAIVE * NAIVE / 4], *Bv = (__m256i *)bb;
const __m256i MM8 = {modmod8, modmod8, modmod8, modmod8};
static inline __attribute__((always_inline)) __m256i shrink(__m256i a) {
  __m256i mask = _mm256_cmpgt_epi64(a - MM8, __m256i{});
  return a - (MM8 & mask);
}
#define LOAD(x)                                                                \
  t##x = _mm256_load_si256(                                                    \
      (__m256i *)(c + ((I_TILE * NAIVE + i) * l + K_TILE * NAIVE + i3) +       \
                  x * 4))
#define MADD(x)                                                                \
  t##x = _mm256_add_epi64(                                                     \
      t##x, _mm256_mul_epu32(aik, tmp[(i3 / s3) * 8 * NAIVE + 8 * k + x]))
#define STOR(x)                                                                \
  _mm256_store_si256(                                                          \
      (__m256i *)(c + ((I_TILE * NAIVE + i) * l + K_TILE * NAIVE + i3) +       \
                  x * 4),                                                      \
      t##x)
#define TMP(x)                                                                 \
  tmp[(i3 / s3) * 8 * NAIVE + 8 * k + x] = Bv[k * NAIVE / 4 + i3 / 4 + x]

void write_b_to_tmp(const unsigned *b, int J_TILE, int K_TILE, int k) {
  for (int i = 0; i < NAIVE; i++)
    for (int j = 0; j < NAIVE; j++)
      bb[i * NAIVE + j] = b[(J_TILE * NAIVE + i) * k + K_TILE * NAIVE + j];
  for (int i3 = 0; i3 < NAIVE; i3 += s3)
    for (int k = 0; k < NAIVE; k++)
      TMP(0), TMP(1), TMP(2), TMP(3), TMP(4), TMP(5), TMP(6), TMP(7);
}
void avx2_emm_tile(const unsigned *a, uint64_t *c, int I_TILE, int J_TILE,
                   int K_TILE, int n, int l) {
  for (int i3 = 0; i3 < NAIVE; i3 += s3)
    for (int i1 = 0; i1 < NAIVE; i1 += s1) {
      for (int i = 0; i < NAIVE; i++) {
        __m256i t0, t1, t2, t3, t4, t5, t6, t7;
        LOAD(0), LOAD(1), LOAD(2), LOAD(3);
        LOAD(4), LOAD(5), LOAD(6), LOAD(7);
        for (int k = i1; k < i1 + s1; k++) {
          if (k % 8 == 0) {
            t0 = shrink(t0), t1 = shrink(t1), t2 = shrink(t2), t3 = shrink(t3),
            t4 = shrink(t4), t5 = shrink(t5), t6 = shrink(t6), t7 = shrink(t7);
          }
          __m256i aik = _mm256_set1_epi32(
              a[(I_TILE * NAIVE + i) * n + J_TILE * NAIVE + k]);
          MADD(0), MADD(1), MADD(2), MADD(3);
          MADD(4), MADD(5), MADD(6), MADD(7);
        }
        STOR(0), STOR(1), STOR(2), STOR(3);
        STOR(4), STOR(5), STOR(6), STOR(7);
      }
    }
}
void emm(const unsigned *a, const unsigned *b, uint64_t *c, int m, int n,
         int k) {
  for (int J_TILE = 0; J_TILE * NAIVE < n; J_TILE++)
    for (int K_TILE = 0; K_TILE * NAIVE < k; K_TILE++) {
      write_b_to_tmp(b, J_TILE, K_TILE, k);
      for (int I_TILE = 0; I_TILE * NAIVE < m; I_TILE++)
        avx2_emm_tile(a, c, I_TILE, J_TILE, K_TILE, n, k);
    }
}
signed main() {
  int n, m, p;
  fin >> n >> m >> p;
  for (int i = 0; i < n; i++)
    for (int j = 0; j < m; j++)
      fin >> A[i * SZ + j];
  for (int i = 0; i < m; i++)
    for (int j = 0; j < p; j++)
      fin >> B[i * SZ + j];
  emm(A, B, CC, SZ, SZ, SZ);
  for (int i = 0; i < n; i++)
    for (int j = 0; j < p; j++) {
      fout << CC[i * SZ + j] % 998244353;
      fout << (j == p - 1 ? '\n' : ' ');
    }
}
#line 1 "yoshi_likes_e4/matmul.test.cpp"
// @brief Matrix multiplication (AVX2)
#define PROBLEM "https://judge.yosupo.jp/problem/matrix_product"
#pragma GCC optimize("Ofast,unroll-loops")
#include <chrono>
#include <cstring>
#include <immintrin.h>
#include <iostream>
#include <type_traits>
#include <vector>
using namespace std;
namespace fastio {
static constexpr size_t buf_size = 1 << 18;
static constexpr size_t integer_size = 20;
static constexpr size_t block_size = 10000;

static char inbuf[buf_size + 1] = {};
static char outbuf[buf_size + 1] = {};
static char block_str[block_size * 4 + 1] = {};

static constexpr uint64_t power10[] = {1,
                                       10,
                                       100,
                                       1000,
                                       10000,
                                       100000,
                                       1000000,
                                       10000000,
                                       100000000,
                                       1000000000,
                                       10000000000,
                                       100000000000,
                                       1000000000000,
                                       10000000000000,
                                       100000000000000,
                                       1000000000000000,
                                       10000000000000000,
                                       100000000000000000,
                                       1000000000000000000,
                                       10000000000000000000u};

struct Scanner {
private:
  size_t pos, end;

  void load() {
    end = fread(inbuf, 1, buf_size, stdin);
    inbuf[end] = '\0';
  }
  void reload() {
    size_t len = end - pos;
    memmove(inbuf, inbuf + pos, len);
    end = len + fread(inbuf + len, 1, buf_size - len, stdin);
    inbuf[end] = '\0';
    pos = 0;
  }
  void skip_space() {
    while (inbuf[pos] <= ' ') {
      if (__builtin_expect(++pos == end, 0))
        reload();
    }
  }
  char get_next() { return inbuf[pos++]; }
  char get_next_nonspace() {
    skip_space();
    return inbuf[pos++];
  }

public:
  Scanner() { load(); }

  void scan(char &c) { c = get_next_nonspace(); }
  void scan(std::string &s) {
    skip_space();
    s = "";
    do {
      size_t start = pos;
      while (inbuf[pos] > ' ')
        pos++;
      s += std::string(inbuf + start, inbuf + pos);
      if (inbuf[pos] != '\0')
        break;
      reload();
    } while (true);
  }

  template <class T>
  typename std::enable_if<std::is_integral<T>::value, void>::type scan(T &x) {
    char c = get_next_nonspace();
    if (__builtin_expect(pos + integer_size >= end, 0))
      reload();
    bool neg = false;
    if (c == '-')
      neg = true, x = 0;
    else
      x = c & 15;
    while ((c = get_next()) >= '0')
      x = x * 10 + (c & 15);
    if (neg)
      x = -x;
  }

  template <class Head, class... Others>
  void scan(Head &head, Others &...others) {
    scan(head);
    scan(others...);
  }

  template <class T> Scanner &operator>>(T &x) {
    scan(x);
    return *this;
  }
};

struct Printer {
private:
  size_t pos = 0;

  void flush() {
    fwrite(outbuf, 1, pos, stdout);
    pos = 0;
  }

  void pre_calc() {
    for (size_t i = 0; i < block_size; i++) {
      size_t j = 4, k = i;
      while (j--) {
        block_str[i * 4 + j] = k % 10 + '0';
        k /= 10;
      }
    }
  }

  static constexpr size_t get_integer_size(uint64_t n) {
    if (n >= power10[10]) {
      if (n >= power10[19])
        return 20;
      if (n >= power10[18])
        return 19;
      if (n >= power10[17])
        return 18;
      if (n >= power10[16])
        return 17;
      if (n >= power10[15])
        return 16;
      if (n >= power10[14])
        return 15;
      if (n >= power10[13])
        return 14;
      if (n >= power10[12])
        return 13;
      if (n >= power10[11])
        return 12;
      return 11;
    } else {
      if (n >= power10[9])
        return 10;
      if (n >= power10[8])
        return 9;
      if (n >= power10[7])
        return 8;
      if (n >= power10[6])
        return 7;
      if (n >= power10[5])
        return 6;
      if (n >= power10[4])
        return 5;
      if (n >= power10[3])
        return 4;
      if (n >= power10[2])
        return 3;
      if (n >= power10[1])
        return 2;
      return 1;
    }
  }

public:
  Printer() { pre_calc(); }
  ~Printer() { flush(); }

  void print(char c) {
    outbuf[pos++] = c;
    if (__builtin_expect(pos == buf_size, 0))
      flush();
  }
  void print(const char *s) {
    while (*s != 0) {
      outbuf[pos++] = *s++;
      // if (pos == buf_size) flush();
      if (__builtin_expect(pos == buf_size, 0))
        flush();
    }
  }
  void print(const std::string &s) {
    for (auto c : s) {
      outbuf[pos++] = c;
      // if (pos == buf_size) flush();
      if (__builtin_expect(pos == buf_size, 0))
        flush();
    }
  }

  template <class T>
  typename std::enable_if<std::is_integral<T>::value, void>::type print(T x) {
    if (__builtin_expect(pos + integer_size >= buf_size, 0))
      flush();
    if (x < 0)
      print('-'), x = -x;
    size_t digit = get_integer_size(x);
    size_t len = digit;
    while (len >= 4) {
      len -= 4;
      memcpy(outbuf + pos + len, block_str + (x % block_size) * 4, 4);
      x /= block_size;
    }
    memcpy(outbuf + pos, block_str + x * 4 + (4 - len), len);
    pos += digit;
  }

  template <class Head, class... Others>
  void print(const Head &head, const Others &...others) {
    print(head);
    print(' ');
    print(others...);
  }

  template <class... Args> void println(const Args &...args) {
    print(args...);
    print('\n');
  }

  template <class T> Printer &operator<<(const T &x) {
    print(x);
    return *this;
  }
};
}; // namespace fastio

fastio::Scanner fin;
fastio::Printer fout;
#include <cassert>
using namespace std;
const int SZ = 1024;
const int s1 = 64, s3 = 32, NAIVE = 128;
chrono::high_resolution_clock Clock;
int64_t modmod8 = 8ULL * 998244353 * 998244353;
alignas(32) unsigned A[SZ * SZ], B[SZ * SZ];
alignas(32) uint64_t CC[SZ * SZ];
alignas(32) uint64_t bb[NAIVE * NAIVE];
alignas(32) __m256i tmp[NAIVE * NAIVE / 4], *Bv = (__m256i *)bb;
const __m256i MM8 = {modmod8, modmod8, modmod8, modmod8};
static inline __attribute__((always_inline)) __m256i shrink(__m256i a) {
  __m256i mask = _mm256_cmpgt_epi64(a - MM8, __m256i{});
  return a - (MM8 & mask);
}
#define LOAD(x)                                                                \
  t##x = _mm256_load_si256(                                                    \
      (__m256i *)(c + ((I_TILE * NAIVE + i) * l + K_TILE * NAIVE + i3) +       \
                  x * 4))
#define MADD(x)                                                                \
  t##x = _mm256_add_epi64(                                                     \
      t##x, _mm256_mul_epu32(aik, tmp[(i3 / s3) * 8 * NAIVE + 8 * k + x]))
#define STOR(x)                                                                \
  _mm256_store_si256(                                                          \
      (__m256i *)(c + ((I_TILE * NAIVE + i) * l + K_TILE * NAIVE + i3) +       \
                  x * 4),                                                      \
      t##x)
#define TMP(x)                                                                 \
  tmp[(i3 / s3) * 8 * NAIVE + 8 * k + x] = Bv[k * NAIVE / 4 + i3 / 4 + x]

void write_b_to_tmp(const unsigned *b, int J_TILE, int K_TILE, int k) {
  for (int i = 0; i < NAIVE; i++)
    for (int j = 0; j < NAIVE; j++)
      bb[i * NAIVE + j] = b[(J_TILE * NAIVE + i) * k + K_TILE * NAIVE + j];
  for (int i3 = 0; i3 < NAIVE; i3 += s3)
    for (int k = 0; k < NAIVE; k++)
      TMP(0), TMP(1), TMP(2), TMP(3), TMP(4), TMP(5), TMP(6), TMP(7);
}
void avx2_emm_tile(const unsigned *a, uint64_t *c, int I_TILE, int J_TILE,
                   int K_TILE, int n, int l) {
  for (int i3 = 0; i3 < NAIVE; i3 += s3)
    for (int i1 = 0; i1 < NAIVE; i1 += s1) {
      for (int i = 0; i < NAIVE; i++) {
        __m256i t0, t1, t2, t3, t4, t5, t6, t7;
        LOAD(0), LOAD(1), LOAD(2), LOAD(3);
        LOAD(4), LOAD(5), LOAD(6), LOAD(7);
        for (int k = i1; k < i1 + s1; k++) {
          if (k % 8 == 0) {
            t0 = shrink(t0), t1 = shrink(t1), t2 = shrink(t2), t3 = shrink(t3),
            t4 = shrink(t4), t5 = shrink(t5), t6 = shrink(t6), t7 = shrink(t7);
          }
          __m256i aik = _mm256_set1_epi32(
              a[(I_TILE * NAIVE + i) * n + J_TILE * NAIVE + k]);
          MADD(0), MADD(1), MADD(2), MADD(3);
          MADD(4), MADD(5), MADD(6), MADD(7);
        }
        STOR(0), STOR(1), STOR(2), STOR(3);
        STOR(4), STOR(5), STOR(6), STOR(7);
      }
    }
}
void emm(const unsigned *a, const unsigned *b, uint64_t *c, int m, int n,
         int k) {
  for (int J_TILE = 0; J_TILE * NAIVE < n; J_TILE++)
    for (int K_TILE = 0; K_TILE * NAIVE < k; K_TILE++) {
      write_b_to_tmp(b, J_TILE, K_TILE, k);
      for (int I_TILE = 0; I_TILE * NAIVE < m; I_TILE++)
        avx2_emm_tile(a, c, I_TILE, J_TILE, K_TILE, n, k);
    }
}
signed main() {
  int n, m, p;
  fin >> n >> m >> p;
  for (int i = 0; i < n; i++)
    for (int j = 0; j < m; j++)
      fin >> A[i * SZ + j];
  for (int i = 0; i < m; i++)
    for (int j = 0; j < p; j++)
      fin >> B[i * SZ + j];
  emm(A, B, CC, SZ, SZ, SZ);
  for (int i = 0; i < n; i++)
    for (int j = 0; j < p; j++) {
      fout << CC[i * SZ + j] % 998244353;
      fout << (j == p - 1 ? '\n' : ' ');
    }
}

Test cases

Env Name Status Elapsed Memory
g++ example_00 :heavy_check_mark: AC 82 ms 15 MB
g++ example_01 :heavy_check_mark: AC 77 ms 16 MB
g++ example_02 :heavy_check_mark: AC 77 ms 16 MB
g++ max_random_00 :heavy_check_mark: AC 122 ms 21 MB
g++ max_random_01 :heavy_check_mark: AC 121 ms 21 MB
g++ max_random_02 :heavy_check_mark: AC 121 ms 21 MB
g++ random_00 :heavy_check_mark: AC 105 ms 21 MB
g++ random_01 :heavy_check_mark: AC 100 ms 20 MB
g++ random_02 :heavy_check_mark: AC 95 ms 21 MB
g++ signed_overflow_00 :heavy_check_mark: AC 83 ms 16 MB
g++ small_00 :heavy_check_mark: AC 79 ms 15 MB
g++ small_01 :heavy_check_mark: AC 79 ms 15 MB
g++ small_02 :heavy_check_mark: AC 79 ms 16 MB
g++ small_03 :heavy_check_mark: AC 84 ms 15 MB
g++ small_04 :heavy_check_mark: AC 79 ms 15 MB
g++ small_05 :heavy_check_mark: AC 78 ms 16 MB
g++ small_06 :heavy_check_mark: AC 80 ms 19 MB
g++ small_07 :heavy_check_mark: AC 79 ms 16 MB
g++ small_08 :heavy_check_mark: AC 79 ms 18 MB
g++ small_09 :heavy_check_mark: AC 79 ms 21 MB
g++ unsigned_overflow_00 :heavy_check_mark: AC 80 ms 16 MB
Back to top page