diff --git a/SwiftDraft/ReliantFramework/ReliantFramework/ReliantContext.swift b/SwiftDraft/ReliantFramework/ReliantFramework/ReliantContext.swift index cb55481..3f5ac00 100644 --- a/SwiftDraft/ReliantFramework/ReliantFramework/ReliantContext.swift +++ b/SwiftDraft/ReliantFramework/ReliantFramework/ReliantContext.swift @@ -14,7 +14,6 @@ public protocol ReliantContext { static func createContext() -> ContextType } - public protocol ReliantContextHolder { typealias ContextType : ReliantContext var context:ContextType { get } diff --git a/SwiftDraft/ReliantFramework/ReliantFramework/RelyOn.swift b/SwiftDraft/ReliantFramework/ReliantFramework/RelyOn.swift index 5659f57..85d65b4 100644 --- a/SwiftDraft/ReliantFramework/ReliantFramework/RelyOn.swift +++ b/SwiftDraft/ReliantFramework/ReliantFramework/RelyOn.swift @@ -8,17 +8,31 @@ import Foundation -class ContextCache { - static let sharedInstance:ContextCache = ContextCache() - var cache:Dictionary = Dictionary() +struct ContextCache { + // If we're only going to use a static cache, we might as well use a struct? + static var standard = [String:Any]() + + static var substitutions = [String:Any.Type]() } public func relyOn(type:T.Type) -> T.ContextType { - if let result = ContextCache.sharedInstance.cache[String(type)] { + let typeKey = String(type) + if let result = ContextCache.standard[typeKey] { return result as! T.ContextType } else { - let result = type.createContext() - ContextCache.sharedInstance.cache[String(type)] = result + var result:T.ContextType + + if let substitutionType = ContextCache.substitutions[typeKey] as? T.Type { + result = substitutionType.createContext() + } else { + result = type.createContext() + } + + ContextCache.standard[String(type)] = result return result; } +} + +public func relyOnSubstitute(type:T.Type)(_ otherType:T.Type) { + ContextCache.substitutions[String(type)] = otherType } \ No newline at end of file diff --git a/SwiftDraft/ReliantFramework/ReliantFrameworkTests/ReliantFrameworkTests.swift b/SwiftDraft/ReliantFramework/ReliantFrameworkTests/ReliantFrameworkTests.swift index 8fb42b6..5509d80 100644 --- a/SwiftDraft/ReliantFramework/ReliantFrameworkTests/ReliantFrameworkTests.swift +++ b/SwiftDraft/ReliantFramework/ReliantFrameworkTests/ReliantFrameworkTests.swift @@ -54,17 +54,25 @@ struct ContextNeedingContext : ReliantContext { return ContextNeedingContext() } } +class SubWaver : Waver { + func wave(reason: String) -> String { + return "Substitute waving" + } +} - - - +class SubstituteContext : SimpleReferenceContext { + override init() { + super.init() + waver = SubWaver() + } +} class ReliantFrameworkTests: XCTestCase { override func setUp() { super.setUp() ReliantFrameworkTestsHelper.sharedInsance.reset() - ContextCache.sharedInstance.cache = Dictionary() + ContextCache.standard.removeAll() } func testRelyOnReturnsSameInsanceEveryTimeForReferenceTypes() { @@ -84,6 +92,17 @@ class ReliantFrameworkTests: XCTestCase { XCTAssertEqual(needed.needy.decorateGreeting(), "Oh! Hello Needy") } - + func testSubstitutions() { + relyOnSubstitute(SimpleReferenceContext)(SubstituteContext) + XCTAssertTrue(ContextCache.substitutions.contains({ (key, value) -> Bool in + return key == String(SimpleReferenceContext) && value == SubstituteContext.self + })) + let context = relyOn(SimpleReferenceContext) + + // Failing test. + // The createContext() function is static and thus final. Actual context type seems + // to be correct, but createContext() is called on original context class + XCTAssertTrue(context is SubstituteContext) + } } diff --git a/SwiftDraft/ReliantFramework/ReliantFrameworkTests/SimpleReferenceContext.swift b/SwiftDraft/ReliantFramework/ReliantFrameworkTests/SimpleReferenceContext.swift index 7bc1bc3..54ce590 100644 --- a/SwiftDraft/ReliantFramework/ReliantFrameworkTests/SimpleReferenceContext.swift +++ b/SwiftDraft/ReliantFramework/ReliantFrameworkTests/SimpleReferenceContext.swift @@ -12,8 +12,8 @@ import Reliant class SimpleReferenceContext : ReliantContext { private let bothWorlds = BothWorlds(prefix:"Hello") - let waver:Waver - let greeter:Greeter + var waver:Waver + var greeter:Greeter init() { ReliantFrameworkTestsHelper.sharedInsance.markInitCalled()