Skip to content

Commit

Permalink
Merge branch 'refactor' into dataset_tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
nmichlo committed Jun 3, 2021
2 parents cade4f3 + f375da5 commit 029f64a
Show file tree
Hide file tree
Showing 18 changed files with 43 additions and 43 deletions.
2 changes: 1 addition & 1 deletion disent/frameworks/helper/latent_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@

from disent.frameworks.helper.util import compute_ave_loss
from disent.nn.loss.kl import kl_loss
from disent.nn.reductions import loss_reduction
from disent.nn.loss.reduction import loss_reduction


# ========================================================================= #
Expand Down
4 changes: 2 additions & 2 deletions disent/frameworks/helper/reconstructions.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@

from disent.frameworks.helper.util import compute_ave_loss
from disent.nn.modules import DisentModule
from disent.nn.reductions import batch_loss_reduction
from disent.nn.reductions import loss_reduction
from disent.nn.loss.reduction import batch_loss_reduction
from disent.nn.loss.reduction import loss_reduction
from disent.nn.transform import FftKernel


Expand Down
2 changes: 1 addition & 1 deletion disent/frameworks/vae/_unsupervised__dfcvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
from torchvision.models import vgg19_bn
from torch.nn import functional as F

from disent.nn.reductions import get_mean_loss_scale
from disent.nn.loss.reduction import get_mean_loss_scale
from disent.frameworks.helper.util import compute_ave_loss
from disent.frameworks.vae._unsupervised__betavae import BetaVae
from disent.nn.transform.functional import check_tensor
Expand Down
2 changes: 1 addition & 1 deletion disent/frameworks/vae/_unsupervised__dipvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

from disent.frameworks.helper.util import compute_ave_loss_and_logs
from disent.frameworks.vae._unsupervised__betavae import BetaVae
from disent.util.math import torch_cov_matrix
from disent.nn.functional import torch_cov_matrix


# ========================================================================= #
Expand Down
5 changes: 2 additions & 3 deletions disent/frameworks/vae/experimental/_unsupervised__dorvae.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
from disent.frameworks.helper.reconstructions import make_reconstruction_loss
from disent.frameworks.helper.reconstructions import ReconLossHandler
from disent.frameworks.vae._supervised__tvae import TripletVae
from disent.frameworks.vae._unsupervised__betavae import BetaVae
from disent.frameworks.vae._weaklysupervised__adavae import AdaVae
from disent.util.math_loss import torch_mse_rank_loss
from disent.util.math_loss import spearman_rank_loss
from disent.nn.loss.softsort import torch_mse_rank_loss
from disent.nn.loss.softsort import spearman_rank_loss
from experiment.util.hydra_utils import instantiate_recursive


Expand Down
4 changes: 2 additions & 2 deletions disent/metrics/_flatness_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,8 @@
from disent.metrics._flatness import filter_inactive_factors
from disent.util import iter_chunks
from disent.util import to_numpy
from disent.util.math import torch_mean_generalized
from disent.util.math import torch_pca
from disent.nn.functional import torch_mean_generalized
from disent.nn.functional import torch_pca


log = logging.getLogger(__name__)
Expand Down
11 changes: 6 additions & 5 deletions disent/util/math.py → disent/nn/functional/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,19 +21,20 @@
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
# ~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~=~

import logging
import warnings
from typing import List
from typing import Optional
from typing import Union

import logging
import numpy as np
import torch

from disent.util.math_generic import generic_as_int32
from disent.util.math_generic import generic_max
from disent.util.math_generic import TypeGenericTensor
from disent.util.math_generic import TypeGenericTorch
from disent.nn.functional._generic_tensors import generic_as_int32
from disent.nn.functional._generic_tensors import generic_max
from disent.nn.functional._generic_tensors import TypeGenericTensor
from disent.nn.functional._generic_tensors import TypeGenericTorch


log = logging.getLogger(__name__)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
6 changes: 3 additions & 3 deletions disent/nn/transform/_augment.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,9 @@

import disent
from disent.nn.modules import DisentModule
from disent.util.math import torch_box_kernel_2d
from disent.util.math import torch_conv2d_channel_wise_fft
from disent.util.math import torch_gaussian_kernel_2d
from disent.nn.functional import torch_box_kernel_2d
from disent.nn.functional import torch_conv2d_channel_wise_fft
from disent.nn.functional import torch_gaussian_kernel_2d


