Šta je novo?

m1 amx intruction set provaljen

bmaxa

Poštovan
Učlanjen(a)
22.01.2021
Poruke
880
Poena
95
Elem haker je uradio reverse engenering Apple biblioteka, pa sistemom
probaj/pogresi provalio instrukcije.
C makroi nema ni asembler ni kompajler, mozete napraviti
asembler od prilozenog. Sve je lepo dokumentovano.
AMX je kao avx512 samo ima i matricne operacije.
8 X registara po 8 64 bitnih brojeva, 8 Y registara isto,
i rezultat u Z registar koji je matrica 8x8x64.
 
Uzeo sam kao pocetak napusten lib u Rustu, pa krenuo da implementiram instrukciju po instrukciju.
I evo primera prve instrukcije, (Rust).
fma64, najkorisnija instrukcija, koja mnozi x*y+z i to ide po indeksu po redovima. Dakle da bi izmonizili sve redove, potrebno
je pozvati instrukciju 8 puta. U primeru pozivam za prvi i zadnji red.
Kod:
use amx::{prelude::*, XBytes, XRow, YBytes, YRow, ZRow};


fn main() {
    unsafe {
        let mut ctx = amx::AmxCtx::new().unwrap();

        let in_x: Vec<u16> = vec![1;256];
        let in_y: Vec<u16> = vec![3;256];
        let mut in_xf: Vec<f64> = vec![1.0;64];
        let mut in_yf: Vec<f64> = vec![3.0;64];
        let in_zf: Vec<f64> = vec![2.0;64*8];
        for i in 0..64 {
          for j in 0..8{
            in_xf[i] += i as f64;
            in_yf[i] += i as f64;
          }
        }
        ctx.clear();
        ctx.set0();

        for i in 0..8 {
            //ctx.load512(&in_x[i * 32], XRow(i));
            //ctx.load512(&in_y[i * 32], YRow(i));
            ctx.load512(&in_xf[i*8], XRow(i));
            ctx.load512(&in_yf[i*8], YRow(i));
        }
        for i in 0..64 {
            ctx.load512(&in_zf[i*8], ZRow(i));
        }

//        println!("x = {:?}", *(in_x.as_ptr() as *const [[u16; 32]; 8]));
//        println!("y = {:?}", *(in_y.as_ptr() as *const [[u16; 32]; 8]));
       let got_x = std::mem::transmute::<_,[[f64;8];8]>(ctx.read_x());
       let got_y = std::mem::transmute::<_,[[f64;8];8]>(ctx.read_y());
       println!("X");
       printA::<8,8>(&got_x);
       println!("Y");
       printA::<8,8>(&got_y);
/*
            ctx.outer_product_u32_xy_to_z(
                Some(XBytes(x_offset)),
                Some(YBytes(y_offset)),
                ZRow(z_index),
                false, // don't accumulate
            );
            ctx.reduce_u32_to_z();
*/
            ctx.fma64_z(0);
            ctx.fma64_z(7);
/*
*/
//            let got_z = std::mem::transmute::<_,[[u32;16];64]>(ctx.read_z());
            let got_z = std::mem::transmute::<_,[[f64;8];64]>(ctx.read_z());
            println!("Z");
            printA::<64,8>(&got_z);

    }
}
fn printA<const rows:usize,const cols:usize>(a:&[[f64;cols];rows]){
  for i in 0..rows {
    println!("{:?}", a[i])
  }
}
Izlaz izgleda ovako:
Kod:
X
[1.0, 9.0, 17.0, 25.0, 33.0, 41.0, 49.0, 57.0]
[65.0, 73.0, 81.0, 89.0, 97.0, 105.0, 113.0, 121.0]
[129.0, 137.0, 145.0, 153.0, 161.0, 169.0, 177.0, 185.0]
[193.0, 201.0, 209.0, 217.0, 225.0, 233.0, 241.0, 249.0]
[257.0, 265.0, 273.0, 281.0, 289.0, 297.0, 305.0, 313.0]
[321.0, 329.0, 337.0, 345.0, 353.0, 361.0, 369.0, 377.0]
[385.0, 393.0, 401.0, 409.0, 417.0, 425.0, 433.0, 441.0]
[449.0, 457.0, 465.0, 473.0, 481.0, 489.0, 497.0, 505.0]
Y
[3.0, 11.0, 19.0, 27.0, 35.0, 43.0, 51.0, 59.0]
[67.0, 75.0, 83.0, 91.0, 99.0, 107.0, 115.0, 123.0]
[131.0, 139.0, 147.0, 155.0, 163.0, 171.0, 179.0, 187.0]
[195.0, 203.0, 211.0, 219.0, 227.0, 235.0, 243.0, 251.0]
[259.0, 267.0, 275.0, 283.0, 291.0, 299.0, 307.0, 315.0]
[323.0, 331.0, 339.0, 347.0, 355.0, 363.0, 371.0, 379.0]
[387.0, 395.0, 403.0, 411.0, 419.0, 427.0, 435.0, 443.0]
[451.0, 459.0, 467.0, 475.0, 483.0, 491.0, 499.0, 507.0]
Z
[5.0, 29.0, 53.0, 77.0, 101.0, 125.0, 149.0, 173.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[202501.0, 206109.0, 209717.0, 213325.0, 216933.0, 220541.0, 224149.0, 227757.0]
[13.0, 101.0, 189.0, 277.0, 365.0, 453.0, 541.0, 629.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[206093.0, 209765.0, 213437.0, 217109.0, 220781.0, 224453.0, 228125.0, 231797.0]
[21.0, 173.0, 325.0, 477.0, 629.0, 781.0, 933.0, 1085.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[209685.0, 213421.0, 217157.0, 220893.0, 224629.0, 228365.0, 232101.0, 235837.0]
[29.0, 245.0, 461.0, 677.0, 893.0, 1109.0, 1325.0, 1541.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[213277.0, 217077.0, 220877.0, 224677.0, 228477.0, 232277.0, 236077.0, 239877.0]
[37.0, 317.0, 597.0, 877.0, 1157.0, 1437.0, 1717.0, 1997.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[216869.0, 220733.0, 224597.0, 228461.0, 232325.0, 236189.0, 240053.0, 243917.0]
[45.0, 389.0, 733.0, 1077.0, 1421.0, 1765.0, 2109.0, 2453.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[220461.0, 224389.0, 228317.0, 232245.0, 236173.0, 240101.0, 244029.0, 247957.0]
[53.0, 461.0, 869.0, 1277.0, 1685.0, 2093.0, 2501.0, 2909.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[224053.0, 228045.0, 232037.0, 236029.0, 240021.0, 244013.0, 248005.0, 251997.0]
[61.0, 533.0, 1005.0, 1477.0, 1949.0, 2421.0, 2893.0, 3365.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0]
[227645.0, 231701.0, 235757.0, 239813.0, 243869.0, 247925.0, 251981.0, 256037.0]
Ako nekog zanima, mogu da stavim na github kad zavrsim, ali treba Apple M1 ili M2 procesor.
Sad radim na deljenju i korenovanju preko sabiranja i mnozenja, posto je ovo
u sustini namenjeno za neuronske mreze...
 
Evo implementacije za 1/x i sqrt x za Apple AMX instrukcije. Koristio sam u fazonu 8 64 bitnih f64 ko AVX 512 :p
Sve intrukcije sam sam napravio binarnim kodiranje u libu .Mislim da je ovo prvo nesto ozbiljno sto moze sad da se izgugla.
Kod:
fn rcp(&mut self,row:&[f64;8])->[f64;8]{
      let mut rc = [0.0;8];
      let one = [1.0;8];
      let zero = [0.0;8];
      let mut magic:[u64;8] = [0x7FDE6238502484BA;8];
      let mut zv = [0.0;8];
      for (i,mut v) in magic.iter_mut().enumerate() {
        unsafe {*v -= std::mem::transmute::<_,u64>(row[i]);}
      }
      unsafe {
        self.load512(&one,ZRow(1));
        self.load512(&magic,XRow(1));
        self.load512(row,YRow(1));
      }
      self.fms64_vec(1,1,1);
      self.extr_y(1,1);
      for _ in 0..3 {
        self.fma64_vec_x(1,1);
        self.fma64_vec(1,1,1);
        self.extr_x(1,1);
        self.extr_xy(0,1);
        unsafe {self.load512(&zero,ZRow(2));}
        self.fma64_vec(2,0,1);
        self.extr_y(2,1);
      }
      self.extr_xy(1,1);
      self.extr_y(1,1);
      self.fma64_vec(1,1,1);
      unsafe { self.store512(&mut rc,ZRow(1));}
      rc
    }
    fn sqrt(& mut self, s:&[f64;8])->[f64;8] {
      let mut rc = [0.0;8];
      let mut a = *s;
      let mut sqr10 = [1.0;8];
      let mut in_x = [0.0;8];
      let mut in_z = [0.0;8];
      let zero = [0.0;8];
      let zero_point_five = [0.5;8];
      unsafe {
          self.load512(s,XRow(1));
          self.load512(&zero_point_five,YRow(1));
          self.load512(&zero,ZRow(1));
      }
      self.fma64_vec(1,1,1);
      self.extr_x(1,1);
      for (i,mut a) in a.iter_mut().enumerate() {
        while *a > 100.0 {
          *a *= 0.001;
          sqr10[i]*= 10.0;
        }
        if *a < 1.0 { *a = 1.0 }
      }
      for (i,a) in a.iter().enumerate(){
        (in_x[i],in_z[i]) = if *a < 10.0 {
          (0.28, 0.89)
        } else {
          (0.89, 2.8)
        }
      }
      unsafe {
      self.load512(&in_x,XRow(0));
      self.load512(&a,YRow(0));
      self.load512(&in_z,ZRow(0));
      }
      self.fma64_vec(0,0,0);
      self.extr_x(0,0);
      unsafe {
      self.load512(&sqr10,YRow(0));
      self.load512(&zero,ZRow(0));
      }
      self.fma64_vec(0,0,0); // we have estimate
      unsafe {
        self.load512(&zero_point_five,XRow(7));
        self.store512(&mut a,ZRow(0));;
      }
      for i in 0..6 {
        let rcp = self.rcp(&a);
        unsafe{
          self.load512(&a,ZRow(0));
          self.load512(s,XRow(0));
          self.load512(&rcp,YRow(0));
        }
        self.fma64_vec(0,0,0);

        self.extr_y(0,0);
        unsafe { self.load512(&zero,ZRow(0));}
        self.fma64_vec(0,7,0);
        unsafe { self.store512(&mut a,ZRow(0));}
      }
      unsafe {self.store512(&mut rc,ZRow(0));}
      rc
    }
ovako se koristi:
Kod:
            let two = [65536.0;8];
            ctx.fma64_vec(0,0,0);
            ctx.fma64_vec(7,7,7);
                 let start = Instant::now();
                 let mut sum = 0.0;
            for _ in 0..1000000 {
              sum = ctx.sqrt(&two).iter().fold(sum,|a,b|a+b);;
            }
                 let end = start.elapsed();
                 let diff = (end.as_secs()*1000000000+end.subsec_nanos() as u64) as f64 / 1000000000.0;
                 println!("simd time {} sum {}",diff, sum);
 
Poslednja izmena:
I na kraju sam optimizovao ovo, prebacio u registre, mada int operacije nisam mogao komplet, ali zadovoljan sam sada. umesto
stek nizova koristim sada registre, pa je malo komplikovanije za koriscenje.
Kod:
    fn rcp(&mut self,zrow_in:u64,zrow_out:u64){
      let one = [1.0;8];
      let zero = [0.0;8];
      let mut magic:[u64;8] = [0x7FDE6238502484BA;8];
      let mut row = [0;8];
      unsafe { self.store512(&mut row,ZRow(zrow_in as usize));}
      for (i,mut v) in magic.iter_mut().enumerate() {
        *v -= row[i];
      }
      unsafe {
        self.load512(&one,ZRow(1));
        self.load512(&magic,XRow(1));
        self.load512(&row,YRow(1));
        self.load512(&zero,XRow(2));
      }
      self.fms64_vec(1,1,1);
      self.extr_y(1,1);
      for _ in 0..3 {
        self.fma64_vec_x(1,1);
        self.fma64_vec(1,1,1);
        self.extr_x(1,1);
        self.extr_xy(0,1);
        self.fma64_vec_x(2,2);
        self.fma64_vec(2,0,1);
        self.extr_y(2,1);
      }
      self.extr_x(1,1);
      self.fma64_vec(1,1,1);
      self.extr_x(1,0);
      self.fma64_vec_x(zrow_out,0);
    }
    fn sqrt(& mut self, zrow_in:u64,zrow_out:u64){
      let mut a = [0.0f32;8];
      let mut number = [0.0f64;8];
      unsafe {self.store512(&mut number,ZRow(zrow_in as usize));}
      for (ind,v) in number.iter().enumerate() {
        a[ind] = *v as f32;
      }
      let mut i = [0u32;8];
      for (ind,v) in a.iter().enumerate() {
        unsafe {i[ind] = std::mem::transmute::<_,u32>(*v);}
      }
      for mut i in i.iter_mut() {
        *i = 0x5f3759df - (*i >> 1);
      }
      for (ind,v) in i.iter().enumerate() {
        unsafe{ a[ind] = std::mem::transmute::<_,f32>(*v);}
      }
      for mut v in a.iter_mut() {
        *v = *v * ( 1.5 - ( 0.5 * *v * *v * *v));
      }
      for (ind,v) in a.iter().enumerate() {
        number[ind] = *v as f64;
      }
      let zero = [0.0f64;8];
      let three = [3.0f64;8];
      let zero_point_five = [0.5f64;8];
      unsafe {
        self.load512(&number,ZRow(60));
        self.load512(&zero,ZRow(63));
        self.load512(&three,ZRow(62));
        self.load512(&zero_point_five,ZRow(61));
      }
      for _ in 0..8 {
        self.extr_y(60,0);// xn -> X
        self.extr_x(60,0);// xn -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7);// 0 -> Z
        self.fma64_vec(0,0,0);// xn ^2
        self.extr_x(zrow_in,0);// s -> X
        self.extr_y(0,0); // Z -> Y
        self.extr_x(62,7);
        self.fma64_vec_x(0,7);// 3 -> Z
        self.fms64_vec(0,0,0);// 3 - s * xn ^ 2
        self.extr_x(60,0);// xn -> X
        self.extr_y(0,0);// Z -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7); // 0 -> Z
        self.fma64_vec(0,0,0);// xn * (3- s * xn ^2)
        self.extr_x(0,0);// Z -> X
        self.extr_y(61,0); // 0.5 -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7);// 0 -> Z
        self.fma64_vec(0,0,0);// xn * (3 - s * xn^2)/2
        self.extr_x(0,0);
        self.fma64_vec_x(60,0); // result -> Z[60]
      }
      self.extr_y(zrow_in,0);// s -> Y
      self.extr_x(63,7);
      self.fma64_vec_x(zrow_out,7);// 0 -> Z
      self.fma64_vec(zrow_out,0,0);// s * 1/sqrt(s)
    }
E sad, primer koriscenja:
Kod:
            let mut two:[f64;8] = [123456789.01;8];
            let two1 = two;
            ctx.fma64_vec(0,0,0);
            ctx.fma64_vec(7,7,7);
                 let start = Instant::now();
                 let mut sum = 0.0;
            ctx.load512(&two,ZRow(50));
            for _ in 0..1000000 {
              ctx.sqrt(50,51);
              ctx.store512(&mut two,ZRow(51));
              sum+=two.iter().sum::<f64>();
            }
                 let end = start.elapsed();
                 let diff = (end.as_secs()*1000000000+end.subsec_nanos() as u64) as f64 / 1000000000.0;
                 println!("simd time {} sum {}",diff, sum);
                 let start = Instant::now();
                 let mut sum = 0.0;
            for i in 0..1000000 {
              for v in two1 {
                sum+=v.sqrt();
              }
            }
                 let end = start.elapsed();
                 let diff = (end.as_secs()*1000000000+end.subsec_nanos() as u64) as f64 / 1000000000.0;
                 println!("seq time {} sum {}",diff,sum);
            ctx.sqrt(50,51);
            ctx.store512(&mut two,ZRow(51));
            println!("sqrt\n{:?}",two);
            let sqrt=two1[0].sqrt();
            println!("sqrtseq\n{:?}",sqrt);
            let mut rcp = [2.0;8];
            ctx.load512(&rcp,ZRow(63));
            ctx.rcp(63,63);
            ctx.store512(&mut rcp,ZRow(63));
            println!("rcp\n{:?}",rcp);
 
Breaking Bad Cooking GIF by Morphin
 
ispravka sqrt funkcije tako da je duplo brza, nema potrebe da racunam prvi korak njutnove aproksimacije u f32 plus jos sto je bila pogresna tako
da je sada u duplo manje koraka kao f64
Kod:
    fn sqrt(& mut self, zrow_in:u64,zrow_out:u64){
      let mut a = [0.0f32;8];
      let mut number = [0.0f64;8];
      unsafe {self.store512(&mut number,ZRow(zrow_in as usize));}
      for (ind,v) in number.iter().enumerate() {
        a[ind] = *v as f32;
      }
      let mut i = [0u32;8];
      for (ind,v) in a.iter().enumerate() {
        unsafe {i[ind] = std::mem::transmute::<_,u32>(*v);}
      }
      for mut i in i.iter_mut() {
        *i = 0x5f3759df - (*i >> 1);
      }
      for (ind,v) in i.iter().enumerate() {
        unsafe{ a[ind] = std::mem::transmute::<_,f32>(*v);}
      }
      for (ind,v) in a.iter().enumerate() {
        number[ind] = *v as f64;
      }
      let zero = [0.0f64;8];
      let three = [3.0f64;8];
      let zero_point_five = [0.5f64;8];
      unsafe {
        self.load512(&number,ZRow(60));
        self.load512(&zero,ZRow(63));
        self.load512(&three,ZRow(62));
        self.load512(&zero_point_five,ZRow(61));
      }
      for _ in 0..4 {
        self.extr_y(60,0);// xn -> X
        self.extr_x(60,0);// xn -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7);// 0 -> Z
        self.fma64_vec(0,0,0);// xn ^2
        self.extr_x(zrow_in,0);// s -> X
        self.extr_y(0,0); // Z -> Y
        self.extr_x(62,7);
        self.fma64_vec_x(0,7);// 3 -> Z
        self.fms64_vec(0,0,0);// 3 - s * xn ^ 2
        self.extr_x(60,0);// xn -> X
        self.extr_y(0,0);// Z -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7); // 0 -> Z
        self.fma64_vec(0,0,0);// xn * (3- s * xn ^2)
        self.extr_x(0,0);// Z -> X
        self.extr_y(61,0); // 0.5 -> Y
        self.extr_x(63,7);
        self.fma64_vec_x(0,7);// 0 -> Z
        self.fma64_vec(0,0,0);// xn * (3 - s * xn^2)/2
        self.extr_x(0,0);
        self.fma64_vec_x(60,0); // result -> Z[60]
      }
      self.extr_y(zrow_in,0);// s -> Y
      self.extr_x(63,7);
      self.fma64_vec_x(zrow_out,7);// 0 -> Z
      self.extr_x(60,0);
      self.fma64_vec(zrow_out,0,0);// s * 1/sqrt(s)
    }
 
Interesantna i dosta duboka tema :)

Cisto da pokusam da dam neki doprinos:
1) Za microbenchmarke bih preporucion criterion crate koji automatizuje pustanje tvojih funkcija dovoljan broj puta da se stekne statisticka sigurnost, meri wall clock vs. cpu time, plot-uje rezultate itd. Bilo bi interesantno uporediti rezultate naivne implementacije i ovoga sto si ti uradio, kao sto i radis kroz println!.
2) Ne znam mnogo o Apple Sillicon-u a pogotovo ne o SIMD-u na istom. Bilo bi interesantno uporediti perf u odnosu na simd sqrt na x86 (e.g. __m256d _mm256_sqrt_pd).
 
Interesantna i dosta duboka tema :)

Cisto da pokusam da dam neki doprinos:
1) Za microbenchmarke bih preporucion criterion crate koji automatizuje pustanje tvojih funkcija dovoljan broj puta da se stekne statisticka sigurnost, meri wall clock vs. cpu time, plot-uje rezultate itd. Bilo bi interesantno uporediti rezultate naivne implementacije i ovoga sto si ti uradio, kao sto i radis kroz println!.
2) Ne znam mnogo o Apple Sillicon-u a pogotovo ne o SIMD-u na istom. Bilo bi interesantno uporediti perf u odnosu na simd sqrt na x86 (e.g. __m256d _mm256_sqrt_pd).
Ima Apple silikon Neon, kao i bilo koji Aarch64 procesor, ovo su AMX instrukcije, koje su namenjene izracunavanju
neuronskih mreza, i koje Apple krije ko zmija noge, kao neko tajno oruzje...
Znam za Criterion, to je Rust implementacija Haskell Criteriona...
Inace Rust ima implementaciju benchmarka, samo treba napisati benchmark programe.
 
Dodao sam jos instrukcija i malo ubrzao ove funkcije.
Kod:
    fn rcp(&mut self,zrow_in:u64,zrow_out:u64){
      let mut row = [0;8];
      let mut magic:[u64;8] = [0x7FDE6238502484BA;8];
      unsafe { self.store512(&mut row,ZRow(zrow_in as usize));}
      for (i,mut v) in magic.iter_mut().enumerate() {
        *v -= row[i];
      }
      unsafe {
        self.load512(&one,ZRow(1));
        self.load512(&magic,XRow(1));
      }
      self.extr_y(zrow_in,1);
      self.fms64_vec(1,1,1,0);
      self.extr_y(1,1);
      for _ in 0..3 {
        self.fma64_vec_x(1,1);
        self.fma64_vec(1,1,1,0);
        self.extr_x(1,1);
        self.extr_xy(0,1);
        self.fma64_vec_xy(2,0,1,0);
        self.extr_y(2,1);
      }
      self.extr_x(1,1);
      self.fma64_vec(1,1,1,0);
      self.extr_x(1,0);
      self.fma64_vec_x(zrow_out,0);
    }
    fn sqrt(& mut self, zrow_in:u64,zrow_out:u64){
      let mut a = [0.0f32;8];
      let mut number = [0.0f64;8];
      unsafe {self.store512(&mut number,ZRow(zrow_in as usize));}
      for (ind,v) in number.iter().enumerate() {
        a[ind] = *v as f32;
      }
      let mut i = [0u32;8];
      for (ind,v) in a.iter().enumerate() {
        unsafe {i[ind] = std::mem::transmute::<_,u32>(*v);}
      }
      for mut i in i.iter_mut() {
        *i = 0x5f3759df - (*i >> 1);
      }
      for (ind,v) in i.iter().enumerate() {
        unsafe{ a[ind] = std::mem::transmute::<_,f32>(*v);}
      }
      for (ind,v) in a.iter().enumerate() {
        number[ind] = *v as f64;
      }
      unsafe {
        self.load512(&number,ZRow(60));
        self.load512(&three,ZRow(62));
        self.load512(&zero_point_five,ZRow(61));
      }
      for _ in 0..4 {
        self.extr_y(60,0);// xn -> X
        self.extr_x(60,0);// xn -> Y
        self.fma64_vec_xy(0,0,0,0);// xn ^2
        self.extr_x(zrow_in,0);// s -> X
        self.extr_y(0,0); // Z -> Y
        self.extr_x(62,7);
        self.fma64_vec_x(0,7);// 3 -> Z
        self.fms64_vec(0,0,0,0);// 3 - s * xn ^ 2
        self.extr_x(60,0);// xn -> X
        self.extr_y(0,0);// Z -> Y
        self.fma64_vec_xy(0,0,0,0);// xn * (3- s * xn ^2)
        self.extr_x(0,0);// Z -> X
        self.extr_y(61,0); // 0.5 -> Y
        self.fma64_vec_xy(0,0,0,0);// xn * (3 - s * xn^2)/2
        self.extr_x(0,0);
        self.fma64_vec_x(60,0); // result -> Z[60]
      }
      self.extr_y(zrow_in,0);// s -> Y
      self.extr_x(60,0);
      self.fma64_vec_xy(zrow_out,0,0,0);// s * 1/sqrt(s)
    }
 
Btw, cisto da dodam, posto square root aproksimacija nije bas opste poznata stvar, bacite pogled na ovaj video:



Moze biti interesantno za gamer-e, posto je ovaj "trik" iskopan iz Quake 3 code base-a.

Ako dobro citam kod, Ovo sto bmaxa radi je implementacija te logike na nedokumentovanom Apple Silicon koprocesoru. Sve u svemu, bas egzoticno.

@bmaxa - Inace, nije mi bas jasna i dalje cela prica sa AMX-om. Ocekivao sam da nadjem pytorch backend koji koristi AMX, ali deluje mi da to nije slucaj. Sa druge strane, Intelov AMX je podrzan. Ne znam ko je predvidjen kao korisnik ovoga (sem Apple-a interno) ako ne postoji torch/tensorflow implementacija.

I btw, da li se ovo samo igras ili zaista radis na necemu gde su ti potrebne ove low level optimizacije?
 
@bmaxa - Inace, nije mi bas jasna i dalje cela prica sa AMX-om. Ocekivao sam da nadjem pytorch backend koji koristi AMX, ali deluje mi da to nije slucaj. Sa druge strane, Intelov AMX je podrzan. Ne znam ko je predvidjen kao korisnik ovoga (sem Apple-a interno) ako ne postoji torch/tensorflow implementacija.

I btw, da li se ovo samo igras ili zaista radis na necemu gde su ti potrebne ove low level optimizacije?
Planiram da radim nesto sa ovim. Nista ovo ne koristi, jer nema asemblera za instrukcije nego sve moras
binarno da kodiras.
|Jedini nacin je preko Apple biblioteka, osim ovoga :p
Pretabao u makro, i poboljsao kodiranje.
Evo kako izgleda:

Kod:
Kod:
macro_rules! op_in {
{$OP:tt , $operand:tt} => {
    asm!(
        ".align 8\n.word (0x201000 + ({op} << 5) + 0{operand} - ((0{operand} >> 4) * 6))",
        op = const $OP,
        operand = in(reg) $operand
    );}
}
/// Emit an AMX instruction with a 5-bit immediate.
macro_rules!op_imm {{ $OP: tt, $OPERAND: tt}=> {
    asm!(
        ".align 8\n.word 0x00201000 + ({op} << 5) + {operand}",
        op = const $OP,
        operand = const $OPERAND
    );}
}
dakle ima dve varijante, sa immediate operandom, tj konstantom, i sa ulaznim registrem.
Varijanta sa ulaznim registrem, koji je u formatu xnn na Aarch64 mora da se pretvori
u dekadni broj jer asembler ubacuje u obliku recimo x21.
Dakle hexadekadni kako ga vidi (zato se lepi 0), treba pretvoriti u dekadni 21.
I zato ova pretumbacija. Inace instrukcija, kako vidite, je zbir konstante, operacije
koja predstavlja instrukciju, i operanda koji dolazi u run time.
Operand sadrzi sve varijante instrukcija sa izborom registara i samu operaciju
koja treba da se izvrsi.
Instrukcija je validna potpuno, sve se generise u compile time.
evo finalne verzije rcp i sqrt:

Kod:
Kod:
  fn rcp(&mut self,zrow_in:u64,zrow_out:u64){
      let mut row = [0;8];
      let mut magic:[u64;8] = [0x7FDE6238502484BA;8];
      unsafe { self.store512(&mut row,ZRow(zrow_in as usize));}
      for (i,mut v) in magic.iter_mut().enumerate() {
        *v -= row[i];
      }
      unsafe {
        self.load512(&one,ZRow(1));
        self.load512(&magic,XRow(1));
      }
      self.extr_y(zrow_in,1);
      self.fms64_vec(1,1,1,0);
      self.extr_y(1,1);
      for _ in 0..3 {
        self.fma64_vec_x(1,1);
        self.fma64_vec(1,1,1,0);
        self.extr_x(1,1);
        self.extr_xy(0,1);
        self.fma64_vec_xy(2,0,1,0);
        self.extr_y(2,1);
      }
      self.extr_x(1,1);
      self.fma64_vec(1,1,1,0);
      self.extr_x(1,0);
      self.fma64_vec_x(zrow_out,0);
    }
   fn sqrt(& mut self, zrow_in:u64,zrow_out:u64){                                                                                                          [15/1817]
      let mut a = [0.0f32;8];
      let mut number = [0.0f64;8];
      unsafe {self.store512(&mut number,ZRow(zrow_in as usize));}
      for (ind,v) in number.iter().enumerate() {
        a[ind] = *v as f32;
      }
      for mut v in a.iter_mut() {
        unsafe {
          let mut v = std::mem::transmute::<_,*mut u32>(v);
          *v = 0x5f3759df - (*v >> 1);
        }
      }
      for (ind,v) in a.iter().enumerate() {
        number[ind] = *v as f64;
      }
      unsafe {
        self.load512(&number,ZRow(60));
        self.load512(&three,ZRow(62));
        self.load512(&zero_point_five,ZRow(61));
      }
      for _ in 0..3 {
        self.extr_y(60,0);// xn -> X
        self.extr_x(60,0);// xn -> Y
        self.fma64_vec_xy(0,0,0,0);// xn ^2
        self.extr_x(zrow_in,0);// s -> X
        self.extr_y(0,0); // Z -> Y
        self.extr_x(62,7);
        self.fma64_vec_x(0,7);// 3 -> Z
        self.fms64_vec(0,0,0,0);// 3 - s * xn ^ 2
        self.extr_x(60,0);// xn -> X
        self.extr_y(0,0);// Z -> Y
        self.fma64_vec_xy(0,0,0,0);// xn * (3- s * xn ^2)
        self.extr_x(0,0);// Z -> X
        self.extr_y(61,0); // 0.5 -> Y
        self.fma64_vec_xy(0,0,0,0);// xn * (3 - s * xn^2)/2
        self.extr_x(0,0);
        self.fma64_vec_x(60,0); // result -> Z[60]
      }
      self.extr_y(zrow_in,0);// s -> Y
      self.extr_x(60,0);
      self.fma64_vec_xy(zrow_out,0,0,0);// s * 1/sqrt(s)
    }
Sto se performansi tice, najsporije su upravo ove vektorske
dok su matricne 10 puta brze, tako da cu izaci
kad napravim sa matricnim instrukcijama.
 
Interesantno je sto je Apple ove instrukcije ubacio u konzumerske procesore i to 2020., dok je Intel
sa istom pricom uleteo januara 2023. i to samo u Xeon procesorima. AMD kaska, jos ih nema..
nisu nesto na Apple u vektorskom obliku, nema saturacije dovoljne pa mora da se digne vise
threadova, bilo na jednom koru, bilo na vise korova, no ove matricne saturisu mnogo
bolje, na single thtead.
 
10ns za store je puno. Na 3.2ghz to je ~30 cycles. Opet, ne znam kakva je arhitektura CPU-a. Pretpostavljam da je AMX koprocesor nakacen na istu cache hirerhiju?

Btw, mislim da bi bilo super da napises neki blog post ili makar detaljniji readme.md. Ovako je skoro pa nemoguce pratiti sta sve radis (a zanimljivo je).
 
10ns za store je puno. Na 3.2ghz to je ~30 cycles. Opet, ne znam kakva je arhitektura CPU-a. Pretpostavljam da je AMX koprocesor nakacen na istu cache hirerhiju?

Btw, mislim da bi bilo super da napises neki blog post ili makar detaljniji readme.md. Ovako je skoro pa nemoguce pratiti sta sve radis (a zanimljivo je).
10ns kad je u l1, no merim 26ns. readme cu pisati kad zavrsim. Inace i 1.5ns za instrukciju je ogromno, to leti
u range milisekundi samo nakon milion instrukcija... kako god posto je lejtensi isti za sve instrukcije
onda je 16bit 4 puta brze od 64 bit, a 2 puta od 32bit i matricne puta 8 jos na sve to od vektoriskih.
 
E sad je odralo, duplo i vise brze sa amx u mnozenju matrica.
Sabiranje i oduzimanje je sporije zbog velikog broja load/store operacija, ali u paketu ispada duplo
brze i sa sabiranjem i oduzimanjem :p
 
Nazad
Vrh Dno