Commit 9513ea90 authored by Ryan Gutenkunst's avatar Ryan Gutenkunst
Browse files

Clean up move of cusparse to dadi distribution

parent 3741c3b8
......@@ -2,8 +2,8 @@ import atexit
import pycuda.autoinit
from pycuda.tools import clear_context_caches, make_default_context
from dadi.cuda.cusparse import cusparseCreate, cusparseDestroy
from skcuda.cublas import cublasCreate, cublasDestroy, cublasDgeam
from .cusparse import cusparseCreate, cusparseDestroy
ctx = make_default_context()
cusparse_handle = cusparseCreate()
......
#!/usr/bin/env python
"""
Python interface to CUSPARSE functions.
Python interface to CUSPARSE functions for solving batch tridiagonal systems.
Note: this module does not explicitly depend on PyCUDA.
The set up code for this module is heavily based on the cusparse module within scikit-cuda.
It is duplicated here because, as of version 0.5.3, the cusparse module within scikit-cuda raises an error upon import.
My pull request to fix this error and add the batch tridiagonal functions scikit-cuda was not acted for
over two months, hence I moved the code into the dadi distribution itself.
"""
from __future__ import absolute_import
import ctypes
import platform
import ctypes, platform, sys
from string import Template
import sys
from skcuda import cuda
#(base) PS C:\Users\rgute\Desktop\dadi-devel\tests> python test_CUDA.py -v C:\Users\rgute\anaconda3\lib\site-packages\skcuda\cublas.py:284: UserWarning: creating CUBLAS context to get version number
# warnings.warn('creating CUBLAS context to get version number')
#test_2d_const_params (__main__.CUDATestCase) ... C:\Users\rgute\anaconda3\lib\site-packages\pycuda\gpuarray.py:183: DeprecationWarning: np.asscalar(a) is deprecated since NumPy v1.16, use a.item() instead
# s = np.asscalar(s)
# Load library:
_linux_version_list = [10.2, 10.1, 10.0, 9.2, 9.1, 9.0, 8.0, 7.5, 7.0, 6.5, 6.0, 5.5, 5.0, 4.0]
......@@ -153,67 +152,6 @@ def cusparseDestroy(handle):
status = _libcusparse.cusparseDestroy(handle)
cusparseCheckStatus(status)
_libcusparse.cusparseGetVersion.restype = int
_libcusparse.cusparseGetVersion.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
def cusparseGetVersion(handle):
"""
Return CUSPARSE library version.
Returns the version number of the CUSPARSE library.
Parameters
----------
handle : int
CUSPARSE library context.
Returns
-------
version : int
CUSPARSE library version number.
"""
version = ctypes.c_int()
status = _libcusparse.cusparseGetVersion(handle,
ctypes.byref(version))
cusparseCheckStatus(status)
return version.value
_libcusparse.cusparseSetStream.restype = int
_libcusparse.cusparseSetStream.argtypes = [ctypes.c_void_p, ctypes.c_int]
def cusparseSetStream(handle, id):
"""
Sets the CUSPARSE stream in which kernels will run.
Parameters
----------
handle : int
CUSPARSE library context.
id : int
Stream ID.
"""
status = _libcusparse.cusparseSetStream(handle, id)
cusparseCheckStatus(status)
_libcusparse.cusparseGetStream.restype = int
_libcusparse.cusparseGetStream.argtypes = [ctypes.c_void_p, ctypes.c_void_p]
def cusparseGetStream(handle):
"""
Gets the CUSPARSE stream in which kernels will run.
Parameters
----------
handle : int
CUSPARSE library context.
Returns
-------
handle : int
CUSPARSE library context.
"""
id = ctypes.c_int()
status = _libcusparse.cusparseGetStream(handle, ctypes.byref(id))
cusparseCheckStatus(status)
return id.value
gtsv2StridedBatch_bufferSizeExt_doc = Template(
"""
Calculate size of work buffer used by cusparse<t>gtsv2StridedBatch.
......
......@@ -26,7 +26,7 @@ class CUDATestCase(unittest.TestCase):
dadi.cuda_enabled(True)
phi_gpu = dadi.Integration.two_pops(phi.copy(), xx, *args,
enable_cuda_const=True)
enable_cuda_cached=True)
self.assertTrue(np.allclose(phi_cpu, phi_gpu))
......@@ -73,7 +73,7 @@ class CUDATestCase(unittest.TestCase):
m12, m13, m21, m23, m31, m32,
gamma1, gamma2, gamma3, h1, h2, h3,
theta0, initial_t, frozen1, frozen2,
frozen3, enable_cuda_const=True)
frozen3, enable_cuda_cached=True)
self.assertTrue(np.allclose(phi_cpu, phi_gpu))
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment