Skip to content

Commit

Permalink
Merge pull request #17 from qoncept/dev
Browse files Browse the repository at this point in the history
Replace combinations of `guard` and `fatalError` with `precondition`s
  • Loading branch information
koher committed May 25, 2016
2 parents 77a671c + 04361ac commit 7871c93
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 18 deletions.
16 changes: 8 additions & 8 deletions TensorSwift/Tensor.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ public struct Tensor {

public init(shape: Shape, elements: [Element]) {
let volume = shape.volume
guard elements.count >= volume else { fatalError("`elements.count` must be greater than or equal to `shape.volume`: elements.count = \(elements.count), shape.volume = \(shape.volume)") }
precondition(elements.count >= volume, "`elements.count` must be greater than or equal to `shape.volume`: elements.count = \(elements.count), shape.volume = \(shape.volume)")
self.shape = shape
self.elements = (elements.count == volume) ? elements : Array(elements[0..<volume])
}
Expand Down Expand Up @@ -68,7 +68,7 @@ internal func commutativeBinaryOperation(lhs: Tensor, _ rhs: Tensor, operation:
let rSize = rhs.shape.dimensions.count

if lSize == rSize {
guard lhs.shape == rhs.shape else { fatalError("Incompatible shapes of tensors: lhs.shape = \(lhs.shape), rhs.shape = \(rhs.shape)") }
precondition(lhs.shape == rhs.shape, "Incompatible shapes of tensors: lhs.shape = \(lhs.shape), rhs.shape = \(rhs.shape)")
return Tensor(shape: lhs.shape, elements: zipMap(lhs.elements, rhs.elements, operation: operation))
}

Expand All @@ -91,13 +91,13 @@ internal func noncommutativeBinaryOperation(lhs: Tensor, _ rhs: Tensor, operatio
let rSize = rhs.shape.dimensions.count

if lSize == rSize {
guard lhs.shape == rhs.shape else { fatalError("Incompatible shapes of tensors: lhs.shape = \(lhs.shape), rhs.shape = \(rhs.shape)") }
precondition(lhs.shape == rhs.shape, "Incompatible shapes of tensors: lhs.shape = \(lhs.shape), rhs.shape = \(rhs.shape)")
return Tensor(shape: lhs.shape, elements: zipMap(lhs.elements, rhs.elements, operation: operation))
} else if lSize < rSize {
guard hasSuffix(array: rhs.shape.dimensions, suffix: lhs.shape.dimensions) else { fatalError("Incompatible shapes of tensors: lhs.shape = \(lhs.shape), rhs.shape = \(rhs.shape)") }
precondition(hasSuffix(array: rhs.shape.dimensions, suffix: lhs.shape.dimensions), "Incompatible shapes of tensors: lhs.shape = \(lhs.shape), rhs.shape = \(rhs.shape)")
return Tensor(shape: rhs.shape, elements: zipMapRepeat(rhs.elements, lhs.elements, operation: { operation($1, $0) }))
} else {
guard hasSuffix(array: lhs.shape.dimensions, suffix: rhs.shape.dimensions) else { fatalError("Incompatible shapes of tensors: lhs.shape = \(lhs.shape), rhs.shape = \(rhs.shape)") }
precondition(hasSuffix(array: lhs.shape.dimensions, suffix: rhs.shape.dimensions), "Incompatible shapes of tensors: lhs.shape = \(lhs.shape), rhs.shape = \(rhs.shape)")
return Tensor(shape: lhs.shape, elements: zipMapRepeat(lhs.elements, rhs.elements, operation: operation))
}
}
Expand Down Expand Up @@ -136,9 +136,9 @@ public func /(lhs: Float, rhs: Tensor) -> Tensor {

extension Tensor { // Matrix
public func matmul(tensor: Tensor) -> Tensor {
guard shape.dimensions.count == 2 else { fatalError("This tensor is not a matrix: shape = \(shape)") }
guard tensor.shape.dimensions.count == 2 else { fatalError("The given tensor is not a matrix: shape = \(tensor.shape)") }
guard tensor.shape.dimensions[0] == shape.dimensions[1] else { fatalError("Incompatible shapes of matrices: self.shape = \(shape), tensor.shape = \(tensor.shape)") }
precondition(shape.dimensions.count == 2, "This tensor is not a matrix: shape = \(shape)")
precondition(tensor.shape.dimensions.count == 2, "The given tensor is not a matrix: shape = \(tensor.shape)")
precondition(tensor.shape.dimensions[0] == shape.dimensions[1], "Incompatible shapes of matrices: self.shape = \(shape), tensor.shape = \(tensor.shape)")

#if os(iOS) || os(OSX)
let result = Tensor(shape: [shape.dimensions[0], tensor.shape.dimensions[1]])
Expand Down
20 changes: 10 additions & 10 deletions TensorSwift/TensorNN.swift
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ extension Tensor {

extension Tensor {
public func maxPool(kernelSize kernelSize: [Int], strides: [Int]) -> Tensor { // padding = Same
guard shape.dimensions.count == 3 else { fatalError("`shape.dimensions.count` must be 3: \(shape.dimensions.count)") }
guard kernelSize.count == 3 else { fatalError("`ksize.count` must be 3: \(kernelSize.count)") }
guard kernelSize[2] == 1 else { fatalError("`ksize[3]` != 1 is not supported: \(kernelSize[2])") }
guard strides.count == 3 else { fatalError("`strides.count` must be 3: \(strides.count)") }
guard strides[2] == 1 else { fatalError("`strides[2]` != 1 is not supported: \(strides[2])") }
precondition(shape.dimensions.count == 3, "`shape.dimensions.count` must be 3: \(shape.dimensions.count)")
precondition(kernelSize.count == 3, "`ksize.count` must be 3: \(kernelSize.count)")
precondition(kernelSize[2] == 1, "`ksize[3]` != 1 is not supported: \(kernelSize[2])")
precondition(strides.count == 3, "`strides.count` must be 3: \(strides.count)")
precondition(strides[2] == 1, "`strides[2]` != 1 is not supported: \(strides[2])")

let inRows = shape.dimensions[0].value
let inCols = shape.dimensions[1].value
Expand Down Expand Up @@ -77,11 +77,11 @@ extension Tensor {
public func conv2d(filter filter: Tensor, strides: [Int]) -> Tensor { // padding = Same
let inChannels = filter.shape.dimensions[2].value

guard shape.dimensions.count == 3 else { fatalError("`shape.dimensions.count` must be 3: \(shape.dimensions.count)") }
guard filter.shape.dimensions.count == 4 else { fatalError("`filter.shape.dimensions.count` must be 4: \(filter.shape.dimensions.count)") }
guard strides.count == 3 else { fatalError("`strides.count` must be 3: \(strides.count)") }
guard strides[2] == 1 else { fatalError("`strides[2]` must be 1") }
guard shape.dimensions[2].value == inChannels else { fatalError("The number of channels of this tensor and the filter are not compatible: \(shape.dimensions[2]) != \(inChannels)") }
precondition(shape.dimensions.count == 3, "`shape.dimensions.count` must be 3: \(shape.dimensions.count)")
precondition(filter.shape.dimensions.count == 4, "`filter.shape.dimensions.count` must be 4: \(filter.shape.dimensions.count)")
precondition(strides.count == 3, "`strides.count` must be 3: \(strides.count)")
precondition(strides[2] == 1, "`strides[2]` must be 1")
precondition(shape.dimensions[2].value == inChannels, "The number of channels of this tensor and the filter are not compatible: \(shape.dimensions[2]) != \(inChannels)")

let inRows = shape.dimensions[0].value
let inCols = shape.dimensions[1].value
Expand Down

0 comments on commit 7871c93

Please sign in to comment.