Skip to content

Commit

Permalink
pytorch api added and working for cuda
Browse files Browse the repository at this point in the history
  • Loading branch information
mpatrou committed Jan 19, 2024
1 parent 71481f2 commit 301bae9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 21 deletions.
4 changes: 2 additions & 2 deletions example/sphere_pytorch_prototype.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@


device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
device = torch.device('mps')
#device = torch.device('mps')

print("device",device)

def make_kernel(model, q_vectors,device):
"""Instantiate the python kernel with input *q_vectors*"""
q_input = kt.PyInput(q_vectors, dtype=torch.float64)
q_input = kt.PyInput(q_vectors, dtype=torch.double)
return kt.PyKernel(model.info, q_input, device = device)


Expand Down
35 changes: 16 additions & 19 deletions sasmodels/kerneltorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@ def __init__(self, q_vectors, dtype):
self.q[:, 1] = q_vectors[1]
else:
# Create empty tensor
self.q = torch.tensor(np.empty(self.nq, dtype=np.float32))
self.q = torch.DoubleTensor(np.empty(self.nq))
self.q[:self.nq] = q_vectors[0]

def release(self):
Expand Down Expand Up @@ -124,7 +124,7 @@ def __init__(self, model_info, q_input, device):
self.dtype = np.dtype('d')
self.info = model_info
self.q_input = q_input
self.res = np.empty(q_input.nq, np.float32)
self.res = np.empty(q_input.nq, np.double)
self.dim = '2d' if q_input.is_2d else '1d'

partable = model_info.parameters
Expand All @@ -138,7 +138,7 @@ def __init__(self, model_info, q_input, device):
# through the loop. Arguments to the kernel and volume functions
# will use views into this vector, relying on the fact that a
# an array of no dimensions acts like a scalar.
parameter_vector = np.empty(len(partable.call_parameters)-2, np.float64)
parameter_vector = np.empty(len(partable.call_parameters)-2, np.double)

# Create views into the array to hold the arguments.
offset = 0
Expand Down Expand Up @@ -224,10 +224,10 @@ def _loops(parameters, form, form_volume, form_radius, nq, call_details,
# mesh, we update the components with the polydispersity values before
# calling the respective functions.
n_pars = len(parameters)
parameters = torch.tensor(parameters, dtype=torch.float32).to(device)

parameters = torch.DoubleTensor(parameters)
print(parameters)
#parameters[:] = values[2:n_pars+2]
parameters[:] = torch.tensor(values[2:n_pars+2], dtype=torch.float32)
parameters[:] = torch.DoubleTensor(values[2:n_pars+2]).to(device)

print("parameters",parameters)
if call_details.num_active == 0:
Expand All @@ -238,18 +238,15 @@ def _loops(parameters, form, form_volume, form_radius, nq, call_details,

else:
#transform to tensor flow
pd_value = torch.tensor(values[2+n_pars:2+n_pars + call_details.num_weights], dtype=torch.float32)
pd_weight = torch.tensor(values[2+n_pars + call_details.num_weights:], dtype=torch.float32)

#print("pd_value",pd_value)
#print("pd_weight",pd_weight)
pd_value = torch.DoubleTensor(values[2+n_pars:2+n_pars + call_details.num_weights]).to(device)
pd_weight = torch.DoubleTensor(values[2+n_pars + call_details.num_weights:]).to(device)

weight_norm = 0.0
weighted_form = 0.0
weighted_shell = 0.0
weighted_radius = 0.0
partial_weight = np.NaN
weight = np.NaN
weight_norm = torch.tensor(0.0,dtype=torch.double).to(device)
weighted_form = torch.tensor(0.0,dtype=torch.double).to(device)
weighted_shell = torch.tensor(0.0,dtype=torch.double).to(device)
weighted_radius = torch.tensor(0.0,dtype=torch.double).to(device)
partial_weight = torch.tensor(np.NaN).to(device)
weight = torch.tensor(np.NaN).to(device)

p0_par = call_details.pd_par[0]
p0_length = call_details.pd_length[0]
Expand All @@ -262,7 +259,7 @@ def _loops(parameters, form, form_volume, form_radius, nq, call_details,
pd_length = call_details.pd_length[:call_details.num_active]

#total = np.zeros(nq, np.float64)
total = torch.zeros(nq, dtype= torch.float32).to(device)
total = torch.zeros(nq, dtype= torch.double).to(device)

#print("ll", range(call_details.num_eval))
#parallel for loop
Expand All @@ -275,7 +272,7 @@ def _loops(parameters, form, form_volume, form_radius, nq, call_details,
pd_index = (loop_index//pd_stride)%pd_length
parameters[pd_par] = pd_value[pd_offset+pd_index]
#partial_weight = np.prod(pd_weight[pd_offset+pd_index][1:])
partial_weight = torch.prod(pd_weight[pd_offset+pd_index][1:])
partial_weight = torch.prod(pd_weight[pd_offset+pd_index][1:]).to(device)

p0_index = loop_index%p0_length

Expand Down

0 comments on commit 301bae9

Please sign in to comment.