Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support for identity #141

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions swiftwinrt/Resources/Support/Aggregation.swift
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ import Foundation
// to the swift object, to ensure it doesn't get cleaned up. The Swift object in turn holds a strong
// reference to this object so that it stays alive.
@_spi(WinRTInternal)
public final class WinRTClassWeakReference<Class: WinRTClass> {
public final class WinRTClassWeakReference<Class: WinRTObject> {
fileprivate weak var instance: Class?
public init(_ instance: Class){
self.instance = instance
Expand Down Expand Up @@ -41,7 +41,7 @@ extension WinRTClassWeakReference: CustomAddRef {
}

extension WinRTClassWeakReference: AnyObjectWrapper {
var obj: AnyObject? { instance }
var obj: WinRTObject? { instance }
}

@_spi(WinRTInternal)
Expand Down
25 changes: 21 additions & 4 deletions swiftwinrt/Resources/Support/IInspectable.swift
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ open class IInspectable: IUnknown {
// }
// internal typealias Composable = IBaseNoOverrides
protocol AnyObjectWrapper {
var obj: AnyObject? { get }
var obj: WinRTObject? { get }
}

public enum __ABI_ {
Expand All @@ -53,7 +53,13 @@ public enum __ABI_ {
} else {
let vtblPtr = withUnsafeMutablePointer(to: &IInspectableVTable) { $0 }
let cAbi = C_IInspectable(lpVtbl: vtblPtr)
super.init(cAbi, swift as AnyObject)
// For WinRTObject types, hold a weak reference to the swift obj since
// it will hold a strong reference to ourselves
if let winrtObj = swift as? WinRTObject {
super.init(cAbi, WinRTClassWeakReference(winrtObj))
} else {
super.init(cAbi, swift as AnyObject)
}
}
}

Expand All @@ -62,6 +68,16 @@ public enum __ABI_ {
if let swiftAbi = swiftObj as? IInspectable {
let abi: UnsafeMutablePointer<C_IInspectable> = RawPointer(swiftAbi)
return try body(abi)
} else if let winrtObjWrapper = swiftObj as? AnyObjectWrapper,
let winrtObj = winrtObjWrapper.obj {
if let identity = winrtObj.identity {
return try body(identity.get())
}
return try super.toABI{
winrtObj.identity = .init($0)
(winrtObjWrapper as? CustomAddRef)?.release()
return try body($0)
}
} else {
return try super.toABI(body)
}
Expand Down Expand Up @@ -116,8 +132,9 @@ public enum __ABI_ {

GetRuntimeClassName: {
guard let instance = AnyWrapper.tryUnwrapFrom(raw: $0) else { return E_INVALIDARG }
guard let winrtClass = instance as? WinRTClass else {
let string = String(reflecting: type(of: instance))
let winrtObj = (instance as? AnyObjectWrapper)?.obj ?? instance
guard let winrtClass = winrtObj as? WinRTClass else {
let string = String(reflecting: type(of: winrtObj))
$1!.pointee = try! HString(string).detach()
return S_OK
}
Expand Down
27 changes: 22 additions & 5 deletions swiftwinrt/Resources/Support/WinRTProtocols.swift
Original file line number Diff line number Diff line change
Expand Up @@ -13,8 +13,27 @@ public protocol IWinRTObject: AnyObject {
public protocol WinRTInterface: AnyObject, CustomQueryInterface {
}

open class WinRTClass : CustomQueryInterface, Equatable {
public init() {}
open class WinRTObject: Equatable {
public init() {}

public static func == (lhs: WinRTObject, rhs: WinRTObject) -> Bool {
return lhs.identity?.get() == rhs.identity?.get()
}

internal var identity: ComPtr<C_IInspectable>?

@_spi(WinRTImplements)
open func queryInterface(_ iid: SUPPORT_MODULE.IID) -> IUnknownRef? {
// the vtables for IInspectable will properly respond to all default QueryInterface
// calls for IInspectable and IUnknown. The queryInterface method exists for subclasses
// to provide custom handling of QueryInterface calls for other interfaces that the code
// gen is unaware of
return nil
}
}

open class WinRTClass : WinRTObject, CustomQueryInterface {
override public init() { super.init() }

@_spi(WinRTInternal)
public init(_ ptr: SUPPORT_MODULE.IInspectable) {
Expand All @@ -35,10 +54,8 @@ open class WinRTClass : CustomQueryInterface, Equatable {
@_spi(WinRTInternal)
public internal(set) var _inner: SUPPORT_MODULE.IInspectable!

var identity: ComPtr<C_IInspectable>?

@_spi(WinRTImplements)
open func queryInterface(_ iid: SUPPORT_MODULE.IID) -> IUnknownRef? {
override open func queryInterface(_ iid: SUPPORT_MODULE.IID) -> IUnknownRef? {
SUPPORT_MODULE.queryInterface(self, iid)
}

Expand Down
78 changes: 78 additions & 0 deletions tests/test_app/IdentityTests.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
import test_component
import XCTest

class MyObj: WinRTObject {
override public init() {
super.init()
}
}

class ABIHelper {
static func toABI(_ obj1: __ABI_.AnyWrapper, _ obj2: __ABI_.AnyWrapper, _ body: (UnsafeMutablePointer<C_IInspectable>, UnsafeMutablePointer<C_IInspectable>) -> Void) throws {
try obj1.toABI { abi1 in
try obj2.toABI { abi2 in
body(abi1, abi2)
}
}
}
}

class IdentityTests: XCTestCase {
public func testIdentity() throws {
let obj = MyObj()
let wrapper = try XCTUnwrap(__ABI_.AnyWrapper(obj))
let wrapper2 = try XCTUnwrap(__ABI_.AnyWrapper(obj))

// Since `identity` is internal and we don't have access to internals we just check that pointers from different wrappers
// are the same
try ABIHelper.toABI(wrapper, wrapper2) { (abi1, abi2) in
XCTAssertEqual(abi1, abi2)
}
}

public func testIdentityCopyTo() throws {
let obj = MyObj()
let wrapper = try XCTUnwrap(__ABI_.AnyWrapper(obj))
let wrapper2 = try XCTUnwrap(__ABI_.AnyWrapper(obj))
var copy: UnsafeMutablePointer<C_IInspectable>?
var copy2: UnsafeMutablePointer<C_IInspectable>?
defer {
_ = copy?.pointee.lpVtbl.pointee.Release(copy)
_ = copy2?.pointee.lpVtbl.pointee.Release(copy2)
}

wrapper.copyTo(&copy)
wrapper2.copyTo(&copy2)
XCTAssertEqual(copy, copy2)
}

// This test verifies that even after releasing the last reference
// from WinRT, that the original identity still persists and is valid.
// This is common in collection accessors, where the object is retrieved
// and then immediately released.
public func testIdentityAfterRelease() throws {
let obj = MyObj()
var copy2: UnsafeMutablePointer<C_IInspectable>?

var wrapper:__ABI_.AnyWrapper? = try XCTUnwrap(__ABI_.AnyWrapper(obj))
var copy: UnsafeMutablePointer<C_IInspectable>?

wrapper?.copyTo(&copy)
_ = copy?.pointee.lpVtbl.pointee.Release(copy)
wrapper = nil


let wrapper2 = try XCTUnwrap(__ABI_.AnyWrapper(obj))
wrapper2.copyTo(&copy2)
XCTAssertEqual(copy, copy2)
_ = copy2?.pointee.lpVtbl.pointee.Release(copy2)
}
}

var identityTests: [XCTestCaseEntry] = [
testCase([
("testIdentity", IdentityTests.testIdentity),
("testIdentityCopyTo", IdentityTests.testIdentityCopyTo),
("testIdentityAfterRelease", IdentityTests.testIdentityAfterRelease)
])
]
19 changes: 19 additions & 0 deletions tests/test_app/MemoryManagementTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,24 @@ class MemoryManagementTests : XCTestCase {
}()
XCTAssertNil(weakDerived)
}

func testWinRTObject() throws {
weak var weakDerived: MyObj? = nil
try {
let obj = MyObj()
weakDerived = obj

// wrapping MyObj will create the identity pointer. make sure
// this doesn't cause a leak
let wrapper = try XCTUnwrap(__ABI_.AnyWrapper(obj))
var copy: UnsafeMutablePointer<C_IInspectable>?
defer {
_ = copy?.pointee.lpVtbl.pointee.Release(copy)
}
wrapper.copyTo(&copy)
}()
XCTAssertNil(weakDerived)
}
}

var memoryManagementTests: [XCTestCaseEntry] = [
Expand All @@ -59,5 +77,6 @@ var memoryManagementTests: [XCTestCaseEntry] = [
("testNonAggregatedObject", MemoryManagementTests.testNonAggregatedObject),
("testReturningAggregatedObject", MemoryManagementTests.testReturningAggregatedObject),
("testReturningNonAggregatedObject", MemoryManagementTests.testReturningNonAggregatedObject),
("testWinRTObject", MemoryManagementTests.testWinRTObject)
])
]
2 changes: 1 addition & 1 deletion tests/test_app/main.swift
Original file line number Diff line number Diff line change
Expand Up @@ -460,7 +460,7 @@ var tests: [XCTestCaseEntry] = [
("testUnicode", SwiftWinRTTests.testUnicode),
("testErrorInfo", SwiftWinRTTests.testErrorInfo),
])
] + valueBoxingTests + eventTests + collectionTests + aggregationTests + asyncTests + memoryManagementTests + bufferTests
] + valueBoxingTests + eventTests + collectionTests + aggregationTests + asyncTests + memoryManagementTests + bufferTests + identityTests

RoInitialize(RO_INIT_MULTITHREADED)
XCTMain(tests)
Loading