Skip to content

Commit

Permalink
chore: code qual
Browse files Browse the repository at this point in the history
  • Loading branch information
jlwalke2 committed Aug 29, 2024
1 parent 09f75dc commit d30c1f8
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 46 deletions.
16 changes: 8 additions & 8 deletions src/sasctl/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -571,11 +571,11 @@ def username(self):
def hostname(self):
return self._settings.get("domain")

def send(self, request, **kwargs):
def send(self, req, **kwargs):
if self.message_log.isEnabledFor(logging.DEBUG):
r = copy.deepcopy(request)
for filter in self.filters:
r = filter(r)
r = copy.deepcopy(req)
for filter_ in self.filters:
r = filter_(r)

self.message_log.debug(
"HTTP/1.1 {verb} {url}\n{headers}\nBody:\n{body}".format(
Expand All @@ -588,14 +588,14 @@ def send(self, request, **kwargs):
)
)
else:
self.message_log.info("HTTP/1.1 %s %s", request.method, request.url)
self.message_log.info("HTTP/1.1 %s %s", req.method, req.url)

response = super(Session, self).send(request, **kwargs)
response = super(Session, self).send(req, **kwargs)

if self.message_log.isEnabledFor(logging.DEBUG):
r = copy.deepcopy(response)
for filter in self.filters:
r = filter(r)
for filter_ in self.filters:
r = filter_(r)

self.message_log.debug(
"HTTP {status} {url}\n{headers}\nBody:\n{body}".format(
Expand Down
8 changes: 0 additions & 8 deletions src/sasctl/pzmm/write_json_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
import importlib
import json

# import math #not used
import pickle
import pickletools
import sys
Expand Down Expand Up @@ -1449,13 +1448,6 @@ def stat_dataset_to_dataframe(
Raised if an improper data format is provided.
"""
# If numpy inputs are supplied, then assume numpy is installed
try:
# noinspection PyPackageRequirements
import numpy as np
except ImportError:
np = None

# Convert target_value to numeric for creating binary probabilities
if isinstance(target_value, str):
target_value = float(target_value)
Expand Down
38 changes: 8 additions & 30 deletions src/sasctl/utils/model_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def get_model_info(model, X, y=None):

# Most PyTorch models are actually subclasses of torch.nn.Module, so checking module
# name alone is not sufficient.
elif torch and isinstance(model, torch.nn.Module):
if torch and isinstance(model, torch.nn.Module):
return PyTorchModelInfo(model, X, y)

raise ValueError(f"Unrecognized model type {type(model)} received.")
Expand Down Expand Up @@ -200,7 +200,8 @@ class OnnxModelInfo(ModelInfo):
def __init__(self, model, X, y=None):
if onnx is None:
raise RuntimeError(
"The onnx package must be installed to work with ONNX models. Please `pip install onnx`."
"The onnx package must be installed to work with ONNX models. "
"Please `pip install onnx`."
)

self._model = model
Expand All @@ -214,38 +215,19 @@ def __init__(self, model, X, y=None):

if len(inputs) > 1:
warnings.warn(
f"The ONNX model has {len(inputs)} inputs but only the first input will be captured in Model Manager."
f"The ONNX model has {len(inputs)} inputs but only the first input "
f"will be captured in Model Manager."
)

if len(outputs) > 1:
warnings.warn(
f"The ONNX model has {len(outputs)} outputs but only the first input will be captured in Model Manager."
f"The ONNX model has {len(outputs)} outputs but only the first output "
f"will be captured in Model Manager."
)

self._X_df = inputs[0]
self._y_df = outputs[0]

# initializer (static params)

# for field in model.ListFields():
# doc_string
# domain
# metadata_props
# model_author
# model_license
# model_version
# producer_name
# producer_version
# training_info

# irVersion
# producerName
# producerVersion
# opsetImport

# # list of (FieldDescriptor, value)
# fields = model.ListFields()

@staticmethod
def _tensor_to_dataframe(tensor):
"""
Expand All @@ -272,7 +254,7 @@ def _tensor_to_dataframe(tensor):
name = tensor.get("name", "Var")
type_ = tensor["type"]

if not "tensorType" in type_:
if "tensorType" not in type_:
raise ValueError(f"Received an unexpected ONNX input type: {type_}.")

dtype = onnx.helper.tensor_dtype_to_np_dtype(type_["tensorType"]["elemType"])
Expand Down Expand Up @@ -374,8 +356,6 @@ def __init__(self, model, X, y=None):
raise ValueError(
f"Expected input data to be a numpy array or PyTorch tensor, received {type(X)}."
)
# if X.ndim != 2:
# raise ValueError(f"Expected input date with shape (n_samples, n_dim), received shape {X.shape}.")

# Ensure each input is a PyTorch Tensor
X = tuple(x if isinstance(x, torch.Tensor) else torch.tensor(x) for x in X)
Expand All @@ -395,8 +375,6 @@ def __init__(self, model, X, y=None):
)

self._model = model

# TODO: convert X and y to DF with arbitrary names
self._X = X
self._y = y

Expand Down

0 comments on commit d30c1f8

Please sign in to comment.