diff --git a/swiftwinrt/Resources/Support/Aggregation.swift b/swiftwinrt/Resources/Support/Aggregation.swift index cf0f839f..d5086287 100644 --- a/swiftwinrt/Resources/Support/Aggregation.swift +++ b/swiftwinrt/Resources/Support/Aggregation.swift @@ -11,7 +11,7 @@ import Foundation // to the swift object, to ensure it doesn't get cleaned up. The Swift object in turn holds a strong // reference to this object so that it stays alive. @_spi(WinRTInternal) -public final class WinRTClassWeakReference { +public final class WinRTClassWeakReference { fileprivate weak var instance: Class? public init(_ instance: Class){ self.instance = instance @@ -41,7 +41,7 @@ extension WinRTClassWeakReference: CustomAddRef { } extension WinRTClassWeakReference: AnyObjectWrapper { - var obj: AnyObject? { instance } + var obj: WinRTObject? { instance } } @_spi(WinRTInternal) diff --git a/swiftwinrt/Resources/Support/IInspectable.swift b/swiftwinrt/Resources/Support/IInspectable.swift index 04fb211b..16ae7bbe 100644 --- a/swiftwinrt/Resources/Support/IInspectable.swift +++ b/swiftwinrt/Resources/Support/IInspectable.swift @@ -40,7 +40,7 @@ open class IInspectable: IUnknown { // } // internal typealias Composable = IBaseNoOverrides protocol AnyObjectWrapper { - var obj: AnyObject? { get } + var obj: WinRTObject? { get } } public enum __ABI_ { @@ -53,7 +53,13 @@ public enum __ABI_ { } else { let vtblPtr = withUnsafeMutablePointer(to: &IInspectableVTable) { $0 } let cAbi = C_IInspectable(lpVtbl: vtblPtr) - super.init(cAbi, swift as AnyObject) + // For WinRTObject types, hold a weak reference to the swift obj since + // it will hold a strong reference to ourselves + if let winrtObj = swift as? WinRTObject { + super.init(cAbi, WinRTClassWeakReference(winrtObj)) + } else { + super.init(cAbi, swift as AnyObject) + } } } @@ -62,6 +68,16 @@ public enum __ABI_ { if let swiftAbi = swiftObj as? IInspectable { let abi: UnsafeMutablePointer = RawPointer(swiftAbi) return try body(abi) + } else if let winrtObjWrapper = swiftObj as? AnyObjectWrapper, + let winrtObj = winrtObjWrapper.obj { + if let identity = winrtObj.identity { + return try body(identity.get()) + } + return try super.toABI{ + winrtObj.identity = .init($0) + (winrtObjWrapper as? CustomAddRef)?.release() + return try body($0) + } } else { return try super.toABI(body) } @@ -116,8 +132,9 @@ public enum __ABI_ { GetRuntimeClassName: { guard let instance = AnyWrapper.tryUnwrapFrom(raw: $0) else { return E_INVALIDARG } - guard let winrtClass = instance as? WinRTClass else { - let string = String(reflecting: type(of: instance)) + let winrtObj = (instance as? AnyObjectWrapper)?.obj ?? instance + guard let winrtClass = winrtObj as? WinRTClass else { + let string = String(reflecting: type(of: winrtObj)) $1!.pointee = try! HString(string).detach() return S_OK } diff --git a/swiftwinrt/Resources/Support/WinRTProtocols.swift b/swiftwinrt/Resources/Support/WinRTProtocols.swift index ac6e0a51..d7acb0e7 100644 --- a/swiftwinrt/Resources/Support/WinRTProtocols.swift +++ b/swiftwinrt/Resources/Support/WinRTProtocols.swift @@ -13,8 +13,27 @@ public protocol IWinRTObject: AnyObject { public protocol WinRTInterface: AnyObject, CustomQueryInterface { } -open class WinRTClass : CustomQueryInterface, Equatable { - public init() {} +open class WinRTObject: Equatable { + public init() {} + + public static func == (lhs: WinRTObject, rhs: WinRTObject) -> Bool { + return lhs.identity?.get() == rhs.identity?.get() + } + + internal var identity: ComPtr? + + @_spi(WinRTImplements) + open func queryInterface(_ iid: SUPPORT_MODULE.IID) -> IUnknownRef? { + // the vtables for IInspectable will properly respond to all default QueryInterface + // calls for IInspectable and IUnknown. The queryInterface method exists for subclasses + // to provide custom handling of QueryInterface calls for other interfaces that the code + // gen is unaware of + return nil + } +} + +open class WinRTClass : WinRTObject, CustomQueryInterface { + override public init() { super.init() } @_spi(WinRTInternal) public init(_ ptr: SUPPORT_MODULE.IInspectable) { @@ -35,10 +54,8 @@ open class WinRTClass : CustomQueryInterface, Equatable { @_spi(WinRTInternal) public internal(set) var _inner: SUPPORT_MODULE.IInspectable! - var identity: ComPtr? - @_spi(WinRTImplements) - open func queryInterface(_ iid: SUPPORT_MODULE.IID) -> IUnknownRef? { + override open func queryInterface(_ iid: SUPPORT_MODULE.IID) -> IUnknownRef? { SUPPORT_MODULE.queryInterface(self, iid) } diff --git a/tests/test_app/IdentityTests.swift b/tests/test_app/IdentityTests.swift new file mode 100644 index 00000000..e028a7de --- /dev/null +++ b/tests/test_app/IdentityTests.swift @@ -0,0 +1,78 @@ +import test_component +import XCTest + +class MyObj: WinRTObject { + override public init() { + super.init() + } +} + +class ABIHelper { + static func toABI(_ obj1: __ABI_.AnyWrapper, _ obj2: __ABI_.AnyWrapper, _ body: (UnsafeMutablePointer, UnsafeMutablePointer) -> Void) throws { + try obj1.toABI { abi1 in + try obj2.toABI { abi2 in + body(abi1, abi2) + } + } + } +} + +class IdentityTests: XCTestCase { + public func testIdentity() throws { + let obj = MyObj() + let wrapper = try XCTUnwrap(__ABI_.AnyWrapper(obj)) + let wrapper2 = try XCTUnwrap(__ABI_.AnyWrapper(obj)) + + // Since `identity` is internal and we don't have access to internals we just check that pointers from different wrappers + // are the same + try ABIHelper.toABI(wrapper, wrapper2) { (abi1, abi2) in + XCTAssertEqual(abi1, abi2) + } + } + + public func testIdentityCopyTo() throws { + let obj = MyObj() + let wrapper = try XCTUnwrap(__ABI_.AnyWrapper(obj)) + let wrapper2 = try XCTUnwrap(__ABI_.AnyWrapper(obj)) + var copy: UnsafeMutablePointer? + var copy2: UnsafeMutablePointer? + defer { + _ = copy?.pointee.lpVtbl.pointee.Release(copy) + _ = copy2?.pointee.lpVtbl.pointee.Release(copy2) + } + + wrapper.copyTo(©) + wrapper2.copyTo(©2) + XCTAssertEqual(copy, copy2) + } + + // This test verifies that even after releasing the last reference + // from WinRT, that the original identity still persists and is valid. + // This is common in collection accessors, where the object is retrieved + // and then immediately released. + public func testIdentityAfterRelease() throws { + let obj = MyObj() + var copy2: UnsafeMutablePointer? + + var wrapper:__ABI_.AnyWrapper? = try XCTUnwrap(__ABI_.AnyWrapper(obj)) + var copy: UnsafeMutablePointer? + + wrapper?.copyTo(©) + _ = copy?.pointee.lpVtbl.pointee.Release(copy) + wrapper = nil + + + let wrapper2 = try XCTUnwrap(__ABI_.AnyWrapper(obj)) + wrapper2.copyTo(©2) + XCTAssertEqual(copy, copy2) + _ = copy2?.pointee.lpVtbl.pointee.Release(copy2) + } +} + +var identityTests: [XCTestCaseEntry] = [ + testCase([ + ("testIdentity", IdentityTests.testIdentity), + ("testIdentityCopyTo", IdentityTests.testIdentityCopyTo), + ("testIdentityAfterRelease", IdentityTests.testIdentityAfterRelease) + ]) +] \ No newline at end of file diff --git a/tests/test_app/MemoryManagementTests.swift b/tests/test_app/MemoryManagementTests.swift index cefb306a..aeec6d01 100644 --- a/tests/test_app/MemoryManagementTests.swift +++ b/tests/test_app/MemoryManagementTests.swift @@ -51,6 +51,24 @@ class MemoryManagementTests : XCTestCase { }() XCTAssertNil(weakDerived) } + + func testWinRTObject() throws { + weak var weakDerived: MyObj? = nil + try { + let obj = MyObj() + weakDerived = obj + + // wrapping MyObj will create the identity pointer. make sure + // this doesn't cause a leak + let wrapper = try XCTUnwrap(__ABI_.AnyWrapper(obj)) + var copy: UnsafeMutablePointer? + defer { + _ = copy?.pointee.lpVtbl.pointee.Release(copy) + } + wrapper.copyTo(©) + }() + XCTAssertNil(weakDerived) + } } var memoryManagementTests: [XCTestCaseEntry] = [ @@ -59,5 +77,6 @@ var memoryManagementTests: [XCTestCaseEntry] = [ ("testNonAggregatedObject", MemoryManagementTests.testNonAggregatedObject), ("testReturningAggregatedObject", MemoryManagementTests.testReturningAggregatedObject), ("testReturningNonAggregatedObject", MemoryManagementTests.testReturningNonAggregatedObject), + ("testWinRTObject", MemoryManagementTests.testWinRTObject) ]) ] diff --git a/tests/test_app/main.swift b/tests/test_app/main.swift index d68c5128..72cd4397 100644 --- a/tests/test_app/main.swift +++ b/tests/test_app/main.swift @@ -460,7 +460,7 @@ var tests: [XCTestCaseEntry] = [ ("testUnicode", SwiftWinRTTests.testUnicode), ("testErrorInfo", SwiftWinRTTests.testErrorInfo), ]) -] + valueBoxingTests + eventTests + collectionTests + aggregationTests + asyncTests + memoryManagementTests + bufferTests +] + valueBoxingTests + eventTests + collectionTests + aggregationTests + asyncTests + memoryManagementTests + bufferTests + identityTests RoInitialize(RO_INIT_MULTITHREADED) XCTMain(tests)