Hi thanks for feedback! That’s a good point I did compare to torch but at a high enough sequence length (~1024) torch version starts OOM because it has to materialize the S^2 in global mem. On small sequence length, torch does win solely on optimised cublas matmuls