diff --git a/src/RefElemData_polynomial.jl b/src/RefElemData_polynomial.jl index c941b421..1f54d2dc 100644 --- a/src/RefElemData_polynomial.jl +++ b/src/RefElemData_polynomial.jl @@ -345,11 +345,11 @@ function RefElemData(elem::Quad, M1D = Vq1D' * diagm(wq1D) * Vq1D # form kronecker products of multidimensional matrices to invert/multiply - VDM = kronecker(VDM_1D, VDM_1D) - invVDM = kronecker(invVDM_1D, invVDM_1D) - invM = kronecker(invM_1D, invM_1D) + VDM = kron(VDM_1D, VDM_1D) + invVDM = kron(invVDM_1D, invVDM_1D) + invM = kron(invM_1D, invM_1D) - M = kronecker(M1D, M1D) + M = kron(M1D, M1D) _, Vr, Vs = basis(elem, N, r, s) Dr, Ds = (A -> A * invVDM).((Vr, Vs)) @@ -363,7 +363,7 @@ function RefElemData(elem::Quad, # quadrature nodes - build from 1D nodes. rq, sq, wq = tensor_product_quadrature(elem, approximation_type.data.quad_rule_1D...) - Vq = kronecker(Vq1D, Vq1D) # vandermonde(elem, N, rq, sq, tq) * invVDM + Vq = kron(Vq1D, Vq1D) # vandermonde(elem, N, rq, sq, tq) * invVDM Pq = invM * (Vq' * diagm(wq)) Vf = vandermonde(elem, N, rf, sf) * invVDM @@ -372,7 +372,7 @@ function RefElemData(elem::Quad, # plotting nodes rp1D = LinRange(-1, 1, Nplot + 1) Vp1D = vandermonde(Line(), N, rp1D) / VDM_1D - Vp = kronecker(Vp1D, Vp1D) + Vp = kron(Vp1D, Vp1D) rp, sp = vec.(StartUpDG.NodesAndModes.meshgrid(rp1D, rp1D)) return RefElemData(elem, approximation_type, N, fv, V1, @@ -406,11 +406,14 @@ function RefElemData(elem::Hex, M1D = Vq1D' * diagm(wq1D) * Vq1D # form kronecker products of multidimensional matrices to invert/multiply - VDM = kronecker(VDM_1D, VDM_1D, VDM_1D) - invVDM = kronecker(invVDM_1D, invVDM_1D, invVDM_1D) - invM = kronecker(invM_1D, invM_1D, invM_1D) + # use dense matrix "kron" if N is 4 or lower; use memory-saving "kronecker" otherwise + build_kronecker_product = (N < 5) ? kron : kronecker + + VDM = build_kronecker_product(VDM_1D, VDM_1D, VDM_1D) + invVDM = build_kronecker_product(invVDM_1D, invVDM_1D, invVDM_1D) + invM = build_kronecker_product(invM_1D, invM_1D, invM_1D) - M = kronecker(M1D, M1D, M1D) + M = build_kronecker_product(M1D, M1D, M1D) _, Vr, Vs, Vt = basis(elem, N, r, s, t) Dr, Ds, Dt = (A -> A * invVDM).((Vr, Vs, Vt)) @@ -424,7 +427,7 @@ function RefElemData(elem::Hex, # quadrature nodes - build from 1D nodes. rq, sq, tq, wq = tensor_product_quadrature(elem, approximation_type.data.quad_rule_1D...) - Vq = kronecker(Vq1D, Vq1D, Vq1D) # vandermonde(elem, N, rq, sq, tq) * invVDM + Vq = build_kronecker_product(Vq1D, Vq1D, Vq1D) # vandermonde(elem, N, rq, sq, tq) * invVDM Pq = invM * (Vq' * diagm(wq)) Vf = vandermonde(elem, N, rf, sf, tf) * invVDM @@ -433,7 +436,7 @@ function RefElemData(elem::Hex, # plotting nodes rp1D = LinRange(-1, 1, Nplot + 1) Vp1D = vandermonde(Line(), N, rp1D) / VDM_1D - Vp = kronecker(Vp1D, Vp1D, Vp1D) + Vp = build_kronecker_product(Vp1D, Vp1D, Vp1D) rp, sp, tp = vec.(StartUpDG.NodesAndModes.meshgrid(rp1D, rp1D, rp1D)) return RefElemData(elem, approximation_type, N, fv, V1,