Skip to content

Commit

Permalink
Allow update codec while continue to load the model.
Browse files Browse the repository at this point in the history
  • Loading branch information
liuliu committed Aug 19, 2024
1 parent 7d06fad commit 89251c4
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 7 deletions.
18 changes: 14 additions & 4 deletions nnc/ModelBuilder.swift
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,15 @@ public class AnyModelBuilder {
CCV_NNC_DATA_TRANSFER_FORWARD, nil, CmdParamsFactory.factory.newParams(), 0),
ccv_nnc_no_hint, 0, &input, 1, tensorOut, 1, nil)
return Int32(CCV_IO_FINAL)
case .continue(let name):
case let .continue(name, codec):
var params = params
guard let codec = codec, var options = options?.pointee else {
return ccv_nnc_tensor_read(
readerHelper.sqlite, name, options, 0, &params, tensorOut)
}
options.decode = codec.decode
return ccv_nnc_tensor_read(
readerHelper.sqlite, name, options, 0, &params, tensorOut)
readerHelper.sqlite, name, &options, 0, &params, tensorOut)
case .fail:
return Int32(CCV_IO_ERROR)
}
Expand Down Expand Up @@ -184,10 +189,15 @@ public class AnyModelBuilder {
CCV_NNC_DATA_TRANSFER_FORWARD, nil, CmdParamsFactory.factory.newParams(), 0),
ccv_nnc_no_hint, 0, &input, 1, tensorOut, 1, nil)
return Int32(CCV_IO_FINAL)
case .continue(let name):
case let .continue(name, codec):
var params = params
guard let codec = codec, var options = options?.pointee else {
return ccv_nnc_tensor_read(
readerHelper.sqlite, name, options, 0, &params, tensorOut)
}
options.decode = codec.decode
return ccv_nnc_tensor_read(
readerHelper.sqlite, name, options, 0, &params, tensorOut)
readerHelper.sqlite, name, &options, 0, &params, tensorOut)
case .fail:
return Int32(CCV_IO_ERROR)
}
Expand Down
11 changes: 8 additions & 3 deletions nnc/Store.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4468,7 +4468,7 @@ extension DynamicGraph {
}
public enum ModelReaderResult {
/// Continue to load parameter with the given name.
case `continue`(String)
case `continue`(String, codec: Codec? = nil)
/// The parameter is loaded, no futher operation need.
case final(NNC.AnyTensor)
/// Nothing is loaded.
Expand Down Expand Up @@ -4539,10 +4539,15 @@ extension DynamicGraph {
CCV_NNC_DATA_TRANSFER_FORWARD, nil, CmdParamsFactory.factory.newParams(), 0),
ccv_nnc_no_hint, 0, &input, 1, tensorOut, 1, nil)
return Int32(CCV_IO_FINAL)
case .continue(let name):
case let .continue(name, codec):
var params = params
guard let codec = codec, var options = options?.pointee else {
return ccv_nnc_tensor_read(
readerHelper.sqlite, name, options, 0, &params, tensorOut)
}
options.decode = codec.decode
return ccv_nnc_tensor_read(
readerHelper.sqlite, name, options, 0, &params, tensorOut)
readerHelper.sqlite, name, &options, 0, &params, tensorOut)
case .fail:
return Int32(CCV_IO_ERROR)
}
Expand Down

0 comments on commit 89251c4

Please sign in to comment.