Skip to content

Commit

Permalink
Merge pull request #5 from qoncept/swift-3
Browse files Browse the repository at this point in the history
Update for Swift 3
  • Loading branch information
koher authored Oct 24, 2016
2 parents 7871c93 + e2c0e9e commit df8ea90
Show file tree
Hide file tree
Showing 55 changed files with 1,232 additions and 1,329 deletions.
7 changes: 7 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,13 @@ DerivedData
*.ipa
*.xcuserstate

# SwiftPackageManager
#
.build
/*.xcodeproj/xcshareddata/
/*.xcodeproj/project.xcworkspace/xcuserdata/
/*.xcodeproj/xcuserdata/

# CocoaPods
#
# We recommend against adding the Pods directory to your .gitignore. However
Expand Down
16 changes: 8 additions & 8 deletions MNIST/AppDelegate.swift
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,30 @@ class AppDelegate: UIResponder, UIApplicationDelegate {
var window: UIWindow?


func application(application: UIApplication, didFinishLaunchingWithOptions launchOptions: [NSObject: AnyObject]?) -> Bool {
func application(_ application: UIApplication, didFinishLaunchingWithOptions launchOptions: [UIApplicationLaunchOptionsKey: Any]?) -> Bool {
// Override point for customization after application launch.
return true
}

func applicationWillResignActive(application: UIApplication) {
func applicationWillResignActive(_ application: UIApplication) {
// Sent when the application is about to move from active to inactive state. This can occur for certain types of temporary interruptions (such as an incoming phone call or SMS message) or when the user quits the application and it begins the transition to the background state.
// Use this method to pause ongoing tasks, disable timers, and throttle down OpenGL ES frame rates. Games should use this method to pause the game.
// Use this method to pause ongoing tasks, disable timers, and invalidate graphics rendering callbacks. Games should use this method to pause the game.
}

func applicationDidEnterBackground(application: UIApplication) {
func applicationDidEnterBackground(_ application: UIApplication) {
// Use this method to release shared resources, save user data, invalidate timers, and store enough application state information to restore your application to its current state in case it is terminated later.
// If your application supports background execution, this method is called instead of applicationWillTerminate: when the user quits.
}

func applicationWillEnterForeground(application: UIApplication) {
// Called as part of the transition from the background to the inactive state; here you can undo many of the changes made on entering the background.
func applicationWillEnterForeground(_ application: UIApplication) {
// Called as part of the transition from the background to the active state; here you can undo many of the changes made on entering the background.
}

func applicationDidBecomeActive(application: UIApplication) {
func applicationDidBecomeActive(_ application: UIApplication) {
// Restart any tasks that were paused (or not yet started) while the application was inactive. If the application was previously in the background, optionally refresh the user interface.
}

func applicationWillTerminate(application: UIApplication) {
func applicationWillTerminate(_ application: UIApplication) {
// Called when the application is about to terminate. Save data if appropriate. See also applicationDidEnterBackground:.
}

Expand Down
4 changes: 2 additions & 2 deletions MNIST/Canvas.swift
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,11 @@ import CoreGraphics
struct Canvas {
var lines: [Line] = [Line()]

mutating func draw(point: CGPoint) {
mutating func draw(_ point: CGPoint) {
lines[lines.endIndex - 1].points.append(point)
}

mutating func newLine() {
lines.append(Line())
}
}
}
30 changes: 15 additions & 15 deletions MNIST/CanvasView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,33 +12,33 @@ class CanvasView: UIView {

var image: UIImage {
UIGraphicsBeginImageContext(bounds.size)
layer.renderInContext(UIGraphicsGetCurrentContext()!)
layer.render(in: UIGraphicsGetCurrentContext()!)
let result = UIGraphicsGetImageFromCurrentImageContext()
UIGraphicsEndImageContext()
return result
return result!
}

override func drawRect(rect: CGRect) {
override func draw(_ rect: CGRect) {
let context = UIGraphicsGetCurrentContext()
for line in canvas.lines {
CGContextSetLineWidth(context, 20.0)
CGContextSetStrokeColorWithColor(context, UIColor(colorLiteralRed: 0.0, green: 0.0, blue: 0.0, alpha: 1.0).CGColor)
CGContextSetLineCap(context, .Round)
CGContextSetLineJoin(context, .Round)
for (index, point) in line.points.enumerate() {
context?.setLineWidth(20.0)
context?.setStrokeColor(UIColor(colorLiteralRed: 0.0, green: 0.0, blue: 0.0, alpha: 1.0).cgColor)
context?.setLineCap(.round)
context?.setLineJoin(.round)
for (index, point) in line.points.enumerated() {
if index == 0 {
CGContextMoveToPoint(context, point.x, point.y)
context?.move(to: CGPoint(x: point.x, y: point.y))
} else {
CGContextAddLineToPoint(context, point.x, point.y)
context?.addLine(to: CGPoint(x: point.x, y: point.y))
}
}
}
CGContextStrokePath(context)
context?.strokePath()
}

func onPanGesture(gestureRecognizer: UIPanGestureRecognizer) {
canvas.draw(gestureRecognizer.locationInView(self))
if gestureRecognizer.state == .Ended {
func onPanGesture(_ gestureRecognizer: UIPanGestureRecognizer) {
canvas.draw(gestureRecognizer.location(in: self))
if gestureRecognizer.state == .ended {
canvas.newLine()
}

Expand All @@ -49,4 +49,4 @@ class CanvasView: UIView {
canvas = Canvas()
setNeedsDisplay()
}
}
}
36 changes: 18 additions & 18 deletions MNIST/Classifier.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,36 +11,36 @@ public struct Classifier {
public let W_fc2: Tensor
public let b_fc2: Tensor

public func classify(x_image: Tensor) -> Int {
let h_conv1 = (x_image.conv2d(filter: W_conv1, strides: [1, 1, 1]) + b_conv1).relu
public func classify(_ x_image: Tensor) -> Int {
let h_conv1 = (x_image.conv2d(filter: W_conv1, strides: [1, 1, 1]) + b_conv1).relu()
let h_pool1 = h_conv1.maxPool(kernelSize: [2, 2, 1], strides: [2, 2, 1])

let h_conv2 = (h_pool1.conv2d(filter: W_conv2, strides: [1, 1, 1]) + b_conv2).relu
let h_conv2 = (h_pool1.conv2d(filter: W_conv2, strides: [1, 1, 1]) + b_conv2).relu()
let h_pool2 = h_conv2.maxPool(kernelSize: [2, 2, 1], strides: [2, 2, 1])

let h_pool2_flat = h_pool2.reshape([1, 7 * 7 * 64])
let h_fc1 = (h_pool2_flat.matmul(W_fc1) + b_fc1).relu
let h_pool2_flat = h_pool2.reshaped([1, 7 * 7 * 64])
let h_fc1 = (h_pool2_flat.matmul(W_fc1) + b_fc1).relu()

let y_conv = (h_fc1.matmul(W_fc2) + b_fc2).softmax
let y_conv = (h_fc1.matmul(W_fc2) + b_fc2).softmax()

return y_conv.elements.enumerate().maxElement { $0.1 < $1.1 }!.0
return y_conv.elements.enumerated().max { $0.1 < $1.1 }!.0
}
}

extension Classifier {
public init(path: String) {
W_conv1 = Tensor(shape: [5, 5, 1, 32], elements: loadFloatArray(directory: path, file: "W_conv1"))
b_conv1 = Tensor(shape: [32], elements: loadFloatArray(directory: path, file: "b_conv1"))
W_conv2 = Tensor(shape: [5, 5, 32, 64], elements: loadFloatArray(directory: path, file: "W_conv2"))
b_conv2 = Tensor(shape: [64], elements: loadFloatArray(directory: path, file: "b_conv2"))
W_fc1 = Tensor(shape: [7 * 7 * 64, 1024], elements: loadFloatArray(directory: path, file: "W_fc1"))
b_fc1 = Tensor(shape: [1024], elements: loadFloatArray(directory: path, file: "b_fc1"))
W_fc2 = Tensor(shape: [1024, 10], elements: loadFloatArray(directory: path, file: "W_fc2"))
b_fc2 = Tensor(shape: [10], elements: loadFloatArray(directory: path, file: "b_fc2"))
W_conv1 = Tensor(shape: [5, 5, 1, 32], elements: loadFloatArray(path, file: "W_conv1"))
b_conv1 = Tensor(shape: [32], elements: loadFloatArray(path, file: "b_conv1"))
W_conv2 = Tensor(shape: [5, 5, 32, 64], elements: loadFloatArray(path, file: "W_conv2"))
b_conv2 = Tensor(shape: [64], elements: loadFloatArray(path, file: "b_conv2"))
W_fc1 = Tensor(shape: [7 * 7 * 64, 1024], elements: loadFloatArray(path, file: "W_fc1"))
b_fc1 = Tensor(shape: [1024], elements: loadFloatArray(path, file: "b_fc1"))
W_fc2 = Tensor(shape: [1024, 10], elements: loadFloatArray(path, file: "W_fc2"))
b_fc2 = Tensor(shape: [10], elements: loadFloatArray(path, file: "b_fc2"))
}
}

private func loadFloatArray(directory directory: String, file: String) -> [Float] {
let data = NSData(contentsOfFile: directory.stringByAppendingPathComponent(file))!
return Array(UnsafeBufferPointer(start: UnsafeMutablePointer<Float>(data.bytes), count: data.length / 4))
private func loadFloatArray(_ directory: String, file: String) -> [Float] {
let data = try! Data(contentsOf: URL(fileURLWithPath: directory.stringByAppendingPathComponent(file)))
return Array(UnsafeBufferPointer(start: UnsafeMutablePointer<Float>(mutating: (data as NSData).bytes.bindMemory(to: Float.self, capacity: data.count)), count: data.count / 4))
}
22 changes: 10 additions & 12 deletions MNIST/Info.plist
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,6 @@
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>NSAppTransportSecurity</key>
<dict>
<key>NSExceptionDomains</key>
<dict>
<key>yann.lecun.com</key>
<string></string>
</dict>
<key>NSAllowsArbitraryLoads</key>
<true/>
</dict>
<key>CFBundleDevelopmentRegion</key>
<string>en</string>
<key>CFBundleExecutable</key>
Expand All @@ -26,8 +16,6 @@
<string>APPL</string>
<key>CFBundleShortVersionString</key>
<string>1.0</string>
<key>CFBundleSignature</key>
<string>????</string>
<key>CFBundleVersion</key>
<string>1</string>
<key>LSRequiresIPhoneOS</key>
Expand All @@ -53,5 +41,15 @@
<string>UIInterfaceOrientationLandscapeLeft</string>
<string>UIInterfaceOrientationLandscapeRight</string>
</array>
<key>NSAppTransportSecurity</key>
<dict>
<key>NSExceptionDomains</key>
<dict>
<key>yann.lecun.com</key>
<string></string>
</dict>
<key>NSAllowsArbitraryLoads</key>
<true/>
</dict>
</dict>
</plist>
6 changes: 3 additions & 3 deletions MNIST/String.swift
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import Foundation

extension String {
public func stringByAppendingPathComponent(str: String) -> String {
return (self as NSString).stringByAppendingPathComponent(str)
public func stringByAppendingPathComponent(_ str: String) -> String {
return (self as NSString).appendingPathComponent(str)
}
}
}
20 changes: 10 additions & 10 deletions MNIST/ViewController.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,31 +5,31 @@ class ViewController: UIViewController {
@IBOutlet private var canvasView: CanvasView!

private let inputSize = 28
private let classifier = Classifier(path: NSBundle.mainBundle().resourcePath!)
private let classifier = Classifier(path: Bundle.main.resourcePath!)

@IBAction func onPressClassifyButton(sender: UIButton) {
@IBAction func onPressClassifyButton(_ sender: UIButton) {
let input: Tensor
do {
let image = canvasView.image

let cgImage = image.CGImage!
let cgImage = image.cgImage!

var pixels = [UInt8](count: inputSize * inputSize, repeatedValue: 0)
var pixels = [UInt8](repeating: 0, count: inputSize * inputSize)

let context = CGBitmapContextCreate(&pixels, inputSize, inputSize, 8, inputSize, CGColorSpaceCreateDeviceGray()!, CGBitmapInfo.ByteOrderDefault.rawValue)!
CGContextClearRect(context, CGRect(x: 0.0, y: 0.0, width: CGFloat(inputSize), height: CGFloat(inputSize)))
let context = CGContext(data: &pixels, width: inputSize, height: inputSize, bitsPerComponent: 8, bytesPerRow: inputSize, space: CGColorSpaceCreateDeviceGray(), bitmapInfo: 0)!
context.clear(CGRect(x: 0.0, y: 0.0, width: CGFloat(inputSize), height: CGFloat(inputSize)))

let rect = CGRect(x: 0.0, y: 0.0, width: CGFloat(inputSize), height: CGFloat(inputSize))
CGContextDrawImage(context, rect, cgImage)
context.draw(cgImage, in: rect)

input = Tensor(shape: [Dimension(inputSize), Dimension(inputSize), 1], elements: pixels.map { -(Float($0) / 255.0 - 0.5) + 0.5 })
}

let estimatedLabel = classifier.classify(input)

let alertController = UIAlertController(title: "\(estimatedLabel)", message: nil, preferredStyle: .Alert)
alertController.addAction(UIAlertAction(title: "Dismiss", style: .Default) { _ in self.canvasView.clear() })
presentViewController(alertController, animated: true, completion: nil)
let alertController = UIAlertController(title: "\(estimatedLabel)", message: nil, preferredStyle: .alert)
alertController.addAction(UIAlertAction(title: "Dismiss", style: .default) { _ in self.canvasView.clear() })
present(alertController, animated: true, completion: nil)
}
}

4 changes: 2 additions & 2 deletions MNISTTests/Array.swift
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
extension Array {
func grouped(count: Int) -> [[Element]] {
func grouped(_ count: Int) -> [[Element]] {
var result: [[Element]] = []
var group: [Element] = []
for element in self {
Expand All @@ -11,4 +11,4 @@ extension Array {
}
return result
}
}
}
17 changes: 0 additions & 17 deletions MNISTTests/ClassifierTest.swift

This file was deleted.

32 changes: 32 additions & 0 deletions MNISTTests/ClassifierTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import XCTest
import TensorSwift
@testable import MNIST

class ClassifierTests: XCTestCase {
func testClassify() {

let classifier = Classifier(path: Bundle(for: ViewController.self).resourcePath!)
let (images, labels) = downloadTestData()

let count = 1000

let xArray: [[Float]] = images.withUnsafeBytes { ptr in
[UInt8](UnsafeBufferPointer(start: UnsafePointer<UInt8>(ptr + 16), count: 28 * 28 * count))
.map { Float($0) / 255.0 }
.grouped(28 * 28)
}

let yArray: [Int] = labels.withUnsafeBytes { ptr in
[UInt8](UnsafeBufferPointer(start: UnsafePointer<UInt8>(ptr + 8), count: count))
.map { Int($0) }
}

let accuracy = Float(zip(xArray, yArray)
.reduce(0) { $0 + (classifier.classify(Tensor(shape: [28, 28, 1], elements: $1.0)) == $1.1 ? 1 : 0) })
/ Float(yArray.count)

print("accuracy: \(accuracy)")

XCTAssertGreaterThan(accuracy, 0.97)
}
}
Loading

0 comments on commit df8ea90

Please sign in to comment.