# ========================================================================= #
Expand Down
4 changes: 2 additions & 2 deletions experiment/exp/05_adversarial_data/run_01_sort_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@
from torch.utils.data import DataLoader

import experiment.exp.util as H
from disent.util.math_loss import multi_spearman_rank_loss
from disent.util.math_loss import torch_soft_rank
from disent.nn.loss.softsort import multi_spearman_rank_loss
from disent.nn.loss.softsort import torch_soft_rank


# ========================================================================= #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@

import torch.nn.functional as F
import experiment.exp.util as H
from disent.util.math import torch_conv2d_channel_wise_fft
from disent.util.math import torch_box_kernel_2d
from disent.util.math import torch_gaussian_kernel_2d
from disent.nn.functional import torch_conv2d_channel_wise_fft
from disent.nn.functional import torch_box_kernel_2d
from disent.nn.functional import torch_gaussian_kernel_2d


# ========================================================================= #
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@
from disent.nn.modules import DisentModule
from disent.util import make_box_str
from disent.util import seed
from disent.util.math import torch_conv2d_channel_wise_fft
from disent.util.math_loss import spearman_rank_loss
from disent.nn.functional import torch_conv2d_channel_wise_fft
from disent.nn.loss.softsort import spearman_rank_loss
from experiment.run import hydra_append_progress_callback
from experiment.run import hydra_check_cuda
from experiment.run import hydra_make_logger
Expand Down
2 changes: 1 addition & 1 deletion experiment/exp/util/_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import torch_optimizer
from torch.nn import functional as F

from disent.nn.reductions import batch_loss_reduction
from disent.nn.loss.reduction import batch_loss_reduction


# ========================================================================= #
Expand Down
20 changes: 10 additions & 10 deletions tests/test_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,18 @@

from disent.data.groundtruth import XYSquaresData
from disent.dataset.groundtruth import GroundTruthDataset
from disent.nn.functional import torch_conv2d_channel_wise
from disent.nn.functional import torch_conv2d_channel_wise_fft
from disent.nn.functional import torch_corr_matrix
from disent.nn.functional import torch_cov_matrix
from disent.nn.functional import torch_dct
from disent.nn.functional import torch_dct2
from disent.nn.functional import torch_gaussian_kernel_2d
from disent.nn.functional import torch_idct
from disent.nn.functional import torch_idct2
from disent.nn.functional import torch_mean_generalized
from disent.nn.transform import ToStandardisedTensor
from disent.util.math import torch_conv2d_channel_wise
from disent.util.math import torch_conv2d_channel_wise_fft
from disent.util import to_numpy
from disent.util.math import torch_dct
from disent.util.math import torch_dct2
from disent.util.math import torch_gaussian_kernel_2d
from disent.util.math import torch_idct
from disent.util.math import torch_idct2
from disent.util.math import torch_corr_matrix
from disent.util.math import torch_cov_matrix
from disent.util.math import torch_mean_generalized


# ========================================================================= #
Expand Down
10 changes: 5 additions & 5 deletions tests/test_math_generic.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
import pytest
import torch

from disent.util.math_generic import generic_as_int32
from disent.util.math_generic import generic_max
from disent.util.math_generic import generic_min
from disent.util.math_generic import generic_ndim
from disent.util.math_generic import generic_shape
from disent.nn.functional._generic_tensors import generic_as_int32
from disent.nn.functional._generic_tensors import generic_max
from disent.nn.functional._generic_tensors import generic_min
from disent.nn.functional._generic_tensors import generic_ndim
from disent.nn.functional._generic_tensors import generic_shape


# ========================================================================= #
Expand Down
4 changes: 2 additions & 2 deletions tests/test_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,8 @@

from disent.nn.transform import FftGaussianBlur
from disent.nn.transform._augment import _expand_to_min_max_tuples
from disent.util.math import torch_gaussian_kernel
from disent.util.math import torch_gaussian_kernel_2d
from disent.nn.functional import torch_gaussian_kernel
from disent.nn.functional import torch_gaussian_kernel_2d


# ========================================================================= #
Expand Down

0 comments on commit 029f64a

Please sign in to comment.