From 89251c4643283f66a494c55033128a3a71ec5209 Mon Sep 17 00:00:00 2001 From: Liu Liu Date: Mon, 19 Aug 2024 16:33:40 -0400 Subject: [PATCH] Allow update codec while continue to load the model. --- nnc/ModelBuilder.swift | 18 ++++++++++++++---- nnc/Store.swift | 11 ++++++++--- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/nnc/ModelBuilder.swift b/nnc/ModelBuilder.swift index 53198e501ea..01dd0045a2c 100644 --- a/nnc/ModelBuilder.swift +++ b/nnc/ModelBuilder.swift @@ -125,10 +125,15 @@ public class AnyModelBuilder { CCV_NNC_DATA_TRANSFER_FORWARD, nil, CmdParamsFactory.factory.newParams(), 0), ccv_nnc_no_hint, 0, &input, 1, tensorOut, 1, nil) return Int32(CCV_IO_FINAL) - case .continue(let name): + case let .continue(name, codec): var params = params + guard let codec = codec, var options = options?.pointee else { + return ccv_nnc_tensor_read( + readerHelper.sqlite, name, options, 0, ¶ms, tensorOut) + } + options.decode = codec.decode return ccv_nnc_tensor_read( - readerHelper.sqlite, name, options, 0, ¶ms, tensorOut) + readerHelper.sqlite, name, &options, 0, ¶ms, tensorOut) case .fail: return Int32(CCV_IO_ERROR) } @@ -184,10 +189,15 @@ public class AnyModelBuilder { CCV_NNC_DATA_TRANSFER_FORWARD, nil, CmdParamsFactory.factory.newParams(), 0), ccv_nnc_no_hint, 0, &input, 1, tensorOut, 1, nil) return Int32(CCV_IO_FINAL) - case .continue(let name): + case let .continue(name, codec): var params = params + guard let codec = codec, var options = options?.pointee else { + return ccv_nnc_tensor_read( + readerHelper.sqlite, name, options, 0, ¶ms, tensorOut) + } + options.decode = codec.decode return ccv_nnc_tensor_read( - readerHelper.sqlite, name, options, 0, ¶ms, tensorOut) + readerHelper.sqlite, name, &options, 0, ¶ms, tensorOut) case .fail: return Int32(CCV_IO_ERROR) } diff --git a/nnc/Store.swift b/nnc/Store.swift index bde32723f57..a1917b4c5d0 100644 --- a/nnc/Store.swift +++ b/nnc/Store.swift @@ -4468,7 +4468,7 @@ extension DynamicGraph { } public enum ModelReaderResult { /// Continue to load parameter with the given name. - case `continue`(String) + case `continue`(String, codec: Codec? = nil) /// The parameter is loaded, no futher operation need. case final(NNC.AnyTensor) /// Nothing is loaded. @@ -4539,10 +4539,15 @@ extension DynamicGraph { CCV_NNC_DATA_TRANSFER_FORWARD, nil, CmdParamsFactory.factory.newParams(), 0), ccv_nnc_no_hint, 0, &input, 1, tensorOut, 1, nil) return Int32(CCV_IO_FINAL) - case .continue(let name): + case let .continue(name, codec): var params = params + guard let codec = codec, var options = options?.pointee else { + return ccv_nnc_tensor_read( + readerHelper.sqlite, name, options, 0, ¶ms, tensorOut) + } + options.decode = codec.decode return ccv_nnc_tensor_read( - readerHelper.sqlite, name, options, 0, ¶ms, tensorOut) + readerHelper.sqlite, name, &options, 0, ¶ms, tensorOut) case .fail: return Int32(CCV_IO_ERROR) }