Commit ffc05e7f authored by Ryan Gutenkunst's avatar Ryan Gutenkunst
Browse files

Simply from_phi_linalg methods to reduce code duplication

parent a693e7dc
......@@ -1632,8 +1632,6 @@ class Spectrum(numpy.ma.masked_array):
data = numpy.zeros((nx+1,ny+1,nz+1))
dbeta1_zz, dbeta2_zz = cached_dbeta(nz, zz)
dbeta1_yy, dbeta2_yy = cached_dbeta(ny, yy)
dbeta1_xx, dbeta2_xx = cached_dbeta(nx, xx)
# Quick testing suggests that doing the x direction first for better
# memory alignment isn't worth much.
......@@ -1650,19 +1648,8 @@ class Spectrum(numpy.ma.masked_array):
term2 *= (kk+1)/((nz+1)*(nz+2))
over_z = term1 + term2
s_yy = (over_z[:,1:]-over_z[:,:-1])/(yy[nuax,1:]-yy[nuax,:-1])
c1_yy = (over_z[:,:-1] - s_yy*yy[nuax,:-1])/(ny+1)
term1_yy = np.dot(dbeta1_yy, c1_yy.T)
term2_yy = np.dot(dbeta2_yy, s_yy.T) * (np.arange(1,ny+2)[:,np.newaxis]/((ny+1)*(ny+2)))
over_y_all = term1_yy + term2_yy
s_xx_all = (over_y_all[:,1:]-over_y_all[:,:-1])/(xx[1:]-xx[:-1])
c1_xx_all = (over_y_all[:,:-1] - s_xx_all*xx[:-1])/(nx+1)
term1_all = np.dot(dbeta1_xx, c1_xx_all.T)
term2_all = np.dot(dbeta2_xx, s_xx_all.T) * (np.arange(1,nx+2)[:,np.newaxis]/((nx+1)*(nx+2)))
data[:,:,kk] = term1_all + term2_all
sub_fs = Spectrum._from_phi_2D_linalg(nx, ny, xx, yy, over_z, mask_corners=False)
data[:,:,kk] = sub_fs.data
fs = dadi.Spectrum(data, mask_corners=mask_corners)
return fs
......@@ -1680,48 +1667,17 @@ class Spectrum(numpy.ma.masked_array):
data = numpy.zeros((nx+1,ny+1,nz+1,na+1,nb+1))
dbeta1_bb, dbeta2_bb = cached_dbeta(nb, bb)
dbeta1_aa, dbeta2_aa = cached_dbeta(na, aa)
dbeta1_zz, dbeta2_zz = cached_dbeta(nz, zz)
dbeta1_yy, dbeta2_yy = cached_dbeta(ny, yy)
dbeta1_xx, dbeta2_xx = cached_dbeta(nx, xx)
s_bb = (phi[:,:,:,:,1:]-phi[:,:,:,:,:-1])/(bb[nuax,nuax,nuax,nuax,1:]-bb[nuax,nuax,nuax,nuax:-1])
c1_bb = (phi[:,:,:,:,:-1] - s_bb*bb[nuax,nuax,nuax,nuax,:-1])/(nb+1)
for mm in range(0, nb+1):
term1 = np.dot(c1_bb, dbeta1_bb[mm])
term2 = np.dot(s_bb, dbeta2_bb[bb])
term2 = np.dot(s_bb, dbeta2_bb[mm])
term2 *= (mm+1)/((nb+1)*(nb+2))
over_b = term1 + term2
s_aa = (over_b[:,:,:,1:]-over_b[:,:,:,:-1])/(aa[nuax,nuax,nuax,1:]-aa[nuax,nuax,nuax:-1])
c1_aa = (over_b[:,:,:,:-1] - s_aa*aa[nuax,nuax,nuax,:-1])/(na+1)
for ll in range(0, na+1):
term1 = np.dot(c1_aa, dbeta1_aa[ll])
term2 = np.dot(s_aa, dbeta2_aa[ll])
term2 *= (ll+1)/((na+1)*(na+2))
over_a = term1 + term2
s_zz = (over_a[:,:,1:]-over_a[:,:,:-1])/(zz[nuax,nuax,1:]-zz[nuax,nuax,:-1])
c1_zz = (over_a[:,:,:-1] - s_zz*zz[nuax,nuax,:-1])/(nz+1)
for kk in range(0, nz+1):
term1 = np.dot(c1_zz, dbeta1_zz[kk])
term2 = np.dot(s_zz, dbeta2_zz[kk])
term2 *= (kk+1)/((nz+1)*(nz+2))
over_z = term1 + term2
s_yy = (over_z[:,1:]-over_z[:,:-1])/(yy[nuax,1:]-yy[nuax,:-1])
c1_yy = (over_z[:,:-1] - s_yy*yy[nuax,:-1])/(ny+1)
term1_yy = np.dot(dbeta1_yy, c1_yy.T)
term2_yy = np.dot(dbeta2_yy, s_yy.T) * (np.arange(1,ny+2)[:,np.newaxis]/((ny+1)*(ny+2)))
over_y_all = term1_yy + term2_yy
s_xx_all = (over_y_all[:,1:]-over_y_all[:,:-1])/(xx[1:]-xx[:-1])
c1_xx_all = (over_y_all[:,:-1] - s_xx_all*xx[:-1])/(nx+1)
term1_all = np.dot(dbeta1_xx, c1_xx_all.T)
term2_all = np.dot(dbeta2_xx, s_xx_all.T) * (np.arange(1,nx+2)[:,np.newaxis]/((nx+1)*(nx+2)))
data[:,:,kk,ll,mm] = term1_all + term2_all
sub_fs = Spectrum._from_phi_4D_linalg(nx, ny, nz, na, xx, yy, zz, aa, over_b, mask_corners=False)
data[:,:,:,:,mm] = sub_fs.data
fs = dadi.Spectrum(data, mask_corners=mask_corners)
return fs
......@@ -1739,9 +1695,6 @@ class Spectrum(numpy.ma.masked_array):
data = numpy.zeros((nx+1,ny+1,nz+1,na+1))
dbeta1_aa, dbeta2_aa = cached_dbeta(na, aa)
dbeta1_zz, dbeta2_zz = cached_dbeta(nz, zz)
dbeta1_yy, dbeta2_yy = cached_dbeta(ny, yy)
dbeta1_xx, dbeta2_xx = cached_dbeta(nx, xx)
s_aa = (phi[:,:,:,1:]-phi[:,:,:,:-1])/(aa[nuax,nuax,nuax,1:]-aa[nuax,nuax,nuax:-1])
c1_aa = (phi[:,:,:,:-1] - s_aa*aa[nuax,nuax,nuax,:-1])/(na+1)
......@@ -1751,27 +1704,8 @@ class Spectrum(numpy.ma.masked_array):
term2 *= (ll+1)/((na+1)*(na+2))
over_a = term1 + term2
s_zz = (over_a[:,:,1:]-over_a[:,:,:-1])/(zz[nuax,nuax,1:]-zz[nuax,nuax,:-1])
c1_zz = (over_a[:,:,:-1] - s_zz*zz[nuax,nuax,:-1])/(nz+1)
for kk in range(0, nz+1):
term1 = np.dot(c1_zz, dbeta1_zz[kk])
term2 = np.dot(s_zz, dbeta2_zz[kk])
term2 *= (kk+1)/((nz+1)*(nz+2))
over_z = term1 + term2
s_yy = (over_z[:,1:]-over_z[:,:-1])/(yy[nuax,1:]-yy[nuax,:-1])
c1_yy = (over_z[:,:-1] - s_yy*yy[nuax,:-1])/(ny+1)
term1_yy = np.dot(dbeta1_yy, c1_yy.T)
term2_yy = np.dot(dbeta2_yy, s_yy.T) * (np.arange(1,ny+2)[:,np.newaxis]/((ny+1)*(ny+2)))
over_y_all = term1_yy + term2_yy
s_xx_all = (over_y_all[:,1:]-over_y_all[:,:-1])/(xx[1:]-xx[:-1])
c1_xx_all = (over_y_all[:,:-1] - s_xx_all*xx[:-1])/(nx+1)
term1_all = np.dot(dbeta1_xx, c1_xx_all.T)
term2_all = np.dot(dbeta2_xx, s_xx_all.T) * (np.arange(1,nx+2)[:,np.newaxis]/((nx+1)*(nx+2)))
data[:,:,kk,ll] = term1_all + term2_all
sub_fs = Spectrum._from_phi_3D_linalg(nx, ny, nz, xx, yy, zz, over_a, mask_corners=False)
data[:,:,:,ll] = sub_fs.data
fs = dadi.Spectrum(data, mask_corners=mask_corners)
return fs
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment